File indexing completed on 2025-11-03 09:52:25
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include <climits>
0008 #include <cstring>
0009 #include <functional>
0010 #include <initializer_list>
0011 #include <iostream>
0012 #include <limits>
0013 #include <map>
0014 #include <memory>
0015 #include <ostream>
0016 #include <set>
0017 #include <string>
0018 #include <string_view>
0019 #include <tuple>
0020 #include <unordered_map>
0021 #include <unordered_set>
0022 #include <utility>
0023 #include <vector>
0024
0025 #include "onnx/common/common.h"
0026 #include "onnx/common/constants.h"
0027 #include "onnx/defs/shape_inference.h"
0028
0029 namespace ONNX_NAMESPACE {
0030
0031 struct FunctionBodyBuildContext {
0032 virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0033 virtual bool hasInput(int inputIndex) const = 0;
0034 virtual bool hasOutput(int inputIndex) const = 0;
0035
0036
0037 virtual const TypeProto* getInputType(int inputIndex) const = 0;
0038 virtual ~FunctionBodyBuildContext() {}
0039 };
0040
0041 struct FunctionBodyBuildContextImpl : public FunctionBodyBuildContext {
0042
0043
0044
0045
0046
0047 FunctionBodyBuildContextImpl(const NodeProto& node_proto, const std::vector<TypeProto>& input_types = {})
0048 : node_proto_(node_proto), input_types_(input_types) {
0049 for (auto& attr : node_proto.attribute()) {
0050 attributesByName_[attr.name()] = &attr;
0051 }
0052 }
0053
0054 const AttributeProto* getAttribute(const std::string& name) const override {
0055 auto iter = attributesByName_.find(name);
0056 if (iter == attributesByName_.end()) {
0057 return nullptr;
0058 } else {
0059 return iter->second;
0060 }
0061 }
0062
0063 bool hasInput(int inputIndex) const override {
0064 if (inputIndex >= node_proto_.input_size())
0065 return false;
0066 return node_proto_.input(inputIndex) != "";
0067 }
0068
0069 bool hasOutput(int inputIndex) const override {
0070 if (inputIndex >= node_proto_.output_size())
0071 return false;
0072 return node_proto_.output(inputIndex) != "";
0073 }
0074
0075 const TypeProto* getInputType(int inputIndex) const override {
0076 if (inputIndex < 0)
0077 return nullptr;
0078 size_t j = static_cast<size_t>(inputIndex);
0079 if (j >= input_types_.size())
0080 return nullptr;
0081
0082 if (input_types_[j].value_case() == TypeProto::ValueCase::VALUE_NOT_SET)
0083 return nullptr;
0084 return &input_types_[j];
0085 }
0086
0087 std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0088
0089 NodeProto node_proto_;
0090 std::vector<TypeProto> input_types_;
0091 };
0092
0093 using FunctionBodyQueryFunction = std::function<bool(FunctionBodyBuildContext&)>;
0094
0095 class OpSchema;
0096 using ContextDependentFunctionBodyBuilder =
0097 std::function<bool(const FunctionBodyBuildContext&, const OpSchema&, FunctionProto&)>;
0098
0099 class SchemaError final : public std::runtime_error {
0100 public:
0101 using std::runtime_error::runtime_error;
0102
0103 SchemaError(const std::string& message) : std::runtime_error(message) {}
0104
0105 const char* what() const noexcept override {
0106 if (!expanded_message_.empty()) {
0107 return expanded_message_.c_str();
0108 }
0109 return std::runtime_error::what();
0110 }
0111
0112 void AppendContext(const std::string& context) {
0113 expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
0114 }
0115
0116 private:
0117 std::string expanded_message_;
0118 };
0119
0120 #define fail_schema(...) ONNX_THROW_EX(ONNX_NAMESPACE::SchemaError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
0121
0122 using OperatorSetVersion = int;
0123
0124 using DataTypeSet = std::unordered_set<DataType>;
0125
0126
0127
0128 using TypeConstraintMap = std::unordered_map<std::string, std::pair<DataTypeSet, std::string>>;
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149 class OpSchema final {
0150 public:
0151 static constexpr int kUninitializedSinceVersion = -1;
0152
0153 enum FormalParameterOption : uint8_t {
0154
0155
0156 Single = 0,
0157
0158
0159 Optional = 1,
0160
0161
0162
0163 Variadic = 2,
0164 };
0165 enum DifferentiationCategory : uint8_t {
0166
0167
0168
0169
0170 Unknown = 0,
0171
0172
0173 Differentiable = 1,
0174
0175
0176 NonDifferentiable = 2
0177 };
0178
0179
0180
0181 class FormalParameter final {
0182 public:
0183
0184 FormalParameter() = default;
0185
0186 explicit FormalParameter(
0187 std::string name,
0188 DataTypeSet allowed_type_set,
0189 std::string type_str,
0190 const std::string& description,
0191 FormalParameterOption param_option = Single,
0192 bool is_homogeneous = true,
0193 int min_arity = 1,
0194 DifferentiationCategory differentiation_category = Unknown)
0195 : name_(std::move(name)),
0196 type_set_(std::move(allowed_type_set)),
0197 type_str_(std::move(type_str)),
0198 #ifndef __ONNX_NO_DOC_STRINGS
0199 description_(description),
0200 #endif
0201 param_option_(param_option),
0202 is_homogeneous_(is_homogeneous),
0203 min_arity_(min_arity),
0204 differentiation_category_(differentiation_category) {
0205 #ifdef __ONNX_NO_DOC_STRINGS
0206 ONNX_UNUSED_PARAMETER(description);
0207 #endif
0208 }
0209
0210 explicit FormalParameter(
0211 std::string name,
0212 const std::string& description,
0213 std::string type_str,
0214 FormalParameterOption param_option = Single,
0215 bool is_homogeneous = true,
0216 int min_arity = 1,
0217 DifferentiationCategory differentiation_category = Unknown)
0218 : name_(std::move(name)),
0219 type_str_(std::move(type_str)),
0220 #ifndef __ONNX_NO_DOC_STRINGS
0221 description_(description),
0222 #endif
0223 param_option_(param_option),
0224 is_homogeneous_(is_homogeneous),
0225 min_arity_(min_arity),
0226 differentiation_category_(differentiation_category) {
0227 #ifdef __ONNX_NO_DOC_STRINGS
0228 ONNX_UNUSED_PARAMETER(description);
0229 #endif
0230 }
0231
0232
0233 const std::string& GetName() const;
0234
0235
0236 const DataTypeSet& GetTypes() const;
0237
0238
0239 const std::string& GetTypeStr() const;
0240
0241
0242 const std::string& GetDescription() const;
0243
0244
0245 FormalParameterOption GetOption() const;
0246
0247
0248 bool GetIsHomogeneous() const;
0249
0250
0251 int GetMinArity() const;
0252
0253
0254 DifferentiationCategory GetDifferentiationCategory() const;
0255
0256 private:
0257 friend class OpSchema;
0258
0259 DataTypeSet& MutableTypes();
0260
0261
0262 std::string name_;
0263
0264
0265
0266 DataTypeSet type_set_;
0267
0268
0269
0270
0271 std::string type_str_;
0272
0273
0274 std::string description_;
0275
0276
0277 FormalParameterOption param_option_;
0278
0279
0280
0281 bool is_homogeneous_;
0282
0283
0284 int min_arity_;
0285
0286
0287
0288
0289 DifferentiationCategory differentiation_category_;
0290 };
0291
0292 enum class SupportType : uint8_t {
0293 COMMON,
0294 EXPERIMENTAL,
0295
0296 };
0297
0298 OpSchema() : OpSchema("unknown", "unknown", 0) {}
0299 OpSchema(std::string name, std::string file, int line)
0300 : name_(std::move(name)), file_(std::move(file)), line_(line), support_(SupportType::COMMON) {}
0301
0302
0303
0304
0305 const std::string& file() const {
0306 return file_;
0307 }
0308
0309
0310
0311
0312 int line() const {
0313 return line_;
0314 }
0315
0316
0317
0318
0319 SupportType support_level() const {
0320 return support_;
0321 }
0322
0323
0324
0325
0326 const char* doc() const {
0327 return doc_.empty() ? nullptr : doc_.c_str();
0328 }
0329
0330
0331 void CheckInputOutputType(struct InferenceContext&) const;
0332
0333
0334
0335
0336
0337 void Verify(const NodeProto& node) const;
0338
0339
0340
0341
0342
0343
0344
0345
0346
0347
0348
0349
0350
0351
0352
0353
0354
0355 OpSchema& SinceVersion(OperatorSetVersion n);
0356
0357
0358
0359
0360
0361
0362 OpSchema& Deprecate();
0363
0364 bool Deprecated() const {
0365 return deprecated_;
0366 }
0367
0368
0369
0370
0371 OpSchema& NumInputs(std::set<int> allowed_input_nums);
0372
0373
0374
0375
0376 OpSchema& NumOutputs(std::set<int> allowed_output_nums);
0377
0378
0379
0380
0381
0382 OpSchema& TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
0383 InferenceFunction GetTypeAndShapeInferenceFunction() const {
0384 return tensor_inference_function_ ? tensor_inference_function_ : dummyInferenceFunction;
0385 }
0386
0387 OpSchema& PartialDataPropagationFunction(DataPropagationFunction dataProgationFunction);
0388 DataPropagationFunction GetDataPropagationFunction() const {
0389 return data_propagation_function_ ? data_propagation_function_ : dummyDataPropagationFunction;
0390 }
0391
0392
0393 OpSchema& SetSupportLevel(SupportType supportType);
0394
0395
0396
0397 OpSchema& SetDoc(const char* doc) {
0398 #ifndef __ONNX_NO_DOC_STRINGS
0399 SetDoc(std::string(doc));
0400 #else
0401 ONNX_UNUSED_PARAMETER(doc);
0402 #endif
0403
0404 return *this;
0405 }
0406
0407 OpSchema& SetDoc(const std::string& doc) {
0408 #ifndef __ONNX_NO_DOC_STRINGS
0409 doc_ = doc;
0410 #else
0411 ONNX_UNUSED_PARAMETER(doc);
0412 #endif
0413 return *this;
0414 }
0415
0416
0417 OpSchema& SetName(const char* name);
0418 OpSchema& SetName(std::string name);
0419
0420
0421 OpSchema& SetLocation(const char* file, int line);
0422 OpSchema& SetLocation(std::string file, int line);
0423
0424
0425
0426 OpSchema& SetDomain(const char* domain);
0427 OpSchema& SetDomain(std::string domain);
0428
0429 struct Attribute final {
0430 Attribute(std::string name_, std::string description_, AttributeProto::AttributeType type_, bool required_)
0431 : name(std::move(name_)),
0432 description(std::move(description_)),
0433 type(type_),
0434 required(required_),
0435 default_value() {}
0436
0437 Attribute(std::string name_, std::string description_, AttributeProto default_value_)
0438 : name(std::move(name_)),
0439 description(std::move(description_)),
0440 type(default_value_.type()),
0441 required(false),
0442 default_value(std::move(default_value_)) {}
0443
0444 const std::string name;
0445 const std::string description;
0446 AttributeProto::AttributeType type;
0447 bool required;
0448 AttributeProto default_value;
0449 };
0450
0451 OpSchema& Attr(Attribute attr);
0452
0453
0454 #define ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName) \
0455 OpSchema& Attr( \
0456 std::string name, std::string description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
0457 \
0458 OpSchema& Attr( \
0459 const char* name, const char* description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
0460 OpSchema& Attr( \
0461 std::string name, \
0462 std::string description, \
0463 AttributeProto::AttributeType type, \
0464 const std::vector<TypeName>& defaultValue);
0465
0466 ATTR_SETTER_WITH_DEFAULT_VALUE(int64_t)
0467 ATTR_SETTER_WITH_DEFAULT_VALUE(float)
0468 ATTR_SETTER_WITH_DEFAULT_VALUE(std::string)
0469 ATTR_SETTER_WITH_DEFAULT_VALUE(TensorProto)
0470 ATTR_SETTER_WITH_DEFAULT_VALUE(GraphProto)
0471 ATTR_SETTER_WITH_DEFAULT_VALUE(TypeProto)
0472
0473 OpSchema& Attr(
0474 std::string name,
0475 std::string description,
0476 std::string conditionExplanation,
0477 AttributeProto::AttributeType attr_type);
0478
0479
0480 OpSchema& Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required = true);
0481
0482
0483 OpSchema& Attr(const char* name, const char* description, AttributeProto::AttributeType type, bool required = true);
0484
0485 OpSchema& AllowUncheckedAttributes();
0486
0487
0488 struct TypeConstraintParam final {
0489 TypeConstraintParam(
0490 std::string type_param_str_,
0491 std::vector<std::string> allowed_type_strs_,
0492 std::string description_)
0493 : type_param_str(std::move(type_param_str_)),
0494 allowed_type_strs(std::move(allowed_type_strs_)),
0495 description(std::move(description_)) {}
0496
0497
0498 std::string type_param_str;
0499
0500
0501 std::vector<std::string> allowed_type_strs;
0502
0503 std::string description;
0504 };
0505
0506
0507
0508
0509
0510
0511
0512
0513
0514
0515
0516
0517
0518
0519
0520
0521
0522
0523
0524
0525
0526
0527
0528
0529
0530
0531
0532 OpSchema& Input(int n, FormalParameter formal_parameter);
0533
0534 OpSchema& Input(
0535 int n,
0536 std::string name,
0537 const std::string& description,
0538 std::string type_str,
0539 FormalParameterOption param_option = Single,
0540 bool is_homogeneous = true,
0541 int min_arity = 1,
0542 DifferentiationCategory differentiation_category = Unknown);
0543
0544
0545 OpSchema& Input(
0546 int n,
0547 const char* name,
0548 const char* description,
0549 const char* type_str,
0550 FormalParameterOption param_option = Single,
0551 bool is_homogeneous = true,
0552 int min_arity = 1,
0553 DifferentiationCategory differentiation_category = Unknown);
0554
0555 OpSchema& Output(int n, FormalParameter formal_parameter);
0556
0557 OpSchema& Output(
0558 int n,
0559 std::string name,
0560 const std::string& description,
0561 std::string type_str,
0562 FormalParameterOption param_option = Single,
0563 bool is_homogeneous = true,
0564 int min_arity = 1,
0565 DifferentiationCategory differentiation_category = Unknown);
0566
0567
0568 OpSchema& Output(
0569 int n,
0570 const char* name,
0571 const char* description,
0572 const char* type_str,
0573 FormalParameterOption param_option = Single,
0574 bool is_homogeneous = true,
0575 int min_arity = 1,
0576 DifferentiationCategory differentiation_category = Unknown);
0577
0578 OpSchema& TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description);
0579
0580
0581 OpSchema&
0582 TypeConstraint(const char* type_str, std::initializer_list<const char*> constraints, const char* description);
0583
0584
0585
0586
0587 static const std::vector<std::string>& numeric_types_for_math_reduction_ir10() {
0588 return numeric_types_for_math_reduction_ir9();
0589 }
0590
0591 static const std::vector<std::string>& numeric_types_for_math_reduction_ir9() {
0592 static const std::vector<std::string> numeric_types_for_math_reduction_ir9 = {
0593 "tensor(uint32)",
0594 "tensor(uint64)",
0595 "tensor(int32)",
0596 "tensor(int64)",
0597 "tensor(float16)",
0598 "tensor(float)",
0599 "tensor(double)",
0600 "tensor(bfloat16)",
0601 "tensor(float8e4m3fn)",
0602 "tensor(float8e4m3fnuz)",
0603 "tensor(float8e5m2)",
0604 "tensor(float8e5m2fnuz)"};
0605 return numeric_types_for_math_reduction_ir9;
0606 }
0607
0608 static const std::vector<std::string>& numeric_types_for_math_reduction_ir4() {
0609 static const std::vector<std::string> numeric_types_for_math_reduction_ir4 = {
0610 "tensor(uint32)",
0611 "tensor(uint64)",
0612 "tensor(int32)",
0613 "tensor(int64)",
0614 "tensor(float16)",
0615 "tensor(float)",
0616 "tensor(double)",
0617 "tensor(bfloat16)"};
0618 return numeric_types_for_math_reduction_ir4;
0619 }
0620
0621 static const std::vector<std::string>& numeric_types_for_math_reduction() {
0622 static const std::vector<std::string> numeric_types_for_math_reduction = {
0623 "tensor(uint32)",
0624 "tensor(uint64)",
0625 "tensor(int32)",
0626 "tensor(int64)",
0627 "tensor(float16)",
0628 "tensor(float)",
0629 "tensor(double)"};
0630 return numeric_types_for_math_reduction;
0631 }
0632
0633 static const std::vector<std::string>& all_numeric_types_ir10() {
0634 static const std::vector<std::string> all_numeric_types_ir10 = {
0635 "tensor(uint8)",
0636 "tensor(uint16)",
0637 "tensor(uint32)",
0638 "tensor(uint64)",
0639 "tensor(int8)",
0640 "tensor(int16)",
0641 "tensor(int32)",
0642 "tensor(int64)",
0643 "tensor(float16)",
0644 "tensor(float)",
0645 "tensor(double)",
0646 "tensor(bfloat16)",
0647 "tensor(float8e4m3fn)",
0648 "tensor(float8e4m3fnuz)",
0649 "tensor(float8e5m2)",
0650 "tensor(float8e5m2fnuz)",
0651 "tensor(uint4)",
0652 "tensor(int4)"};
0653 return all_numeric_types_ir10;
0654 }
0655
0656 static const std::vector<std::string>& all_numeric_types_ir9() {
0657 static const std::vector<std::string> all_numeric_types_ir9 = {
0658 "tensor(uint8)",
0659 "tensor(uint16)",
0660 "tensor(uint32)",
0661 "tensor(uint64)",
0662 "tensor(int8)",
0663 "tensor(int16)",
0664 "tensor(int32)",
0665 "tensor(int64)",
0666 "tensor(float16)",
0667 "tensor(float)",
0668 "tensor(double)",
0669 "tensor(bfloat16)",
0670 "tensor(float8e4m3fn)",
0671 "tensor(float8e4m3fnuz)",
0672 "tensor(float8e5m2)",
0673 "tensor(float8e5m2fnuz)"};
0674 return all_numeric_types_ir9;
0675 }
0676
0677 static const std::vector<std::string>& all_numeric_types_ir4() {
0678 static const std::vector<std::string> all_numeric_types_ir4 = {
0679 "tensor(uint8)",
0680 "tensor(uint16)",
0681 "tensor(uint32)",
0682 "tensor(uint64)",
0683 "tensor(int8)",
0684 "tensor(int16)",
0685 "tensor(int32)",
0686 "tensor(int64)",
0687 "tensor(float16)",
0688 "tensor(float)",
0689 "tensor(double)",
0690 "tensor(bfloat16)"};
0691 return all_numeric_types_ir4;
0692 }
0693
0694 static const std::vector<std::string>& all_numeric_types() {
0695 static const std::vector<std::string> all_numeric_types = {
0696 "tensor(uint8)",
0697 "tensor(uint16)",
0698 "tensor(uint32)",
0699 "tensor(uint64)",
0700 "tensor(int8)",
0701 "tensor(int16)",
0702 "tensor(int32)",
0703 "tensor(int64)",
0704 "tensor(float16)",
0705 "tensor(float)",
0706 "tensor(double)"};
0707 return all_numeric_types;
0708 }
0709
0710 static const std::vector<std::string>& all_numeric_sequence_types() {
0711 static const std::vector<std::string> all_numeric_sequence_types = {
0712 "seq(tensor(uint8))",
0713 "seq(tensor(uint16))",
0714 "seq(tensor(uint32))",
0715 "seq(tensor(uint64))",
0716 "seq(tensor(int8))",
0717 "seq(tensor(int16))",
0718 "seq(tensor(int32))",
0719 "seq(tensor(int64))",
0720 "seq(tensor(float16))",
0721 "seq(tensor(float))",
0722 "seq(tensor(double))"};
0723 return all_numeric_sequence_types;
0724 }
0725
0726 static const std::vector<std::string>& all_tensor_types() {
0727 static const std::vector<std::string> all_tensor_types = {
0728 "tensor(uint8)",
0729 "tensor(uint16)",
0730 "tensor(uint32)",
0731 "tensor(uint64)",
0732 "tensor(int8)",
0733 "tensor(int16)",
0734 "tensor(int32)",
0735 "tensor(int64)",
0736 "tensor(float16)",
0737 "tensor(float)",
0738 "tensor(double)",
0739 "tensor(string)",
0740 "tensor(bool)",
0741 "tensor(complex64)",
0742 "tensor(complex128)"};
0743 return all_tensor_types;
0744 }
0745
0746 static const std::vector<std::string>& all_tensor_types_ir4() {
0747 static const std::vector<std::string> all_tensor_types_ir4 = {
0748 "tensor(uint8)",
0749 "tensor(uint16)",
0750 "tensor(uint32)",
0751 "tensor(uint64)",
0752 "tensor(int8)",
0753 "tensor(int16)",
0754 "tensor(int32)",
0755 "tensor(int64)",
0756 "tensor(bfloat16)",
0757 "tensor(float16)",
0758 "tensor(float)",
0759 "tensor(double)",
0760 "tensor(string)",
0761 "tensor(bool)",
0762 "tensor(complex64)",
0763 "tensor(complex128)"};
0764 return all_tensor_types_ir4;
0765 }
0766
0767 static const std::vector<std::string>& all_non_complex_numeric_types_plus_bool_ir4() {
0768 static const std::vector<std::string> all_non_complex_numeric_types_plus_bool_ir4 = {
0769 "tensor(uint8)",
0770 "tensor(uint16)",
0771 "tensor(uint32)",
0772 "tensor(uint64)",
0773 "tensor(int8)",
0774 "tensor(int16)",
0775 "tensor(int32)",
0776 "tensor(int64)",
0777 "tensor(bfloat16)",
0778 "tensor(float16)",
0779 "tensor(float)",
0780 "tensor(double)",
0781 "tensor(bool)"};
0782 return all_non_complex_numeric_types_plus_bool_ir4;
0783 }
0784
0785 static const std::vector<std::string>& all_float_types_ir4() {
0786 static const std::vector<std::string> all_float_types_ir4 = {
0787 "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)"};
0788 return all_float_types_ir4;
0789 }
0790
0791 static const std::vector<std::string>& all_float_types_plus_Xint8_ir4() {
0792 static const std::vector<std::string> all_float_types_ir4 = {
0793 "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(int8)", "tensor(uint8)"};
0794 return all_float_types_ir4;
0795 }
0796
0797 static const std::vector<std::string>& all_float_types_ir9() {
0798 static const std::vector<std::string> all_float_types_ir9 = {
0799 "tensor(bfloat16)",
0800 "tensor(float16)",
0801 "tensor(float)",
0802 "tensor(double)",
0803 "tensor(float8e4m3fn)",
0804 "tensor(float8e4m3fnuz)",
0805 "tensor(float8e5m2)",
0806 "tensor(float8e5m2fnuz)"};
0807 return all_float_types_ir9;
0808 }
0809
0810 static const std::vector<std::string>& all_float_types_ir10() {
0811 return all_float_types_ir9();
0812 }
0813
0814 static const std::vector<std::string>& all_tensor_types_ir9() {
0815 static const std::vector<std::string> all_tensor_types_ir9 = {
0816 "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)",
0817 "tensor(int8)", "tensor(int16)", "tensor(int32)", "tensor(int64)",
0818 "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)",
0819 "tensor(string)", "tensor(bool)", "tensor(complex64)", "tensor(complex128)",
0820 "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"};
0821 return all_tensor_types_ir9;
0822 }
0823
0824 static const std::vector<std::string>& all_tensor_types_ir10() {
0825 static const std::vector<std::string> all_tensor_types_ir10 = {
0826 "tensor(uint8)", "tensor(uint16)", "tensor(uint32)",
0827 "tensor(uint64)", "tensor(int8)", "tensor(int16)",
0828 "tensor(int32)", "tensor(int64)", "tensor(bfloat16)",
0829 "tensor(float16)", "tensor(float)", "tensor(double)",
0830 "tensor(string)", "tensor(bool)", "tensor(complex64)",
0831 "tensor(complex128)", "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
0832 "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)",
0833 "tensor(int4)"};
0834 return all_tensor_types_ir10;
0835 }
0836
0837 static const std::vector<std::string>& all_non_complex_tensor_types_ir10() {
0838 static const std::vector<std::string> all_non_complex_tensor_types_ir10 = {
0839 "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)",
0840 "tensor(int8)", "tensor(int16)", "tensor(int32)", "tensor(int64)",
0841 "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)",
0842 "tensor(string)", "tensor(bool)", "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
0843 "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)", "tensor(int4)"};
0844 return all_non_complex_tensor_types_ir10;
0845 }
0846
0847 static const std::vector<std::string>& all_tensor_sequence_types() {
0848 static const std::vector<std::string> all_tensor_sequence_types = {
0849 "seq(tensor(uint8))",
0850 "seq(tensor(uint16))",
0851 "seq(tensor(uint32))",
0852 "seq(tensor(uint64))",
0853 "seq(tensor(int8))",
0854 "seq(tensor(int16))",
0855 "seq(tensor(int32))",
0856 "seq(tensor(int64))",
0857 "seq(tensor(float16))",
0858 "seq(tensor(float))",
0859 "seq(tensor(double))",
0860 "seq(tensor(string))",
0861 "seq(tensor(bool))",
0862 "seq(tensor(complex64))",
0863 "seq(tensor(complex128))"};
0864 return all_tensor_sequence_types;
0865 }
0866
0867 static const std::vector<std::string>& all_tensor_sequence_types_ir4() {
0868 static const std::vector<std::string> all_tensor_sequence_types_ir4 = {
0869 "seq(tensor(uint8))",
0870 "seq(tensor(uint16))",
0871 "seq(tensor(uint32))",
0872 "seq(tensor(uint64))",
0873 "seq(tensor(int8))",
0874 "seq(tensor(int16))",
0875 "seq(tensor(int32))",
0876 "seq(tensor(int64))",
0877 "seq(tensor(bfloat16))",
0878 "seq(tensor(float16))",
0879 "seq(tensor(float))",
0880 "seq(tensor(double))",
0881 "seq(tensor(string))",
0882 "seq(tensor(bool))",
0883 "seq(tensor(complex64))",
0884 "seq(tensor(complex128))"};
0885 return all_tensor_sequence_types_ir4;
0886 }
0887
0888 static const std::vector<std::string>& all_tensor_sequence_types_ir9() {
0889 static const std::vector<std::string> all_tensor_sequence_types_ir9 = {
0890 "seq(tensor(uint8))", "seq(tensor(uint16))", "seq(tensor(uint32))",
0891 "seq(tensor(uint64))", "seq(tensor(int8))", "seq(tensor(int16))",
0892 "seq(tensor(int32))", "seq(tensor(int64))", "seq(tensor(bfloat16))",
0893 "seq(tensor(float16))", "seq(tensor(float))", "seq(tensor(double))",
0894 "seq(tensor(string))", "seq(tensor(bool))", "seq(tensor(complex64))",
0895 "seq(tensor(complex128))", "seq(tensor(float8e4m3fn))", "seq(tensor(float8e4m3fnuz))",
0896 "seq(tensor(float8e5m2))", "seq(tensor(float8e5m2fnuz))"};
0897 return all_tensor_sequence_types_ir9;
0898 }
0899
0900 static const std::vector<std::string>& all_tensor_sequence_types_ir10() {
0901 static const std::vector<std::string> all_tensor_sequence_types_ir10 = {
0902 "seq(tensor(uint8))", "seq(tensor(uint16))", "seq(tensor(uint32))",
0903 "seq(tensor(uint64))", "seq(tensor(int8))", "seq(tensor(int16))",
0904 "seq(tensor(int32))", "seq(tensor(int64))", "seq(tensor(bfloat16))",
0905 "seq(tensor(float16))", "seq(tensor(float))", "seq(tensor(double))",
0906 "seq(tensor(string))", "seq(tensor(bool))", "seq(tensor(complex64))",
0907 "seq(tensor(complex128))", "seq(tensor(float8e4m3fn))", "seq(tensor(float8e4m3fnuz))",
0908 "seq(tensor(float8e5m2))", "seq(tensor(float8e5m2fnuz))", "seq(tensor(uint4))",
0909 "seq(tensor(int4))"};
0910 return all_tensor_sequence_types_ir10;
0911 }
0912
0913 static const std::vector<std::string>& all_optional_types() {
0914 static const std::vector<std::string> all_optional_types = {
0915 "optional(seq(tensor(uint8)))", "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0916 "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))", "optional(seq(tensor(int16)))",
0917 "optional(seq(tensor(int32)))", "optional(seq(tensor(int64)))", "optional(seq(tensor(float16)))",
0918 "optional(seq(tensor(float)))", "optional(seq(tensor(double)))", "optional(seq(tensor(string)))",
0919 "optional(seq(tensor(bool)))", "optional(seq(tensor(complex64)))", "optional(seq(tensor(complex128)))",
0920 "optional(tensor(uint8))", "optional(tensor(uint16))", "optional(tensor(uint32))",
0921 "optional(tensor(uint64))", "optional(tensor(int8))", "optional(tensor(int16))",
0922 "optional(tensor(int32))", "optional(tensor(int64))", "optional(tensor(float16))",
0923 "optional(tensor(float))", "optional(tensor(double))", "optional(tensor(string))",
0924 "optional(tensor(bool))", "optional(tensor(complex64))", "optional(tensor(complex128))"};
0925 return all_optional_types;
0926 }
0927
0928 static const std::vector<std::string>& all_optional_types_ir4() {
0929 static const std::vector<std::string> all_optional_types = {
0930 "optional(seq(tensor(uint8)))", "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0931 "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))", "optional(seq(tensor(int16)))",
0932 "optional(seq(tensor(int32)))", "optional(seq(tensor(int64)))", "optional(seq(tensor(bfloat16)))",
0933 "optional(seq(tensor(float16)))", "optional(seq(tensor(float)))", "optional(seq(tensor(double)))",
0934 "optional(seq(tensor(string)))", "optional(seq(tensor(bool)))", "optional(seq(tensor(complex64)))",
0935 "optional(seq(tensor(complex128)))", "optional(tensor(uint8))", "optional(tensor(uint16))",
0936 "optional(tensor(uint32))", "optional(tensor(uint64))", "optional(tensor(int8))",
0937 "optional(tensor(int16))", "optional(tensor(int32))", "optional(tensor(int64))",
0938 "optional(tensor(bfloat16))", "optional(tensor(float16))", "optional(tensor(float))",
0939 "optional(tensor(double))", "optional(tensor(string))", "optional(tensor(bool))",
0940 "optional(tensor(complex64))", "optional(tensor(complex128))"};
0941 return all_optional_types;
0942 }
0943
0944 static const std::vector<std::string>& all_optional_types_ir9() {
0945 static const std::vector<std::string> all_optional_types = {
0946 "optional(seq(tensor(uint8)))", "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0947 "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))", "optional(seq(tensor(int16)))",
0948 "optional(seq(tensor(int32)))", "optional(seq(tensor(int64)))", "optional(seq(tensor(bfloat16)))",
0949 "optional(seq(tensor(float16)))", "optional(seq(tensor(float)))", "optional(seq(tensor(double)))",
0950 "optional(seq(tensor(string)))", "optional(seq(tensor(bool)))", "optional(seq(tensor(complex64)))",
0951 "optional(seq(tensor(complex128)))", "optional(tensor(uint8))", "optional(tensor(uint16))",
0952 "optional(tensor(uint32))", "optional(tensor(uint64))", "optional(tensor(int8))",
0953 "optional(tensor(int16))", "optional(tensor(int32))", "optional(tensor(int64))",
0954 "optional(tensor(bfloat16))", "optional(tensor(float16))", "optional(tensor(float))",
0955 "optional(tensor(double))", "optional(tensor(string))", "optional(tensor(bool))",
0956 "optional(tensor(complex64))", "optional(tensor(complex128))", "optional(tensor(float8e4m3fn))",
0957 "optional(tensor(float8e4m3fnuz))", "optional(tensor(float8e5m2))", "optional(tensor(float8e5m2fnuz))"};
0958 return all_optional_types;
0959 }
0960
0961 static const std::vector<std::string>& all_optional_types_ir10() {
0962 static const std::vector<std::string> all_optional_types = {
0963 "optional(seq(tensor(uint8)))", "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0964 "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))", "optional(seq(tensor(int16)))",
0965 "optional(seq(tensor(int32)))", "optional(seq(tensor(int64)))", "optional(seq(tensor(bfloat16)))",
0966 "optional(seq(tensor(float16)))", "optional(seq(tensor(float)))", "optional(seq(tensor(double)))",
0967 "optional(seq(tensor(string)))", "optional(seq(tensor(bool)))", "optional(seq(tensor(complex64)))",
0968 "optional(seq(tensor(complex128)))", "optional(tensor(uint8))", "optional(tensor(uint16))",
0969 "optional(tensor(uint32))", "optional(tensor(uint64))", "optional(tensor(int8))",
0970 "optional(tensor(int16))", "optional(tensor(int32))", "optional(tensor(int64))",
0971 "optional(tensor(bfloat16))", "optional(tensor(float16))", "optional(tensor(float))",
0972 "optional(tensor(double))", "optional(tensor(string))", "optional(tensor(bool))",
0973 "optional(tensor(complex64))", "optional(tensor(complex128))", "optional(tensor(float8e4m3fn))",
0974 "optional(tensor(float8e4m3fnuz))", "optional(tensor(float8e5m2))", "optional(tensor(float8e5m2fnuz))",
0975 "optional(tensor(uint4))", "optional(tensor(int4))"};
0976 return all_optional_types;
0977 }
0978
0979
0980
0981 OpSchema& FillUsing(const std::function<void(OpSchema&)>& populator);
0982
0983 friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema);
0984
0985 const std::string& domain() const {
0986 return domain_;
0987 }
0988
0989 const std::map<std::string, Attribute>& attributes() const {
0990 return attributes_;
0991 }
0992
0993
0994 const std::vector<FormalParameter>& inputs() const {
0995 return inputs_;
0996 }
0997
0998
0999 const std::vector<FormalParameter>& outputs() const {
1000 return outputs_;
1001 }
1002
1003 const std::vector<TypeConstraintParam>& typeConstraintParams() const {
1004 return type_constraint_params_;
1005 }
1006
1007 const TypeConstraintMap& typeConstraintMap() const {
1008 return type_constraints_;
1009 }
1010
1011 const std::string& Name() const {
1012 return name_;
1013 }
1014
1015 OperatorSetVersion SinceVersion() const {
1016 return since_version_;
1017 }
1018
1019 int since_version() const {
1020 return since_version_;
1021 }
1022
1023 bool deprecated() const {
1024 return deprecated_;
1025 }
1026
1027 int min_input() const {
1028 return min_input_;
1029 }
1030 int max_input() const {
1031 return max_input_;
1032 }
1033 int min_output() const {
1034 return min_output_;
1035 }
1036 int max_output() const {
1037 return max_output_;
1038 }
1039
1040 bool has_type_and_shape_inference_function() const {
1041 return tensor_inference_function_ ? true : false;
1042 }
1043
1044 bool has_data_propagation_function() const {
1045 return data_propagation_function_ ? true : false;
1046 }
1047
1048 std::vector<int> function_opset_versions() const {
1049 std::vector<int> opset_versions;
1050 std::map<int, std::shared_ptr<FunctionProto>>::const_iterator it = opset_version_to_function_body_.cbegin();
1051 for (; it != opset_version_to_function_body_.cend(); ++it) {
1052 opset_versions.push_back(it->first);
1053 }
1054 return opset_versions;
1055 }
1056
1057 bool HasFunction() const {
1058 return !opset_version_to_function_body_.empty();
1059 }
1060
1061 OpSchema& FunctionBody(const std::vector<NodeProto>& func_nodes, int opset_version = kUninitializedSinceVersion);
1062
1063 OpSchema& FunctionBody(
1064 const std::vector<NodeProto>& func_nodes,
1065 const std::vector<OperatorSetIdProto>& opsets,
1066 int opset_version = kUninitializedSinceVersion);
1067
1068 OpSchema& FunctionBody(const char* func_body, int opset_version = kUninitializedSinceVersion);
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087 const FunctionProto* GetFunction(
1088 int requested_opset_version = OpSchema::kUninitializedSinceVersion,
1089 bool validate = false) const;
1090
1091 std::vector<int> context_dependent_function_opset_versions() const {
1092 std::vector<int> opset_versions;
1093 std::map<int, ContextDependentFunctionBodyBuilder>::const_iterator it = opset_version_to_function_builder_.cbegin();
1094 for (; it != opset_version_to_function_builder_.cend(); ++it) {
1095 opset_versions.push_back(it->first);
1096 }
1097 return opset_versions;
1098 }
1099
1100 bool HasContextDependentFunction() const {
1101 return !opset_version_to_function_builder_.empty();
1102 }
1103
1104 bool HasContextDependentFunctionWithOpsetVersion(int opset_version) const {
1105 return opset_version_to_function_builder_.find(opset_version) != opset_version_to_function_builder_.end();
1106 }
1107
1108 OpSchema& SetContextDependentFunctionBodyBuilder(
1109 ContextDependentFunctionBodyBuilder,
1110 int opset_version = kUninitializedSinceVersion);
1111
1112 bool BuildContextDependentFunction(
1113 const FunctionBodyBuildContext& ctx,
1114 FunctionProto& function_proto,
1115 int requested_opset_version = OpSchema::kUninitializedSinceVersion) const;
1116
1117
1118
1119
1120
1121 void Finalize();
1122
1123
1124 void BuildFunction(FunctionProto& function_body) const;
1125
1126 private:
1127 void ParseAndSetTypes(
1128 std::vector<OpSchema::FormalParameter>* formalParameters);
1129 bool ValidateReferencedOpsInFuncton(
1130 const FunctionProto* function,
1131 int requested_opset_version,
1132 int function_since_version,
1133 std::set<std::string>* updated_ops = nullptr) const;
1134 void UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int opset_version) const;
1135
1136
1137
1138
1139
1140
1141 std::string VerifyFailPrefix(std::string_view node_name) const;
1142
1143
1144
1145
1146
1147
1148 void VerifyInputNum(int input_num, std::string_view node_name = "") const;
1149
1150
1151
1152
1153
1154
1155 void VerifyOutputNum(int output_num, std::string_view node_name = "") const;
1156
1157 std::string name_;
1158 std::string file_;
1159 std::string doc_;
1160
1161 std::string domain_ = ONNX_DOMAIN;
1162 std::map<std::string, Attribute> attributes_{};
1163 bool allows_unchecked_attributes_ = false;
1164 std::vector<FormalParameter> inputs_;
1165 std::vector<FormalParameter> outputs_;
1166 std::vector<TypeConstraintParam> type_constraint_params_;
1167 TypeConstraintMap type_constraints_;
1168 int line_ = 0;
1169 SupportType support_;
1170 int min_input_ = 0;
1171 int max_input_ = 0;
1172 int min_output_ = 0;
1173 int max_output_ = 0;
1174
1175 OperatorSetVersion since_version_ = kUninitializedSinceVersion;
1176 bool deprecated_{};
1177 std::function<bool(int)> num_inputs_allowed_ = [](int) { return true; };
1178 std::function<bool(int)> num_outputs_allowed_ = [](int) { return true; };
1179 InferenceFunction tensor_inference_function_;
1180 DataPropagationFunction data_propagation_function_;
1181
1182 std::map<int, std::shared_ptr<FunctionProto>> opset_version_to_function_body_;
1183 std::map<int, ContextDependentFunctionBodyBuilder> opset_version_to_function_builder_;
1184 };
1185
1186
1187
1188 using OpName_Domain_Version_Schema_Map =
1189 std::unordered_map<std::string, std::unordered_map<std::string, std::map<OperatorSetVersion, OpSchema>>>;
1190
1191 class ISchemaRegistry {
1192 public:
1193 virtual ~ISchemaRegistry() = default;
1194
1195 virtual const OpSchema*
1196 GetSchema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) const = 0;
1197 };
1198
1199
1200
1201
1202 class OpSchemaRegistry final : public ISchemaRegistry {
1203 public:
1204
1205
1206 class DomainToVersionRange final {
1207 public:
1208 DomainToVersionRange() {
1209
1210
1211
1212 map_[ONNX_DOMAIN] = std::make_pair(1, 22);
1213 map_[AI_ONNX_ML_DOMAIN] = std::make_pair(1, 5);
1214 map_[AI_ONNX_TRAINING_DOMAIN] = std::make_pair(1, 1);
1215
1216
1217
1218 map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 1);
1219
1220
1221
1222 last_release_version_map_[ONNX_DOMAIN] = 22;
1223 last_release_version_map_[AI_ONNX_ML_DOMAIN] = 5;
1224 last_release_version_map_[AI_ONNX_TRAINING_DOMAIN] = 1;
1225 last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 1;
1226 }
1227
1228 const std::unordered_map<std::string, std::pair<int, int>>& Map() const {
1229 return map_;
1230 }
1231
1232 const std::unordered_map<std::string, int>& LastReleaseVersionMap() const {
1233 return last_release_version_map_;
1234 }
1235
1236
1237
1238
1239
1240
1241
1242
1243 void
1244 AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) {
1245 std::lock_guard<std::mutex> lock(mutex_);
1246 if (map_.count(domain) != 0) {
1247 std::stringstream err;
1248 err << "Trying to add a domain to DomainToVersion map, but the domain is already exist with version range ("
1249 << map_.at(domain).first << ", " << map_.at(domain).second << "). domain: \"" << domain << "\""
1250 << std::endl;
1251 fail_schema(err.str());
1252 }
1253 if (last_release_version_map_.count(domain) != 0) {
1254 std::stringstream err;
1255 err << "Trying to add a domain to LastReleaseVersion map, but the domain is already exist with last version: "
1256 << last_release_version_map_.at(domain) << ", domain: \"" << domain << "\"" << std::endl;
1257 fail_schema(err.str());
1258 }
1259 map_[domain] = std::make_pair(min_version, max_version);
1260
1261
1262 if (last_release_version == -1) {
1263 last_release_version = max_version;
1264 }
1265 last_release_version_map_[domain] = last_release_version;
1266 }
1267
1268 void
1269 UpdateDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) {
1270 std::lock_guard<std::mutex> lock(mutex_);
1271 if (map_.count(domain) == 0) {
1272 std::stringstream err;
1273 err << "Trying to update a domain in DomainToVersion map, but the domain has not been add. domain: \"" << domain
1274 << "\"" << std::endl;
1275 fail_schema(err.str());
1276 }
1277 if (last_release_version_map_.count(domain) == 0) {
1278 std::stringstream err;
1279 err << "Trying to update a domain in LastReleaseVersion map, but the domain has not been add. domain: \""
1280 << domain << "\"" << std::endl;
1281 fail_schema(err.str());
1282 }
1283 map_.at(domain).first = min_version;
1284 map_.at(domain).second = max_version;
1285
1286 if (last_release_version == -1) {
1287 last_release_version = max_version;
1288 }
1289 last_release_version_map_.at(domain) = last_release_version;
1290 }
1291
1292 static DomainToVersionRange& Instance();
1293
1294 private:
1295
1296 std::unordered_map<std::string, std::pair<int, int>> map_;
1297
1298
1299
1300
1301 std::unordered_map<std::string, int> last_release_version_map_;
1302
1303 std::mutex mutex_;
1304 };
1305
1306 class OpSchemaRegisterOnce final {
1307 public:
1308
1309 OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1310 OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema);
1311 }
1312 static void
1313 OpSchemaRegisterNoExcept(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1314 ONNX_TRY {
1315 OpSchemaRegisterImpl(std::move(op_schema), opset_version_to_load, fail_duplicate_schema);
1316 }
1317 ONNX_CATCH(const std::exception& e) {
1318 ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; });
1319 }
1320 }
1321 static void
1322 OpSchemaRegisterImpl(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1323 op_schema.Finalize();
1324 auto& m = GetMapWithoutEnsuringRegistration();
1325 auto& op_name = op_schema.Name();
1326 auto& op_domain = op_schema.domain();
1327 auto& schema_ver_map = m[op_name][op_domain];
1328 auto ver = op_schema.SinceVersion();
1329 if (OpSchema::kUninitializedSinceVersion == ver) {
1330 op_schema.SinceVersion(1);
1331 ver = op_schema.SinceVersion();
1332 }
1333
1334
1335 if (schema_ver_map.count(ver)) {
1336 if (fail_duplicate_schema) {
1337 const auto& schema = schema_ver_map[ver];
1338 std::stringstream err;
1339 err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1340 << ") from file " << op_schema.file() << " line " << op_schema.line()
1341 << ", but it is already registered from file " << schema.file() << " line " << schema.line() << std::endl;
1342 fail_schema(err.str());
1343 }
1344 return;
1345 }
1346
1347 if (opset_version_to_load != 0) {
1348
1349 if (ver > opset_version_to_load) {
1350 return;
1351 }
1352
1353
1354 if (!schema_ver_map.empty()) {
1355 int max_registered_ver_le_target = GetMaxRegisteredVerWithinTarget(schema_ver_map, opset_version_to_load);
1356 if (max_registered_ver_le_target >= ver) {
1357 return;
1358 }
1359 }
1360 }
1361
1362 CheckDomainAndVersionToRegister(op_schema, op_name, op_domain);
1363 schema_ver_map.insert(std::pair<int, OpSchema&&>(ver, std::move(op_schema)));
1364 }
1365
1366 private:
1367
1368 static int GetMaxRegisteredVerWithinTarget(const std::map<OperatorSetVersion, OpSchema>& m, int target_ver) {
1369
1370
1371 for (auto&& it = m.rbegin(); it != m.rend(); it++) {
1372 const auto& registered_ver = it->first;
1373 if (registered_ver <= target_ver) {
1374 return registered_ver;
1375 }
1376 }
1377 return -1;
1378 }
1379
1380 static void CheckDomainAndVersionToRegister(
1381 const OpSchema& op_schema,
1382 const std::string& op_name,
1383 const std::string& op_domain) {
1384 auto ver_range_map = DomainToVersionRange::Instance().Map();
1385 auto ver_range_it = ver_range_map.find(op_domain);
1386 auto ver = op_schema.SinceVersion();
1387
1388 if (ver_range_it == ver_range_map.end()) {
1389 std::stringstream err;
1390 err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1391 << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its domain is not"
1392 << " known by the checker." << std::endl;
1393
1394 fail_schema(err.str());
1395 }
1396 auto lower_bound_incl = ver_range_it->second.first;
1397 auto upper_bound_incl = ver_range_it->second.second;
1398 if (!(lower_bound_incl <= ver && upper_bound_incl >= ver)) {
1399 std::stringstream err;
1400 err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1401 << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its version is not "
1402 << "in the inclusive range [" << lower_bound_incl << ", " << upper_bound_incl
1403 << "] (usually, this means you "
1404 << "bumped the operator version but "
1405 << "forgot to update the version range in DomainToVersionRange "
1406 << "in onnx/defs/schema.h)." << std::endl;
1407 fail_schema(err.str());
1408 }
1409 }
1410 };
1411
1412 static void
1413 OpSchemaDeregister(const std::string& op_type, const int version, const std::string& domain = ONNX_DOMAIN) {
1414 auto& schema_map = GetMapWithoutEnsuringRegistration();
1415 if (schema_map.count(op_type) && schema_map[op_type].count(domain) && schema_map[op_type][domain].count(version)) {
1416 schema_map[op_type][domain].erase(version);
1417 } else {
1418 std::stringstream err;
1419 err << "Attempting to deregister an unregistered schema with name: " << op_type << " domain: " << domain
1420 << " version: " << version << std::endl;
1421 fail_schema(err.str());
1422 }
1423 }
1424
1425
1426
1427 static void OpSchemaDeregisterAll(const std::string& domain = ONNX_DOMAIN) {
1428 auto& schema_map = GetMapWithoutEnsuringRegistration();
1429
1430
1431 for (auto&& schema_map_pair : schema_map) {
1432 auto& domain_map = schema_map_pair.second;
1433 if (domain_map.count(domain)) {
1434 auto& opset_version_schema_map = domain_map[domain];
1435
1436 opset_version_schema_map.clear();
1437 domain_map.erase(domain);
1438 }
1439 }
1440 }
1441
1442
1443
1444 static const OpSchema* Schema(const std::string& key, const std::string& domain = ONNX_DOMAIN) {
1445 auto& m = map();
1446 if (m.count(key) && m[key].count(domain)) {
1447 const auto& schema_ver_map = m[key][domain];
1448 if (!schema_ver_map.empty()) {
1449 return &m[key][domain].rbegin()->second;
1450 }
1451 }
1452 return nullptr;
1453 }
1454
1455
1456
1457
1458 static const OpSchema*
1459 Schema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) {
1460 auto& m = map();
1461 if (m.count(key) && m[key].count(domain)) {
1462 const auto& schema_ver_map = m[key][domain];
1463 if (!schema_ver_map.empty()) {
1464 auto pos = m[key][domain].lower_bound(maxInclusiveVersion);
1465 if (m[key][domain].begin() == pos && pos->first > maxInclusiveVersion) {
1466
1467 return nullptr;
1468 }
1469 if (m[key][domain].end() == pos || pos->first > maxInclusiveVersion) {
1470
1471
1472 pos--;
1473 }
1474
1475
1476 return &(pos->second);
1477 }
1478 }
1479 return nullptr;
1480 }
1481
1482 static OpSchemaRegistry* Instance();
1483
1484 const OpSchema* GetSchema(
1485 const std::string& key,
1486 const int maxInclusiveVersion,
1487 const std::string& domain = ONNX_DOMAIN) const override {
1488 return Schema(key, maxInclusiveVersion, domain);
1489 }
1490 static void SetLoadedSchemaVersion(int target_version) {
1491 loaded_schema_version = target_version;
1492 }
1493 static int GetLoadedSchemaVersion() {
1494 return loaded_schema_version;
1495 }
1496
1497 private:
1498
1499
1500 OpSchemaRegistry() = default;
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512 static OpName_Domain_Version_Schema_Map& GetMapWithoutEnsuringRegistration();
1513 static OpName_Domain_Version_Schema_Map& map();
1514 static int loaded_schema_version;
1515
1516 public:
1517 static const std::vector<OpSchema> get_all_schemas_with_history() {
1518 std::vector<OpSchema> r;
1519 for (auto& x : map()) {
1520 for (auto& y : x.second) {
1521 for (auto& z : y.second) {
1522 r.emplace_back(z.second);
1523 }
1524 }
1525 }
1526 return r;
1527 }
1528
1529 static const std::vector<OpSchema> get_all_schemas() {
1530 std::vector<OpSchema> r;
1531 for (auto& x : map()) {
1532 for (auto& y : x.second) {
1533 auto& version2schema = y.second;
1534 if (!version2schema.empty()) {
1535 r.emplace_back(version2schema.rbegin()->second);
1536 }
1537 }
1538 }
1539 return r;
1540 }
1541 };
1542
1543 void RegisterSchema(
1544 const OpSchema& schema,
1545 int opset_version_to_load = 0,
1546 bool fail_duplicate_schema = true,
1547 bool fail_with_exception = false);
1548 void RegisterSchema(
1549 OpSchema&& schema,
1550 int opset_version_to_load = 0,
1551 bool fail_duplicate_schema = true,
1552 bool fail_with_exception = false);
1553 void DeregisterSchema(const std::string& op_type, int version, const std::string& domain);
1554
1555
1556
1557 template <class T>
1558 void RegisterOpSetSchema(int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1559 T::ForEachSchema([opset_version_to_load, fail_duplicate_schema](OpSchema&& schema) {
1560 RegisterSchema(std::move(schema), opset_version_to_load, fail_duplicate_schema);
1561 });
1562 };
1563
1564
1565
1566
1567 template <typename T>
1568 OpSchema GetOpSchema();
1569
1570 #define ONNX_OPERATOR_SET_SCHEMA(name, ver, impl) ONNX_OPERATOR_SET_SCHEMA_EX(name, Onnx, ONNX_DOMAIN, ver, true, impl)
1571
1572 #define ONNX_ML_OPERATOR_SET_SCHEMA(name, ver, impl) \
1573 ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxML, AI_ONNX_ML_DOMAIN, ver, true, impl)
1574
1575 #define ONNX_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
1576 ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxTraining, AI_ONNX_TRAINING_DOMAIN, ver, true, impl)
1577
1578 #define ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
1579 ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxPreview, AI_ONNX_PREVIEW_TRAINING_DOMAIN, ver, true, impl)
1580
1581
1582
1583
1584
1585
1586
1587 #define ONNX_OPERATOR_SET_SCHEMA_EX(name, domain, domain_str, ver, dbg_included_in_static_opset, impl) \
1588 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name); \
1589 template <> \
1590 OpSchema GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name)>() { \
1591 return impl.SetName(#name).SetDomain(domain_str).SinceVersion(ver).SetLocation(__FILE__, __LINE__); \
1592 } \
1593 size_t dbg_count_check_##name##_##domain##_ver##ver = \
1594 (dbg_included_in_static_opset) ? ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() : 0;
1595 #ifdef NDEBUG
1596 #define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() 0
1597 #else
1598 #define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().IncrementCount()
1599 #define ONNX_DBG_GET_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().GetCount()
1600
1601 class DbgOperatorSetTracker {
1602 public:
1603 static DbgOperatorSetTracker& Instance();
1604
1605 size_t IncrementCount() {
1606 return ++count_;
1607 }
1608
1609 size_t GetCount() const {
1610 return count_;
1611 }
1612
1613 private:
1614 size_t count_ = 0;
1615 };
1616 #endif
1617
1618
1619 #define ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name) name##_##domain##_ver##ver
1620
1621
1622 #define ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(ver, name) \
1623 ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxPreview, ver, name)
1624
1625
1626 size_t ReplaceAll(std::string& s, const char* from, const char* to);
1627
1628 #ifdef __GNUC__
1629 #define ONNX_UNUSED __attribute__((__unused__))
1630 #else
1631 #define ONNX_UNUSED
1632 #endif
1633
1634
1635 #define ONNX_OPERATOR_SCHEMA(name) ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name)
1636 #define ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)
1637 #define ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name) \
1638 static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce(op_schema_register_once##name##Counter) ONNX_UNUSED = \
1639 OpSchema(#name, __FILE__, __LINE__)
1640
1641
1642 size_t ReplaceAll(std::string& s, const char* from, const char* to);
1643
1644 inline std::string GenerateOptionalArgumentsDoc() {
1645 return "This operator has **optional** inputs/outputs. "
1646 "See [the doc](IR.md) for more details about the representation of "
1647 "optional arguments. An empty string may be used in the place of "
1648 "an actual argument's name to indicate a missing argument. "
1649 "Trailing optional arguments (those not followed by an argument "
1650 "that is present) may also be simply omitted.\n";
1651 }
1652
1653 inline std::string GenerateBroadcastingDocMul() {
1654 return "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**;"
1655 " for more details please check [the doc](Broadcasting.md).";
1656 }
1657
1658 inline std::string GenerateBroadcastingDocUni(const char* from, const char* to) {
1659 std::string ret = "This operator supports **unidirectional broadcasting** (";
1660 ret = ret + from + " should be unidirectional broadcastable to " + to +
1661 ");"
1662 " for more details please check [the doc](Broadcasting.md).";
1663 return ret;
1664 }
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677 #ifndef __ONNX_NO_DOC_STRINGS
1678 #define GET_OP_DOC_STR(doc_str) (doc_str)
1679 #else
1680 #define GET_OP_DOC_STR(doc_str) ("")
1681 #endif
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702 #ifndef __ONNX_NO_DOC_STRINGS
1703 #define POPULATE_OP_DOC_STR(DocPopulatorCode) \
1704 do { \
1705 DocPopulatorCode \
1706 } while (0)
1707 #else
1708 #define POPULATE_OP_DOC_STR(DocPopulatorCode)
1709 #endif
1710
1711 }