File indexing completed on 2025-02-22 10:42:47
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include <map>
0008 #include <memory>
0009 #include <string>
0010 #include <unordered_map>
0011 #include <unordered_set>
0012 #include <utility>
0013 #include <vector>
0014
0015 #include "onnx/defs/function.h"
0016 #include "onnx/defs/schema.h"
0017 #include "onnx/proto_utils.h"
0018 #include "onnx/string_utils.h"
0019
0020 namespace ONNX_NAMESPACE {
0021 namespace shape_inference {
0022
0023 using ModelLocalFunctionsMap = std::unordered_map<std::string, const FunctionProto*>;
0024
0025
0026
0027
0028
0029
0030 using DataValueMap = std::unordered_map<std::string, TensorShapeProto>;
0031
0032 class SymbolTableImpl : public SymbolTable {
0033 public:
0034 SymbolTableImpl() : index_(0) {}
0035
0036 void addFromGraph(const GraphProto& g) override {
0037 AddExistingSymbolicDims(g.input());
0038 AddExistingSymbolicDims(g.output());
0039 AddExistingSymbolicDims(g.value_info());
0040 }
0041
0042
0043 std::string createNew(const std::string& symbol_prefix) override {
0044 std::string newSymbol;
0045 do {
0046 newSymbol = symbol_prefix + std::to_string(index_++);
0047 } while (existing_symbols.count(newSymbol) > 0);
0048 existing_symbols.insert(newSymbol);
0049 return newSymbol;
0050 }
0051
0052 private:
0053 unsigned int index_;
0054 std::unordered_set<std::string> existing_symbols;
0055
0056
0057 template <typename TensorTypeProto>
0058 void AddExistingSymbolicDims(const TensorTypeProto& tensorType) {
0059 if (tensorType.has_shape()) {
0060 for (int i = 0; i < tensorType.shape().dim_size(); ++i) {
0061 if (tensorType.shape().dim(i).has_dim_param()) {
0062 existing_symbols.insert(tensorType.shape().dim(i).dim_param());
0063 }
0064 }
0065 }
0066 }
0067
0068 void AddExistingSymbolicDims(const TypeProto& typeProto) {
0069 const auto val_case = typeProto.value_case();
0070 switch (val_case) {
0071 case TypeProto::kTensorType:
0072 AddExistingSymbolicDims(typeProto.tensor_type());
0073 break;
0074 case TypeProto::kSparseTensorType:
0075 AddExistingSymbolicDims(typeProto.sparse_tensor_type());
0076 break;
0077 case TypeProto::kSequenceType:
0078 AddExistingSymbolicDims(typeProto.sequence_type().elem_type());
0079 break;
0080 case TypeProto::kOptionalType:
0081 AddExistingSymbolicDims(typeProto.optional_type().elem_type());
0082 break;
0083 case TypeProto::kMapType:
0084 AddExistingSymbolicDims(typeProto.map_type().value_type());
0085 break;
0086 default:
0087 break;
0088 }
0089 }
0090
0091 void AddExistingSymbolicDims(const google::protobuf::RepeatedPtrField<ValueInfoProto>& protos) {
0092 for (const auto& proto : protos) {
0093 AddExistingSymbolicDims(proto.type());
0094 }
0095 }
0096 };
0097
0098 struct GraphInferenceContext {
0099 GraphInferenceContext(
0100 const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name_in,
0101 const std::unordered_map<std::string, int> opset_imports_in,
0102 SymbolTable* symbol_table_in = nullptr,
0103 const ModelLocalFunctionsMap& model_local_functions_in = {},
0104 const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(),
0105 DataValueMap* generated_shape_data_by_name_in = nullptr,
0106 const int ir_version_in = IR_VERSION)
0107 : outer_scope_value_types_by_name{&outer_scope_value_types_by_name_in},
0108 opset_imports{opset_imports_in},
0109 symbol_table{symbol_table_in},
0110 model_local_functions{model_local_functions_in},
0111 schema_registry{schema_registry_in},
0112 generated_shape_data_by_name{generated_shape_data_by_name_in},
0113 ir_version{ir_version_in} {}
0114
0115 const std::unordered_map<std::string, TypeProto*>* outer_scope_value_types_by_name;
0116 const std::unordered_map<std::string, int> opset_imports;
0117 SymbolTable* symbol_table;
0118 const ModelLocalFunctionsMap& model_local_functions;
0119 const ISchemaRegistry* schema_registry;
0120 DataValueMap* generated_shape_data_by_name;
0121 const int ir_version;
0122 };
0123
0124 class GraphInferencerImpl : public GraphInferencer {
0125 public:
0126 GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context) : g_{&g}, context_{&context}, options_() {}
0127 GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context, const ShapeInferenceOptions& options)
0128 : g_{&g}, context_{&context}, options_(options) {}
0129
0130 std::vector<const TypeProto*> doInferencing(
0131 const std::vector<const TypeProto*>& inputTypes,
0132 const std::vector<const TensorProto*>& inputData) override;
0133
0134 private:
0135 GraphProto* g_;
0136 GraphInferenceContext* context_;
0137 ShapeInferenceOptions options_;
0138 };
0139
0140 struct InferenceContextImpl : public InferenceContext {
0141 InferenceContextImpl(
0142 NodeProto& n,
0143 const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
0144 const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
0145 const std::unordered_map<std::string, const SparseTensorProto*>& inputSparseDataByName,
0146 const ShapeInferenceOptions& options,
0147 DataValueMap* generatedShapeData = nullptr,
0148 GraphInferenceContext* graphInferenceContext = nullptr)
0149 : graphInferenceContext_{graphInferenceContext}, options_(options) {
0150 for (auto& attr : *n.mutable_attribute()) {
0151 attributesByName_[attr.name()] = &attr;
0152 if (attr.has_g()) {
0153
0154 graphProtoAttributesByName_[attr.name()] = attr.mutable_g();
0155 }
0156 }
0157
0158 for (const auto& input : n.input()) {
0159 auto valueTypesIter = valueTypesByName.find(input);
0160 if (valueTypesIter != valueTypesByName.end()) {
0161 allInputTypes_.push_back(valueTypesIter->second);
0162 } else {
0163 allInputTypes_.push_back(nullptr);
0164 }
0165
0166
0167
0168
0169
0170 const auto inputDataIter = inputDataByName.find(input);
0171 if (inputDataIter != inputDataByName.cend()) {
0172 allInputData_.push_back(inputDataIter->second);
0173 allInputSparseData_.push_back(nullptr);
0174 allShapeInputData_.push_back(nullptr);
0175 } else {
0176 allInputData_.push_back(nullptr);
0177 const auto inputSparseDataIter = inputSparseDataByName.find(input);
0178 if (inputSparseDataIter != inputSparseDataByName.cend()) {
0179 allInputSparseData_.push_back(inputSparseDataIter->second);
0180 allShapeInputData_.push_back(nullptr);
0181 } else {
0182 allInputSparseData_.push_back(nullptr);
0183 if (generatedShapeData != nullptr) {
0184 const auto inputShapeDataIter = generatedShapeData->find(input);
0185 if (inputShapeDataIter != generatedShapeData->cend()) {
0186 allShapeInputData_.push_back(&inputShapeDataIter->second);
0187 } else {
0188 allShapeInputData_.push_back(nullptr);
0189 }
0190 } else {
0191 allShapeInputData_.push_back(nullptr);
0192 }
0193 }
0194 }
0195 }
0196
0197 allOutputTypes_.resize(n.output_size());
0198 }
0199
0200 const AttributeProto* getAttribute(const std::string& name) const override {
0201 auto iter = attributesByName_.find(name);
0202 if (iter == attributesByName_.end()) {
0203 return nullptr;
0204 } else {
0205 return iter->second;
0206 }
0207 }
0208
0209 size_t getNumInputs() const override {
0210 return allInputTypes_.size();
0211 }
0212
0213 const TypeProto* getInputType(size_t index) const override {
0214 if (index >= allInputTypes_.size()) {
0215 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0216 }
0217 return allInputTypes_[index];
0218 }
0219
0220 const TensorProto* getInputData(size_t index) const override {
0221 if (index >= allInputData_.size()) {
0222 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0223 }
0224 return allInputData_[index];
0225 }
0226
0227 const TensorShapeProto* getSymbolicInput(size_t index) const override {
0228 if (index >= allShapeInputData_.size()) {
0229 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0230 }
0231
0232 return allShapeInputData_[index];
0233 }
0234
0235 const SparseTensorProto* getInputSparseData(size_t index) const override {
0236 if (index >= allInputSparseData_.size()) {
0237 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0238 }
0239 return allInputSparseData_[index];
0240 }
0241
0242 size_t getNumOutputs() const override {
0243 return allOutputTypes_.size();
0244 }
0245
0246 TypeProto* getOutputType(size_t index) override {
0247 if (index >= allOutputTypes_.size()) {
0248 ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0249 }
0250 return &allOutputTypes_[index];
0251 }
0252
0253 GraphInferencer* getGraphAttributeInferencer(const std::string& attr_name) override {
0254 if (!graphInferenceContext_) {
0255 fail_type_inference("GraphProto attribute inferencing is not enabled in this InferenceContextImpl instance.");
0256 }
0257
0258 GraphInferencer* inferencer = nullptr;
0259
0260 auto entry = graphAttributeInferencers_.find(attr_name);
0261 if (entry == graphAttributeInferencers_.cend()) {
0262
0263 auto attrNameToGraphProto = graphProtoAttributesByName_.find(attr_name);
0264 if (attrNameToGraphProto == graphProtoAttributesByName_.cend()) {
0265 fail_type_inference("Attribute ", attr_name, " does not contain a graph.");
0266 }
0267
0268 std::unique_ptr<GraphInferencer> new_inferencer{
0269 new GraphInferencerImpl(*attrNameToGraphProto->second, *graphInferenceContext_, options_)};
0270
0271 inferencer = new_inferencer.get();
0272 graphAttributeInferencers_.emplace(attr_name, std::move(new_inferencer));
0273 } else {
0274 inferencer = entry->second.get();
0275 }
0276
0277 return inferencer;
0278 }
0279
0280 std::vector<const TensorProto*> allInputData_;
0281 std::vector<const SparseTensorProto*> allInputSparseData_;
0282 std::vector<const TensorShapeProto*> allShapeInputData_;
0283 std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0284 std::unordered_map<std::string, GraphProto*> graphProtoAttributesByName_;
0285 std::vector<const TypeProto*> allInputTypes_;
0286 std::vector<TypeProto> allOutputTypes_;
0287 GraphInferenceContext* graphInferenceContext_;
0288
0289
0290 mutable std::unordered_map<std::string, std::unique_ptr<GraphInferencer>> graphAttributeInferencers_;
0291 ShapeInferenceOptions options_;
0292 };
0293
0294 struct DataPropagationContextImpl : public DataPropagationContext {
0295 DataPropagationContextImpl(
0296 NodeProto& n,
0297 const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
0298 const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
0299 DataValueMap& generatedShapeData)
0300 : generatedShapeData_(generatedShapeData) {
0301 size_t input_idx = 0;
0302
0303 for (auto& attr : *n.mutable_attribute()) {
0304 attributesByName_[attr.name()] = &attr;
0305 }
0306
0307 for (const auto& input : n.input()) {
0308 inputIndexToNameMap_.insert({input_idx++, input});
0309
0310 auto valueTypesIter = valueTypesByName.find(input);
0311 if (valueTypesIter != valueTypesByName.end()) {
0312 allInputTypes_.push_back(valueTypesIter->second);
0313 } else {
0314 allInputTypes_.push_back(nullptr);
0315 }
0316
0317 const auto inputDataIter = inputDataByName.find(input);
0318 if (inputDataIter != inputDataByName.cend()) {
0319 allInputData_.push_back(inputDataIter->second);
0320 } else {
0321 allInputData_.push_back(nullptr);
0322 }
0323 }
0324
0325 size_t output_idx = 0;
0326 for (const auto& output : n.output()) {
0327 outputIndexToNameMap_.insert({output_idx++, output});
0328 }
0329
0330 allOutputTypes_.resize(n.output_size());
0331 }
0332
0333 const AttributeProto* getAttribute(const std::string& name) const override {
0334 auto iter = attributesByName_.find(name);
0335 if (iter == attributesByName_.end()) {
0336 return nullptr;
0337 } else {
0338 return iter->second;
0339 }
0340 }
0341
0342 size_t getNumInputs() const override {
0343 return allInputTypes_.size();
0344 }
0345
0346 const TypeProto* getInputType(size_t index) const override {
0347 if (index >= allInputTypes_.size()) {
0348 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0349 }
0350 return allInputTypes_[index];
0351 }
0352
0353 size_t getNumOutputs() const override {
0354 return allOutputTypes_.size();
0355 }
0356
0357 const TypeProto* getOutputType(size_t index) const override {
0358 if (index >= allOutputTypes_.size()) {
0359 ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0360 }
0361 return &allOutputTypes_[index];
0362 }
0363
0364
0365 template <typename INTEGER>
0366 void vectorToTensorShapeProto(const std::vector<INTEGER>& input_vals, TensorShapeProto& converted_tsp) const {
0367 for (unsigned int i = 0; i < input_vals.size(); ++i) {
0368 converted_tsp.mutable_dim()->Add()->set_dim_value(input_vals[i]);
0369 }
0370 }
0371
0372 const TensorShapeProto* getInputData(size_t index) override {
0373 if (index >= allInputData_.size()) {
0374 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0375 }
0376 const std::string input_name = inputIndexToNameMap_.at(index);
0377
0378 auto iter = generatedShapeData_.find(input_name);
0379 if (iter != generatedShapeData_.end()) {
0380 return &iter->second;
0381 }
0382
0383 const auto* input_data = allInputData_[index];
0384
0385
0386 if (input_data != nullptr && (input_data->dims_size() == 0 || input_data->dims_size() == 1)) {
0387 TensorShapeProto tsp;
0388
0389 if (input_data->data_type() == TensorProto_DataType_INT64) {
0390 vectorToTensorShapeProto(ParseData<int64_t>(input_data), tsp);
0391 } else if (input_data->data_type() == TensorProto_DataType_INT32) {
0392 vectorToTensorShapeProto(ParseData<int32_t>(input_data), tsp);
0393 } else {
0394
0395 return nullptr;
0396 }
0397
0398
0399
0400 auto result = generatedShapeData_.insert({input_name, std::move(tsp)});
0401 if (result.second) {
0402 return &(result.first->second);
0403 }
0404 }
0405 return nullptr;
0406 }
0407
0408 void addOutputData(size_t index, TensorShapeProto&& tsp) override {
0409 if (index >= outputIndexToNameMap_.size()) {
0410 ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0411 }
0412 auto result = generatedShapeData_.insert({outputIndexToNameMap_.at(index), std::move(tsp)});
0413 if (!result.second) {
0414 fail_shape_inference("Data for input " + ONNX_NAMESPACE::to_string(index) + " already exists.");
0415 }
0416 }
0417
0418 std::vector<const TensorProto*> allInputData_;
0419 std::unordered_map<size_t, std::string> inputIndexToNameMap_;
0420 std::unordered_map<size_t, std::string> outputIndexToNameMap_;
0421 std::vector<const TypeProto*> allInputTypes_;
0422 std::vector<TypeProto> allOutputTypes_;
0423 DataValueMap& generatedShapeData_;
0424 std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0425 };
0426
0427 void checkShapesAndTypes(const TypeProto_Sequence& inferredType, const TypeProto_Sequence& existingType);
0428
0429 void checkShapesAndTypes(const TypeProto& inferredType, const TypeProto& existingType);
0430
0431 template <typename TensorTypeProto>
0432 void GenerateSymbolicShape(TensorTypeProto* inferredType, SymbolTable& symbolTable);
0433
0434 void MaterializeSymbolicShape(TypeProto* inferredType, SymbolTable& symbolTable);
0435
0436 void mergeShapesAndTypes(const TypeProto_Tensor& inferredType, TypeProto_Tensor* existingType);
0437
0438 void mergeShapesAndTypes(const TypeProto_SparseTensor& inferredType, TypeProto_SparseTensor* existingType);
0439
0440 void mergeShapesAndTypes(const TypeProto_Sequence& inferredType, TypeProto_Tensor* existingType);
0441
0442 void mergeShapesAndTypes(const TypeProto& inferredType, TypeProto* existingType);
0443
0444
0445
0446
0447
0448 void InferShapes(
0449 GraphProto* g,
0450 const std::unordered_map<std::string, int>& opset_imports,
0451 const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0452 const ShapeInferenceOptions& options = {},
0453 const ModelLocalFunctionsMap& in_model_functions = {});
0454
0455 void InferShapes(
0456 ModelProto& m,
0457 const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0458 const ShapeInferenceOptions& options = {},
0459 DataValueMap* generated_shape_data_by_name = nullptr);
0460
0461 void InferShapes(
0462 const std::string& model_path,
0463 const std::string& save_path = "",
0464 const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0465 const ShapeInferenceOptions& options = {},
0466 DataValueMap* generated_shape_data_by_name = nullptr);
0467
0468
0469
0470
0471
0472 void InferShapeForFunctionNode(
0473 const FunctionProto& func,
0474 const ISchemaRegistry* schema_registry,
0475 InferenceContext& ctx,
0476 const ShapeInferenceOptions& options = {},
0477 const ModelLocalFunctionsMap& model_local_functions_map = {},
0478 SymbolTable* symbolTable = nullptr,
0479 DataValueMap* generated_shape_data_by_name = nullptr);
0480
0481
0482
0483
0484
0485 void InferShapeForFunctionNode(
0486 const FunctionProto& func_proto,
0487 const std::unordered_map<std::string, int>& func_opset_imports,
0488 const ISchemaRegistry* schema_registry,
0489 InferenceContext& ctx,
0490 const ShapeInferenceOptions& options = {},
0491 const ModelLocalFunctionsMap& model_local_functions_map = {},
0492 SymbolTable* symbolTable = nullptr,
0493 DataValueMap* generated_shape_data_by_name = nullptr);
0494
0495
0496
0497
0498
0499
0500
0501
0502
0503 std::vector<TypeProto> InferFunctionOutputTypes(
0504 const FunctionProto& func_proto,
0505 const std::vector<TypeProto>& input_types,
0506 const std::vector<AttributeProto>& attributes);
0507
0508 std::string GetErrorWithNodeInfo(const NodeProto& n, const std::runtime_error& err);
0509
0510 void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbolTable);
0511
0512 }
0513 }