Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:45

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 #include <utility>
0008 
0009 #include "onnx/defs/shape_inference.h"
0010 
0011 namespace ONNX_NAMESPACE {
0012 
0013 inline void appendDimToTensorShapeProto(TensorShapeProto& tsp, const TensorShapeProto* input_data, int index) {
0014   if (index >= input_data->dim_size() || index < -input_data->dim_size()) {
0015     fail_shape_inference("indices must be in [-rank, rank-1].");
0016   } else {
0017     *tsp.add_dim() = input_data->dim((index < 0) ? input_data->dim_size() + index : index);
0018   }
0019 }
0020 
0021 // Returns true if the given axis attribute is 0
0022 inline bool axisIsZero(DataPropagationContext& ctx, bool defaultZero = false) {
0023   auto axisAttr = ctx.getAttribute("axis");
0024   // if axis is not defined
0025   if (!axisAttr) {
0026     if (defaultZero) {
0027       return true;
0028     } else {
0029       fail_shape_inference("Required attribute axis is missing");
0030       return false;
0031     }
0032   }
0033   int axis = static_cast<int>(axisAttr->i());
0034   auto input_data_0 = ctx.getInputData(0);
0035   if (input_data_0 == nullptr) {
0036     return false;
0037   }
0038   int rank = input_data_0->dim_size();
0039   if (axis < -rank || axis >= rank) {
0040     fail_shape_inference("axis must be in [-rank, rank-1].");
0041     return false;
0042   }
0043   if (axis < 0) {
0044     axis += rank;
0045   }
0046   // Only supports axis = 0 since the data comes from Shape
0047   return axis == 0;
0048 }
0049 
0050 inline void PropagateShapeDataFromInputToOutput(DataPropagationContext& ctx, int idx) {
0051   // propagate input data
0052   const auto input_data = ctx.getInputData(idx);
0053   if (input_data != nullptr) {
0054     TensorShapeProto tsp;
0055     tsp.CopyFrom(*input_data);
0056     ctx.addOutputData(0, std::move(tsp));
0057   }
0058 }
0059 
0060 inline void GatherOp13DataPropagator(DataPropagationContext& ctx) {
0061   if (!axisIsZero(ctx, true)) {
0062     return;
0063   }
0064   const auto input_data = ctx.getInputData(0);
0065   if (input_data == nullptr) {
0066     return;
0067   }
0068   const auto input_indices = ctx.getInputData(1);
0069   if (input_data == nullptr || input_indices == nullptr) {
0070     return;
0071   }
0072   TensorShapeProto tsp;
0073   for (int i = 0; i < input_indices->dim_size(); ++i) {
0074     if (input_indices->dim(i).has_dim_value()) {
0075       appendDimToTensorShapeProto(tsp, input_data, input_indices->dim(i).dim_value());
0076     } else {
0077       return;
0078     }
0079   }
0080   if (tsp.dim_size() > 0) {
0081     ctx.addOutputData(0, std::move(tsp));
0082   }
0083 }
0084 
0085 } // namespace ONNX_NAMESPACE