Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Upsample in default domain from version 6 to 7
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <utility>
0013 #include <vector>
0014 
0015 #include "onnx/version_converter/adapters/adapter.h"
0016 
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019 
0020 struct Upsample_6_7 final : public Adapter {
0021   explicit Upsample_6_7() : Adapter("Upsample", OpSetID(6), OpSetID(7)) {}
0022 
0023   void adapt_upsample_6_7(std::shared_ptr<Graph>, Node* node) const {
0024     Symbol width_scale_symbol = Symbol("width_scale");
0025     Symbol height_scale_symbol = Symbol("height_scale");
0026     ONNX_ASSERTM(
0027         node->hasAttribute(width_scale_symbol) && node->hasAttribute(height_scale_symbol),
0028         "Upsample in opset 1 needs to have width_scale and height_scale attributes");
0029 
0030     auto width_scale = node->f(width_scale_symbol);
0031     auto height_scale = node->f(height_scale_symbol);
0032 
0033     auto input_shape = node->inputs()[0]->sizes();
0034     ONNX_ASSERTM(input_shape.size() == 4, "Upsample in opset 1 supports only 4D input tensor");
0035     std::vector<double> scales = {1.0, 1.0, height_scale, width_scale};
0036 
0037     node->fs_(kscales, std::move(scales));
0038     node->removeAttribute(width_scale_symbol);
0039     node->removeAttribute(height_scale_symbol);
0040   }
0041 
0042   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0043     adapt_upsample_6_7(graph, node);
0044     return node;
0045   }
0046 };
0047 
0048 } // namespace version_conversion
0049 } // namespace ONNX_NAMESPACE