File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <vector>
0013
0014 namespace ONNX_NAMESPACE {
0015 namespace version_conversion {
0016
0017 class Resize_10_11 final : public Adapter {
0018 public:
0019 explicit Resize_10_11() : Adapter("Resize", OpSetID(10), OpSetID(11)) {}
0020
0021 void adapt_resize_10_11(std::shared_ptr<Graph> graph, Node* node) const {
0022 int input_rank = node->inputs()[0]->sizes().size();
0023
0024 Value* scales_input = node->inputs()[1];
0025 node->addInput(scales_input);
0026
0027 Tensor t;
0028 t.sizes() = std::vector<int64_t>{2 * input_rank};
0029 t.elem_type() = TensorProto_DataType_FLOAT;
0030 auto& data = t.floats();
0031
0032 for (int i = 0; i < input_rank; i++)
0033 data.emplace_back(0);
0034 for (int i = 0; i < input_rank; i++)
0035 data.emplace_back(1);
0036
0037 Node* constant = graph->create(kConstant);
0038 constant->insertBefore(node);
0039 constant->t_(kvalue, t);
0040 node->replaceInput(1, constant->output());
0041 }
0042
0043 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0044 adapt_resize_10_11(graph, node);
0045 return node;
0046 }
0047 };
0048
0049 }
0050 }