Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:46

0001 /*
0002  * SPDX-License-Identifier: Apache-2.0
0003  */
0004 
0005 #pragma once
0006 
0007 #include <algorithm>
0008 #include <functional>
0009 #include <string>
0010 #include <utility>
0011 #include <vector>
0012 
0013 #include "onnx/defs/data_type_utils.h"
0014 #include "onnx/proto_utils.h"
0015 #include "onnx/string_utils.h"
0016 
0017 namespace ONNX_NAMESPACE {
0018 
0019 using Dim = TensorShapeProto_Dimension;
0020 
0021 struct ShapeInferenceOptions {
0022   // Checks the type-equality for input and output
0023   bool check_type;
0024   // 1: Will throw any node level shape infer errors
0025   // 0: Won't throw node-level shape infer errors, but other errors
0026   // like merging existing shape with inferred etc are thrown
0027   int error_mode;
0028   // Enables data propagation for limited operators
0029   // to perform shape computation
0030   bool enable_data_propagation;
0031   ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)
0032       : check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){};
0033 };
0034 
0035 // Maintains a SymbolTable for symbolic shape inference
0036 class SymbolTable {
0037  public:
0038   // Adds existing symbols from a main graph or subgraph
0039   virtual void addFromGraph(const GraphProto& g) = 0;
0040   // Creates a new symbol which is not duplicate as any existing one
0041   std::string createNew() {
0042     return createNew("unk__");
0043   }
0044   virtual std::string createNew(const std::string& symbol_prefix) = 0;
0045   virtual ~SymbolTable() = default;
0046 };
0047 
0048 class GraphInferencer {
0049  public:
0050   // Perform inferencing on the graph contained in GraphInferencer.
0051   // Returns the graph output types post-inferencing.
0052   virtual std::vector<const TypeProto*> doInferencing(
0053       const std::vector<const TypeProto*>& inputTypes,
0054       const std::vector<const TensorProto*>& inputData) = 0;
0055   virtual ~GraphInferencer() = default;
0056 };
0057 
0058 // Exception class used for handling errors in type and shape inference
0059 
0060 class InferenceError final : public std::runtime_error {
0061  public:
0062   using std::runtime_error::runtime_error;
0063 
0064   InferenceError(const std::string& message) : std::runtime_error(message) {}
0065 
0066   const char* what() const noexcept override {
0067     if (!expanded_message_.empty()) {
0068       return expanded_message_.c_str();
0069     }
0070     return std::runtime_error::what();
0071   }
0072 
0073   void AppendContext(const std::string& context) {
0074     expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
0075   }
0076 
0077  private:
0078   std::string expanded_message_;
0079 };
0080 
0081 #define fail_type_inference(...) \
0082   ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__)));
0083 
0084 #define fail_shape_inference(...) \
0085   ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__)));
0086 
0087 struct InferenceContext {
0088   virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0089   virtual size_t getNumInputs() const = 0;
0090   virtual const TypeProto* getInputType(size_t index) const = 0;
0091   virtual bool hasInput(size_t index) const {
0092     // The default implementation below is used for backward-compatibility
0093     // for implementations of InferenceContext that don't provide an explicit
0094     // implementation. This works for normal usage, but may be imprecise in
0095     // the edge-case where an input is supplied but has no known type.
0096     // However, inference-methods work only under the assumption that the
0097     // input-types of all inputs are known.
0098     return ((index < getNumInputs()) && (getInputType(index) != nullptr));
0099   }
0100   virtual const TensorProto* getInputData(size_t index) const = 0;
0101   virtual size_t getNumOutputs() const = 0;
0102   virtual TypeProto* getOutputType(size_t index) = 0;
0103   virtual GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) = 0;
0104   virtual ~InferenceContext() {}
0105   virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0;
0106   // Gets the shape inputs computed by partial data propagation.
0107   virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0;
0108 };
0109 
0110 // We use data propagation to perform partial evaluation of the model, to compute statically
0111 // known information about tensor values. It is intended to improve the precision of shape
0112 // inference. We reuse TensorShapeProto to represent the statically known values. One
0113 // limitation of this is that TensorShapeProto can represent only integer values.
0114 // As an example, data-propagation is intended to handle code-fragments like below:
0115 //   shape = Shape(X)
0116 //   batchsize = Slice(shape, [0], [1])
0117 //   newshape = Concat (batchsize, [1024, 1024])
0118 //   Z = Reshape(Y, newshape)
0119 // If the shape of X is statically known, then data-propagation should be able to determine
0120 // the value of newshape, as well as the shape of Z.
0121 struct DataPropagationContext {
0122   virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0123   virtual size_t getNumInputs() const = 0;
0124   virtual const TypeProto* getInputType(size_t index) const = 0;
0125   virtual size_t getNumOutputs() const = 0;
0126   virtual const TypeProto* getOutputType(size_t index) const = 0;
0127   virtual ~DataPropagationContext() {}
0128   virtual const TensorShapeProto* getInputData(size_t index) = 0;
0129   virtual void addOutputData(size_t index, TensorShapeProto&& tp) = 0;
0130 };
0131 
0132 using InferenceFunction = std::function<void(InferenceContext&)>;
0133 using DataPropagationFunction = std::function<void(DataPropagationContext&)>;
0134 
0135 // This no-op inference function is used for operators without an
0136 // inference implementation.
0137 inline void dummyInferenceFunction(InferenceContext&){};
0138 
0139 // This no-op data propagation function is used for operators without a defined data propagator
0140 inline void dummyDataPropagationFunction(DataPropagationContext&){};
0141 
0142 template <typename T>
0143 inline bool getRepeatedAttribute(InferenceContext& ctx, std::string attr_name, std::vector<T>& values) {
0144   const auto* attr = ctx.getAttribute(attr_name);
0145   if (attr) {
0146     values = RetrieveValues<T>(*attr);
0147     return true;
0148   } else {
0149     return false;
0150   }
0151 }
0152 
0153 inline int64_t getAttribute(InferenceContext& ctx, const std::string& attributeName, int64_t defaultValue) {
0154   auto attr_proto = ctx.getAttribute(attributeName);
0155   if ((nullptr != attr_proto) && attr_proto->has_i())
0156     return attr_proto->i();
0157   return defaultValue;
0158 }
0159 
0160 inline int64_t getAttribute(DataPropagationContext& ctx, const std::string& attributeName, int64_t defaultValue) {
0161   auto attr_proto = ctx.getAttribute(attributeName);
0162   if ((nullptr != attr_proto) && attr_proto->has_i())
0163     return attr_proto->i();
0164   return defaultValue;
0165 }
0166 
0167 inline std::string
0168 getAttribute(InferenceContext& ctx, const std::string& attributeName, const std::string& defaultValue) {
0169   auto attr_proto = ctx.getAttribute(attributeName);
0170   if ((nullptr != attr_proto) && attr_proto->has_s())
0171     return attr_proto->s();
0172   return defaultValue;
0173 }
0174 
0175 inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, TensorShapeProto::Dimension dim2) {
0176   TensorShapeProto::Dimension result;
0177   if (dim1.has_dim_value() && dim2.has_dim_value()) {
0178     result.set_dim_value(dim1.dim_value() * dim2.dim_value());
0179   } else if (dim1.has_dim_value() && (dim1.dim_value() == 1)) {
0180     return dim2;
0181   } else if (dim2.has_dim_value() && (dim2.dim_value() == 1)) {
0182     return dim1;
0183   }
0184   return result;
0185 }
0186 
0187 template <typename Container>
0188 std::string stringify(const Container& elements);
0189 
0190 std::pair<int, int> getAttributeElementTypeAndLength(
0191     const InferenceContext& ctx,
0192     const std::initializer_list<std::string>& attribute_names);
0193 
0194 inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, int64_t dim2) {
0195   TensorShapeProto::Dimension result;
0196   if (dim1.has_dim_value()) {
0197     result.set_dim_value(dim1.dim_value() * dim2);
0198   } else if (dim2 == 1) {
0199     return dim1;
0200   }
0201   return result;
0202 }
0203 
0204 inline TensorShapeProto::Dimension operator/(TensorShapeProto::Dimension dim1, int64_t dim2) {
0205   TensorShapeProto::Dimension result;
0206   if (dim1.has_dim_value()) {
0207     result.set_dim_value(dim1.dim_value() / dim2);
0208   } else if (dim2 == 1) {
0209     return dim1;
0210   }
0211   return result;
0212 }
0213 
0214 // if from >= upto_exclusive, return 1.
0215 // Caller must make sure upto_exclusive is less than or equal to shape.size()
0216 // Caller must make sure from>=0
0217 inline TensorShapeProto::Dimension multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) {
0218   TensorShapeProto::Dimension dim;
0219   dim.set_dim_value(1);
0220   for (int i = from; i < upto_exclusive; ++i) {
0221     dim = dim * shape.dim(i);
0222   }
0223   return dim;
0224 }
0225 
0226 inline int32_t getTensorElementType(const TypeProto& type) {
0227   int32_t result = TensorProto::UNDEFINED;
0228   const auto value_case = type.value_case();
0229   if (value_case == TypeProto::kTensorType) {
0230     result = type.tensor_type().elem_type();
0231   } else if (value_case == TypeProto::kSparseTensorType) {
0232     result = type.sparse_tensor_type().elem_type();
0233   }
0234   return result;
0235 }
0236 
0237 inline void setTensorElementType(int32_t elem_type, TypeProto::ValueCase value_case, TypeProto& type) {
0238   if (value_case == TypeProto::kTensorType) {
0239     type.mutable_tensor_type()->set_elem_type(elem_type);
0240   } else if (value_case == TypeProto::kSparseTensorType) {
0241     type.mutable_sparse_tensor_type()->set_elem_type(elem_type);
0242   }
0243 }
0244 
0245 void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type);
0246 
0247 void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
0248 
0249 void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
0250 
0251 inline void propagateElemTypeFromDtypeToOutput(
0252     InferenceContext& ctx,
0253     const int data_type,
0254     size_t outputIndex,
0255     TypeProto::ValueCase expected_value_case) {
0256   const auto attribute_tensor_datatype = data_type;
0257   auto output_type = ctx.getOutputType(outputIndex);
0258   const auto output_value_case = output_type->value_case();
0259   if (output_value_case == TypeProto::VALUE_NOT_SET || output_value_case == expected_value_case) {
0260     setTensorElementType(attribute_tensor_datatype, expected_value_case, *output_type);
0261   } else {
0262     // This is not expected to happen
0263     fail_type_inference(
0264         "Output ", outputIndex, " expected to have: ", expected_value_case, " or UNDEFINED. Got: ", output_value_case);
0265   }
0266 }
0267 
0268 inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const int data_type, size_t outputIndex) {
0269   propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, TypeProto::kTensorType);
0270 }
0271 
0272 inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const AttributeProto* attr, size_t outputIndex) {
0273   int32_t data_type = TensorProto::UNDEFINED;
0274   TypeProto::ValueCase expected_value_case = TypeProto::VALUE_NOT_SET;
0275   const auto attr_type = attr->type();
0276   if (attr_type == AttributeProto::TENSOR) {
0277     if (attr->t().dims().size() != 1) {
0278       fail_type_inference("Attribute expected to have a one-dim tensor");
0279     }
0280     data_type = attr->t().data_type();
0281     expected_value_case = TypeProto::kTensorType;
0282   } else if (attr_type == AttributeProto::SPARSE_TENSOR) {
0283     if (attr->sparse_tensor().dims().size() != 1) {
0284       fail_type_inference("Attribute expected to have a one-dim sparse tensor");
0285     }
0286     data_type = attr->sparse_tensor().values().data_type();
0287     expected_value_case = TypeProto::kSparseTensorType;
0288   } else {
0289     fail_type_inference("Attribute expected to have tensor or sparse tensor type");
0290   }
0291 
0292   propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case);
0293 }
0294 
0295 inline bool hasShape(const TypeProto& type) {
0296   if (type.has_tensor_type()) {
0297     return type.tensor_type().has_shape();
0298   } else if (type.has_sparse_tensor_type()) {
0299     return type.sparse_tensor_type().has_shape();
0300   } else if (type.has_sequence_type() && type.sequence_type().has_elem_type()) {
0301     return hasShape(type.sequence_type().elem_type());
0302   } else if (type.has_optional_type() && type.optional_type().has_elem_type()) {
0303     return hasShape(type.optional_type().elem_type());
0304   }
0305   return false;
0306 }
0307 
0308 template <typename Context>
0309 inline bool hasInputShape(const Context& ctx, size_t n) {
0310   return ctx.getNumInputs() > static_cast<size_t>(n) && ctx.getInputType(n) && hasShape(*ctx.getInputType(n));
0311 }
0312 
0313 template <typename Context>
0314 inline bool hasNInputShapes(const Context& ctx, size_t n) {
0315   for (size_t i = 0; i < n; i++) {
0316     if (!hasInputShape(ctx, i)) {
0317       return false;
0318     }
0319   }
0320   return true;
0321 }
0322 
0323 inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t n) {
0324   const auto* input_type = ctx.getInputType(n);
0325   const auto value_case = input_type->value_case();
0326   if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
0327     fail_type_inference("Attribute expected to have tensor or sparse tensor type");
0328   }
0329   if (value_case == TypeProto::kTensorType) {
0330     return input_type->tensor_type().shape();
0331   } else {
0332     return input_type->sparse_tensor_type().shape();
0333   }
0334 }
0335 
0336 inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size_t n) {
0337   const auto* input_type = ctx.getInputType(n);
0338 
0339   if (input_type == nullptr) {
0340     return nullptr;
0341   }
0342 
0343   const auto value_case = input_type->value_case();
0344   if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
0345     fail_type_inference("Attribute expected to have tensor or sparse tensor type");
0346   }
0347   if (value_case == TypeProto::kTensorType) {
0348     return &input_type->tensor_type().shape();
0349   } else {
0350     return &input_type->sparse_tensor_type().shape();
0351   }
0352 }
0353 
0354 // Caller must make sure fromDimIndex is strictly less than shape.dim_size()
0355 inline void appendSingleDimCopiedFromInputTypeToOutputType(
0356     InferenceContext& ctx,
0357     size_t inputIndex,
0358     size_t outputIndex,
0359     size_t fromDimIndex) {
0360   auto output_type = ctx.getOutputType(outputIndex);
0361   const auto output_value_case = output_type->value_case();
0362   auto input_type = ctx.getInputType(inputIndex);
0363   const auto input_value_case = input_type->value_case();
0364   if (output_value_case != input_value_case) {
0365     fail_type_inference(
0366         "Input: ",
0367         inputIndex,
0368         " type: ",
0369         input_value_case,
0370         " does not match type of output: ",
0371         outputIndex,
0372         "type: ",
0373         output_value_case);
0374   }
0375   if (TypeProto::kTensorType == input_value_case) {
0376     auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim();
0377     *dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex));
0378   } else if (TypeProto::kSparseTensorType == input_value_case) {
0379     auto* dim = output_type->mutable_sparse_tensor_type()->mutable_shape()->add_dim();
0380     *dim = input_type->sparse_tensor_type().shape().dim(static_cast<int>(fromDimIndex));
0381   } else {
0382     fail_type_inference(
0383         "Input ", inputIndex, " and Output ", outputIndex, " expected to have tensor or sparse tensor type");
0384   }
0385 }
0386 
0387 inline void propagateShape(const TypeProto* from_type, TypeProto* to_type) {
0388   const auto from_type_case = from_type->value_case();
0389   const auto to_type_case = to_type->value_case();
0390   if (from_type_case != to_type_case) {
0391     fail_shape_inference(
0392         "Mismatch between inferred and declared type. Inferred=", from_type_case, " Declared=", to_type_case);
0393   }
0394 
0395   if (TypeProto::kTensorType == from_type_case || TypeProto::kSparseTensorType == from_type_case) {
0396     // If input shape is "unknown", the corresponding should be "unknown" too.
0397     // The way to make output shape unknown is not to assign it any value.
0398     if (hasShape(*from_type)) {
0399       if (TypeProto::kTensorType == from_type_case) {
0400         *to_type->mutable_tensor_type()->mutable_shape() = from_type->tensor_type().shape();
0401       } else {
0402         *to_type->mutable_sparse_tensor_type()->mutable_shape() = from_type->sparse_tensor_type().shape();
0403       }
0404     }
0405   } else if (TypeProto::kSequenceType == from_type_case) {
0406     propagateShape(&from_type->sequence_type().elem_type(), to_type->mutable_sequence_type()->mutable_elem_type());
0407   } else if (TypeProto::kOptionalType == from_type_case) {
0408     propagateShape(&from_type->optional_type().elem_type(), to_type->mutable_optional_type()->mutable_elem_type());
0409   } else if (TypeProto::kMapType == from_type_case) {
0410     propagateShape(&from_type->map_type().value_type(), to_type->mutable_map_type()->mutable_value_type());
0411   } else {
0412     fail_shape_inference("Unsupported Source/Target type=", from_type_case);
0413   }
0414 }
0415 
0416 inline void propagateShapeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
0417   auto output_type = ctx.getOutputType(outputIndex);
0418   auto input_type = ctx.getInputType(inputIndex);
0419 
0420   propagateShape(input_type, output_type);
0421 }
0422 
0423 inline void propagateShapeAndTypeFromFirstInput(InferenceContext& ctx) {
0424   propagateElemTypeFromInputToOutput(ctx, 0, 0);
0425   if (!hasNInputShapes(ctx, 1)) {
0426     return;
0427   }
0428   propagateShapeFromInputToOutput(ctx, 0, 0);
0429 }
0430 
0431 inline void
0432 updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType, TypeProto::ValueCase expected_type) {
0433   auto output_type = ctx.getOutputType(outputIndex);
0434   if (output_type == nullptr) {
0435     fail_type_inference("Output ", outputIndex, " is null");
0436   }
0437   if (output_type->value_case() == expected_type || output_type->value_case() == TypeProto::VALUE_NOT_SET) {
0438     setTensorElementType(elemType, expected_type, *output_type);
0439   } else {
0440     // This is not expected to happen
0441     fail_type_inference("Output ", outputIndex, " expected to have tensor or sparse tensor type: ", expected_type);
0442   }
0443 }
0444 
0445 inline void updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType) {
0446   updateOutputElemType(ctx, outputIndex, elemType, TypeProto::kTensorType);
0447 }
0448 
0449 // Infer type of an output from the value of a specified attribute, which is
0450 // expected to have a valid value representing a TensorProto_DataType.
0451 inline void propagateElemTypeFromAttributeToOutput(
0452     InferenceContext& ctx,
0453     const std::string& attributeName,
0454     size_t outputIndex,
0455     TypeProto::ValueCase expected_type,
0456     TensorProto_DataType default_value = TensorProto::UNDEFINED) {
0457   auto attr_proto = ctx.getAttribute(attributeName);
0458   if (nullptr == attr_proto) { // attribute not present
0459     if (default_value != TensorProto::UNDEFINED) {
0460       updateOutputElemType(ctx, outputIndex, default_value, expected_type);
0461       return;
0462     } else {
0463       fail_type_inference("Value of attribute ", attributeName, " not specified");
0464     }
0465   }
0466   if (!attr_proto->has_i()) {
0467     fail_type_inference("Attribute ", attributeName, " should be of integer type and specify a type.");
0468   }
0469   auto attr_value = attr_proto->i();
0470   auto elem_type = static_cast<TensorProto_DataType>(attr_value);
0471   if (!TensorProto_DataType_IsValid(elem_type)) {
0472     fail_type_inference("Attribute ", attributeName, " does not specify a valid type.");
0473   }
0474   updateOutputElemType(ctx, outputIndex, elem_type, expected_type);
0475 }
0476 
0477 inline void propagateElemTypeFromAttributeToOutput(
0478     InferenceContext& ctx,
0479     const std::string& attributeName,
0480     size_t outputIndex,
0481     TensorProto_DataType default_value = TensorProto::UNDEFINED) {
0482   propagateElemTypeFromAttributeToOutput(ctx, attributeName, outputIndex, TypeProto::kTensorType, default_value);
0483 }
0484 
0485 inline TensorShapeProto* getTensorMutableShape(TypeProto::ValueCase value_case, TypeProto& type) {
0486   if (value_case == TypeProto::kTensorType) {
0487     return type.mutable_tensor_type()->mutable_shape();
0488   } else if (value_case == TypeProto::kSparseTensorType) {
0489     return type.mutable_tensor_type()->mutable_shape();
0490   }
0491   return nullptr;
0492 }
0493 
0494 inline TensorShapeProto*
0495 getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0496   auto output_type = ctx.getOutputType(n);
0497   if (output_type == nullptr) {
0498     fail_type_inference("Output ", n, " expected to have tensor or sparse type");
0499   }
0500   const auto output_value_case = output_type->value_case();
0501   if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
0502     return getTensorMutableShape(output_value_case, *output_type);
0503   } else if (output_value_case == TypeProto::VALUE_NOT_SET) {
0504     return getTensorMutableShape(default_type, *output_type);
0505   } else {
0506     fail_type_inference("Output ", n, " expected to have tensor type");
0507   }
0508 }
0509 
0510 inline void appendDim(TensorShapeProto* shape, int64_t dim_value) {
0511   shape->add_dim()->set_dim_value(dim_value);
0512 }
0513 
0514 inline void updateOutputShape(
0515     InferenceContext& ctx,
0516     size_t outputIndex,
0517     const TensorShapeProto& shape,
0518     TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0519   auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0520   *output_shape = shape;
0521 }
0522 
0523 inline void updateOutputShape(
0524     InferenceContext& ctx,
0525     size_t outputIndex,
0526     const TensorProto& tensorProto,
0527     TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0528   auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0529   for (auto d : tensorProto.dims()) {
0530     auto* dim = output_shape->add_dim();
0531     dim->set_dim_value(d);
0532   }
0533 }
0534 
0535 inline void updateOutputShape(
0536     InferenceContext& ctx,
0537     size_t outputIndex,
0538     std::initializer_list<TensorShapeProto::Dimension> dims,
0539     TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0540   auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
0541   for (auto& d : dims) {
0542     auto* dim = output_shape->add_dim();
0543     *dim = d;
0544   }
0545 }
0546 
0547 // Get shape input by first checking initializer and then propagated symbolic data.
0548 // If neither is available, try rank inference.
0549 // When one of above succeeds, `true` is stored in `found`.
0550 // Otherwise, `false` is stored, which means that returned TensorShapeProto does not make sense.
0551 TensorShapeProto getShapeInput(const InferenceContext& ctx, size_t input_index, bool& found);
0552 
0553 // Infer shape of an output from the value of a specified attribute, which is
0554 // expected to be a list of integers specifying a valid shape.
0555 inline void propagateShapeFromAttributeToOutput(
0556     InferenceContext& ctx,
0557     const std::string& attributeName,
0558     size_t outputIndex,
0559     TypeProto::ValueCase default_type = TypeProto::kTensorType) {
0560   auto attr_proto = ctx.getAttribute(attributeName);
0561   if ((nullptr == attr_proto) || (!attr_proto->has_type()) ||
0562       (attr_proto->type() != AttributeProto_AttributeType_INTS)) {
0563     fail_shape_inference("Attribute ", attributeName, " should specify a shape");
0564   }
0565   auto& int_list = attr_proto->ints();
0566   TensorShapeProto shape;
0567   for (auto dim_size : int_list) {
0568     if (dim_size < 0) {
0569       fail_shape_inference("Negative values are not allowed in a shape specification");
0570     }
0571     shape.add_dim()->set_dim_value(dim_size);
0572   }
0573 
0574   updateOutputShape(ctx, outputIndex, shape, default_type);
0575 }
0576 
0577 inline void multidirectionalBroadcastShapeInference(
0578     const std::vector<const TensorShapeProto*>& shapes,
0579     TensorShapeProto& resultShape) {
0580   int result_shape_size = 0;
0581   // Get the result shape size.
0582   for (size_t i = 0; i < shapes.size(); ++i) {
0583     if (shapes[i]->dim_size() > result_shape_size) {
0584       result_shape_size = shapes[i]->dim_size();
0585     }
0586   }
0587 
0588   for (int i = 0; i < result_shape_size; ++i) {
0589     int64_t dim_value = 1;
0590     TensorShapeProto_Dimension symbolic_dim;
0591     int num_symbolic_dims = 0;
0592     for (size_t j = 0; j < shapes.size(); ++j) {
0593       if (i < result_shape_size - shapes[j]->dim_size()) {
0594         // Shape j will be filled with 1 at dimension i;
0595         continue;
0596       }
0597 
0598       auto dim_i_j = shapes[j]->dim(i - result_shape_size + shapes[j]->dim_size());
0599       if (dim_i_j.has_dim_value()) {
0600         if (dim_i_j.dim_value() != 1) {
0601           if (dim_value != dim_i_j.dim_value() && dim_value != 1) {
0602             fail_shape_inference("Incompatible dimensions");
0603           } else {
0604             dim_value = dim_i_j.dim_value();
0605           }
0606         }
0607       } else {
0608         if (num_symbolic_dims == 0) {
0609           symbolic_dim = dim_i_j;
0610           ++num_symbolic_dims;
0611         } else if (dim_i_j.dim_param() != symbolic_dim.dim_param()) {
0612           ++num_symbolic_dims;
0613         }
0614       }
0615     }
0616 
0617     if (dim_value != 1 || num_symbolic_dims == 0) {
0618       resultShape.add_dim()->set_dim_value(dim_value);
0619     } else if (num_symbolic_dims == 1) {
0620       *resultShape.add_dim() = symbolic_dim;
0621     } else {
0622       resultShape.add_dim();
0623     }
0624   }
0625 }
0626 
0627 inline void bidirectionalBroadcastShapeInference(
0628     const TensorShapeProto& shapeL,
0629     const TensorShapeProto& shapeR,
0630     TensorShapeProto& resultShape) {
0631   std::vector<const TensorShapeProto*> shapes;
0632   shapes.push_back(&shapeL);
0633   shapes.push_back(&shapeR);
0634   multidirectionalBroadcastShapeInference(shapes, resultShape);
0635 }
0636 
0637 /*
0638 Merge the dimension information from two TensorShapeProto_Dimension instances.
0639 Values are merged into target from source.
0640 If target has no dimension information, copy from source.
0641 If source has no dimension information, ignore source.
0642 If both have dimension information:
0643  - Prefer values over params. If both have values, values must match.
0644  - Prefer target param over source param if mismatched.
0645 Fail if there are mismatches in dimension values.
0646 Currently, there is no way to refine/update dimension information for the
0647 source from information available in the target.
0648 */
0649 inline void mergeInDimensionInfo(
0650     const TensorShapeProto_Dimension& source_dim,
0651     TensorShapeProto_Dimension& target_dim,
0652     int dim_index) {
0653   // if source has value, merge into target
0654   // else if target has value, preserve it
0655   // else merge params
0656   if (source_dim.has_dim_value()) {
0657     auto source_value = source_dim.dim_value();
0658     if (target_dim.has_dim_value()) {
0659       auto target_value = target_dim.dim_value();
0660       if (target_value != source_value) {
0661         fail_shape_inference(
0662             "Can't merge shape info. "
0663             "Both inferred and declared dimension have values but they differ. Inferred=",
0664             source_value,
0665             " Declared=",
0666             target_value,
0667             " Dimension=",
0668             dim_index);
0669       }
0670     } else {
0671       target_dim.set_dim_value(source_value);
0672     }
0673   } else if (target_dim.has_dim_value()) {
0674     // if target has a value we preserve it so do nothing
0675   } else if (target_dim.has_dim_param()) {
0676     // prefer target param over source
0677   } else if (source_dim.has_dim_param()) {
0678     target_dim.set_dim_param(source_dim.dim_param());
0679   }
0680 }
0681 
0682 void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
0683 
0684 void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
0685 
0686 /*
0687 Merge the shape information from two TypeProto_Tensor instances.
0688 Values are merged into target from source.
0689 If target has no shape information, copy from source.
0690 If source has no shape information, ignore source.
0691 If both have shape information:
0692 - merge each TensorShapeProto_Dimension separately.
0693 - Prefer values over params. If both have values, values must match.
0694 - Prefer target param over source param if mismatched.
0695 Fail if there are mismatches in number of dimensions or dimension values.
0696 */
0697 void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target);
0698 
0699 void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target);
0700 
0701 // Return a copy of a type, with a specified dimension removed from its shape.
0702 inline TypeProto RemoveIthDimensionFromShape(const TypeProto& proto, int removed_dim) {
0703   TypeProto t(proto);
0704   auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
0705   mutable_shape->clear_dim();
0706 
0707   const auto& dims = proto.tensor_type().shape().dim();
0708 
0709   for (int j = 0, end = dims.size(); j < end; ++j) {
0710     if (j != removed_dim)
0711       (*mutable_shape->add_dim()) = dims.Get(j);
0712   }
0713 
0714   return t;
0715 }
0716 
0717 // Return a copy of a type, with specified number of dimensions removed from the
0718 // beginning.
0719 inline TypeProto RemoveDimensionsFromShape(const TypeProto& proto, int num_dimensions) {
0720   TypeProto t(proto);
0721   auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
0722   mutable_shape->clear_dim();
0723 
0724   const auto& dims = proto.tensor_type().shape().dim();
0725 
0726   // skip first num_dimensions
0727   for (int j = num_dimensions, end = dims.size(); j < end; ++j) {
0728     (*mutable_shape->add_dim()) = dims.Get(j);
0729   }
0730 
0731   return t;
0732 }
0733 
0734 // copied from GSL:
0735 // https://github.com/microsoft/GSL/blob/main/include/gsl/util
0736 template <class T, class U>
0737 static constexpr T narrow_cast(U&& u) noexcept {
0738   return static_cast<T>(std::forward<U>(u));
0739 }
0740 
0741 inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expected_rank) {
0742   // We check the rank only if a rank is known for the input:
0743   if (hasInputShape(ctx, input_index)) {
0744     auto rank = getInputShape(ctx, input_index).dim_size();
0745     if (rank != expected_rank) {
0746       fail_shape_inference("Input ", input_index, " expected to have rank ", expected_rank, " but has rank ", rank);
0747     }
0748   }
0749 }
0750 
0751 // Unification (between dimensions and/or shapes) is at the heart of
0752 // shape-inference. The current inference algorithm can check input
0753 // shapes/dimensions of a node and update the output shapes/dimensions. It
0754 // cannot currently update input shapes and dimensions (even though in some
0755 // contexts this inference is possible). Hence, we have the variants below to
0756 // support "const" and "mutable" dimensions/shapes in unification.
0757 
0758 inline void checkDimEquality(int64_t value1, int64_t value2) {
0759   if (value1 != value2) {
0760     fail_shape_inference("Dimension mismatch in unification between ", value1, " and ", value2);
0761   }
0762 }
0763 
0764 inline void unifyDim(const Dim& dim1, const Dim& dim2) {
0765   if (dim1.has_dim_value() && dim2.has_dim_value())
0766     checkDimEquality(dim1.dim_value(), dim2.dim_value());
0767 }
0768 
0769 // TODO: The functionality of unifyDim is similar to that of
0770 // mergeInDimensionInfo. However, the error messages are different. Leaving this
0771 // duplication in-place to preserve error message content.
0772 inline void unifyDim(const Dim& source_dim, Dim& target_dim) {
0773   if (source_dim.has_dim_value()) {
0774     auto source_value = source_dim.dim_value();
0775     if (target_dim.has_dim_value()) {
0776       auto target_value = target_dim.dim_value();
0777       checkDimEquality(source_value, target_value);
0778     } else {
0779       target_dim.set_dim_value(source_value);
0780     }
0781   } else if (target_dim.has_dim_value()) {
0782     // if target has a value we preserve it.
0783     // we cannot set source dim value.
0784   } else if (target_dim.has_dim_param()) {
0785     // prefer target param over source
0786     // we cannot currently unify the dim_params
0787   } else if (source_dim.has_dim_param()) {
0788     target_dim.set_dim_param(source_dim.dim_param());
0789   }
0790 }
0791 
0792 inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_index, Dim& dim) {
0793   // We unify the dimensions only if it is available for specified input:
0794   if (hasInputShape(ctx, input_index)) {
0795     auto& input_shape = getInputShape(ctx, input_index);
0796     // This shape is expected to have rank > dim_index:
0797     if (input_shape.dim_size() <= dim_index) {
0798       fail_shape_inference(
0799           "Input ", input_index, " expected to have rank >", dim_index, " but has rank ", input_shape.dim_size());
0800     }
0801     const Dim& input_dim = input_shape.dim(dim_index);
0802     // Now, unify dim and input_dim:
0803     unifyDim(input_dim, dim);
0804   }
0805 }
0806 
0807 // unifyDim: unifies a dimension with a constant value. If the dimension
0808 // already has a value, we check for equality of new value with old value.
0809 inline void unifyDim(Dim& dim, int64_t value) {
0810   if (dim.has_dim_value()) {
0811     checkDimEquality(dim.dim_value(), value);
0812   } else
0813     dim.set_dim_value(value);
0814 }
0815 
0816 // target-shape = Union (target-shape, source_shape)
0817 // Example 1: same rank, different dimensions
0818 //    input1 shape: (2, 3, 4, 'x')
0819 //    input2 shape: (2, 'y', 5, 'x')
0820 //    output shape: (2, None, None, 'x')
0821 // Example 2: different rank
0822 //    input1 shape: (2, 3, 4, 'x')
0823 //    input2 shape: (2, 3, 4)
0824 //    output shape: None
0825 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
0826 
0827 void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
0828 
0829 // target-type = Union (target-type, source-type)
0830 // target and source are required to have the same type.
0831 // Example 1: same tensor type, different shape
0832 //    source: tensor elem_type: int64, shape: (2, 3, 4, 'x')
0833 //    target: tensor elem_type: int64, shape: (2, 'y', 5, 'x')
0834 //    output: tensor elem_type: int64, shape: (2, None, None, 'x')
0835 // Example 2: same sequence type, different shape
0836 //    source: sequence of tensor, elem_type: float, shape: (2, 3, 4)
0837 //    target: sequence of tensor, elem_type: float, shape: None
0838 //    output: sequence of tensor, elem_type: float, shape: None
0839 void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type);
0840 
0841 // adjustNegativeAxes: Negative axes values are translated to the right axis in the positive range
0842 template <typename Axes>
0843 void adjustNegativeAxes(Axes& axes, int rank) {
0844   std::transform(
0845       axes.begin(), axes.end(), axes.begin(), [&](int64_t axis) -> int64_t { return axis < 0 ? axis + rank : axis; });
0846 }
0847 
0848 // checkAxesRange: Checks that values are within the range [-rank, rank)
0849 template <typename Axes>
0850 void checkAxesRange(Axes& axes, int rank) {
0851   for (auto axis : axes) {
0852     if (axis < -rank || axis > (rank - 1))
0853       fail_shape_inference("Unexpected axis value: ", axis, ". Expected range [", -rank, ", ", rank, ")");
0854   }
0855 }
0856 
0857 // checkDuplicateAxes: Check that there are no duplicated axes
0858 template <typename Axes>
0859 void checkDuplicateAxes(Axes& axes, int rank) {
0860   std::vector<bool> tmp(rank, false);
0861   for (auto axis : axes) {
0862     int actual_axis = axis < 0 ? axis + rank : axis;
0863     if (tmp[actual_axis])
0864       fail_shape_inference("Axis ", axis, " is referred to more than once.");
0865     tmp[actual_axis] = true;
0866   }
0867 }
0868 
0869 } // namespace ONNX_NAMESPACE