Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-06 08:55:47

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Upsample in default domain from version 9 to 8
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <string>
0013 #include <vector>
0014 
0015 #include "onnx/defs/tensor_proto_util.h"
0016 #include "onnx/defs/tensor_util.h"
0017 #include "onnx/version_converter/adapters/adapter.h"
0018 
0019 namespace ONNX_NAMESPACE {
0020 namespace version_conversion {
0021 
0022 struct Upsample_9_8 final : public Adapter {
0023   explicit Upsample_9_8() : Adapter("Upsample", OpSetID(9), OpSetID(8)) {}
0024 
0025   void adapt_upsample_9_8(std::shared_ptr<Graph> graph, Node* node) const {
0026     const ArrayRef<Value*>& inputs = node->inputs();
0027     const std::vector<Tensor>& initializers = graph->initializers();
0028 
0029     ONNX_ASSERTM(inputs.size() == 2, "Upsample in opset 9 needs to have 2 inputs.");
0030     std::string scale_input_name = node->inputs()[1]->uniqueName();
0031 
0032     for (size_t i = 0; i < initializers.size(); i++) {
0033       if (initializers[i].name() == inputs[1]->uniqueName()) {
0034         std::vector<float> value = ParseData<float>(&initializers[i]);
0035         std::vector<double> d_values;
0036         d_values.reserve(value.size());
0037         for (size_t j = 0; j < value.size(); j++) {
0038           d_values.push_back(static_cast<double>(value[j]));
0039         }
0040         node->fs_(kscales, const_cast<std::vector<double>&&>(d_values));
0041 
0042         node->removeInput(1);
0043         graph->eraseInitializer(initializers[i].name());
0044         for (size_t j = 0; j < graph->inputs().size(); j++) {
0045           if (graph->inputs()[j]->uniqueName() == scale_input_name) {
0046             graph->eraseInput(j);
0047             break;
0048           }
0049         }
0050         return;
0051       }
0052     }
0053 
0054     for (Node* op : graph->nodes()) {
0055       if (op->kind() == kConstant && op->outputs()[0]->uniqueName() == scale_input_name) {
0056         std::vector<float> value = ParseData<float>(&op->t(kvalue));
0057         std::vector<double> d_values;
0058         d_values.reserve(value.size());
0059         for (size_t j = 0; j < value.size(); j++) {
0060           d_values.push_back(static_cast<double>(value[j]));
0061         }
0062         node->fs_(kscales, const_cast<std::vector<double>&&>(d_values));
0063         node->removeInput(1);
0064         op->destroy();
0065         return;
0066       }
0067     }
0068 
0069     ONNX_ASSERTM(false, "Unsuppported conversion due to unavailable input: scale");
0070   }
0071 
0072   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0073     adapt_upsample_9_8(graph, node);
0074     return node;
0075   }
0076 };
0077 
0078 } // namespace version_conversion
0079 } // namespace ONNX_NAMESPACE