Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-18 09:32:33

0001 #ifndef TMVA_SOFIE_RMODELPARSER_ONNX
0002 #define TMVA_SOFIE_RMODELPARSER_ONNX
0003 
0004 #include "TMVA/RModel.hxx"
0005 
0006 #include <memory>
0007 #include <functional>
0008 #include <unordered_map>
0009 
0010 // forward declaration
0011 namespace onnx {
0012 class NodeProto;
0013 class GraphProto;
0014 class ModelProto;
0015 } // namespace onnx
0016 
0017 namespace TMVA {
0018 namespace Experimental {
0019 namespace SOFIE {
0020 
0021 class RModelParser_ONNX;
0022 
0023 using ParserFuncSignature =
0024    std::function<std::unique_ptr<ROperator>(RModelParser_ONNX & /*parser*/, const onnx::NodeProto & /*nodeproto*/)>;
0025 using ParserFuseFuncSignature =
0026    std::function<std::unique_ptr<ROperator> (RModelParser_ONNX& /*parser*/, const onnx::NodeProto& /*firstnode*/, const onnx::NodeProto& /*secondnode*/)>;
0027 
0028 class RModelParser_ONNX {
0029 public:
0030    struct OperatorsMapImpl;
0031 
0032 private:
0033 
0034    bool fVerbose = false;
0035    // Registered operators
0036    std::unique_ptr<OperatorsMapImpl> fOperatorsMapImpl;
0037    // Type of the tensors
0038    std::unordered_map<std::string, ETensorType> fTensorTypeMap;
0039    // flag list of fused operators
0040    std::vector<bool> fFusedOperators;
0041 
0042 
0043 public:
0044    // Register an ONNX operator
0045    void RegisterOperator(const std::string &name, ParserFuncSignature func);
0046 
0047    // Check if the operator is registered
0048    bool IsRegisteredOperator(const std::string &name);
0049 
0050    // List of registered operators (in alphabetical order)
0051    std::vector<std::string> GetRegisteredOperators();
0052 
0053    // Set the type of the tensor
0054    void RegisterTensorType(const std::string & /*name*/, ETensorType /*type*/);
0055 
0056    // Check if the type of the tensor is registered
0057    bool IsRegisteredTensorType(const std::string & /*name*/);
0058 
0059    // check verbosity
0060    bool Verbose() const {
0061       return fVerbose;
0062    }
0063 
0064    // Get the type of the tensor
0065    ETensorType GetTensorType(const std::string &name);
0066 
0067    // Parse the index'th node from the ONNX graph
0068    std::unique_ptr<ROperator> ParseOperator(const size_t /*index*/, const onnx::GraphProto & /*graphproto*/,
0069                                             const std::vector<size_t> & /*nodes*/, const std::vector<int> & /* children */);
0070 
0071    // check a graph for missing operators
0072    void CheckGraph(const onnx::GraphProto & g, int & level, std::map<std::string, int> & missingOperators);
0073 
0074    // parse the ONNX graph
0075    void ParseONNXGraph(RModel & model, const onnx::GraphProto & g, std::string  name = "");
0076 
0077    std::unique_ptr<onnx::ModelProto> LoadModel(std::string filename);
0078 
0079 public:
0080 
0081    RModelParser_ONNX() noexcept;
0082 
0083    RModel Parse(std::string filename, bool verbose = false);
0084 
0085    // check the model for missing operators - return false in case some operator implementation is missing
0086    bool CheckModel(std::string filename, bool verbose = false);
0087 
0088    ~RModelParser_ONNX();
0089 };
0090 
0091 } // namespace SOFIE
0092 } // namespace Experimental
0093 } // namespace TMVA
0094 
0095 #endif // TMVA_SOFIE_RMODELPARSER_ONNX