File indexing completed on 2025-12-06 10:01:45
0001 #ifndef TMVA_SOFIE_ROPERATOR_Custom
0002 #define TMVA_SOFIE_ROPERATOR_Custom
0003
0004
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008
0009 namespace TMVA{
0010 namespace Experimental{
0011 namespace SOFIE{
0012
0013
0014 template<typename T>
0015 class ROperator_Custom final : public ROperator
0016 {
0017
0018 private:
0019 std::string fOpName;
0020 std::vector<std::string> fInputNames;
0021 std::vector<std::string> fOutputNames;
0022 std::vector<std::vector<std::size_t>> fOutputShapes;
0023 std::vector<std::size_t> fInputSizes;
0024 std::string fHeaderName;
0025 ETensorType fInputType;
0026
0027 public:
0028 ROperator_Custom(){}
0029 ROperator_Custom(std::string OpName, std::vector<std::string>Inputs, std::vector<std::string>Outputs, std::vector<std::vector<std::size_t>> OutputShapes, std::string HeaderName){
0030 fOpName = OpName;
0031 fOutputShapes = OutputShapes;
0032 fHeaderName = HeaderName;
0033 for(auto& it:Inputs){
0034 fInputNames.emplace_back(UTILITY::Clean_name(it));
0035 fInputTensorNames.emplace_back(fInputNames.back());
0036 }
0037 for(auto& it:Outputs){
0038 fOutputNames.emplace_back(UTILITY::Clean_name(it));
0039 fOutputTensorNames.emplace_back(fOutputNames.back());
0040 }
0041 }
0042
0043 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>>) override {return {{}};};
0044 std::vector<ETensorType> TypeInference(std::vector<ETensorType>) override { return {};};
0045
0046 void Initialize(RModel& model) override {
0047 model.AddNeededCustomHeader(fHeaderName);
0048 fInputType = model.GetTensorType(fInputNames[0]);
0049
0050 for(auto& it:fInputNames){
0051 if (model.CheckIfTensorAlreadyExist(it) == false){
0052 throw std::runtime_error("TMVA SOFIE Custom " + fOpName + " Op Input Tensor " + it + " is not found in model");
0053 }
0054 fInputSizes.push_back(ConvertShapeToLength(model.GetTensorShape(it)));
0055 }
0056
0057 if(fOutputNames.size() != fOutputShapes.size()){
0058 throw std::runtime_error("TMVA SOFIE Custom "+ fOpName + " Op was not intialized with the names/shapes of all the output tensors");
0059 }
0060
0061 for(long unsigned int i=0; i<fOutputNames.size(); ++i){
0062 model.AddIntermediateTensor(std::string(fOutputNames[i]), ETensorType::FLOAT, fOutputShapes[i]);
0063 }
0064
0065
0066 model.UpdateOutputTensorList(fInputNames, fOutputNames);
0067
0068 if (model.Verbose()) {
0069 std::cout << "Custom operator using " << fHeaderName;
0070 for (auto & i : fInputNames) std::cout << " " << i;
0071 std::cout << " ---> ";
0072 for (auto & i : fOutputNames) std::cout << " " << i;
0073 std::cout << "\n";
0074 }
0075 model.AddNeededCustomHeader("ROOT/RSpan.hxx");
0076 }
0077
0078 std::string Generate(std::string OpName) override {
0079 OpName = "op_" + OpName;
0080 std::stringstream out;
0081 out << "\n//------ "<<fOpName<<" \n";
0082 std::string args;
0083 for(long unsigned int i = 0; i<fInputNames.size(); ++i){
0084 args+="std::span<const "+ConvertTypeToString(fInputType)+">(tensor_"+std::string(fInputNames[i])+", "+fInputSizes[i]+"),";
0085 }
0086
0087 for(long unsigned int i = 0; i<fOutputNames.size(); ++i){
0088 args+="std::span<"+TensorType<T>::Name()+">(tensor_"+std::string(fOutputNames[i])+", "+ConvertShapeToLength(fOutputShapes[i])+"),";
0089 }
0090 args.pop_back();
0091 out << SP << fOpName<<"::Compute("+args+");\n";
0092 return out.str();
0093 }
0094
0095 };
0096
0097
0098 }
0099 }
0100 }
0101
0102
0103 #endif