Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-06 10:01:45

0001 #ifndef TMVA_SOFIE_ROPERATOR_Custom
0002 #define TMVA_SOFIE_ROPERATOR_Custom
0003 
0004 
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008 
0009 namespace TMVA{
0010 namespace Experimental{
0011 namespace SOFIE{
0012 
0013 
0014 template<typename T>
0015 class ROperator_Custom final : public ROperator
0016 {
0017 
0018 private:
0019     std::string fOpName;
0020     std::vector<std::string> fInputNames;
0021     std::vector<std::string> fOutputNames;
0022     std::vector<std::vector<std::size_t>> fOutputShapes;
0023     std::vector<std::size_t> fInputSizes;
0024     std::string fHeaderName;
0025     ETensorType fInputType;
0026 
0027 public:
0028     ROperator_Custom(){}
0029     ROperator_Custom(std::string OpName, std::vector<std::string>Inputs, std::vector<std::string>Outputs, std::vector<std::vector<std::size_t>> OutputShapes, std::string HeaderName){
0030         fOpName = OpName;
0031         fOutputShapes = OutputShapes;
0032         fHeaderName = HeaderName;
0033         for(auto& it:Inputs){
0034             fInputNames.emplace_back(UTILITY::Clean_name(it));
0035             fInputTensorNames.emplace_back(fInputNames.back());
0036         }
0037         for(auto& it:Outputs){
0038             fOutputNames.emplace_back(UTILITY::Clean_name(it));
0039             fOutputTensorNames.emplace_back(fOutputNames.back());
0040         }
0041     }
0042 
0043     std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>>) override {return {{}};};
0044     std::vector<ETensorType> TypeInference(std::vector<ETensorType>) override { return {};};
0045 
0046    void Initialize(RModel& model) override {
0047       model.AddNeededCustomHeader(fHeaderName);
0048       fInputType = model.GetTensorType(fInputNames[0]);
0049 
0050       for(auto& it:fInputNames){
0051         if (model.CheckIfTensorAlreadyExist(it) == false){
0052          throw std::runtime_error("TMVA SOFIE Custom " + fOpName + " Op Input Tensor " + it + " is not found in model");
0053         }
0054         fInputSizes.push_back(ConvertShapeToLength(model.GetTensorShape(it)));
0055       }
0056 
0057       if(fOutputNames.size() != fOutputShapes.size()){
0058         throw std::runtime_error("TMVA SOFIE Custom "+ fOpName + " Op was not intialized with the names/shapes of all the output tensors");
0059       }
0060 
0061       for(long unsigned int i=0; i<fOutputNames.size(); ++i){
0062         model.AddIntermediateTensor(std::string(fOutputNames[i]), ETensorType::FLOAT, fOutputShapes[i]);
0063       }
0064 
0065 
0066       model.UpdateOutputTensorList(fInputNames, fOutputNames);
0067 
0068       if (model.Verbose()) {
0069          std::cout << "Custom operator using " << fHeaderName;
0070          for (auto & i : fInputNames) std::cout << " " << i;
0071          std::cout << " ---> ";
0072          for (auto & i : fOutputNames) std::cout << " " << i;
0073          std::cout << "\n";
0074       }
0075       model.AddNeededCustomHeader("ROOT/RSpan.hxx");
0076    }
0077 
0078     std::string Generate(std::string OpName) override {
0079       OpName = "op_" + OpName;
0080       std::stringstream out;
0081       out << "\n//------ "<<fOpName<<" \n";
0082       std::string args;
0083       for(long unsigned int i = 0; i<fInputNames.size(); ++i){
0084         args+="std::span<const "+ConvertTypeToString(fInputType)+">(tensor_"+std::string(fInputNames[i])+", "+fInputSizes[i]+"),";
0085       }
0086 
0087       for(long unsigned int i = 0; i<fOutputNames.size(); ++i){
0088         args+="std::span<"+TensorType<T>::Name()+">(tensor_"+std::string(fOutputNames[i])+", "+ConvertShapeToLength(fOutputShapes[i])+"),";
0089       }
0090       args.pop_back();
0091       out << SP << fOpName<<"::Compute("+args+");\n";
0092       return out.str();
0093    }
0094 
0095 };
0096 
0097 
0098 }//SOFIE
0099 }//Experimental
0100 }//TMVA
0101 
0102 
0103 #endif //TMVA_SOFIE_ROPERATOR_Custom