Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 10:20:26

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Version converter interface for ONNX models between different opset versions.
0008 
0009 #pragma once
0010 
0011 #include <stdlib.h>
0012 
0013 #include <iostream>
0014 #include <map>
0015 #include <memory>
0016 #include <string>
0017 #include <unordered_map>
0018 #include <utility>
0019 
0020 #include "onnx/common/ir.h"
0021 #include "onnx/common/ir_pb_converter.h"
0022 #include "onnx/defs/schema.h"
0023 #include "onnx/proto_utils.h"
0024 #include "onnx/version_converter/adapters/adapter.h"
0025 
0026 namespace ONNX_NAMESPACE {
0027 namespace version_conversion {
0028 
0029 // TODO: Consider creating interface for this class.
0030 class BaseVersionConverter {
0031   // Schema for adapters: {<op_name>:{<from_domain>$<from_version>:{<to_domain>
0032   // <to_version>: adapter}}}
0033  protected:
0034   std::unordered_map<
0035       std::string,
0036       std::unordered_map<std::string, std::unordered_map<std::string, std::unique_ptr<Adapter>>>>
0037       adapters;
0038 
0039   // Map of All Versions of format {op_name: {domain: {version: schema}}}
0040   std::unordered_map<std::string, std::unordered_map<std::string, std::map<int64_t, const OpSchema*>>> all_schemas;
0041 
0042  public:
0043   BaseVersionConverter() = default;
0044 
0045   virtual ~BaseVersionConverter() = default;
0046 
0047   // adapter_lookup should be called in convert_version when the user would
0048   // like to identify the proper registered adapter in the adapters map for
0049   // a given Node from a certain version to another. It should only be called
0050   // when the user knows that an adapter should exist for the given context.
0051   const Adapter& adapter_lookup(const Node* op, const OpSetID& initial_version, const OpSetID& target_version) const {
0052     const std::string op_name = op->kind().toString();
0053     const std::string initial = initial_version.toString();
0054     const std::string target = target_version.toString();
0055     // Find appropriate adapter in adapters map for provided initial and target versions
0056     // TODO: Consider abstracting elements of this that are specific to
0057     // DefaultConverter to separate methods here and maintain the procedure in Base Converter
0058     const auto op_adapters = adapters.find(op_name);
0059     if (op_adapters != adapters.end()) {
0060       // If we're adapting downwards, we just want to find the one downwards
0061       // adapter implemented for initial_version. If we're adapting upwards, we
0062       // want to actually use the SinceVersion value for the given op.
0063       const auto target_map = op_adapters->second.find(initial);
0064       if (target_map != op_adapters->second.end()) {
0065         // Either adapt from SinceVersion or Incompatible Breaking Change
0066         const auto adapter_ptr = target_map->second.find(target);
0067         if (adapter_ptr != target_map->second.end()) {
0068           return *(adapter_ptr->second);
0069         } else {
0070           ONNX_ASSERTM(false, "No Adapter To Version %s for %s", target.c_str(), op_name.c_str());
0071         }
0072       } else {
0073         ONNX_ASSERTM(false, "No Adapter From Version %s for %s", initial.c_str(), op_name.c_str());
0074       }
0075     } else {
0076       // No adapters exist for the given op
0077       ONNX_ASSERTM(false, "No Adapter For %s", op_name.c_str());
0078     }
0079   }
0080 
0081   virtual ModelProto
0082   convert_version(const ModelProto& mp_in, const OpSetID& initial_version, const OpSetID& target_version) const = 0;
0083 
0084   void registerAdapter(std::unique_ptr<Adapter> a_ptr) {
0085     const OpSetID& iv = a_ptr->initial_version();
0086     const OpSetID& tv = a_ptr->target_version();
0087     adapters[a_ptr->name()][iv.toString()][tv.toString()] = std::move(a_ptr);
0088   }
0089 
0090   void registerAdapter(const char* op, int64_t from, int64_t to, NodeTransformerFunction transformer) {
0091     registerAdapter(std::make_unique<GenericAdapter>(op, from, to, transformer));
0092   }
0093 };
0094 
0095 } // namespace version_conversion
0096 } // namespace ONNX_NAMESPACE