File indexing completed on 2025-04-03 08:57:54
0001
0002
0003
0004
0005
0006
0007 #pragma once
0008
0009 #include <memory>
0010 #include <string>
0011 #include <vector>
0012
0013 #include "onnx/version_converter/adapters/adapter.h"
0014
0015 namespace ONNX_NAMESPACE {
0016 namespace version_conversion {
0017
0018 class AxisAttributeToInput : public Adapter {
0019 public:
0020 AxisAttributeToInput(
0021 const std::string& op_name,
0022 const OpSetID& initial,
0023 const OpSetID& target,
0024 size_t axis_index,
0025 int64_t default_axis)
0026 : Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {}
0027
0028 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0029 if (node->hasAttribute(kaxis)) {
0030 AttrToInput(graph, node, node->i(kaxis), this->axis_index);
0031 node->removeAttribute(kaxis);
0032 return node;
0033 }
0034
0035
0036 AttrToInput(graph, node, default_axis, this->axis_index);
0037 return node;
0038 }
0039
0040 private:
0041 size_t axis_index;
0042 int64_t default_axis;
0043
0044 void AttrToInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis, size_t axis_index) const {
0045 const ArrayRef<Value*>& inputs = node->inputs();
0046
0047
0048 for (size_t i = inputs.size(); i < axis_index; ++i) {
0049 Node* empty_input = graph->create(kUndefined);
0050 empty_input->insertBefore(node);
0051 node->addInput(empty_input->output());
0052 }
0053
0054
0055 Node* constant = CreateAxisInput(graph, node, axis);
0056 node->addInput(constant->output());
0057 }
0058
0059 Node* CreateAxisInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis) const {
0060 Tensor t;
0061 t.elem_type() = TensorProto_DataType_INT64;
0062 t.sizes() = std::vector<int64_t>{};
0063 auto& data = t.int64s();
0064 data.emplace_back(axis);
0065
0066 Node* constant = graph->create(kConstant);
0067 constant->insertBefore(node);
0068 constant->t_(kvalue, t);
0069 return constant;
0070 }
0071 };
0072
0073 }
0074 }