Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Reshape in default domain from version 4 to 5
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 
0013 #include "onnx/version_converter/adapters/remove_consumed_inputs.h"
0014 
0015 namespace ONNX_NAMESPACE {
0016 namespace version_conversion {
0017 
0018 class Reshape_4_5 final : public RemoveConsumedInputs {
0019  public:
0020   explicit Reshape_4_5() : RemoveConsumedInputs("Reshape", OpSetID(4), OpSetID(5)) {}
0021 
0022   void adapt_reshape_4_5(std::shared_ptr<Graph> graph, Node* node) const {
0023     // Create Input from Attribute - add as Initializer
0024     // Create tensor for value attribute
0025     Tensor t;
0026     t.elem_type() = TensorProto_DataType_INT64;
0027     auto& data = t.int64s();
0028     // Turn shapes attribute into tensor
0029     for (int64_t shape : node->is(kshape)) {
0030       data.emplace_back(shape);
0031     }
0032     // Add value as input to node
0033     Node* constant = graph->create(kConstant);
0034     constant->insertBefore(node);
0035     constant->t_(kvalue, t);
0036     node->addInput(constant->output());
0037     // Remove kshape attribute
0038     node->removeAttribute(kshape);
0039   }
0040 
0041   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0042     RemoveConsumedInputs::adapt(graph, node);
0043     adapt_reshape_4_5(graph, node);
0044     return node;
0045   }
0046 };
0047 
0048 } // namespace version_conversion
0049 } // namespace ONNX_NAMESPACE