File indexing completed on 2025-04-03 08:57:54
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <functional>
0012 #include <memory>
0013 #include <string>
0014
0015 #include "onnx/onnx_pb.h"
0016 #include "onnx/version_converter/helper.h"
0017
0018 namespace ONNX_NAMESPACE {
0019 namespace version_conversion {
0020
0021 class Adapter {
0022 private:
0023 std::string name_;
0024 OpSetID initial_version_;
0025 OpSetID target_version_;
0026
0027 public:
0028 virtual ~Adapter() noexcept = default;
0029
0030 explicit Adapter(const std::string& name, const OpSetID& initial_version, const OpSetID& target_version)
0031 : name_(name), initial_version_(initial_version), target_version_(target_version) {}
0032
0033
0034
0035
0036
0037 virtual Node* adapt(std::shared_ptr<Graph> , Node* node) const = 0;
0038
0039 const std::string& name() const {
0040 return name_;
0041 }
0042
0043 const OpSetID& initial_version() const {
0044 return initial_version_;
0045 }
0046
0047 const OpSetID& target_version() const {
0048 return target_version_;
0049 }
0050 };
0051
0052 using NodeTransformerFunction = std::function<Node*(std::shared_ptr<Graph>, Node* node)>;
0053
0054 class GenericAdapter final : public Adapter {
0055 public:
0056 GenericAdapter(const char* op, int64_t from, int64_t to, NodeTransformerFunction transformer)
0057 : Adapter(op, OpSetID(from), OpSetID(to)), transformer_(transformer) {}
0058
0059 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0060 return transformer_(graph, node);
0061 }
0062
0063 private:
0064 NodeTransformerFunction transformer_;
0065 };
0066
0067 }
0068 }