File indexing completed on 2025-04-06 08:55:47
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <vector>
0013
0014 namespace ONNX_NAMESPACE {
0015 namespace version_conversion {
0016
0017 class Slice_9_10 final : public Adapter {
0018 public:
0019 explicit Slice_9_10() : Adapter("Slice", OpSetID(9), OpSetID(10)) {}
0020
0021 void attrToInput(std::shared_ptr<Graph> graph, Node* node, const std::vector<int64_t>& attr) const {
0022 Tensor t;
0023 t.elem_type() = TensorProto_DataType_INT64;
0024 t.sizes() = std::vector<int64_t>{static_cast<int64_t>(attr.size())};
0025 auto& data = t.int64s();
0026 for (auto a : attr) {
0027 data.emplace_back(a);
0028 }
0029 Node* constant = graph->create(kConstant);
0030 constant->insertBefore(node);
0031 constant->t_(kvalue, t);
0032 node->addInput(constant->output());
0033 }
0034
0035 void adapt_slice_9_10(std::shared_ptr<Graph> graph, Node* node) const {
0036 attrToInput(graph, node, node->is(kstarts));
0037 node->removeAttribute(kstarts);
0038 attrToInput(graph, node, node->is(kends));
0039 node->removeAttribute(kends);
0040
0041 if (node->hasAttribute(kaxes)) {
0042 attrToInput(graph, node, node->is(kaxes));
0043 node->removeAttribute(kaxes);
0044 }
0045 }
0046
0047 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0048 adapt_slice_9_10(graph, node);
0049 return node;
0050 }
0051 };
0052
0053 }
0054 }