File indexing completed on 2025-09-18 09:32:33
0001 #ifndef TMVA_SOFIE_RMODELPARSER_ONNX
0002 #define TMVA_SOFIE_RMODELPARSER_ONNX
0003
0004 #include "TMVA/RModel.hxx"
0005
0006 #include <memory>
0007 #include <functional>
0008 #include <unordered_map>
0009
0010
0011 namespace onnx {
0012 class NodeProto;
0013 class GraphProto;
0014 class ModelProto;
0015 }
0016
0017 namespace TMVA {
0018 namespace Experimental {
0019 namespace SOFIE {
0020
0021 class RModelParser_ONNX;
0022
0023 using ParserFuncSignature =
0024 std::function<std::unique_ptr<ROperator>(RModelParser_ONNX & , const onnx::NodeProto & )>;
0025 using ParserFuseFuncSignature =
0026 std::function<std::unique_ptr<ROperator> (RModelParser_ONNX& , const onnx::NodeProto& , const onnx::NodeProto& )>;
0027
0028 class RModelParser_ONNX {
0029 public:
0030 struct OperatorsMapImpl;
0031
0032 private:
0033
0034 bool fVerbose = false;
0035
0036 std::unique_ptr<OperatorsMapImpl> fOperatorsMapImpl;
0037
0038 std::unordered_map<std::string, ETensorType> fTensorTypeMap;
0039
0040 std::vector<bool> fFusedOperators;
0041
0042
0043 public:
0044
0045 void RegisterOperator(const std::string &name, ParserFuncSignature func);
0046
0047
0048 bool IsRegisteredOperator(const std::string &name);
0049
0050
0051 std::vector<std::string> GetRegisteredOperators();
0052
0053
0054 void RegisterTensorType(const std::string & , ETensorType );
0055
0056
0057 bool IsRegisteredTensorType(const std::string & );
0058
0059
0060 bool Verbose() const {
0061 return fVerbose;
0062 }
0063
0064
0065 ETensorType GetTensorType(const std::string &name);
0066
0067
0068 std::unique_ptr<ROperator> ParseOperator(const size_t , const onnx::GraphProto & ,
0069 const std::vector<size_t> & , const std::vector<int> & );
0070
0071
0072 void CheckGraph(const onnx::GraphProto & g, int & level, std::map<std::string, int> & missingOperators);
0073
0074
0075 void ParseONNXGraph(RModel & model, const onnx::GraphProto & g, std::string name = "");
0076
0077 std::unique_ptr<onnx::ModelProto> LoadModel(std::string filename);
0078
0079 public:
0080
0081 RModelParser_ONNX() noexcept;
0082
0083 RModel Parse(std::string filename, bool verbose = false);
0084
0085
0086 bool CheckModel(std::string filename, bool verbose = false);
0087
0088 ~RModelParser_ONNX();
0089 };
0090
0091 }
0092 }
0093 }
0094
0095 #endif