File indexing completed on 2025-04-03 08:57:54
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <string>
0013 #include <vector>
0014
0015 #include "onnx/version_converter/adapters/adapter.h"
0016
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019
0020 class AxesAttributeToInput : public Adapter {
0021 public:
0022 explicit AxesAttributeToInput(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
0023 : Adapter(op_name, initial, target) {}
0024
0025 void attrToInput(std::shared_ptr<Graph> graph, Node* node, std::vector<int64_t> axes) const {
0026 Tensor t;
0027 t.elem_type() = TensorProto_DataType_INT64;
0028 t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())};
0029 auto& data = t.int64s();
0030 for (auto a : axes) {
0031 data.emplace_back(a);
0032 }
0033
0034 Node* constant = graph->create(kConstant);
0035 constant->insertBefore(node);
0036 constant->t_(kvalue, t);
0037 node->addInput(constant->output());
0038 }
0039
0040 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0041 if (node->hasAttribute(kaxes)) {
0042 attrToInput(graph, node, node->is(kaxes));
0043 node->removeAttribute(kaxes);
0044 }
0045 return node;
0046 }
0047 };
0048
0049 }
0050 }