File indexing completed on 2025-01-30 10:22:59
0001 #ifndef TMVA_SOFIE_ROPERATOR_Custom
0002 #define TMVA_SOFIE_ROPERATOR_Custom
0003
0004
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008
0009 namespace TMVA{
0010 namespace Experimental{
0011 namespace SOFIE{
0012
0013
0014 template<typename T>
0015 class ROperator_Custom final : public ROperator
0016 {
0017
0018 private:
0019 std::string fOpName;
0020 std::vector<std::string> fInputNames;
0021 std::vector<std::string> fOutputNames;
0022 std::vector<std::vector<std::size_t>> fOutputShapes;
0023 std::string fHeaderName;
0024
0025 public:
0026 ROperator_Custom(){}
0027 ROperator_Custom(std::string OpName, std::vector<std::string>Inputs, std::vector<std::string>Outputs, std::vector<std::vector<std::size_t>> OutputShapes, std::string HeaderName){
0028 fOpName = OpName;
0029 fOutputShapes = OutputShapes;
0030 fHeaderName = HeaderName;
0031 for(auto& it:Inputs){
0032 fInputNames.emplace_back(UTILITY::Clean_name(it));
0033 }
0034 for(auto& it:Outputs){
0035 fOutputNames.emplace_back(UTILITY::Clean_name(it));
0036 }
0037 }
0038
0039 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>>) {return {{}};};
0040 std::vector<ETensorType> TypeInference(std::vector<ETensorType>){ return {};};
0041
0042 void Initialize(RModel& model){
0043 model.AddNeededCustomHeader(fHeaderName);
0044 for(auto& it:fInputNames){
0045 if (model.CheckIfTensorAlreadyExist(it) == false){
0046 throw std::runtime_error("TMVA SOFIE Custom " + fOpName + " Op Input Tensor " + it + " is not found in model");
0047 }
0048 }
0049
0050 if(fOutputNames.size() != fOutputShapes.size()){
0051 throw std::runtime_error("TMVA SOFIE Custom "+ fOpName + " Op was not intialized with the names/shapes of all the output tensors");
0052 }
0053
0054 for(long unsigned int i=0; i<fOutputNames.size(); ++i){
0055 model.AddIntermediateTensor(fOutputNames[i], ETensorType::FLOAT, fOutputShapes[i]);
0056 }
0057 model.UpdateOutputTensorList(fInputNames, fOutputNames);
0058 }
0059
0060 std::string Generate(std::string OpName){
0061 OpName = "op_" + OpName;
0062 std::stringstream out;
0063 out << "\n//------ "<<fOpName<<" \n";
0064 std::string args;
0065 for(long unsigned int i = 0; i<fInputNames.size(); ++i){
0066 args+="fTensor_"+fInputNames[i]+",";
0067 }
0068
0069 for(long unsigned int i = 0; i<fOutputNames.size(); ++i){
0070 args+="fTensor_"+fOutputNames[i]+",";
0071 }
0072 args.pop_back();
0073 out << SP << fOpName<<"::Compute("+args+");\n";
0074 return out.str();
0075 }
0076
0077 };
0078
0079
0080 }
0081 }
0082 }
0083
0084
0085 #endif