File indexing completed on 2025-02-21 10:05:38
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include <stdexcept>
0008 #include <string>
0009 #include <unordered_map>
0010 #include <unordered_set>
0011 #include <utility>
0012
0013 #include "onnx/defs/function.h"
0014 #include "onnx/defs/schema.h"
0015 #include "onnx/onnx-data_pb.h"
0016 #include "onnx/onnx-operators_pb.h"
0017 #include "onnx/onnx_pb.h"
0018 #include "onnx/string_utils.h"
0019
0020 namespace ONNX_NAMESPACE {
0021 namespace checker {
0022 class ValidationError final : public std::runtime_error {
0023 public:
0024 using std::runtime_error::runtime_error;
0025 const char* what() const noexcept override {
0026 if (!expanded_message_.empty()) {
0027 return expanded_message_.c_str();
0028 }
0029 return std::runtime_error::what();
0030 }
0031 void AppendContext(const std::string& context) {
0032 expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
0033 }
0034
0035 private:
0036 std::string expanded_message_;
0037 };
0038
0039 #define fail_check(...) \
0040 ONNX_THROW_EX(ONNX_NAMESPACE::checker::ValidationError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
0041
0042 class CheckerContext final {
0043 public:
0044 int get_ir_version() const {
0045 return ir_version_;
0046 }
0047 void set_ir_version(int v) {
0048 ir_version_ = v;
0049 }
0050 const std::unordered_map<std::string, int>& get_opset_imports() const {
0051 return opset_imports_;
0052 }
0053 void set_opset_imports(std::unordered_map<std::string, int> imps) {
0054 opset_imports_ = std::move(imps);
0055 }
0056 bool is_main_graph() const {
0057 return is_main_graph_;
0058 }
0059 void set_is_main_graph(bool is_main_graph) {
0060 is_main_graph_ = is_main_graph;
0061 }
0062
0063 void set_schema_registry(const ISchemaRegistry* schema_registry) {
0064 schema_registry_ = schema_registry;
0065 }
0066
0067 const ISchemaRegistry* get_schema_registry() const {
0068 return schema_registry_;
0069 }
0070
0071 void set_model_dir(const std::string& model_dir) {
0072 model_dir_ = model_dir;
0073 }
0074
0075 std::string get_model_dir() const {
0076 return model_dir_;
0077 }
0078
0079 bool skip_opset_compatibility_check() const {
0080 return skip_opset_compatibility_check_;
0081 }
0082
0083 void set_skip_opset_compatibility_check(bool value) {
0084 skip_opset_compatibility_check_ = value;
0085 }
0086
0087 explicit CheckerContext() : ir_version_(-1) {}
0088
0089 private:
0090 int ir_version_;
0091 std::unordered_map<std::string, int> opset_imports_;
0092 bool is_main_graph_ = true;
0093 const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance();
0094 std::string model_dir_;
0095 bool skip_opset_compatibility_check_ = false;
0096 };
0097
0098 class LexicalScopeContext final {
0099 public:
0100 LexicalScopeContext() = default;
0101
0102
0103
0104
0105
0106
0107
0108 LexicalScopeContext(const LexicalScopeContext& parent_context) : parent_context_{&parent_context} {}
0109 LexicalScopeContext& operator=(const LexicalScopeContext& parent_context) {
0110 parent_context_ = &parent_context;
0111 return *this;
0112 }
0113
0114 void add(const std::string& name) {
0115 output_names.insert(name);
0116 }
0117
0118 bool this_graph_has(const std::string& name) const {
0119 return output_names.find(name) != output_names.cend();
0120 }
0121
0122 bool this_or_ancestor_graph_has(const std::string& name) const {
0123 return this_graph_has(name) || (parent_context_ && parent_context_->this_or_ancestor_graph_has(name));
0124 }
0125
0126
0127
0128 std::unordered_set<std::string> output_names;
0129
0130 private:
0131 const LexicalScopeContext* parent_context_{nullptr};
0132 };
0133
0134 using IR_VERSION_TYPE = decltype(Version::IR_VERSION);
0135 void check_value_info(const ValueInfoProto& value_info, const CheckerContext&);
0136 void check_tensor(const TensorProto& tensor, const CheckerContext&);
0137 void check_sparse_tensor(const SparseTensorProto& sparse_tensor, const CheckerContext&);
0138 void check_sequence(const SequenceProto& sequence, const CheckerContext&);
0139 void check_map(const MapProto& map, const CheckerContext&);
0140 void check_optional(const OptionalProto& opt, const CheckerContext&);
0141 void check_attribute(const AttributeProto& attr, const CheckerContext&, const LexicalScopeContext&);
0142 void check_node(const NodeProto& node, const CheckerContext&, const LexicalScopeContext&);
0143 void check_graph(const GraphProto& graph, const CheckerContext&, const LexicalScopeContext&);
0144 void check_function(const FunctionProto& function, const CheckerContext&, const LexicalScopeContext&);
0145
0146
0147
0148
0149 void check_opset_compatibility(
0150 const NodeProto& node,
0151 const CheckerContext& ctx,
0152 const std::unordered_map<std::string, int>& func_opset_imports,
0153 const std::unordered_map<std::string, int>& model_opset_imports);
0154
0155
0156 void check_model_local_functions(
0157 const ModelProto& model,
0158 const CheckerContext& ctx,
0159 const LexicalScopeContext& parent_lex);
0160
0161 void check_model(const ModelProto& model, bool full_check = false, bool skip_opset_compatibility_check = false);
0162 void check_model(const std::string& model_path, bool full_check = false, bool skip_opset_compatibility_check = false);
0163
0164 bool check_is_experimental_op(const NodeProto& node);
0165
0166 }
0167 }