Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Scatter in default domain from version 10 to 11
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 
0013 namespace ONNX_NAMESPACE {
0014 namespace version_conversion {
0015 
0016 class Scatter_10_11 final : public Adapter {
0017  public:
0018   explicit Scatter_10_11() : Adapter("Scatter", OpSetID(10), OpSetID(11)) {}
0019 
0020   Node* adapt_scatter_10_11(std::shared_ptr<Graph> graph, Node* node) const {
0021     int axis = node->hasAttribute(kaxis) ? node->i(kaxis) : 0;
0022 
0023     // Replace the node with an equivalent ScatterElements node
0024     Node* scatter_elements = graph->create(kScatterElements);
0025     scatter_elements->i_(kaxis, axis);
0026     scatter_elements->addInput(node->inputs()[0]);
0027     scatter_elements->addInput(node->inputs()[1]);
0028     scatter_elements->addInput(node->inputs()[2]);
0029     node->replaceAllUsesWith(scatter_elements);
0030 
0031     scatter_elements->insertBefore(node);
0032     node->destroy();
0033 
0034     return scatter_elements;
0035   }
0036 
0037   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0038     return adapt_scatter_10_11(graph, node);
0039   }
0040 };
0041 
0042 } // namespace version_conversion
0043 } // namespace ONNX_NAMESPACE