File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <string>
0013 #include <utility>
0014
0015 #include "onnx/version_converter/adapters/adapter.h"
0016
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019
0020 class Reshape_5_4 final : public Adapter {
0021 public:
0022 explicit Reshape_5_4() : Adapter("Reshape", OpSetID(5), OpSetID(4)) {}
0023
0024 void adapt_reshape_5_4(std::shared_ptr<Graph> graph, Node* node) const {
0025
0026 const ArrayRef<Value*>& inputs = node->inputs();
0027
0028
0029 Value* const_val = inputs[1];
0030 Node* node_ptr = const_val->node();
0031 if (node_ptr->kind() == kConstant) {
0032
0033 const std::vector<int64_t>& int64s = node_ptr->t(kvalue).int64s();
0034 if (int64s.empty()) {
0035
0036 std::string raw_data = node_ptr->t(kvalue).raw();
0037 ONNX_ASSERTM(
0038 raw_data.size() != 0 && raw_data.size() % 8 == 0,
0039 "Raw Data must be non-empty and size must be a multiple of 8");
0040 int64_t* raw = (int64_t*)const_cast<char*>(raw_data.c_str());
0041 node->is_(kshape, std::vector<int64_t>(raw, raw + node_ptr->t(kvalue).size_from_dim(0)));
0042 } else {
0043 node->is_(kshape, std::forward<const std::vector<int64_t>>(int64s));
0044 }
0045
0046 node->removeInput(1);
0047 if (const_val->uses().size() < 1) {
0048 node_ptr->destroy();
0049 }
0050 } else {
0051
0052 for (const auto& initializer : graph->initializers()) {
0053 if (initializer.name() == inputs[1]->uniqueName()) {
0054 node->is_(kshape, std::forward<const std::vector<int64_t>>(initializer.int64s()));
0055 node->removeInput(1);
0056
0057 if (const_val->uses().size() < 1)
0058 graph->eraseInitializerAndInput(const_val);
0059 break;
0060 }
0061 }
0062 }
0063 ONNX_ASSERTM(node->hasAttribute(kshape), "No initializer or constant input to Reshape node found");
0064 }
0065
0066 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0067 adapt_reshape_5_4(graph, node);
0068 return node;
0069 }
0070 };
0071
0072 }
0073 }