File indexing completed on 2025-09-18 09:32:36
0001 #ifndef TMVA_SOFIE_ROPERATOR_IDENTITY
0002 #define TMVA_SOFIE_ROPERATOR_IDENTITY
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014 template <typename T>
0015 class ROperator_Identity final : public ROperator
0016 {
0017
0018 private:
0019
0020 bool fIsInputInitialized = false;
0021 std::string fNX;
0022 std::string fNY;
0023 std::vector<size_t> fShape;
0024
0025 public:
0026 ROperator_Identity(){}
0027 ROperator_Identity(std::string nameX, std::string nameY):
0028 fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){
0029 fInputTensorNames = { fNX };
0030 fOutputTensorNames = { fNY };
0031 }
0032
0033 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0034 return input;
0035 }
0036
0037 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0038 auto ret = input;
0039 return ret;
0040 }
0041
0042 void Initialize(RModel& model) override {
0043
0044 if (model.CheckIfTensorAlreadyExist(fNX) == false){
0045 throw std::runtime_error("TMVA SOFIE Identity Op Input Tensor is not found in model");
0046 }
0047 fShape = model.GetTensorShape(fNX);
0048 if (model.IsInitializedTensor(fNX)) {
0049
0050
0051
0052 if (model.IsConstantTensor(fNX)) {
0053 auto inputData = static_cast<T*>(model.GetInitializedTensorData(fNX).get());
0054 model.AddConstantTensor<T>(fNY, fShape, inputData);
0055 fIsOutputConstant = true;
0056 } else {
0057 fIsInputInitialized = true;
0058
0059
0060 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
0061 }
0062 } else
0063 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
0064 }
0065
0066 std::string GenerateInitCode() override {
0067
0068 if (!fIsInputInitialized) return "";
0069 std::stringstream out;
0070 out << "\n//------ IDENTITY\n";
0071
0072 out << SP << SP << "tensor_" << fNY << " = tensor_" << fNX << ";\n";
0073 return out.str();
0074 }
0075
0076
0077 std::string Generate(std::string OpName) override {
0078 if (fIsOutputConstant || fIsInputInitialized) return "";
0079 OpName = "op_" + OpName;
0080 if (fShape.empty()) {
0081 throw std::runtime_error("TMVA SOFIE Operator Identity called to Generate without being initialized first");
0082 }
0083 std::stringstream out;
0084 out << "\n//------ IDENTITY\n";
0085
0086 out << SP << SP << "tensor_" << fNY << " = tensor_" << fNX << ";\n";
0087 return out.str();
0088 }
0089
0090 };
0091
0092 }
0093 }
0094 }
0095
0096
0097 #endif