Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:54

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 #pragma once
0008 
0009 #include <memory>
0010 #include <string>
0011 #include <utility>
0012 #include <vector>
0013 
0014 #include "onnx/version_converter/adapters/adapter.h"
0015 
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018 class AxisInputToAttribute : public Adapter {
0019  public:
0020   explicit AxisInputToAttribute(
0021       const std::string& op_name,
0022       const OpSetID& initial,
0023       const OpSetID& target,
0024       size_t axis_index,
0025       int64_t default_axis)
0026       : Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {}
0027 
0028   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0029     if (!HasAxisInput(node)) {
0030       node->i_(kaxis, this->default_axis);
0031       return EnsureAndReturnNode(node);
0032     }
0033 
0034     const ArrayRef<Value*>& inputs = node->inputs();
0035     Value* axis_val = inputs[this->axis_index];
0036     Node* axis_node = axis_val->node();
0037 
0038     if (axis_node->kind() == kConstant) {
0039       HandleConstantNode(node, axis_node, axis_val);
0040       return EnsureAndReturnNode(node);
0041     }
0042 
0043     if (graph->is_constant_initializer(axis_val)) {
0044       HandleInitializerNode(graph, node, axis_val);
0045       return EnsureAndReturnNode(node);
0046     }
0047 
0048     ONNX_ASSERTM(false, "Axis input must be a constant or initializer for promotion to attribute.");
0049   }
0050 
0051  private:
0052   size_t axis_index;
0053   int64_t default_axis;
0054 
0055   bool HasAxisInput(const Node* node) const {
0056     const ArrayRef<const Value*>& inputs = node->inputs();
0057     return inputs.size() > this->axis_index && inputs[this->axis_index]->node()->kind() != kUndefined;
0058   }
0059 
0060   void HandleConstantNode(Node* node, Node* axis_node, Value* axis_val) const {
0061     const std::vector<int64_t>& int64s = axis_node->t(kvalue).int64s();
0062     if (int64s.empty()) {
0063       std::string raw_data = axis_node->t(kvalue).raw();
0064       ONNX_ASSERTM(
0065           raw_data.size() != 0 && raw_data.size() % 8 == 0,
0066           "Raw Data must be non-empty and size must be a multiple of 8");
0067       const int64_t* raw = reinterpret_cast<const int64_t*>(raw_data.c_str());
0068       node->i_(kaxis, raw[0]);
0069     } else {
0070       node->i_(kaxis, int64s.at(0));
0071     }
0072     node->removeInput(this->axis_index);
0073     if (axis_val->uses().size() < 1) {
0074       axis_node->destroy();
0075     }
0076   }
0077 
0078   void HandleInitializerNode(std::shared_ptr<Graph> graph, Node* node, Value* axis_val) const {
0079     const std::string initializer_name = axis_val->uniqueName();
0080     for (const auto& initializer : graph->initializers()) {
0081       if (initializer.name() == initializer_name) {
0082         node->i_(kaxis, initializer.int64s().at(0));
0083         node->removeInput(this->axis_index);
0084         // Remove initializer
0085         if (axis_val->uses().size() < 1)
0086           graph->eraseInitializer(initializer_name);
0087         break;
0088       }
0089     }
0090   }
0091 
0092   inline Node* EnsureAndReturnNode(Node* node) const {
0093     ONNX_ASSERTM(node->hasAttribute(kaxis), "Axis attribute not created. This may be a bug.");
0094     return node;
0095   }
0096 };
0097 
0098 } // namespace version_conversion
0099 } // namespace ONNX_NAMESPACE