Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:05

0001 #ifndef TMVA_EXPERIMENTAL_SOFIE_ROPERATOR_BASIC_UNARY
0002 #define TMVA_EXPERIMENTAL_SOFIE_ROPERATOR_BASIC_UNARY
0003 
0004 #include <TMVA/ROperator.hxx>
0005 #include <TMVA/RModel.hxx>
0006 #include <TMVA/SOFIE_common.hxx>
0007 
0008 namespace TMVA {
0009 namespace Experimental {
0010 namespace SOFIE {
0011 
0012 enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog };
0013 
0014 template <typename T, EBasicUnaryOperator Op>
0015 struct UnaryOpTraits {
0016 };
0017 
0018 template <typename T>
0019 struct UnaryOpTraits<T, EBasicUnaryOperator::kReciprocal> {
0020    static std::string Name() { return "Reciprocal"; }
0021    static std::string Op(const std::string &X) { return "1/" + X; }
0022 };
0023 
0024 template <typename T>
0025 struct UnaryOpTraits<T, EBasicUnaryOperator::kSqrt> {
0026    static std::string Name() { return "Sqrt"; }
0027    static std::string Op(const std::string &X) { return "std::sqrt(" + X + ")"; }
0028 };
0029 
0030 template <typename T>
0031 struct UnaryOpTraits<T, EBasicUnaryOperator::kNeg> {
0032    static std::string Name() { return "Neg"; }
0033    static std::string Op(const std::string &X) { return "-" + X; }
0034 };
0035 
0036 template <typename T>
0037 struct UnaryOpTraits<T, EBasicUnaryOperator::kExp> {
0038    static std::string Name() { return "Exp"; }
0039    static std::string Op(const std::string &X) { return "std::exp(" + X + ")"; }
0040 };
0041 
0042 template <typename T>
0043 struct UnaryOpTraits<T, EBasicUnaryOperator::kLog> {
0044    static std::string Name() { return "Log"; }
0045    static std::string Op(const std::string &X) { return "std::log(" + X + ")"; }
0046 };
0047 
0048 template <typename T, EBasicUnaryOperator Op>
0049 class ROperator_BasicUnary final : public ROperator {
0050 private:
0051    std::string fNX;
0052    std::string fNY;
0053 
0054    std::vector<size_t> fShapeX;
0055    std::vector<size_t> fShapeY;
0056 
0057 public:
0058    ROperator_BasicUnary() {}
0059 
0060    ROperator_BasicUnary(std::string nameX, std::string nameY)
0061       : fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
0062    {}
0063 
0064    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override { return input; }
0065 
0066    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0067 
0068    void Initialize(RModel &model) override
0069    {
0070       if (!model.CheckIfTensorAlreadyExist(fNX)) {
0071          throw std::runtime_error("TMVA::SOFIE - Tensor " + fNX + " not found.");
0072       }
0073       fShapeX = model.GetTensorShape(fNX);
0074       fShapeY = ShapeInference({fShapeX})[0];
0075       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0076    }
0077 
0078    std::string Generate(std::string OpName) override
0079    {
0080       OpName = "op_" + OpName;
0081       std::stringstream out;
0082 
0083       out << SP << "\n//---- Operator" << UnaryOpTraits<T, Op>::Name() << " " << OpName << "\n";
0084       size_t length = ConvertShapeToLength(fShapeX);
0085       out << SP << "for (size_t i = 0; i < " << length << "; i++) {\n";
0086       out << SP << SP << "tensor_" << fNY << "[i] = " << UnaryOpTraits<T, Op>::Op("tensor_" + fNX + "[i]") << ";\n";
0087       out << SP << "}\n";
0088       return out.str();
0089    }
0090 };
0091 
0092 } // namespace SOFIE
0093 } // namespace Experimental
0094 } // namespace TMVA
0095 
0096 #endif