Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:04

0001 #ifndef TMVA_SOFIE_RMODELPARSER_ONNX
0002 #define TMVA_SOFIE_RMODELPARSER_ONNX
0003 
0004 #include "TMVA/RModel.hxx"
0005 
0006 #include <memory>
0007 #include <functional>
0008 #include <unordered_map>
0009 
0010 // forward declaration
0011 namespace onnx {
0012 class NodeProto;
0013 class GraphProto;
0014 } // namespace onnx
0015 
0016 namespace TMVA {
0017 namespace Experimental {
0018 namespace SOFIE {
0019 
0020 class RModelParser_ONNX;
0021 
0022 using ParserFuncSignature =
0023    std::function<std::unique_ptr<ROperator>(RModelParser_ONNX & /*parser*/, const onnx::NodeProto & /*nodeproto*/)>;
0024 using ParserFuseFuncSignature =
0025    std::function<std::unique_ptr<ROperator> (RModelParser_ONNX& /*parser*/, const onnx::NodeProto& /*firstnode*/, const onnx::NodeProto& /*secondnode*/)>;
0026 
0027 class RModelParser_ONNX {
0028 public:
0029    struct OperatorsMapImpl;
0030 
0031 private:
0032    bool fVerbose = false;
0033    // Registered operators
0034    std::unique_ptr<OperatorsMapImpl> fOperatorsMapImpl;
0035    // Type of the tensors
0036    std::unordered_map<std::string, ETensorType> fTensorTypeMap;
0037 
0038 public:
0039    // Register an ONNX operator
0040    void RegisterOperator(const std::string &name, ParserFuncSignature func);
0041 
0042    // Check if the operator is registered
0043    bool IsRegisteredOperator(const std::string &name);
0044 
0045    // List of registered operators
0046    std::vector<std::string> GetRegisteredOperators();
0047 
0048    // Set the type of the tensor
0049    void RegisterTensorType(const std::string & /*name*/, ETensorType /*type*/);
0050 
0051    // Check if the type of the tensor is registered
0052    bool IsRegisteredTensorType(const std::string & /*name*/);
0053 
0054    // Get the type of the tensor
0055    ETensorType GetTensorType(const std::string &name);
0056 
0057    // Parse the index'th node from the ONNX graph
0058    std::unique_ptr<ROperator> ParseOperator(const size_t /*index*/, const onnx::GraphProto & /*graphproto*/,
0059                                             const std::vector<size_t> & /*nodes*/);
0060 
0061 public:
0062    RModelParser_ONNX() noexcept;
0063 
0064    RModel Parse(std::string filename, bool verbose = false);
0065 
0066    ~RModelParser_ONNX();
0067 };
0068 
0069 } // namespace SOFIE
0070 } // namespace Experimental
0071 } // namespace TMVA
0072 
0073 #endif // TMVA_SOFIE_RMODELPARSER_ONNX