File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <utility>
0013 #include <vector>
0014
0015 #include "onnx/version_converter/adapters/adapter.h"
0016
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019
0020 struct Upsample_6_7 final : public Adapter {
0021 explicit Upsample_6_7() : Adapter("Upsample", OpSetID(6), OpSetID(7)) {}
0022
0023 void adapt_upsample_6_7(std::shared_ptr<Graph>, Node* node) const {
0024 Symbol width_scale_symbol = Symbol("width_scale");
0025 Symbol height_scale_symbol = Symbol("height_scale");
0026 ONNX_ASSERTM(
0027 node->hasAttribute(width_scale_symbol) && node->hasAttribute(height_scale_symbol),
0028 "Upsample in opset 1 needs to have width_scale and height_scale attributes");
0029
0030 auto width_scale = node->f(width_scale_symbol);
0031 auto height_scale = node->f(height_scale_symbol);
0032
0033 auto input_shape = node->inputs()[0]->sizes();
0034 ONNX_ASSERTM(input_shape.size() == 4, "Upsample in opset 1 supports only 4D input tensor");
0035 std::vector<double> scales = {1.0, 1.0, height_scale, width_scale};
0036
0037 node->fs_(kscales, std::move(scales));
0038 node->removeAttribute(width_scale_symbol);
0039 node->removeAttribute(height_scale_symbol);
0040 }
0041
0042 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0043 adapt_upsample_6_7(graph, node);
0044 return node;
0045 }
0046 };
0047
0048 }
0049 }