File indexing completed on 2025-02-22 10:42:45
0001
0002
0003
0004
0005
0006
0007 #include <utility>
0008
0009 #include "onnx/defs/shape_inference.h"
0010
0011 namespace ONNX_NAMESPACE {
0012
0013 inline void appendDimToTensorShapeProto(TensorShapeProto& tsp, const TensorShapeProto* input_data, int index) {
0014 if (index >= input_data->dim_size() || index < -input_data->dim_size()) {
0015 fail_shape_inference("indices must be in [-rank, rank-1].");
0016 } else {
0017 *tsp.add_dim() = input_data->dim((index < 0) ? input_data->dim_size() + index : index);
0018 }
0019 }
0020
0021
0022 inline bool axisIsZero(DataPropagationContext& ctx, bool defaultZero = false) {
0023 auto axisAttr = ctx.getAttribute("axis");
0024
0025 if (!axisAttr) {
0026 if (defaultZero) {
0027 return true;
0028 } else {
0029 fail_shape_inference("Required attribute axis is missing");
0030 return false;
0031 }
0032 }
0033 int axis = static_cast<int>(axisAttr->i());
0034 auto input_data_0 = ctx.getInputData(0);
0035 if (input_data_0 == nullptr) {
0036 return false;
0037 }
0038 int rank = input_data_0->dim_size();
0039 if (axis < -rank || axis >= rank) {
0040 fail_shape_inference("axis must be in [-rank, rank-1].");
0041 return false;
0042 }
0043 if (axis < 0) {
0044 axis += rank;
0045 }
0046
0047 return axis == 0;
0048 }
0049
0050 inline void PropagateShapeDataFromInputToOutput(DataPropagationContext& ctx, int idx) {
0051
0052 const auto input_data = ctx.getInputData(idx);
0053 if (input_data != nullptr) {
0054 TensorShapeProto tsp;
0055 tsp.CopyFrom(*input_data);
0056 ctx.addOutputData(0, std::move(tsp));
0057 }
0058 }
0059
0060 inline void GatherOp13DataPropagator(DataPropagationContext& ctx) {
0061 if (!axisIsZero(ctx, true)) {
0062 return;
0063 }
0064 const auto input_data = ctx.getInputData(0);
0065 if (input_data == nullptr) {
0066 return;
0067 }
0068 const auto input_indices = ctx.getInputData(1);
0069 if (input_data == nullptr || input_indices == nullptr) {
0070 return;
0071 }
0072 TensorShapeProto tsp;
0073 for (int i = 0; i < input_indices->dim_size(); ++i) {
0074 if (input_indices->dim(i).has_dim_value()) {
0075 appendDimToTensorShapeProto(tsp, input_data, input_indices->dim(i).dim_value());
0076 } else {
0077 return;
0078 }
0079 }
0080 if (tsp.dim_size() > 0) {
0081 ctx.addOutputData(0, std::move(tsp));
0082 }
0083 }
0084
0085 }