Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-15 09:12:02

0001 #ifndef TMVA_EXPERIMENTAL_SOFIE_ROPERATOR_BASIC_UNARY
0002 #define TMVA_EXPERIMENTAL_SOFIE_ROPERATOR_BASIC_UNARY
0003 
0004 #include <TMVA/ROperator.hxx>
0005 #include <TMVA/RModel.hxx>
0006 #include <TMVA/SOFIE_common.hxx>
0007 
0008 namespace TMVA {
0009 namespace Experimental {
0010 namespace SOFIE {
0011 
0012 enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog, kSin, kCos, kAbs };
0013 
0014 template <typename T, EBasicUnaryOperator Op>
0015 struct UnaryOpTraits {
0016 };
0017 
0018 template <typename T>
0019 struct UnaryOpTraits<T, EBasicUnaryOperator::kReciprocal> {
0020    static std::string Name() { return "Reciprocal"; }
0021    static std::string Op(const std::string &X) { return "1/" + X; }
0022 };
0023 
0024 template <typename T>
0025 struct UnaryOpTraits<T, EBasicUnaryOperator::kSqrt> {
0026    static std::string Name() { return "Sqrt"; }
0027    static std::string Op(const std::string &X) { return "std::sqrt(" + X + ")"; }
0028 };
0029 
0030 template <typename T>
0031 struct UnaryOpTraits<T, EBasicUnaryOperator::kNeg> {
0032    static std::string Name() { return "Neg"; }
0033    static std::string Op(const std::string &X) { return "-" + X; }
0034 };
0035 
0036 template <typename T>
0037 struct UnaryOpTraits<T, EBasicUnaryOperator::kExp> {
0038    static std::string Name() { return "Exp"; }
0039    static std::string Op(const std::string &X) { return "std::exp(" + X + ")"; }
0040 };
0041 
0042 template <typename T>
0043 struct UnaryOpTraits<T, EBasicUnaryOperator::kLog> {
0044    static std::string Name() { return "Log"; }
0045    static std::string Op(const std::string &X) { return "std::log(" + X + ")"; }
0046 };
0047 
0048 template <typename T>
0049 struct UnaryOpTraits<T, EBasicUnaryOperator::kSin> {
0050    static std::string Name() { return "Sin"; }
0051    static std::string Op(const std::string &X) { return "std::sin(" + X + ")"; }
0052 };
0053 
0054 template <typename T>
0055 struct UnaryOpTraits<T, EBasicUnaryOperator::kCos> {
0056    static std::string Name() { return "Cos"; }
0057    static std::string Op(const std::string &X) { return "std::cos(" + X + ")"; }
0058 };
0059 
0060 template <typename T>
0061 struct UnaryOpTraits<T, EBasicUnaryOperator::kAbs> {
0062    static std::string Name() { return "Abs"; }
0063    static std::string Op(const std::string &X) { return "std::abs(" + X + ")"; }
0064 };
0065 
0066 template <typename T, EBasicUnaryOperator Op>
0067 class ROperator_BasicUnary final : public ROperator {
0068 private:
0069    std::string fNX;
0070    std::string fNY;
0071 
0072    std::vector<size_t> fShapeX;
0073    std::vector<size_t> fShapeY;
0074 
0075 public:
0076    ROperator_BasicUnary() {}
0077 
0078    ROperator_BasicUnary(std::string nameX, std::string nameY)
0079       : fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
0080    {
0081          fInputTensorNames =  { fNX };
0082          fOutputTensorNames = { fNY };
0083    }
0084 
0085    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override { return input; }
0086 
0087    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0088 
0089    void Initialize(RModel& model) override {
0090       if (!model.CheckIfTensorAlreadyExist(fNX)) {
0091          throw std::runtime_error("TMVA::SOFIE - Tensor " + fNX + " not found.");
0092       }
0093       fShapeX = model.GetTensorShape(fNX);
0094       fShapeY = ShapeInference({fShapeX})[0];
0095       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0096    }
0097 
0098    std::string Generate(std::string OpName) override
0099    {
0100       OpName = "op_" + OpName;
0101       std::stringstream out;
0102 
0103       out << SP << "\n//---- Operator" << UnaryOpTraits<T, Op>::Name() << " " << OpName << "\n";
0104       size_t length = ConvertShapeToLength(fShapeX);
0105       out << SP << "for (size_t i = 0; i < " << length << "; i++) {\n";
0106       out << SP << SP << "tensor_" << fNY << "[i] = " << UnaryOpTraits<T, Op>::Op("tensor_" + fNX + "[i]") << ";\n";
0107       out << SP << "}\n";
0108       return out.str();
0109    }
0110 
0111    std::vector<std::string> GetStdLibs() override {
0112       if (Op == EBasicUnaryOperator::kSqrt || Op == EBasicUnaryOperator::kExp || Op == EBasicUnaryOperator::kLog) {
0113          return { std::string("cmath") };
0114       } else {
0115          return {};
0116       }
0117    }
0118 };
0119 
0120 } // namespace SOFIE
0121 } // namespace Experimental
0122 } // namespace TMVA
0123 
0124 #endif