Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:08

0001 #ifndef TMVA_SOFIE_ROPERATOR_Shape
0002 #define TMVA_SOFIE_ROPERATOR_Shape
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 #include<sstream>
0010 #include<vector>
0011 #include <iterator>
0012 #include<string>
0013 namespace TMVA{
0014 namespace Experimental{
0015 namespace SOFIE{
0016 
0017 class ROperator_Shape final : public ROperator
0018 {
0019 
0020 private:
0021 
0022    /* Attributes*/
0023    int fStart = 0;
0024    int fEnd = -1;
0025    std::string fNX;
0026    std::string fNY;
0027    std::vector<size_t> fShape;
0028    std::vector<size_t> fOutput_shape;
0029 
0030 public:
0031    ROperator_Shape(){}
0032    ROperator_Shape(int start, int end, std::string nameX, std::string nameY):
0033    fStart(start) ,fEnd(end), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){}
0034 
0035    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0036       return input;
0037    }
0038 
0039    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0040       std::vector<std::vector<size_t>>  ret;
0041       ret[0].push_back(input[0].size());
0042       return ret;
0043    }
0044 
0045    void Initialize(RModel& model){
0046       if (model.CheckIfTensorAlreadyExist(fNX) == false){   //input must be a graph input, or already initialized intermediate tensor
0047          throw std::runtime_error("TMVA SOFIE Shape Op Input Tensor is not found in model");
0048       }
0049       fShape = model.GetTensorShape(fNX);
0050       size_t length = ConvertShapeToLength(fShape);
0051       if (fStart < 0) fStart += length;
0052       if (fEnd < 0) fEnd += length;
0053       fOutput_shape = { size_t(fEnd - fStart) + 1};
0054       model.AddIntermediateTensor(fNY, ETensorType::INT64, fOutput_shape);
0055    }
0056 
0057    std::string Generate(std::string OpName){
0058       OpName = "op_" + OpName;
0059       if (fShape.empty()) {
0060          throw std::runtime_error("TMVA SOFIE Shape op called to Generate without being initialized first");
0061       }
0062       std::stringstream out;
0063 
0064       out << "\n//------ Shape\n";
0065       // add a dummy statement to avoid warning for unused input
0066       out << SP << "(void) tensor_" << fNX << ";\n";
0067       size_t length = ConvertShapeToLength(fOutput_shape);
0068       for (size_t id = 0; id < length; id++) {
0069          out << SP << "tensor_" << fNY << "["<< id << "] = " << fShape[fStart+id] << ";\n";
0070       }
0071       return out.str();
0072    }
0073 
0074 };
0075 
0076 }//SOFIE
0077 }//Experimental
0078 }//TMVA
0079 
0080 
0081 #endif //TMVA_SOFIE_ROPERATOR_Shape