File indexing completed on 2025-01-18 10:11:06
0001 #ifndef TMVA_SOFIE_ROPERATOR_IDENTITY
0002 #define TMVA_SOFIE_ROPERATOR_IDENTITY
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014 template <typename T>
0015 class ROperator_Identity final : public ROperator
0016 {
0017
0018 private:
0019
0020 std::string fNX;
0021 std::string fNY;
0022 std::vector<size_t> fShape;
0023
0024 public:
0025 ROperator_Identity(){}
0026 ROperator_Identity(std::string nameX, std::string nameY):
0027 fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){}
0028
0029 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0030 return input;
0031 }
0032
0033 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0034 auto ret = input;
0035 return ret;
0036 }
0037
0038 void Initialize(RModel& model){
0039
0040 if (model.CheckIfTensorAlreadyExist(fNX) == false){
0041 throw std::runtime_error("TMVA SOFIE Identity Op Input Tensor is not found in model");
0042 }
0043 fShape = model.GetTensorShape(fNX);
0044 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
0045 }
0046
0047
0048 std::string Generate(std::string OpName){
0049 OpName = "op_" + OpName;
0050 if (fShape.empty()) {
0051 throw std::runtime_error("TMVA SOFIE Operator Identity called to Generate without being initialized first");
0052 }
0053 std::stringstream out;
0054 out << "\n//------ IDENTITY\n";
0055
0056 out << SP << SP << "tensor_" << fNY << " = tensor_" << fNX << ";\n";
0057 return out.str();
0058 }
0059
0060 };
0061
0062 }
0063 }
0064 }
0065
0066
0067 #endif