File indexing completed on 2025-01-18 10:11:04
0001 #ifndef TMVA_SOFIE_RMODELPARSER_ONNX
0002 #define TMVA_SOFIE_RMODELPARSER_ONNX
0003
0004 #include "TMVA/RModel.hxx"
0005
0006 #include <memory>
0007 #include <functional>
0008 #include <unordered_map>
0009
0010
0011 namespace onnx {
0012 class NodeProto;
0013 class GraphProto;
0014 }
0015
0016 namespace TMVA {
0017 namespace Experimental {
0018 namespace SOFIE {
0019
0020 class RModelParser_ONNX;
0021
0022 using ParserFuncSignature =
0023 std::function<std::unique_ptr<ROperator>(RModelParser_ONNX & , const onnx::NodeProto & )>;
0024 using ParserFuseFuncSignature =
0025 std::function<std::unique_ptr<ROperator> (RModelParser_ONNX& , const onnx::NodeProto& , const onnx::NodeProto& )>;
0026
0027 class RModelParser_ONNX {
0028 public:
0029 struct OperatorsMapImpl;
0030
0031 private:
0032 bool fVerbose = false;
0033
0034 std::unique_ptr<OperatorsMapImpl> fOperatorsMapImpl;
0035
0036 std::unordered_map<std::string, ETensorType> fTensorTypeMap;
0037
0038 public:
0039
0040 void RegisterOperator(const std::string &name, ParserFuncSignature func);
0041
0042
0043 bool IsRegisteredOperator(const std::string &name);
0044
0045
0046 std::vector<std::string> GetRegisteredOperators();
0047
0048
0049 void RegisterTensorType(const std::string & , ETensorType );
0050
0051
0052 bool IsRegisteredTensorType(const std::string & );
0053
0054
0055 ETensorType GetTensorType(const std::string &name);
0056
0057
0058 std::unique_ptr<ROperator> ParseOperator(const size_t , const onnx::GraphProto & ,
0059 const std::vector<size_t> & );
0060
0061 public:
0062 RModelParser_ONNX() noexcept;
0063
0064 RModel Parse(std::string filename, bool verbose = false);
0065
0066 ~RModelParser_ONNX();
0067 };
0068
0069 }
0070 }
0071 }
0072
0073 #endif