Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_Constant.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROPERATOR_Constant
0002 #define TMVA_SOFIE_ROPERATOR_Constant
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013 
0014 template<typename T>
0015 class ROperator_Constant final : public ROperator
0016 {
0017 
0018 private:
0019 
0020    std::string fNX;
0021    std::string fNY;
0022    std::vector<size_t> fShape;
0023    std::vector<T> fValues;
0024    std::string fAttrType;
0025    bool fIsConstantOfShape = false;
0026 
0027 public:
0028    ROperator_Constant(){}
0029 
0030    ROperator_Constant(const std::string & type, const std::vector<T> & values, const std::vector<size_t> & shape, std::string nameX, std::string nameY):
0031       fNX(UTILITY::Clean_name(nameX)),
0032       fNY(UTILITY::Clean_name(nameY)),
0033       fShape(shape),
0034       fValues(values),
0035       fAttrType(type)
0036       {
0037          fInputTensorNames = { };
0038          fOutputTensorNames = { };
0039       }
0040 
0041    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0042       return input;
0043    }
0044 
0045    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0046       auto ret = input; //suggest copy to compiler
0047       return ret;
0048    }
0049 
0050    void Initialize(RModel& model) override {
0051        //input must be a graph input, or already initialized intermediate tensor
0052       size_t length = 1;
0053       if (!fNX.empty()) {
0054          // case of ConstantOfShape (since no inputs in case of Constant operator)
0055          fIsConstantOfShape  = true;
0056          if (model.CheckIfTensorAlreadyExist(fNX) == false){
0057            throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor is not found in model");
0058          }
0059          // get output shape from input values:
0060          // can work only if input is a constant or initialized tensor (or dynamic one)
0061          auto dptr = model.GetInitializedTensorData(fNX);
0062          auto input_tensor = static_cast<int64_t *>(dptr.get());
0063          auto input_shape = model.GetTensorShape(fNX);
0064          if (input_shape.size() > 1 )
0065             throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor has invalid shape");
0066          if (input_tensor != nullptr && !input_shape.empty()) {
0067             fShape = std::vector<size_t> (input_shape[0]);
0068             for (size_t i = 0; i < fShape.size(); i++)
0069                fShape[i] = input_tensor[i];
0070          } else
0071             fShape = {1};  // scalar case
0072 
0073          length = ConvertShapeToLength(fShape);
0074          if (fValues.size() != 1)
0075             throw std::runtime_error("TMVA SOFIE ConstantOfShape Op value Tensor has invalid size " + std::to_string(fValues.size()));
0076 
0077          T value = fValues[0];
0078          fValues = std::vector<T>(length, value);
0079 
0080       } else {
0081          // case of constant operator
0082          // in case of standard constant the shape is provided as input
0083          length = ConvertShapeToLength(fShape);
0084          if (length != fValues.size())
0085             throw std::runtime_error("TMVA SOFIE Constant Op has invalid shape : " + ConvertShapeToString(fShape) +
0086                                  " with " + std::to_string(fValues.size()) + " values");
0087       }
0088 
0089       // we need to create an initialized tensor of type constant to flag to not save it in a weight file
0090       // but keep its initialization in the generated code. The values might also be needed in initializing the
0091       // following operators using as input Constant or ConstantOfShape
0092        // resize fValues to shape length
0093       model.AddConstantTensor(fNY, fShape, fValues);
0094       if (model.Verbose()) {
0095          std::cout << "adding constant tensor " << fNY << " with shape " << ConvertShapeToString(fShape)
0096          << " and values [";
0097          for (auto v : fValues) std::cout << " " << v;
0098          std::cout << "]" << std::endl;
0099       }
0100    }
0101 
0102    std::string Generate(std::string /* OpName */) override {
0103       // no code to generate here. Tensor are defined in Session constructor
0104       return "//---------------------------------------\n";
0105    }
0106 };
0107 
0108 }//SOFIE
0109 }//Experimental
0110 }//TMVA
0111 
0112 
0113 #endif //TMVA_SOFIE_ROPERATOR_Constant