File indexing completed on 2025-04-03 08:57:54
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <string>
0013 #include <utility>
0014 #include <vector>
0015
0016 #include "onnx/version_converter/adapters/adapter.h"
0017
0018 namespace ONNX_NAMESPACE {
0019 namespace version_conversion {
0020
0021 class AxesInputToAttribute : public Adapter {
0022 public:
0023 explicit AxesInputToAttribute(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
0024 : Adapter(op_name, initial, target) {}
0025
0026 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0027
0028 const ArrayRef<Value*>& inputs = node->inputs();
0029
0030
0031 Value* const_val = inputs[1];
0032 Node* node_ptr = const_val->node();
0033 if (node_ptr->kind() == kConstant) {
0034
0035 const std::vector<int64_t>& int64s = node_ptr->t(kvalue).int64s();
0036 if (int64s.empty()) {
0037
0038 std::string raw_data = node_ptr->t(kvalue).raw();
0039 ONNX_ASSERTM(
0040 raw_data.size() != 0 && raw_data.size() % 8 == 0,
0041 "Raw Data must be non-empty and size must be a multiple of 8");
0042 int64_t* raw = (int64_t*)const_cast<char*>(raw_data.c_str());
0043 node->is_(kaxes, std::vector<int64_t>(raw, raw + node_ptr->t(kvalue).size_from_dim(0)));
0044 } else {
0045 node->is_(kaxes, std::forward<const std::vector<int64_t>>(int64s));
0046 }
0047
0048 node->removeInput(1);
0049 if (const_val->uses().size() < 1) {
0050 node_ptr->destroy();
0051 }
0052 } else {
0053
0054 for (const auto& initializer : graph->initializers()) {
0055 if (initializer.name() == inputs[1]->uniqueName()) {
0056 node->is_(kaxes, std::forward<const std::vector<int64_t>>(initializer.int64s()));
0057 node->removeInput(1);
0058
0059 if (const_val->uses().size() < 1)
0060 graph->eraseInitializerAndInput(const_val);
0061 break;
0062 }
0063 }
0064 }
0065 ONNX_ASSERTM(node->hasAttribute(kaxes), "No initializer or constant input to node found");
0066 return node;
0067 }
0068 };
0069
0070 }
0071 }