Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 09:03:58

0001 // Copyright (c) ONNX Project Contributors
0002 //
0003 // SPDX-License-Identifier: Apache-2.0
0004 
0005 #pragma once
0006 
0007 #include <stdexcept>
0008 #include <string>
0009 #include <unordered_map>
0010 #include <unordered_set>
0011 #include <utility>
0012 
0013 #include "onnx/defs/function.h"
0014 #include "onnx/defs/schema.h"
0015 #include "onnx/onnx-data_pb.h"
0016 #include "onnx/onnx-operators_pb.h"
0017 #include "onnx/onnx_pb.h"
0018 #include "onnx/string_utils.h"
0019 
0020 namespace ONNX_NAMESPACE {
0021 namespace checker {
0022 class ValidationError final : public std::runtime_error {
0023  public:
0024   using std::runtime_error::runtime_error;
0025   const char* what() const noexcept override {
0026     if (!expanded_message_.empty()) {
0027       return expanded_message_.c_str();
0028     }
0029     return std::runtime_error::what();
0030   }
0031   void AppendContext(const std::string& context) {
0032     expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
0033   }
0034 
0035  private:
0036   std::string expanded_message_;
0037 };
0038 
0039 #define fail_check(...) \
0040   ONNX_THROW_EX(ONNX_NAMESPACE::checker::ValidationError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
0041 
0042 class CheckerContext final {
0043  public:
0044   int get_ir_version() const {
0045     return ir_version_;
0046   }
0047   void set_ir_version(int v) {
0048     ir_version_ = v;
0049   }
0050   const std::unordered_map<std::string, int>& get_opset_imports() const {
0051     return opset_imports_;
0052   }
0053   void set_opset_imports(std::unordered_map<std::string, int> imps) {
0054     opset_imports_ = std::move(imps);
0055   }
0056   bool is_main_graph() const {
0057     return is_main_graph_;
0058   }
0059   void set_is_main_graph(bool is_main_graph) {
0060     is_main_graph_ = is_main_graph;
0061   }
0062 
0063   void set_schema_registry(const ISchemaRegistry* schema_registry) {
0064     schema_registry_ = schema_registry;
0065   }
0066 
0067   const ISchemaRegistry* get_schema_registry() const {
0068     return schema_registry_;
0069   }
0070 
0071   void set_model_dir(const std::string& model_dir) {
0072     model_dir_ = model_dir;
0073   }
0074 
0075   std::string get_model_dir() const {
0076     return model_dir_;
0077   }
0078 
0079   bool skip_opset_compatibility_check() const {
0080     return skip_opset_compatibility_check_;
0081   }
0082 
0083   void set_skip_opset_compatibility_check(bool value) {
0084     skip_opset_compatibility_check_ = value;
0085   }
0086 
0087   bool check_custom_domain() const {
0088     return check_custom_domain_;
0089   }
0090 
0091   void set_check_custom_domain(bool value) {
0092     check_custom_domain_ = value;
0093   }
0094 
0095   explicit CheckerContext() : ir_version_(-1) {}
0096 
0097  private:
0098   int ir_version_;
0099   std::unordered_map<std::string, int> opset_imports_;
0100   bool is_main_graph_ = true;
0101   const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance();
0102   std::string model_dir_;
0103   bool skip_opset_compatibility_check_ = false;
0104   bool check_custom_domain_ = false;
0105 };
0106 
0107 class LexicalScopeContext final {
0108  public:
0109   LexicalScopeContext() = default;
0110 
0111   // Construct an instance with the lexical scope from the parent graph to allow
0112   // lookup of names from that scope via this_or_ancestor_graph_has.
0113   // The caller must ensure parent_context remains valid for the entire lifetime
0114   // of the new instance. Alternatively, if that cannot be guaranteed, create an
0115   // instance with the default constructor and populate output_names with the
0116   // values from the parent scope so the values are copied instead.
0117   LexicalScopeContext(const LexicalScopeContext& parent_context) : parent_context_{&parent_context} {}
0118   LexicalScopeContext& operator=(const LexicalScopeContext& parent_context) {
0119     parent_context_ = &parent_context;
0120     return *this;
0121   }
0122 
0123   void add(const std::string& name) {
0124     output_names.insert(name);
0125   }
0126 
0127   bool this_graph_has(const std::string& name) const {
0128     return output_names.find(name) != output_names.cend();
0129   }
0130 
0131   bool this_or_ancestor_graph_has(const std::string& name) const {
0132     return this_graph_has(name) || (parent_context_ && parent_context_->this_or_ancestor_graph_has(name));
0133   }
0134 
0135   // public for backwards compatibility. please prefer the public interface of
0136   // this class over directly changing output_names
0137   std::unordered_set<std::string> output_names;
0138 
0139  private:
0140   const LexicalScopeContext* parent_context_{nullptr};
0141 };
0142 
0143 using IR_VERSION_TYPE = decltype(Version::IR_VERSION);
0144 void check_value_info(const ValueInfoProto& value_info, const CheckerContext&);
0145 void check_tensor(const TensorProto& tensor, const CheckerContext&);
0146 void check_sparse_tensor(const SparseTensorProto& sparse_tensor, const CheckerContext&);
0147 void check_sequence(const SequenceProto& sequence, const CheckerContext&);
0148 void check_map(const MapProto& map, const CheckerContext&);
0149 void check_optional(const OptionalProto& opt, const CheckerContext&);
0150 void check_attribute(const AttributeProto& attr, const CheckerContext&, const LexicalScopeContext&);
0151 void check_node(const NodeProto& node, const CheckerContext&, const LexicalScopeContext&);
0152 void check_graph(const GraphProto& graph, const CheckerContext&, const LexicalScopeContext&);
0153 void check_function(const FunctionProto& function, const CheckerContext&, const LexicalScopeContext&);
0154 
0155 // Check schema compatibility for 2 opset versions for a given node.
0156 // Checks whether the schema for 2 versions is same, this is true when the opschema
0157 // does not change between versions.
0158 void check_opset_compatibility(
0159     const NodeProto& node,
0160     const CheckerContext& ctx,
0161     const std::unordered_map<std::string, int>& func_opset_imports,
0162     const std::unordered_map<std::string, int>& model_opset_imports);
0163 
0164 // Checks all model local functions present in ModelProto
0165 void check_model_local_functions(
0166     const ModelProto& model,
0167     const CheckerContext& ctx,
0168     const LexicalScopeContext& parent_lex);
0169 
0170 void check_model(
0171     const ModelProto& model,
0172     bool full_check = false,
0173     bool skip_opset_compatibility_check = false,
0174     bool check_custom_domain = false);
0175 void check_model(
0176     const std::string& model_path,
0177     bool full_check = false,
0178     bool skip_opset_compatibility_check = false,
0179     bool check_custom_domain = false);
0180 std::string resolve_external_data_location(
0181     const std::string& base_dir,
0182     const std::string& location,
0183     const std::string& tensor_name);
0184 bool check_is_experimental_op(const NodeProto& node);
0185 
0186 } // namespace checker
0187 } // namespace ONNX_NAMESPACE