Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for all ops that remove consumed_inputs
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <vector>
0013 
0014 #include "onnx/version_converter/adapters/adapter.h"
0015 
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018 
0019 class Split_12_13 : public Adapter {
0020  public:
0021   explicit Split_12_13() : Adapter("Split", OpSetID(12), OpSetID(13)) {}
0022 
0023   void attrToInput(std::shared_ptr<Graph> graph, Node* node, std::vector<int64_t> axes) const {
0024     Tensor t;
0025     t.elem_type() = TensorProto_DataType_INT64;
0026     t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())};
0027     auto& data = t.int64s();
0028     for (auto a : axes) {
0029       data.emplace_back(a);
0030     }
0031     Node* constant = graph->create(kConstant);
0032     constant->insertBefore(node);
0033     constant->t_(kvalue, t);
0034     node->addInput(constant->output());
0035   }
0036 
0037   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0038     if (node->hasAttribute(ksplit)) {
0039       attrToInput(graph, node, node->is(ksplit));
0040       node->removeAttribute(ksplit);
0041     }
0042     return node;
0043   }
0044 };
0045 
0046 } // namespace version_conversion
0047 } // namespace ONNX_NAMESPACE