File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <vector>
0013
0014 #include "onnx/version_converter/adapters/adapter.h"
0015
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018
0019 class Split_12_13 : public Adapter {
0020 public:
0021 explicit Split_12_13() : Adapter("Split", OpSetID(12), OpSetID(13)) {}
0022
0023 void attrToInput(std::shared_ptr<Graph> graph, Node* node, std::vector<int64_t> axes) const {
0024 Tensor t;
0025 t.elem_type() = TensorProto_DataType_INT64;
0026 t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())};
0027 auto& data = t.int64s();
0028 for (auto a : axes) {
0029 data.emplace_back(a);
0030 }
0031 Node* constant = graph->create(kConstant);
0032 constant->insertBefore(node);
0033 constant->t_(kvalue, t);
0034 node->addInput(constant->output());
0035 }
0036
0037 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0038 if (node->hasAttribute(ksplit)) {
0039 attrToInput(graph, node, node->is(ksplit));
0040 node->removeAttribute(ksplit);
0041 }
0042 return node;
0043 }
0044 };
0045
0046 }
0047 }