Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Resize in default domain from version 10 to 11
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <vector>
0013 
0014 namespace ONNX_NAMESPACE {
0015 namespace version_conversion {
0016 
0017 class Resize_10_11 final : public Adapter {
0018  public:
0019   explicit Resize_10_11() : Adapter("Resize", OpSetID(10), OpSetID(11)) {}
0020 
0021   void adapt_resize_10_11(std::shared_ptr<Graph> graph, Node* node) const {
0022     int input_rank = node->inputs()[0]->sizes().size();
0023 
0024     Value* scales_input = node->inputs()[1];
0025     node->addInput(scales_input);
0026 
0027     Tensor t;
0028     t.sizes() = std::vector<int64_t>{2 * input_rank};
0029     t.elem_type() = TensorProto_DataType_FLOAT;
0030     auto& data = t.floats();
0031 
0032     for (int i = 0; i < input_rank; i++)
0033       data.emplace_back(0);
0034     for (int i = 0; i < input_rank; i++)
0035       data.emplace_back(1);
0036 
0037     Node* constant = graph->create(kConstant);
0038     constant->insertBefore(node);
0039     constant->t_(kvalue, t);
0040     node->replaceInput(1, constant->output());
0041   }
0042 
0043   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0044     adapt_resize_10_11(graph, node);
0045     return node;
0046   }
0047 };
0048 
0049 } // namespace version_conversion
0050 } // namespace ONNX_NAMESPACE