File indexing completed on 2025-01-18 10:11:05
0001 #ifndef TMVA_EXPERIMENTAL_SOFIE_ROPERATOR_BASIC_UNARY
0002 #define TMVA_EXPERIMENTAL_SOFIE_ROPERATOR_BASIC_UNARY
0003
0004 #include <TMVA/ROperator.hxx>
0005 #include <TMVA/RModel.hxx>
0006 #include <TMVA/SOFIE_common.hxx>
0007
0008 namespace TMVA {
0009 namespace Experimental {
0010 namespace SOFIE {
0011
0012 enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog };
0013
0014 template <typename T, EBasicUnaryOperator Op>
0015 struct UnaryOpTraits {
0016 };
0017
0018 template <typename T>
0019 struct UnaryOpTraits<T, EBasicUnaryOperator::kReciprocal> {
0020 static std::string Name() { return "Reciprocal"; }
0021 static std::string Op(const std::string &X) { return "1/" + X; }
0022 };
0023
0024 template <typename T>
0025 struct UnaryOpTraits<T, EBasicUnaryOperator::kSqrt> {
0026 static std::string Name() { return "Sqrt"; }
0027 static std::string Op(const std::string &X) { return "std::sqrt(" + X + ")"; }
0028 };
0029
0030 template <typename T>
0031 struct UnaryOpTraits<T, EBasicUnaryOperator::kNeg> {
0032 static std::string Name() { return "Neg"; }
0033 static std::string Op(const std::string &X) { return "-" + X; }
0034 };
0035
0036 template <typename T>
0037 struct UnaryOpTraits<T, EBasicUnaryOperator::kExp> {
0038 static std::string Name() { return "Exp"; }
0039 static std::string Op(const std::string &X) { return "std::exp(" + X + ")"; }
0040 };
0041
0042 template <typename T>
0043 struct UnaryOpTraits<T, EBasicUnaryOperator::kLog> {
0044 static std::string Name() { return "Log"; }
0045 static std::string Op(const std::string &X) { return "std::log(" + X + ")"; }
0046 };
0047
0048 template <typename T, EBasicUnaryOperator Op>
0049 class ROperator_BasicUnary final : public ROperator {
0050 private:
0051 std::string fNX;
0052 std::string fNY;
0053
0054 std::vector<size_t> fShapeX;
0055 std::vector<size_t> fShapeY;
0056
0057 public:
0058 ROperator_BasicUnary() {}
0059
0060 ROperator_BasicUnary(std::string nameX, std::string nameY)
0061 : fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
0062 {}
0063
0064 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override { return input; }
0065
0066 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0067
0068 void Initialize(RModel &model) override
0069 {
0070 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0071 throw std::runtime_error("TMVA::SOFIE - Tensor " + fNX + " not found.");
0072 }
0073 fShapeX = model.GetTensorShape(fNX);
0074 fShapeY = ShapeInference({fShapeX})[0];
0075 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0076 }
0077
0078 std::string Generate(std::string OpName) override
0079 {
0080 OpName = "op_" + OpName;
0081 std::stringstream out;
0082
0083 out << SP << "\n//---- Operator" << UnaryOpTraits<T, Op>::Name() << " " << OpName << "\n";
0084 size_t length = ConvertShapeToLength(fShapeX);
0085 out << SP << "for (size_t i = 0; i < " << length << "; i++) {\n";
0086 out << SP << SP << "tensor_" << fNY << "[i] = " << UnaryOpTraits<T, Op>::Op("tensor_" + fNX + "[i]") << ";\n";
0087 out << SP << "}\n";
0088 return out.str();
0089 }
0090 };
0091
0092 }
0093 }
0094 }
0095
0096 #endif