File indexing completed on 2025-02-22 10:42:46
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include <climits>
0008 #include <cstring>
0009 #include <functional>
0010 #include <initializer_list>
0011 #include <iostream>
0012 #include <limits>
0013 #include <map>
0014 #include <memory>
0015 #include <ostream>
0016 #include <set>
0017 #include <string>
0018 #include <tuple>
0019 #include <unordered_map>
0020 #include <unordered_set>
0021 #include <utility>
0022 #include <vector>
0023
0024 #include "onnx/common/common.h"
0025 #include "onnx/common/constants.h"
0026 #include "onnx/defs/shape_inference.h"
0027
0028 namespace ONNX_NAMESPACE {
0029
0030 struct FunctionBodyBuildContext {
0031 virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0032 virtual bool hasInput(int inputIndex) const = 0;
0033 virtual bool hasOutput(int inputIndex) const = 0;
0034
0035
0036 virtual const TypeProto* getInputType(int inputIndex) const = 0;
0037 virtual ~FunctionBodyBuildContext() {}
0038 };
0039
0040 struct FunctionBodyBuildContextImpl : public FunctionBodyBuildContext {
0041
0042
0043
0044
0045
0046 FunctionBodyBuildContextImpl(const NodeProto& node_proto, const std::vector<TypeProto>& input_types = {})
0047 : node_proto_(node_proto), input_types_(input_types) {
0048 for (auto& attr : node_proto.attribute()) {
0049 attributesByName_[attr.name()] = &attr;
0050 }
0051 }
0052
0053 const AttributeProto* getAttribute(const std::string& name) const override {
0054 auto iter = attributesByName_.find(name);
0055 if (iter == attributesByName_.end()) {
0056 return nullptr;
0057 } else {
0058 return iter->second;
0059 }
0060 }
0061
0062 bool hasInput(int inputIndex) const override {
0063 if (inputIndex >= node_proto_.input_size())
0064 return false;
0065 return node_proto_.input(inputIndex) != "";
0066 }
0067
0068 bool hasOutput(int inputIndex) const override {
0069 if (inputIndex >= node_proto_.output_size())
0070 return false;
0071 return node_proto_.output(inputIndex) != "";
0072 }
0073
0074 const TypeProto* getInputType(int inputIndex) const override {
0075 if (inputIndex < 0)
0076 return nullptr;
0077 size_t j = static_cast<size_t>(inputIndex);
0078 if (j >= input_types_.size())
0079 return nullptr;
0080
0081 if (input_types_[j].value_case() == TypeProto::ValueCase::VALUE_NOT_SET)
0082 return nullptr;
0083 return &input_types_[j];
0084 }
0085
0086 std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0087
0088 NodeProto node_proto_;
0089 std::vector<TypeProto> input_types_;
0090 };
0091
0092 using FunctionBodyQueryFunction = std::function<bool(FunctionBodyBuildContext&)>;
0093
0094 class OpSchema;
0095 using ContextDependentFunctionBodyBuilder =
0096 std::function<bool(const FunctionBodyBuildContext&, const OpSchema&, FunctionProto&)>;
0097
0098 class SchemaError final : public std::runtime_error {
0099 public:
0100 using std::runtime_error::runtime_error;
0101
0102 SchemaError(const std::string& message) : std::runtime_error(message) {}
0103
0104 const char* what() const noexcept override {
0105 if (!expanded_message_.empty()) {
0106 return expanded_message_.c_str();
0107 }
0108 return std::runtime_error::what();
0109 }
0110
0111 void AppendContext(const std::string& context) {
0112 expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
0113 }
0114
0115 private:
0116 std::string expanded_message_;
0117 };
0118
0119 #define fail_schema(...) ONNX_THROW_EX(ONNX_NAMESPACE::SchemaError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
0120
0121 using OperatorSetVersion = int;
0122
0123 using DataTypeSet = std::unordered_set<DataType>;
0124
0125
0126
0127 using TypeConstraintMap = std::unordered_map<std::string, std::pair<DataTypeSet, std::string>>;
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148 class OpSchema final {
0149 public:
0150 static constexpr int kUninitializedSinceVersion = -1;
0151
0152 enum FormalParameterOption : uint8_t {
0153
0154
0155 Single = 0,
0156
0157
0158 Optional = 1,
0159
0160
0161
0162 Variadic = 2,
0163 };
0164 enum DifferentiationCategory : uint8_t {
0165
0166
0167
0168
0169 Unknown = 0,
0170
0171
0172 Differentiable = 1,
0173
0174
0175 NonDifferentiable = 2
0176 };
0177
0178
0179
0180 class FormalParameter final {
0181 public:
0182
0183 FormalParameter() = default;
0184
0185 explicit FormalParameter(
0186 std::string name,
0187 DataTypeSet allowed_type_set,
0188 std::string type_str,
0189 const std::string& description,
0190 FormalParameterOption param_option = Single,
0191 bool is_homogeneous = true,
0192 int min_arity = 1,
0193 DifferentiationCategory differentiation_category = Unknown)
0194 : name_(std::move(name)),
0195 type_set_(std::move(allowed_type_set)),
0196 type_str_(std::move(type_str)),
0197 #ifndef __ONNX_NO_DOC_STRINGS
0198 description_(description),
0199 #endif
0200 param_option_(param_option),
0201 is_homogeneous_(is_homogeneous),
0202 min_arity_(min_arity),
0203 differentiation_category_(differentiation_category) {
0204 #ifdef __ONNX_NO_DOC_STRINGS
0205 ONNX_UNUSED_PARAMETER(description);
0206 #endif
0207 }
0208
0209 explicit FormalParameter(
0210 std::string name,
0211 const std::string& description,
0212 std::string type_str,
0213 FormalParameterOption param_option = Single,
0214 bool is_homogeneous = true,
0215 int min_arity = 1,
0216 DifferentiationCategory differentiation_category = Unknown)
0217 : name_(std::move(name)),
0218 type_str_(std::move(type_str)),
0219 #ifndef __ONNX_NO_DOC_STRINGS
0220 description_(description),
0221 #endif
0222 param_option_(param_option),
0223 is_homogeneous_(is_homogeneous),
0224 min_arity_(min_arity),
0225 differentiation_category_(differentiation_category) {
0226 #ifdef __ONNX_NO_DOC_STRINGS
0227 ONNX_UNUSED_PARAMETER(description);
0228 #endif
0229 }
0230
0231
0232 const std::string& GetName() const;
0233
0234
0235 const DataTypeSet& GetTypes() const;
0236
0237
0238 const std::string& GetTypeStr() const;
0239
0240
0241 const std::string& GetDescription() const;
0242
0243
0244 FormalParameterOption GetOption() const;
0245
0246
0247 bool GetIsHomogeneous() const;
0248
0249
0250 int GetMinArity() const;
0251
0252
0253 DifferentiationCategory GetDifferentiationCategory() const;
0254
0255 private:
0256 friend class OpSchema;
0257
0258 DataTypeSet& MutableTypes();
0259
0260
0261 std::string name_;
0262
0263
0264
0265 DataTypeSet type_set_;
0266
0267
0268
0269
0270 std::string type_str_;
0271
0272
0273 std::string description_;
0274
0275
0276 FormalParameterOption param_option_;
0277
0278
0279
0280 bool is_homogeneous_;
0281
0282
0283 int min_arity_;
0284
0285
0286
0287
0288 DifferentiationCategory differentiation_category_;
0289 };
0290
0291 enum class SupportType : uint8_t {
0292 COMMON,
0293 EXPERIMENTAL,
0294
0295 };
0296
0297 OpSchema() : OpSchema("unknown", "unknown", 0) {}
0298 OpSchema(std::string name, std::string file, int line)
0299 : name_(std::move(name)), file_(std::move(file)), line_(line), support_(SupportType::COMMON) {}
0300
0301
0302
0303
0304 const std::string& file() const {
0305 return file_;
0306 }
0307
0308
0309
0310
0311 int line() const {
0312 return line_;
0313 }
0314
0315
0316
0317
0318 SupportType support_level() const {
0319 return support_;
0320 }
0321
0322
0323
0324
0325 const char* doc() const {
0326 return doc_.empty() ? nullptr : doc_.c_str();
0327 }
0328
0329
0330 void CheckInputOutputType(struct InferenceContext&) const;
0331
0332
0333
0334
0335
0336 void Verify(const NodeProto& node) const;
0337
0338
0339
0340
0341
0342
0343
0344
0345
0346
0347
0348
0349
0350
0351
0352
0353
0354 OpSchema& SinceVersion(OperatorSetVersion n);
0355
0356
0357
0358
0359
0360
0361 OpSchema& Deprecate();
0362
0363 bool Deprecated() const {
0364 return deprecated_;
0365 }
0366
0367
0368
0369
0370 OpSchema& NumInputs(std::set<int> allowed_input_nums);
0371
0372
0373
0374
0375 OpSchema& NumOutputs(std::set<int> allowed_output_nums);
0376
0377
0378
0379
0380
0381 OpSchema& TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
0382 InferenceFunction GetTypeAndShapeInferenceFunction() const {
0383 return tensor_inference_function_ ? tensor_inference_function_ : dummyInferenceFunction;
0384 }
0385
0386 OpSchema& PartialDataPropagationFunction(DataPropagationFunction dataProgationFunction);
0387 DataPropagationFunction GetDataPropagationFunction() const {
0388 return data_propagation_function_ ? data_propagation_function_ : dummyDataPropagationFunction;
0389 }
0390
0391
0392 OpSchema& SetSupportLevel(SupportType supportType);
0393
0394
0395
0396 OpSchema& SetDoc(const char* doc) {
0397 #ifndef __ONNX_NO_DOC_STRINGS
0398 SetDoc(std::string(doc));
0399 #else
0400 ONNX_UNUSED_PARAMETER(doc);
0401 #endif
0402
0403 return *this;
0404 }
0405
0406 OpSchema& SetDoc(const std::string& doc) {
0407 #ifndef __ONNX_NO_DOC_STRINGS
0408 doc_ = doc;
0409 #else
0410 ONNX_UNUSED_PARAMETER(doc);
0411 #endif
0412 return *this;
0413 }
0414
0415
0416 OpSchema& SetName(const char* name);
0417 OpSchema& SetName(std::string name);
0418
0419
0420 OpSchema& SetLocation(const char* file, int line);
0421 OpSchema& SetLocation(std::string file, int line);
0422
0423
0424
0425 OpSchema& SetDomain(const char* domain);
0426 OpSchema& SetDomain(std::string domain);
0427
0428 struct Attribute final {
0429 Attribute(std::string name_, std::string description_, AttributeProto::AttributeType type_, bool required_)
0430 : name(std::move(name_)),
0431 description(std::move(description_)),
0432 type(type_),
0433 required(required_),
0434 default_value() {}
0435
0436 Attribute(std::string name_, std::string description_, AttributeProto default_value_)
0437 : name(std::move(name_)),
0438 description(std::move(description_)),
0439 type(default_value_.type()),
0440 required(false),
0441 default_value(std::move(default_value_)) {}
0442
0443 const std::string name;
0444 const std::string description;
0445 AttributeProto::AttributeType type;
0446 bool required;
0447 AttributeProto default_value;
0448 };
0449
0450 OpSchema& Attr(Attribute attr);
0451
0452
0453 #define ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName) \
0454 OpSchema& Attr( \
0455 std::string name, std::string description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
0456 \
0457 OpSchema& Attr( \
0458 const char* name, const char* description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
0459 OpSchema& Attr( \
0460 std::string name, \
0461 std::string description, \
0462 AttributeProto::AttributeType type, \
0463 const std::vector<TypeName>& defaultValue);
0464
0465 ATTR_SETTER_WITH_DEFAULT_VALUE(int64_t)
0466 ATTR_SETTER_WITH_DEFAULT_VALUE(float)
0467 ATTR_SETTER_WITH_DEFAULT_VALUE(std::string)
0468 ATTR_SETTER_WITH_DEFAULT_VALUE(TensorProto)
0469 ATTR_SETTER_WITH_DEFAULT_VALUE(GraphProto)
0470 ATTR_SETTER_WITH_DEFAULT_VALUE(TypeProto)
0471
0472 OpSchema& Attr(
0473 std::string name,
0474 std::string description,
0475 std::string conditionExplanation,
0476 AttributeProto::AttributeType attr_type);
0477
0478
0479 OpSchema& Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required = true);
0480
0481
0482 OpSchema& Attr(const char* name, const char* description, AttributeProto::AttributeType type, bool required = true);
0483
0484 OpSchema& AllowUncheckedAttributes();
0485
0486
0487 struct TypeConstraintParam final {
0488 TypeConstraintParam(
0489 std::string type_param_str_,
0490 std::vector<std::string> allowed_type_strs_,
0491 std::string description_)
0492 : type_param_str(std::move(type_param_str_)),
0493 allowed_type_strs(std::move(allowed_type_strs_)),
0494 description(std::move(description_)) {}
0495
0496
0497 std::string type_param_str;
0498
0499
0500 std::vector<std::string> allowed_type_strs;
0501
0502 std::string description;
0503 };
0504
0505
0506
0507
0508
0509
0510
0511
0512
0513
0514
0515
0516
0517
0518
0519
0520
0521
0522
0523
0524
0525
0526
0527
0528
0529
0530
0531 OpSchema& Input(int n, FormalParameter formal_parameter);
0532
0533 OpSchema& Input(
0534 int n,
0535 std::string name,
0536 const std::string& description,
0537 std::string type_str,
0538 FormalParameterOption param_option = Single,
0539 bool is_homogeneous = true,
0540 int min_arity = 1,
0541 DifferentiationCategory differentiation_category = Unknown);
0542
0543
0544 OpSchema& Input(
0545 int n,
0546 const char* name,
0547 const char* description,
0548 const char* type_str,
0549 FormalParameterOption param_option = Single,
0550 bool is_homogeneous = true,
0551 int min_arity = 1,
0552 DifferentiationCategory differentiation_category = Unknown);
0553
0554 OpSchema& Output(int n, FormalParameter formal_parameter);
0555
0556 OpSchema& Output(
0557 int n,
0558 std::string name,
0559 const std::string& description,
0560 std::string type_str,
0561 FormalParameterOption param_option = Single,
0562 bool is_homogeneous = true,
0563 int min_arity = 1,
0564 DifferentiationCategory differentiation_category = Unknown);
0565
0566
0567 OpSchema& Output(
0568 int n,
0569 const char* name,
0570 const char* description,
0571 const char* type_str,
0572 FormalParameterOption param_option = Single,
0573 bool is_homogeneous = true,
0574 int min_arity = 1,
0575 DifferentiationCategory differentiation_category = Unknown);
0576
0577 OpSchema& TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description);
0578
0579
0580 OpSchema&
0581 TypeConstraint(const char* type_str, std::initializer_list<const char*> constraints, const char* description);
0582
0583
0584
0585
0586 static const std::vector<std::string>& numeric_types_for_math_reduction_ir9() {
0587 static const std::vector<std::string> numeric_types_for_math_reduction_ir9 = {
0588 "tensor(uint32)",
0589 "tensor(uint64)",
0590 "tensor(int32)",
0591 "tensor(int64)",
0592 "tensor(float16)",
0593 "tensor(float)",
0594 "tensor(double)",
0595 "tensor(bfloat16)",
0596 "tensor(float8e4m3fn)",
0597 "tensor(float8e4m3fnuz)",
0598 "tensor(float8e5m2)",
0599 "tensor(float8e5m2fnuz)"};
0600 return numeric_types_for_math_reduction_ir9;
0601 }
0602
0603 static const std::vector<std::string>& numeric_types_for_math_reduction_ir4() {
0604 static const std::vector<std::string> numeric_types_for_math_reduction_ir4 = {
0605 "tensor(uint32)",
0606 "tensor(uint64)",
0607 "tensor(int32)",
0608 "tensor(int64)",
0609 "tensor(float16)",
0610 "tensor(float)",
0611 "tensor(double)",
0612 "tensor(bfloat16)"};
0613 return numeric_types_for_math_reduction_ir4;
0614 }
0615
0616 static const std::vector<std::string>& numeric_types_for_math_reduction() {
0617 static const std::vector<std::string> numeric_types_for_math_reduction = {
0618 "tensor(uint32)",
0619 "tensor(uint64)",
0620 "tensor(int32)",
0621 "tensor(int64)",
0622 "tensor(float16)",
0623 "tensor(float)",
0624 "tensor(double)"};
0625 return numeric_types_for_math_reduction;
0626 }
0627
0628 static const std::vector<std::string>& all_numeric_types_ir9() {
0629 static const std::vector<std::string> all_numeric_types_ir9 = {
0630 "tensor(uint8)",
0631 "tensor(uint16)",
0632 "tensor(uint32)",
0633 "tensor(uint64)",
0634 "tensor(int8)",
0635 "tensor(int16)",
0636 "tensor(int32)",
0637 "tensor(int64)",
0638 "tensor(float16)",
0639 "tensor(float)",
0640 "tensor(double)",
0641 "tensor(bfloat16)",
0642 "tensor(float8e4m3fn)",
0643 "tensor(float8e4m3fnuz)",
0644 "tensor(float8e5m2)",
0645 "tensor(float8e5m2fnuz)"};
0646 return all_numeric_types_ir9;
0647 }
0648
0649 static const std::vector<std::string>& all_numeric_types_ir4() {
0650 static const std::vector<std::string> all_numeric_types_ir4 = {
0651 "tensor(uint8)",
0652 "tensor(uint16)",
0653 "tensor(uint32)",
0654 "tensor(uint64)",
0655 "tensor(int8)",
0656 "tensor(int16)",
0657 "tensor(int32)",
0658 "tensor(int64)",
0659 "tensor(float16)",
0660 "tensor(float)",
0661 "tensor(double)",
0662 "tensor(bfloat16)"};
0663 return all_numeric_types_ir4;
0664 }
0665
0666 static const std::vector<std::string>& all_numeric_types() {
0667 static const std::vector<std::string> all_numeric_types = {
0668 "tensor(uint8)",
0669 "tensor(uint16)",
0670 "tensor(uint32)",
0671 "tensor(uint64)",
0672 "tensor(int8)",
0673 "tensor(int16)",
0674 "tensor(int32)",
0675 "tensor(int64)",
0676 "tensor(float16)",
0677 "tensor(float)",
0678 "tensor(double)"};
0679 return all_numeric_types;
0680 }
0681
0682 static const std::vector<std::string>& all_numeric_sequence_types() {
0683 static const std::vector<std::string> all_numeric_sequence_types = {
0684 "seq(tensor(uint8))",
0685 "seq(tensor(uint16))",
0686 "seq(tensor(uint32))",
0687 "seq(tensor(uint64))",
0688 "seq(tensor(int8))",
0689 "seq(tensor(int16))",
0690 "seq(tensor(int32))",
0691 "seq(tensor(int64))",
0692 "seq(tensor(float16))",
0693 "seq(tensor(float))",
0694 "seq(tensor(double))"};
0695 return all_numeric_sequence_types;
0696 }
0697
0698 static const std::vector<std::string>& all_tensor_types() {
0699 static const std::vector<std::string> all_tensor_types = {
0700 "tensor(uint8)",
0701 "tensor(uint16)",
0702 "tensor(uint32)",
0703 "tensor(uint64)",
0704 "tensor(int8)",
0705 "tensor(int16)",
0706 "tensor(int32)",
0707 "tensor(int64)",
0708 "tensor(float16)",
0709 "tensor(float)",
0710 "tensor(double)",
0711 "tensor(string)",
0712 "tensor(bool)",
0713 "tensor(complex64)",
0714 "tensor(complex128)"};
0715 return all_tensor_types;
0716 }
0717
0718 static const std::vector<std::string>& all_tensor_types_ir4() {
0719 static const std::vector<std::string> all_tensor_types_ir4 = {
0720 "tensor(uint8)",
0721 "tensor(uint16)",
0722 "tensor(uint32)",
0723 "tensor(uint64)",
0724 "tensor(int8)",
0725 "tensor(int16)",
0726 "tensor(int32)",
0727 "tensor(int64)",
0728 "tensor(bfloat16)",
0729 "tensor(float16)",
0730 "tensor(float)",
0731 "tensor(double)",
0732 "tensor(string)",
0733 "tensor(bool)",
0734 "tensor(complex64)",
0735 "tensor(complex128)"};
0736 return all_tensor_types_ir4;
0737 }
0738
0739 static const std::vector<std::string>& all_float_types_ir4() {
0740 static const std::vector<std::string> all_float_types_ir4 = {
0741 "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)"};
0742 return all_float_types_ir4;
0743 }
0744
0745 static const std::vector<std::string>& all_float_types_ir9() {
0746 static const std::vector<std::string> all_float_types_ir9 = {
0747 "tensor(bfloat16)",
0748 "tensor(float16)",
0749 "tensor(float)",
0750 "tensor(double)",
0751 "tensor(float8e4m3fn)",
0752 "tensor(float8e4m3fnuz)",
0753 "tensor(float8e5m2)",
0754 "tensor(float8e5m2fnuz)"};
0755 return all_float_types_ir9;
0756 }
0757
0758 static const std::vector<std::string>& all_tensor_types_ir9() {
0759 static const std::vector<std::string> all_tensor_types_ir9 = {
0760 "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)",
0761 "tensor(int8)", "tensor(int16)", "tensor(int32)", "tensor(int64)",
0762 "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)",
0763 "tensor(string)", "tensor(bool)", "tensor(complex64)", "tensor(complex128)",
0764 "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"};
0765 return all_tensor_types_ir9;
0766 }
0767
0768 static const std::vector<std::string>& all_tensor_sequence_types() {
0769 static const std::vector<std::string> all_tensor_sequence_types = {
0770 "seq(tensor(uint8))",
0771 "seq(tensor(uint16))",
0772 "seq(tensor(uint32))",
0773 "seq(tensor(uint64))",
0774 "seq(tensor(int8))",
0775 "seq(tensor(int16))",
0776 "seq(tensor(int32))",
0777 "seq(tensor(int64))",
0778 "seq(tensor(float16))",
0779 "seq(tensor(float))",
0780 "seq(tensor(double))",
0781 "seq(tensor(string))",
0782 "seq(tensor(bool))",
0783 "seq(tensor(complex64))",
0784 "seq(tensor(complex128))"};
0785 return all_tensor_sequence_types;
0786 }
0787
0788 static const std::vector<std::string>& all_tensor_sequence_types_ir4() {
0789 static const std::vector<std::string> all_tensor_sequence_types_ir4 = {
0790 "seq(tensor(uint8))",
0791 "seq(tensor(uint16))",
0792 "seq(tensor(uint32))",
0793 "seq(tensor(uint64))",
0794 "seq(tensor(int8))",
0795 "seq(tensor(int16))",
0796 "seq(tensor(int32))",
0797 "seq(tensor(int64))",
0798 "seq(tensor(bfloat16))",
0799 "seq(tensor(float16))",
0800 "seq(tensor(float))",
0801 "seq(tensor(double))",
0802 "seq(tensor(string))",
0803 "seq(tensor(bool))",
0804 "seq(tensor(complex64))",
0805 "seq(tensor(complex128))"};
0806 return all_tensor_sequence_types_ir4;
0807 }
0808
0809 static const std::vector<std::string>& all_tensor_sequence_types_ir9() {
0810 static const std::vector<std::string> all_tensor_sequence_types_ir4 = {
0811 "seq(tensor(uint8))", "seq(tensor(uint16))", "seq(tensor(uint32))",
0812 "seq(tensor(uint64))", "seq(tensor(int8))", "seq(tensor(int16))",
0813 "seq(tensor(int32))", "seq(tensor(int64))", "seq(tensor(bfloat16))",
0814 "seq(tensor(float16))", "seq(tensor(float))", "seq(tensor(double))",
0815 "seq(tensor(string))", "seq(tensor(bool))", "seq(tensor(complex64))",
0816 "seq(tensor(complex128))", "seq(tensor(float8e4m3fn))", "seq(tensor(float8e4m3fnuz))",
0817 "seq(tensor(float8e5m2))", "seq(tensor(float8e5m2fnuz))"};
0818 return all_tensor_sequence_types_ir4;
0819 }
0820
0821 static const std::vector<std::string>& all_optional_types() {
0822 static const std::vector<std::string> all_optional_types = {
0823 "optional(seq(tensor(uint8)))", "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0824 "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))", "optional(seq(tensor(int16)))",
0825 "optional(seq(tensor(int32)))", "optional(seq(tensor(int64)))", "optional(seq(tensor(float16)))",
0826 "optional(seq(tensor(float)))", "optional(seq(tensor(double)))", "optional(seq(tensor(string)))",
0827 "optional(seq(tensor(bool)))", "optional(seq(tensor(complex64)))", "optional(seq(tensor(complex128)))",
0828 "optional(tensor(uint8))", "optional(tensor(uint16))", "optional(tensor(uint32))",
0829 "optional(tensor(uint64))", "optional(tensor(int8))", "optional(tensor(int16))",
0830 "optional(tensor(int32))", "optional(tensor(int64))", "optional(tensor(float16))",
0831 "optional(tensor(float))", "optional(tensor(double))", "optional(tensor(string))",
0832 "optional(tensor(bool))", "optional(tensor(complex64))", "optional(tensor(complex128))"};
0833 return all_optional_types;
0834 }
0835
0836 static const std::vector<std::string>& all_optional_types_ir4() {
0837 static const std::vector<std::string> all_optional_types = {
0838 "optional(seq(tensor(uint8)))", "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0839 "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))", "optional(seq(tensor(int16)))",
0840 "optional(seq(tensor(int32)))", "optional(seq(tensor(int64)))", "optional(seq(tensor(bfloat16)))",
0841 "optional(seq(tensor(float16)))", "optional(seq(tensor(float)))", "optional(seq(tensor(double)))",
0842 "optional(seq(tensor(string)))", "optional(seq(tensor(bool)))", "optional(seq(tensor(complex64)))",
0843 "optional(seq(tensor(complex128)))", "optional(tensor(uint8))", "optional(tensor(uint16))",
0844 "optional(tensor(uint32))", "optional(tensor(uint64))", "optional(tensor(int8))",
0845 "optional(tensor(int16))", "optional(tensor(int32))", "optional(tensor(int64))",
0846 "optional(tensor(bfloat16))", "optional(tensor(float16))", "optional(tensor(float))",
0847 "optional(tensor(double))", "optional(tensor(string))", "optional(tensor(bool))",
0848 "optional(tensor(complex64))", "optional(tensor(complex128))"};
0849 return all_optional_types;
0850 }
0851
0852 static const std::vector<std::string>& all_optional_types_ir9() {
0853 static const std::vector<std::string> all_optional_types = {
0854 "optional(seq(tensor(uint8)))", "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0855 "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))", "optional(seq(tensor(int16)))",
0856 "optional(seq(tensor(int32)))", "optional(seq(tensor(int64)))", "optional(seq(tensor(bfloat16)))",
0857 "optional(seq(tensor(float16)))", "optional(seq(tensor(float)))", "optional(seq(tensor(double)))",
0858 "optional(seq(tensor(string)))", "optional(seq(tensor(bool)))", "optional(seq(tensor(complex64)))",
0859 "optional(seq(tensor(complex128)))", "optional(tensor(uint8))", "optional(tensor(uint16))",
0860 "optional(tensor(uint32))", "optional(tensor(uint64))", "optional(tensor(int8))",
0861 "optional(tensor(int16))", "optional(tensor(int32))", "optional(tensor(int64))",
0862 "optional(tensor(bfloat16))", "optional(tensor(float16))", "optional(tensor(float))",
0863 "optional(tensor(double))", "optional(tensor(string))", "optional(tensor(bool))",
0864 "optional(tensor(complex64))", "optional(tensor(complex128))", "optional(tensor(float8e4m3fn))",
0865 "optional(tensor(float8e4m3fnuz))", "optional(tensor(float8e5m2))", "optional(tensor(float8e5m2fnuz))"};
0866 return all_optional_types;
0867 }
0868
0869
0870
0871 OpSchema& FillUsing(const std::function<void(OpSchema&)>& populator);
0872
0873 friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema);
0874
0875 const std::string& domain() const {
0876 return domain_;
0877 }
0878
0879 const std::map<std::string, Attribute>& attributes() const {
0880 return attributes_;
0881 }
0882
0883
0884 const std::vector<FormalParameter>& inputs() const {
0885 return inputs_;
0886 }
0887
0888
0889 const std::vector<FormalParameter>& outputs() const {
0890 return outputs_;
0891 }
0892
0893 const std::vector<TypeConstraintParam>& typeConstraintParams() const {
0894 return type_constraint_params_;
0895 }
0896
0897 const TypeConstraintMap& typeConstraintMap() const {
0898 return type_constraints_;
0899 }
0900
0901 const std::string& Name() const {
0902 return name_;
0903 }
0904
0905 OperatorSetVersion SinceVersion() const {
0906 return since_version_;
0907 }
0908
0909 int since_version() const {
0910 return since_version_;
0911 }
0912
0913 bool deprecated() const {
0914 return deprecated_;
0915 }
0916
0917 int min_input() const {
0918 return min_input_;
0919 }
0920 int max_input() const {
0921 return max_input_;
0922 }
0923 int min_output() const {
0924 return min_output_;
0925 }
0926 int max_output() const {
0927 return max_output_;
0928 }
0929
0930 bool has_type_and_shape_inference_function() const {
0931 return tensor_inference_function_ ? true : false;
0932 }
0933
0934 bool has_data_propagation_function() const {
0935 return data_propagation_function_ ? true : false;
0936 }
0937
0938 std::vector<int> function_opset_versions() const {
0939 std::vector<int> opset_versions;
0940 std::map<int, std::shared_ptr<FunctionProto>>::const_iterator it = opset_version_to_function_body_.cbegin();
0941 for (; it != opset_version_to_function_body_.cend(); ++it) {
0942 opset_versions.push_back(it->first);
0943 }
0944 return opset_versions;
0945 }
0946
0947 bool HasFunction() const {
0948 return !opset_version_to_function_body_.empty();
0949 }
0950
0951 OpSchema& FunctionBody(const std::vector<NodeProto>& func_nodes, int opset_version = kUninitializedSinceVersion);
0952
0953 OpSchema& FunctionBody(
0954 const std::vector<NodeProto>& func_nodes,
0955 const std::vector<OperatorSetIdProto>& opsets,
0956 int opset_version = kUninitializedSinceVersion);
0957
0958 OpSchema& FunctionBody(const char* func_body, int opset_version = kUninitializedSinceVersion);
0959
0960
0961
0962
0963
0964
0965
0966
0967
0968
0969
0970
0971
0972
0973
0974
0975
0976
0977 const FunctionProto* GetFunction(
0978 int requested_opset_version = OpSchema::kUninitializedSinceVersion,
0979 bool validate = false) const;
0980
0981 std::vector<int> context_dependent_function_opset_versions() const {
0982 std::vector<int> opset_versions;
0983 std::map<int, ContextDependentFunctionBodyBuilder>::const_iterator it = opset_version_to_function_builder_.cbegin();
0984 for (; it != opset_version_to_function_builder_.cend(); ++it) {
0985 opset_versions.push_back(it->first);
0986 }
0987 return opset_versions;
0988 }
0989
0990 bool HasContextDependentFunction() const {
0991 return !opset_version_to_function_builder_.empty();
0992 }
0993
0994 bool HasContextDependentFunctionWithOpsetVersion(int opset_version) const {
0995 return opset_version_to_function_builder_.find(opset_version) != opset_version_to_function_builder_.end();
0996 }
0997
0998 OpSchema& SetContextDependentFunctionBodyBuilder(
0999 ContextDependentFunctionBodyBuilder,
1000 int opset_version = kUninitializedSinceVersion);
1001
1002 bool BuildContextDependentFunction(
1003 const FunctionBodyBuildContext& ctx,
1004 FunctionProto& function_proto,
1005 int requested_opset_version = OpSchema::kUninitializedSinceVersion) const;
1006
1007
1008
1009
1010
1011 void Finalize();
1012
1013
1014 void BuildFunction(FunctionProto& function_body) const;
1015
1016 private:
1017 void ParseAndSetTypes(
1018 std::vector<OpSchema::FormalParameter>* formalParameters);
1019 bool ValidateReferencedOpsInFuncton(
1020 const FunctionProto* function,
1021 int requested_opset_version,
1022 int function_since_version,
1023 std::set<std::string>* updated_ops = nullptr) const;
1024 void UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int opset_version) const;
1025
1026 std::string name_;
1027 std::string file_;
1028 std::string doc_;
1029
1030 std::string domain_ = ONNX_DOMAIN;
1031 std::map<std::string, Attribute> attributes_{};
1032 bool allows_unchecked_attributes_ = false;
1033 std::vector<FormalParameter> inputs_;
1034 std::vector<FormalParameter> outputs_;
1035 std::vector<TypeConstraintParam> type_constraint_params_;
1036 TypeConstraintMap type_constraints_;
1037 int line_ = 0;
1038 SupportType support_;
1039 int min_input_ = 0;
1040 int max_input_ = 0;
1041 int min_output_ = 0;
1042 int max_output_ = 0;
1043
1044 OperatorSetVersion since_version_ = kUninitializedSinceVersion;
1045 bool deprecated_{};
1046 std::function<bool(int)> num_inputs_allowed_ = [](int) { return true; };
1047 std::function<bool(int)> num_outputs_allowed_ = [](int) { return true; };
1048 InferenceFunction tensor_inference_function_;
1049 DataPropagationFunction data_propagation_function_;
1050
1051 std::map<int, std::shared_ptr<FunctionProto>> opset_version_to_function_body_;
1052 std::map<int, ContextDependentFunctionBodyBuilder> opset_version_to_function_builder_;
1053 };
1054
1055
1056
1057 using OpName_Domain_Version_Schema_Map =
1058 std::unordered_map<std::string, std::unordered_map<std::string, std::map<OperatorSetVersion, OpSchema>>>;
1059
1060 class ISchemaRegistry {
1061 public:
1062 virtual ~ISchemaRegistry() = default;
1063
1064 virtual const OpSchema*
1065 GetSchema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) const = 0;
1066 };
1067
1068
1069
1070
1071 class OpSchemaRegistry final : public ISchemaRegistry {
1072 public:
1073
1074
1075 class DomainToVersionRange final {
1076 public:
1077 DomainToVersionRange() {
1078
1079
1080
1081 map_[ONNX_DOMAIN] = std::make_pair(1, 20);
1082 map_[AI_ONNX_ML_DOMAIN] = std::make_pair(1, 4);
1083 map_[AI_ONNX_TRAINING_DOMAIN] = std::make_pair(1, 1);
1084
1085
1086
1087 map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 1);
1088
1089
1090
1091 last_release_version_map_[ONNX_DOMAIN] = 20;
1092 last_release_version_map_[AI_ONNX_ML_DOMAIN] = 4;
1093 last_release_version_map_[AI_ONNX_TRAINING_DOMAIN] = 1;
1094 last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 1;
1095 }
1096
1097 const std::unordered_map<std::string, std::pair<int, int>>& Map() const {
1098 return map_;
1099 }
1100
1101 const std::unordered_map<std::string, int>& LastReleaseVersionMap() const {
1102 return last_release_version_map_;
1103 }
1104
1105
1106
1107
1108
1109
1110
1111
1112 void
1113 AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) {
1114 std::lock_guard<std::mutex> lock(mutex_);
1115 assert(map_.end() == map_.find(domain));
1116 map_[domain] = std::make_pair(min_version, max_version);
1117
1118
1119 if (last_release_version == -1)
1120 last_release_version = max_version;
1121 assert(last_release_version_map_.end() == last_release_version_map_.find(domain));
1122 last_release_version_map_[domain] = last_release_version;
1123 }
1124
1125 static DomainToVersionRange& Instance();
1126
1127 private:
1128
1129 std::unordered_map<std::string, std::pair<int, int>> map_;
1130
1131
1132
1133
1134 std::unordered_map<std::string, int> last_release_version_map_;
1135
1136 std::mutex mutex_;
1137 };
1138
1139 class OpSchemaRegisterOnce final {
1140 public:
1141 OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1142 ONNX_TRY {
1143 op_schema.Finalize();
1144 auto& m = GetMapWithoutEnsuringRegistration();
1145 auto& op_name = op_schema.Name();
1146 auto& op_domain = op_schema.domain();
1147 auto& schema_ver_map = m[op_name][op_domain];
1148 auto ver = op_schema.SinceVersion();
1149 if (OpSchema::kUninitializedSinceVersion == ver) {
1150 op_schema.SinceVersion(1);
1151 ver = op_schema.SinceVersion();
1152 }
1153
1154
1155 if (schema_ver_map.count(ver)) {
1156 if (fail_duplicate_schema) {
1157 const auto& schema = schema_ver_map[ver];
1158 std::stringstream err;
1159 err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1160 << ") from file " << op_schema.file() << " line " << op_schema.line()
1161 << ", but it is already registered from file " << schema.file() << " line " << schema.line()
1162 << std::endl;
1163 fail_schema(err.str());
1164 }
1165 return;
1166 }
1167
1168 if (opset_version_to_load != 0) {
1169
1170 if (ver > opset_version_to_load)
1171 return;
1172
1173
1174 if (!schema_ver_map.empty()) {
1175 int max_registered_ver_le_target = GetMaxRegisteredVerWithinTarget(schema_ver_map, opset_version_to_load);
1176 if (max_registered_ver_le_target >= ver)
1177 return;
1178 }
1179 }
1180
1181 CheckDomainAndVersionToRegister(op_schema, op_name, op_domain);
1182 schema_ver_map.insert(std::pair<int, OpSchema&&>(ver, std::move(op_schema)));
1183 }
1184 ONNX_CATCH(const std::exception& e) {
1185 ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; });
1186 }
1187 }
1188
1189 private:
1190
1191 static int GetMaxRegisteredVerWithinTarget(const std::map<OperatorSetVersion, OpSchema>& m, int target_ver) {
1192
1193
1194 for (auto&& it = m.rbegin(); it != m.rend(); it++) {
1195 const auto& registered_ver = it->first;
1196 if (registered_ver <= target_ver) {
1197 return registered_ver;
1198 }
1199 }
1200 return -1;
1201 }
1202
1203 static void CheckDomainAndVersionToRegister(
1204 const OpSchema& op_schema,
1205 const std::string& op_name,
1206 const std::string& op_domain) {
1207 auto ver_range_map = DomainToVersionRange::Instance().Map();
1208 auto ver_range_it = ver_range_map.find(op_domain);
1209 auto ver = op_schema.SinceVersion();
1210
1211 if (ver_range_it == ver_range_map.end()) {
1212 std::stringstream err;
1213 err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1214 << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its domain is not"
1215 << " known by the checker." << std::endl;
1216
1217 fail_schema(err.str());
1218 }
1219 auto lower_bound_incl = ver_range_it->second.first;
1220 auto upper_bound_incl = ver_range_it->second.second;
1221 if (!(lower_bound_incl <= ver && upper_bound_incl >= ver)) {
1222 std::stringstream err;
1223 err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1224 << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its version is not "
1225 << "in the inclusive range [" << lower_bound_incl << ", " << upper_bound_incl
1226 << "] (usually, this means you "
1227 << "bumped the operator version but "
1228 << "forgot to update the version range in DomainToVersionRange "
1229 << "in onnx/defs/schema.h)." << std::endl;
1230 fail_schema(err.str());
1231 }
1232 }
1233 };
1234
1235
1236
1237 static void OpSchemaDeregisterAll(const std::string& domain = ONNX_DOMAIN) {
1238 auto& schema_map = GetMapWithoutEnsuringRegistration();
1239
1240
1241 for (auto&& schema_map_pair : schema_map) {
1242 auto& domain_map = schema_map_pair.second;
1243 if (domain_map.count(domain)) {
1244 auto& opset_version_schema_map = domain_map[domain];
1245
1246 opset_version_schema_map.clear();
1247 domain_map.erase(domain);
1248 }
1249 }
1250 }
1251
1252
1253
1254 static const OpSchema* Schema(const std::string& key, const std::string& domain = ONNX_DOMAIN) {
1255 auto& m = map();
1256 if (m.count(key) && m[key].count(domain)) {
1257 const auto& schema_ver_map = m[key][domain];
1258 if (!schema_ver_map.empty()) {
1259 return &m[key][domain].rbegin()->second;
1260 }
1261 }
1262 return nullptr;
1263 }
1264
1265
1266
1267
1268 static const OpSchema*
1269 Schema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) {
1270 auto& m = map();
1271 if (m.count(key) && m[key].count(domain)) {
1272 const auto& schema_ver_map = m[key][domain];
1273 if (!schema_ver_map.empty()) {
1274 auto pos = m[key][domain].lower_bound(maxInclusiveVersion);
1275 if (m[key][domain].begin() == pos && pos->first > maxInclusiveVersion) {
1276
1277 return nullptr;
1278 }
1279 if (m[key][domain].end() == pos || pos->first > maxInclusiveVersion) {
1280
1281
1282 pos--;
1283 }
1284
1285
1286 return &(pos->second);
1287 }
1288 }
1289 return nullptr;
1290 }
1291
1292 static OpSchemaRegistry* Instance();
1293
1294 const OpSchema* GetSchema(
1295 const std::string& key,
1296 const int maxInclusiveVersion,
1297 const std::string& domain = ONNX_DOMAIN) const override {
1298 return Schema(key, maxInclusiveVersion, domain);
1299 }
1300 static void SetLoadedSchemaVersion(int target_version) {
1301 loaded_schema_version = target_version;
1302 }
1303 static int GetLoadedSchemaVersion() {
1304 return loaded_schema_version;
1305 }
1306
1307 private:
1308
1309
1310 OpSchemaRegistry() = default;
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322 static OpName_Domain_Version_Schema_Map& GetMapWithoutEnsuringRegistration();
1323 static OpName_Domain_Version_Schema_Map& map();
1324 static int loaded_schema_version;
1325
1326 public:
1327 static const std::vector<OpSchema> get_all_schemas_with_history() {
1328 std::vector<OpSchema> r;
1329 for (auto& x : map()) {
1330 for (auto& y : x.second) {
1331 for (auto& z : y.second) {
1332 r.emplace_back(z.second);
1333 }
1334 }
1335 }
1336 return r;
1337 }
1338
1339 static const std::vector<OpSchema> get_all_schemas() {
1340 std::vector<OpSchema> r;
1341 for (auto& x : map()) {
1342 for (auto& y : x.second) {
1343 auto& version2schema = y.second;
1344 r.emplace_back(version2schema.rbegin()->second);
1345 }
1346 }
1347 return r;
1348 }
1349 };
1350
1351 void RegisterSchema(OpSchema schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true);
1352
1353
1354
1355 template <class T>
1356 void RegisterOpSetSchema(int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1357 T::ForEachSchema([opset_version_to_load, fail_duplicate_schema](OpSchema&& schema) {
1358 RegisterSchema(schema, opset_version_to_load, fail_duplicate_schema);
1359 });
1360 };
1361
1362
1363
1364
1365 template <typename T>
1366 OpSchema GetOpSchema();
1367
1368 #define ONNX_OPERATOR_SET_SCHEMA(name, ver, impl) ONNX_OPERATOR_SET_SCHEMA_EX(name, Onnx, ONNX_DOMAIN, ver, true, impl)
1369
1370 #define ONNX_ML_OPERATOR_SET_SCHEMA(name, ver, impl) \
1371 ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxML, AI_ONNX_ML_DOMAIN, ver, true, impl)
1372
1373 #define ONNX_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
1374 ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxTraining, AI_ONNX_TRAINING_DOMAIN, ver, true, impl)
1375
1376 #define ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
1377 ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxPreview, AI_ONNX_PREVIEW_TRAINING_DOMAIN, ver, true, impl)
1378
1379
1380
1381
1382
1383
1384
1385 #define ONNX_OPERATOR_SET_SCHEMA_EX(name, domain, domain_str, ver, dbg_included_in_static_opset, impl) \
1386 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name); \
1387 template <> \
1388 OpSchema GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name)>() { \
1389 return impl.SetName(#name).SetDomain(domain_str).SinceVersion(ver).SetLocation(__FILE__, __LINE__); \
1390 } \
1391 size_t dbg_count_check_##name##_##domain##_ver##ver = \
1392 (dbg_included_in_static_opset) ? ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() : 0;
1393 #ifdef NDEBUG
1394 #define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() 0
1395 #else
1396 #define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().IncrementCount()
1397 #define ONNX_DBG_GET_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().GetCount()
1398
1399 class DbgOperatorSetTracker {
1400 public:
1401 static DbgOperatorSetTracker& Instance();
1402
1403 size_t IncrementCount() {
1404 return ++count_;
1405 }
1406
1407 size_t GetCount() const {
1408 return count_;
1409 }
1410
1411 private:
1412 size_t count_ = 0;
1413 };
1414 #endif
1415
1416
1417 #define ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name) name##_##domain##_ver##ver
1418
1419
1420 #define ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(ver, name) \
1421 ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxPreview, ver, name)
1422
1423
1424 size_t ReplaceAll(std::string& s, const char* from, const char* to);
1425
1426 #ifdef __GNUC__
1427 #define ONNX_UNUSED __attribute__((__unused__))
1428 #else
1429 #define ONNX_UNUSED
1430 #endif
1431
1432
1433 #define ONNX_OPERATOR_SCHEMA(name) ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name)
1434 #define ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)
1435 #define ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name) \
1436 static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce(op_schema_register_once##name##Counter) ONNX_UNUSED = \
1437 OpSchema(#name, __FILE__, __LINE__)
1438
1439
1440 size_t ReplaceAll(std::string& s, const char* from, const char* to);
1441
1442 inline std::string GenerateOptionalArgumentsDoc() {
1443 return "This operator has **optional** inputs/outputs. "
1444 "See [the doc](IR.md) for more details about the representation of "
1445 "optional arguments. An empty string may be used in the place of "
1446 "an actual argument's name to indicate a missing argument. "
1447 "Trailing optional arguments (those not followed by an argument "
1448 "that is present) may also be simply omitted.\n";
1449 }
1450
1451 inline std::string GenerateBroadcastingDocMul() {
1452 return "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**;"
1453 " for more details please check [the doc](Broadcasting.md).";
1454 }
1455
1456 inline std::string GenerateBroadcastingDocUni(const char* from, const char* to) {
1457 std::string ret = "This operator supports **unidirectional broadcasting** (";
1458 ret = ret + from + " should be unidirectional broadcastable to " + to +
1459 ");"
1460 " for more details please check [the doc](Broadcasting.md).";
1461 return ret;
1462 }
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475 #ifndef __ONNX_NO_DOC_STRINGS
1476 #define GET_OP_DOC_STR(doc_str) (doc_str)
1477 #else
1478 #define GET_OP_DOC_STR(doc_str) ("")
1479 #endif
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500 #ifndef __ONNX_NO_DOC_STRINGS
1501 #define POPULATE_OP_DOC_STR(DocPopulatorCode) \
1502 do { \
1503 DocPopulatorCode \
1504 } while (0)
1505 #else
1506 #define POPULATE_OP_DOC_STR(DocPopulatorCode)
1507 #endif
1508
1509 }