Warning, file /include/onnx/defs/shape_inference.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include <algorithm>
0008 #include <functional>
0009 #include <string>
0010 #include <utility>
0011 #include <vector>
0012
0013 #include "onnx/defs/data_type_utils.h"
0014 #include "onnx/proto_utils.h"
0015 #include "onnx/string_utils.h"
0016
0017 namespace ONNX_NAMESPACE {
0018
0019 using Dim = TensorShapeProto_Dimension;
0020
0021 struct ShapeInferenceOptions {
0022
0023 bool check_type;
0024
0025
0026
0027 int error_mode;
0028
0029
0030 bool enable_data_propagation;
0031 ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)
0032 : check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){};
0033 };
0034
0035
0036 class SymbolTable {
0037 public:
0038
0039 virtual void addFromGraph(const GraphProto& g) = 0;
0040
0041 std::string createNew() {
0042 return createNew("unk__");
0043 }
0044 virtual std::string createNew(const std::string& symbol_prefix) = 0;
0045 virtual ~SymbolTable() = default;
0046 };
0047
0048 class GraphInferencer {
0049 public:
0050
0051
0052 virtual std::vector<const TypeProto*> doInferencing(
0053 const std::vector<const TypeProto*>& inputTypes,
0054 const std::vector<const TensorProto*>& inputData) = 0;
0055 virtual ~GraphInferencer() = default;
0056 };
0057
0058
0059
0060 class InferenceError final : public std::runtime_error {
0061 public:
0062 using std::runtime_error::runtime_error;
0063
0064 InferenceError(const std::string& message) : std::runtime_error(message) {}
0065
0066 const char* what() const noexcept override {
0067 if (!expanded_message_.empty()) {
0068 return expanded_message_.c_str();
0069 }
0070 return std::runtime_error::what();
0071 }
0072
0073 void AppendContext(const std::string& context) {
0074 expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
0075 }
0076
0077 private:
0078 std::string expanded_message_;
0079 };
0080
0081 #define fail_type_inference(...) \
0082 ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__)));
0083
0084 #define fail_shape_inference(...) \
0085 ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__)));
0086
0087 struct InferenceContext {
0088 virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0089 virtual size_t getNumInputs() const = 0;
0090 virtual const TypeProto* getInputType(size_t index) const = 0;
0091 virtual bool hasInput(size_t index) const {
0092
0093
0094
0095
0096
0097
0098 return ((index < getNumInputs()) && (getInputType(index) != nullptr));
0099 }
0100 virtual const TensorProto* getInputData(size_t index) const = 0;
0101 virtual size_t getNumOutputs() const = 0;
0102 virtual TypeProto* getOutputType(size_t index) = 0;
0103 virtual GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) = 0;
0104 virtual ~InferenceContext() {}
0105 virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0;
0106
0107 virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0;
0108
0109 virtual std::string getDisplayName() const {
0110 return "";
0111 }
0112 };
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125 struct DataPropagationContext {
0126 virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0127 virtual size_t getNumInputs() const = 0;
0128 virtual const TypeProto* getInputType(size_t index) const = 0;
0129 virtual size_t getNumOutputs() const = 0;
0130 virtual const TypeProto* getOutputType(size_t index) const = 0;
0131 virtual ~DataPropagationContext() {}
0132 virtual const TensorShapeProto* getInputData(size_t index) = 0;
0133 virtual void addOutputData(size_t index, TensorShapeProto&& tp) = 0;
0134 };
0135
0136 using InferenceFunction = std::function<void(InferenceContext&)>;
0137 using DataPropagationFunction = std::function<void(DataPropagationContext&)>;
0138
0139
0140
0141 inline void dummyInferenceFunction(InferenceContext&){};
0142
0143
0144 inline void dummyDataPropagationFunction(DataPropagationContext&){};
0145
0146 template <typename T>
0147 inline bool getRepeatedAttribute(InferenceContext& ctx, std::string attr_name, std::vector<T>& values) {
0148 const auto* attr = ctx.getAttribute(attr_name);
0149 if (attr) {
0150 values = RetrieveValues<T>(*attr);
0151 return true;
0152 } else {
0153 return false;
0154 }
0155 }
0156
0157 inline int64_t getAttribute(InferenceContext& ctx, const std::string& attributeName, int64_t defaultValue) {
0158 auto attr_proto = ctx.getAttribute(attributeName);
0159 if ((nullptr != attr_proto) && attr_proto->has_i())
0160 return attr_proto->i();
0161 return defaultValue;
0162 }
0163
0164 inline int64_t getAttribute(DataPropagationContext& ctx, const std::string& attributeName, int64_t defaultValue) {
0165 auto attr_proto = ctx.getAttribute(attributeName);
0166 if ((nullptr != attr_proto) && attr_proto->has_i())
0167 return attr_proto->i();
0168 return defaultValue;
0169 }
0170
0171 inline std::string
0172 getAttribute(InferenceContext& ctx, const std::string& attributeName, const std::string& defaultValue) {
0173 auto attr_proto = ctx.getAttribute(attributeName);
0174 if ((nullptr != attr_proto) && attr_proto->has_s())
0175 return attr_proto->s();
0176 return defaultValue;
0177 }
0178
0179 inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, TensorShapeProto::Dimension dim2) {
0180 TensorShapeProto::Dimension result;
0181 if (dim1.has_dim_value() && dim2.has_dim_value()) {
0182 result.set_dim_value(dim1.dim_value() * dim2.dim_value());
0183 } else if (dim1.has_dim_value() && (dim1.dim_value() == 1)) {
0184 return dim2;
0185 } else if (dim2.has_dim_value() && (dim2.dim_value() == 1)) {
0186 return dim1;
0187 }
0188 return result;
0189 }
0190
0191 template <typename Container>
0192 std::string stringify(const Container& elements);
0193
0194 std::pair<int, int> getAttributeProtoElemTypeAndLength(const AttributeProto* attr_proto);
0195
0196 std::pair<int, int> getAttributeElementTypeAndLength(
0197 const InferenceContext& ctx,
0198 const std::initializer_list<std::string>& attribute_names);
0199
0200 inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, int64_t dim2) {
0201 TensorShapeProto::Dimension result;
0202 if (dim1.has_dim_value()) {
0203 result.set_dim_value(dim1.dim_value() * dim2);
0204 } else if (dim2 == 1) {
0205 return dim1;
0206 }
0207 return result;
0208 }
0209
0210 inline TensorShapeProto::Dimension operator/(TensorShapeProto::Dimension dim1, int64_t dim2) {
0211 TensorShapeProto::Dimension result;
0212 if (dim1.has_dim_value()) {
0213 result.set_dim_value(dim1.dim_value() / dim2);
0214 } else if (dim2 == 1) {
0215 return dim1;
0216 }
0217 return result;
0218 }
0219
0220
0221
0222
0223 inline TensorShapeProto::Dimension multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) {
0224 TensorShapeProto::Dimension dim;
0225 dim.set_dim_value(1);
0226 for (int i = from; i < upto_exclusive; ++i) {
0227 dim = dim * shape.dim(i);
0228 }
0229 return dim;
0230 }
0231
0232 inline int32_t getTensorElementType(const TypeProto& type) {
0233 int32_t result = TensorProto::UNDEFINED;
0234 const auto value_case = type.value_case();
0235 if (value_case == TypeProto::kTensorType) {
0236 result = type.tensor_type().elem_type();
0237 } else if (value_case == TypeProto::kSparseTensorType) {
0238 result = type.sparse_tensor_type().elem_type();
0239 }
0240 return result;
0241 }
0242
0243 inline void setTensorElementType(int32_t elem_type, TypeProto::ValueCase value_case, TypeProto& type) {
0244 if (value_case == TypeProto::kTensorType) {
0245 type.mutable_tensor_type()->set_elem_type(elem_type);
0246 } else if (value_case == TypeProto::kSparseTensorType) {
0247 type.mutable_sparse_tensor_type()->set_elem_type(elem_type);
0248 }
0249 }
0250
0251 void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type);
0252
0253 void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
0254
0255 void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
0256
0257 inline void propagateElemTypeFromDtypeToOutput(
0258 InferenceContext& ctx,
0259 const int data_type,
0260 size_t outputIndex,
0261 TypeProto::ValueCase expected_value_case) {
0262 const auto attribute_tensor_datatype = data_type;
0263 auto output_type = ctx.getOutputType(outputIndex);
0264 const auto output_value_case = output_type->value_case();
0265 if (output_value_case == TypeProto::VALUE_NOT_SET || output_value_case == expected_value_case) {
0266 setTensorElementType(attribute_tensor_datatype, expected_value_case, *output_type);
0267 } else {
0268
0269 fail_type_inference(
0270 "Output ",
0271 outputIndex,
0272 " expected to have: ",
0273 expected_value_case,
0274 " or UNDEFINED. Got: ",
0275 output_value_case,
0276 " in ",
0277 ctx.getDisplayName(),
0278 ".");
0279 }
0280 }
0281
0282 inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const int data_type, size_t outputIndex) {
0283 propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, TypeProto::kTensorType);
0284 }
0285
0286 inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const AttributeProto* attr, size_t outputIndex) {
0287 int32_t data_type = TensorProto::UNDEFINED;
0288 TypeProto::ValueCase expected_value_case = TypeProto::VALUE_NOT_SET;
0289 const auto attr_type = attr->type();
0290 if (attr_type == AttributeProto::TENSOR) {
0291 if (attr->t().dims().size() != 1) {
0292 fail_type_inference("Attribute expected to have a one-dim tensor in ", ctx.getDisplayName(), ".");
0293 }
0294 data_type = attr->t().data_type();
0295 expected_value_case = TypeProto::kTensorType;
0296 } else if (attr_type == AttributeProto::SPARSE_TENSOR) {
0297 if (attr->sparse_tensor().dims().size() != 1) {
0298 fail_type_inference("Attribute expected to have a one-dim sparse tensor in ", ctx.getDisplayName(), ".");
0299 }
0300 data_type = attr->sparse_tensor().values().data_type();
0301 expected_value_case = TypeProto::kSparseTensorType;
0302 } else {
0303 fail_type_inference("Attribute expected to have tensor or sparse tensor type in ", ctx.getDisplayName(), ".");
0304 }
0305
0306 propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case);
0307 }
0308
0309 inline bool hasShape(const TypeProto& type) {
0310 if (type.has_tensor_type()) {
0311 return type.tensor_type().has_shape();
0312 } else if (type.has_sparse_tensor_type()) {
0313 return type.sparse_tensor_type().has_shape();
0314 } else if (type.has_sequence_type() && type.sequence_type().has_elem_type()) {
0315 return hasShape(type.sequence_type().elem_type());
0316 } else if (type.has_optional_type() && type.optional_type().has_elem_type()) {
0317 return hasShape(type.optional_type().elem_type());
0318 }
0319 return false;
0320 }
0321
0322 template <typename Context>
0323 inline bool hasInputShape(const Context& ctx, size_t n) {
0324 return ctx.getNumInputs() > static_cast<size_t>(n) && ctx.getInputType(n) && hasShape(*ctx.getInputType(n));
0325 }
0326
0327 template <typename Context>
0328 inline bool hasNInputShapes(const Context& ctx, size_t n) {
0329 for (size_t i = 0; i < n; i++) {
0330 if (!hasInputShape(ctx, i)) {
0331 return false;
0332 }
0333 }
0334 return true;
0335 }
0336
0337 inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t n) {
0338 const auto* input_type = ctx.getInputType(n);
0339 const auto value_case = input_type->value_case();
0340 if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
0341 fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), ".");
0342 }
0343 if (!hasShape(*input_type)) {
0344 fail_shape_inference("Input ", n, " must have a non null shape in ", ctx.getDisplayName(), ".");
0345 }
0346 if (value_case == TypeProto::kTensorType) {
0347 return input_type->tensor_type().shape();
0348 } else {
0349 return input_type->sparse_tensor_type().shape();
0350 }
0351 }
0352
0353 inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size_t n) {
0354 const auto* input_type = ctx.getInputType(n);
0355
0356 if (input_type == nullptr) {
0357 return nullptr;
0358 }
0359
0360 const auto value_case = input_type->value_case();
0361 if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
0362 fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), ".");
0363 }
0364 if (value_case == TypeProto::kTensorType) {
0365 return &input_type->tensor_type().shape();
0366 } else {
0367 return &input_type->sparse_tensor_type().shape();
0368 }
0369 }
0370
0371
0372 inline void appendSingleDimCopiedFromInputTypeToOutputType(
0373 InferenceContext& ctx,
0374 size_t inputIndex,
0375 size_t outputIndex,
0376 size_t fromDimIndex) {
0377 auto output_type = ctx.getOutputType(outputIndex);
0378 const auto output_value_case = output_type->value_case();
0379 auto input_type = ctx.getInputType(inputIndex);
0380 const auto input_value_case = input_type->value_case();
0381 if (output_value_case != input_value_case) {
0382 fail_type_inference(
0383 "Input: ",
0384 inputIndex,
0385 " type: ",
0386 input_value_case,
0387 " does not match type of output: ",
0388 outputIndex,
0389 "type: ",
0390 output_value_case,
0391 " in ",
0392 ctx.getDisplayName(),
0393 ".");
0394 }
0395 if (TypeProto::kTensorType == input_value_case) {
0396 auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim();
0397 *dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex));
0398 } else if (TypeProto::kSparseTensorType == input_value_case) {
0399 auto* dim = output_type->mutable_sparse_tensor_type()->mutable_shape()->add_dim();
0400 *dim = input_type->sparse_tensor_type().shape().dim(static_cast<int>(fromDimIndex));
0401 } else {
0402 fail_type_inference(
0403 "Input ",
0404 inputIndex,
0405 " and Output ",
0406 outputIndex,
0407 " expected to have tensor or sparse tensor type in ",
0408 ctx.getDisplayName(),
0409 ".");
0410 }
0411 }
0412
0413 inline void propagateShape(const TypeProto* from_type, TypeProto* to_type) {
0414 const auto from_type_case = from_type->value_case();
0415 const auto to_type_case = to_type->value_case();
0416 if (from_type_case != to_type_case) {
0417 fail_shape_inference(
0418 "Mismatch between inferred and declared type. Inferred=", from_type_case, " Declared=", to_type_case);
0419 }
0420
0421 if (TypeProto::kTensorType == from_type_case || TypeProto::kSparseTensorType == from_type_case) {
0422
0423
0424 if (hasShape(*from_type)) {
0425 if (TypeProto::kTensorType == from_type_case) {
0426 *to_type->mutable_tensor_type()->mutable_shape() = from_type->tensor_type().shape();
0427 } else {
0428 *to_type->mutable_sparse_tensor_type()->mutable_shape() = from_type->sparse_tensor_type().shape();
0429 }
0430 }
0431 } else if (TypeProto::kSequenceType == from_type_case) {
0432 propagateShape(&from_type->sequence_type().elem_type(), to_type->mutable_sequence_type()->mutable_elem_type());
0433 } else if (TypeProto::kOptionalType == from_type_case) {
0434 propagateShape(&from_type->optional_type().elem_type(), to_type->mutable_optional_type()->mutable_elem_type());
0435 } else if (TypeProto::kMapType == from_type_case) {
0436 propagateShape(&from_type->map_type().value_type(), to_type->mutable_map_type()->mutable_value_type());
0437 } else {
0438 fail_shape_inference("Unsupported Source/Target type=", from_type_case);
0439 }
0440 }
0441
0442 inline void propagateShapeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
0443 auto output_type = ctx.getOutputType(outputIndex);
0444 auto input_type = ctx.getInputType(inputIndex);
0445
0446 propagateShape(input_type, output_type);
0447 }
0448
0449 inline void propagateShapeAndTypeFromFirstInput(InferenceContext& ctx) {
0450 propagateElemTypeFromInputToOutput(ctx, 0, 0);
0451 if (!hasNInputShapes(ctx, 1)) {
0452 return;
0453 }
0454 propagateShapeFromInputToOutput(ctx, 0, 0);
0455 }
0456
0457 inline void
0458 updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType, TypeProto::ValueCase expected_type) {
0459 auto output_type = ctx.getOutputType(outputIndex);
0460 if (output_type == nullptr) {
0461 fail_type_inference("Output ", outputIndex, " is null");
0462 }
0463 if (output_type->value_case() == expected_type || output_type->value_case() == TypeProto::VALUE_NOT_SET) {
0464 setTensorElementType(elemType, expected_type, *output_type);
0465 } else {
0466
0467 fail_type_inference(
0468 "Output ",
0469 outputIndex,
0470 " expected to have tensor or sparse tensor type: ",
0471 expected_type,
0472 " in ",
0473 ctx.getDisplayName(),
0474 ".");
0475 }
0476 }
0477
0478 inline void updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType) {
0479 updateOutputElemType(ctx, outputIndex, elemType, TypeProto::kTensorType);
0480 }
0481
0482
0483
0484 inline void propagateElemTypeFromAttributeToOutput(
0485 InferenceContext& ctx,
0486 const std::string& attributeName,
0487 size_t outputIndex,
0488 TypeProto::ValueCase expected_type,
0489 TensorProto_DataType default_value = TensorProto::UNDEFINED) {
0490 auto attr_proto = ctx.getAttribute(attributeName);
0491 if (nullptr == attr_proto) {
0492 if (default_value != TensorProto::UNDEFINED) {
0493 updateOutputElemType(ctx, outputIndex, default_value, expected_type);
0494 return;
0495 } else {
0496 fail_type_inference("Value of attribute ", attributeName, " not specified in ", ctx.getDisplayName(), ".");
0497 }
0498 }
0499 if (!attr_proto->has_i()) {
0500 fail_type_inference(
0501 "Attribute ", attributeName, " should be of integer type and specify a type in ", ctx.getDisplayName(), ".");
0502 }
0503 auto attr_value = attr_proto->i();
0504 auto elem_type = static_cast<TensorProto_DataType>(attr_value);
0505 if (!TensorProto_DataType_IsValid(elem_type)) {
0506 fail_type_inference("Attribute ", attributeName, " does not specify a valid type in ", ctx.getDisplayName(), ".");
0507 }
0508 updateOutputElemType(ctx, outputIndex, elem_type, expected_type);
0509 }
0510
0511 inline void propagateElemTypeFromAttributeToOutput(
0512 InferenceContext& ctx,
0513 const std::string& attributeName,
0514 size_t outputIndex,
0515 TensorProto_DataType default_value = TensorProto::UNDEFINED) {
0516 propagateElemTypeFromAttributeToOutput(ctx, attributeName, outputIndex, TypeProto::kTensorType, default_value);
0517 }
0518
0519 inline TensorShapeProto* getTensorMutableShape(TypeProto::ValueCase value_case, TypeProto& type) {
0520 if (value_case == TypeProto::kTensorType) {
0521 return type.mutable_tensor_type()->mutable_shape();
0522 } else if (value_case == TypeProto::kSparseTensorType) {
0523 return type.mutable_tensor_type()->mutable_shape();
0524 }
0525 return nullptr;
0526 }
0527
0528 inline TensorShapeProto*
0529 getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0530 auto output_type = ctx.getOutputType(n);
0531 if (output_type == nullptr) {
0532 fail_type_inference("Output ", n, " expected to have tensor or sparse type in ", ctx.getDisplayName(), ".");
0533 }
0534 const auto output_value_case = output_type->value_case();
0535 if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
0536 return getTensorMutableShape(output_value_case, *output_type);
0537 } else if (output_value_case == TypeProto::VALUE_NOT_SET) {
0538 return getTensorMutableShape(default_type, *output_type);
0539 } else {
0540 fail_type_inference("Output ", n, " expected to have tensor type in ", ctx.getDisplayName(), ".");
0541 }
0542 }
0543
0544 inline void appendDim(TensorShapeProto* shape, int64_t dim_value) {
0545 shape->add_dim()->set_dim_value(dim_value);
0546 }
0547
0548 inline void updateOutputShape(
0549 InferenceContext& ctx,
0550 size_t outputIndex,
0551 const TensorShapeProto& shape,
0552 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0553 auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0554 *output_shape = shape;
0555 }
0556
0557 inline void updateOutputShape(
0558 InferenceContext& ctx,
0559 size_t outputIndex,
0560 const TensorProto& tensorProto,
0561 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0562 auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0563 for (auto d : tensorProto.dims()) {
0564 auto* dim = output_shape->add_dim();
0565 dim->set_dim_value(d);
0566 }
0567 }
0568
0569 inline void updateOutputShape(
0570 InferenceContext& ctx,
0571 size_t outputIndex,
0572 std::initializer_list<TensorShapeProto::Dimension> dims,
0573 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0574 auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0575 for (auto& d : dims) {
0576 auto* dim = output_shape->add_dim();
0577 *dim = d;
0578 }
0579 }
0580
0581
0582
0583
0584
0585 TensorShapeProto getShapeInput(const InferenceContext& ctx, size_t input_index, bool& found);
0586
0587
0588
0589 inline void propagateShapeFromAttributeToOutput(
0590 InferenceContext& ctx,
0591 const std::string& attributeName,
0592 size_t outputIndex,
0593 TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0594 auto attr_proto = ctx.getAttribute(attributeName);
0595 if ((nullptr == attr_proto) || (!attr_proto->has_type()) ||
0596 (attr_proto->type() != AttributeProto_AttributeType_INTS)) {
0597 fail_shape_inference("Attribute ", attributeName, " should specify a shape in ", ctx.getDisplayName(), ".");
0598 }
0599 auto& int_list = attr_proto->ints();
0600 TensorShapeProto shape;
0601 for (auto dim_size : int_list) {
0602 if (dim_size < 0) {
0603 fail_shape_inference("Negative values are not allowed in a shape specification in ", ctx.getDisplayName(), ".");
0604 }
0605 shape.add_dim()->set_dim_value(dim_size);
0606 }
0607
0608 updateOutputShape(ctx, outputIndex, shape, default_type);
0609 }
0610
0611 inline void multidirectionalBroadcastShapeInference(
0612 const std::vector<const TensorShapeProto*>& shapes,
0613 TensorShapeProto& resultShape) {
0614 int result_shape_size = 0;
0615
0616 for (size_t i = 0; i < shapes.size(); ++i) {
0617 if (shapes[i]->dim_size() > result_shape_size) {
0618 result_shape_size = shapes[i]->dim_size();
0619 }
0620 }
0621
0622 for (int i = 0; i < result_shape_size; ++i) {
0623 int64_t dim_value = 1;
0624 TensorShapeProto_Dimension symbolic_dim;
0625 int num_symbolic_dims = 0;
0626 for (size_t j = 0; j < shapes.size(); ++j) {
0627 if (i < result_shape_size - shapes[j]->dim_size()) {
0628
0629 continue;
0630 }
0631
0632 auto dim_i_j = shapes[j]->dim(i - result_shape_size + shapes[j]->dim_size());
0633 if (dim_i_j.has_dim_value()) {
0634 if (dim_i_j.dim_value() != 1) {
0635 if (dim_value != dim_i_j.dim_value() && dim_value != 1) {
0636 fail_shape_inference("Incompatible dimensions");
0637 } else {
0638 dim_value = dim_i_j.dim_value();
0639 }
0640 }
0641 } else {
0642 if (num_symbolic_dims == 0) {
0643 symbolic_dim = dim_i_j;
0644 ++num_symbolic_dims;
0645 } else if (dim_i_j.dim_param() != symbolic_dim.dim_param()) {
0646 ++num_symbolic_dims;
0647 }
0648 }
0649 }
0650
0651 if (dim_value != 1 || num_symbolic_dims == 0) {
0652 resultShape.add_dim()->set_dim_value(dim_value);
0653 } else if (num_symbolic_dims == 1) {
0654 *resultShape.add_dim() = symbolic_dim;
0655 } else {
0656 resultShape.add_dim();
0657 }
0658 }
0659 }
0660
0661 inline void bidirectionalBroadcastShapeInference(
0662 const TensorShapeProto& shapeL,
0663 const TensorShapeProto& shapeR,
0664 TensorShapeProto& resultShape) {
0665 std::vector<const TensorShapeProto*> shapes;
0666 shapes.push_back(&shapeL);
0667 shapes.push_back(&shapeR);
0668 multidirectionalBroadcastShapeInference(shapes, resultShape);
0669 }
0670
0671
0672
0673
0674
0675
0676
0677
0678
0679
0680
0681
0682
0683 inline void mergeInDimensionInfo(
0684 const TensorShapeProto_Dimension& source_dim,
0685 TensorShapeProto_Dimension& target_dim,
0686 int dim_index) {
0687
0688
0689
0690 if (source_dim.has_dim_value()) {
0691 auto source_value = source_dim.dim_value();
0692 if (target_dim.has_dim_value()) {
0693 auto target_value = target_dim.dim_value();
0694 if (target_value != source_value) {
0695 fail_shape_inference(
0696 "Can't merge shape info. "
0697 "Both inferred and declared dimension have values but they differ. Inferred=",
0698 source_value,
0699 " Declared=",
0700 target_value,
0701 " Dimension=",
0702 dim_index);
0703 }
0704 } else {
0705 target_dim.set_dim_value(source_value);
0706 }
0707 } else if (target_dim.has_dim_value()) {
0708
0709 } else if (target_dim.has_dim_param()) {
0710
0711 } else if (source_dim.has_dim_param()) {
0712 target_dim.set_dim_param(source_dim.dim_param());
0713 }
0714 }
0715
0716 void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
0717
0718 void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
0719
0720
0721
0722
0723
0724
0725
0726
0727
0728
0729
0730
0731 void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target);
0732
0733 void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target);
0734
0735
0736 inline TypeProto RemoveIthDimensionFromShape(const TypeProto& proto, int removed_dim) {
0737 TypeProto t(proto);
0738 auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
0739 mutable_shape->clear_dim();
0740
0741 const auto& dims = proto.tensor_type().shape().dim();
0742
0743 for (int j = 0, end = dims.size(); j < end; ++j) {
0744 if (j != removed_dim)
0745 (*mutable_shape->add_dim()) = dims.Get(j);
0746 }
0747
0748 return t;
0749 }
0750
0751
0752
0753 inline TypeProto RemoveDimensionsFromShape(const TypeProto& proto, int num_dimensions) {
0754 TypeProto t(proto);
0755 auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
0756 mutable_shape->clear_dim();
0757
0758 const auto& dims = proto.tensor_type().shape().dim();
0759
0760
0761 for (int j = num_dimensions, end = dims.size(); j < end; ++j) {
0762 (*mutable_shape->add_dim()) = dims.Get(j);
0763 }
0764
0765 return t;
0766 }
0767
0768
0769
0770 template <class T, class U>
0771 static constexpr T narrow_cast(U&& u) noexcept {
0772 return static_cast<T>(std::forward<U>(u));
0773 }
0774
0775 inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expected_rank) {
0776
0777 if (hasInputShape(ctx, input_index)) {
0778 auto rank = getInputShape(ctx, input_index).dim_size();
0779 if (rank != expected_rank) {
0780 fail_shape_inference(
0781 "Input ",
0782 input_index,
0783 " expected to have rank ",
0784 expected_rank,
0785 " but has rank ",
0786 rank,
0787 " in ",
0788 ctx.getDisplayName(),
0789 ".");
0790 }
0791 }
0792 }
0793
0794
0795
0796
0797
0798
0799
0800
0801 inline void checkDimEquality(int64_t value1, int64_t value2) {
0802 if (value1 != value2) {
0803 fail_shape_inference("Dimension mismatch in unification between ", value1, " and ", value2);
0804 }
0805 }
0806
0807 inline void unifyDim(const Dim& dim1, const Dim& dim2) {
0808 if (dim1.has_dim_value() && dim2.has_dim_value())
0809 checkDimEquality(dim1.dim_value(), dim2.dim_value());
0810 }
0811
0812
0813
0814
0815 inline void unifyDim(const Dim& source_dim, Dim& target_dim) {
0816 if (source_dim.has_dim_value()) {
0817 auto source_value = source_dim.dim_value();
0818 if (target_dim.has_dim_value()) {
0819 auto target_value = target_dim.dim_value();
0820 checkDimEquality(source_value, target_value);
0821 } else {
0822 target_dim.set_dim_value(source_value);
0823 }
0824 } else if (target_dim.has_dim_value()) {
0825
0826
0827 } else if (target_dim.has_dim_param()) {
0828
0829
0830 } else if (source_dim.has_dim_param()) {
0831 target_dim.set_dim_param(source_dim.dim_param());
0832 }
0833 }
0834
0835 inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_index, Dim& dim) {
0836
0837 if (hasInputShape(ctx, input_index)) {
0838 auto& input_shape = getInputShape(ctx, input_index);
0839
0840 if (input_shape.dim_size() <= dim_index) {
0841 fail_shape_inference(
0842 "Input ",
0843 input_index,
0844 " expected to have rank >",
0845 dim_index,
0846 " but has rank ",
0847 input_shape.dim_size(),
0848 " in ",
0849 ctx.getDisplayName(),
0850 ".");
0851 }
0852 const Dim& input_dim = input_shape.dim(dim_index);
0853
0854 unifyDim(input_dim, dim);
0855 }
0856 }
0857
0858
0859
0860 inline void unifyDim(Dim& dim, int64_t value) {
0861 if (dim.has_dim_value()) {
0862 checkDimEquality(dim.dim_value(), value);
0863 } else
0864 dim.set_dim_value(value);
0865 }
0866
0867
0868
0869
0870
0871
0872
0873
0874
0875
0876 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
0877
0878 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
0879
0880
0881
0882
0883
0884
0885
0886
0887
0888
0889
0890 void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type);
0891
0892
0893 template <typename Axes>
0894 void adjustNegativeAxes(Axes& axes, int rank) {
0895 std::transform(
0896 axes.begin(), axes.end(), axes.begin(), [&](int64_t axis) -> int64_t { return axis < 0 ? axis + rank : axis; });
0897 }
0898
0899
0900 template <typename Axes>
0901 void checkAxesRange(Axes& axes, int rank) {
0902 for (auto axis : axes) {
0903 if (axis < -rank || axis > (rank - 1))
0904 fail_shape_inference("Unexpected axis value: ", axis, ". Expected range [", -rank, ", ", rank, ")");
0905 }
0906 }
0907
0908
0909 template <typename Axes>
0910 void checkDuplicateAxes(Axes& axes, int rank) {
0911 std::vector<bool> tmp(rank, false);
0912 for (auto axis : axes) {
0913 int actual_axis = axis < 0 ? axis + rank : axis;
0914 if (tmp[actual_axis])
0915 fail_shape_inference("Axis ", axis, " is referred to more than once.");
0916 tmp[actual_axis] = true;
0917 }
0918 }
0919
0920 }