Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-06 08:55:47

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Slice in default domain from version 9 to 10
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <vector>
0013 
0014 namespace ONNX_NAMESPACE {
0015 namespace version_conversion {
0016 
0017 class Slice_9_10 final : public Adapter {
0018  public:
0019   explicit Slice_9_10() : Adapter("Slice", OpSetID(9), OpSetID(10)) {}
0020 
0021   void attrToInput(std::shared_ptr<Graph> graph, Node* node, const std::vector<int64_t>& attr) const {
0022     Tensor t;
0023     t.elem_type() = TensorProto_DataType_INT64;
0024     t.sizes() = std::vector<int64_t>{static_cast<int64_t>(attr.size())};
0025     auto& data = t.int64s();
0026     for (auto a : attr) {
0027       data.emplace_back(a);
0028     }
0029     Node* constant = graph->create(kConstant);
0030     constant->insertBefore(node);
0031     constant->t_(kvalue, t);
0032     node->addInput(constant->output());
0033   }
0034 
0035   void adapt_slice_9_10(std::shared_ptr<Graph> graph, Node* node) const {
0036     attrToInput(graph, node, node->is(kstarts));
0037     node->removeAttribute(kstarts);
0038     attrToInput(graph, node, node->is(kends));
0039     node->removeAttribute(kends);
0040 
0041     if (node->hasAttribute(kaxes)) {
0042       attrToInput(graph, node, node->is(kaxes));
0043       node->removeAttribute(kaxes);
0044     }
0045   }
0046 
0047   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0048     adapt_slice_9_10(graph, node);
0049     return node;
0050   }
0051 };
0052 
0053 } // namespace version_conversion
0054 } // namespace ONNX_NAMESPACE