File indexing completed on 2025-04-03 08:57:54
0001
0002
0003
0004
0005
0006
0007 #pragma once
0008
0009 #include <memory>
0010 #include <string>
0011 #include <utility>
0012 #include <vector>
0013
0014 #include "onnx/version_converter/adapters/adapter.h"
0015
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018 class AxisInputToAttribute : public Adapter {
0019 public:
0020 explicit AxisInputToAttribute(
0021 const std::string& op_name,
0022 const OpSetID& initial,
0023 const OpSetID& target,
0024 size_t axis_index,
0025 int64_t default_axis)
0026 : Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {}
0027
0028 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0029 if (!HasAxisInput(node)) {
0030 node->i_(kaxis, this->default_axis);
0031 return EnsureAndReturnNode(node);
0032 }
0033
0034 const ArrayRef<Value*>& inputs = node->inputs();
0035 Value* axis_val = inputs[this->axis_index];
0036 Node* axis_node = axis_val->node();
0037
0038 if (axis_node->kind() == kConstant) {
0039 HandleConstantNode(node, axis_node, axis_val);
0040 return EnsureAndReturnNode(node);
0041 }
0042
0043 if (graph->is_constant_initializer(axis_val)) {
0044 HandleInitializerNode(graph, node, axis_val);
0045 return EnsureAndReturnNode(node);
0046 }
0047
0048 ONNX_ASSERTM(false, "Axis input must be a constant or initializer for promotion to attribute.");
0049 }
0050
0051 private:
0052 size_t axis_index;
0053 int64_t default_axis;
0054
0055 bool HasAxisInput(const Node* node) const {
0056 const ArrayRef<const Value*>& inputs = node->inputs();
0057 return inputs.size() > this->axis_index && inputs[this->axis_index]->node()->kind() != kUndefined;
0058 }
0059
0060 void HandleConstantNode(Node* node, Node* axis_node, Value* axis_val) const {
0061 const std::vector<int64_t>& int64s = axis_node->t(kvalue).int64s();
0062 if (int64s.empty()) {
0063 std::string raw_data = axis_node->t(kvalue).raw();
0064 ONNX_ASSERTM(
0065 raw_data.size() != 0 && raw_data.size() % 8 == 0,
0066 "Raw Data must be non-empty and size must be a multiple of 8");
0067 const int64_t* raw = reinterpret_cast<const int64_t*>(raw_data.c_str());
0068 node->i_(kaxis, raw[0]);
0069 } else {
0070 node->i_(kaxis, int64s.at(0));
0071 }
0072 node->removeInput(this->axis_index);
0073 if (axis_val->uses().size() < 1) {
0074 axis_node->destroy();
0075 }
0076 }
0077
0078 void HandleInitializerNode(std::shared_ptr<Graph> graph, Node* node, Value* axis_val) const {
0079 const std::string initializer_name = axis_val->uniqueName();
0080 for (const auto& initializer : graph->initializers()) {
0081 if (initializer.name() == initializer_name) {
0082 node->i_(kaxis, initializer.int64s().at(0));
0083 node->removeInput(this->axis_index);
0084
0085 if (axis_val->uses().size() < 1)
0086 graph->eraseInitializer(initializer_name);
0087 break;
0088 }
0089 }
0090 }
0091
0092 inline Node* EnsureAndReturnNode(Node* node) const {
0093 ONNX_ASSERTM(node->hasAttribute(kaxis), "Axis attribute not created. This may be a bug.");
0094 return node;
0095 }
0096 };
0097
0098 }
0099 }