File indexing completed on 2025-02-22 10:42:44
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #pragma once
0011 #include <stdint.h>
0012
0013 #include <string>
0014 #include <unordered_map>
0015 #include <vector>
0016
0017 namespace ONNX_NAMESPACE {
0018
0019 #define FORALL_BUILTIN_SYMBOLS(_) \
0020 _(spatial) \
0021 _(select_last_index) \
0022 _(coordinate_transformation_mode) \
0023 _(PythonOp) \
0024 _(CppOp) \
0025 _(Param) \
0026 _(Select) \
0027 _(Return) \
0028 _(Eval) \
0029 _(add) \
0030 _(Add) \
0031 _(Div) \
0032 _(Mul) \
0033 _(Neg) \
0034 _(Sub) \
0035 _(Pow) \
0036 _(Sigmoid) \
0037 _(ArgMax) \
0038 _(Concat) \
0039 _(Softmax) \
0040 _(LogSoftmax) \
0041 _(Dropout) \
0042 _(Tanh) \
0043 _(mul) \
0044 _(neg) \
0045 _(sigmoid) \
0046 _(tanh) \
0047 _(Constant) \
0048 _(cat) \
0049 _(Slice) \
0050 _(Squeeze) \
0051 _(Undefined) \
0052 _(FusionGroup) \
0053 _(MatMul) \
0054 _(Gemm) \
0055 _(Tile) \
0056 _(SubConstant) \
0057 _(Scale) \
0058 _(Transpose) \
0059 _(Pad) \
0060 _(Reshape) \
0061 _(split) \
0062 _(chunk) \
0063 _(Offset) \
0064 _(value) \
0065 _(Subgraph) \
0066 _(BatchNormalization) \
0067 _(Conv) \
0068 _(ConvTranspose) \
0069 _(is_test) \
0070 _(epsilon) \
0071 _(expand) \
0072 _(Expand) \
0073 _(order) \
0074 _(momentum) \
0075 _(consumed_inputs) \
0076 _(kernels) \
0077 _(kernel_shape) \
0078 _(kernel) \
0079 _(scale) \
0080 _(strides) \
0081 _(stride) \
0082 _(pads) \
0083 _(pad) \
0084 _(beta) \
0085 _(alpha) \
0086 _(dilations) \
0087 _(dilation) \
0088 _(broadcast) \
0089 _(axis) \
0090 _(ratio) \
0091 _(size) \
0092 _(dim) \
0093 _(keepdims) \
0094 _(perm) \
0095 _(shape) \
0096 _(axes) \
0097 _(group) \
0098 _(inplace) \
0099 _(transA) \
0100 _(transB) \
0101 _(other) \
0102 _(__and__) \
0103 _(__lshift__) \
0104 _(__or__) \
0105 _(__rshift__) \
0106 _(__xor__) \
0107 _(abs) \
0108 _(acos) \
0109 _(asin) \
0110 _(atan) \
0111 _(atan2) \
0112 _(ceil) \
0113 _(clamp) \
0114 _(cos) \
0115 _(cosh) \
0116 _(div) \
0117 _(eq) \
0118 _(equal) \
0119 _(Exp) \
0120 _(ends) \
0121 _(expm1) \
0122 _(floor) \
0123 _(fmod) \
0124 _(frac) \
0125 _(ge) \
0126 _(gt) \
0127 _(le) \
0128 _(lerp) \
0129 _(lgamma) \
0130 _(Log) \
0131 _(log1p) \
0132 _(lt) \
0133 _(max) \
0134 _(min) \
0135 _(ne) \
0136 _(ones) \
0137 _(pow) \
0138 _(reciprocal) \
0139 _(remainder) \
0140 _(round) \
0141 _(rsqrt) \
0142 _(sin) \
0143 _(sinh) \
0144 _(Sqrt) \
0145 _(sub) \
0146 _(starts) \
0147 _(tan) \
0148 _(trunc) \
0149 _(zeros) \
0150 _(exponent) \
0151 _(device) \
0152 _(mode) \
0153 _(Identity) \
0154 _(Loop) \
0155 _(If) \
0156 _(body) \
0157 _(then_branch) \
0158 _(else_branch) \
0159 _(Captured) \
0160 _(__control_inputs) \
0161 _(count_include_pad) \
0162 _(storage_order) \
0163 _(Unsqueeze) \
0164 _(ReduceL1) \
0165 _(ReduceL2) \
0166 _(ReduceLogSum) \
0167 _(ReduceLogSumExp) \
0168 _(ReduceMax) \
0169 _(ReduceMean) \
0170 _(ReduceMin) \
0171 _(ReduceProd) \
0172 _(ReduceSum) \
0173 _(ReduceSumSquare) \
0174 _(Cast) \
0175 _(to) \
0176 _(PRelu) \
0177 _(Greater) \
0178 _(Less) \
0179 _(scales) \
0180 _(Upsample) \
0181 _(RNN) \
0182 _(layout) \
0183 _(k) \
0184 _(Flatten) \
0185 _(ScatterElements) \
0186 _(Resize) \
0187 _(ceil_mode) \
0188 _(num_outputs)
0189
0190 enum BuiltinSymbol {
0191 #define DEFINE_SYMBOL(s) k##s,
0192 FORALL_BUILTIN_SYMBOLS(DEFINE_SYMBOL)
0193 #undef DEFINE_SYMBOL
0194 kLastSymbol,
0195 };
0196
0197 struct Symbol {
0198 Symbol() {}
0199 Symbol(BuiltinSymbol value) : value(value) {}
0200 explicit Symbol(const std::string& s);
0201 explicit Symbol(uint32_t value) : value(value) {}
0202
0203 operator uint32_t() const {
0204 return value;
0205 }
0206 const char* toString() const;
0207
0208 private:
0209 uint32_t value;
0210 };
0211
0212 static inline bool operator==(Symbol lhs, Symbol rhs) {
0213 return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
0214 }
0215
0216 static inline bool operator==(BuiltinSymbol lhs, Symbol rhs) {
0217 return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
0218 }
0219 static inline bool operator==(Symbol lhs, BuiltinSymbol rhs) {
0220 return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
0221 }
0222
0223 inline Symbol operator"" _sym(const char* s, size_t) {
0224 return Symbol(s);
0225 }
0226
0227 }
0228
0229
0230 namespace std {
0231 template <>
0232 struct hash<ONNX_NAMESPACE::Symbol> {
0233 std::size_t operator()(ONNX_NAMESPACE::Symbol s) const {
0234 return std::hash<uint32_t>()(static_cast<uint32_t>(s));
0235 }
0236 };
0237
0238 }