Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:54

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Interface for Op Version Adapters
0008 
0009 #pragma once
0010 
0011 #include <functional>
0012 #include <memory>
0013 #include <string>
0014 
0015 #include "onnx/onnx_pb.h"
0016 #include "onnx/version_converter/helper.h"
0017 
0018 namespace ONNX_NAMESPACE {
0019 namespace version_conversion {
0020 
0021 class Adapter {
0022  private:
0023   std::string name_;
0024   OpSetID initial_version_;
0025   OpSetID target_version_;
0026 
0027  public:
0028   virtual ~Adapter() noexcept = default;
0029 
0030   explicit Adapter(const std::string& name, const OpSetID& initial_version, const OpSetID& target_version)
0031       : name_(name), initial_version_(initial_version), target_version_(target_version) {}
0032 
0033   // This will almost always return its own node argument after modifying it in place.
0034   // The only exception are adapters for deprecated operators: in this case the input
0035   // node must be destroyed and a new one must be created and returned. See e.g.
0036   // upsample_9_10.h
0037   virtual Node* adapt(std::shared_ptr<Graph> /*graph*/, Node* node) const = 0;
0038 
0039   const std::string& name() const {
0040     return name_;
0041   }
0042 
0043   const OpSetID& initial_version() const {
0044     return initial_version_;
0045   }
0046 
0047   const OpSetID& target_version() const {
0048     return target_version_;
0049   }
0050 };
0051 
0052 using NodeTransformerFunction = std::function<Node*(std::shared_ptr<Graph>, Node* node)>;
0053 
0054 class GenericAdapter final : public Adapter {
0055  public:
0056   GenericAdapter(const char* op, int64_t from, int64_t to, NodeTransformerFunction transformer)
0057       : Adapter(op, OpSetID(from), OpSetID(to)), transformer_(transformer) {}
0058 
0059   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0060     return transformer_(graph, node);
0061   }
0062 
0063  private:
0064   NodeTransformerFunction transformer_;
0065 };
0066 
0067 } // namespace version_conversion
0068 } // namespace ONNX_NAMESPACE