Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:59

0001 #ifndef TMVA_SOFIE_ROPERATOR_Custom
0002 #define TMVA_SOFIE_ROPERATOR_Custom
0003 
0004 
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008 
0009 namespace TMVA{
0010 namespace Experimental{
0011 namespace SOFIE{
0012 
0013 
0014 template<typename T>
0015 class ROperator_Custom final : public ROperator
0016 {   
0017 
0018 private:
0019     std::string fOpName;
0020     std::vector<std::string> fInputNames;
0021     std::vector<std::string> fOutputNames;
0022     std::vector<std::vector<std::size_t>> fOutputShapes;
0023     std::string fHeaderName;
0024 
0025 public:
0026     ROperator_Custom(){}
0027     ROperator_Custom(std::string OpName, std::vector<std::string>Inputs, std::vector<std::string>Outputs, std::vector<std::vector<std::size_t>> OutputShapes, std::string HeaderName){
0028         fOpName = OpName;
0029         fOutputShapes = OutputShapes;
0030         fHeaderName = HeaderName;
0031         for(auto& it:Inputs){
0032             fInputNames.emplace_back(UTILITY::Clean_name(it));
0033         }
0034         for(auto& it:Outputs){
0035             fOutputNames.emplace_back(UTILITY::Clean_name(it));
0036         }
0037     }
0038 
0039     std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>>) {return {{}};};
0040     std::vector<ETensorType> TypeInference(std::vector<ETensorType>){ return {};};
0041 
0042     void Initialize(RModel& model){
0043       model.AddNeededCustomHeader(fHeaderName);
0044       for(auto& it:fInputNames){
0045         if (model.CheckIfTensorAlreadyExist(it) == false){  
0046          throw std::runtime_error("TMVA SOFIE Custom " + fOpName + " Op Input Tensor " + it + " is not found in model");
0047         }
0048       }
0049 
0050       if(fOutputNames.size() != fOutputShapes.size()){
0051         throw std::runtime_error("TMVA SOFIE Custom "+ fOpName + " Op was not intialized with the names/shapes of all the output tensors");
0052       }
0053 
0054       for(long unsigned int i=0; i<fOutputNames.size(); ++i){
0055         model.AddIntermediateTensor(fOutputNames[i], ETensorType::FLOAT, fOutputShapes[i]);
0056       }
0057       model.UpdateOutputTensorList(fInputNames, fOutputNames);
0058    }
0059 
0060     std::string Generate(std::string OpName){
0061       OpName = "op_" + OpName;
0062       std::stringstream out;
0063       out << "\n//------ "<<fOpName<<" \n";
0064       std::string args;
0065       for(long unsigned int i = 0; i<fInputNames.size(); ++i){
0066         args+="fTensor_"+fInputNames[i]+",";
0067       }
0068       
0069       for(long unsigned int i = 0; i<fOutputNames.size(); ++i){
0070         args+="fTensor_"+fOutputNames[i]+",";
0071       }
0072       args.pop_back();
0073       out << SP << fOpName<<"::Compute("+args+");\n";
0074       return out.str();
0075    }
0076 
0077 };
0078 
0079 
0080 }//SOFIE
0081 }//Experimental
0082 }//TMVA
0083 
0084 
0085 #endif //TMVA_SOFIE_ROPERATOR_Custom