Warning, file /include/root/TMVA/RModelParser_ONNX.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_RMODELPARSER_ONNX
0002 #define TMVA_SOFIE_RMODELPARSER_ONNX
0003 
0004 #include "TMVA/RModel.hxx"
0005 
0006 #include <memory>
0007 #include <functional>
0008 #include <unordered_map>
0009 
0010 
0011 namespace onnx {
0012 class NodeProto;
0013 class GraphProto;
0014 class ModelProto;
0015 } 
0016 
0017 namespace TMVA {
0018 namespace Experimental {
0019 namespace SOFIE {
0020 
0021 class RModelParser_ONNX;
0022 
0023 using ParserFuncSignature =
0024    std::function<std::unique_ptr<ROperator>(RModelParser_ONNX & , const onnx::NodeProto & )>;
0025 using ParserFuseFuncSignature =
0026    std::function<std::unique_ptr<ROperator> (RModelParser_ONNX& , const onnx::NodeProto& , const onnx::NodeProto& )>;
0027 
0028 class RModelParser_ONNX {
0029 public:
0030    struct OperatorsMapImpl;
0031 
0032 private:
0033 
0034    bool fVerbose = false;
0035    
0036    std::unique_ptr<OperatorsMapImpl> fOperatorsMapImpl;
0037    
0038    std::unordered_map<std::string, ETensorType> fTensorTypeMap;
0039    
0040    std::vector<bool> fFusedOperators;
0041 
0042 
0043 public:
0044    
0045    void RegisterOperator(const std::string &name, ParserFuncSignature func);
0046 
0047    
0048    bool IsRegisteredOperator(const std::string &name);
0049 
0050    
0051    std::vector<std::string> GetRegisteredOperators();
0052 
0053    
0054    void RegisterTensorType(const std::string & , ETensorType );
0055 
0056    
0057    bool IsRegisteredTensorType(const std::string & );
0058 
0059    
0060    bool Verbose() const {
0061       return fVerbose;
0062    }
0063 
0064    
0065    ETensorType GetTensorType(const std::string &name);
0066 
0067    
0068    std::unique_ptr<ROperator> ParseOperator(const size_t , const onnx::GraphProto & ,
0069                                             const std::vector<size_t> & , const std::vector<int> & );
0070 
0071    
0072    void CheckGraph(const onnx::GraphProto & g, int & level, std::map<std::string, int> & missingOperators);
0073 
0074    
0075    void ParseONNXGraph(RModel & model, const onnx::GraphProto & g, std::string  name = "");
0076 
0077    std::unique_ptr<onnx::ModelProto> LoadModel(std::string filename);
0078 
0079 public:
0080 
0081    RModelParser_ONNX() noexcept;
0082 
0083    RModel Parse(std::string filename, bool verbose = false);
0084 
0085    
0086    bool CheckModel(std::string filename, bool verbose = false);
0087 
0088    ~RModelParser_ONNX();
0089 };
0090 
0091 } 
0092 } 
0093 } 
0094 
0095 #endif