Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:04

0001 #ifndef TMVA_SOFIE_RMODEL
0002 #define TMVA_SOFIE_RMODEL
0003 
0004 #include "TMVA/RModel_Base.hxx"
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 
0008 namespace TMVA {
0009 namespace Experimental {
0010 namespace SOFIE {
0011 
0012 class RModel final : public RModel_Base {
0013 
0014 private:
0015    std::unordered_map<std::string, InputTensorInfo>
0016       fInputTensorInfos; // input tensors where shape is not defined or other graph inputs?
0017    std::unordered_map<std::string, TensorInfo> fReadyInputTensorInfos; // input tensors where shape is full defined
0018    std::unordered_map<std::string, InitializedTensor> fInitializedTensors;
0019    std::unordered_map<std::string, TensorInfo> fIntermediateTensorInfos;
0020    std::unordered_map<std::string, DynamicTensorInfo> fDynamicTensorInfos;
0021    std::unordered_map<std::string, std::string>
0022       fShapeParams; // parameters defining the dynamic shape (e.g. batch size), store also its default value
0023    std::vector<std::string> fOutputTensorNames;
0024    std::vector<std::string> fInputTensorNames; // input tensor names using ONNX order
0025 
0026    std::vector<std::unique_ptr<ROperator>> fOperators;
0027 
0028    const std::string SP = "   ";
0029 
0030 public:
0031    // Rule of five: explicitly define move semantics, disallow copy
0032    RModel(RModel &&other);
0033    RModel &operator=(RModel &&other);
0034    RModel(const RModel &other) = delete;
0035    RModel &operator=(const RModel &other) = delete;
0036    ~RModel() = default;
0037 
0038    /**
0039        Default constructor. Needed to allow serialization of ROOT objects. See
0040        https://root.cern/manual/io_custom_classes/#restrictions-on-types-root-io-can-handle
0041    */
0042    RModel() = default;
0043    RModel(std::string name, std::string parsedtime) : RModel_Base(name, parsedtime) {}
0044 
0045    // For GNN Functions usage
0046    RModel(std::string function_name) : RModel_Base(function_name) {}
0047 
0048    const std::vector<size_t> &GetTensorShape(std::string name);
0049    std::vector<Dim> GetDynamicTensorShape(std::string name);
0050    const ETensorType &GetTensorType(std::string name);
0051 
0052    bool CheckIfTensorAlreadyExist(std::string tensor_name);
0053    void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector<Dim> shape);
0054    void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector<size_t> shape);
0055    void AddOperator(std::unique_ptr<ROperator> op, int order_execution = -1);
0056    void AddOperatorReference(ROperator *op, int order_execution = -1)
0057    {
0058       std::unique_ptr<ROperator> tmp(op);
0059       AddOperator(std::move(tmp), order_execution);
0060    }
0061    void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape,
0062                              std::shared_ptr<void> data);
0063 
0064    template <typename T>
0065    void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape, T *raw_data)
0066    {
0067       int size = 1;
0068       for (auto item : shape) {
0069          size *= (int)item;
0070       }
0071       std::shared_ptr<void> data(malloc(size * sizeof(T)), free);
0072       std::memcpy(data.get(), raw_data, size * sizeof(T));
0073       AddInitializedTensor(tensor_name, type, shape, data);
0074    }
0075 
0076    // Check if a tensor is initialized
0077    bool IsInitializedTensor(const std::string &name) const;
0078    bool IsDynamicTensor(const std::string &name) const;
0079    bool IsInputTensor(const std::string &name) const;
0080 
0081    // Add intermediate tensor
0082    void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector<Dim> dim_shape);
0083    void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape);
0084    // Add an intermediate dynamic tensor
0085    void AddDynamicTensor(std::string tensor_name, ETensorType type, std::vector<Dim> shape);
0086 
0087    void AddInputTensorName(std::string name);
0088    void AddOutputTensorNameList(std::vector<std::string> output_tensor_names);
0089    void
0090    UpdateOutputTensorList(std::vector<std::string> curr_output_tensor, std::vector<std::string> modify_output_tensor);
0091    void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape,
0092                                 std::shared_ptr<void> data);
0093    std::shared_ptr<void> GetInitializedTensorData(std::string tensor_name);
0094 
0095    void Initialize(int batchSize = -1, bool verbose = false);
0096    void GenerateInitializedTensorInfo();
0097    void GenerateIntermediateTensorInfo();
0098    void GenerateDynamicTensorInfo();
0099    void GenerateOutput();
0100    void Generate(std::underlying_type_t<Options> options, int batchSize = -1, long pos = 0);
0101    void Generate(Options options = Options::kDefault, int batchSize = -1, int pos = 0)
0102    {
0103       Generate(static_cast<std::underlying_type_t<Options>>(options), batchSize, pos);
0104    }
0105 
0106    const std::vector<std::string> &GetInputTensorNames() const { return fInputTensorNames; }
0107    const std::vector<std::string> &GetOutputTensorNames() const { return fOutputTensorNames; }
0108 
0109    void ReadInitializedTensorsFromFile(long);
0110    long WriteInitializedTensorsToFile(std::string filename = "");
0111 
0112    void PrintIntermediateTensors();
0113    void PrintOutputTensors();
0114    void OutputGenerated(std::string filename = "", bool append = false);
0115    std::vector<std::string> GetOutputTensorNames() { return fOutputTensorNames; }
0116    void SetFilename(std::string filename) { fName = filename; }
0117 
0118    /*
0119       template <typename T>
0120       void AddInitializedTensor(std::string tensor_name, RTensor<T> new_tensor){
0121          //a view only
0122          T obj;
0123          if (fInitializedTensors.find(tensor_name) != fInitializedTensors.end()){
0124             throw std::runtime_error("TMVA-SOFIE: initialized tensor with name " + tensor_name + " already exists \n");
0125          }
0126          InitializedTensor new_tensor_ {GetTemplatedType(obj), new_tensor.GetShape() ,
0127       static_cast<void>(new_tensor.GetData())}; fInitializedTensors[tensor_name] = new_tensor_;
0128       }
0129    */
0130 
0131    void PrintRequiredInputTensors();
0132    void PrintInitializedTensors();
0133    void PrintDynamicTensors();
0134    void HeadInitializedTensors(std::string name, int n_print = 50);
0135 
0136    bool UseSession() const { return fUseSession; }
0137 
0138    // Use the ClassDef macro to allow definition of custom streaming
0139    ClassDefNV(RModel, 2);
0140 };
0141 
0142 } // namespace SOFIE
0143 } // namespace Experimental
0144 } // namespace TMVA
0145 
0146 #endif // TMVA_SOFIE_RMODEL