Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:47

0001 // Copyright (c) ONNX Project Contributors
0002 //
0003 // SPDX-License-Identifier: Apache-2.0
0004 
0005 #pragma once
0006 
0007 #include <map>
0008 #include <memory>
0009 #include <string>
0010 #include <unordered_map>
0011 #include <unordered_set>
0012 #include <utility>
0013 #include <vector>
0014 
0015 #include "onnx/defs/function.h"
0016 #include "onnx/defs/schema.h"
0017 #include "onnx/proto_utils.h"
0018 #include "onnx/string_utils.h"
0019 
0020 namespace ONNX_NAMESPACE {
0021 namespace shape_inference {
0022 
0023 using ModelLocalFunctionsMap = std::unordered_map<std::string, const FunctionProto*>;
0024 
0025 // We reuse TensorShapeProto to propagate statically known (partial) information about
0026 // the values of tensors. It is intended for tensors used to store shape information
0027 // (the return values of ops like Shape and input values of ops like Reshape/Expand).
0028 
0029 // A DataValueMap is used to store the statically known (partial) values of variables.
0030 using DataValueMap = std::unordered_map<std::string, TensorShapeProto>;
0031 
0032 class SymbolTableImpl : public SymbolTable {
0033  public:
0034   SymbolTableImpl() : index_(0) {}
0035 
0036   void addFromGraph(const GraphProto& g) override {
0037     AddExistingSymbolicDims(g.input());
0038     AddExistingSymbolicDims(g.output());
0039     AddExistingSymbolicDims(g.value_info());
0040   }
0041   // Creates a new unique symbol with the given prefix and adds it to the SymbolTable
0042   // Returns the newly created symbol
0043   std::string createNew(const std::string& symbol_prefix) override {
0044     std::string newSymbol;
0045     do {
0046       newSymbol = symbol_prefix + std::to_string(index_++);
0047     } while (existing_symbols.count(newSymbol) > 0);
0048     existing_symbols.insert(newSymbol);
0049     return newSymbol;
0050   }
0051 
0052  private:
0053   unsigned int index_;
0054   std::unordered_set<std::string> existing_symbols;
0055 
0056   // TypeProto_Tensor or TypeProto_SparseTensor
0057   template <typename TensorTypeProto>
0058   void AddExistingSymbolicDims(const TensorTypeProto& tensorType) {
0059     if (tensorType.has_shape()) {
0060       for (int i = 0; i < tensorType.shape().dim_size(); ++i) {
0061         if (tensorType.shape().dim(i).has_dim_param()) {
0062           existing_symbols.insert(tensorType.shape().dim(i).dim_param());
0063         }
0064       }
0065     }
0066   }
0067 
0068   void AddExistingSymbolicDims(const TypeProto& typeProto) {
0069     const auto val_case = typeProto.value_case();
0070     switch (val_case) {
0071       case TypeProto::kTensorType:
0072         AddExistingSymbolicDims(typeProto.tensor_type());
0073         break;
0074       case TypeProto::kSparseTensorType:
0075         AddExistingSymbolicDims(typeProto.sparse_tensor_type());
0076         break;
0077       case TypeProto::kSequenceType:
0078         AddExistingSymbolicDims(typeProto.sequence_type().elem_type());
0079         break;
0080       case TypeProto::kOptionalType:
0081         AddExistingSymbolicDims(typeProto.optional_type().elem_type());
0082         break;
0083       case TypeProto::kMapType:
0084         AddExistingSymbolicDims(typeProto.map_type().value_type());
0085         break;
0086       default:
0087         break;
0088     }
0089   }
0090 
0091   void AddExistingSymbolicDims(const google::protobuf::RepeatedPtrField<ValueInfoProto>& protos) {
0092     for (const auto& proto : protos) {
0093       AddExistingSymbolicDims(proto.type());
0094     }
0095   }
0096 };
0097 
0098 struct GraphInferenceContext {
0099   GraphInferenceContext(
0100       const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name_in,
0101       const std::unordered_map<std::string, int> opset_imports_in,
0102       SymbolTable* symbol_table_in = nullptr,
0103       const ModelLocalFunctionsMap& model_local_functions_in = {},
0104       const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(),
0105       DataValueMap* generated_shape_data_by_name_in = nullptr,
0106       const int ir_version_in = IR_VERSION)
0107       : outer_scope_value_types_by_name{&outer_scope_value_types_by_name_in},
0108         opset_imports{opset_imports_in},
0109         symbol_table{symbol_table_in},
0110         model_local_functions{model_local_functions_in},
0111         schema_registry{schema_registry_in},
0112         generated_shape_data_by_name{generated_shape_data_by_name_in},
0113         ir_version{ir_version_in} {}
0114 
0115   const std::unordered_map<std::string, TypeProto*>* outer_scope_value_types_by_name;
0116   const std::unordered_map<std::string, int> opset_imports;
0117   SymbolTable* symbol_table;
0118   const ModelLocalFunctionsMap& model_local_functions;
0119   const ISchemaRegistry* schema_registry;
0120   DataValueMap* generated_shape_data_by_name;
0121   const int ir_version;
0122 };
0123 
0124 class GraphInferencerImpl : public GraphInferencer {
0125  public:
0126   GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context) : g_{&g}, context_{&context}, options_() {}
0127   GraphInferencerImpl(GraphProto& g, GraphInferenceContext& context, const ShapeInferenceOptions& options)
0128       : g_{&g}, context_{&context}, options_(options) {}
0129 
0130   std::vector<const TypeProto*> doInferencing(
0131       const std::vector<const TypeProto*>& inputTypes,
0132       const std::vector<const TensorProto*>& inputData) override;
0133 
0134  private:
0135   GraphProto* g_;
0136   GraphInferenceContext* context_;
0137   ShapeInferenceOptions options_;
0138 };
0139 
0140 struct InferenceContextImpl : public InferenceContext {
0141   InferenceContextImpl(
0142       NodeProto& n,
0143       const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
0144       const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
0145       const std::unordered_map<std::string, const SparseTensorProto*>& inputSparseDataByName,
0146       const ShapeInferenceOptions& options,
0147       DataValueMap* generatedShapeData = nullptr,
0148       GraphInferenceContext* graphInferenceContext = nullptr)
0149       : graphInferenceContext_{graphInferenceContext}, options_(options) {
0150     for (auto& attr : *n.mutable_attribute()) {
0151       attributesByName_[attr.name()] = &attr;
0152       if (attr.has_g()) {
0153         // need a mutable GraphProto to run inferencing on this attribute
0154         graphProtoAttributesByName_[attr.name()] = attr.mutable_g();
0155       }
0156     }
0157 
0158     for (const auto& input : n.input()) {
0159       auto valueTypesIter = valueTypesByName.find(input);
0160       if (valueTypesIter != valueTypesByName.end()) {
0161         allInputTypes_.push_back(valueTypesIter->second);
0162       } else {
0163         allInputTypes_.push_back(nullptr);
0164       }
0165 
0166       // input data can be in 1 of the 3 containers
0167       // inputDataByName - this is when input is TensorProto
0168       // inputSparseDataByName - this is when input is SparseTensorProto
0169       // generatedShapeData - this is when input was generated as part of partial data propagation
0170       const auto inputDataIter = inputDataByName.find(input);
0171       if (inputDataIter != inputDataByName.cend()) {
0172         allInputData_.push_back(inputDataIter->second);
0173         allInputSparseData_.push_back(nullptr);
0174         allShapeInputData_.push_back(nullptr);
0175       } else {
0176         allInputData_.push_back(nullptr);
0177         const auto inputSparseDataIter = inputSparseDataByName.find(input);
0178         if (inputSparseDataIter != inputSparseDataByName.cend()) {
0179           allInputSparseData_.push_back(inputSparseDataIter->second);
0180           allShapeInputData_.push_back(nullptr);
0181         } else {
0182           allInputSparseData_.push_back(nullptr);
0183           if (generatedShapeData != nullptr) {
0184             const auto inputShapeDataIter = generatedShapeData->find(input);
0185             if (inputShapeDataIter != generatedShapeData->cend()) {
0186               allShapeInputData_.push_back(&inputShapeDataIter->second);
0187             } else {
0188               allShapeInputData_.push_back(nullptr);
0189             }
0190           } else {
0191             allShapeInputData_.push_back(nullptr);
0192           }
0193         }
0194       }
0195     }
0196 
0197     allOutputTypes_.resize(n.output_size());
0198   }
0199 
0200   const AttributeProto* getAttribute(const std::string& name) const override {
0201     auto iter = attributesByName_.find(name);
0202     if (iter == attributesByName_.end()) {
0203       return nullptr;
0204     } else {
0205       return iter->second;
0206     }
0207   }
0208 
0209   size_t getNumInputs() const override {
0210     return allInputTypes_.size();
0211   }
0212 
0213   const TypeProto* getInputType(size_t index) const override {
0214     if (index >= allInputTypes_.size()) {
0215       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0216     }
0217     return allInputTypes_[index];
0218   }
0219 
0220   const TensorProto* getInputData(size_t index) const override {
0221     if (index >= allInputData_.size()) {
0222       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0223     }
0224     return allInputData_[index];
0225   }
0226 
0227   const TensorShapeProto* getSymbolicInput(size_t index) const override {
0228     if (index >= allShapeInputData_.size()) {
0229       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0230     }
0231 
0232     return allShapeInputData_[index];
0233   }
0234 
0235   const SparseTensorProto* getInputSparseData(size_t index) const override {
0236     if (index >= allInputSparseData_.size()) {
0237       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0238     }
0239     return allInputSparseData_[index];
0240   }
0241 
0242   size_t getNumOutputs() const override {
0243     return allOutputTypes_.size();
0244   }
0245 
0246   TypeProto* getOutputType(size_t index) override {
0247     if (index >= allOutputTypes_.size()) {
0248       ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0249     }
0250     return &allOutputTypes_[index];
0251   }
0252 
0253   GraphInferencer* getGraphAttributeInferencer(const std::string& attr_name) override {
0254     if (!graphInferenceContext_) {
0255       fail_type_inference("GraphProto attribute inferencing is not enabled in this InferenceContextImpl instance.");
0256     }
0257 
0258     GraphInferencer* inferencer = nullptr;
0259 
0260     auto entry = graphAttributeInferencers_.find(attr_name);
0261     if (entry == graphAttributeInferencers_.cend()) {
0262       // create GraphInferencer instance
0263       auto attrNameToGraphProto = graphProtoAttributesByName_.find(attr_name);
0264       if (attrNameToGraphProto == graphProtoAttributesByName_.cend()) {
0265         fail_type_inference("Attribute ", attr_name, " does not contain a graph.");
0266       }
0267 
0268       std::unique_ptr<GraphInferencer> new_inferencer{
0269           new GraphInferencerImpl(*attrNameToGraphProto->second, *graphInferenceContext_, options_)};
0270 
0271       inferencer = new_inferencer.get();
0272       graphAttributeInferencers_.emplace(attr_name, std::move(new_inferencer));
0273     } else {
0274       inferencer = entry->second.get();
0275     }
0276 
0277     return inferencer;
0278   }
0279 
0280   std::vector<const TensorProto*> allInputData_;
0281   std::vector<const SparseTensorProto*> allInputSparseData_;
0282   std::vector<const TensorShapeProto*> allShapeInputData_;
0283   std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0284   std::unordered_map<std::string, GraphProto*> graphProtoAttributesByName_;
0285   std::vector<const TypeProto*> allInputTypes_;
0286   std::vector<TypeProto> allOutputTypes_;
0287   GraphInferenceContext* graphInferenceContext_;
0288 
0289   // mutable as internal cache of GraphInferencer instances
0290   mutable std::unordered_map<std::string, std::unique_ptr<GraphInferencer>> graphAttributeInferencers_;
0291   ShapeInferenceOptions options_;
0292 };
0293 
0294 struct DataPropagationContextImpl : public DataPropagationContext {
0295   DataPropagationContextImpl(
0296       NodeProto& n,
0297       const std::unordered_map<std::string, TypeProto*>& valueTypesByName,
0298       const std::unordered_map<std::string, const TensorProto*>& inputDataByName,
0299       DataValueMap& generatedShapeData)
0300       : generatedShapeData_(generatedShapeData) {
0301     size_t input_idx = 0;
0302 
0303     for (auto& attr : *n.mutable_attribute()) {
0304       attributesByName_[attr.name()] = &attr;
0305     }
0306 
0307     for (const auto& input : n.input()) {
0308       inputIndexToNameMap_.insert({input_idx++, input});
0309 
0310       auto valueTypesIter = valueTypesByName.find(input);
0311       if (valueTypesIter != valueTypesByName.end()) {
0312         allInputTypes_.push_back(valueTypesIter->second);
0313       } else {
0314         allInputTypes_.push_back(nullptr);
0315       }
0316 
0317       const auto inputDataIter = inputDataByName.find(input);
0318       if (inputDataIter != inputDataByName.cend()) {
0319         allInputData_.push_back(inputDataIter->second);
0320       } else {
0321         allInputData_.push_back(nullptr);
0322       }
0323     }
0324 
0325     size_t output_idx = 0;
0326     for (const auto& output : n.output()) {
0327       outputIndexToNameMap_.insert({output_idx++, output});
0328     }
0329 
0330     allOutputTypes_.resize(n.output_size());
0331   }
0332 
0333   const AttributeProto* getAttribute(const std::string& name) const override {
0334     auto iter = attributesByName_.find(name);
0335     if (iter == attributesByName_.end()) {
0336       return nullptr;
0337     } else {
0338       return iter->second;
0339     }
0340   }
0341 
0342   size_t getNumInputs() const override {
0343     return allInputTypes_.size();
0344   }
0345 
0346   const TypeProto* getInputType(size_t index) const override {
0347     if (index >= allInputTypes_.size()) {
0348       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0349     }
0350     return allInputTypes_[index];
0351   }
0352 
0353   size_t getNumOutputs() const override {
0354     return allOutputTypes_.size();
0355   }
0356 
0357   const TypeProto* getOutputType(size_t index) const override {
0358     if (index >= allOutputTypes_.size()) {
0359       ONNX_THROW("Output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0360     }
0361     return &allOutputTypes_[index];
0362   }
0363 
0364   // Convert integer vector into TensorShapeProto
0365   template <typename INTEGER>
0366   void vectorToTensorShapeProto(const std::vector<INTEGER>& input_vals, TensorShapeProto& converted_tsp) const {
0367     for (unsigned int i = 0; i < input_vals.size(); ++i) {
0368       converted_tsp.mutable_dim()->Add()->set_dim_value(input_vals[i]);
0369     }
0370   }
0371 
0372   const TensorShapeProto* getInputData(size_t index) override {
0373     if (index >= allInputData_.size()) {
0374       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0375     }
0376     const std::string input_name = inputIndexToNameMap_.at(index);
0377     // Gets it from previous data propagation
0378     auto iter = generatedShapeData_.find(input_name);
0379     if (iter != generatedShapeData_.end()) {
0380       return &iter->second;
0381     }
0382     // Otherwise, gets it from initializer if it exists
0383     const auto* input_data = allInputData_[index];
0384     // Only scalar (0D tensor) or 1D tensor can be converted for now
0385     // TODO: It should support tensors with more dimension on demand
0386     if (input_data != nullptr && (input_data->dims_size() == 0 || input_data->dims_size() == 1)) {
0387       TensorShapeProto tsp;
0388 
0389       if (input_data->data_type() == TensorProto_DataType_INT64) {
0390         vectorToTensorShapeProto(ParseData<int64_t>(input_data), tsp);
0391       } else if (input_data->data_type() == TensorProto_DataType_INT32) {
0392         vectorToTensorShapeProto(ParseData<int32_t>(input_data), tsp);
0393       } else {
0394         // Only supports integer type to form a shape
0395         return nullptr;
0396       }
0397 
0398       // Adds this TensorShapeProto from initializer into generatedShapeData
0399       // for future use
0400       auto result = generatedShapeData_.insert({input_name, std::move(tsp)});
0401       if (result.second) {
0402         return &(result.first->second);
0403       }
0404     }
0405     return nullptr;
0406   }
0407 
0408   void addOutputData(size_t index, TensorShapeProto&& tsp) override {
0409     if (index >= outputIndexToNameMap_.size()) {
0410       ONNX_THROW("Input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds.");
0411     }
0412     auto result = generatedShapeData_.insert({outputIndexToNameMap_.at(index), std::move(tsp)});
0413     if (!result.second) {
0414       fail_shape_inference("Data for input  " + ONNX_NAMESPACE::to_string(index) + " already exists.");
0415     }
0416   }
0417 
0418   std::vector<const TensorProto*> allInputData_;
0419   std::unordered_map<size_t, std::string> inputIndexToNameMap_;
0420   std::unordered_map<size_t, std::string> outputIndexToNameMap_;
0421   std::vector<const TypeProto*> allInputTypes_;
0422   std::vector<TypeProto> allOutputTypes_;
0423   DataValueMap& generatedShapeData_;
0424   std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0425 };
0426 
0427 void checkShapesAndTypes(const TypeProto_Sequence& inferredType, const TypeProto_Sequence& existingType);
0428 
0429 void checkShapesAndTypes(const TypeProto& inferredType, const TypeProto& existingType);
0430 
0431 template <typename TensorTypeProto>
0432 void GenerateSymbolicShape(TensorTypeProto* inferredType, SymbolTable& symbolTable);
0433 
0434 void MaterializeSymbolicShape(TypeProto* inferredType, SymbolTable& symbolTable);
0435 
0436 void mergeShapesAndTypes(const TypeProto_Tensor& inferredType, TypeProto_Tensor* existingType);
0437 
0438 void mergeShapesAndTypes(const TypeProto_SparseTensor& inferredType, TypeProto_SparseTensor* existingType);
0439 
0440 void mergeShapesAndTypes(const TypeProto_Sequence& inferredType, TypeProto_Tensor* existingType);
0441 
0442 void mergeShapesAndTypes(const TypeProto& inferredType, TypeProto* existingType);
0443 
0444 ///
0445 /// ModelLocalFunctionsMap is a map of function id -> model local function proto
0446 /// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
0447 ///
0448 void InferShapes(
0449     GraphProto* g,
0450     const std::unordered_map<std::string, int>& opset_imports,
0451     const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0452     const ShapeInferenceOptions& options = {},
0453     const ModelLocalFunctionsMap& in_model_functions = {});
0454 
0455 void InferShapes(
0456     ModelProto& m,
0457     const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0458     const ShapeInferenceOptions& options = {},
0459     DataValueMap* generated_shape_data_by_name = nullptr);
0460 
0461 void InferShapes(
0462     const std::string& model_path,
0463     const std::string& save_path = "",
0464     const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
0465     const ShapeInferenceOptions& options = {},
0466     DataValueMap* generated_shape_data_by_name = nullptr);
0467 
0468 ///
0469 /// ModelLocalFunctionsMap is a map of function id -> model local function proto
0470 /// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
0471 ///
0472 void InferShapeForFunctionNode(
0473     const FunctionProto& func,
0474     const ISchemaRegistry* schema_registry,
0475     InferenceContext& ctx,
0476     const ShapeInferenceOptions& options = {},
0477     const ModelLocalFunctionsMap& model_local_functions_map = {},
0478     SymbolTable* symbolTable = nullptr,
0479     DataValueMap* generated_shape_data_by_name = nullptr);
0480 
0481 ///
0482 /// ModelLocalFunctionsMap is a map of function id -> model local function proto
0483 /// All the ONNX helper utilities expect the function id == <function_proto.domain>:<function_proto.name>
0484 ///
0485 void InferShapeForFunctionNode(
0486     const FunctionProto& func_proto,
0487     const std::unordered_map<std::string, int>& func_opset_imports,
0488     const ISchemaRegistry* schema_registry,
0489     InferenceContext& ctx,
0490     const ShapeInferenceOptions& options = {},
0491     const ModelLocalFunctionsMap& model_local_functions_map = {},
0492     SymbolTable* symbolTable = nullptr,
0493     DataValueMap* generated_shape_data_by_name = nullptr);
0494 
0495 ///
0496 /// Apply type-and-shape-inference based checks to a Function body.
0497 /// Returns the inferred types of the outputs of the function.
0498 /// Inference depends on the types of the inputs of the function as well as
0499 /// the attribute values supplied.
0500 /// A TypeProto with value_case() == TypeProto::ValueCase::VALUE_NOT_SET is used
0501 /// for missing optional parameters.
0502 ///
0503 std::vector<TypeProto> InferFunctionOutputTypes(
0504     const FunctionProto& func_proto,
0505     const std::vector<TypeProto>& input_types,
0506     const std::vector<AttributeProto>& attributes);
0507 
0508 std::string GetErrorWithNodeInfo(const NodeProto& n, const std::runtime_error& err);
0509 
0510 void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbolTable);
0511 
0512 } // namespace shape_inference
0513 } // namespace ONNX_NAMESPACE