File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012
0013 namespace ONNX_NAMESPACE {
0014 namespace version_conversion {
0015
0016 class Scatter_10_11 final : public Adapter {
0017 public:
0018 explicit Scatter_10_11() : Adapter("Scatter", OpSetID(10), OpSetID(11)) {}
0019
0020 Node* adapt_scatter_10_11(std::shared_ptr<Graph> graph, Node* node) const {
0021 int axis = node->hasAttribute(kaxis) ? node->i(kaxis) : 0;
0022
0023
0024 Node* scatter_elements = graph->create(kScatterElements);
0025 scatter_elements->i_(kaxis, axis);
0026 scatter_elements->addInput(node->inputs()[0]);
0027 scatter_elements->addInput(node->inputs()[1]);
0028 scatter_elements->addInput(node->inputs()[2]);
0029 node->replaceAllUsesWith(scatter_elements);
0030
0031 scatter_elements->insertBefore(node);
0032 node->destroy();
0033
0034 return scatter_elements;
0035 }
0036
0037 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0038 return adapt_scatter_10_11(graph, node);
0039 }
0040 };
0041
0042 }
0043 }