Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Upsample in default domain from version 8 to 9
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <vector>
0013 
0014 #include "onnx/version_converter/adapters/adapter.h"
0015 
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018 
0019 struct Upsample_8_9 final : public Adapter {
0020   explicit Upsample_8_9() : Adapter("Upsample", OpSetID(8), OpSetID(9)) {}
0021 
0022   void adapt_upsample_8_9(std::shared_ptr<Graph> graph, Node* node) const {
0023     Symbol input_dirs = Symbol("scales");
0024     int dim = (int)(node->fs(kscales).size());
0025     Tensor t;
0026     t.elem_type() = TensorProto_DataType_FLOAT;
0027     t.sizes() = std::vector<int64_t>{dim};
0028     auto& data = t.floats();
0029 
0030     if (node->hasAttribute(input_dirs)) {
0031       for (double scale : node->fs(kscales)) {
0032         data.emplace_back((float)scale);
0033       }
0034 
0035       Node* constant = graph->create(kConstant);
0036       constant->insertBefore(node);
0037       constant->t_(kvalue, t);
0038       node->addInput(constant->output());
0039       node->removeAttribute(kscales);
0040     }
0041   }
0042 
0043   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0044     adapt_upsample_8_9(graph, node);
0045     return node;
0046   }
0047 };
0048 
0049 } // namespace version_conversion
0050 } // namespace ONNX_NAMESPACE