File indexing completed on 2025-09-17 09:03:56
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include <map>
0008 #include <memory>
0009 #include <string>
0010 #include <unordered_map>
0011 #include <unordered_set>
0012 #include <utility>
0013 #include <vector>
0014
0015 #include "onnx/defs/function.h"
0016 #include "onnx/defs/schema.h"
0017 #include "onnx/proto_utils.h"
0018 #include "onnx/string_utils.h"
0019
0020 namespace ONNX_NAMESPACE {
0021 namespace shape_inference {
0022
0023 using ModelLocalFunctionsMap = std::unordered_map<std::string, const FunctionProto*>;
0024
0025
0026
0027
0028
0029
0030 using DataValueMap = std::unordered_map<std::string, TensorShapeProto>;
0031
0032 class SymbolTableImpl : public SymbolTable {
0033 public:
0034 SymbolTableImpl() : index_(0) {}
0035
0036 void addFromGraph(const GraphProto& g) override {
0037 AddExistingSymbolicDims(g.input());
0038 AddExistingSymbolicDims(g.output());
0039 AddExistingSymbolicDims(g.value_info());
0040 }
0041
0042
0043 std::string createNew(const std::string& symbol_prefix) override {
0044 std::string newSymbol;
0045 do {
0046 newSymbol = symbol_prefix + std::to_string(index_++);
0047 } while (existing_symbols.count(newSymbol) > 0);
0048 existing_symbols.insert(newSymbol);
0049 return newSymbol;
0050 }
0051
0052 private:
0053 unsigned int index_;
0054 std::unordered_set<std::string> existing_symbols;
0055
0056
0057 template <typename TensorTypeProto>
0058 void AddExistingSymbolicDims(const TensorTypeProto& tensorType) {
0059 if (tensorType.has_shape()) {
0060 for (int i = 0; i < tensorType.shape().dim_size(); ++i) {
0061 if (tensorType.shape().dim(i).has_dim_param()) {
0062 existing_symbols.insert(tensorType.shape().dim(i).dim_param());
0063 }
0064 }
0065 }
0066 }
0067
0068 void AddExistingSymbolicDims(const TypeProto& typeProto) {
0069 const auto val_case = typeProto.value_case();
0070 switch (val_case) {
0071 case TypeProto::kTensorType:
0072 AddExistingSymbolicDims(typeProto.tensor_type());
0073 break;
0074 case TypeProto::kSparseTensorType:
0075 AddExistingSymbolicDims(typeProto.sparse_tensor_type());
0076 break;
0077 case TypeProto::kSequenceType:
0078 AddExistingSymbolicDims(typeProto.sequence_type().elem_type());
0079 break;
0080 case TypeProto::kOptionalType:
0081 AddExistingSymbolicDims(typeProto.optional_type().elem_type());
0082 break;
0083 case TypeProto::kMapType:
0084 AddExistingSymbolicDims(typeProto.map_type().value_type());
0085 break;
0086 default:
0087 break;
0088 }
0089 }
0090
0091 void AddExistingSymbolicDims(const google::protobuf::RepeatedPtrField<ValueInfoProto>& protos) {
0092 for (const auto& proto : protos) {
0093 AddExistingSymbolicDims(proto.type());
0094 }
0095 }
0096 };
0097
0098 struct GraphInferenceContext {
0099 GraphInferenceContext(
0100 const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name_in,
0101 const std::unordered_map<std::string, int> opset_imports_in,
0102 SymbolTable* symbol_table_in = nullptr,
0103 const ModelLocalFunctionsMap& model_local_functions_in = {},
0104 const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(),
0105 DataValueMap* generated_shape_data_by_name_in = nullptr,
0106 const int ir_version_in = IR_VERSION)
0107 : outer_scope_value_types_by_name{&outer_scope_value_types_by_name_in},
0108 opset_imports{opset_imports_in},
0109 symbol_table{symbol_table_in},
0110 model_local_functions{model_local_functions_in},
0111 schema_registry{schema_registry_in},
0112 generated_shape_data_by_name{generated_shape_data_by_name_in},
0113 ir_version{ir_version_in} {}
0114
0115 const std::unordered_map<std::string, TypeProto*>* outer_scope_value_types_by_name;
0116 const std::unordered_map<std::string, int> opset_imports;
0117 SymbolTable* symbol_table;
0118 const ModelLocalFunctionsMap& model_local_functions;
0119 const ISchemaRegistry* schema_registry;
0120 DataValueMap* generated_shape_data_by_name;
0121 const int ir_version;
0122 };
0123
0124 class GraphInferencerImpl : public GraphInferencer {
0125 public:
0126 GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context) : g_{&g}, context_{&context}, options_() {}
0127 GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context, const ShapeInferenceOptions& options)
0128 : g_{&g}, context_{&context}, options_(options) {}
0129
0130 std::vector<const TypeProto*> doInferencing(
0131 const std::vector<const TypeProto*>& inputTypes,
0132 const std::vector<const TensorProto*>& inputData) override;
0133
0134 private:
0135 GraphProto* g_;
0136 GraphInferenceContext* context_;
0137 ShapeInferenceOptions options_;
0138 };
0139
0140 struct InferenceContextImpl : public InferenceContext {
0141 InferenceContextImpl(
0142 NodeProto& n,
0143 const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
0144 const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
0145 const std::unordered_map<std::string, const SparseTensorProto*>& inputSparseDataByName,
0146 const ShapeInferenceOptions& options,
0147 DataValueMap* generatedShapeData = nullptr,
0148 GraphInferenceContext* graphInferenceContext = nullptr)
0149 : graphInferenceContext_{graphInferenceContext}, options_(options), node_(&n) {
0150 for (auto& attr : *n.mutable_attribute()) {
0151 attributesByName_[attr.name()] = &attr;
0152 if (attr.has_g()) {
0153
0154 graphProtoAttributesByName_[attr.name()] = attr.mutable_g();
0155 }
0156 }
0157
0158 for (const auto& input : n.input()) {
0159 auto valueTypesIter = valueTypesByName.find(input);
0160 if (valueTypesIter != valueTypesByName.end()) {
0161 allInputTypes_.push_back(valueTypesIter->second);
0162 } else {
0163 allInputTypes_.push_back(nullptr);
0164 }
0165
0166
0167
0168
0169
0170 const auto inputDataIter = inputDataByName.find(input);
0171 if (inputDataIter != inputDataByName.cend()) {
0172 allInputData_.push_back(inputDataIter->second);
0173 allInputSparseData_.push_back(nullptr);
0174 allShapeInputData_.push_back(nullptr);
0175 } else {
0176 allInputData_.push_back(nullptr);
0177 const auto inputSparseDataIter = inputSparseDataByName.find(input);
0178 if (inputSparseDataIter != inputSparseDataByName.cend()) {
0179 allInputSparseData_.push_back(inputSparseDataIter->second);
0180 allShapeInputData_.push_back(nullptr);
0181 } else {
0182 allInputSparseData_.push_back(nullptr);
0183 if (generatedShapeData != nullptr) {
0184 const auto inputShapeDataIter = generatedShapeData->find(input);
0185 if (inputShapeDataIter != generatedShapeData->cend()) {
0186 allShapeInputData_.push_back(&inputShapeDataIter->second);
0187 } else {
0188 allShapeInputData_.push_back(nullptr);
0189 }
0190 } else {
0191 allShapeInputData_.push_back(nullptr);
0192 }
0193 }
0194 }
0195 }
0196
0197 allOutputTypes_.resize(n.output_size());
0198 }
0199
0200 const AttributeProto* getAttribute(const std::string& name) const override {
0201 auto iter = attributesByName_.find(name);
0202 if (iter == attributesByName_.end()) {
0203 return nullptr;
0204 } else {
0205 return iter->second;
0206 }
0207 }
0208
0209 size_t getNumInputs() const override {
0210 return allInputTypes_.size();
0211 }
0212
0213 const TypeProto* getInputType(size_t index) const override {
0214 if (index >= allInputTypes_.size()) {
0215 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0216 }
0217 return allInputTypes_[index];
0218 }
0219
0220 const TensorProto* getInputData(size_t index) const override {
0221 if (index >= allInputData_.size()) {
0222 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0223 }
0224 return allInputData_[index];
0225 }
0226
0227 const TensorShapeProto* getSymbolicInput(size_t index) const override {
0228 if (index >= allShapeInputData_.size()) {
0229 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0230 }
0231
0232 return allShapeInputData_[index];
0233 }
0234
0235 const SparseTensorProto* getInputSparseData(size_t index) const override {
0236 if (index >= allInputSparseData_.size()) {
0237 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0238 }
0239 return allInputSparseData_[index];
0240 }
0241
0242 size_t getNumOutputs() const override {
0243 return allOutputTypes_.size();
0244 }
0245
0246 TypeProto* getOutputType(size_t index) override {
0247 if (index >= allOutputTypes_.size()) {
0248 ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0249 }
0250 return &allOutputTypes_[index];
0251 }
0252
0253 GraphInferencer* getGraphAttributeInferencer(const std::string& attr_name) override {
0254 if (!graphInferenceContext_) {
0255 fail_type_inference("GraphProto attribute inferencing is not enabled in this InferenceContextImpl instance.");
0256 }
0257
0258 GraphInferencer* inferencer = nullptr;
0259
0260 auto entry = graphAttributeInferencers_.find(attr_name);
0261 if (entry == graphAttributeInferencers_.cend()) {
0262
0263 auto attrNameToGraphProto = graphProtoAttributesByName_.find(attr_name);
0264 if (attrNameToGraphProto == graphProtoAttributesByName_.cend()) {
0265 fail_type_inference("Attribute ", attr_name, " does not contain a graph.");
0266 }
0267
0268 std::unique_ptr<GraphInferencer> new_inferencer{
0269 new GraphInferencerImpl(*attrNameToGraphProto->second, *graphInferenceContext_, options_)};
0270
0271 inferencer = new_inferencer.get();
0272 graphAttributeInferencers_.emplace(attr_name, std::move(new_inferencer));
0273 } else {
0274 inferencer = entry->second.get();
0275 }
0276
0277 return inferencer;
0278 }
0279
0280 std::string getDisplayName() const override {
0281 if (node_ == nullptr)
0282 return "";
0283 if (node_->domain().empty()) {
0284 if (node_->name().empty())
0285 return MakeString("node ", node_->op_type());
0286 return MakeString("node ", node_->op_type(), " (", node_->name(), ")");
0287 }
0288 if (node_->name().empty())
0289 return MakeString("node ", node_->op_type(), "[", node_->domain(), "]");
0290 return MakeString("node ", node_->op_type(), "[", node_->domain(), "]", " (", node_->name(), ")");
0291 }
0292
0293 std::vector<const TensorProto*> allInputData_;
0294 std::vector<const SparseTensorProto*> allInputSparseData_;
0295 std::vector<const TensorShapeProto*> allShapeInputData_;
0296 std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0297 std::unordered_map<std::string, GraphProto*> graphProtoAttributesByName_;
0298 std::vector<const TypeProto*> allInputTypes_;
0299 std::vector<TypeProto> allOutputTypes_;
0300 GraphInferenceContext* graphInferenceContext_;
0301
0302
0303 mutable std::unordered_map<std::string, std::unique_ptr<GraphInferencer>> graphAttributeInferencers_;
0304 ShapeInferenceOptions options_;
0305 NodeProto* node_;
0306 };
0307
0308 struct DataPropagationContextImpl : public DataPropagationContext {
0309 DataPropagationContextImpl(
0310 NodeProto& n,
0311 const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
0312 const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
0313 DataValueMap& generatedShapeData)
0314 : generatedShapeData_(generatedShapeData) {
0315 size_t input_idx = 0;
0316
0317 for (auto& attr : *n.mutable_attribute()) {
0318 attributesByName_[attr.name()] = &attr;
0319 }
0320
0321 for (const auto& input : n.input()) {
0322 inputIndexToNameMap_.insert({input_idx++, input});
0323
0324 auto valueTypesIter = valueTypesByName.find(input);
0325 if (valueTypesIter != valueTypesByName.end()) {
0326 allInputTypes_.push_back(valueTypesIter->second);
0327 } else {
0328 allInputTypes_.push_back(nullptr);
0329 }
0330
0331 const auto inputDataIter = inputDataByName.find(input);
0332 if (inputDataIter != inputDataByName.cend()) {
0333 allInputData_.push_back(inputDataIter->second);
0334 } else {
0335 allInputData_.push_back(nullptr);
0336 }
0337 }
0338
0339 size_t output_idx = 0;
0340 for (const auto& output : n.output()) {
0341 outputIndexToNameMap_.insert({output_idx++, output});
0342 }
0343
0344 allOutputTypes_.resize(n.output_size());
0345 }
0346
0347 const AttributeProto* getAttribute(const std::string& name) const override {
0348 auto iter = attributesByName_.find(name);
0349 if (iter == attributesByName_.end()) {
0350 return nullptr;
0351 } else {
0352 return iter->second;
0353 }
0354 }
0355
0356 size_t getNumInputs() const override {
0357 return allInputTypes_.size();
0358 }
0359
0360 const TypeProto* getInputType(size_t index) const override {
0361 if (index >= allInputTypes_.size()) {
0362 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0363 }
0364 return allInputTypes_[index];
0365 }
0366
0367 size_t getNumOutputs() const override {
0368 return allOutputTypes_.size();
0369 }
0370
0371 const TypeProto* getOutputType(size_t index) const override {
0372 if (index >= allOutputTypes_.size()) {
0373 ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0374 }
0375 return &allOutputTypes_[index];
0376 }
0377
0378
0379 template <typename INTEGER>
0380 void vectorToTensorShapeProto(const std::vector<INTEGER>& input_vals, TensorShapeProto& converted_tsp) const {
0381 for (unsigned int i = 0; i < input_vals.size(); ++i) {
0382 converted_tsp.mutable_dim()->Add()->set_dim_value(input_vals[i]);
0383 }
0384 }
0385
0386 const TensorShapeProto* getInputData(size_t index) override {
0387 if (index >= allInputData_.size()) {
0388 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0389 }
0390 const std::string input_name = inputIndexToNameMap_.at(index);
0391
0392 auto iter = generatedShapeData_.find(input_name);
0393 if (iter != generatedShapeData_.end()) {
0394 return &iter->second;
0395 }
0396
0397 const auto* input_data = allInputData_[index];
0398
0399
0400 if (input_data != nullptr && (input_data->dims_size() == 0 || input_data->dims_size() == 1)) {
0401 TensorShapeProto tsp;
0402
0403 if (input_data->data_type() == TensorProto_DataType_INT64) {
0404 vectorToTensorShapeProto(ParseData<int64_t>(input_data), tsp);
0405 } else if (input_data->data_type() == TensorProto_DataType_INT32) {
0406 vectorToTensorShapeProto(ParseData<int32_t>(input_data), tsp);
0407 } else {
0408
0409 return nullptr;
0410 }
0411
0412
0413
0414 auto result = generatedShapeData_.insert({input_name, std::move(tsp)});
0415 if (result.second) {
0416 return &(result.first->second);
0417 }
0418 }
0419 return nullptr;
0420 }
0421
0422 void addOutputData(size_t index, TensorShapeProto&& tsp) override {
0423 if (index >= outputIndexToNameMap_.size()) {
0424 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0425 }
0426 auto result = generatedShapeData_.insert({outputIndexToNameMap_.at(index), std::move(tsp)});
0427 if (!result.second) {
0428 fail_shape_inference("Data for input " + ONNX_NAMESPACE::to_string(index) + " already exists.");
0429 }
0430 }
0431
0432 std::vector<const TensorProto*> allInputData_;
0433 std::unordered_map<size_t, std::string> inputIndexToNameMap_;
0434 std::unordered_map<size_t, std::string> outputIndexToNameMap_;
0435 std::vector<const TypeProto*> allInputTypes_;
0436 std::vector<TypeProto> allOutputTypes_;
0437 DataValueMap& generatedShapeData_;
0438 std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0439 };
0440
0441 void checkShapesAndTypes(const TypeProto_Sequence& inferredType, const TypeProto_Sequence& existingType);
0442
0443 void checkShapesAndTypes(const TypeProto& inferredType, const TypeProto& existingType);
0444
0445 template <typename TensorTypeProto>
0446 void GenerateSymbolicShape(TensorTypeProto* inferredType, SymbolTable& symbolTable);
0447
0448 void MaterializeSymbolicShape(TypeProto* inferredType, SymbolTable& symbolTable);
0449
0450 void mergeShapesAndTypes(const TypeProto_Tensor& inferredType, TypeProto_Tensor* existingType);
0451
0452 void mergeShapesAndTypes(const TypeProto_SparseTensor& inferredType, TypeProto_SparseTensor* existingType);
0453
0454 void mergeShapesAndTypes(const TypeProto_Sequence& inferredType, TypeProto_Tensor* existingType);
0455
0456 void mergeShapesAndTypes(const TypeProto& inferredType, TypeProto* existingType);
0457
0458
0459
0460
0461
0462 void InferShapes(
0463 GraphProto* g,
0464 const std::unordered_map<std::string, int>& opset_imports,
0465 const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0466 const ShapeInferenceOptions& options = {},
0467 const ModelLocalFunctionsMap& in_model_functions = {});
0468
0469 void InferShapes(
0470 ModelProto& m,
0471 const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0472 const ShapeInferenceOptions& options = {},
0473 DataValueMap* generated_shape_data_by_name = nullptr);
0474
0475 void InferShapes(
0476 const std::string& model_path,
0477 const std::string& save_path = "",
0478 const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0479 const ShapeInferenceOptions& options = {},
0480 DataValueMap* generated_shape_data_by_name = nullptr);
0481
0482
0483
0484
0485
0486 void InferShapeForFunctionNode(
0487 const FunctionProto& func,
0488 const ISchemaRegistry* schema_registry,
0489 InferenceContext& ctx,
0490 const ShapeInferenceOptions& options = {},
0491 const ModelLocalFunctionsMap& model_local_functions_map = {},
0492 SymbolTable* symbolTable = nullptr,
0493 DataValueMap* generated_shape_data_by_name = nullptr);
0494
0495
0496
0497
0498
0499 void InferShapeForFunctionNode(
0500 const FunctionProto& func_proto,
0501 const std::unordered_map<std::string, int>& func_opset_imports,
0502 const ISchemaRegistry* schema_registry,
0503 InferenceContext& ctx,
0504 const ShapeInferenceOptions& options = {},
0505 const ModelLocalFunctionsMap& model_local_functions_map = {},
0506 SymbolTable* symbolTable = nullptr,
0507 DataValueMap* generated_shape_data_by_name = nullptr);
0508
0509
0510
0511
0512
0513
0514
0515
0516
0517 std::vector<TypeProto> InferFunctionOutputTypes(
0518 const FunctionProto& func_proto,
0519 const std::vector<TypeProto>& input_types,
0520 const std::vector<AttributeProto>& attributes);
0521
0522 std::string GetErrorWithNodeInfo(const NodeProto& n, const std::runtime_error& err);
0523
0524 void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbolTable);
0525
0526 }
0527 }