File indexing completed on 2025-02-22 10:42:46
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include <algorithm>
0008 #include <functional>
0009 #include <string>
0010 #include <utility>
0011 #include <vector>
0012
0013 #include "onnx/defs/data_type_utils.h"
0014 #include "onnx/proto_utils.h"
0015 #include "onnx/string_utils.h"
0016
0017 namespace ONNX_NAMESPACE {
0018
0019 using Dim = TensorShapeProto_Dimension;
0020
0021 struct ShapeInferenceOptions {
0022
0023 bool check_type;
0024
0025
0026
0027 int error_mode;
0028
0029
0030 bool enable_data_propagation;
0031 ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)
0032 : check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){};
0033 };
0034
0035
0036 class SymbolTable {
0037 public:
0038
0039 virtual void addFromGraph(const GraphProto& g) = 0;
0040
0041 std::string createNew() {
0042 return createNew("unk__");
0043 }
0044 virtual std::string createNew(const std::string& symbol_prefix) = 0;
0045 virtual ~SymbolTable() = default;
0046 };
0047
0048 class GraphInferencer {
0049 public:
0050
0051
0052 virtual std::vector<const TypeProto*> doInferencing(
0053 const std::vector<const TypeProto*>& inputTypes,
0054 const std::vector<const TensorProto*>& inputData) = 0;
0055 virtual ~GraphInferencer() = default;
0056 };
0057
0058
0059
0060 class InferenceError final : public std::runtime_error {
0061 public:
0062 using std::runtime_error::runtime_error;
0063
0064 InferenceError(const std::string& message) : std::runtime_error(message) {}
0065
0066 const char* what() const noexcept override {
0067 if (!expanded_message_.empty()) {
0068 return expanded_message_.c_str();
0069 }
0070 return std::runtime_error::what();
0071 }
0072
0073 void AppendContext(const std::string& context) {
0074 expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
0075 }
0076
0077 private:
0078 std::string expanded_message_;
0079 };
0080
0081 #define fail_type_inference(...) \
0082 ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__)));
0083
0084 #define fail_shape_inference(...) \
0085 ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__)));
0086
0087 struct InferenceContext {
0088 virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0089 virtual size_t getNumInputs() const = 0;
0090 virtual const TypeProto* getInputType(size_t index) const = 0;
0091 virtual bool hasInput(size_t index) const {
0092
0093
0094
0095
0096
0097
0098 return ((index < getNumInputs()) && (getInputType(index) != nullptr));
0099 }
0100 virtual const TensorProto* getInputData(size_t index) const = 0;
0101 virtual size_t getNumOutputs() const = 0;
0102 virtual TypeProto* getOutputType(size_t index) = 0;
0103 virtual GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) = 0;
0104 virtual ~InferenceContext() {}
0105 virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0;
0106
0107 virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0;
0108 };
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121 struct DataPropagationContext {
0122 virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0123 virtual size_t getNumInputs() const = 0;
0124 virtual const TypeProto* getInputType(size_t index) const = 0;
0125 virtual size_t getNumOutputs() const = 0;
0126 virtual const TypeProto* getOutputType(size_t index) const = 0;
0127 virtual ~DataPropagationContext() {}
0128 virtual const TensorShapeProto* getInputData(size_t index) = 0;
0129 virtual void addOutputData(size_t index, TensorShapeProto&& tp) = 0;
0130 };
0131
0132 using InferenceFunction = std::function<void(InferenceContext&)>;
0133 using DataPropagationFunction = std::function<void(DataPropagationContext&)>;
0134
0135
0136
0137 inline void dummyInferenceFunction(InferenceContext&){};
0138
0139
0140 inline void dummyDataPropagationFunction(DataPropagationContext&){};
0141
0142 template <typename T>
0143 inline bool getRepeatedAttribute(InferenceContext& ctx, std::string attr_name, std::vector<T>& values) {
0144 const auto* attr = ctx.getAttribute(attr_name);
0145 if (attr) {
0146 values = RetrieveValues<T>(*attr);
0147 return true;
0148 } else {
0149 return false;
0150 }
0151 }
0152
0153 inline int64_t getAttribute(InferenceContext& ctx, const std::string& attributeName, int64_t defaultValue) {
0154 auto attr_proto = ctx.getAttribute(attributeName);
0155 if ((nullptr != attr_proto) && attr_proto->has_i())
0156 return attr_proto->i();
0157 return defaultValue;
0158 }
0159
0160 inline int64_t getAttribute(DataPropagationContext& ctx, const std::string& attributeName, int64_t defaultValue) {
0161 auto attr_proto = ctx.getAttribute(attributeName);
0162 if ((nullptr != attr_proto) && attr_proto->has_i())
0163 return attr_proto->i();
0164 return defaultValue;
0165 }
0166
0167 inline std::string
0168 getAttribute(InferenceContext& ctx, const std::string& attributeName, const std::string& defaultValue) {
0169 auto attr_proto = ctx.getAttribute(attributeName);
0170 if ((nullptr != attr_proto) && attr_proto->has_s())
0171 return attr_proto->s();
0172 return defaultValue;
0173 }
0174
0175 inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, TensorShapeProto::Dimension dim2) {
0176 TensorShapeProto::Dimension result;
0177 if (dim1.has_dim_value() && dim2.has_dim_value()) {
0178 result.set_dim_value(dim1.dim_value() * dim2.dim_value());
0179 } else if (dim1.has_dim_value() && (dim1.dim_value() == 1)) {
0180 return dim2;
0181 } else if (dim2.has_dim_value() && (dim2.dim_value() == 1)) {
0182 return dim1;
0183 }
0184 return result;
0185 }
0186
0187 template <typename Container>
0188 std::string stringify(const Container& elements);
0189
0190 std::pair<int, int> getAttributeElementTypeAndLength(
0191 const InferenceContext& ctx,
0192 const std::initializer_list<std::string>& attribute_names);
0193
0194 inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, int64_t dim2) {
0195 TensorShapeProto::Dimension result;
0196 if (dim1.has_dim_value()) {
0197 result.set_dim_value(dim1.dim_value() * dim2);
0198 } else if (dim2 == 1) {
0199 return dim1;
0200 }
0201 return result;
0202 }
0203
0204 inline TensorShapeProto::Dimension operator/(TensorShapeProto::Dimension dim1, int64_t dim2) {
0205 TensorShapeProto::Dimension result;
0206 if (dim1.has_dim_value()) {
0207 result.set_dim_value(dim1.dim_value() / dim2);
0208 } else if (dim2 == 1) {
0209 return dim1;
0210 }
0211 return result;
0212 }
0213
0214
0215
0216
0217 inline TensorShapeProto::Dimension multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) {
0218 TensorShapeProto::Dimension dim;
0219 dim.set_dim_value(1);
0220 for (int i = from; i < upto_exclusive; ++i) {
0221 dim = dim * shape.dim(i);
0222 }
0223 return dim;
0224 }
0225
0226 inline int32_t getTensorElementType(const TypeProto& type) {
0227 int32_t result = TensorProto::UNDEFINED;
0228 const auto value_case = type.value_case();
0229 if (value_case == TypeProto::kTensorType) {
0230 result = type.tensor_type().elem_type();
0231 } else if (value_case == TypeProto::kSparseTensorType) {
0232 result = type.sparse_tensor_type().elem_type();
0233 }
0234 return result;
0235 }
0236
0237 inline void setTensorElementType(int32_t elem_type, TypeProto::ValueCase value_case, TypeProto& type) {
0238 if (value_case == TypeProto::kTensorType) {
0239 type.mutable_tensor_type()->set_elem_type(elem_type);
0240 } else if (value_case == TypeProto::kSparseTensorType) {
0241 type.mutable_sparse_tensor_type()->set_elem_type(elem_type);
0242 }
0243 }
0244
0245 void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type);
0246
0247 void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
0248
0249 void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
0250
0251 inline void propagateElemTypeFromDtypeToOutput(
0252 InferenceContext& ctx,
0253 const int data_type,
0254 size_t outputIndex,
0255 TypeProto::ValueCase expected_value_case) {
0256 const auto attribute_tensor_datatype = data_type;
0257 auto output_type = ctx.getOutputType(outputIndex);
0258 const auto output_value_case = output_type->value_case();
0259 if (output_value_case == TypeProto::VALUE_NOT_SET || output_value_case == expected_value_case) {
0260 setTensorElementType(attribute_tensor_datatype, expected_value_case, *output_type);
0261 } else {
0262
0263 fail_type_inference(
0264 "Output ", outputIndex, " expected to have: ", expected_value_case, " or UNDEFINED. Got: ", output_value_case);
0265 }
0266 }
0267
0268 inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const int data_type, size_t outputIndex) {
0269 propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, TypeProto::kTensorType);
0270 }
0271
0272 inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const AttributeProto* attr, size_t outputIndex) {
0273 int32_t data_type = TensorProto::UNDEFINED;
0274 TypeProto::ValueCase expected_value_case = TypeProto::VALUE_NOT_SET;
0275 const auto attr_type = attr->type();
0276 if (attr_type == AttributeProto::TENSOR) {
0277 if (attr->t().dims().size() != 1) {
0278 fail_type_inference("Attribute expected to have a one-dim tensor");
0279 }
0280 data_type = attr->t().data_type();
0281 expected_value_case = TypeProto::kTensorType;
0282 } else if (attr_type == AttributeProto::SPARSE_TENSOR) {
0283 if (attr->sparse_tensor().dims().size() != 1) {
0284 fail_type_inference("Attribute expected to have a one-dim sparse tensor");
0285 }
0286 data_type = attr->sparse_tensor().values().data_type();
0287 expected_value_case = TypeProto::kSparseTensorType;
0288 } else {
0289 fail_type_inference("Attribute expected to have tensor or sparse tensor type");
0290 }
0291
0292 propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case);
0293 }
0294
0295 inline bool hasShape(const TypeProto& type) {
0296 if (type.has_tensor_type()) {
0297 return type.tensor_type().has_shape();
0298 } else if (type.has_sparse_tensor_type()) {
0299 return type.sparse_tensor_type().has_shape();
0300 } else if (type.has_sequence_type() && type.sequence_type().has_elem_type()) {
0301 return hasShape(type.sequence_type().elem_type());
0302 } else if (type.has_optional_type() && type.optional_type().has_elem_type()) {
0303 return hasShape(type.optional_type().elem_type());
0304 }
0305 return false;
0306 }
0307
0308 template <typename Context>
0309 inline bool hasInputShape(const Context& ctx, size_t n) {
0310 return ctx.getNumInputs() > static_cast<size_t>(n) && ctx.getInputType(n) && hasShape(*ctx.getInputType(n));
0311 }
0312
0313 template <typename Context>
0314 inline bool hasNInputShapes(const Context& ctx, size_t n) {
0315 for (size_t i = 0; i < n; i++) {
0316 if (!hasInputShape(ctx, i)) {
0317 return false;
0318 }
0319 }
0320 return true;
0321 }
0322
0323 inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t n) {
0324 const auto* input_type = ctx.getInputType(n);
0325 const auto value_case = input_type->value_case();
0326 if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
0327 fail_type_inference("Attribute expected to have tensor or sparse tensor type");
0328 }
0329 if (value_case == TypeProto::kTensorType) {
0330 return input_type->tensor_type().shape();
0331 } else {
0332 return input_type->sparse_tensor_type().shape();
0333 }
0334 }
0335
0336 inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size_t n) {
0337 const auto* input_type = ctx.getInputType(n);
0338
0339 if (input_type == nullptr) {
0340 return nullptr;
0341 }
0342
0343 const auto value_case = input_type->value_case();
0344 if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
0345 fail_type_inference("Attribute expected to have tensor or sparse tensor type");
0346 }
0347 if (value_case == TypeProto::kTensorType) {
0348 return &input_type->tensor_type().shape();
0349 } else {
0350 return &input_type->sparse_tensor_type().shape();
0351 }
0352 }
0353
0354
0355 inline void appendSingleDimCopiedFromInputTypeToOutputType(
0356 InferenceContext& ctx,
0357 size_t inputIndex,
0358 size_t outputIndex,
0359 size_t fromDimIndex) {
0360 auto output_type = ctx.getOutputType(outputIndex);
0361 const auto output_value_case = output_type->value_case();
0362 auto input_type = ctx.getInputType(inputIndex);
0363 const auto input_value_case = input_type->value_case();
0364 if (output_value_case != input_value_case) {
0365 fail_type_inference(
0366 "Input: ",
0367 inputIndex,
0368 " type: ",
0369 input_value_case,
0370 " does not match type of output: ",
0371 outputIndex,
0372 "type: ",
0373 output_value_case);
0374 }
0375 if (TypeProto::kTensorType == input_value_case) {
0376 auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim();
0377 *dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex));
0378 } else if (TypeProto::kSparseTensorType == input_value_case) {
0379 auto* dim = output_type->mutable_sparse_tensor_type()->mutable_shape()->add_dim();
0380 *dim = input_type->sparse_tensor_type().shape().dim(static_cast<int>(fromDimIndex));
0381 } else {
0382 fail_type_inference(
0383 "Input ", inputIndex, " and Output ", outputIndex, " expected to have tensor or sparse tensor type");
0384 }
0385 }
0386
0387 inline void propagateShape(const TypeProto* from_type, TypeProto* to_type) {
0388 const auto from_type_case = from_type->value_case();
0389 const auto to_type_case = to_type->value_case();
0390 if (from_type_case != to_type_case) {
0391 fail_shape_inference(
0392 "Mismatch between inferred and declared type. Inferred=", from_type_case, " Declared=", to_type_case);
0393 }
0394
0395 if (TypeProto::kTensorType == from_type_case || TypeProto::kSparseTensorType == from_type_case) {
0396
0397
0398 if (hasShape(*from_type)) {
0399 if (TypeProto::kTensorType == from_type_case) {
0400 *to_type->mutable_tensor_type()->mutable_shape() = from_type->tensor_type().shape();
0401 } else {
0402 *to_type->mutable_sparse_tensor_type()->mutable_shape() = from_type->sparse_tensor_type().shape();
0403 }
0404 }
0405 } else if (TypeProto::kSequenceType == from_type_case) {
0406 propagateShape(&from_type->sequence_type().elem_type(), to_type->mutable_sequence_type()->mutable_elem_type());
0407 } else if (TypeProto::kOptionalType == from_type_case) {
0408 propagateShape(&from_type->optional_type().elem_type(), to_type->mutable_optional_type()->mutable_elem_type());
0409 } else if (TypeProto::kMapType == from_type_case) {
0410 propagateShape(&from_type->map_type().value_type(), to_type->mutable_map_type()->mutable_value_type());
0411 } else {
0412 fail_shape_inference("Unsupported Source/Target type=", from_type_case);
0413 }
0414 }
0415
0416 inline void propagateShapeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
0417 auto output_type = ctx.getOutputType(outputIndex);
0418 auto input_type = ctx.getInputType(inputIndex);
0419
0420 propagateShape(input_type, output_type);
0421 }
0422
0423 inline void propagateShapeAndTypeFromFirstInput(InferenceContext& ctx) {
0424 propagateElemTypeFromInputToOutput(ctx, 0, 0);
0425 if (!hasNInputShapes(ctx, 1)) {
0426 return;
0427 }
0428 propagateShapeFromInputToOutput(ctx, 0, 0);
0429 }
0430
0431 inline void
0432 updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType, TypeProto::ValueCase expected_type) {
0433 auto output_type = ctx.getOutputType(outputIndex);
0434 if (output_type == nullptr) {
0435 fail_type_inference("Output ", outputIndex, " is null");
0436 }
0437 if (output_type->value_case() == expected_type || output_type->value_case() == TypeProto::VALUE_NOT_SET) {
0438 setTensorElementType(elemType, expected_type, *output_type);
0439 } else {
0440
0441 fail_type_inference("Output ", outputIndex, " expected to have tensor or sparse tensor type: ", expected_type);
0442 }
0443 }
0444
0445 inline void updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType) {
0446 updateOutputElemType(ctx, outputIndex, elemType, TypeProto::kTensorType);
0447 }
0448
0449
0450
0451 inline void propagateElemTypeFromAttributeToOutput(
0452 InferenceContext& ctx,
0453 const std::string& attributeName,
0454 size_t outputIndex,
0455 TypeProto::ValueCase expected_type,
0456 TensorProto_DataType default_value = TensorProto::UNDEFINED) {
0457 auto attr_proto = ctx.getAttribute(attributeName);
0458 if (nullptr == attr_proto) {
0459 if (default_value != TensorProto::UNDEFINED) {
0460 updateOutputElemType(ctx, outputIndex, default_value, expected_type);
0461 return;
0462 } else {
0463 fail_type_inference("Value of attribute ", attributeName, " not specified");
0464 }
0465 }
0466 if (!attr_proto->has_i()) {
0467 fail_type_inference("Attribute ", attributeName, " should be of integer type and specify a type.");
0468 }
0469 auto attr_value = attr_proto->i();
0470 auto elem_type = static_cast<TensorProto_DataType>(attr_value);
0471 if (!TensorProto_DataType_IsValid(elem_type)) {
0472 fail_type_inference("Attribute ", attributeName, " does not specify a valid type.");
0473 }
0474 updateOutputElemType(ctx, outputIndex, elem_type, expected_type);
0475 }
0476
0477 inline void propagateElemTypeFromAttributeToOutput(
0478 InferenceContext& ctx,
0479 const std::string& attributeName,
0480 size_t outputIndex,
0481 TensorProto_DataType default_value = TensorProto::UNDEFINED) {
0482 propagateElemTypeFromAttributeToOutput(ctx, attributeName, outputIndex, TypeProto::kTensorType, default_value);
0483 }
0484
0485 inline TensorShapeProto* getTensorMutableShape(TypeProto::ValueCase value_case, TypeProto& type) {
0486 if (value_case == TypeProto::kTensorType) {
0487 return type.mutable_tensor_type()->mutable_shape();
0488 } else if (value_case == TypeProto::kSparseTensorType) {
0489 return type.mutable_tensor_type()->mutable_shape();
0490 }
0491 return nullptr;
0492 }
0493
0494 inline TensorShapeProto*
0495 getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0496 auto output_type = ctx.getOutputType(n);
0497 if (output_type == nullptr) {
0498 fail_type_inference("Output ", n, " expected to have tensor or sparse type");
0499 }
0500 const auto output_value_case = output_type->value_case();
0501 if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
0502 return getTensorMutableShape(output_value_case, *output_type);
0503 } else if (output_value_case == TypeProto::VALUE_NOT_SET) {
0504 return getTensorMutableShape(default_type, *output_type);
0505 } else {
0506 fail_type_inference("Output ", n, " expected to have tensor type");
0507 }
0508 }
0509
0510 inline void appendDim(TensorShapeProto* shape, int64_t dim_value) {
0511 shape->add_dim()->set_dim_value(dim_value);
0512 }
0513
0514 inline void updateOutputShape(
0515 InferenceContext& ctx,
0516 size_t outputIndex,
0517 const TensorShapeProto& shape,
0518 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0519 auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0520 *output_shape = shape;
0521 }
0522
0523 inline void updateOutputShape(
0524 InferenceContext& ctx,
0525 size_t outputIndex,
0526 const TensorProto& tensorProto,
0527 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0528 auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0529 for (auto d : tensorProto.dims()) {
0530 auto* dim = output_shape->add_dim();
0531 dim->set_dim_value(d);
0532 }
0533 }
0534
0535 inline void updateOutputShape(
0536 InferenceContext& ctx,
0537 size_t outputIndex,
0538 std::initializer_list<TensorShapeProto::Dimension> dims,
0539 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0540 auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0541 for (auto& d : dims) {
0542 auto* dim = output_shape->add_dim();
0543 *dim = d;
0544 }
0545 }
0546
0547
0548
0549
0550
0551 TensorShapeProto getShapeInput(const InferenceContext& ctx, size_t input_index, bool& found);
0552
0553
0554
0555 inline void propagateShapeFromAttributeToOutput(
0556 InferenceContext& ctx,
0557 const std::string& attributeName,
0558 size_t outputIndex,
0559 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0560 auto attr_proto = ctx.getAttribute(attributeName);
0561 if ((nullptr == attr_proto) || (!attr_proto->has_type()) ||
0562 (attr_proto->type() != AttributeProto_AttributeType_INTS)) {
0563 fail_shape_inference("Attribute ", attributeName, " should specify a shape");
0564 }
0565 auto& int_list = attr_proto->ints();
0566 TensorShapeProto shape;
0567 for (auto dim_size : int_list) {
0568 if (dim_size < 0) {
0569 fail_shape_inference("Negative values are not allowed in a shape specification");
0570 }
0571 shape.add_dim()->set_dim_value(dim_size);
0572 }
0573
0574 updateOutputShape(ctx, outputIndex, shape, default_type);
0575 }
0576
0577 inline void multidirectionalBroadcastShapeInference(
0578 const std::vector<const TensorShapeProto*>& shapes,
0579 TensorShapeProto& resultShape) {
0580 int result_shape_size = 0;
0581
0582 for (size_t i = 0; i < shapes.size(); ++i) {
0583 if (shapes[i]->dim_size() > result_shape_size) {
0584 result_shape_size = shapes[i]->dim_size();
0585 }
0586 }
0587
0588 for (int i = 0; i < result_shape_size; ++i) {
0589 int64_t dim_value = 1;
0590 TensorShapeProto_Dimension symbolic_dim;
0591 int num_symbolic_dims = 0;
0592 for (size_t j = 0; j < shapes.size(); ++j) {
0593 if (i < result_shape_size - shapes[j]->dim_size()) {
0594
0595 continue;
0596 }
0597
0598 auto dim_i_j = shapes[j]->dim(i - result_shape_size + shapes[j]->dim_size());
0599 if (dim_i_j.has_dim_value()) {
0600 if (dim_i_j.dim_value() != 1) {
0601 if (dim_value != dim_i_j.dim_value() && dim_value != 1) {
0602 fail_shape_inference("Incompatible dimensions");
0603 } else {
0604 dim_value = dim_i_j.dim_value();
0605 }
0606 }
0607 } else {
0608 if (num_symbolic_dims == 0) {
0609 symbolic_dim = dim_i_j;
0610 ++num_symbolic_dims;
0611 } else if (dim_i_j.dim_param() != symbolic_dim.dim_param()) {
0612 ++num_symbolic_dims;
0613 }
0614 }
0615 }
0616
0617 if (dim_value != 1 || num_symbolic_dims == 0) {
0618 resultShape.add_dim()->set_dim_value(dim_value);
0619 } else if (num_symbolic_dims == 1) {
0620 *resultShape.add_dim() = symbolic_dim;
0621 } else {
0622 resultShape.add_dim();
0623 }
0624 }
0625 }
0626
0627 inline void bidirectionalBroadcastShapeInference(
0628 const TensorShapeProto& shapeL,
0629 const TensorShapeProto& shapeR,
0630 TensorShapeProto& resultShape) {
0631 std::vector<const TensorShapeProto*> shapes;
0632 shapes.push_back(&shapeL);
0633 shapes.push_back(&shapeR);
0634 multidirectionalBroadcastShapeInference(shapes, resultShape);
0635 }
0636
0637
0638
0639
0640
0641
0642
0643
0644
0645
0646
0647
0648
0649 inline void mergeInDimensionInfo(
0650 const TensorShapeProto_Dimension& source_dim,
0651 TensorShapeProto_Dimension& target_dim,
0652 int dim_index) {
0653
0654
0655
0656 if (source_dim.has_dim_value()) {
0657 auto source_value = source_dim.dim_value();
0658 if (target_dim.has_dim_value()) {
0659 auto target_value = target_dim.dim_value();
0660 if (target_value != source_value) {
0661 fail_shape_inference(
0662 "Can't merge shape info. "
0663 "Both inferred and declared dimension have values but they differ. Inferred=",
0664 source_value,
0665 " Declared=",
0666 target_value,
0667 " Dimension=",
0668 dim_index);
0669 }
0670 } else {
0671 target_dim.set_dim_value(source_value);
0672 }
0673 } else if (target_dim.has_dim_value()) {
0674
0675 } else if (target_dim.has_dim_param()) {
0676
0677 } else if (source_dim.has_dim_param()) {
0678 target_dim.set_dim_param(source_dim.dim_param());
0679 }
0680 }
0681
0682 void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
0683
0684 void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
0685
0686
0687
0688
0689
0690
0691
0692
0693
0694
0695
0696
0697 void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target);
0698
0699 void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target);
0700
0701
0702 inline TypeProto RemoveIthDimensionFromShape(const TypeProto& proto, int removed_dim) {
0703 TypeProto t(proto);
0704 auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
0705 mutable_shape->clear_dim();
0706
0707 const auto& dims = proto.tensor_type().shape().dim();
0708
0709 for (int j = 0, end = dims.size(); j < end; ++j) {
0710 if (j != removed_dim)
0711 (*mutable_shape->add_dim()) = dims.Get(j);
0712 }
0713
0714 return t;
0715 }
0716
0717
0718
0719 inline TypeProto RemoveDimensionsFromShape(const TypeProto& proto, int num_dimensions) {
0720 TypeProto t(proto);
0721 auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
0722 mutable_shape->clear_dim();
0723
0724 const auto& dims = proto.tensor_type().shape().dim();
0725
0726
0727 for (int j = num_dimensions, end = dims.size(); j < end; ++j) {
0728 (*mutable_shape->add_dim()) = dims.Get(j);
0729 }
0730
0731 return t;
0732 }
0733
0734
0735
0736 template <class T, class U>
0737 static constexpr T narrow_cast(U&& u) noexcept {
0738 return static_cast<T>(std::forward<U>(u));
0739 }
0740
0741 inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expected_rank) {
0742
0743 if (hasInputShape(ctx, input_index)) {
0744 auto rank = getInputShape(ctx, input_index).dim_size();
0745 if (rank != expected_rank) {
0746 fail_shape_inference("Input ", input_index, " expected to have rank ", expected_rank, " but has rank ", rank);
0747 }
0748 }
0749 }
0750
0751
0752
0753
0754
0755
0756
0757
0758 inline void checkDimEquality(int64_t value1, int64_t value2) {
0759 if (value1 != value2) {
0760 fail_shape_inference("Dimension mismatch in unification between ", value1, " and ", value2);
0761 }
0762 }
0763
0764 inline void unifyDim(const Dim& dim1, const Dim& dim2) {
0765 if (dim1.has_dim_value() && dim2.has_dim_value())
0766 checkDimEquality(dim1.dim_value(), dim2.dim_value());
0767 }
0768
0769
0770
0771
0772 inline void unifyDim(const Dim& source_dim, Dim& target_dim) {
0773 if (source_dim.has_dim_value()) {
0774 auto source_value = source_dim.dim_value();
0775 if (target_dim.has_dim_value()) {
0776 auto target_value = target_dim.dim_value();
0777 checkDimEquality(source_value, target_value);
0778 } else {
0779 target_dim.set_dim_value(source_value);
0780 }
0781 } else if (target_dim.has_dim_value()) {
0782
0783
0784 } else if (target_dim.has_dim_param()) {
0785
0786
0787 } else if (source_dim.has_dim_param()) {
0788 target_dim.set_dim_param(source_dim.dim_param());
0789 }
0790 }
0791
0792 inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_index, Dim& dim) {
0793
0794 if (hasInputShape(ctx, input_index)) {
0795 auto& input_shape = getInputShape(ctx, input_index);
0796
0797 if (input_shape.dim_size() <= dim_index) {
0798 fail_shape_inference(
0799 "Input ", input_index, " expected to have rank >", dim_index, " but has rank ", input_shape.dim_size());
0800 }
0801 const Dim& input_dim = input_shape.dim(dim_index);
0802
0803 unifyDim(input_dim, dim);
0804 }
0805 }
0806
0807
0808
0809 inline void unifyDim(Dim& dim, int64_t value) {
0810 if (dim.has_dim_value()) {
0811 checkDimEquality(dim.dim_value(), value);
0812 } else
0813 dim.set_dim_value(value);
0814 }
0815
0816
0817
0818
0819
0820
0821
0822
0823
0824
0825 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
0826
0827 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
0828
0829
0830
0831
0832
0833
0834
0835
0836
0837
0838
0839 void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type);
0840
0841
0842 template <typename Axes>
0843 void adjustNegativeAxes(Axes& axes, int rank) {
0844 std::transform(
0845 axes.begin(), axes.end(), axes.begin(), [&](int64_t axis) -> int64_t { return axis < 0 ? axis + rank : axis; });
0846 }
0847
0848
0849 template <typename Axes>
0850 void checkAxesRange(Axes& axes, int rank) {
0851 for (auto axis : axes) {
0852 if (axis < -rank || axis > (rank - 1))
0853 fail_shape_inference("Unexpected axis value: ", axis, ". Expected range [", -rank, ", ", rank, ")");
0854 }
0855 }
0856
0857
0858 template <typename Axes>
0859 void checkDuplicateAxes(Axes& axes, int rank) {
0860 std::vector<bool> tmp(rank, false);
0861 for (auto axis : axes) {
0862 int actual_axis = axis < 0 ? axis + rank : axis;
0863 if (tmp[actual_axis])
0864 fail_shape_inference("Axis ", axis, " is referred to more than once.");
0865 tmp[actual_axis] = true;
0866 }
0867 }
0868
0869 }