Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-15 09:02:59

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 #pragma once
0008 
0009 #include <utility>
0010 
0011 #include "onnx/defs/shape_inference.h"
0012 
0013 namespace ONNX_NAMESPACE {
0014 
0015 inline void appendDimToTensorShapeProto(TensorShapeProto& tsp, const TensorShapeProto* input_data, int index) {
0016   if (index >= input_data->dim_size() || index < -input_data->dim_size()) {
0017     fail_shape_inference("indices must be in [-rank, rank-1].");
0018   } else {
0019     *tsp.add_dim() = input_data->dim((index < 0) ? input_data->dim_size() + index : index);
0020   }
0021 }
0022 
0023 // Returns true if the given axis attribute is 0
0024 inline bool axisIsZero(DataPropagationContext& ctx, bool defaultZero = false) {
0025   auto axisAttr = ctx.getAttribute("axis");
0026   // if axis is not defined
0027   if (!axisAttr) {
0028     if (defaultZero) {
0029       return true;
0030     } else {
0031       fail_shape_inference("Required attribute axis is missing");
0032       return false;
0033     }
0034   }
0035   int axis = static_cast<int>(axisAttr->i());
0036   auto input_data_0 = ctx.getInputData(0);
0037   if (input_data_0 == nullptr) {
0038     return false;
0039   }
0040   int rank = input_data_0->dim_size();
0041   if (axis < -rank || axis >= rank) {
0042     fail_shape_inference("axis must be in [-rank, rank-1].");
0043     return false;
0044   }
0045   if (axis < 0) {
0046     axis += rank;
0047   }
0048   // Only supports axis = 0 since the data comes from Shape
0049   return axis == 0;
0050 }
0051 
0052 inline void PropagateShapeDataFromInputToOutput(DataPropagationContext& ctx, int idx) {
0053   // propagate input data
0054   const auto input_data = ctx.getInputData(idx);
0055   if (input_data != nullptr) {
0056     TensorShapeProto tsp;
0057     tsp.CopyFrom(*input_data);
0058     ctx.addOutputData(0, std::move(tsp));
0059   }
0060 }
0061 
0062 inline void GatherOp13DataPropagator(DataPropagationContext& ctx) {
0063   if (!axisIsZero(ctx, true)) {
0064     return;
0065   }
0066   const auto input_data = ctx.getInputData(0);
0067   if (input_data == nullptr) {
0068     return;
0069   }
0070   const auto input_indices = ctx.getInputData(1);
0071   if (input_data == nullptr || input_indices == nullptr) {
0072     return;
0073   }
0074   TensorShapeProto tsp;
0075   for (int i = 0; i < input_indices->dim_size(); ++i) {
0076     if (input_indices->dim(i).has_dim_value()) {
0077       appendDimToTensorShapeProto(tsp, input_data, input_indices->dim(i).dim_value());
0078     } else {
0079       return;
0080     }
0081   }
0082   if (tsp.dim_size() > 0) {
0083     ctx.addOutputData(0, std::move(tsp));
0084   }
0085 }
0086 
0087 } // namespace ONNX_NAMESPACE