File indexing completed on 2025-12-16 10:20:26
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <stdlib.h>
0012
0013 #include <iostream>
0014 #include <map>
0015 #include <memory>
0016 #include <string>
0017 #include <unordered_map>
0018 #include <utility>
0019
0020 #include "onnx/common/ir.h"
0021 #include "onnx/common/ir_pb_converter.h"
0022 #include "onnx/defs/schema.h"
0023 #include "onnx/proto_utils.h"
0024 #include "onnx/version_converter/adapters/adapter.h"
0025
0026 namespace ONNX_NAMESPACE {
0027 namespace version_conversion {
0028
0029
0030 class BaseVersionConverter {
0031
0032
0033 protected:
0034 std::unordered_map<
0035 std::string,
0036 std::unordered_map<std::string, std::unordered_map<std::string, std::unique_ptr<Adapter>>>>
0037 adapters;
0038
0039
0040 std::unordered_map<std::string, std::unordered_map<std::string, std::map<int64_t, const OpSchema*>>> all_schemas;
0041
0042 public:
0043 BaseVersionConverter() = default;
0044
0045 virtual ~BaseVersionConverter() = default;
0046
0047
0048
0049
0050
0051 const Adapter& adapter_lookup(const Node* op, const OpSetID& initial_version, const OpSetID& target_version) const {
0052 const std::string op_name = op->kind().toString();
0053 const std::string initial = initial_version.toString();
0054 const std::string target = target_version.toString();
0055
0056
0057
0058 const auto op_adapters = adapters.find(op_name);
0059 if (op_adapters != adapters.end()) {
0060
0061
0062
0063 const auto target_map = op_adapters->second.find(initial);
0064 if (target_map != op_adapters->second.end()) {
0065
0066 const auto adapter_ptr = target_map->second.find(target);
0067 if (adapter_ptr != target_map->second.end()) {
0068 return *(adapter_ptr->second);
0069 } else {
0070 ONNX_ASSERTM(false, "No Adapter To Version %s for %s", target.c_str(), op_name.c_str());
0071 }
0072 } else {
0073 ONNX_ASSERTM(false, "No Adapter From Version %s for %s", initial.c_str(), op_name.c_str());
0074 }
0075 } else {
0076
0077 ONNX_ASSERTM(false, "No Adapter For %s", op_name.c_str());
0078 }
0079 }
0080
0081 virtual ModelProto
0082 convert_version(const ModelProto& mp_in, const OpSetID& initial_version, const OpSetID& target_version) const = 0;
0083
0084 void registerAdapter(std::unique_ptr<Adapter> a_ptr) {
0085 const OpSetID& iv = a_ptr->initial_version();
0086 const OpSetID& tv = a_ptr->target_version();
0087 adapters[a_ptr->name()][iv.toString()][tv.toString()] = std::move(a_ptr);
0088 }
0089
0090 void registerAdapter(const char* op, int64_t from, int64_t to, NodeTransformerFunction transformer) {
0091 registerAdapter(std::make_unique<GenericAdapter>(op, from, to, transformer));
0092 }
0093 };
0094
0095 }
0096 }