File indexing completed on 2025-09-15 09:12:02
0001 #ifndef TMVA_EXPERIMENTAL_SOFIE_ROPERATOR_BASIC_UNARY
0002 #define TMVA_EXPERIMENTAL_SOFIE_ROPERATOR_BASIC_UNARY
0003
0004 #include <TMVA/ROperator.hxx>
0005 #include <TMVA/RModel.hxx>
0006 #include <TMVA/SOFIE_common.hxx>
0007
0008 namespace TMVA {
0009 namespace Experimental {
0010 namespace SOFIE {
0011
0012 enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog, kSin, kCos, kAbs };
0013
0014 template <typename T, EBasicUnaryOperator Op>
0015 struct UnaryOpTraits {
0016 };
0017
0018 template <typename T>
0019 struct UnaryOpTraits<T, EBasicUnaryOperator::kReciprocal> {
0020 static std::string Name() { return "Reciprocal"; }
0021 static std::string Op(const std::string &X) { return "1/" + X; }
0022 };
0023
0024 template <typename T>
0025 struct UnaryOpTraits<T, EBasicUnaryOperator::kSqrt> {
0026 static std::string Name() { return "Sqrt"; }
0027 static std::string Op(const std::string &X) { return "std::sqrt(" + X + ")"; }
0028 };
0029
0030 template <typename T>
0031 struct UnaryOpTraits<T, EBasicUnaryOperator::kNeg> {
0032 static std::string Name() { return "Neg"; }
0033 static std::string Op(const std::string &X) { return "-" + X; }
0034 };
0035
0036 template <typename T>
0037 struct UnaryOpTraits<T, EBasicUnaryOperator::kExp> {
0038 static std::string Name() { return "Exp"; }
0039 static std::string Op(const std::string &X) { return "std::exp(" + X + ")"; }
0040 };
0041
0042 template <typename T>
0043 struct UnaryOpTraits<T, EBasicUnaryOperator::kLog> {
0044 static std::string Name() { return "Log"; }
0045 static std::string Op(const std::string &X) { return "std::log(" + X + ")"; }
0046 };
0047
0048 template <typename T>
0049 struct UnaryOpTraits<T, EBasicUnaryOperator::kSin> {
0050 static std::string Name() { return "Sin"; }
0051 static std::string Op(const std::string &X) { return "std::sin(" + X + ")"; }
0052 };
0053
0054 template <typename T>
0055 struct UnaryOpTraits<T, EBasicUnaryOperator::kCos> {
0056 static std::string Name() { return "Cos"; }
0057 static std::string Op(const std::string &X) { return "std::cos(" + X + ")"; }
0058 };
0059
0060 template <typename T>
0061 struct UnaryOpTraits<T, EBasicUnaryOperator::kAbs> {
0062 static std::string Name() { return "Abs"; }
0063 static std::string Op(const std::string &X) { return "std::abs(" + X + ")"; }
0064 };
0065
0066 template <typename T, EBasicUnaryOperator Op>
0067 class ROperator_BasicUnary final : public ROperator {
0068 private:
0069 std::string fNX;
0070 std::string fNY;
0071
0072 std::vector<size_t> fShapeX;
0073 std::vector<size_t> fShapeY;
0074
0075 public:
0076 ROperator_BasicUnary() {}
0077
0078 ROperator_BasicUnary(std::string nameX, std::string nameY)
0079 : fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
0080 {
0081 fInputTensorNames = { fNX };
0082 fOutputTensorNames = { fNY };
0083 }
0084
0085 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override { return input; }
0086
0087 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0088
0089 void Initialize(RModel& model) override {
0090 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0091 throw std::runtime_error("TMVA::SOFIE - Tensor " + fNX + " not found.");
0092 }
0093 fShapeX = model.GetTensorShape(fNX);
0094 fShapeY = ShapeInference({fShapeX})[0];
0095 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0096 }
0097
0098 std::string Generate(std::string OpName) override
0099 {
0100 OpName = "op_" + OpName;
0101 std::stringstream out;
0102
0103 out << SP << "\n//---- Operator" << UnaryOpTraits<T, Op>::Name() << " " << OpName << "\n";
0104 size_t length = ConvertShapeToLength(fShapeX);
0105 out << SP << "for (size_t i = 0; i < " << length << "; i++) {\n";
0106 out << SP << SP << "tensor_" << fNY << "[i] = " << UnaryOpTraits<T, Op>::Op("tensor_" + fNX + "[i]") << ";\n";
0107 out << SP << "}\n";
0108 return out.str();
0109 }
0110
0111 std::vector<std::string> GetStdLibs() override {
0112 if (Op == EBasicUnaryOperator::kSqrt || Op == EBasicUnaryOperator::kExp || Op == EBasicUnaryOperator::kLog) {
0113 return { std::string("cmath") };
0114 } else {
0115 return {};
0116 }
0117 }
0118 };
0119
0120 }
0121 }
0122 }
0123
0124 #endif