File indexing completed on 2025-04-06 08:55:47
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <string>
0013 #include <vector>
0014
0015 #include "onnx/defs/tensor_proto_util.h"
0016 #include "onnx/defs/tensor_util.h"
0017 #include "onnx/version_converter/adapters/adapter.h"
0018
0019 namespace ONNX_NAMESPACE {
0020 namespace version_conversion {
0021
0022 struct Upsample_9_8 final : public Adapter {
0023 explicit Upsample_9_8() : Adapter("Upsample", OpSetID(9), OpSetID(8)) {}
0024
0025 void adapt_upsample_9_8(std::shared_ptr<Graph> graph, Node* node) const {
0026 const ArrayRef<Value*>& inputs = node->inputs();
0027 const std::vector<Tensor>& initializers = graph->initializers();
0028
0029 ONNX_ASSERTM(inputs.size() == 2, "Upsample in opset 9 needs to have 2 inputs.");
0030 std::string scale_input_name = node->inputs()[1]->uniqueName();
0031
0032 for (size_t i = 0; i < initializers.size(); i++) {
0033 if (initializers[i].name() == inputs[1]->uniqueName()) {
0034 std::vector<float> value = ParseData<float>(&initializers[i]);
0035 std::vector<double> d_values;
0036 d_values.reserve(value.size());
0037 for (size_t j = 0; j < value.size(); j++) {
0038 d_values.push_back(static_cast<double>(value[j]));
0039 }
0040 node->fs_(kscales, const_cast<std::vector<double>&&>(d_values));
0041
0042 node->removeInput(1);
0043 graph->eraseInitializer(initializers[i].name());
0044 for (size_t j = 0; j < graph->inputs().size(); j++) {
0045 if (graph->inputs()[j]->uniqueName() == scale_input_name) {
0046 graph->eraseInput(j);
0047 break;
0048 }
0049 }
0050 return;
0051 }
0052 }
0053
0054 for (Node* op : graph->nodes()) {
0055 if (op->kind() == kConstant && op->outputs()[0]->uniqueName() == scale_input_name) {
0056 std::vector<float> value = ParseData<float>(&op->t(kvalue));
0057 std::vector<double> d_values;
0058 d_values.reserve(value.size());
0059 for (size_t j = 0; j < value.size(); j++) {
0060 d_values.push_back(static_cast<double>(value[j]));
0061 }
0062 node->fs_(kscales, const_cast<std::vector<double>&&>(d_values));
0063 node->removeInput(1);
0064 op->destroy();
0065 return;
0066 }
0067 }
0068
0069 ONNX_ASSERTM(false, "Unsuppported conversion due to unavailable input: scale");
0070 }
0071
0072 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0073 adapt_upsample_9_8(graph, node);
0074 return node;
0075 }
0076 };
0077
0078 }
0079 }