File indexing completed on 2025-09-15 09:02:59
0001
0002
0003
0004
0005
0006
0007 #pragma once
0008
0009 #include <utility>
0010
0011 #include "onnx/defs/shape_inference.h"
0012
0013 namespace ONNX_NAMESPACE {
0014
0015 inline void appendDimToTensorShapeProto(TensorShapeProto& tsp, const TensorShapeProto* input_data, int index) {
0016 if (index >= input_data->dim_size() || index < -input_data->dim_size()) {
0017 fail_shape_inference("indices must be in [-rank, rank-1].");
0018 } else {
0019 *tsp.add_dim() = input_data->dim((index < 0) ? input_data->dim_size() + index : index);
0020 }
0021 }
0022
0023
0024 inline bool axisIsZero(DataPropagationContext& ctx, bool defaultZero = false) {
0025 auto axisAttr = ctx.getAttribute("axis");
0026
0027 if (!axisAttr) {
0028 if (defaultZero) {
0029 return true;
0030 } else {
0031 fail_shape_inference("Required attribute axis is missing");
0032 return false;
0033 }
0034 }
0035 int axis = static_cast<int>(axisAttr->i());
0036 auto input_data_0 = ctx.getInputData(0);
0037 if (input_data_0 == nullptr) {
0038 return false;
0039 }
0040 int rank = input_data_0->dim_size();
0041 if (axis < -rank || axis >= rank) {
0042 fail_shape_inference("axis must be in [-rank, rank-1].");
0043 return false;
0044 }
0045 if (axis < 0) {
0046 axis += rank;
0047 }
0048
0049 return axis == 0;
0050 }
0051
0052 inline void PropagateShapeDataFromInputToOutput(DataPropagationContext& ctx, int idx) {
0053
0054 const auto input_data = ctx.getInputData(idx);
0055 if (input_data != nullptr) {
0056 TensorShapeProto tsp;
0057 tsp.CopyFrom(*input_data);
0058 ctx.addOutputData(0, std::move(tsp));
0059 }
0060 }
0061
0062 inline void GatherOp13DataPropagator(DataPropagationContext& ctx) {
0063 if (!axisIsZero(ctx, true)) {
0064 return;
0065 }
0066 const auto input_data = ctx.getInputData(0);
0067 if (input_data == nullptr) {
0068 return;
0069 }
0070 const auto input_indices = ctx.getInputData(1);
0071 if (input_data == nullptr || input_indices == nullptr) {
0072 return;
0073 }
0074 TensorShapeProto tsp;
0075 for (int i = 0; i < input_indices->dim_size(); ++i) {
0076 if (input_indices->dim(i).has_dim_value()) {
0077 appendDimToTensorShapeProto(tsp, input_data, input_indices->dim(i).dim_value());
0078 } else {
0079 return;
0080 }
0081 }
0082 if (tsp.dim_size() > 0) {
0083 ctx.addOutputData(0, std::move(tsp));
0084 }
0085 }
0086
0087 }