File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012
0013 #include "onnx/version_converter/adapters/remove_consumed_inputs.h"
0014
0015 namespace ONNX_NAMESPACE {
0016 namespace version_conversion {
0017
0018 class Reshape_4_5 final : public RemoveConsumedInputs {
0019 public:
0020 explicit Reshape_4_5() : RemoveConsumedInputs("Reshape", OpSetID(4), OpSetID(5)) {}
0021
0022 void adapt_reshape_4_5(std::shared_ptr<Graph> graph, Node* node) const {
0023
0024
0025 Tensor t;
0026 t.elem_type() = TensorProto_DataType_INT64;
0027 auto& data = t.int64s();
0028
0029 for (int64_t shape : node->is(kshape)) {
0030 data.emplace_back(shape);
0031 }
0032
0033 Node* constant = graph->create(kConstant);
0034 constant->insertBefore(node);
0035 constant->t_(kvalue, t);
0036 node->addInput(constant->output());
0037
0038 node->removeAttribute(kshape);
0039 }
0040
0041 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0042 RemoveConsumedInputs::adapt(graph, node);
0043 adapt_reshape_4_5(graph, node);
0044 return node;
0045 }
0046 };
0047
0048 }
0049 }