Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:58

0001 #ifndef TMVA_SOFIE_ROPERATOR_Cast
0002 #define TMVA_SOFIE_ROPERATOR_Cast
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013 
0014 
0015 template <typename T>
0016 class ROperator_Cast final : public ROperator
0017 {
0018 
0019 private:
0020 
0021    std::string fNX;
0022    std::string fNY;
0023    std::vector<size_t> fShape;
0024    std::string fAttrType = "float";
0025 
0026 public:
0027    ROperator_Cast(){}
0028    ROperator_Cast(std::string attr_type,std::string nameX, std::string nameY):
0029    fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)),
0030    fAttrType(attr_type) {}
0031 
0032    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0033       return input;
0034    }
0035 
0036    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0037       auto ret = input; //suggest copy to compiler
0038       return ret;
0039    }
0040 
0041    void Initialize(RModel& model){
0042        //input must be a graph input, or already initialized intermediate tensor
0043       if (model.CheckIfTensorAlreadyExist(fNX) == false){
0044         throw std::runtime_error("TMVA SOFIE Cast Op Input Tensor is not found in model");
0045       }
0046       fShape = model.GetTensorShape(fNX);
0047       model.AddIntermediateTensor(fNY, ConvertStringToType(fAttrType), fShape);
0048    }
0049 
0050 
0051    std::string Generate(std::string OpName){
0052       OpName = "op_" + OpName;
0053       if (fShape.empty()) {
0054          throw std::runtime_error("TMVA SOFIE Cast called to Generate without being initialized first");
0055       }
0056       std::stringstream out;
0057       size_t length = ConvertShapeToLength(fShape);
0058 
0059       // out << SP << ETensorType << " " << OpName << "_attr = "  << fattr << ";\n";
0060       out << "\n//------ CAST\n";
0061       out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
0062 
0063       out << SP << SP << "tensor_" << fNY << "[id] = static_cast<"<< fAttrType << ">(tensor_" << fNX << "[id]);\n";
0064 
0065       out << SP << "}\n";
0066       return out.str();
0067    }
0068 
0069 };
0070 
0071 }//SOFIE
0072 }//Experimental
0073 }//TMVA
0074 
0075 
0076 #endif //TMVA_SOFIE_ROPERATOR_Cast