Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-11-03 09:52:25

0001 /*
0002  * SPDX-License-Identifier: Apache-2.0
0003  */
0004 
0005 #pragma once
0006 
0007 #include <climits>
0008 #include <cstring>
0009 #include <functional>
0010 #include <initializer_list>
0011 #include <iostream>
0012 #include <limits>
0013 #include <map>
0014 #include <memory>
0015 #include <ostream>
0016 #include <set>
0017 #include <string>
0018 #include <string_view>
0019 #include <tuple>
0020 #include <unordered_map>
0021 #include <unordered_set>
0022 #include <utility>
0023 #include <vector>
0024 
0025 #include "onnx/common/common.h"
0026 #include "onnx/common/constants.h"
0027 #include "onnx/defs/shape_inference.h"
0028 
0029 namespace ONNX_NAMESPACE {
0030 
0031 struct FunctionBodyBuildContext {
0032   virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0033   virtual bool hasInput(int inputIndex) const = 0;
0034   virtual bool hasOutput(int inputIndex) const = 0;
0035   // getInputType(i) should return null for missing optional inputs, or if
0036   // type-inference could not infer the input-type (erroneous model).
0037   virtual const TypeProto* getInputType(int inputIndex) const = 0;
0038   virtual ~FunctionBodyBuildContext() {}
0039 };
0040 
0041 struct FunctionBodyBuildContextImpl : public FunctionBodyBuildContext {
0042   // Input_types: use a default TypeProto for missing types. We use a different convention
0043   // here (from FunctionBodyBuildContext) to simplify python interoperability.
0044   // The default value for input_types is included only for backward compatibility.
0045   // It can be used for functions that do not depend on the type-context, but
0046   // will not be sufficient for functions that do use the type-context.
0047   FunctionBodyBuildContextImpl(const NodeProto& node_proto, const std::vector<TypeProto>& input_types = {})
0048       : node_proto_(node_proto), input_types_(input_types) {
0049     for (auto& attr : node_proto.attribute()) {
0050       attributesByName_[attr.name()] = &attr;
0051     }
0052   }
0053 
0054   const AttributeProto* getAttribute(const std::string& name) const override {
0055     auto iter = attributesByName_.find(name);
0056     if (iter == attributesByName_.end()) {
0057       return nullptr;
0058     } else {
0059       return iter->second;
0060     }
0061   }
0062 
0063   bool hasInput(int inputIndex) const override {
0064     if (inputIndex >= node_proto_.input_size())
0065       return false;
0066     return node_proto_.input(inputIndex) != "";
0067   }
0068 
0069   bool hasOutput(int inputIndex) const override {
0070     if (inputIndex >= node_proto_.output_size())
0071       return false;
0072     return node_proto_.output(inputIndex) != "";
0073   }
0074 
0075   const TypeProto* getInputType(int inputIndex) const override {
0076     if (inputIndex < 0)
0077       return nullptr;
0078     size_t j = static_cast<size_t>(inputIndex);
0079     if (j >= input_types_.size())
0080       return nullptr;
0081     // Convert default value (no variant set) into null.
0082     if (input_types_[j].value_case() == TypeProto::ValueCase::VALUE_NOT_SET)
0083       return nullptr;
0084     return &input_types_[j];
0085   }
0086 
0087   std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0088 
0089   NodeProto node_proto_;
0090   std::vector<TypeProto> input_types_;
0091 };
0092 
0093 using FunctionBodyQueryFunction = std::function<bool(FunctionBodyBuildContext&)>;
0094 
0095 class OpSchema;
0096 using ContextDependentFunctionBodyBuilder =
0097     std::function<bool(const FunctionBodyBuildContext&, const OpSchema&, FunctionProto&)>;
0098 
0099 class SchemaError final : public std::runtime_error {
0100  public:
0101   using std::runtime_error::runtime_error;
0102 
0103   SchemaError(const std::string& message) : std::runtime_error(message) {}
0104 
0105   const char* what() const noexcept override {
0106     if (!expanded_message_.empty()) {
0107       return expanded_message_.c_str();
0108     }
0109     return std::runtime_error::what();
0110   }
0111 
0112   void AppendContext(const std::string& context) {
0113     expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
0114   }
0115 
0116  private:
0117   std::string expanded_message_;
0118 };
0119 
0120 #define fail_schema(...) ONNX_THROW_EX(ONNX_NAMESPACE::SchemaError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
0121 
0122 using OperatorSetVersion = int;
0123 
0124 using DataTypeSet = std::unordered_set<DataType>;
0125 
0126 // Type constraint map. Key is type string. Value is data type set and
0127 // description.
0128 using TypeConstraintMap = std::unordered_map<std::string, std::pair<DataTypeSet, std::string>>;
0129 
0130 /**
0131  * @brief A class to record the schema of an op.
0132  *
0133  * OpSchema records the common interface of an op specified by its name.
0134  *
0135  * To register an OpSchema, one can use the macro ONNX_OPERATOR_SCHEMA(name) and
0136  * then append the various functions in the class. For example, for an op
0137  * that takes in two inputs, one output, and the first input and output
0138  * could be in-place, can be written as
0139  *
0140  *     ONNX_OPERATOR_SCHEMA(name)
0141  *         .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}});
0142  *
0143  * To manufacture methods that may be used to register an OpSchema
0144  * non-statically, the following may be used:
0145  *
0146  *     ONNX_OPERATOR_SET_SCHEMA(name, version, OpSchema()
0147  *         .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}}));
0148  */
0149 class OpSchema final {
0150  public:
0151   static constexpr int kUninitializedSinceVersion = -1;
0152   // Formal parameter options.
0153   enum FormalParameterOption : uint8_t {
0154     // The formal parameter is single and not optional.
0155     // Number of supplied actual parameters must be 1.
0156     Single = 0,
0157     // The formal parameter is single and optional.
0158     // Number of supplied actual parameters may be 0 or 1.
0159     Optional = 1,
0160     // The formal parameter is variadic.
0161     // Number of supplied actual parameters must be N or more, where
0162     // the minimum value N is indicated separately (default value 1).
0163     Variadic = 2,
0164   };
0165   enum DifferentiationCategory : uint8_t {
0166     // Whether this formal parameter is differentiable or not cannot
0167     // be statically determined. It also covers variadic formal
0168     // parameters which contain both of differentiable and
0169     // non-differentiable variables.
0170     Unknown = 0,
0171     // This formal parameter is differentiable. That is, this formal
0172     // parameter can be differentiable input of Gradient operator.
0173     Differentiable = 1,
0174     // This formal parameter is not differentiable. That is, this formal
0175     // parameter can not be differentiable input of Gradient operator.
0176     NonDifferentiable = 2
0177   };
0178 
0179   // Formal parameter represenation, including input/output name, typeStr,
0180   // description, and type constraints.
0181   class FormalParameter final {
0182    public:
0183     // Constructor.
0184     FormalParameter() = default;
0185 
0186     explicit FormalParameter(
0187         std::string name,
0188         DataTypeSet allowed_type_set,
0189         std::string type_str,
0190         const std::string& description,
0191         FormalParameterOption param_option = Single,
0192         bool is_homogeneous = true,
0193         int min_arity = 1,
0194         DifferentiationCategory differentiation_category = Unknown)
0195         : name_(std::move(name)),
0196           type_set_(std::move(allowed_type_set)),
0197           type_str_(std::move(type_str)),
0198 #ifndef __ONNX_NO_DOC_STRINGS
0199           description_(description),
0200 #endif
0201           param_option_(param_option),
0202           is_homogeneous_(is_homogeneous),
0203           min_arity_(min_arity),
0204           differentiation_category_(differentiation_category) {
0205 #ifdef __ONNX_NO_DOC_STRINGS
0206       ONNX_UNUSED_PARAMETER(description);
0207 #endif
0208     }
0209 
0210     explicit FormalParameter(
0211         std::string name,
0212         const std::string& description,
0213         std::string type_str,
0214         FormalParameterOption param_option = Single,
0215         bool is_homogeneous = true,
0216         int min_arity = 1,
0217         DifferentiationCategory differentiation_category = Unknown)
0218         : name_(std::move(name)),
0219           type_str_(std::move(type_str)),
0220 #ifndef __ONNX_NO_DOC_STRINGS
0221           description_(description),
0222 #endif
0223           param_option_(param_option),
0224           is_homogeneous_(is_homogeneous),
0225           min_arity_(min_arity),
0226           differentiation_category_(differentiation_category) {
0227 #ifdef __ONNX_NO_DOC_STRINGS
0228       ONNX_UNUSED_PARAMETER(description);
0229 #endif
0230     }
0231 
0232     // Get formal parameter name.
0233     const std::string& GetName() const;
0234 
0235     // Get allowed data types.
0236     const DataTypeSet& GetTypes() const;
0237 
0238     // Get formal parameter type string.
0239     const std::string& GetTypeStr() const;
0240 
0241     // Get formal parameter description.
0242     const std::string& GetDescription() const;
0243 
0244     // Get the parameter option, it could be Single, Optional or Variadic.
0245     FormalParameterOption GetOption() const;
0246 
0247     // Get whether a variadic parameter requires all to be of same type
0248     bool GetIsHomogeneous() const;
0249 
0250     // Get minimum arity. Applicable only in the Variadic case.
0251     int GetMinArity() const;
0252 
0253     // Get the differentiation property of this formal parameter.
0254     DifferentiationCategory GetDifferentiationCategory() const;
0255 
0256    private:
0257     friend class OpSchema;
0258 
0259     DataTypeSet& MutableTypes();
0260 
0261     // Formal parameter name.
0262     std::string name_;
0263 
0264     // A set of data types supported for <*this> formal parameter.
0265     // It should contain at least one element if this formal parameter is good.
0266     DataTypeSet type_set_;
0267 
0268     // The <parameter type> string specified when registring an op.
0269     // It could be a supported data type or a type constraint key, which
0270     // maps to a set of supported data types.
0271     std::string type_str_;
0272 
0273     // Formal parameter description.
0274     std::string description_;
0275 
0276     // Formal parameter option.
0277     FormalParameterOption param_option_;
0278 
0279     // For variadic parameters, a flag indicating if all parameters must be of
0280     // same type
0281     bool is_homogeneous_;
0282 
0283     // Minimum number of parameters expected. Applicable only for Variadic.
0284     int min_arity_;
0285 
0286     // True if this parameter can be an differentiable inputs of Gradient.
0287     // Otherwise, using this parameter as an differentiable inputs of Gradient
0288     // is prohibited.
0289     DifferentiationCategory differentiation_category_;
0290   };
0291 
0292   enum class SupportType : uint8_t {
0293     COMMON, // Supported by all frameworks that support this IR.
0294     EXPERIMENTAL, // This OP is experimental and can be changed or removed in
0295                   // the future.
0296   };
0297 
0298   OpSchema() : OpSchema("unknown", "unknown", 0) {}
0299   OpSchema(std::string name, std::string file, int line)
0300       : name_(std::move(name)), file_(std::move(file)), line_(line), support_(SupportType::COMMON) {}
0301 
0302   /**
0303    * @brief Returns the file that the op schema is registered from.
0304    */
0305   const std::string& file() const {
0306     return file_;
0307   }
0308 
0309   /**
0310    * @brief Returns the line in file that the op schema is registered from.
0311    */
0312   int line() const {
0313     return line_;
0314   }
0315 
0316   /**
0317    * @brief Returns the support level of the op schema.
0318    */
0319   SupportType support_level() const {
0320     return support_;
0321   }
0322 
0323   /**
0324    * @brief Returns the docstring of the op schema.
0325    */
0326   const char* doc() const {
0327     return doc_.empty() ? nullptr : doc_.c_str();
0328   }
0329 
0330   // Check if input and output types fall into valid set and match each other
0331   void CheckInputOutputType(struct InferenceContext&) const;
0332 
0333   /**
0334    * @brief Verifies if a NodeProto matches the pattern specified in
0335    * the schema.
0336    */
0337   void Verify(const NodeProto& node) const;
0338 
0339   // Functions to set the property of the operator schemas.
0340   // Sets the number of inputs, either a fixed number or a min and a max.
0341 
0342   /**
0343    * The earliest operator set version which this operator was
0344    * present in.  If an operator has had no BC-breaking changes,
0345    * this is simply the first operator set the operator was a member
0346    * of; if it has had BC-breaking changes, then for the semantics
0347    * /as described/ in the OpSchema entry, this version describes
0348    * the operator set which introduced the BC-breaking change.
0349    *
0350    * For example, suppose op Foo was added in v3, and had a BC-breaking
0351    * change in v6.  Then there will be an op schema entry for Foo with
0352    * SinceVersion(3), and another, updated op schema entry for Foo
0353    * with SinceVersion(6).
0354    */
0355   OpSchema& SinceVersion(OperatorSetVersion n); // aka int
0356 
0357   /**
0358    * Marks this op as deprecated as of it's since_version. This will cause the
0359    * Schema() lookup functions to return nullptr when the version is in the
0360    * deprecated range.
0361    */
0362   OpSchema& Deprecate();
0363 
0364   bool Deprecated() const {
0365     return deprecated_;
0366   }
0367 
0368   /**
0369    * @brief Input could be one of the values specified in allowed_input_nums.
0370    */
0371   OpSchema& NumInputs(std::set<int> allowed_input_nums);
0372 
0373   /**
0374    * @brief Output could be one of the values specified in allowed_output_nums.
0375    */
0376   OpSchema& NumOutputs(std::set<int> allowed_output_nums);
0377 
0378   // Shape Inference
0379   //
0380   // Note that signatures are defined to allow for forward-declaring
0381   // any structs used from ir.h
0382   OpSchema& TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
0383   InferenceFunction GetTypeAndShapeInferenceFunction() const {
0384     return tensor_inference_function_ ? tensor_inference_function_ : dummyInferenceFunction;
0385   }
0386 
0387   OpSchema& PartialDataPropagationFunction(DataPropagationFunction dataProgationFunction);
0388   DataPropagationFunction GetDataPropagationFunction() const {
0389     return data_propagation_function_ ? data_propagation_function_ : dummyDataPropagationFunction;
0390   }
0391 
0392   // Set the support level for the op schema.
0393   OpSchema& SetSupportLevel(SupportType supportType);
0394 
0395   // Functions to do documentation for the operator schema.
0396   // This may be disabled to save memory.
0397   OpSchema& SetDoc(const char* doc) {
0398 #ifndef __ONNX_NO_DOC_STRINGS
0399     SetDoc(std::string(doc));
0400 #else
0401     ONNX_UNUSED_PARAMETER(doc);
0402 #endif
0403 
0404     return *this;
0405   }
0406 
0407   OpSchema& SetDoc(const std::string& doc) {
0408 #ifndef __ONNX_NO_DOC_STRINGS
0409     doc_ = doc;
0410 #else
0411     ONNX_UNUSED_PARAMETER(doc);
0412 #endif
0413     return *this;
0414   }
0415 
0416   // Functions to specify name for the operator schema.
0417   OpSchema& SetName(const char* name);
0418   OpSchema& SetName(std::string name);
0419 
0420   // Functions to specify code location for the operator schema.
0421   OpSchema& SetLocation(const char* file, int line);
0422   OpSchema& SetLocation(std::string file, int line);
0423 
0424   // Functions to specify domain for the operator schema.
0425   // Default domain value (ONNX_DOMAIN) means it's ONNX domain.
0426   OpSchema& SetDomain(const char* domain);
0427   OpSchema& SetDomain(std::string domain);
0428 
0429   struct Attribute final {
0430     Attribute(std::string name_, std::string description_, AttributeProto::AttributeType type_, bool required_)
0431         : name(std::move(name_)),
0432           description(std::move(description_)),
0433           type(type_),
0434           required(required_),
0435           default_value() {}
0436 
0437     Attribute(std::string name_, std::string description_, AttributeProto default_value_)
0438         : name(std::move(name_)),
0439           description(std::move(description_)),
0440           type(default_value_.type()),
0441           required(false),
0442           default_value(std::move(default_value_)) {}
0443 
0444     const std::string name;
0445     const std::string description;
0446     AttributeProto::AttributeType type;
0447     bool required;
0448     AttributeProto default_value;
0449   };
0450 
0451   OpSchema& Attr(Attribute attr);
0452 
0453 // Register "optional" attribute with default value.
0454 #define ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName)                                                                    \
0455   OpSchema& Attr(                                                                                                   \
0456       std::string name, std::string description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
0457   /* non-STL wrapper to reduce binary size */                                                                       \
0458   OpSchema& Attr(                                                                                                   \
0459       const char* name, const char* description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
0460   OpSchema& Attr(                                                                                                   \
0461       std::string name,                                                                                             \
0462       std::string description,                                                                                      \
0463       AttributeProto::AttributeType type,                                                                           \
0464       const std::vector<TypeName>& defaultValue);
0465 
0466   ATTR_SETTER_WITH_DEFAULT_VALUE(int64_t)
0467   ATTR_SETTER_WITH_DEFAULT_VALUE(float)
0468   ATTR_SETTER_WITH_DEFAULT_VALUE(std::string)
0469   ATTR_SETTER_WITH_DEFAULT_VALUE(TensorProto)
0470   ATTR_SETTER_WITH_DEFAULT_VALUE(GraphProto)
0471   ATTR_SETTER_WITH_DEFAULT_VALUE(TypeProto)
0472 
0473   OpSchema& Attr(
0474       std::string name,
0475       std::string description,
0476       std::string conditionExplanation,
0477       AttributeProto::AttributeType attr_type);
0478 
0479   // Register "required" attribute without default value.
0480   OpSchema& Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required = true);
0481 
0482   // Non-STL wrapper to reduce binary size
0483   OpSchema& Attr(const char* name, const char* description, AttributeProto::AttributeType type, bool required = true);
0484 
0485   OpSchema& AllowUncheckedAttributes();
0486 
0487   // Type constraint.
0488   struct TypeConstraintParam final {
0489     TypeConstraintParam(
0490         std::string type_param_str_,
0491         std::vector<std::string> allowed_type_strs_,
0492         std::string description_)
0493         : type_param_str(std::move(type_param_str_)),
0494           allowed_type_strs(std::move(allowed_type_strs_)),
0495           description(std::move(description_)) {}
0496 
0497     // Type parameter string, for example, "T", "T1", etc.
0498     std::string type_param_str;
0499     // Allowed type strings for <*this> type parameter, for example,
0500     // "tensor(float)".
0501     std::vector<std::string> allowed_type_strs;
0502     // Type parameter description.
0503     std::string description;
0504   };
0505 
0506   // Grammar for type strings used in Input(), Output().
0507   // <type> ::= <data_type> |
0508   //            tensor(<data_type>) |
0509   //            seq(<type>) |
0510   //            map(<data_type>, <type>) |
0511   //            <type_parameter>
0512   // <data_type> :: = float | int32 | string | bool | uint8
0513   //                | int8 | uint16 | int16 | int64 | float16 | double
0514   // <type_parameter> ::= any type parameter string, say "T".
0515   //
0516   // NOTE: 1) <type_parameter> will always be together with a type constraints
0517   // specification.
0518   //       2) <type> ::= <data_type> means the data is scalar (zero dimension).
0519   //
0520   // Example:
0521   // ONNX_OPERATOR_SET_SCHEMA(Sum, 1, OpSchema()
0522   // .Input(0, "input_a", "the first input", "T")
0523   // .Input(1, "input_b", "the second input", "T")
0524   // .Output(0, "sum", "the sum of two numbers", "T")
0525   // .TypeConstraint("T", {"float", "double", "int32"}, "allowed data types for
0526   // sum."))
0527   //
0528   // Optional = true means that the input might have empty input value
0529   // (represented as "") in the graph even though the later inputs have values.
0530   // It's useful for complex situation when there are several independent
0531   // optional inputs.
0532   OpSchema& Input(int n, FormalParameter formal_parameter);
0533 
0534   OpSchema& Input(
0535       int n,
0536       std::string name,
0537       const std::string& description,
0538       std::string type_str,
0539       FormalParameterOption param_option = Single,
0540       bool is_homogeneous = true,
0541       int min_arity = 1,
0542       DifferentiationCategory differentiation_category = Unknown);
0543 
0544   // Non-STL wrapper to reduce binary size
0545   OpSchema& Input(
0546       int n,
0547       const char* name,
0548       const char* description,
0549       const char* type_str,
0550       FormalParameterOption param_option = Single,
0551       bool is_homogeneous = true,
0552       int min_arity = 1,
0553       DifferentiationCategory differentiation_category = Unknown);
0554 
0555   OpSchema& Output(int n, FormalParameter formal_parameter);
0556 
0557   OpSchema& Output(
0558       int n,
0559       std::string name,
0560       const std::string& description,
0561       std::string type_str,
0562       FormalParameterOption param_option = Single,
0563       bool is_homogeneous = true,
0564       int min_arity = 1,
0565       DifferentiationCategory differentiation_category = Unknown);
0566 
0567   // Non-STL wrapper to reduce binary size
0568   OpSchema& Output(
0569       int n,
0570       const char* name,
0571       const char* description,
0572       const char* type_str,
0573       FormalParameterOption param_option = Single,
0574       bool is_homogeneous = true,
0575       int min_arity = 1,
0576       DifferentiationCategory differentiation_category = Unknown);
0577 
0578   OpSchema& TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description);
0579 
0580   // Non-STL wrapper to reduce binary size
0581   OpSchema&
0582   TypeConstraint(const char* type_str, std::initializer_list<const char*> constraints, const char* description);
0583 
0584   // Convenience members for types
0585 
0586   // All high-precision numeric types.
0587   static const std::vector<std::string>& numeric_types_for_math_reduction_ir10() {
0588     return numeric_types_for_math_reduction_ir9();
0589   }
0590 
0591   static const std::vector<std::string>& numeric_types_for_math_reduction_ir9() {
0592     static const std::vector<std::string> numeric_types_for_math_reduction_ir9 = {
0593         "tensor(uint32)",
0594         "tensor(uint64)",
0595         "tensor(int32)",
0596         "tensor(int64)",
0597         "tensor(float16)",
0598         "tensor(float)",
0599         "tensor(double)",
0600         "tensor(bfloat16)",
0601         "tensor(float8e4m3fn)",
0602         "tensor(float8e4m3fnuz)",
0603         "tensor(float8e5m2)",
0604         "tensor(float8e5m2fnuz)"};
0605     return numeric_types_for_math_reduction_ir9;
0606   }
0607 
0608   static const std::vector<std::string>& numeric_types_for_math_reduction_ir4() {
0609     static const std::vector<std::string> numeric_types_for_math_reduction_ir4 = {
0610         "tensor(uint32)",
0611         "tensor(uint64)",
0612         "tensor(int32)",
0613         "tensor(int64)",
0614         "tensor(float16)",
0615         "tensor(float)",
0616         "tensor(double)",
0617         "tensor(bfloat16)"};
0618     return numeric_types_for_math_reduction_ir4;
0619   }
0620 
0621   static const std::vector<std::string>& numeric_types_for_math_reduction() {
0622     static const std::vector<std::string> numeric_types_for_math_reduction = {
0623         "tensor(uint32)",
0624         "tensor(uint64)",
0625         "tensor(int32)",
0626         "tensor(int64)",
0627         "tensor(float16)",
0628         "tensor(float)",
0629         "tensor(double)"};
0630     return numeric_types_for_math_reduction;
0631   }
0632 
0633   static const std::vector<std::string>& all_numeric_types_ir10() {
0634     static const std::vector<std::string> all_numeric_types_ir10 = {
0635         "tensor(uint8)",
0636         "tensor(uint16)",
0637         "tensor(uint32)",
0638         "tensor(uint64)",
0639         "tensor(int8)",
0640         "tensor(int16)",
0641         "tensor(int32)",
0642         "tensor(int64)",
0643         "tensor(float16)",
0644         "tensor(float)",
0645         "tensor(double)",
0646         "tensor(bfloat16)",
0647         "tensor(float8e4m3fn)",
0648         "tensor(float8e4m3fnuz)",
0649         "tensor(float8e5m2)",
0650         "tensor(float8e5m2fnuz)",
0651         "tensor(uint4)",
0652         "tensor(int4)"};
0653     return all_numeric_types_ir10;
0654   }
0655 
0656   static const std::vector<std::string>& all_numeric_types_ir9() {
0657     static const std::vector<std::string> all_numeric_types_ir9 = {
0658         "tensor(uint8)",
0659         "tensor(uint16)",
0660         "tensor(uint32)",
0661         "tensor(uint64)",
0662         "tensor(int8)",
0663         "tensor(int16)",
0664         "tensor(int32)",
0665         "tensor(int64)",
0666         "tensor(float16)",
0667         "tensor(float)",
0668         "tensor(double)",
0669         "tensor(bfloat16)",
0670         "tensor(float8e4m3fn)",
0671         "tensor(float8e4m3fnuz)",
0672         "tensor(float8e5m2)",
0673         "tensor(float8e5m2fnuz)"};
0674     return all_numeric_types_ir9;
0675   }
0676 
0677   static const std::vector<std::string>& all_numeric_types_ir4() {
0678     static const std::vector<std::string> all_numeric_types_ir4 = {
0679         "tensor(uint8)",
0680         "tensor(uint16)",
0681         "tensor(uint32)",
0682         "tensor(uint64)",
0683         "tensor(int8)",
0684         "tensor(int16)",
0685         "tensor(int32)",
0686         "tensor(int64)",
0687         "tensor(float16)",
0688         "tensor(float)",
0689         "tensor(double)",
0690         "tensor(bfloat16)"};
0691     return all_numeric_types_ir4;
0692   }
0693 
0694   static const std::vector<std::string>& all_numeric_types() {
0695     static const std::vector<std::string> all_numeric_types = {
0696         "tensor(uint8)",
0697         "tensor(uint16)",
0698         "tensor(uint32)",
0699         "tensor(uint64)",
0700         "tensor(int8)",
0701         "tensor(int16)",
0702         "tensor(int32)",
0703         "tensor(int64)",
0704         "tensor(float16)",
0705         "tensor(float)",
0706         "tensor(double)"};
0707     return all_numeric_types;
0708   }
0709 
0710   static const std::vector<std::string>& all_numeric_sequence_types() {
0711     static const std::vector<std::string> all_numeric_sequence_types = {
0712         "seq(tensor(uint8))",
0713         "seq(tensor(uint16))",
0714         "seq(tensor(uint32))",
0715         "seq(tensor(uint64))",
0716         "seq(tensor(int8))",
0717         "seq(tensor(int16))",
0718         "seq(tensor(int32))",
0719         "seq(tensor(int64))",
0720         "seq(tensor(float16))",
0721         "seq(tensor(float))",
0722         "seq(tensor(double))"};
0723     return all_numeric_sequence_types;
0724   }
0725 
0726   static const std::vector<std::string>& all_tensor_types() {
0727     static const std::vector<std::string> all_tensor_types = {
0728         "tensor(uint8)",
0729         "tensor(uint16)",
0730         "tensor(uint32)",
0731         "tensor(uint64)",
0732         "tensor(int8)",
0733         "tensor(int16)",
0734         "tensor(int32)",
0735         "tensor(int64)",
0736         "tensor(float16)",
0737         "tensor(float)",
0738         "tensor(double)",
0739         "tensor(string)",
0740         "tensor(bool)",
0741         "tensor(complex64)",
0742         "tensor(complex128)"};
0743     return all_tensor_types;
0744   }
0745 
0746   static const std::vector<std::string>& all_tensor_types_ir4() {
0747     static const std::vector<std::string> all_tensor_types_ir4 = {
0748         "tensor(uint8)",
0749         "tensor(uint16)",
0750         "tensor(uint32)",
0751         "tensor(uint64)",
0752         "tensor(int8)",
0753         "tensor(int16)",
0754         "tensor(int32)",
0755         "tensor(int64)",
0756         "tensor(bfloat16)",
0757         "tensor(float16)",
0758         "tensor(float)",
0759         "tensor(double)",
0760         "tensor(string)",
0761         "tensor(bool)",
0762         "tensor(complex64)",
0763         "tensor(complex128)"};
0764     return all_tensor_types_ir4;
0765   }
0766 
0767   static const std::vector<std::string>& all_non_complex_numeric_types_plus_bool_ir4() {
0768     static const std::vector<std::string> all_non_complex_numeric_types_plus_bool_ir4 = {
0769         "tensor(uint8)",
0770         "tensor(uint16)",
0771         "tensor(uint32)",
0772         "tensor(uint64)",
0773         "tensor(int8)",
0774         "tensor(int16)",
0775         "tensor(int32)",
0776         "tensor(int64)",
0777         "tensor(bfloat16)",
0778         "tensor(float16)",
0779         "tensor(float)",
0780         "tensor(double)",
0781         "tensor(bool)"};
0782     return all_non_complex_numeric_types_plus_bool_ir4;
0783   }
0784 
0785   static const std::vector<std::string>& all_float_types_ir4() {
0786     static const std::vector<std::string> all_float_types_ir4 = {
0787         "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)"};
0788     return all_float_types_ir4;
0789   }
0790 
0791   static const std::vector<std::string>& all_float_types_plus_Xint8_ir4() {
0792     static const std::vector<std::string> all_float_types_ir4 = {
0793         "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(int8)", "tensor(uint8)"};
0794     return all_float_types_ir4;
0795   }
0796 
0797   static const std::vector<std::string>& all_float_types_ir9() {
0798     static const std::vector<std::string> all_float_types_ir9 = {
0799         "tensor(bfloat16)",
0800         "tensor(float16)",
0801         "tensor(float)",
0802         "tensor(double)",
0803         "tensor(float8e4m3fn)",
0804         "tensor(float8e4m3fnuz)",
0805         "tensor(float8e5m2)",
0806         "tensor(float8e5m2fnuz)"};
0807     return all_float_types_ir9;
0808   }
0809 
0810   static const std::vector<std::string>& all_float_types_ir10() {
0811     return all_float_types_ir9();
0812   }
0813 
0814   static const std::vector<std::string>& all_tensor_types_ir9() {
0815     static const std::vector<std::string> all_tensor_types_ir9 = {
0816         "tensor(uint8)",        "tensor(uint16)",         "tensor(uint32)",     "tensor(uint64)",
0817         "tensor(int8)",         "tensor(int16)",          "tensor(int32)",      "tensor(int64)",
0818         "tensor(bfloat16)",     "tensor(float16)",        "tensor(float)",      "tensor(double)",
0819         "tensor(string)",       "tensor(bool)",           "tensor(complex64)",  "tensor(complex128)",
0820         "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"};
0821     return all_tensor_types_ir9;
0822   }
0823 
0824   static const std::vector<std::string>& all_tensor_types_ir10() {
0825     static const std::vector<std::string> all_tensor_types_ir10 = {
0826         "tensor(uint8)",      "tensor(uint16)",         "tensor(uint32)",
0827         "tensor(uint64)",     "tensor(int8)",           "tensor(int16)",
0828         "tensor(int32)",      "tensor(int64)",          "tensor(bfloat16)",
0829         "tensor(float16)",    "tensor(float)",          "tensor(double)",
0830         "tensor(string)",     "tensor(bool)",           "tensor(complex64)",
0831         "tensor(complex128)", "tensor(float8e4m3fn)",   "tensor(float8e4m3fnuz)",
0832         "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)",
0833         "tensor(int4)"};
0834     return all_tensor_types_ir10;
0835   }
0836 
0837   static const std::vector<std::string>& all_non_complex_tensor_types_ir10() {
0838     static const std::vector<std::string> all_non_complex_tensor_types_ir10 = {
0839         "tensor(uint8)",      "tensor(uint16)",         "tensor(uint32)",       "tensor(uint64)",
0840         "tensor(int8)",       "tensor(int16)",          "tensor(int32)",        "tensor(int64)",
0841         "tensor(bfloat16)",   "tensor(float16)",        "tensor(float)",        "tensor(double)",
0842         "tensor(string)",     "tensor(bool)",           "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
0843         "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)",        "tensor(int4)"};
0844     return all_non_complex_tensor_types_ir10;
0845   }
0846 
0847   static const std::vector<std::string>& all_tensor_sequence_types() {
0848     static const std::vector<std::string> all_tensor_sequence_types = {
0849         "seq(tensor(uint8))",
0850         "seq(tensor(uint16))",
0851         "seq(tensor(uint32))",
0852         "seq(tensor(uint64))",
0853         "seq(tensor(int8))",
0854         "seq(tensor(int16))",
0855         "seq(tensor(int32))",
0856         "seq(tensor(int64))",
0857         "seq(tensor(float16))",
0858         "seq(tensor(float))",
0859         "seq(tensor(double))",
0860         "seq(tensor(string))",
0861         "seq(tensor(bool))",
0862         "seq(tensor(complex64))",
0863         "seq(tensor(complex128))"};
0864     return all_tensor_sequence_types;
0865   }
0866 
0867   static const std::vector<std::string>& all_tensor_sequence_types_ir4() {
0868     static const std::vector<std::string> all_tensor_sequence_types_ir4 = {
0869         "seq(tensor(uint8))",
0870         "seq(tensor(uint16))",
0871         "seq(tensor(uint32))",
0872         "seq(tensor(uint64))",
0873         "seq(tensor(int8))",
0874         "seq(tensor(int16))",
0875         "seq(tensor(int32))",
0876         "seq(tensor(int64))",
0877         "seq(tensor(bfloat16))",
0878         "seq(tensor(float16))",
0879         "seq(tensor(float))",
0880         "seq(tensor(double))",
0881         "seq(tensor(string))",
0882         "seq(tensor(bool))",
0883         "seq(tensor(complex64))",
0884         "seq(tensor(complex128))"};
0885     return all_tensor_sequence_types_ir4;
0886   }
0887 
0888   static const std::vector<std::string>& all_tensor_sequence_types_ir9() {
0889     static const std::vector<std::string> all_tensor_sequence_types_ir9 = {
0890         "seq(tensor(uint8))",      "seq(tensor(uint16))",        "seq(tensor(uint32))",
0891         "seq(tensor(uint64))",     "seq(tensor(int8))",          "seq(tensor(int16))",
0892         "seq(tensor(int32))",      "seq(tensor(int64))",         "seq(tensor(bfloat16))",
0893         "seq(tensor(float16))",    "seq(tensor(float))",         "seq(tensor(double))",
0894         "seq(tensor(string))",     "seq(tensor(bool))",          "seq(tensor(complex64))",
0895         "seq(tensor(complex128))", "seq(tensor(float8e4m3fn))",  "seq(tensor(float8e4m3fnuz))",
0896         "seq(tensor(float8e5m2))", "seq(tensor(float8e5m2fnuz))"};
0897     return all_tensor_sequence_types_ir9;
0898   }
0899 
0900   static const std::vector<std::string>& all_tensor_sequence_types_ir10() {
0901     static const std::vector<std::string> all_tensor_sequence_types_ir10 = {
0902         "seq(tensor(uint8))",      "seq(tensor(uint16))",         "seq(tensor(uint32))",
0903         "seq(tensor(uint64))",     "seq(tensor(int8))",           "seq(tensor(int16))",
0904         "seq(tensor(int32))",      "seq(tensor(int64))",          "seq(tensor(bfloat16))",
0905         "seq(tensor(float16))",    "seq(tensor(float))",          "seq(tensor(double))",
0906         "seq(tensor(string))",     "seq(tensor(bool))",           "seq(tensor(complex64))",
0907         "seq(tensor(complex128))", "seq(tensor(float8e4m3fn))",   "seq(tensor(float8e4m3fnuz))",
0908         "seq(tensor(float8e5m2))", "seq(tensor(float8e5m2fnuz))", "seq(tensor(uint4))",
0909         "seq(tensor(int4))"};
0910     return all_tensor_sequence_types_ir10;
0911   }
0912 
0913   static const std::vector<std::string>& all_optional_types() {
0914     static const std::vector<std::string> all_optional_types = {
0915         "optional(seq(tensor(uint8)))",  "optional(seq(tensor(uint16)))",    "optional(seq(tensor(uint32)))",
0916         "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))",      "optional(seq(tensor(int16)))",
0917         "optional(seq(tensor(int32)))",  "optional(seq(tensor(int64)))",     "optional(seq(tensor(float16)))",
0918         "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",    "optional(seq(tensor(string)))",
0919         "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))", "optional(seq(tensor(complex128)))",
0920         "optional(tensor(uint8))",       "optional(tensor(uint16))",         "optional(tensor(uint32))",
0921         "optional(tensor(uint64))",      "optional(tensor(int8))",           "optional(tensor(int16))",
0922         "optional(tensor(int32))",       "optional(tensor(int64))",          "optional(tensor(float16))",
0923         "optional(tensor(float))",       "optional(tensor(double))",         "optional(tensor(string))",
0924         "optional(tensor(bool))",        "optional(tensor(complex64))",      "optional(tensor(complex128))"};
0925     return all_optional_types;
0926   }
0927 
0928   static const std::vector<std::string>& all_optional_types_ir4() {
0929     static const std::vector<std::string> all_optional_types = {
0930         "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0931         "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
0932         "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
0933         "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
0934         "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
0935         "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
0936         "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
0937         "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
0938         "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
0939         "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
0940         "optional(tensor(complex64))",       "optional(tensor(complex128))"};
0941     return all_optional_types;
0942   }
0943 
0944   static const std::vector<std::string>& all_optional_types_ir9() {
0945     static const std::vector<std::string> all_optional_types = {
0946         "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0947         "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
0948         "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
0949         "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
0950         "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
0951         "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
0952         "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
0953         "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
0954         "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
0955         "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
0956         "optional(tensor(complex64))",       "optional(tensor(complex128))",  "optional(tensor(float8e4m3fn))",
0957         "optional(tensor(float8e4m3fnuz))",  "optional(tensor(float8e5m2))",  "optional(tensor(float8e5m2fnuz))"};
0958     return all_optional_types;
0959   }
0960 
0961   static const std::vector<std::string>& all_optional_types_ir10() {
0962     static const std::vector<std::string> all_optional_types = {
0963         "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0964         "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
0965         "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
0966         "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
0967         "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
0968         "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
0969         "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
0970         "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
0971         "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
0972         "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
0973         "optional(tensor(complex64))",       "optional(tensor(complex128))",  "optional(tensor(float8e4m3fn))",
0974         "optional(tensor(float8e4m3fnuz))",  "optional(tensor(float8e5m2))",  "optional(tensor(float8e5m2fnuz))",
0975         "optional(tensor(uint4))",           "optional(tensor(int4))"};
0976     return all_optional_types;
0977   }
0978 
0979   // Calls the passed function with `this` as an argument. Useful for
0980   // adding docs for temlated/macro ops.
0981   OpSchema& FillUsing(const std::function<void(OpSchema&)>& populator);
0982 
0983   friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema);
0984 
0985   const std::string& domain() const {
0986     return domain_;
0987   }
0988 
0989   const std::map<std::string, Attribute>& attributes() const {
0990     return attributes_;
0991   }
0992 
0993   // Get input formal parameters.
0994   const std::vector<FormalParameter>& inputs() const {
0995     return inputs_;
0996   }
0997 
0998   // Get output formal parameters.
0999   const std::vector<FormalParameter>& outputs() const {
1000     return outputs_;
1001   }
1002 
1003   const std::vector<TypeConstraintParam>& typeConstraintParams() const {
1004     return type_constraint_params_;
1005   }
1006 
1007   const TypeConstraintMap& typeConstraintMap() const {
1008     return type_constraints_;
1009   }
1010 
1011   const std::string& Name() const {
1012     return name_;
1013   }
1014 
1015   OperatorSetVersion SinceVersion() const {
1016     return since_version_;
1017   }
1018 
1019   int since_version() const {
1020     return since_version_;
1021   }
1022 
1023   bool deprecated() const {
1024     return deprecated_;
1025   }
1026 
1027   int min_input() const {
1028     return min_input_;
1029   }
1030   int max_input() const {
1031     return max_input_;
1032   }
1033   int min_output() const {
1034     return min_output_;
1035   }
1036   int max_output() const {
1037     return max_output_;
1038   }
1039 
1040   bool has_type_and_shape_inference_function() const {
1041     return tensor_inference_function_ ? true : false;
1042   }
1043 
1044   bool has_data_propagation_function() const {
1045     return data_propagation_function_ ? true : false;
1046   }
1047 
1048   std::vector<int> function_opset_versions() const {
1049     std::vector<int> opset_versions;
1050     std::map<int, std::shared_ptr<FunctionProto>>::const_iterator it = opset_version_to_function_body_.cbegin();
1051     for (; it != opset_version_to_function_body_.cend(); ++it) {
1052       opset_versions.push_back(it->first);
1053     }
1054     return opset_versions;
1055   }
1056 
1057   bool HasFunction() const {
1058     return !opset_version_to_function_body_.empty();
1059   }
1060 
1061   OpSchema& FunctionBody(const std::vector<NodeProto>& func_nodes, int opset_version = kUninitializedSinceVersion);
1062 
1063   OpSchema& FunctionBody(
1064       const std::vector<NodeProto>& func_nodes,
1065       const std::vector<OperatorSetIdProto>& opsets,
1066       int opset_version = kUninitializedSinceVersion);
1067 
1068   OpSchema& FunctionBody(const char* func_body, int opset_version = kUninitializedSinceVersion);
1069 
1070   // since_version_ of an OpSchema tells the last opset version when an op is defined.
1071   // When the op's definition is changed, a new OpSchema (of the same op_type) is created
1072   // with a newer since_version_, reflecting the opset version at the time of change.
1073   // For a function op, operators used to define its function body may change
1074   // while there is no change to the function op definition itself.
1075   // When this happens, mutiple function bodies are provided, each for a specific opset version.
1076   //
1077   // Take LogSoftmax for example. Its latest opset version is 13.
1078   // In LogSoftmax's function body, ReduceMax (with since_version_ 1, 11, 12, 18) is used.
1079   // When a model containing LogSoftmax with opset_import version within 13 to 17 is loaded, function body
1080   // with opset_version 13 is used for inlining.
1081   // When the same model but opset_import version 18 is loaded, function body
1082   // with opset_version 18 is used for inlining.
1083   // Clearly function body for opset_import version 13 will not work
1084   // in a model with opset_import version 18 because the function body make worng use of ReduceMax(18).
1085   // Inside GetFunction we ensure that ops being used to construct a function body do not endure such
1086   // issue.
1087   const FunctionProto* GetFunction(
1088       int requested_opset_version = OpSchema::kUninitializedSinceVersion,
1089       bool validate = false) const;
1090 
1091   std::vector<int> context_dependent_function_opset_versions() const {
1092     std::vector<int> opset_versions;
1093     std::map<int, ContextDependentFunctionBodyBuilder>::const_iterator it = opset_version_to_function_builder_.cbegin();
1094     for (; it != opset_version_to_function_builder_.cend(); ++it) {
1095       opset_versions.push_back(it->first);
1096     }
1097     return opset_versions;
1098   }
1099 
1100   bool HasContextDependentFunction() const {
1101     return !opset_version_to_function_builder_.empty();
1102   }
1103 
1104   bool HasContextDependentFunctionWithOpsetVersion(int opset_version) const {
1105     return opset_version_to_function_builder_.find(opset_version) != opset_version_to_function_builder_.end();
1106   }
1107 
1108   OpSchema& SetContextDependentFunctionBodyBuilder(
1109       ContextDependentFunctionBodyBuilder,
1110       int opset_version = kUninitializedSinceVersion);
1111 
1112   bool BuildContextDependentFunction(
1113       const FunctionBodyBuildContext& ctx,
1114       FunctionProto& function_proto,
1115       int requested_opset_version = OpSchema::kUninitializedSinceVersion) const;
1116 
1117   // Verifies that the schema is valid and all specifications are compatible.
1118   // It will also parse all type strings specified for inputs/outputs into valid
1119   // TypeProto and create global unique string pointer as the DataType for
1120   // efficiency.
1121   void Finalize();
1122 
1123   // Build function with information stored in opschema
1124   void BuildFunction(FunctionProto& function_body) const;
1125 
1126  private:
1127   void ParseAndSetTypes(
1128       /*out*/ std::vector<OpSchema::FormalParameter>* formalParameters);
1129   bool ValidateReferencedOpsInFuncton(
1130       const FunctionProto* function,
1131       int requested_opset_version,
1132       int function_since_version,
1133       std::set<std::string>* updated_ops = nullptr) const;
1134   void UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int opset_version) const;
1135 
1136   /**
1137    * @brief A common function to generate a prefix string for use in fail_check during the verify function.
1138    * @param  node_name If empty, the returned string will not include the node name.
1139    * @return std::string The prefix string.
1140    */
1141   std::string VerifyFailPrefix(std::string_view node_name) const;
1142 
1143   /**
1144    * @brief Verifies if the input number matches the pattern specified in the schema.
1145    * @param input_num The number of inputs to be verified against the schema.
1146    * @param node_info The prefix string used if the check fails.
1147    */
1148   void VerifyInputNum(int input_num, std::string_view node_name = "") const;
1149 
1150   /**
1151    * @brief Verifies if the output number matches the pattern specified in the schema.
1152    * @param output_num The number of outputs to be verified against the schema.
1153    * @param node_info The prefix string used if the check fails.
1154    */
1155   void VerifyOutputNum(int output_num, std::string_view node_name = "") const;
1156 
1157   std::string name_;
1158   std::string file_;
1159   std::string doc_;
1160   // Default domain value ("") means it's ONNX domain.
1161   std::string domain_ = ONNX_DOMAIN;
1162   std::map<std::string, Attribute> attributes_{};
1163   bool allows_unchecked_attributes_ = false;
1164   std::vector<FormalParameter> inputs_;
1165   std::vector<FormalParameter> outputs_;
1166   std::vector<TypeConstraintParam> type_constraint_params_;
1167   TypeConstraintMap type_constraints_;
1168   int line_ = 0;
1169   SupportType support_;
1170   int min_input_ = 0;
1171   int max_input_ = 0;
1172   int min_output_ = 0;
1173   int max_output_ = 0;
1174   // The default is a little goofy, since it is never what you want
1175   OperatorSetVersion since_version_ = kUninitializedSinceVersion;
1176   bool deprecated_{};
1177   std::function<bool(int)> num_inputs_allowed_ = [](int) { return true; };
1178   std::function<bool(int)> num_outputs_allowed_ = [](int) { return true; };
1179   InferenceFunction tensor_inference_function_;
1180   DataPropagationFunction data_propagation_function_;
1181 
1182   std::map<int, std::shared_ptr<FunctionProto>> opset_version_to_function_body_;
1183   std::map<int, ContextDependentFunctionBodyBuilder> opset_version_to_function_builder_;
1184 };
1185 
1186 // Map type to store operator schemas. The format is,
1187 // <OpName, <Domain, <OperatorSetVersion, OpSchema>>>.
1188 using OpName_Domain_Version_Schema_Map =
1189     std::unordered_map<std::string, std::unordered_map<std::string, std::map<OperatorSetVersion, OpSchema>>>;
1190 
1191 class ISchemaRegistry {
1192  public:
1193   virtual ~ISchemaRegistry() = default;
1194 
1195   virtual const OpSchema*
1196   GetSchema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) const = 0;
1197 };
1198 
1199 /**
1200  * @brief A registry to hold all the operator schemas.
1201  */
1202 class OpSchemaRegistry final : public ISchemaRegistry {
1203  public:
1204   // A singleton class to store domain to min/max op_set version map, as well as
1205   // domain to last-release op_set version map.
1206   class DomainToVersionRange final {
1207    public:
1208     DomainToVersionRange() {
1209       // Increase the highest version when you make BC-breaking changes to the
1210       // operator schema on specific domain. Update the lowest version when it's
1211       // determined to remove too old version history.
1212       map_[ONNX_DOMAIN] = std::make_pair(1, 22);
1213       map_[AI_ONNX_ML_DOMAIN] = std::make_pair(1, 5);
1214       map_[AI_ONNX_TRAINING_DOMAIN] = std::make_pair(1, 1);
1215       // ONNX's preview domain contains operators subject to change, so
1216       // versining is not meaningful and that domain should have only one
1217       // version.
1218       map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 1);
1219       // Version corresponding last release of ONNX. Update this to match with
1220       // the max version above in a *release* version of ONNX. But in other
1221       // versions, the max version may be ahead of the last-release-version.
1222       last_release_version_map_[ONNX_DOMAIN] = 22;
1223       last_release_version_map_[AI_ONNX_ML_DOMAIN] = 5;
1224       last_release_version_map_[AI_ONNX_TRAINING_DOMAIN] = 1;
1225       last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 1;
1226     }
1227 
1228     const std::unordered_map<std::string, std::pair<int, int>>& Map() const {
1229       return map_;
1230     }
1231 
1232     const std::unordered_map<std::string, int>& LastReleaseVersionMap() const {
1233       return last_release_version_map_;
1234     }
1235 
1236     // Add customized domain to min/max version.
1237     // Onnx partners are able to use onnx operator schema api to
1238     // register customized op in their own domain.
1239     // Can optionally specify last_release_version (to make it similar to
1240     // standard ONNX domains as above). Custom-domains are free to interpret
1241     // this as appropriate (that is, as relative to releases of custom-domain
1242     // as opposed to ONNX releases).
1243     void
1244     AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) {
1245       std::lock_guard<std::mutex> lock(mutex_);
1246       if (map_.count(domain) != 0) {
1247         std::stringstream err;
1248         err << "Trying to add a domain to DomainToVersion map, but the domain is already exist with version range ("
1249             << map_.at(domain).first << ", " << map_.at(domain).second << "). domain: \"" << domain << "\""
1250             << std::endl;
1251         fail_schema(err.str());
1252       }
1253       if (last_release_version_map_.count(domain) != 0) {
1254         std::stringstream err;
1255         err << "Trying to add a domain to LastReleaseVersion map, but the domain is already exist with last version: "
1256             << last_release_version_map_.at(domain) << ", domain: \"" << domain << "\"" << std::endl;
1257         fail_schema(err.str());
1258       }
1259       map_[domain] = std::make_pair(min_version, max_version);
1260       // If a last-release-version is not explicitly specified, use max as
1261       // last-release-version.
1262       if (last_release_version == -1) {
1263         last_release_version = max_version;
1264       }
1265       last_release_version_map_[domain] = last_release_version;
1266     }
1267 
1268     void
1269     UpdateDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) {
1270       std::lock_guard<std::mutex> lock(mutex_);
1271       if (map_.count(domain) == 0) {
1272         std::stringstream err;
1273         err << "Trying to update a domain in DomainToVersion map, but the domain has not been add. domain: \"" << domain
1274             << "\"" << std::endl;
1275         fail_schema(err.str());
1276       }
1277       if (last_release_version_map_.count(domain) == 0) {
1278         std::stringstream err;
1279         err << "Trying to update a domain in LastReleaseVersion map, but the domain has not been add. domain: \""
1280             << domain << "\"" << std::endl;
1281         fail_schema(err.str());
1282       }
1283       map_.at(domain).first = min_version;
1284       map_.at(domain).second = max_version;
1285       // Correspond to `AddDomainToVersion`
1286       if (last_release_version == -1) {
1287         last_release_version = max_version;
1288       }
1289       last_release_version_map_.at(domain) = last_release_version;
1290     }
1291 
1292     static DomainToVersionRange& Instance();
1293 
1294    private:
1295     // Key: domain. Value: <lowest version, highest version> pair.
1296     std::unordered_map<std::string, std::pair<int, int>> map_;
1297 
1298     // Key: domain. Value: most recent release opset version. Note that
1299     // the highest opset version may be ahead of the most recent release's opset
1300     // version.
1301     std::unordered_map<std::string, int> last_release_version_map_;
1302 
1303     std::mutex mutex_;
1304   };
1305 
1306   class OpSchemaRegisterOnce final {
1307    public:
1308     // Export to cpp custom register macro
1309     OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1310       OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema);
1311     }
1312     static void
1313     OpSchemaRegisterNoExcept(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1314       ONNX_TRY {
1315         OpSchemaRegisterImpl(std::move(op_schema), opset_version_to_load, fail_duplicate_schema);
1316       }
1317       ONNX_CATCH(const std::exception& e) {
1318         ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; });
1319       }
1320     }
1321     static void
1322     OpSchemaRegisterImpl(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1323       op_schema.Finalize();
1324       auto& m = GetMapWithoutEnsuringRegistration();
1325       auto& op_name = op_schema.Name();
1326       auto& op_domain = op_schema.domain();
1327       auto& schema_ver_map = m[op_name][op_domain];
1328       auto ver = op_schema.SinceVersion();
1329       if (OpSchema::kUninitializedSinceVersion == ver) {
1330         op_schema.SinceVersion(1);
1331         ver = op_schema.SinceVersion();
1332       }
1333 
1334       // Stops because the exact opset_version is registered
1335       if (schema_ver_map.count(ver)) {
1336         if (fail_duplicate_schema) {
1337           const auto& schema = schema_ver_map[ver];
1338           std::stringstream err;
1339           err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1340               << ") from file " << op_schema.file() << " line " << op_schema.line()
1341               << ", but it is already registered from file " << schema.file() << " line " << schema.line() << std::endl;
1342           fail_schema(err.str());
1343         }
1344         return;
1345       }
1346 
1347       if (opset_version_to_load != 0) {
1348         // Stops because the opset_version is higher than opset_version_to_load
1349         if (ver > opset_version_to_load) {
1350           return;
1351         }
1352 
1353         // Stops because a later version is registered within target opset version
1354         if (!schema_ver_map.empty()) {
1355           int max_registered_ver_le_target = GetMaxRegisteredVerWithinTarget(schema_ver_map, opset_version_to_load);
1356           if (max_registered_ver_le_target >= ver) {
1357             return;
1358           }
1359         }
1360       }
1361 
1362       CheckDomainAndVersionToRegister(op_schema, op_name, op_domain);
1363       schema_ver_map.insert(std::pair<int, OpSchema&&>(ver, std::move(op_schema)));
1364     }
1365 
1366    private:
1367     // Gets the maximum version from given map that is less or equal to target version
1368     static int GetMaxRegisteredVerWithinTarget(const std::map<OperatorSetVersion, OpSchema>& m, int target_ver) {
1369       // std::map is sorted on key
1370       // reverse iterator returns the largest element keyed on the integer version
1371       for (auto&& it = m.rbegin(); it != m.rend(); it++) {
1372         const auto& registered_ver = it->first;
1373         if (registered_ver <= target_ver) {
1374           return registered_ver;
1375         }
1376       }
1377       return -1;
1378     }
1379 
1380     static void CheckDomainAndVersionToRegister(
1381         const OpSchema& op_schema,
1382         const std::string& op_name,
1383         const std::string& op_domain) {
1384       auto ver_range_map = DomainToVersionRange::Instance().Map();
1385       auto ver_range_it = ver_range_map.find(op_domain);
1386       auto ver = op_schema.SinceVersion();
1387 
1388       if (ver_range_it == ver_range_map.end()) {
1389         std::stringstream err;
1390         err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1391             << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its domain is not"
1392             << " known by the checker." << std::endl;
1393 
1394         fail_schema(err.str());
1395       }
1396       auto lower_bound_incl = ver_range_it->second.first;
1397       auto upper_bound_incl = ver_range_it->second.second;
1398       if (!(lower_bound_incl <= ver && upper_bound_incl >= ver)) {
1399         std::stringstream err;
1400         err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1401             << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its version is not "
1402             << "in the inclusive range [" << lower_bound_incl << ", " << upper_bound_incl
1403             << "] (usually, this means you "
1404             << "bumped the operator version but "
1405             << "forgot to update the version range in DomainToVersionRange "
1406             << "in onnx/defs/schema.h)." << std::endl;
1407         fail_schema(err.str());
1408       }
1409     }
1410   };
1411 
1412   static void
1413   OpSchemaDeregister(const std::string& op_type, const int version, const std::string& domain = ONNX_DOMAIN) {
1414     auto& schema_map = GetMapWithoutEnsuringRegistration();
1415     if (schema_map.count(op_type) && schema_map[op_type].count(domain) && schema_map[op_type][domain].count(version)) {
1416       schema_map[op_type][domain].erase(version);
1417     } else {
1418       std::stringstream err;
1419       err << "Attempting to deregister an unregistered schema with name: " << op_type << " domain: " << domain
1420           << " version: " << version << std::endl;
1421       fail_schema(err.str());
1422     }
1423   }
1424 
1425   // Deregister all ONNX opset schemas from domain
1426   // Domain with default value ONNX_DOMAIN means ONNX.
1427   static void OpSchemaDeregisterAll(const std::string& domain = ONNX_DOMAIN) {
1428     auto& schema_map = GetMapWithoutEnsuringRegistration();
1429     // schema_map stores operator schemas in the format of
1430     // <OpName, <Domain, <OperatorSetVersion, OpSchema>>>
1431     for (auto&& schema_map_pair : schema_map) {
1432       auto& domain_map = schema_map_pair.second;
1433       if (domain_map.count(domain)) {
1434         auto& opset_version_schema_map = domain_map[domain];
1435         // Invalidates ver-schema pairs and frees memory, leaving m[op_name][op_domain] empty
1436         opset_version_schema_map.clear();
1437         domain_map.erase(domain);
1438       }
1439     }
1440   }
1441 
1442   // Return the latest schema for an operator in specified domain.
1443   // Domain with default value ONNX_DOMAIN means ONNX.
1444   static const OpSchema* Schema(const std::string& key, const std::string& domain = ONNX_DOMAIN) {
1445     auto& m = map();
1446     if (m.count(key) && m[key].count(domain)) {
1447       const auto& schema_ver_map = m[key][domain];
1448       if (!schema_ver_map.empty()) {
1449         return &m[key][domain].rbegin()->second;
1450       }
1451     }
1452     return nullptr;
1453   }
1454 
1455   // Return the schema with biggest version, which is not greater than specified
1456   // <maxInclusiveVersion> in specified domain. Domain with default value
1457   // ONNX_DOMAIN means ONNX.
1458   static const OpSchema*
1459   Schema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) {
1460     auto& m = map();
1461     if (m.count(key) && m[key].count(domain)) {
1462       const auto& schema_ver_map = m[key][domain];
1463       if (!schema_ver_map.empty()) {
1464         auto pos = m[key][domain].lower_bound(maxInclusiveVersion);
1465         if (m[key][domain].begin() == pos && pos->first > maxInclusiveVersion) {
1466           // All versions are greater than specified version.
1467           return nullptr;
1468         }
1469         if (m[key][domain].end() == pos || pos->first > maxInclusiveVersion) {
1470           // All versions are less than specified version, or,
1471           // The <pos> version is greater than specified version.
1472           pos--;
1473         }
1474 
1475         // Schema with exact version as specified one exists.
1476         return &(pos->second);
1477       }
1478     }
1479     return nullptr;
1480   }
1481 
1482   static OpSchemaRegistry* Instance();
1483 
1484   const OpSchema* GetSchema(
1485       const std::string& key,
1486       const int maxInclusiveVersion,
1487       const std::string& domain = ONNX_DOMAIN) const override {
1488     return Schema(key, maxInclusiveVersion, domain);
1489   }
1490   static void SetLoadedSchemaVersion(int target_version) {
1491     loaded_schema_version = target_version;
1492   }
1493   static int GetLoadedSchemaVersion() {
1494     return loaded_schema_version;
1495   }
1496 
1497  private:
1498   // OpSchemaRegistry should not need to be instantiated except statically
1499   // within this class
1500   OpSchemaRegistry() = default;
1501 
1502   /**
1503    * @brief Returns the underlying string to OpSchema map.
1504    *
1505    * You should not manually manipulate the map object returned. Instead, use
1506    * the macros defined such as ONNX_OPERATOR_SET_SCHEMA to register your
1507    * operator schema.
1508    *
1509    * We wrap it inside a function to avoid the static initialization order
1510    * fiasco.
1511    */
1512   static OpName_Domain_Version_Schema_Map& GetMapWithoutEnsuringRegistration();
1513   static OpName_Domain_Version_Schema_Map& map();
1514   static int loaded_schema_version;
1515 
1516  public:
1517   static const std::vector<OpSchema> get_all_schemas_with_history() {
1518     std::vector<OpSchema> r;
1519     for (auto& x : map()) {
1520       for (auto& y : x.second) {
1521         for (auto& z : y.second) {
1522           r.emplace_back(z.second);
1523         }
1524       }
1525     }
1526     return r;
1527   }
1528 
1529   static const std::vector<OpSchema> get_all_schemas() {
1530     std::vector<OpSchema> r;
1531     for (auto& x : map()) {
1532       for (auto& y : x.second) {
1533         auto& version2schema = y.second;
1534         if (!version2schema.empty()) {
1535           r.emplace_back(version2schema.rbegin()->second);
1536         }
1537       }
1538     }
1539     return r;
1540   }
1541 };
1542 
1543 void RegisterSchema(
1544     const OpSchema& schema,
1545     int opset_version_to_load = 0,
1546     bool fail_duplicate_schema = true,
1547     bool fail_with_exception = false);
1548 void RegisterSchema(
1549     OpSchema&& schema,
1550     int opset_version_to_load = 0,
1551     bool fail_duplicate_schema = true,
1552     bool fail_with_exception = false);
1553 void DeregisterSchema(const std::string& op_type, int version, const std::string& domain);
1554 
1555 // Registers the latest opset schema before opset_version_to_load
1556 // By default opset_version_to_load=0 means it will register all versions
1557 template <class T>
1558 void RegisterOpSetSchema(int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1559   T::ForEachSchema([opset_version_to_load, fail_duplicate_schema](OpSchema&& schema) {
1560     RegisterSchema(std::move(schema), opset_version_to_load, fail_duplicate_schema);
1561   });
1562 };
1563 
1564 // Forward declaration for the non-specialized GetOpSchema method.  This
1565 // enforces a consistent signature on functions that query individual schema,
1566 // which are defined as specializations of this function.
1567 template <typename T>
1568 OpSchema GetOpSchema();
1569 
1570 #define ONNX_OPERATOR_SET_SCHEMA(name, ver, impl) ONNX_OPERATOR_SET_SCHEMA_EX(name, Onnx, ONNX_DOMAIN, ver, true, impl)
1571 
1572 #define ONNX_ML_OPERATOR_SET_SCHEMA(name, ver, impl) \
1573   ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxML, AI_ONNX_ML_DOMAIN, ver, true, impl)
1574 
1575 #define ONNX_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
1576   ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxTraining, AI_ONNX_TRAINING_DOMAIN, ver, true, impl)
1577 
1578 #define ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
1579   ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxPreview, AI_ONNX_PREVIEW_TRAINING_DOMAIN, ver, true, impl)
1580 
1581 // Defines specialization of GetOpSchema for a class whose name is determined
1582 // based on a convention using name, domain, and version.  Operator schema are
1583 // normally included in operator sets and registered in OpSchemaRegistry::map().
1584 // In this case, callers should set dbg_included_in_static_opset to true.  This
1585 // assists with runtime validation in DEBUG builds ensuring the intended set
1586 // of operator schema is registered.
1587 #define ONNX_OPERATOR_SET_SCHEMA_EX(name, domain, domain_str, ver, dbg_included_in_static_opset, impl)  \
1588   class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name);                                         \
1589   template <>                                                                                           \
1590   OpSchema GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name)>() {                      \
1591     return impl.SetName(#name).SetDomain(domain_str).SinceVersion(ver).SetLocation(__FILE__, __LINE__); \
1592   }                                                                                                     \
1593   size_t dbg_count_check_##name##_##domain##_ver##ver =                                                 \
1594       (dbg_included_in_static_opset) ? ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() : 0;
1595 #ifdef NDEBUG
1596 #define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() 0
1597 #else
1598 #define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().IncrementCount()
1599 #define ONNX_DBG_GET_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().GetCount()
1600 
1601 class DbgOperatorSetTracker {
1602  public:
1603   static DbgOperatorSetTracker& Instance();
1604 
1605   size_t IncrementCount() {
1606     return ++count_;
1607   }
1608 
1609   size_t GetCount() const {
1610     return count_;
1611   }
1612 
1613  private:
1614   size_t count_ = 0;
1615 };
1616 #endif
1617 
1618 // Naming convention for operator schema classes
1619 #define ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name) name##_##domain##_ver##ver
1620 
1621 // Naming convention for preview operator schema classes
1622 #define ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(ver, name) \
1623   ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxPreview, ver, name)
1624 
1625 // Helper function
1626 size_t ReplaceAll(std::string& s, const char* from, const char* to);
1627 
1628 #ifdef __GNUC__
1629 #define ONNX_UNUSED __attribute__((__unused__))
1630 #else
1631 #define ONNX_UNUSED
1632 #endif
1633 
1634 // Legacy macros to register schema at static initialization
1635 #define ONNX_OPERATOR_SCHEMA(name) ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name)
1636 #define ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)
1637 #define ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)                                                                      \
1638   static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce(op_schema_register_once##name##Counter) ONNX_UNUSED = \
1639       OpSchema(#name, __FILE__, __LINE__)
1640 
1641 // Helper function
1642 size_t ReplaceAll(std::string& s, const char* from, const char* to);
1643 
1644 inline std::string GenerateOptionalArgumentsDoc() {
1645   return "This operator has **optional** inputs/outputs. "
1646          "See [the doc](IR.md) for more details about the representation of "
1647          "optional arguments. An empty string may be used in the place of "
1648          "an actual argument's name to indicate a missing argument. "
1649          "Trailing optional arguments (those not followed by an argument "
1650          "that is present) may also be simply omitted.\n";
1651 }
1652 
1653 inline std::string GenerateBroadcastingDocMul() {
1654   return "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**;"
1655          " for more details please check [the doc](Broadcasting.md).";
1656 }
1657 
1658 inline std::string GenerateBroadcastingDocUni(const char* from, const char* to) {
1659   std::string ret = "This operator supports **unidirectional broadcasting** (";
1660   ret = ret + from + " should be unidirectional broadcastable to " + to +
1661       ");"
1662       " for more details please check [the doc](Broadcasting.md).";
1663   return ret;
1664 }
1665 
1666 /*
1667  * Macros for setting operator documentation
1668  * Use this macro for simple SetDoc() calls that generate documentation
1669  * directly. This is the macro to use in almost all cases.
1670  * Sample usage guidelines:
1671  * const char* doc_str = "foo";
1672  * SetDoc(GET_OP_DOC_STR(doc_str))
1673  *
1674  * SetDoc(GET_OP_DOC_STR(
1675             std::string(BitShift_ver11_doc) + GenerateBroadcastingDocMul()))
1676  */
1677 #ifndef __ONNX_NO_DOC_STRINGS
1678 #define GET_OP_DOC_STR(doc_str) (doc_str)
1679 #else
1680 #define GET_OP_DOC_STR(doc_str) ("")
1681 #endif
1682 
1683 /*
1684  * Use this macro when the documentation needs to be populated in some
1685  * complicated way like string substitutions, etc before calling SetDoc.
1686  * Sample usage guidelines:
1687     std::string doc;
1688     POPULATE_OP_DOC_STR(
1689         doc = R"DOC(
1690 Returns the tensor resulted from performing the `{name}` logical operation
1691 elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting
1692 support).
1693 
1694 {broadcast_doc}
1695 )DOC";
1696         ReplaceAll(doc, "{name}", name);
1697         ReplaceAll(
1698             doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
1699     schema.SetDoc(doc);
1700  *
1701  */
1702 #ifndef __ONNX_NO_DOC_STRINGS
1703 #define POPULATE_OP_DOC_STR(DocPopulatorCode) \
1704   do {                                        \
1705     DocPopulatorCode                          \
1706   } while (0)
1707 #else
1708 #define POPULATE_OP_DOC_STR(DocPopulatorCode)
1709 #endif
1710 
1711 } // namespace ONNX_NAMESPACE