Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:54

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for all ops that remove consumed_inputs
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <string>
0013 #include <vector>
0014 
0015 #include "onnx/version_converter/adapters/adapter.h"
0016 
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019 
0020 class AxesAttributeToInput : public Adapter {
0021  public:
0022   explicit AxesAttributeToInput(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
0023       : Adapter(op_name, initial, target) {}
0024 
0025   void attrToInput(std::shared_ptr<Graph> graph, Node* node, std::vector<int64_t> axes) const {
0026     Tensor t;
0027     t.elem_type() = TensorProto_DataType_INT64;
0028     t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())};
0029     auto& data = t.int64s();
0030     for (auto a : axes) {
0031       data.emplace_back(a);
0032     }
0033 
0034     Node* constant = graph->create(kConstant);
0035     constant->insertBefore(node);
0036     constant->t_(kvalue, t);
0037     node->addInput(constant->output());
0038   }
0039 
0040   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0041     if (node->hasAttribute(kaxes)) {
0042       attrToInput(graph, node, node->is(kaxes));
0043       node->removeAttribute(kaxes);
0044     }
0045     return node;
0046   }
0047 };
0048 
0049 } // namespace version_conversion
0050 } // namespace ONNX_NAMESPACE