File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <vector>
0013
0014 #include "onnx/version_converter/adapters/adapter.h"
0015
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018
0019 struct Upsample_8_9 final : public Adapter {
0020 explicit Upsample_8_9() : Adapter("Upsample", OpSetID(8), OpSetID(9)) {}
0021
0022 void adapt_upsample_8_9(std::shared_ptr<Graph> graph, Node* node) const {
0023 Symbol input_dirs = Symbol("scales");
0024 int dim = (int)(node->fs(kscales).size());
0025 Tensor t;
0026 t.elem_type() = TensorProto_DataType_FLOAT;
0027 t.sizes() = std::vector<int64_t>{dim};
0028 auto& data = t.floats();
0029
0030 if (node->hasAttribute(input_dirs)) {
0031 for (double scale : node->fs(kscales)) {
0032 data.emplace_back((float)scale);
0033 }
0034
0035 Node* constant = graph->create(kConstant);
0036 constant->insertBefore(node);
0037 constant->t_(kvalue, t);
0038 node->addInput(constant->output());
0039 node->removeAttribute(kscales);
0040 }
0041 }
0042
0043 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0044 adapt_upsample_8_9(graph, node);
0045 return node;
0046 }
0047 };
0048
0049 }
0050 }