Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/onnx/defs/shape_inference.h was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 /*
0002  * SPDX-License-Identifier: Apache-2.0
0003  */
0004 
0005 #pragma once
0006 
0007 #include <algorithm>
0008 #include <functional>
0009 #include <string>
0010 #include <utility>
0011 #include <vector>
0012 
0013 #include "onnx/defs/data_type_utils.h"
0014 #include "onnx/proto_utils.h"
0015 #include "onnx/string_utils.h"
0016 
0017 namespace ONNX_NAMESPACE {
0018 
0019 using Dim = TensorShapeProto_Dimension;
0020 
0021 struct ShapeInferenceOptions {
0022   // Checks the type-equality for input and output
0023   bool check_type;
0024   // 1: Will throw any node level shape infer errors
0025   // 0: Won't throw node-level shape infer errors, but other errors
0026   // like merging existing shape with inferred etc are thrown
0027   int error_mode;
0028   // Enables data propagation for limited operators
0029   // to perform shape computation
0030   bool enable_data_propagation;
0031   ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)
0032       : check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){};
0033 };
0034 
0035 // Maintains a SymbolTable for symbolic shape inference
0036 class SymbolTable {
0037  public:
0038   // Adds existing symbols from a main graph or subgraph
0039   virtual void addFromGraph(const GraphProto& g) = 0;
0040   // Creates a new symbol which is not duplicate as any existing one
0041   std::string createNew() {
0042     return createNew("unk__");
0043   }
0044   virtual std::string createNew(const std::string& symbol_prefix) = 0;
0045   virtual ~SymbolTable() = default;
0046 };
0047 
0048 class GraphInferencer {
0049  public:
0050   // Perform inferencing on the graph contained in GraphInferencer.
0051   // Returns the graph output types post-inferencing.
0052   virtual std::vector<const TypeProto*> doInferencing(
0053       const std::vector<const TypeProto*>& inputTypes,
0054       const std::vector<const TensorProto*>& inputData) = 0;
0055   virtual ~GraphInferencer() = default;
0056 };
0057 
0058 // Exception class used for handling errors in type and shape inference
0059 
0060 class InferenceError final : public std::runtime_error {
0061  public:
0062   using std::runtime_error::runtime_error;
0063 
0064   InferenceError(const std::string& message) : std::runtime_error(message) {}
0065 
0066   const char* what() const noexcept override {
0067     if (!expanded_message_.empty()) {
0068       return expanded_message_.c_str();
0069     }
0070     return std::runtime_error::what();
0071   }
0072 
0073   void AppendContext(const std::string& context) {
0074     expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
0075   }
0076 
0077  private:
0078   std::string expanded_message_;
0079 };
0080 
0081 #define fail_type_inference(...) \
0082   ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__)));
0083 
0084 #define fail_shape_inference(...) \
0085   ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__)));
0086 
0087 struct InferenceContext {
0088   virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0089   virtual size_t getNumInputs() const = 0;
0090   virtual const TypeProto* getInputType(size_t index) const = 0;
0091   virtual bool hasInput(size_t index) const {
0092     // The default implementation below is used for backward-compatibility
0093     // for implementations of InferenceContext that don't provide an explicit
0094     // implementation. This works for normal usage, but may be imprecise in
0095     // the edge-case where an input is supplied but has no known type.
0096     // However, inference-methods work only under the assumption that the
0097     // input-types of all inputs are known.
0098     return ((index < getNumInputs()) && (getInputType(index) != nullptr));
0099   }
0100   virtual const TensorProto* getInputData(size_t index) const = 0;
0101   virtual size_t getNumOutputs() const = 0;
0102   virtual TypeProto* getOutputType(size_t index) = 0;
0103   virtual GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) = 0;
0104   virtual ~InferenceContext() {}
0105   virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0;
0106   // Gets the shape inputs computed by partial data propagation.
0107   virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0;
0108   // To display a name the user can use to narrow its search.
0109   virtual std::string getDisplayName() const {
0110     return "";
0111   }
0112 };
0113 
0114 // We use data propagation to perform partial evaluation of the model, to compute statically
0115 // known information about tensor values. It is intended to improve the precision of shape
0116 // inference. We reuse TensorShapeProto to represent the statically known values. One
0117 // limitation of this is that TensorShapeProto can represent only integer values.
0118 // As an example, data-propagation is intended to handle code-fragments like below:
0119 //   shape = Shape(X)
0120 //   batchsize = Slice(shape, [0], [1])
0121 //   newshape = Concat (batchsize, [1024, 1024])
0122 //   Z = Reshape(Y, newshape)
0123 // If the shape of X is statically known, then data-propagation should be able to determine
0124 // the value of newshape, as well as the shape of Z.
0125 struct DataPropagationContext {
0126   virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0127   virtual size_t getNumInputs() const = 0;
0128   virtual const TypeProto* getInputType(size_t index) const = 0;
0129   virtual size_t getNumOutputs() const = 0;
0130   virtual const TypeProto* getOutputType(size_t index) const = 0;
0131   virtual ~DataPropagationContext() {}
0132   virtual const TensorShapeProto* getInputData(size_t index) = 0;
0133   virtual void addOutputData(size_t index, TensorShapeProto&& tp) = 0;
0134 };
0135 
0136 using InferenceFunction = std::function<void(InferenceContext&)>;
0137 using DataPropagationFunction = std::function<void(DataPropagationContext&)>;
0138 
0139 // This no-op inference function is used for operators without an
0140 // inference implementation.
0141 inline void dummyInferenceFunction(InferenceContext&){};
0142 
0143 // This no-op data propagation function is used for operators without a defined data propagator
0144 inline void dummyDataPropagationFunction(DataPropagationContext&){};
0145 
0146 template <typename T>
0147 inline bool getRepeatedAttribute(InferenceContext& ctx, std::string attr_name, std::vector<T>& values) {
0148   const auto* attr = ctx.getAttribute(attr_name);
0149   if (attr) {
0150     values = RetrieveValues<T>(*attr);
0151     return true;
0152   } else {
0153     return false;
0154   }
0155 }
0156 
0157 inline int64_t getAttribute(InferenceContext& ctx, const std::string& attributeName, int64_t defaultValue) {
0158   auto attr_proto = ctx.getAttribute(attributeName);
0159   if ((nullptr != attr_proto) && attr_proto->has_i())
0160     return attr_proto->i();
0161   return defaultValue;
0162 }
0163 
0164 inline int64_t getAttribute(DataPropagationContext& ctx, const std::string& attributeName, int64_t defaultValue) {
0165   auto attr_proto = ctx.getAttribute(attributeName);
0166   if ((nullptr != attr_proto) && attr_proto->has_i())
0167     return attr_proto->i();
0168   return defaultValue;
0169 }
0170 
0171 inline std::string
0172 getAttribute(InferenceContext& ctx, const std::string& attributeName, const std::string& defaultValue) {
0173   auto attr_proto = ctx.getAttribute(attributeName);
0174   if ((nullptr != attr_proto) && attr_proto->has_s())
0175     return attr_proto->s();
0176   return defaultValue;
0177 }
0178 
0179 inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, TensorShapeProto::Dimension dim2) {
0180   TensorShapeProto::Dimension result;
0181   if (dim1.has_dim_value() && dim2.has_dim_value()) {
0182     result.set_dim_value(dim1.dim_value() * dim2.dim_value());
0183   } else if (dim1.has_dim_value() && (dim1.dim_value() == 1)) {
0184     return dim2;
0185   } else if (dim2.has_dim_value() && (dim2.dim_value() == 1)) {
0186     return dim1;
0187   }
0188   return result;
0189 }
0190 
0191 template <typename Container>
0192 std::string stringify(const Container& elements);
0193 
0194 std::pair<int, int> getAttributeProtoElemTypeAndLength(const AttributeProto* attr_proto);
0195 
0196 std::pair<int, int> getAttributeElementTypeAndLength(
0197     const InferenceContext& ctx,
0198     const std::initializer_list<std::string>& attribute_names);
0199 
0200 inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, int64_t dim2) {
0201   TensorShapeProto::Dimension result;
0202   if (dim1.has_dim_value()) {
0203     result.set_dim_value(dim1.dim_value() * dim2);
0204   } else if (dim2 == 1) {
0205     return dim1;
0206   }
0207   return result;
0208 }
0209 
0210 inline TensorShapeProto::Dimension operator/(TensorShapeProto::Dimension dim1, int64_t dim2) {
0211   TensorShapeProto::Dimension result;
0212   if (dim1.has_dim_value()) {
0213     result.set_dim_value(dim1.dim_value() / dim2);
0214   } else if (dim2 == 1) {
0215     return dim1;
0216   }
0217   return result;
0218 }
0219 
0220 // if from >= upto_exclusive, return 1.
0221 // Caller must make sure upto_exclusive is less than or equal to shape.size()
0222 // Caller must make sure from>=0
0223 inline TensorShapeProto::Dimension multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) {
0224   TensorShapeProto::Dimension dim;
0225   dim.set_dim_value(1);
0226   for (int i = from; i < upto_exclusive; ++i) {
0227     dim = dim * shape.dim(i);
0228   }
0229   return dim;
0230 }
0231 
0232 inline int32_t getTensorElementType(const TypeProto& type) {
0233   int32_t result = TensorProto::UNDEFINED;
0234   const auto value_case = type.value_case();
0235   if (value_case == TypeProto::kTensorType) {
0236     result = type.tensor_type().elem_type();
0237   } else if (value_case == TypeProto::kSparseTensorType) {
0238     result = type.sparse_tensor_type().elem_type();
0239   }
0240   return result;
0241 }
0242 
0243 inline void setTensorElementType(int32_t elem_type, TypeProto::ValueCase value_case, TypeProto& type) {
0244   if (value_case == TypeProto::kTensorType) {
0245     type.mutable_tensor_type()->set_elem_type(elem_type);
0246   } else if (value_case == TypeProto::kSparseTensorType) {
0247     type.mutable_sparse_tensor_type()->set_elem_type(elem_type);
0248   }
0249 }
0250 
0251 void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type);
0252 
0253 void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
0254 
0255 void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
0256 
0257 inline void propagateElemTypeFromDtypeToOutput(
0258     InferenceContext& ctx,
0259     const int data_type,
0260     size_t outputIndex,
0261     TypeProto::ValueCase expected_value_case) {
0262   const auto attribute_tensor_datatype = data_type;
0263   auto output_type = ctx.getOutputType(outputIndex);
0264   const auto output_value_case = output_type->value_case();
0265   if (output_value_case == TypeProto::VALUE_NOT_SET || output_value_case == expected_value_case) {
0266     setTensorElementType(attribute_tensor_datatype, expected_value_case, *output_type);
0267   } else {
0268     // This is not expected to happen
0269     fail_type_inference(
0270         "Output ",
0271         outputIndex,
0272         " expected to have: ",
0273         expected_value_case,
0274         " or UNDEFINED. Got: ",
0275         output_value_case,
0276         " in ",
0277         ctx.getDisplayName(),
0278         ".");
0279   }
0280 }
0281 
0282 inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const int data_type, size_t outputIndex) {
0283   propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, TypeProto::kTensorType);
0284 }
0285 
0286 inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const AttributeProto* attr, size_t outputIndex) {
0287   int32_t data_type = TensorProto::UNDEFINED;
0288   TypeProto::ValueCase expected_value_case = TypeProto::VALUE_NOT_SET;
0289   const auto attr_type = attr->type();
0290   if (attr_type == AttributeProto::TENSOR) {
0291     if (attr->t().dims().size() != 1) {
0292       fail_type_inference("Attribute expected to have a one-dim tensor in ", ctx.getDisplayName(), ".");
0293     }
0294     data_type = attr->t().data_type();
0295     expected_value_case = TypeProto::kTensorType;
0296   } else if (attr_type == AttributeProto::SPARSE_TENSOR) {
0297     if (attr->sparse_tensor().dims().size() != 1) {
0298       fail_type_inference("Attribute expected to have a one-dim sparse tensor in ", ctx.getDisplayName(), ".");
0299     }
0300     data_type = attr->sparse_tensor().values().data_type();
0301     expected_value_case = TypeProto::kSparseTensorType;
0302   } else {
0303     fail_type_inference("Attribute expected to have tensor or sparse tensor type in ", ctx.getDisplayName(), ".");
0304   }
0305 
0306   propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case);
0307 }
0308 
0309 inline bool hasShape(const TypeProto& type) {
0310   if (type.has_tensor_type()) {
0311     return type.tensor_type().has_shape();
0312   } else if (type.has_sparse_tensor_type()) {
0313     return type.sparse_tensor_type().has_shape();
0314   } else if (type.has_sequence_type() && type.sequence_type().has_elem_type()) {
0315     return hasShape(type.sequence_type().elem_type());
0316   } else if (type.has_optional_type() && type.optional_type().has_elem_type()) {
0317     return hasShape(type.optional_type().elem_type());
0318   }
0319   return false;
0320 }
0321 
0322 template <typename Context>
0323 inline bool hasInputShape(const Context& ctx, size_t n) {
0324   return ctx.getNumInputs() > static_cast<size_t>(n) && ctx.getInputType(n) && hasShape(*ctx.getInputType(n));
0325 }
0326 
0327 template <typename Context>
0328 inline bool hasNInputShapes(const Context& ctx, size_t n) {
0329   for (size_t i = 0; i < n; i++) {
0330     if (!hasInputShape(ctx, i)) {
0331       return false;
0332     }
0333   }
0334   return true;
0335 }
0336 
0337 inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t n) {
0338   const auto* input_type = ctx.getInputType(n);
0339   const auto value_case = input_type->value_case();
0340   if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
0341     fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), ".");
0342   }
0343   if (!hasShape(*input_type)) {
0344     fail_shape_inference("Input ", n, " must have a non null shape in ", ctx.getDisplayName(), ".");
0345   }
0346   if (value_case == TypeProto::kTensorType) {
0347     return input_type->tensor_type().shape();
0348   } else {
0349     return input_type->sparse_tensor_type().shape();
0350   }
0351 }
0352 
0353 inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size_t n) {
0354   const auto* input_type = ctx.getInputType(n);
0355 
0356   if (input_type == nullptr) {
0357     return nullptr;
0358   }
0359 
0360   const auto value_case = input_type->value_case();
0361   if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
0362     fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), ".");
0363   }
0364   if (value_case == TypeProto::kTensorType) {
0365     return &input_type->tensor_type().shape();
0366   } else {
0367     return &input_type->sparse_tensor_type().shape();
0368   }
0369 }
0370 
0371 // Caller must make sure fromDimIndex is strictly less than shape.dim_size()
0372 inline void appendSingleDimCopiedFromInputTypeToOutputType(
0373     InferenceContext& ctx,
0374     size_t inputIndex,
0375     size_t outputIndex,
0376     size_t fromDimIndex) {
0377   auto output_type = ctx.getOutputType(outputIndex);
0378   const auto output_value_case = output_type->value_case();
0379   auto input_type = ctx.getInputType(inputIndex);
0380   const auto input_value_case = input_type->value_case();
0381   if (output_value_case != input_value_case) {
0382     fail_type_inference(
0383         "Input: ",
0384         inputIndex,
0385         " type: ",
0386         input_value_case,
0387         " does not match type of output: ",
0388         outputIndex,
0389         "type: ",
0390         output_value_case,
0391         " in ",
0392         ctx.getDisplayName(),
0393         ".");
0394   }
0395   if (TypeProto::kTensorType == input_value_case) {
0396     auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim();
0397     *dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex));
0398   } else if (TypeProto::kSparseTensorType == input_value_case) {
0399     auto* dim = output_type->mutable_sparse_tensor_type()->mutable_shape()->add_dim();
0400     *dim = input_type->sparse_tensor_type().shape().dim(static_cast<int>(fromDimIndex));
0401   } else {
0402     fail_type_inference(
0403         "Input ",
0404         inputIndex,
0405         " and Output ",
0406         outputIndex,
0407         " expected to have tensor or sparse tensor type in ",
0408         ctx.getDisplayName(),
0409         ".");
0410   }
0411 }
0412 
0413 inline void propagateShape(const TypeProto* from_type, TypeProto* to_type) {
0414   const auto from_type_case = from_type->value_case();
0415   const auto to_type_case = to_type->value_case();
0416   if (from_type_case != to_type_case) {
0417     fail_shape_inference(
0418         "Mismatch between inferred and declared type. Inferred=", from_type_case, " Declared=", to_type_case);
0419   }
0420 
0421   if (TypeProto::kTensorType == from_type_case || TypeProto::kSparseTensorType == from_type_case) {
0422     // If input shape is "unknown", the corresponding should be "unknown" too.
0423     // The way to make output shape unknown is not to assign it any value.
0424     if (hasShape(*from_type)) {
0425       if (TypeProto::kTensorType == from_type_case) {
0426         *to_type->mutable_tensor_type()->mutable_shape() = from_type->tensor_type().shape();
0427       } else {
0428         *to_type->mutable_sparse_tensor_type()->mutable_shape() = from_type->sparse_tensor_type().shape();
0429       }
0430     }
0431   } else if (TypeProto::kSequenceType == from_type_case) {
0432     propagateShape(&from_type->sequence_type().elem_type(), to_type->mutable_sequence_type()->mutable_elem_type());
0433   } else if (TypeProto::kOptionalType == from_type_case) {
0434     propagateShape(&from_type->optional_type().elem_type(), to_type->mutable_optional_type()->mutable_elem_type());
0435   } else if (TypeProto::kMapType == from_type_case) {
0436     propagateShape(&from_type->map_type().value_type(), to_type->mutable_map_type()->mutable_value_type());
0437   } else {
0438     fail_shape_inference("Unsupported Source/Target type=", from_type_case);
0439   }
0440 }
0441 
0442 inline void propagateShapeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
0443   auto output_type = ctx.getOutputType(outputIndex);
0444   auto input_type = ctx.getInputType(inputIndex);
0445 
0446   propagateShape(input_type, output_type);
0447 }
0448 
0449 inline void propagateShapeAndTypeFromFirstInput(InferenceContext& ctx) {
0450   propagateElemTypeFromInputToOutput(ctx, 0, 0);
0451   if (!hasNInputShapes(ctx, 1)) {
0452     return;
0453   }
0454   propagateShapeFromInputToOutput(ctx, 0, 0);
0455 }
0456 
0457 inline void
0458 updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType, TypeProto::ValueCase expected_type) {
0459   auto output_type = ctx.getOutputType(outputIndex);
0460   if (output_type == nullptr) {
0461     fail_type_inference("Output ", outputIndex, " is null");
0462   }
0463   if (output_type->value_case() == expected_type || output_type->value_case() == TypeProto::VALUE_NOT_SET) {
0464     setTensorElementType(elemType, expected_type, *output_type);
0465   } else {
0466     // This is not expected to happen
0467     fail_type_inference(
0468         "Output ",
0469         outputIndex,
0470         " expected to have tensor or sparse tensor type: ",
0471         expected_type,
0472         " in ",
0473         ctx.getDisplayName(),
0474         ".");
0475   }
0476 }
0477 
0478 inline void updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType) {
0479   updateOutputElemType(ctx, outputIndex, elemType, TypeProto::kTensorType);
0480 }
0481 
0482 // Infer type of an output from the value of a specified attribute, which is
0483 // expected to have a valid value representing a TensorProto_DataType.
0484 inline void propagateElemTypeFromAttributeToOutput(
0485     InferenceContext& ctx,
0486     const std::string& attributeName,
0487     size_t outputIndex,
0488     TypeProto::ValueCase expected_type,
0489     TensorProto_DataType default_value = TensorProto::UNDEFINED) {
0490   auto attr_proto = ctx.getAttribute(attributeName);
0491   if (nullptr == attr_proto) { // attribute not present
0492     if (default_value != TensorProto::UNDEFINED) {
0493       updateOutputElemType(ctx, outputIndex, default_value, expected_type);
0494       return;
0495     } else {
0496       fail_type_inference("Value of attribute ", attributeName, " not specified in ", ctx.getDisplayName(), ".");
0497     }
0498   }
0499   if (!attr_proto->has_i()) {
0500     fail_type_inference(
0501         "Attribute ", attributeName, " should be of integer type and specify a type in ", ctx.getDisplayName(), ".");
0502   }
0503   auto attr_value = attr_proto->i();
0504   auto elem_type = static_cast<TensorProto_DataType>(attr_value);
0505   if (!TensorProto_DataType_IsValid(elem_type)) {
0506     fail_type_inference("Attribute ", attributeName, " does not specify a valid type in ", ctx.getDisplayName(), ".");
0507   }
0508   updateOutputElemType(ctx, outputIndex, elem_type, expected_type);
0509 }
0510 
0511 inline void propagateElemTypeFromAttributeToOutput(
0512     InferenceContext& ctx,
0513     const std::string& attributeName,
0514     size_t outputIndex,
0515     TensorProto_DataType default_value = TensorProto::UNDEFINED) {
0516   propagateElemTypeFromAttributeToOutput(ctx, attributeName, outputIndex, TypeProto::kTensorType, default_value);
0517 }
0518 
0519 inline TensorShapeProto* getTensorMutableShape(TypeProto::ValueCase value_case, TypeProto& type) {
0520   if (value_case == TypeProto::kTensorType) {
0521     return type.mutable_tensor_type()->mutable_shape();
0522   } else if (value_case == TypeProto::kSparseTensorType) {
0523     return type.mutable_tensor_type()->mutable_shape();
0524   }
0525   return nullptr;
0526 }
0527 
0528 inline TensorShapeProto*
0529 getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0530   auto output_type = ctx.getOutputType(n);
0531   if (output_type == nullptr) {
0532     fail_type_inference("Output ", n, " expected to have tensor or sparse type in ", ctx.getDisplayName(), ".");
0533   }
0534   const auto output_value_case = output_type->value_case();
0535   if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
0536     return getTensorMutableShape(output_value_case, *output_type);
0537   } else if (output_value_case == TypeProto::VALUE_NOT_SET) {
0538     return getTensorMutableShape(default_type, *output_type);
0539   } else {
0540     fail_type_inference("Output ", n, " expected to have tensor type in ", ctx.getDisplayName(), ".");
0541   }
0542 }
0543 
0544 inline void appendDim(TensorShapeProto* shape, int64_t dim_value) {
0545   shape->add_dim()->set_dim_value(dim_value);
0546 }
0547 
0548 inline void updateOutputShape(
0549     InferenceContext& ctx,
0550     size_t outputIndex,
0551     const TensorShapeProto& shape,
0552     TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0553   auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0554   *output_shape = shape;
0555 }
0556 
0557 inline void updateOutputShape(
0558     InferenceContext& ctx,
0559     size_t outputIndex,
0560     const TensorProto& tensorProto,
0561     TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0562   auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0563   for (auto d : tensorProto.dims()) {
0564     auto* dim = output_shape->add_dim();
0565     dim->set_dim_value(d);
0566   }
0567 }
0568 
0569 inline void updateOutputShape(
0570     InferenceContext& ctx,
0571     size_t outputIndex,
0572     std::initializer_list<TensorShapeProto::Dimension> dims,
0573     TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0574   auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0575   for (auto& d : dims) {
0576     auto* dim = output_shape->add_dim();
0577     *dim = d;
0578   }
0579 }
0580 
0581 // Get shape input by first checking initializer and then propagated symbolic data.
0582 // If neither is available, try rank inference.
0583 // When one of above succeeds, `true` is stored in `found`.
0584 // Otherwise, `false` is stored, which means that returned TensorShapeProto does not make sense.
0585 TensorShapeProto getShapeInput(const InferenceContext& ctx, size_t input_index, bool& found);
0586 
0587 // Infer shape of an output from the value of a specified attribute, which is
0588 // expected to be a list of integers specifying a valid shape.
0589 inline void propagateShapeFromAttributeToOutput(
0590     InferenceContext& ctx,
0591     const std::string& attributeName,
0592     size_t outputIndex,
0593     TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0594   auto attr_proto = ctx.getAttribute(attributeName);
0595   if ((nullptr == attr_proto) || (!attr_proto->has_type()) ||
0596       (attr_proto->type() != AttributeProto_AttributeType_INTS)) {
0597     fail_shape_inference("Attribute ", attributeName, " should specify a shape in ", ctx.getDisplayName(), ".");
0598   }
0599   auto& int_list = attr_proto->ints();
0600   TensorShapeProto shape;
0601   for (auto dim_size : int_list) {
0602     if (dim_size < 0) {
0603       fail_shape_inference("Negative values are not allowed in a shape specification in ", ctx.getDisplayName(), ".");
0604     }
0605     shape.add_dim()->set_dim_value(dim_size);
0606   }
0607 
0608   updateOutputShape(ctx, outputIndex, shape, default_type);
0609 }
0610 
0611 inline void multidirectionalBroadcastShapeInference(
0612     const std::vector<const TensorShapeProto*>& shapes,
0613     TensorShapeProto& resultShape) {
0614   int result_shape_size = 0;
0615   // Get the result shape size.
0616   for (size_t i = 0; i < shapes.size(); ++i) {
0617     if (shapes[i]->dim_size() > result_shape_size) {
0618       result_shape_size = shapes[i]->dim_size();
0619     }
0620   }
0621 
0622   for (int i = 0; i < result_shape_size; ++i) {
0623     int64_t dim_value = 1;
0624     TensorShapeProto_Dimension symbolic_dim;
0625     int num_symbolic_dims = 0;
0626     for (size_t j = 0; j < shapes.size(); ++j) {
0627       if (i < result_shape_size - shapes[j]->dim_size()) {
0628         // Shape j will be filled with 1 at dimension i;
0629         continue;
0630       }
0631 
0632       auto dim_i_j = shapes[j]->dim(i - result_shape_size + shapes[j]->dim_size());
0633       if (dim_i_j.has_dim_value()) {
0634         if (dim_i_j.dim_value() != 1) {
0635           if (dim_value != dim_i_j.dim_value() && dim_value != 1) {
0636             fail_shape_inference("Incompatible dimensions");
0637           } else {
0638             dim_value = dim_i_j.dim_value();
0639           }
0640         }
0641       } else {
0642         if (num_symbolic_dims == 0) {
0643           symbolic_dim = dim_i_j;
0644           ++num_symbolic_dims;
0645         } else if (dim_i_j.dim_param() != symbolic_dim.dim_param()) {
0646           ++num_symbolic_dims;
0647         }
0648       }
0649     }
0650 
0651     if (dim_value != 1 || num_symbolic_dims == 0) {
0652       resultShape.add_dim()->set_dim_value(dim_value);
0653     } else if (num_symbolic_dims == 1) {
0654       *resultShape.add_dim() = symbolic_dim;
0655     } else {
0656       resultShape.add_dim();
0657     }
0658   }
0659 }
0660 
0661 inline void bidirectionalBroadcastShapeInference(
0662     const TensorShapeProto& shapeL,
0663     const TensorShapeProto& shapeR,
0664     TensorShapeProto& resultShape) {
0665   std::vector<const TensorShapeProto*> shapes;
0666   shapes.push_back(&shapeL);
0667   shapes.push_back(&shapeR);
0668   multidirectionalBroadcastShapeInference(shapes, resultShape);
0669 }
0670 
0671 /*
0672 Merge the dimension information from two TensorShapeProto_Dimension instances.
0673 Values are merged into target from source.
0674 If target has no dimension information, copy from source.
0675 If source has no dimension information, ignore source.
0676 If both have dimension information:
0677  - Prefer values over params. If both have values, values must match.
0678  - Prefer target param over source param if mismatched.
0679 Fail if there are mismatches in dimension values.
0680 Currently, there is no way to refine/update dimension information for the
0681 source from information available in the target.
0682 */
0683 inline void mergeInDimensionInfo(
0684     const TensorShapeProto_Dimension& source_dim,
0685     TensorShapeProto_Dimension& target_dim,
0686     int dim_index) {
0687   // if source has value, merge into target
0688   // else if target has value, preserve it
0689   // else merge params
0690   if (source_dim.has_dim_value()) {
0691     auto source_value = source_dim.dim_value();
0692     if (target_dim.has_dim_value()) {
0693       auto target_value = target_dim.dim_value();
0694       if (target_value != source_value) {
0695         fail_shape_inference(
0696             "Can't merge shape info. "
0697             "Both inferred and declared dimension have values but they differ. Inferred=",
0698             source_value,
0699             " Declared=",
0700             target_value,
0701             " Dimension=",
0702             dim_index);
0703       }
0704     } else {
0705       target_dim.set_dim_value(source_value);
0706     }
0707   } else if (target_dim.has_dim_value()) {
0708     // if target has a value we preserve it so do nothing
0709   } else if (target_dim.has_dim_param()) {
0710     // prefer target param over source
0711   } else if (source_dim.has_dim_param()) {
0712     target_dim.set_dim_param(source_dim.dim_param());
0713   }
0714 }
0715 
0716 void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
0717 
0718 void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
0719 
0720 /*
0721 Merge the shape information from two TypeProto_Tensor instances.
0722 Values are merged into target from source.
0723 If target has no shape information, copy from source.
0724 If source has no shape information, ignore source.
0725 If both have shape information:
0726 - merge each TensorShapeProto_Dimension separately.
0727 - Prefer values over params. If both have values, values must match.
0728 - Prefer target param over source param if mismatched.
0729 Fail if there are mismatches in number of dimensions or dimension values.
0730 */
0731 void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target);
0732 
0733 void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target);
0734 
0735 // Return a copy of a type, with a specified dimension removed from its shape.
0736 inline TypeProto RemoveIthDimensionFromShape(const TypeProto& proto, int removed_dim) {
0737   TypeProto t(proto);
0738   auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
0739   mutable_shape->clear_dim();
0740 
0741   const auto& dims = proto.tensor_type().shape().dim();
0742 
0743   for (int j = 0, end = dims.size(); j < end; ++j) {
0744     if (j != removed_dim)
0745       (*mutable_shape->add_dim()) = dims.Get(j);
0746   }
0747 
0748   return t;
0749 }
0750 
0751 // Return a copy of a type, with specified number of dimensions removed from the
0752 // beginning.
0753 inline TypeProto RemoveDimensionsFromShape(const TypeProto& proto, int num_dimensions) {
0754   TypeProto t(proto);
0755   auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
0756   mutable_shape->clear_dim();
0757 
0758   const auto& dims = proto.tensor_type().shape().dim();
0759 
0760   // skip first num_dimensions
0761   for (int j = num_dimensions, end = dims.size(); j < end; ++j) {
0762     (*mutable_shape->add_dim()) = dims.Get(j);
0763   }
0764 
0765   return t;
0766 }
0767 
0768 // copied from GSL:
0769 // https://github.com/microsoft/GSL/blob/main/include/gsl/util
0770 template <class T, class U>
0771 static constexpr T narrow_cast(U&& u) noexcept {
0772   return static_cast<T>(std::forward<U>(u));
0773 }
0774 
0775 inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expected_rank) {
0776   // We check the rank only if a rank is known for the input:
0777   if (hasInputShape(ctx, input_index)) {
0778     auto rank = getInputShape(ctx, input_index).dim_size();
0779     if (rank != expected_rank) {
0780       fail_shape_inference(
0781           "Input ",
0782           input_index,
0783           " expected to have rank ",
0784           expected_rank,
0785           " but has rank ",
0786           rank,
0787           " in ",
0788           ctx.getDisplayName(),
0789           ".");
0790     }
0791   }
0792 }
0793 
0794 // Unification (between dimensions and/or shapes) is at the heart of
0795 // shape-inference. The current inference algorithm can check input
0796 // shapes/dimensions of a node and update the output shapes/dimensions. It
0797 // cannot currently update input shapes and dimensions (even though in some
0798 // contexts this inference is possible). Hence, we have the variants below to
0799 // support "const" and "mutable" dimensions/shapes in unification.
0800 
0801 inline void checkDimEquality(int64_t value1, int64_t value2) {
0802   if (value1 != value2) {
0803     fail_shape_inference("Dimension mismatch in unification between ", value1, " and ", value2);
0804   }
0805 }
0806 
0807 inline void unifyDim(const Dim& dim1, const Dim& dim2) {
0808   if (dim1.has_dim_value() && dim2.has_dim_value())
0809     checkDimEquality(dim1.dim_value(), dim2.dim_value());
0810 }
0811 
0812 // TODO: The functionality of unifyDim is similar to that of
0813 // mergeInDimensionInfo. However, the error messages are different. Leaving this
0814 // duplication in-place to preserve error message content.
0815 inline void unifyDim(const Dim& source_dim, Dim& target_dim) {
0816   if (source_dim.has_dim_value()) {
0817     auto source_value = source_dim.dim_value();
0818     if (target_dim.has_dim_value()) {
0819       auto target_value = target_dim.dim_value();
0820       checkDimEquality(source_value, target_value);
0821     } else {
0822       target_dim.set_dim_value(source_value);
0823     }
0824   } else if (target_dim.has_dim_value()) {
0825     // if target has a value we preserve it.
0826     // we cannot set source dim value.
0827   } else if (target_dim.has_dim_param()) {
0828     // prefer target param over source
0829     // we cannot currently unify the dim_params
0830   } else if (source_dim.has_dim_param()) {
0831     target_dim.set_dim_param(source_dim.dim_param());
0832   }
0833 }
0834 
0835 inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_index, Dim& dim) {
0836   // We unify the dimensions only if it is available for specified input:
0837   if (hasInputShape(ctx, input_index)) {
0838     auto& input_shape = getInputShape(ctx, input_index);
0839     // This shape is expected to have rank > dim_index:
0840     if (input_shape.dim_size() <= dim_index) {
0841       fail_shape_inference(
0842           "Input ",
0843           input_index,
0844           " expected to have rank >",
0845           dim_index,
0846           " but has rank ",
0847           input_shape.dim_size(),
0848           " in ",
0849           ctx.getDisplayName(),
0850           ".");
0851     }
0852     const Dim& input_dim = input_shape.dim(dim_index);
0853     // Now, unify dim and input_dim:
0854     unifyDim(input_dim, dim);
0855   }
0856 }
0857 
0858 // unifyDim: unifies a dimension with a constant value. If the dimension
0859 // already has a value, we check for equality of new value with old value.
0860 inline void unifyDim(Dim& dim, int64_t value) {
0861   if (dim.has_dim_value()) {
0862     checkDimEquality(dim.dim_value(), value);
0863   } else
0864     dim.set_dim_value(value);
0865 }
0866 
0867 // target-shape = Union (target-shape, source_shape)
0868 // Example 1: same rank, different dimensions
0869 //    input1 shape: (2, 3, 4, 'x')
0870 //    input2 shape: (2, 'y', 5, 'x')
0871 //    output shape: (2, None, None, 'x')
0872 // Example 2: different rank
0873 //    input1 shape: (2, 3, 4, 'x')
0874 //    input2 shape: (2, 3, 4)
0875 //    output shape: None
0876 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
0877 
0878 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
0879 
0880 // target-type = Union (target-type, source-type)
0881 // target and source are required to have the same type.
0882 // Example 1: same tensor type, different shape
0883 //    source: tensor elem_type: int64, shape: (2, 3, 4, 'x')
0884 //    target: tensor elem_type: int64, shape: (2, 'y', 5, 'x')
0885 //    output: tensor elem_type: int64, shape: (2, None, None, 'x')
0886 // Example 2: same sequence type, different shape
0887 //    source: sequence of tensor, elem_type: float, shape: (2, 3, 4)
0888 //    target: sequence of tensor, elem_type: float, shape: None
0889 //    output: sequence of tensor, elem_type: float, shape: None
0890 void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type);
0891 
0892 // adjustNegativeAxes: Negative axes values are translated to the right axis in the positive range
0893 template <typename Axes>
0894 void adjustNegativeAxes(Axes& axes, int rank) {
0895   std::transform(
0896       axes.begin(), axes.end(), axes.begin(), [&](int64_t axis) -> int64_t { return axis < 0 ? axis + rank : axis; });
0897 }
0898 
0899 // checkAxesRange: Checks that values are within the range [-rank, rank)
0900 template <typename Axes>
0901 void checkAxesRange(Axes& axes, int rank) {
0902   for (auto axis : axes) {
0903     if (axis < -rank || axis > (rank - 1))
0904       fail_shape_inference("Unexpected axis value: ", axis, ". Expected range [", -rank, ", ", rank, ")");
0905   }
0906 }
0907 
0908 // checkDuplicateAxes: Check that there are no duplicated axes
0909 template <typename Axes>
0910 void checkDuplicateAxes(Axes& axes, int rank) {
0911   std::vector<bool> tmp(rank, false);
0912   for (auto axis : axes) {
0913     int actual_axis = axis < 0 ? axis + rank : axis;
0914     if (tmp[actual_axis])
0915       fail_shape_inference("Axis ", axis, " is referred to more than once.");
0916     tmp[actual_axis] = true;
0917   }
0918 }
0919 
0920 } // namespace ONNX_NAMESPACE