Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:54

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 #pragma once
0008 
0009 #include <memory>
0010 #include <string>
0011 #include <vector>
0012 
0013 #include "onnx/version_converter/adapters/adapter.h"
0014 
0015 namespace ONNX_NAMESPACE {
0016 namespace version_conversion {
0017 
0018 class AxisAttributeToInput : public Adapter {
0019  public:
0020   AxisAttributeToInput(
0021       const std::string& op_name,
0022       const OpSetID& initial,
0023       const OpSetID& target,
0024       size_t axis_index,
0025       int64_t default_axis)
0026       : Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {}
0027 
0028   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0029     if (node->hasAttribute(kaxis)) {
0030       AttrToInput(graph, node, node->i(kaxis), this->axis_index);
0031       node->removeAttribute(kaxis);
0032       return node;
0033     }
0034 
0035     // Fill in the default value for axis
0036     AttrToInput(graph, node, default_axis, this->axis_index);
0037     return node;
0038   }
0039 
0040  private:
0041   size_t axis_index;
0042   int64_t default_axis;
0043 
0044   void AttrToInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis, size_t axis_index) const {
0045     const ArrayRef<Value*>& inputs = node->inputs();
0046 
0047     // Add the optional inputs if they don't exist
0048     for (size_t i = inputs.size(); i < axis_index; ++i) {
0049       Node* empty_input = graph->create(kUndefined);
0050       empty_input->insertBefore(node);
0051       node->addInput(empty_input->output());
0052     }
0053 
0054     // Add the axis input
0055     Node* constant = CreateAxisInput(graph, node, axis);
0056     node->addInput(constant->output());
0057   }
0058 
0059   Node* CreateAxisInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis) const {
0060     Tensor t;
0061     t.elem_type() = TensorProto_DataType_INT64;
0062     t.sizes() = std::vector<int64_t>{};
0063     auto& data = t.int64s();
0064     data.emplace_back(axis);
0065 
0066     Node* constant = graph->create(kConstant);
0067     constant->insertBefore(node);
0068     constant->t_(kvalue, t);
0069     return constant;
0070   }
0071 };
0072 
0073 } // namespace version_conversion
0074 } // namespace ONNX_NAMESPACE