Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Reshape in default domain from version 5 to 4
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <string>
0013 #include <utility>
0014 
0015 #include "onnx/version_converter/adapters/adapter.h"
0016 
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019 
0020 class Reshape_5_4 final : public Adapter {
0021  public:
0022   explicit Reshape_5_4() : Adapter("Reshape", OpSetID(5), OpSetID(4)) {}
0023 
0024   void adapt_reshape_5_4(std::shared_ptr<Graph> graph, Node* node) const {
0025     // Identify if shape is statically determined; if so, feed as attribute
0026     const ArrayRef<Value*>& inputs = node->inputs();
0027     // Get shape from initializer or constant operator, not actual shape
0028     // Identify whether we have a Constant Op or an Initializer
0029     Value* const_val = inputs[1];
0030     Node* node_ptr = const_val->node();
0031     if (node_ptr->kind() == kConstant) {
0032       // Get value attribute of kConstant
0033       const std::vector<int64_t>& int64s = node_ptr->t(kvalue).int64s();
0034       if (int64s.empty()) {
0035         // Also handle raw data
0036         std::string raw_data = node_ptr->t(kvalue).raw();
0037         ONNX_ASSERTM(
0038             raw_data.size() != 0 && raw_data.size() % 8 == 0,
0039             "Raw Data must be non-empty and size must be a multiple of 8");
0040         int64_t* raw = (int64_t*)const_cast<char*>(raw_data.c_str());
0041         node->is_(kshape, std::vector<int64_t>(raw, raw + node_ptr->t(kvalue).size_from_dim(0)));
0042       } else {
0043         node->is_(kshape, std::forward<const std::vector<int64_t>>(int64s));
0044       }
0045       // If Constant node isn't used anywhere else, remove it
0046       node->removeInput(1);
0047       if (const_val->uses().size() < 1) {
0048         node_ptr->destroy();
0049       }
0050     } else {
0051       // Get Value name, find Initializer with same name
0052       for (const auto& initializer : graph->initializers()) {
0053         if (initializer.name() == inputs[1]->uniqueName()) {
0054           node->is_(kshape, std::forward<const std::vector<int64_t>>(initializer.int64s()));
0055           node->removeInput(1);
0056           // Remove initializer
0057           if (const_val->uses().size() < 1)
0058             graph->eraseInitializerAndInput(const_val);
0059           break;
0060         }
0061       }
0062     }
0063     ONNX_ASSERTM(node->hasAttribute(kshape), "No initializer or constant input to Reshape node found");
0064   }
0065 
0066   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0067     adapt_reshape_5_4(graph, node);
0068     return node;
0069   }
0070 };
0071 
0072 } // namespace version_conversion
0073 } // namespace ONNX_NAMESPACE