Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:23:01

0001 #ifndef TMVA_SOFIE_ROPERATOR_RELU
0002 #define TMVA_SOFIE_ROPERATOR_RELU
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013 
0014 template <typename T>
0015 class ROperator_Relu final : public ROperator
0016 {
0017 
0018 private:
0019 
0020    std::string fNX;
0021    std::string fNY;
0022    std::vector<Dim> fShape;
0023 
0024 public:
0025    ROperator_Relu(){}
0026    ROperator_Relu(std::string nameX, std::string nameY):
0027       fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){}
0028 
0029    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0030       return input;
0031    }
0032 
0033    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0034       auto ret = input; //suggest copy to compiler
0035       return ret;
0036    }
0037 
0038    void Initialize(RModel& model){
0039       if (model.CheckIfTensorAlreadyExist(fNX) == false){   //input must be a graph input, or already initialized intermediate tensor
0040          throw std::runtime_error("TMVA SOFIE Relu Op Input Tensor " + fNX + " is not found in model");
0041       }
0042 
0043       fShape = model.GetDynamicTensorShape(fNX);
0044 
0045       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
0046    }
0047 
0048 
0049    std::string Generate(std::string OpName){
0050       OpName = "op_" + OpName;
0051       if (fShape.empty()) {
0052          throw std::runtime_error("TMVA SOFIE Operator Relu called to Generate without being initialized first");
0053       }
0054       std::stringstream out;
0055       auto length = ConvertDynamicShapeToLength(fShape);
0056       out << "\n//------ RELU\n";
0057       out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
0058       out << SP << SP << "tensor_" << fNY << "[id] = ((tensor_" << fNX << "[id] > 0 )? tensor_" << fNX << "[id] : 0);\n";
0059       out << SP << "}\n";
0060       return out.str();
0061    }
0062 
0063 };
0064 
0065 }//SOFIE
0066 }//Experimental
0067 }//TMVA
0068 
0069 
0070 #endif //TMVA_SOFIE_ROPERATOR_RELU