Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 09:03:56

0001 // Copyright (c) ONNX Project Contributors
0002 //
0003 // SPDX-License-Identifier: Apache-2.0
0004 
0005 #pragma once
0006 
0007 #include <map>
0008 #include <memory>
0009 #include <string>
0010 #include <unordered_map>
0011 #include <unordered_set>
0012 #include <utility>
0013 #include <vector>
0014 
0015 #include "onnx/defs/function.h"
0016 #include "onnx/defs/schema.h"
0017 #include "onnx/proto_utils.h"
0018 #include "onnx/string_utils.h"
0019 
0020 namespace ONNX_NAMESPACE {
0021 namespace shape_inference {
0022 
0023 using ModelLocalFunctionsMap = std::unordered_map<std::string, const FunctionProto*>;
0024 
0025 // We reuse TensorShapeProto to propagate statically known (partial) information about
0026 // the values of tensors. It is intended for tensors used to store shape information
0027 // (the return values of ops like Shape and input values of ops like Reshape/Expand).
0028 
0029 // A DataValueMap is used to store the statically known (partial) values of variables.
0030 using DataValueMap = std::unordered_map<std::string, TensorShapeProto>;
0031 
0032 class SymbolTableImpl : public SymbolTable {
0033  public:
0034   SymbolTableImpl() : index_(0) {}
0035 
0036   void addFromGraph(const GraphProto& g) override {
0037     AddExistingSymbolicDims(g.input());
0038     AddExistingSymbolicDims(g.output());
0039     AddExistingSymbolicDims(g.value_info());
0040   }
0041   // Creates a new unique symbol with the given prefix and adds it to the SymbolTable
0042   // Returns the newly created symbol
0043   std::string createNew(const std::string& symbol_prefix) override {
0044     std::string newSymbol;
0045     do {
0046       newSymbol = symbol_prefix + std::to_string(index_++);
0047     } while (existing_symbols.count(newSymbol) > 0);
0048     existing_symbols.insert(newSymbol);
0049     return newSymbol;
0050   }
0051 
0052  private:
0053   unsigned int index_;
0054   std::unordered_set<std::string> existing_symbols;
0055 
0056   // TypeProto_Tensor or TypeProto_SparseTensor
0057   template <typename TensorTypeProto>
0058   void AddExistingSymbolicDims(const TensorTypeProto& tensorType) {
0059     if (tensorType.has_shape()) {
0060       for (int i = 0; i < tensorType.shape().dim_size(); ++i) {
0061         if (tensorType.shape().dim(i).has_dim_param()) {
0062           existing_symbols.insert(tensorType.shape().dim(i).dim_param());
0063         }
0064       }
0065     }
0066   }
0067 
0068   void AddExistingSymbolicDims(const TypeProto& typeProto) {
0069     const auto val_case = typeProto.value_case();
0070     switch (val_case) {
0071       case TypeProto::kTensorType:
0072         AddExistingSymbolicDims(typeProto.tensor_type());
0073         break;
0074       case TypeProto::kSparseTensorType:
0075         AddExistingSymbolicDims(typeProto.sparse_tensor_type());
0076         break;
0077       case TypeProto::kSequenceType:
0078         AddExistingSymbolicDims(typeProto.sequence_type().elem_type());
0079         break;
0080       case TypeProto::kOptionalType:
0081         AddExistingSymbolicDims(typeProto.optional_type().elem_type());
0082         break;
0083       case TypeProto::kMapType:
0084         AddExistingSymbolicDims(typeProto.map_type().value_type());
0085         break;
0086       default:
0087         break;
0088     }
0089   }
0090 
0091   void AddExistingSymbolicDims(const google::protobuf::RepeatedPtrField<ValueInfoProto>& protos) {
0092     for (const auto& proto : protos) {
0093       AddExistingSymbolicDims(proto.type());
0094     }
0095   }
0096 };
0097 
0098 struct GraphInferenceContext {
0099   GraphInferenceContext(
0100       const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name_in,
0101       const std::unordered_map<std::string, int> opset_imports_in,
0102       SymbolTable* symbol_table_in = nullptr,
0103       const ModelLocalFunctionsMap& model_local_functions_in = {},
0104       const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(),
0105       DataValueMap* generated_shape_data_by_name_in = nullptr,
0106       const int ir_version_in = IR_VERSION)
0107       : outer_scope_value_types_by_name{&outer_scope_value_types_by_name_in},
0108         opset_imports{opset_imports_in},
0109         symbol_table{symbol_table_in},
0110         model_local_functions{model_local_functions_in},
0111         schema_registry{schema_registry_in},
0112         generated_shape_data_by_name{generated_shape_data_by_name_in},
0113         ir_version{ir_version_in} {}
0114 
0115   const std::unordered_map<std::string, TypeProto*>* outer_scope_value_types_by_name;
0116   const std::unordered_map<std::string, int> opset_imports;
0117   SymbolTable* symbol_table;
0118   const ModelLocalFunctionsMap& model_local_functions;
0119   const ISchemaRegistry* schema_registry;
0120   DataValueMap* generated_shape_data_by_name;
0121   const int ir_version;
0122 };
0123 
0124 class GraphInferencerImpl : public GraphInferencer {
0125  public:
0126   GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context) : g_{&g}, context_{&context}, options_() {}
0127   GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context, const ShapeInferenceOptions& options)
0128       : g_{&g}, context_{&context}, options_(options) {}
0129 
0130   std::vector<const TypeProto*> doInferencing(
0131       const std::vector<const TypeProto*>& inputTypes,
0132       const std::vector<const TensorProto*>& inputData) override;
0133 
0134  private:
0135   GraphProto* g_;
0136   GraphInferenceContext* context_;
0137   ShapeInferenceOptions options_;
0138 };
0139 
0140 struct InferenceContextImpl : public InferenceContext {
0141   InferenceContextImpl(
0142       NodeProto& n,
0143       const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
0144       const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
0145       const std::unordered_map<std::string, const SparseTensorProto*>& inputSparseDataByName,
0146       const ShapeInferenceOptions& options,
0147       DataValueMap* generatedShapeData = nullptr,
0148       GraphInferenceContext* graphInferenceContext = nullptr)
0149       : graphInferenceContext_{graphInferenceContext}, options_(options), node_(&n) {
0150     for (auto& attr : *n.mutable_attribute()) {
0151       attributesByName_[attr.name()] = &attr;
0152       if (attr.has_g()) {
0153         // need a mutable GraphProto to run inferencing on this attribute
0154         graphProtoAttributesByName_[attr.name()] = attr.mutable_g();
0155       }
0156     }
0157 
0158     for (const auto& input : n.input()) {
0159       auto valueTypesIter = valueTypesByName.find(input);
0160       if (valueTypesIter != valueTypesByName.end()) {
0161         allInputTypes_.push_back(valueTypesIter->second);
0162       } else {
0163         allInputTypes_.push_back(nullptr);
0164       }
0165 
0166       // input data can be in 1 of the 3 containers
0167       // inputDataByName - this is when input is TensorProto
0168       // inputSparseDataByName - this is when input is SparseTensorProto
0169       // generatedShapeData - this is when input was generated as part of partial data propagation
0170       const auto inputDataIter = inputDataByName.find(input);
0171       if (inputDataIter != inputDataByName.cend()) {
0172         allInputData_.push_back(inputDataIter->second);
0173         allInputSparseData_.push_back(nullptr);
0174         allShapeInputData_.push_back(nullptr);
0175       } else {
0176         allInputData_.push_back(nullptr);
0177         const auto inputSparseDataIter = inputSparseDataByName.find(input);
0178         if (inputSparseDataIter != inputSparseDataByName.cend()) {
0179           allInputSparseData_.push_back(inputSparseDataIter->second);
0180           allShapeInputData_.push_back(nullptr);
0181         } else {
0182           allInputSparseData_.push_back(nullptr);
0183           if (generatedShapeData != nullptr) {
0184             const auto inputShapeDataIter = generatedShapeData->find(input);
0185             if (inputShapeDataIter != generatedShapeData->cend()) {
0186               allShapeInputData_.push_back(&inputShapeDataIter->second);
0187             } else {
0188               allShapeInputData_.push_back(nullptr);
0189             }
0190           } else {
0191             allShapeInputData_.push_back(nullptr);
0192           }
0193         }
0194       }
0195     }
0196 
0197     allOutputTypes_.resize(n.output_size());
0198   }
0199 
0200   const AttributeProto* getAttribute(const std::string& name) const override {
0201     auto iter = attributesByName_.find(name);
0202     if (iter == attributesByName_.end()) {
0203       return nullptr;
0204     } else {
0205       return iter->second;
0206     }
0207   }
0208 
0209   size_t getNumInputs() const override {
0210     return allInputTypes_.size();
0211   }
0212 
0213   const TypeProto* getInputType(size_t index) const override {
0214     if (index >= allInputTypes_.size()) {
0215       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0216     }
0217     return allInputTypes_[index];
0218   }
0219 
0220   const TensorProto* getInputData(size_t index) const override {
0221     if (index >= allInputData_.size()) {
0222       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0223     }
0224     return allInputData_[index];
0225   }
0226 
0227   const TensorShapeProto* getSymbolicInput(size_t index) const override {
0228     if (index >= allShapeInputData_.size()) {
0229       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0230     }
0231 
0232     return allShapeInputData_[index];
0233   }
0234 
0235   const SparseTensorProto* getInputSparseData(size_t index) const override {
0236     if (index >= allInputSparseData_.size()) {
0237       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0238     }
0239     return allInputSparseData_[index];
0240   }
0241 
0242   size_t getNumOutputs() const override {
0243     return allOutputTypes_.size();
0244   }
0245 
0246   TypeProto* getOutputType(size_t index) override {
0247     if (index >= allOutputTypes_.size()) {
0248       ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0249     }
0250     return &allOutputTypes_[index];
0251   }
0252 
0253   GraphInferencer* getGraphAttributeInferencer(const std::string& attr_name) override {
0254     if (!graphInferenceContext_) {
0255       fail_type_inference("GraphProto attribute inferencing is not enabled in this InferenceContextImpl instance.");
0256     }
0257 
0258     GraphInferencer* inferencer = nullptr;
0259 
0260     auto entry = graphAttributeInferencers_.find(attr_name);
0261     if (entry == graphAttributeInferencers_.cend()) {
0262       // create GraphInferencer instance
0263       auto attrNameToGraphProto = graphProtoAttributesByName_.find(attr_name);
0264       if (attrNameToGraphProto == graphProtoAttributesByName_.cend()) {
0265         fail_type_inference("Attribute ", attr_name, " does not contain a graph.");
0266       }
0267 
0268       std::unique_ptr<GraphInferencer> new_inferencer{
0269           new GraphInferencerImpl(*attrNameToGraphProto->second, *graphInferenceContext_, options_)};
0270 
0271       inferencer = new_inferencer.get();
0272       graphAttributeInferencers_.emplace(attr_name, std::move(new_inferencer));
0273     } else {
0274       inferencer = entry->second.get();
0275     }
0276 
0277     return inferencer;
0278   }
0279 
0280   std::string getDisplayName() const override {
0281     if (node_ == nullptr)
0282       return "";
0283     if (node_->domain().empty()) {
0284       if (node_->name().empty())
0285         return MakeString("node ", node_->op_type());
0286       return MakeString("node ", node_->op_type(), " (", node_->name(), ")");
0287     }
0288     if (node_->name().empty())
0289       return MakeString("node ", node_->op_type(), "[", node_->domain(), "]");
0290     return MakeString("node ", node_->op_type(), "[", node_->domain(), "]", " (", node_->name(), ")");
0291   }
0292 
0293   std::vector<const TensorProto*> allInputData_;
0294   std::vector<const SparseTensorProto*> allInputSparseData_;
0295   std::vector<const TensorShapeProto*> allShapeInputData_;
0296   std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0297   std::unordered_map<std::string, GraphProto*> graphProtoAttributesByName_;
0298   std::vector<const TypeProto*> allInputTypes_;
0299   std::vector<TypeProto> allOutputTypes_;
0300   GraphInferenceContext* graphInferenceContext_;
0301 
0302   // mutable as internal cache of GraphInferencer instances
0303   mutable std::unordered_map<std::string, std::unique_ptr<GraphInferencer>> graphAttributeInferencers_;
0304   ShapeInferenceOptions options_;
0305   NodeProto* node_;
0306 };
0307 
0308 struct DataPropagationContextImpl : public DataPropagationContext {
0309   DataPropagationContextImpl(
0310       NodeProto& n,
0311       const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
0312       const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
0313       DataValueMap& generatedShapeData)
0314       : generatedShapeData_(generatedShapeData) {
0315     size_t input_idx = 0;
0316 
0317     for (auto& attr : *n.mutable_attribute()) {
0318       attributesByName_[attr.name()] = &attr;
0319     }
0320 
0321     for (const auto& input : n.input()) {
0322       inputIndexToNameMap_.insert({input_idx++, input});
0323 
0324       auto valueTypesIter = valueTypesByName.find(input);
0325       if (valueTypesIter != valueTypesByName.end()) {
0326         allInputTypes_.push_back(valueTypesIter->second);
0327       } else {
0328         allInputTypes_.push_back(nullptr);
0329       }
0330 
0331       const auto inputDataIter = inputDataByName.find(input);
0332       if (inputDataIter != inputDataByName.cend()) {
0333         allInputData_.push_back(inputDataIter->second);
0334       } else {
0335         allInputData_.push_back(nullptr);
0336       }
0337     }
0338 
0339     size_t output_idx = 0;
0340     for (const auto& output : n.output()) {
0341       outputIndexToNameMap_.insert({output_idx++, output});
0342     }
0343 
0344     allOutputTypes_.resize(n.output_size());
0345   }
0346 
0347   const AttributeProto* getAttribute(const std::string& name) const override {
0348     auto iter = attributesByName_.find(name);
0349     if (iter == attributesByName_.end()) {
0350       return nullptr;
0351     } else {
0352       return iter->second;
0353     }
0354   }
0355 
0356   size_t getNumInputs() const override {
0357     return allInputTypes_.size();
0358   }
0359 
0360   const TypeProto* getInputType(size_t index) const override {
0361     if (index >= allInputTypes_.size()) {
0362       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0363     }
0364     return allInputTypes_[index];
0365   }
0366 
0367   size_t getNumOutputs() const override {
0368     return allOutputTypes_.size();
0369   }
0370 
0371   const TypeProto* getOutputType(size_t index) const override {
0372     if (index >= allOutputTypes_.size()) {
0373       ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0374     }
0375     return &allOutputTypes_[index];
0376   }
0377 
0378   // Convert integer vector into TensorShapeProto
0379   template <typename INTEGER>
0380   void vectorToTensorShapeProto(const std::vector<INTEGER>& input_vals, TensorShapeProto& converted_tsp) const {
0381     for (unsigned int i = 0; i < input_vals.size(); ++i) {
0382       converted_tsp.mutable_dim()->Add()->set_dim_value(input_vals[i]);
0383     }
0384   }
0385 
0386   const TensorShapeProto* getInputData(size_t index) override {
0387     if (index >= allInputData_.size()) {
0388       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0389     }
0390     const std::string input_name = inputIndexToNameMap_.at(index);
0391     // Gets it from previous data propagation
0392     auto iter = generatedShapeData_.find(input_name);
0393     if (iter != generatedShapeData_.end()) {
0394       return &iter->second;
0395     }
0396     // Otherwise, gets it from initializer if it exists
0397     const auto* input_data = allInputData_[index];
0398     // Only scalar (0D tensor) or 1D tensor can be converted for now
0399     // TODO: It should support tensors with more dimension on demand
0400     if (input_data != nullptr && (input_data->dims_size() == 0 || input_data->dims_size() == 1)) {
0401       TensorShapeProto tsp;
0402 
0403       if (input_data->data_type() == TensorProto_DataType_INT64) {
0404         vectorToTensorShapeProto(ParseData<int64_t>(input_data), tsp);
0405       } else if (input_data->data_type() == TensorProto_DataType_INT32) {
0406         vectorToTensorShapeProto(ParseData<int32_t>(input_data), tsp);
0407       } else {
0408         // Only supports integer type to form a shape
0409         return nullptr;
0410       }
0411 
0412       // Adds this TensorShapeProto from initializer into generatedShapeData
0413       // for future use
0414       auto result = generatedShapeData_.insert({input_name, std::move(tsp)});
0415       if (result.second) {
0416         return &(result.first->second);
0417       }
0418     }
0419     return nullptr;
0420   }
0421 
0422   void addOutputData(size_t index, TensorShapeProto&& tsp) override {
0423     if (index >= outputIndexToNameMap_.size()) {
0424       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0425     }
0426     auto result = generatedShapeData_.insert({outputIndexToNameMap_.at(index), std::move(tsp)});
0427     if (!result.second) {
0428       fail_shape_inference("Data for input  " + ONNX_NAMESPACE::to_string(index) + " already exists.");
0429     }
0430   }
0431 
0432   std::vector<const TensorProto*> allInputData_;
0433   std::unordered_map<size_t, std::string> inputIndexToNameMap_;
0434   std::unordered_map<size_t, std::string> outputIndexToNameMap_;
0435   std::vector<const TypeProto*> allInputTypes_;
0436   std::vector<TypeProto> allOutputTypes_;
0437   DataValueMap& generatedShapeData_;
0438   std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0439 };
0440 
0441 void checkShapesAndTypes(const TypeProto_Sequence& inferredType, const TypeProto_Sequence& existingType);
0442 
0443 void checkShapesAndTypes(const TypeProto& inferredType, const TypeProto& existingType);
0444 
0445 template <typename TensorTypeProto>
0446 void GenerateSymbolicShape(TensorTypeProto* inferredType, SymbolTable& symbolTable);
0447 
0448 void MaterializeSymbolicShape(TypeProto* inferredType, SymbolTable& symbolTable);
0449 
0450 void mergeShapesAndTypes(const TypeProto_Tensor& inferredType, TypeProto_Tensor* existingType);
0451 
0452 void mergeShapesAndTypes(const TypeProto_SparseTensor& inferredType, TypeProto_SparseTensor* existingType);
0453 
0454 void mergeShapesAndTypes(const TypeProto_Sequence& inferredType, TypeProto_Tensor* existingType);
0455 
0456 void mergeShapesAndTypes(const TypeProto& inferredType, TypeProto* existingType);
0457 
0458 ///
0459 /// ModelLocalFunctionsMap is a map of function id -> model local function proto
0460 /// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
0461 ///
0462 void InferShapes(
0463     GraphProto* g,
0464     const std::unordered_map<std::string, int>& opset_imports,
0465     const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0466     const ShapeInferenceOptions& options = {},
0467     const ModelLocalFunctionsMap& in_model_functions = {});
0468 
0469 void InferShapes(
0470     ModelProto& m,
0471     const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0472     const ShapeInferenceOptions& options = {},
0473     DataValueMap* generated_shape_data_by_name = nullptr);
0474 
0475 void InferShapes(
0476     const std::string& model_path,
0477     const std::string& save_path = "",
0478     const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0479     const ShapeInferenceOptions& options = {},
0480     DataValueMap* generated_shape_data_by_name = nullptr);
0481 
0482 ///
0483 /// ModelLocalFunctionsMap is a map of function id -> model local function proto
0484 /// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
0485 ///
0486 void InferShapeForFunctionNode(
0487     const FunctionProto& func,
0488     const ISchemaRegistry* schema_registry,
0489     InferenceContext& ctx,
0490     const ShapeInferenceOptions& options = {},
0491     const ModelLocalFunctionsMap& model_local_functions_map = {},
0492     SymbolTable* symbolTable = nullptr,
0493     DataValueMap* generated_shape_data_by_name = nullptr);
0494 
0495 ///
0496 /// ModelLocalFunctionsMap is a map of function id -> model local function proto
0497 /// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
0498 ///
0499 void InferShapeForFunctionNode(
0500     const FunctionProto& func_proto,
0501     const std::unordered_map<std::string, int>& func_opset_imports,
0502     const ISchemaRegistry* schema_registry,
0503     InferenceContext& ctx,
0504     const ShapeInferenceOptions& options = {},
0505     const ModelLocalFunctionsMap& model_local_functions_map = {},
0506     SymbolTable* symbolTable = nullptr,
0507     DataValueMap* generated_shape_data_by_name = nullptr);
0508 
0509 ///
0510 /// Apply type-and-shape-inference based checks to a Function body.
0511 /// Returns the inferred types of the outputs of the function.
0512 /// Inference depends on the types of the inputs of the function as well as
0513 /// the attribute values supplied.
0514 /// A TypeProto with value_case() == TypeProto::ValueCase::VALUE_NOT_SET is used
0515 /// for missing optional parameters.
0516 ///
0517 std::vector<TypeProto> InferFunctionOutputTypes(
0518     const FunctionProto& func_proto,
0519     const std::vector<TypeProto>& input_types,
0520     const std::vector<AttributeProto>& attributes);
0521 
0522 std::string GetErrorWithNodeInfo(const NodeProto& n, const std::runtime_error& err);
0523 
0524 void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbolTable);
0525 
0526 } // namespace shape_inference
0527 } // namespace ONNX_NAMESPACE