Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:46

0001 /*
0002  * SPDX-License-Identifier: Apache-2.0
0003  */
0004 
0005 #pragma once
0006 
0007 #include <climits>
0008 #include <cstring>
0009 #include <functional>
0010 #include <initializer_list>
0011 #include <iostream>
0012 #include <limits>
0013 #include <map>
0014 #include <memory>
0015 #include <ostream>
0016 #include <set>
0017 #include <string>
0018 #include <tuple>
0019 #include <unordered_map>
0020 #include <unordered_set>
0021 #include <utility>
0022 #include <vector>
0023 
0024 #include "onnx/common/common.h"
0025 #include "onnx/common/constants.h"
0026 #include "onnx/defs/shape_inference.h"
0027 
0028 namespace ONNX_NAMESPACE {
0029 
0030 struct FunctionBodyBuildContext {
0031   virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
0032   virtual bool hasInput(int inputIndex) const = 0;
0033   virtual bool hasOutput(int inputIndex) const = 0;
0034   // getInputType(i) should return null for missing optional inputs, or if
0035   // type-inference could not infer the input-type (erroneous model).
0036   virtual const TypeProto* getInputType(int inputIndex) const = 0;
0037   virtual ~FunctionBodyBuildContext() {}
0038 };
0039 
0040 struct FunctionBodyBuildContextImpl : public FunctionBodyBuildContext {
0041   // Input_types: use a default TypeProto for missing types. We use a different convention
0042   // here (from FunctionBodyBuildContext) to simplify python interoperability.
0043   // The default value for input_types is included only for backward compatibility.
0044   // It can be used for functions that do not depend on the type-context, but
0045   // will not be sufficient for functions that do use the type-context.
0046   FunctionBodyBuildContextImpl(const NodeProto& node_proto, const std::vector<TypeProto>& input_types = {})
0047       : node_proto_(node_proto), input_types_(input_types) {
0048     for (auto& attr : node_proto.attribute()) {
0049       attributesByName_[attr.name()] = &attr;
0050     }
0051   }
0052 
0053   const AttributeProto* getAttribute(const std::string& name) const override {
0054     auto iter = attributesByName_.find(name);
0055     if (iter == attributesByName_.end()) {
0056       return nullptr;
0057     } else {
0058       return iter->second;
0059     }
0060   }
0061 
0062   bool hasInput(int inputIndex) const override {
0063     if (inputIndex >= node_proto_.input_size())
0064       return false;
0065     return node_proto_.input(inputIndex) != "";
0066   }
0067 
0068   bool hasOutput(int inputIndex) const override {
0069     if (inputIndex >= node_proto_.output_size())
0070       return false;
0071     return node_proto_.output(inputIndex) != "";
0072   }
0073 
0074   const TypeProto* getInputType(int inputIndex) const override {
0075     if (inputIndex < 0)
0076       return nullptr;
0077     size_t j = static_cast<size_t>(inputIndex);
0078     if (j >= input_types_.size())
0079       return nullptr;
0080     // Convert default value (no variant set) into null.
0081     if (input_types_[j].value_case() == TypeProto::ValueCase::VALUE_NOT_SET)
0082       return nullptr;
0083     return &input_types_[j];
0084   }
0085 
0086   std::unordered_map<std::string, const AttributeProto*> attributesByName_;
0087 
0088   NodeProto node_proto_;
0089   std::vector<TypeProto> input_types_;
0090 };
0091 
0092 using FunctionBodyQueryFunction = std::function<bool(FunctionBodyBuildContext&)>;
0093 
0094 class OpSchema;
0095 using ContextDependentFunctionBodyBuilder =
0096     std::function<bool(const FunctionBodyBuildContext&, const OpSchema&, FunctionProto&)>;
0097 
0098 class SchemaError final : public std::runtime_error {
0099  public:
0100   using std::runtime_error::runtime_error;
0101 
0102   SchemaError(const std::string& message) : std::runtime_error(message) {}
0103 
0104   const char* what() const noexcept override {
0105     if (!expanded_message_.empty()) {
0106       return expanded_message_.c_str();
0107     }
0108     return std::runtime_error::what();
0109   }
0110 
0111   void AppendContext(const std::string& context) {
0112     expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
0113   }
0114 
0115  private:
0116   std::string expanded_message_;
0117 };
0118 
0119 #define fail_schema(...) ONNX_THROW_EX(ONNX_NAMESPACE::SchemaError(ONNX_NAMESPACE::MakeString(__VA_ARGS__)));
0120 
0121 using OperatorSetVersion = int;
0122 
0123 using DataTypeSet = std::unordered_set<DataType>;
0124 
0125 // Type constraint map. Key is type string. Value is data type set and
0126 // description.
0127 using TypeConstraintMap = std::unordered_map<std::string, std::pair<DataTypeSet, std::string>>;
0128 
0129 /**
0130  * @brief A class to record the schema of an op.
0131  *
0132  * OpSchema records the common interface of an op specified by its name.
0133  *
0134  * To register an OpSchema, one can use the macro ONNX_OPERATOR_SCHEMA(name) and
0135  * then append the various functions in the class. For example, for an op
0136  * that takes in two inputs, one output, and the first input and output
0137  * could be in-place, can be written as
0138  *
0139  *     ONNX_OPERATOR_SCHEMA(name)
0140  *         .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}});
0141  *
0142  * To manufacture methods that may be used to register an OpSchema
0143  * non-statically, the following may be used:
0144  *
0145  *     ONNX_OPERATOR_SET_SCHEMA(name, version, OpSchema()
0146  *         .NumInputs(2).NumOutputs(1).AllowConsumed({{0, 0}}));
0147  */
0148 class OpSchema final {
0149  public:
0150   static constexpr int kUninitializedSinceVersion = -1;
0151   // Formal parameter options.
0152   enum FormalParameterOption : uint8_t {
0153     // The formal parameter is single and not optional.
0154     // Number of supplied actual parameters must be 1.
0155     Single = 0,
0156     // The formal parameter is single and optional.
0157     // Number of supplied actual parameters may be 0 or 1.
0158     Optional = 1,
0159     // The formal parameter is variadic.
0160     // Number of supplied actual parameters must be N or more, where
0161     // the minimum value N is indicated separately (default value 1).
0162     Variadic = 2,
0163   };
0164   enum DifferentiationCategory : uint8_t {
0165     // Whether this formal parameter is differentiable or not cannot
0166     // be statically determined. It also covers variadic formal
0167     // parameters which contain both of differentiable and
0168     // non-differentiable variables.
0169     Unknown = 0,
0170     // This formal parameter is differentiable. That is, this formal
0171     // parameter can be differentiable input of Gradient operator.
0172     Differentiable = 1,
0173     // This formal parameter is not differentiable. That is, this formal
0174     // parameter can not be differentiable input of Gradient operator.
0175     NonDifferentiable = 2
0176   };
0177 
0178   // Formal parameter represenation, including input/output name, typeStr,
0179   // description, and type constraints.
0180   class FormalParameter final {
0181    public:
0182     // Constructor.
0183     FormalParameter() = default;
0184 
0185     explicit FormalParameter(
0186         std::string name,
0187         DataTypeSet allowed_type_set,
0188         std::string type_str,
0189         const std::string& description,
0190         FormalParameterOption param_option = Single,
0191         bool is_homogeneous = true,
0192         int min_arity = 1,
0193         DifferentiationCategory differentiation_category = Unknown)
0194         : name_(std::move(name)),
0195           type_set_(std::move(allowed_type_set)),
0196           type_str_(std::move(type_str)),
0197 #ifndef __ONNX_NO_DOC_STRINGS
0198           description_(description),
0199 #endif
0200           param_option_(param_option),
0201           is_homogeneous_(is_homogeneous),
0202           min_arity_(min_arity),
0203           differentiation_category_(differentiation_category) {
0204 #ifdef __ONNX_NO_DOC_STRINGS
0205       ONNX_UNUSED_PARAMETER(description);
0206 #endif
0207     }
0208 
0209     explicit FormalParameter(
0210         std::string name,
0211         const std::string& description,
0212         std::string type_str,
0213         FormalParameterOption param_option = Single,
0214         bool is_homogeneous = true,
0215         int min_arity = 1,
0216         DifferentiationCategory differentiation_category = Unknown)
0217         : name_(std::move(name)),
0218           type_str_(std::move(type_str)),
0219 #ifndef __ONNX_NO_DOC_STRINGS
0220           description_(description),
0221 #endif
0222           param_option_(param_option),
0223           is_homogeneous_(is_homogeneous),
0224           min_arity_(min_arity),
0225           differentiation_category_(differentiation_category) {
0226 #ifdef __ONNX_NO_DOC_STRINGS
0227       ONNX_UNUSED_PARAMETER(description);
0228 #endif
0229     }
0230 
0231     // Get formal parameter name.
0232     const std::string& GetName() const;
0233 
0234     // Get allowed data types.
0235     const DataTypeSet& GetTypes() const;
0236 
0237     // Get formal parameter type string.
0238     const std::string& GetTypeStr() const;
0239 
0240     // Get formal parameter description.
0241     const std::string& GetDescription() const;
0242 
0243     // Get the parameter option, it could be Single, Optional or Variadic.
0244     FormalParameterOption GetOption() const;
0245 
0246     // Get whether a variadic parameter requires all to be of same type
0247     bool GetIsHomogeneous() const;
0248 
0249     // Get minimum arity. Applicable only in the Variadic case.
0250     int GetMinArity() const;
0251 
0252     // Get the differentiation property of this formal parameter.
0253     DifferentiationCategory GetDifferentiationCategory() const;
0254 
0255    private:
0256     friend class OpSchema;
0257 
0258     DataTypeSet& MutableTypes();
0259 
0260     // Formal parameter name.
0261     std::string name_;
0262 
0263     // A set of data types supported for <*this> formal parameter.
0264     // It should contain at least one element if this formal parameter is good.
0265     DataTypeSet type_set_;
0266 
0267     // The <parameter type> string specified when registring an op.
0268     // It could be a supported data type or a type constraint key, which
0269     // maps to a set of supported data types.
0270     std::string type_str_;
0271 
0272     // Formal parameter description.
0273     std::string description_;
0274 
0275     // Formal parameter option.
0276     FormalParameterOption param_option_;
0277 
0278     // For variadic parameters, a flag indicating if all parameters must be of
0279     // same type
0280     bool is_homogeneous_;
0281 
0282     // Minimum number of parameters expected. Applicable only for Variadic.
0283     int min_arity_;
0284 
0285     // True if this parameter can be an differentiable inputs of Gradient.
0286     // Otherwise, using this parameter as an differentiable inputs of Gradient
0287     // is prohibited.
0288     DifferentiationCategory differentiation_category_;
0289   };
0290 
0291   enum class SupportType : uint8_t {
0292     COMMON, // Supported by all frameworks that support this IR.
0293     EXPERIMENTAL, // This OP is experimental and can be changed or removed in
0294                   // the future.
0295   };
0296 
0297   OpSchema() : OpSchema("unknown", "unknown", 0) {}
0298   OpSchema(std::string name, std::string file, int line)
0299       : name_(std::move(name)), file_(std::move(file)), line_(line), support_(SupportType::COMMON) {}
0300 
0301   /**
0302    * @brief Returns the file that the op schema is registered from.
0303    */
0304   const std::string& file() const {
0305     return file_;
0306   }
0307 
0308   /**
0309    * @brief Returns the line in file that the op schema is registered from.
0310    */
0311   int line() const {
0312     return line_;
0313   }
0314 
0315   /**
0316    * @brief Returns the support level of the op schema.
0317    */
0318   SupportType support_level() const {
0319     return support_;
0320   }
0321 
0322   /**
0323    * @brief Returns the docstring of the op schema.
0324    */
0325   const char* doc() const {
0326     return doc_.empty() ? nullptr : doc_.c_str();
0327   }
0328 
0329   // Check if input and output types fall into valid set and match each other
0330   void CheckInputOutputType(struct InferenceContext&) const;
0331 
0332   /**
0333    * @brief Verifies if a NodeProto matches the pattern specified in
0334    * the schema.
0335    */
0336   void Verify(const NodeProto& node) const;
0337 
0338   // Functions to set the property of the operator schemas.
0339   // Sets the number of inputs, either a fixed number or a min and a max.
0340 
0341   /**
0342    * The earliest operator set version which this operator was
0343    * present in.  If an operator has had no BC-breaking changes,
0344    * this is simply the first operator set the operator was a member
0345    * of; if it has had BC-breaking changes, then for the semantics
0346    * /as described/ in the OpSchema entry, this version describes
0347    * the operator set which introduced the BC-breaking change.
0348    *
0349    * For example, suppose op Foo was added in v3, and had a BC-breaking
0350    * change in v6.  Then there will be an op schema entry for Foo with
0351    * SinceVersion(3), and another, updated op schema entry for Foo
0352    * with SinceVersion(6).
0353    */
0354   OpSchema& SinceVersion(OperatorSetVersion n); // aka int
0355 
0356   /**
0357    * Marks this op as deprecated as of it's since_version. This will cause the
0358    * Schema() lookup functions to return nullptr when the version is in the
0359    * deprecated range.
0360    */
0361   OpSchema& Deprecate();
0362 
0363   bool Deprecated() const {
0364     return deprecated_;
0365   }
0366 
0367   /**
0368    * @brief Input could be one of the values specified in allowed_input_nums.
0369    */
0370   OpSchema& NumInputs(std::set<int> allowed_input_nums);
0371 
0372   /**
0373    * @brief Output could be one of the values specified in allowed_output_nums.
0374    */
0375   OpSchema& NumOutputs(std::set<int> allowed_output_nums);
0376 
0377   // Shape Inference
0378   //
0379   // Note that signatures are defined to allow for forward-declaring
0380   // any structs used from ir.h
0381   OpSchema& TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
0382   InferenceFunction GetTypeAndShapeInferenceFunction() const {
0383     return tensor_inference_function_ ? tensor_inference_function_ : dummyInferenceFunction;
0384   }
0385 
0386   OpSchema& PartialDataPropagationFunction(DataPropagationFunction dataProgationFunction);
0387   DataPropagationFunction GetDataPropagationFunction() const {
0388     return data_propagation_function_ ? data_propagation_function_ : dummyDataPropagationFunction;
0389   }
0390 
0391   // Set the support level for the op schema.
0392   OpSchema& SetSupportLevel(SupportType supportType);
0393 
0394   // Functions to do documentation for the operator schema.
0395   // This may be disabled to save memory.
0396   OpSchema& SetDoc(const char* doc) {
0397 #ifndef __ONNX_NO_DOC_STRINGS
0398     SetDoc(std::string(doc));
0399 #else
0400     ONNX_UNUSED_PARAMETER(doc);
0401 #endif
0402 
0403     return *this;
0404   }
0405 
0406   OpSchema& SetDoc(const std::string& doc) {
0407 #ifndef __ONNX_NO_DOC_STRINGS
0408     doc_ = doc;
0409 #else
0410     ONNX_UNUSED_PARAMETER(doc);
0411 #endif
0412     return *this;
0413   }
0414 
0415   // Functions to specify name for the operator schema.
0416   OpSchema& SetName(const char* name);
0417   OpSchema& SetName(std::string name);
0418 
0419   // Functions to specify code location for the operator schema.
0420   OpSchema& SetLocation(const char* file, int line);
0421   OpSchema& SetLocation(std::string file, int line);
0422 
0423   // Functions to specify domain for the operator schema.
0424   // Default domain value (ONNX_DOMAIN) means it's ONNX domain.
0425   OpSchema& SetDomain(const char* domain);
0426   OpSchema& SetDomain(std::string domain);
0427 
0428   struct Attribute final {
0429     Attribute(std::string name_, std::string description_, AttributeProto::AttributeType type_, bool required_)
0430         : name(std::move(name_)),
0431           description(std::move(description_)),
0432           type(type_),
0433           required(required_),
0434           default_value() {}
0435 
0436     Attribute(std::string name_, std::string description_, AttributeProto default_value_)
0437         : name(std::move(name_)),
0438           description(std::move(description_)),
0439           type(default_value_.type()),
0440           required(false),
0441           default_value(std::move(default_value_)) {}
0442 
0443     const std::string name;
0444     const std::string description;
0445     AttributeProto::AttributeType type;
0446     bool required;
0447     AttributeProto default_value;
0448   };
0449 
0450   OpSchema& Attr(Attribute attr);
0451 
0452 // Register "optional" attribute with default value.
0453 #define ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName)                                                                    \
0454   OpSchema& Attr(                                                                                                   \
0455       std::string name, std::string description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
0456   /* non-STL wrapper to reduce binary size */                                                                       \
0457   OpSchema& Attr(                                                                                                   \
0458       const char* name, const char* description, AttributeProto::AttributeType type, const TypeName& defaultValue); \
0459   OpSchema& Attr(                                                                                                   \
0460       std::string name,                                                                                             \
0461       std::string description,                                                                                      \
0462       AttributeProto::AttributeType type,                                                                           \
0463       const std::vector<TypeName>& defaultValue);
0464 
0465   ATTR_SETTER_WITH_DEFAULT_VALUE(int64_t)
0466   ATTR_SETTER_WITH_DEFAULT_VALUE(float)
0467   ATTR_SETTER_WITH_DEFAULT_VALUE(std::string)
0468   ATTR_SETTER_WITH_DEFAULT_VALUE(TensorProto)
0469   ATTR_SETTER_WITH_DEFAULT_VALUE(GraphProto)
0470   ATTR_SETTER_WITH_DEFAULT_VALUE(TypeProto)
0471 
0472   OpSchema& Attr(
0473       std::string name,
0474       std::string description,
0475       std::string conditionExplanation,
0476       AttributeProto::AttributeType attr_type);
0477 
0478   // Register "required" attribute without default value.
0479   OpSchema& Attr(std::string name, std::string description, AttributeProto::AttributeType type, bool required = true);
0480 
0481   // Non-STL wrapper to reduce binary size
0482   OpSchema& Attr(const char* name, const char* description, AttributeProto::AttributeType type, bool required = true);
0483 
0484   OpSchema& AllowUncheckedAttributes();
0485 
0486   // Type constraint.
0487   struct TypeConstraintParam final {
0488     TypeConstraintParam(
0489         std::string type_param_str_,
0490         std::vector<std::string> allowed_type_strs_,
0491         std::string description_)
0492         : type_param_str(std::move(type_param_str_)),
0493           allowed_type_strs(std::move(allowed_type_strs_)),
0494           description(std::move(description_)) {}
0495 
0496     // Type parameter string, for example, "T", "T1", etc.
0497     std::string type_param_str;
0498     // Allowed type strings for <*this> type parameter, for example,
0499     // "tensor(float)".
0500     std::vector<std::string> allowed_type_strs;
0501     // Type parameter description.
0502     std::string description;
0503   };
0504 
0505   // Grammar for type strings used in Input(), Output().
0506   // <type> ::= <data_type> |
0507   //            tensor(<data_type>) |
0508   //            seq(<type>) |
0509   //            map(<data_type>, <type>) |
0510   //            <type_parameter>
0511   // <data_type> :: = float | int32 | string | bool | uint8
0512   //                | int8 | uint16 | int16 | int64 | float16 | double
0513   // <type_parameter> ::= any type parameter string, say "T".
0514   //
0515   // NOTE: 1) <type_parameter> will always be together with a type constraints
0516   // specification.
0517   //       2) <type> ::= <data_type> means the data is scalar (zero dimension).
0518   //
0519   // Example:
0520   // ONNX_OPERATOR_SET_SCHEMA(Sum, 1, OpSchema()
0521   // .Input(0, "input_a", "the first input", "T")
0522   // .Input(1, "input_b", "the second input", "T")
0523   // .Output(0, "sum", "the sum of two numbers", "T")
0524   // .TypeConstraint("T", {"float", "double", "int32"}, "allowed data types for
0525   // sum."))
0526   //
0527   // Optional = true means that the input might have empty input value
0528   // (represented as "") in the graph even though the later inputs have values.
0529   // It's useful for complex situation when there are several independent
0530   // optional inputs.
0531   OpSchema& Input(int n, FormalParameter formal_parameter);
0532 
0533   OpSchema& Input(
0534       int n,
0535       std::string name,
0536       const std::string& description,
0537       std::string type_str,
0538       FormalParameterOption param_option = Single,
0539       bool is_homogeneous = true,
0540       int min_arity = 1,
0541       DifferentiationCategory differentiation_category = Unknown);
0542 
0543   // Non-STL wrapper to reduce binary size
0544   OpSchema& Input(
0545       int n,
0546       const char* name,
0547       const char* description,
0548       const char* type_str,
0549       FormalParameterOption param_option = Single,
0550       bool is_homogeneous = true,
0551       int min_arity = 1,
0552       DifferentiationCategory differentiation_category = Unknown);
0553 
0554   OpSchema& Output(int n, FormalParameter formal_parameter);
0555 
0556   OpSchema& Output(
0557       int n,
0558       std::string name,
0559       const std::string& description,
0560       std::string type_str,
0561       FormalParameterOption param_option = Single,
0562       bool is_homogeneous = true,
0563       int min_arity = 1,
0564       DifferentiationCategory differentiation_category = Unknown);
0565 
0566   // Non-STL wrapper to reduce binary size
0567   OpSchema& Output(
0568       int n,
0569       const char* name,
0570       const char* description,
0571       const char* type_str,
0572       FormalParameterOption param_option = Single,
0573       bool is_homogeneous = true,
0574       int min_arity = 1,
0575       DifferentiationCategory differentiation_category = Unknown);
0576 
0577   OpSchema& TypeConstraint(std::string type_str, std::vector<std::string> constraints, std::string description);
0578 
0579   // Non-STL wrapper to reduce binary size
0580   OpSchema&
0581   TypeConstraint(const char* type_str, std::initializer_list<const char*> constraints, const char* description);
0582 
0583   // Convenience members for types
0584 
0585   // All high-precision numeric types.
0586   static const std::vector<std::string>& numeric_types_for_math_reduction_ir9() {
0587     static const std::vector<std::string> numeric_types_for_math_reduction_ir9 = {
0588         "tensor(uint32)",
0589         "tensor(uint64)",
0590         "tensor(int32)",
0591         "tensor(int64)",
0592         "tensor(float16)",
0593         "tensor(float)",
0594         "tensor(double)",
0595         "tensor(bfloat16)",
0596         "tensor(float8e4m3fn)",
0597         "tensor(float8e4m3fnuz)",
0598         "tensor(float8e5m2)",
0599         "tensor(float8e5m2fnuz)"};
0600     return numeric_types_for_math_reduction_ir9;
0601   }
0602 
0603   static const std::vector<std::string>& numeric_types_for_math_reduction_ir4() {
0604     static const std::vector<std::string> numeric_types_for_math_reduction_ir4 = {
0605         "tensor(uint32)",
0606         "tensor(uint64)",
0607         "tensor(int32)",
0608         "tensor(int64)",
0609         "tensor(float16)",
0610         "tensor(float)",
0611         "tensor(double)",
0612         "tensor(bfloat16)"};
0613     return numeric_types_for_math_reduction_ir4;
0614   }
0615 
0616   static const std::vector<std::string>& numeric_types_for_math_reduction() {
0617     static const std::vector<std::string> numeric_types_for_math_reduction = {
0618         "tensor(uint32)",
0619         "tensor(uint64)",
0620         "tensor(int32)",
0621         "tensor(int64)",
0622         "tensor(float16)",
0623         "tensor(float)",
0624         "tensor(double)"};
0625     return numeric_types_for_math_reduction;
0626   }
0627 
0628   static const std::vector<std::string>& all_numeric_types_ir9() {
0629     static const std::vector<std::string> all_numeric_types_ir9 = {
0630         "tensor(uint8)",
0631         "tensor(uint16)",
0632         "tensor(uint32)",
0633         "tensor(uint64)",
0634         "tensor(int8)",
0635         "tensor(int16)",
0636         "tensor(int32)",
0637         "tensor(int64)",
0638         "tensor(float16)",
0639         "tensor(float)",
0640         "tensor(double)",
0641         "tensor(bfloat16)",
0642         "tensor(float8e4m3fn)",
0643         "tensor(float8e4m3fnuz)",
0644         "tensor(float8e5m2)",
0645         "tensor(float8e5m2fnuz)"};
0646     return all_numeric_types_ir9;
0647   }
0648 
0649   static const std::vector<std::string>& all_numeric_types_ir4() {
0650     static const std::vector<std::string> all_numeric_types_ir4 = {
0651         "tensor(uint8)",
0652         "tensor(uint16)",
0653         "tensor(uint32)",
0654         "tensor(uint64)",
0655         "tensor(int8)",
0656         "tensor(int16)",
0657         "tensor(int32)",
0658         "tensor(int64)",
0659         "tensor(float16)",
0660         "tensor(float)",
0661         "tensor(double)",
0662         "tensor(bfloat16)"};
0663     return all_numeric_types_ir4;
0664   }
0665 
0666   static const std::vector<std::string>& all_numeric_types() {
0667     static const std::vector<std::string> all_numeric_types = {
0668         "tensor(uint8)",
0669         "tensor(uint16)",
0670         "tensor(uint32)",
0671         "tensor(uint64)",
0672         "tensor(int8)",
0673         "tensor(int16)",
0674         "tensor(int32)",
0675         "tensor(int64)",
0676         "tensor(float16)",
0677         "tensor(float)",
0678         "tensor(double)"};
0679     return all_numeric_types;
0680   }
0681 
0682   static const std::vector<std::string>& all_numeric_sequence_types() {
0683     static const std::vector<std::string> all_numeric_sequence_types = {
0684         "seq(tensor(uint8))",
0685         "seq(tensor(uint16))",
0686         "seq(tensor(uint32))",
0687         "seq(tensor(uint64))",
0688         "seq(tensor(int8))",
0689         "seq(tensor(int16))",
0690         "seq(tensor(int32))",
0691         "seq(tensor(int64))",
0692         "seq(tensor(float16))",
0693         "seq(tensor(float))",
0694         "seq(tensor(double))"};
0695     return all_numeric_sequence_types;
0696   }
0697 
0698   static const std::vector<std::string>& all_tensor_types() {
0699     static const std::vector<std::string> all_tensor_types = {
0700         "tensor(uint8)",
0701         "tensor(uint16)",
0702         "tensor(uint32)",
0703         "tensor(uint64)",
0704         "tensor(int8)",
0705         "tensor(int16)",
0706         "tensor(int32)",
0707         "tensor(int64)",
0708         "tensor(float16)",
0709         "tensor(float)",
0710         "tensor(double)",
0711         "tensor(string)",
0712         "tensor(bool)",
0713         "tensor(complex64)",
0714         "tensor(complex128)"};
0715     return all_tensor_types;
0716   }
0717 
0718   static const std::vector<std::string>& all_tensor_types_ir4() {
0719     static const std::vector<std::string> all_tensor_types_ir4 = {
0720         "tensor(uint8)",
0721         "tensor(uint16)",
0722         "tensor(uint32)",
0723         "tensor(uint64)",
0724         "tensor(int8)",
0725         "tensor(int16)",
0726         "tensor(int32)",
0727         "tensor(int64)",
0728         "tensor(bfloat16)",
0729         "tensor(float16)",
0730         "tensor(float)",
0731         "tensor(double)",
0732         "tensor(string)",
0733         "tensor(bool)",
0734         "tensor(complex64)",
0735         "tensor(complex128)"};
0736     return all_tensor_types_ir4;
0737   }
0738 
0739   static const std::vector<std::string>& all_float_types_ir4() {
0740     static const std::vector<std::string> all_float_types_ir4 = {
0741         "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)"};
0742     return all_float_types_ir4;
0743   }
0744 
0745   static const std::vector<std::string>& all_float_types_ir9() {
0746     static const std::vector<std::string> all_float_types_ir9 = {
0747         "tensor(bfloat16)",
0748         "tensor(float16)",
0749         "tensor(float)",
0750         "tensor(double)",
0751         "tensor(float8e4m3fn)",
0752         "tensor(float8e4m3fnuz)",
0753         "tensor(float8e5m2)",
0754         "tensor(float8e5m2fnuz)"};
0755     return all_float_types_ir9;
0756   }
0757 
0758   static const std::vector<std::string>& all_tensor_types_ir9() {
0759     static const std::vector<std::string> all_tensor_types_ir9 = {
0760         "tensor(uint8)",        "tensor(uint16)",         "tensor(uint32)",     "tensor(uint64)",
0761         "tensor(int8)",         "tensor(int16)",          "tensor(int32)",      "tensor(int64)",
0762         "tensor(bfloat16)",     "tensor(float16)",        "tensor(float)",      "tensor(double)",
0763         "tensor(string)",       "tensor(bool)",           "tensor(complex64)",  "tensor(complex128)",
0764         "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)", "tensor(float8e5m2)", "tensor(float8e5m2fnuz)"};
0765     return all_tensor_types_ir9;
0766   }
0767 
0768   static const std::vector<std::string>& all_tensor_sequence_types() {
0769     static const std::vector<std::string> all_tensor_sequence_types = {
0770         "seq(tensor(uint8))",
0771         "seq(tensor(uint16))",
0772         "seq(tensor(uint32))",
0773         "seq(tensor(uint64))",
0774         "seq(tensor(int8))",
0775         "seq(tensor(int16))",
0776         "seq(tensor(int32))",
0777         "seq(tensor(int64))",
0778         "seq(tensor(float16))",
0779         "seq(tensor(float))",
0780         "seq(tensor(double))",
0781         "seq(tensor(string))",
0782         "seq(tensor(bool))",
0783         "seq(tensor(complex64))",
0784         "seq(tensor(complex128))"};
0785     return all_tensor_sequence_types;
0786   }
0787 
0788   static const std::vector<std::string>& all_tensor_sequence_types_ir4() {
0789     static const std::vector<std::string> all_tensor_sequence_types_ir4 = {
0790         "seq(tensor(uint8))",
0791         "seq(tensor(uint16))",
0792         "seq(tensor(uint32))",
0793         "seq(tensor(uint64))",
0794         "seq(tensor(int8))",
0795         "seq(tensor(int16))",
0796         "seq(tensor(int32))",
0797         "seq(tensor(int64))",
0798         "seq(tensor(bfloat16))",
0799         "seq(tensor(float16))",
0800         "seq(tensor(float))",
0801         "seq(tensor(double))",
0802         "seq(tensor(string))",
0803         "seq(tensor(bool))",
0804         "seq(tensor(complex64))",
0805         "seq(tensor(complex128))"};
0806     return all_tensor_sequence_types_ir4;
0807   }
0808 
0809   static const std::vector<std::string>& all_tensor_sequence_types_ir9() {
0810     static const std::vector<std::string> all_tensor_sequence_types_ir4 = {
0811         "seq(tensor(uint8))",      "seq(tensor(uint16))",        "seq(tensor(uint32))",
0812         "seq(tensor(uint64))",     "seq(tensor(int8))",          "seq(tensor(int16))",
0813         "seq(tensor(int32))",      "seq(tensor(int64))",         "seq(tensor(bfloat16))",
0814         "seq(tensor(float16))",    "seq(tensor(float))",         "seq(tensor(double))",
0815         "seq(tensor(string))",     "seq(tensor(bool))",          "seq(tensor(complex64))",
0816         "seq(tensor(complex128))", "seq(tensor(float8e4m3fn))",  "seq(tensor(float8e4m3fnuz))",
0817         "seq(tensor(float8e5m2))", "seq(tensor(float8e5m2fnuz))"};
0818     return all_tensor_sequence_types_ir4;
0819   }
0820 
0821   static const std::vector<std::string>& all_optional_types() {
0822     static const std::vector<std::string> all_optional_types = {
0823         "optional(seq(tensor(uint8)))",  "optional(seq(tensor(uint16)))",    "optional(seq(tensor(uint32)))",
0824         "optional(seq(tensor(uint64)))", "optional(seq(tensor(int8)))",      "optional(seq(tensor(int16)))",
0825         "optional(seq(tensor(int32)))",  "optional(seq(tensor(int64)))",     "optional(seq(tensor(float16)))",
0826         "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",    "optional(seq(tensor(string)))",
0827         "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))", "optional(seq(tensor(complex128)))",
0828         "optional(tensor(uint8))",       "optional(tensor(uint16))",         "optional(tensor(uint32))",
0829         "optional(tensor(uint64))",      "optional(tensor(int8))",           "optional(tensor(int16))",
0830         "optional(tensor(int32))",       "optional(tensor(int64))",          "optional(tensor(float16))",
0831         "optional(tensor(float))",       "optional(tensor(double))",         "optional(tensor(string))",
0832         "optional(tensor(bool))",        "optional(tensor(complex64))",      "optional(tensor(complex128))"};
0833     return all_optional_types;
0834   }
0835 
0836   static const std::vector<std::string>& all_optional_types_ir4() {
0837     static const std::vector<std::string> all_optional_types = {
0838         "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0839         "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
0840         "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
0841         "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
0842         "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
0843         "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
0844         "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
0845         "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
0846         "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
0847         "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
0848         "optional(tensor(complex64))",       "optional(tensor(complex128))"};
0849     return all_optional_types;
0850   }
0851 
0852   static const std::vector<std::string>& all_optional_types_ir9() {
0853     static const std::vector<std::string> all_optional_types = {
0854         "optional(seq(tensor(uint8)))",      "optional(seq(tensor(uint16)))", "optional(seq(tensor(uint32)))",
0855         "optional(seq(tensor(uint64)))",     "optional(seq(tensor(int8)))",   "optional(seq(tensor(int16)))",
0856         "optional(seq(tensor(int32)))",      "optional(seq(tensor(int64)))",  "optional(seq(tensor(bfloat16)))",
0857         "optional(seq(tensor(float16)))",    "optional(seq(tensor(float)))",  "optional(seq(tensor(double)))",
0858         "optional(seq(tensor(string)))",     "optional(seq(tensor(bool)))",   "optional(seq(tensor(complex64)))",
0859         "optional(seq(tensor(complex128)))", "optional(tensor(uint8))",       "optional(tensor(uint16))",
0860         "optional(tensor(uint32))",          "optional(tensor(uint64))",      "optional(tensor(int8))",
0861         "optional(tensor(int16))",           "optional(tensor(int32))",       "optional(tensor(int64))",
0862         "optional(tensor(bfloat16))",        "optional(tensor(float16))",     "optional(tensor(float))",
0863         "optional(tensor(double))",          "optional(tensor(string))",      "optional(tensor(bool))",
0864         "optional(tensor(complex64))",       "optional(tensor(complex128))",  "optional(tensor(float8e4m3fn))",
0865         "optional(tensor(float8e4m3fnuz))",  "optional(tensor(float8e5m2))",  "optional(tensor(float8e5m2fnuz))"};
0866     return all_optional_types;
0867   }
0868 
0869   // Calls the passed function with `this` as an argument. Useful for
0870   // adding docs for temlated/macro ops.
0871   OpSchema& FillUsing(const std::function<void(OpSchema&)>& populator);
0872 
0873   friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema);
0874 
0875   const std::string& domain() const {
0876     return domain_;
0877   }
0878 
0879   const std::map<std::string, Attribute>& attributes() const {
0880     return attributes_;
0881   }
0882 
0883   // Get input formal parameters.
0884   const std::vector<FormalParameter>& inputs() const {
0885     return inputs_;
0886   }
0887 
0888   // Get output formal parameters.
0889   const std::vector<FormalParameter>& outputs() const {
0890     return outputs_;
0891   }
0892 
0893   const std::vector<TypeConstraintParam>& typeConstraintParams() const {
0894     return type_constraint_params_;
0895   }
0896 
0897   const TypeConstraintMap& typeConstraintMap() const {
0898     return type_constraints_;
0899   }
0900 
0901   const std::string& Name() const {
0902     return name_;
0903   }
0904 
0905   OperatorSetVersion SinceVersion() const {
0906     return since_version_;
0907   }
0908 
0909   int since_version() const {
0910     return since_version_;
0911   }
0912 
0913   bool deprecated() const {
0914     return deprecated_;
0915   }
0916 
0917   int min_input() const {
0918     return min_input_;
0919   }
0920   int max_input() const {
0921     return max_input_;
0922   }
0923   int min_output() const {
0924     return min_output_;
0925   }
0926   int max_output() const {
0927     return max_output_;
0928   }
0929 
0930   bool has_type_and_shape_inference_function() const {
0931     return tensor_inference_function_ ? true : false;
0932   }
0933 
0934   bool has_data_propagation_function() const {
0935     return data_propagation_function_ ? true : false;
0936   }
0937 
0938   std::vector<int> function_opset_versions() const {
0939     std::vector<int> opset_versions;
0940     std::map<int, std::shared_ptr<FunctionProto>>::const_iterator it = opset_version_to_function_body_.cbegin();
0941     for (; it != opset_version_to_function_body_.cend(); ++it) {
0942       opset_versions.push_back(it->first);
0943     }
0944     return opset_versions;
0945   }
0946 
0947   bool HasFunction() const {
0948     return !opset_version_to_function_body_.empty();
0949   }
0950 
0951   OpSchema& FunctionBody(const std::vector<NodeProto>& func_nodes, int opset_version = kUninitializedSinceVersion);
0952 
0953   OpSchema& FunctionBody(
0954       const std::vector<NodeProto>& func_nodes,
0955       const std::vector<OperatorSetIdProto>& opsets,
0956       int opset_version = kUninitializedSinceVersion);
0957 
0958   OpSchema& FunctionBody(const char* func_body, int opset_version = kUninitializedSinceVersion);
0959 
0960   // since_version_ of an OpSchema tells the last opset version when an op is defined.
0961   // When the op's definition is changed, a new OpSchema (of the same op_type) is created
0962   // with a newer since_version_, reflecting the opset version at the time of change.
0963   // For a function op, operators used to define its function body may change
0964   // while there is no change to the function op definition itself.
0965   // When this happens, mutiple function bodies are provided, each for a specific opset version.
0966   //
0967   // Take LogSoftmax for example. Its latest opset version is 13.
0968   // In LogSoftmax's function body, ReduceMax (with since_version_ 1, 11, 12, 18) is used.
0969   // When a model containing LogSoftmax with opset_import version within 13 to 17 is loaded, function body
0970   // with opset_version 13 is used for inlining.
0971   // When the same model but opset_import version 18 is loaded, function body
0972   // with opset_version 18 is used for inlining.
0973   // Clearly function body for opset_import version 13 will not work
0974   // in a model with opset_import version 18 because the function body make worng use of ReduceMax(18).
0975   // Inside GetFunction we ensure that ops being used to construct a function body do not endure such
0976   // issue.
0977   const FunctionProto* GetFunction(
0978       int requested_opset_version = OpSchema::kUninitializedSinceVersion,
0979       bool validate = false) const;
0980 
0981   std::vector<int> context_dependent_function_opset_versions() const {
0982     std::vector<int> opset_versions;
0983     std::map<int, ContextDependentFunctionBodyBuilder>::const_iterator it = opset_version_to_function_builder_.cbegin();
0984     for (; it != opset_version_to_function_builder_.cend(); ++it) {
0985       opset_versions.push_back(it->first);
0986     }
0987     return opset_versions;
0988   }
0989 
0990   bool HasContextDependentFunction() const {
0991     return !opset_version_to_function_builder_.empty();
0992   }
0993 
0994   bool HasContextDependentFunctionWithOpsetVersion(int opset_version) const {
0995     return opset_version_to_function_builder_.find(opset_version) != opset_version_to_function_builder_.end();
0996   }
0997 
0998   OpSchema& SetContextDependentFunctionBodyBuilder(
0999       ContextDependentFunctionBodyBuilder,
1000       int opset_version = kUninitializedSinceVersion);
1001 
1002   bool BuildContextDependentFunction(
1003       const FunctionBodyBuildContext& ctx,
1004       FunctionProto& function_proto,
1005       int requested_opset_version = OpSchema::kUninitializedSinceVersion) const;
1006 
1007   // Verifies that the schema is valid and all specifications are compatible.
1008   // It will also parse all type strings specified for inputs/outputs into valid
1009   // TypeProto and create global unique string pointer as the DataType for
1010   // efficiency.
1011   void Finalize();
1012 
1013   // Build function with information stored in opschema
1014   void BuildFunction(FunctionProto& function_body) const;
1015 
1016  private:
1017   void ParseAndSetTypes(
1018       /*out*/ std::vector<OpSchema::FormalParameter>* formalParameters);
1019   bool ValidateReferencedOpsInFuncton(
1020       const FunctionProto* function,
1021       int requested_opset_version,
1022       int function_since_version,
1023       std::set<std::string>* updated_ops = nullptr) const;
1024   void UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int opset_version) const;
1025 
1026   std::string name_;
1027   std::string file_;
1028   std::string doc_;
1029   // Default domain value ("") means it's ONNX domain.
1030   std::string domain_ = ONNX_DOMAIN;
1031   std::map<std::string, Attribute> attributes_{};
1032   bool allows_unchecked_attributes_ = false;
1033   std::vector<FormalParameter> inputs_;
1034   std::vector<FormalParameter> outputs_;
1035   std::vector<TypeConstraintParam> type_constraint_params_;
1036   TypeConstraintMap type_constraints_;
1037   int line_ = 0;
1038   SupportType support_;
1039   int min_input_ = 0;
1040   int max_input_ = 0;
1041   int min_output_ = 0;
1042   int max_output_ = 0;
1043   // The default is a little goofy, since it is never what you want
1044   OperatorSetVersion since_version_ = kUninitializedSinceVersion;
1045   bool deprecated_{};
1046   std::function<bool(int)> num_inputs_allowed_ = [](int) { return true; };
1047   std::function<bool(int)> num_outputs_allowed_ = [](int) { return true; };
1048   InferenceFunction tensor_inference_function_;
1049   DataPropagationFunction data_propagation_function_;
1050 
1051   std::map<int, std::shared_ptr<FunctionProto>> opset_version_to_function_body_;
1052   std::map<int, ContextDependentFunctionBodyBuilder> opset_version_to_function_builder_;
1053 };
1054 
1055 // Map type to store operator schemas. The format is,
1056 // <OpName, <Domain, <OperatorSetVersion, OpSchema>>>.
1057 using OpName_Domain_Version_Schema_Map =
1058     std::unordered_map<std::string, std::unordered_map<std::string, std::map<OperatorSetVersion, OpSchema>>>;
1059 
1060 class ISchemaRegistry {
1061  public:
1062   virtual ~ISchemaRegistry() = default;
1063 
1064   virtual const OpSchema*
1065   GetSchema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) const = 0;
1066 };
1067 
1068 /**
1069  * @brief A registry to hold all the operator schemas.
1070  */
1071 class OpSchemaRegistry final : public ISchemaRegistry {
1072  public:
1073   // A singleton class to store domain to min/max op_set version map, as well as
1074   // domain to last-release op_set version map.
1075   class DomainToVersionRange final {
1076    public:
1077     DomainToVersionRange() {
1078       // Increase the highest version when you make BC-breaking changes to the
1079       // operator schema on specific domain. Update the lowest version when it's
1080       // determined to remove too old version history.
1081       map_[ONNX_DOMAIN] = std::make_pair(1, 20);
1082       map_[AI_ONNX_ML_DOMAIN] = std::make_pair(1, 4);
1083       map_[AI_ONNX_TRAINING_DOMAIN] = std::make_pair(1, 1);
1084       // ONNX's preview domain contains operators subject to change, so
1085       // versining is not meaningful and that domain should have only one
1086       // version.
1087       map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 1);
1088       // Version corresponding last release of ONNX. Update this to match with
1089       // the max version above in a *release* version of ONNX. But in other
1090       // versions, the max version may be ahead of the last-release-version.
1091       last_release_version_map_[ONNX_DOMAIN] = 20;
1092       last_release_version_map_[AI_ONNX_ML_DOMAIN] = 4;
1093       last_release_version_map_[AI_ONNX_TRAINING_DOMAIN] = 1;
1094       last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 1;
1095     }
1096 
1097     const std::unordered_map<std::string, std::pair<int, int>>& Map() const {
1098       return map_;
1099     }
1100 
1101     const std::unordered_map<std::string, int>& LastReleaseVersionMap() const {
1102       return last_release_version_map_;
1103     }
1104 
1105     // Add customized domain to min/max version.
1106     // Onnx partners are able to use onnx operator schema api to
1107     // register customized op in their own domain.
1108     // Can optionally specify last_release_version (to make it similar to
1109     // standard ONNX domains as above). Custom-domains are free to interpret
1110     // this as appropriate (that is, as relative to releases of custom-domain
1111     // as opposed to ONNX releases).
1112     void
1113     AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) {
1114       std::lock_guard<std::mutex> lock(mutex_);
1115       assert(map_.end() == map_.find(domain));
1116       map_[domain] = std::make_pair(min_version, max_version);
1117       // If a last-release-version is not explicitly specified, use max as
1118       // last-release-version.
1119       if (last_release_version == -1)
1120         last_release_version = max_version;
1121       assert(last_release_version_map_.end() == last_release_version_map_.find(domain));
1122       last_release_version_map_[domain] = last_release_version;
1123     }
1124 
1125     static DomainToVersionRange& Instance();
1126 
1127    private:
1128     // Key: domain. Value: <lowest version, highest version> pair.
1129     std::unordered_map<std::string, std::pair<int, int>> map_;
1130 
1131     // Key: domain. Value: most recent release opset version. Note that
1132     // the highest opset version may be ahead of the most recent release's opset
1133     // version.
1134     std::unordered_map<std::string, int> last_release_version_map_;
1135 
1136     std::mutex mutex_;
1137   };
1138 
1139   class OpSchemaRegisterOnce final {
1140    public:
1141     OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1142       ONNX_TRY {
1143         op_schema.Finalize();
1144         auto& m = GetMapWithoutEnsuringRegistration();
1145         auto& op_name = op_schema.Name();
1146         auto& op_domain = op_schema.domain();
1147         auto& schema_ver_map = m[op_name][op_domain];
1148         auto ver = op_schema.SinceVersion();
1149         if (OpSchema::kUninitializedSinceVersion == ver) {
1150           op_schema.SinceVersion(1);
1151           ver = op_schema.SinceVersion();
1152         }
1153 
1154         // Stops because the exact opset_version is registered
1155         if (schema_ver_map.count(ver)) {
1156           if (fail_duplicate_schema) {
1157             const auto& schema = schema_ver_map[ver];
1158             std::stringstream err;
1159             err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1160                 << ") from file " << op_schema.file() << " line " << op_schema.line()
1161                 << ", but it is already registered from file " << schema.file() << " line " << schema.line()
1162                 << std::endl;
1163             fail_schema(err.str());
1164           }
1165           return;
1166         }
1167 
1168         if (opset_version_to_load != 0) {
1169           // Stops because the opset_version is higher than opset_version_to_load
1170           if (ver > opset_version_to_load)
1171             return;
1172 
1173           // Stops because a later version is registered within target opset version
1174           if (!schema_ver_map.empty()) {
1175             int max_registered_ver_le_target = GetMaxRegisteredVerWithinTarget(schema_ver_map, opset_version_to_load);
1176             if (max_registered_ver_le_target >= ver)
1177               return;
1178           }
1179         }
1180 
1181         CheckDomainAndVersionToRegister(op_schema, op_name, op_domain);
1182         schema_ver_map.insert(std::pair<int, OpSchema&&>(ver, std::move(op_schema)));
1183       }
1184       ONNX_CATCH(const std::exception& e) {
1185         ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; });
1186       }
1187     }
1188 
1189    private:
1190     // Gets the maximum version from given map that is less or equal to target version
1191     static int GetMaxRegisteredVerWithinTarget(const std::map<OperatorSetVersion, OpSchema>& m, int target_ver) {
1192       // std::map is sorted on key
1193       // reverse iterator returns the largest element keyed on the integer version
1194       for (auto&& it = m.rbegin(); it != m.rend(); it++) {
1195         const auto& registered_ver = it->first;
1196         if (registered_ver <= target_ver) {
1197           return registered_ver;
1198         }
1199       }
1200       return -1;
1201     }
1202 
1203     static void CheckDomainAndVersionToRegister(
1204         const OpSchema& op_schema,
1205         const std::string& op_name,
1206         const std::string& op_domain) {
1207       auto ver_range_map = DomainToVersionRange::Instance().Map();
1208       auto ver_range_it = ver_range_map.find(op_domain);
1209       auto ver = op_schema.SinceVersion();
1210 
1211       if (ver_range_it == ver_range_map.end()) {
1212         std::stringstream err;
1213         err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1214             << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its domain is not"
1215             << " known by the checker." << std::endl;
1216 
1217         fail_schema(err.str());
1218       }
1219       auto lower_bound_incl = ver_range_it->second.first;
1220       auto upper_bound_incl = ver_range_it->second.second;
1221       if (!(lower_bound_incl <= ver && upper_bound_incl >= ver)) {
1222         std::stringstream err;
1223         err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver
1224             << ") from file " << op_schema.file() << " line " << op_schema.line() << ", but its version is not "
1225             << "in the inclusive range [" << lower_bound_incl << ", " << upper_bound_incl
1226             << "] (usually, this means you "
1227             << "bumped the operator version but "
1228             << "forgot to update the version range in DomainToVersionRange "
1229             << "in onnx/defs/schema.h)." << std::endl;
1230         fail_schema(err.str());
1231       }
1232     }
1233   };
1234 
1235   // Deregister all ONNX opset schemas from domain
1236   // Domain with default value ONNX_DOMAIN means ONNX.
1237   static void OpSchemaDeregisterAll(const std::string& domain = ONNX_DOMAIN) {
1238     auto& schema_map = GetMapWithoutEnsuringRegistration();
1239     // schema_map stores operator schemas in the format of
1240     // <OpName, <Domain, <OperatorSetVersion, OpSchema>>>
1241     for (auto&& schema_map_pair : schema_map) {
1242       auto& domain_map = schema_map_pair.second;
1243       if (domain_map.count(domain)) {
1244         auto& opset_version_schema_map = domain_map[domain];
1245         // Invalidates ver-schema pairs and frees memory, leaving m[op_name][op_domain] empty
1246         opset_version_schema_map.clear();
1247         domain_map.erase(domain);
1248       }
1249     }
1250   }
1251 
1252   // Return the latest schema for an operator in specified domain.
1253   // Domain with default value ONNX_DOMAIN means ONNX.
1254   static const OpSchema* Schema(const std::string& key, const std::string& domain = ONNX_DOMAIN) {
1255     auto& m = map();
1256     if (m.count(key) && m[key].count(domain)) {
1257       const auto& schema_ver_map = m[key][domain];
1258       if (!schema_ver_map.empty()) {
1259         return &m[key][domain].rbegin()->second;
1260       }
1261     }
1262     return nullptr;
1263   }
1264 
1265   // Return the schema with biggest version, which is not greater than specified
1266   // <maxInclusiveVersion> in specified domain. Domain with default value
1267   // ONNX_DOMAIN means ONNX.
1268   static const OpSchema*
1269   Schema(const std::string& key, const int maxInclusiveVersion, const std::string& domain = ONNX_DOMAIN) {
1270     auto& m = map();
1271     if (m.count(key) && m[key].count(domain)) {
1272       const auto& schema_ver_map = m[key][domain];
1273       if (!schema_ver_map.empty()) {
1274         auto pos = m[key][domain].lower_bound(maxInclusiveVersion);
1275         if (m[key][domain].begin() == pos && pos->first > maxInclusiveVersion) {
1276           // All versions are greater than specified version.
1277           return nullptr;
1278         }
1279         if (m[key][domain].end() == pos || pos->first > maxInclusiveVersion) {
1280           // All versions are less than specified version, or,
1281           // The <pos> version is greater than specified version.
1282           pos--;
1283         }
1284 
1285         // Schema with exact version as specified one exists.
1286         return &(pos->second);
1287       }
1288     }
1289     return nullptr;
1290   }
1291 
1292   static OpSchemaRegistry* Instance();
1293 
1294   const OpSchema* GetSchema(
1295       const std::string& key,
1296       const int maxInclusiveVersion,
1297       const std::string& domain = ONNX_DOMAIN) const override {
1298     return Schema(key, maxInclusiveVersion, domain);
1299   }
1300   static void SetLoadedSchemaVersion(int target_version) {
1301     loaded_schema_version = target_version;
1302   }
1303   static int GetLoadedSchemaVersion() {
1304     return loaded_schema_version;
1305   }
1306 
1307  private:
1308   // OpSchemaRegistry should not need to be instantiated except statically
1309   // within this class
1310   OpSchemaRegistry() = default;
1311 
1312   /**
1313    * @brief Returns the underlying string to OpSchema map.
1314    *
1315    * You should not manually manipulate the map object returned. Instead, use
1316    * the macros defined such as ONNX_OPERATOR_SET_SCHEMA to register your
1317    * operator schema.
1318    *
1319    * We wrap it inside a function to avoid the static initialization order
1320    * fiasco.
1321    */
1322   static OpName_Domain_Version_Schema_Map& GetMapWithoutEnsuringRegistration();
1323   static OpName_Domain_Version_Schema_Map& map();
1324   static int loaded_schema_version;
1325 
1326  public:
1327   static const std::vector<OpSchema> get_all_schemas_with_history() {
1328     std::vector<OpSchema> r;
1329     for (auto& x : map()) {
1330       for (auto& y : x.second) {
1331         for (auto& z : y.second) {
1332           r.emplace_back(z.second);
1333         }
1334       }
1335     }
1336     return r;
1337   }
1338 
1339   static const std::vector<OpSchema> get_all_schemas() {
1340     std::vector<OpSchema> r;
1341     for (auto& x : map()) {
1342       for (auto& y : x.second) {
1343         auto& version2schema = y.second;
1344         r.emplace_back(version2schema.rbegin()->second);
1345       }
1346     }
1347     return r;
1348   }
1349 };
1350 
1351 void RegisterSchema(OpSchema schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true);
1352 
1353 // Registers the latest opset schema before opset_version_to_load
1354 // By default opset_version_to_load=0 means it will register all versions
1355 template <class T>
1356 void RegisterOpSetSchema(int opset_version_to_load = 0, bool fail_duplicate_schema = true) {
1357   T::ForEachSchema([opset_version_to_load, fail_duplicate_schema](OpSchema&& schema) {
1358     RegisterSchema(schema, opset_version_to_load, fail_duplicate_schema);
1359   });
1360 };
1361 
1362 // Forward declaration for the non-specialized GetOpSchema method.  This
1363 // enforces a consistent signature on functions that query individual schema,
1364 // which are defined as specializations of this function.
1365 template <typename T>
1366 OpSchema GetOpSchema();
1367 
1368 #define ONNX_OPERATOR_SET_SCHEMA(name, ver, impl) ONNX_OPERATOR_SET_SCHEMA_EX(name, Onnx, ONNX_DOMAIN, ver, true, impl)
1369 
1370 #define ONNX_ML_OPERATOR_SET_SCHEMA(name, ver, impl) \
1371   ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxML, AI_ONNX_ML_DOMAIN, ver, true, impl)
1372 
1373 #define ONNX_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
1374   ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxTraining, AI_ONNX_TRAINING_DOMAIN, ver, true, impl)
1375 
1376 #define ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(name, ver, impl) \
1377   ONNX_OPERATOR_SET_SCHEMA_EX(name, OnnxPreview, AI_ONNX_PREVIEW_TRAINING_DOMAIN, ver, true, impl)
1378 
1379 // Defines specialization of GetOpSchema for a class whose name is determined
1380 // based on a convention using name, domain, and version.  Operator schema are
1381 // normally included in operator sets and registered in OpSchemaRegistry::map().
1382 // In this case, callers should set dbg_included_in_static_opset to true.  This
1383 // assists with runtime validation in DEBUG builds ensuring the intended set
1384 // of operator schema is registered.
1385 #define ONNX_OPERATOR_SET_SCHEMA_EX(name, domain, domain_str, ver, dbg_included_in_static_opset, impl)  \
1386   class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name);                                         \
1387   template <>                                                                                           \
1388   OpSchema GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name)>() {                      \
1389     return impl.SetName(#name).SetDomain(domain_str).SinceVersion(ver).SetLocation(__FILE__, __LINE__); \
1390   }                                                                                                     \
1391   size_t dbg_count_check_##name##_##domain##_ver##ver =                                                 \
1392       (dbg_included_in_static_opset) ? ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() : 0;
1393 #ifdef NDEBUG
1394 #define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() 0
1395 #else
1396 #define ONNX_DBG_INCREMENT_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().IncrementCount()
1397 #define ONNX_DBG_GET_COUNT_IN_OPSETS() DbgOperatorSetTracker::Instance().GetCount()
1398 
1399 class DbgOperatorSetTracker {
1400  public:
1401   static DbgOperatorSetTracker& Instance();
1402 
1403   size_t IncrementCount() {
1404     return ++count_;
1405   }
1406 
1407   size_t GetCount() const {
1408     return count_;
1409   }
1410 
1411  private:
1412   size_t count_ = 0;
1413 };
1414 #endif
1415 
1416 // Naming convention for operator schema classes
1417 #define ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(domain, ver, name) name##_##domain##_ver##ver
1418 
1419 // Naming convention for preview operator schema classes
1420 #define ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(ver, name) \
1421   ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxPreview, ver, name)
1422 
1423 // Helper function
1424 size_t ReplaceAll(std::string& s, const char* from, const char* to);
1425 
1426 #ifdef __GNUC__
1427 #define ONNX_UNUSED __attribute__((__unused__))
1428 #else
1429 #define ONNX_UNUSED
1430 #endif
1431 
1432 // Legacy macros to register schema at static initialization
1433 #define ONNX_OPERATOR_SCHEMA(name) ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name)
1434 #define ONNX_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)
1435 #define ONNX_OPERATOR_SCHEMA_UNIQ(Counter, name)                                                                      \
1436   static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce(op_schema_register_once##name##Counter) ONNX_UNUSED = \
1437       OpSchema(#name, __FILE__, __LINE__)
1438 
1439 // Helper function
1440 size_t ReplaceAll(std::string& s, const char* from, const char* to);
1441 
1442 inline std::string GenerateOptionalArgumentsDoc() {
1443   return "This operator has **optional** inputs/outputs. "
1444          "See [the doc](IR.md) for more details about the representation of "
1445          "optional arguments. An empty string may be used in the place of "
1446          "an actual argument's name to indicate a missing argument. "
1447          "Trailing optional arguments (those not followed by an argument "
1448          "that is present) may also be simply omitted.\n";
1449 }
1450 
1451 inline std::string GenerateBroadcastingDocMul() {
1452   return "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**;"
1453          " for more details please check [the doc](Broadcasting.md).";
1454 }
1455 
1456 inline std::string GenerateBroadcastingDocUni(const char* from, const char* to) {
1457   std::string ret = "This operator supports **unidirectional broadcasting** (";
1458   ret = ret + from + " should be unidirectional broadcastable to " + to +
1459       ");"
1460       " for more details please check [the doc](Broadcasting.md).";
1461   return ret;
1462 }
1463 
1464 /*
1465  * Macros for setting operator documentation
1466  * Use this macro for simple SetDoc() calls that generate documentation
1467  * directly. This is the macro to use in almost all cases.
1468  * Sample usage guidelines:
1469  * const char* doc_str = "foo";
1470  * SetDoc(GET_OP_DOC_STR(doc_str))
1471  *
1472  * SetDoc(GET_OP_DOC_STR(
1473             std::string(BitShift_ver11_doc) + GenerateBroadcastingDocMul()))
1474  */
1475 #ifndef __ONNX_NO_DOC_STRINGS
1476 #define GET_OP_DOC_STR(doc_str) (doc_str)
1477 #else
1478 #define GET_OP_DOC_STR(doc_str) ("")
1479 #endif
1480 
1481 /*
1482  * Use this macro when the documentation needs to be populated in some
1483  * complicated way like string substitutions, etc before calling SetDoc.
1484  * Sample usage guidelines:
1485     std::string doc;
1486     POPULATE_OP_DOC_STR(
1487         doc = R"DOC(
1488 Returns the tensor resulted from performing the `{name}` logical operation
1489 elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting
1490 support).
1491 
1492 {broadcast_doc}
1493 )DOC";
1494         ReplaceAll(doc, "{name}", name);
1495         ReplaceAll(
1496             doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
1497     schema.SetDoc(doc);
1498  *
1499  */
1500 #ifndef __ONNX_NO_DOC_STRINGS
1501 #define POPULATE_OP_DOC_STR(DocPopulatorCode) \
1502   do {                                        \
1503     DocPopulatorCode                          \
1504   } while (0)
1505 #else
1506 #define POPULATE_OP_DOC_STR(DocPopulatorCode)
1507 #endif
1508 
1509 } // namespace ONNX_NAMESPACE