Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:44

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // ATTENTION: The code in this file is highly EXPERIMENTAL.
0008 // Adventurous users should note that the APIs will probably change.
0009 
0010 #pragma once
0011 #include <stdint.h>
0012 
0013 #include <string>
0014 #include <unordered_map>
0015 #include <vector>
0016 
0017 namespace ONNX_NAMESPACE {
0018 
0019 #define FORALL_BUILTIN_SYMBOLS(_)   \
0020   _(spatial)                        \
0021   _(select_last_index)              \
0022   _(coordinate_transformation_mode) \
0023   _(PythonOp)                       \
0024   _(CppOp)                          \
0025   _(Param)                          \
0026   _(Select)                         \
0027   _(Return)                         \
0028   _(Eval)                           \
0029   _(add)                            \
0030   _(Add)                            \
0031   _(Div)                            \
0032   _(Mul)                            \
0033   _(Neg)                            \
0034   _(Sub)                            \
0035   _(Pow)                            \
0036   _(Sigmoid)                        \
0037   _(ArgMax)                         \
0038   _(Concat)                         \
0039   _(Softmax)                        \
0040   _(LogSoftmax)                     \
0041   _(Dropout)                        \
0042   _(Tanh)                           \
0043   _(mul)                            \
0044   _(neg)                            \
0045   _(sigmoid)                        \
0046   _(tanh)                           \
0047   _(Constant)                       \
0048   _(cat)                            \
0049   _(Slice)                          \
0050   _(Squeeze)                        \
0051   _(Undefined)                      \
0052   _(FusionGroup)                    \
0053   _(MatMul)                         \
0054   _(Gemm)                           \
0055   _(Tile)                           \
0056   _(SubConstant)                    \
0057   _(Scale)                          \
0058   _(Transpose)                      \
0059   _(Pad)                            \
0060   _(Reshape)                        \
0061   _(split)                          \
0062   _(chunk)                          \
0063   _(Offset)                         \
0064   _(value)                          \
0065   _(Subgraph)                       \
0066   _(BatchNormalization)             \
0067   _(Conv)                           \
0068   _(ConvTranspose)                  \
0069   _(is_test)                        \
0070   _(epsilon)                        \
0071   _(expand)                         \
0072   _(Expand)                         \
0073   _(order)                          \
0074   _(momentum)                       \
0075   _(consumed_inputs)                \
0076   _(kernels)                        \
0077   _(kernel_shape)                   \
0078   _(kernel)                         \
0079   _(scale)                          \
0080   _(strides)                        \
0081   _(stride)                         \
0082   _(pads)                           \
0083   _(pad)                            \
0084   _(beta)                           \
0085   _(alpha)                          \
0086   _(dilations)                      \
0087   _(dilation)                       \
0088   _(broadcast)                      \
0089   _(axis)                           \
0090   _(ratio)                          \
0091   _(size)                           \
0092   _(dim)                            \
0093   _(keepdims)                       \
0094   _(perm)                           \
0095   _(shape)                          \
0096   _(axes)                           \
0097   _(group)                          \
0098   _(inplace)                        \
0099   _(transA)                         \
0100   _(transB)                         \
0101   _(other)                          \
0102   _(__and__)                        \
0103   _(__lshift__)                     \
0104   _(__or__)                         \
0105   _(__rshift__)                     \
0106   _(__xor__)                        \
0107   _(abs)                            \
0108   _(acos)                           \
0109   _(asin)                           \
0110   _(atan)                           \
0111   _(atan2)                          \
0112   _(ceil)                           \
0113   _(clamp)                          \
0114   _(cos)                            \
0115   _(cosh)                           \
0116   _(div)                            \
0117   _(eq)                             \
0118   _(equal)                          \
0119   _(Exp)                            \
0120   _(ends)                           \
0121   _(expm1)                          \
0122   _(floor)                          \
0123   _(fmod)                           \
0124   _(frac)                           \
0125   _(ge)                             \
0126   _(gt)                             \
0127   _(le)                             \
0128   _(lerp)                           \
0129   _(lgamma)                         \
0130   _(Log)                            \
0131   _(log1p)                          \
0132   _(lt)                             \
0133   _(max)                            \
0134   _(min)                            \
0135   _(ne)                             \
0136   _(ones)                           \
0137   _(pow)                            \
0138   _(reciprocal)                     \
0139   _(remainder)                      \
0140   _(round)                          \
0141   _(rsqrt)                          \
0142   _(sin)                            \
0143   _(sinh)                           \
0144   _(Sqrt)                           \
0145   _(sub)                            \
0146   _(starts)                         \
0147   _(tan)                            \
0148   _(trunc)                          \
0149   _(zeros)                          \
0150   _(exponent)                       \
0151   _(device)                         \
0152   _(mode)                           \
0153   _(Identity)                       \
0154   _(Loop)                           \
0155   _(If)                             \
0156   _(body)                           \
0157   _(then_branch)                    \
0158   _(else_branch)                    \
0159   _(Captured)                       \
0160   _(__control_inputs)               \
0161   _(count_include_pad)              \
0162   _(storage_order)                  \
0163   _(Unsqueeze)                      \
0164   _(ReduceL1)                       \
0165   _(ReduceL2)                       \
0166   _(ReduceLogSum)                   \
0167   _(ReduceLogSumExp)                \
0168   _(ReduceMax)                      \
0169   _(ReduceMean)                     \
0170   _(ReduceMin)                      \
0171   _(ReduceProd)                     \
0172   _(ReduceSum)                      \
0173   _(ReduceSumSquare)                \
0174   _(Cast)                           \
0175   _(to)                             \
0176   _(PRelu)                          \
0177   _(Greater)                        \
0178   _(Less)                           \
0179   _(scales)                         \
0180   _(Upsample)                       \
0181   _(RNN)                            \
0182   _(layout)                         \
0183   _(k)                              \
0184   _(Flatten)                        \
0185   _(ScatterElements)                \
0186   _(Resize)                         \
0187   _(ceil_mode)                      \
0188   _(num_outputs)
0189 
0190 enum BuiltinSymbol {
0191 #define DEFINE_SYMBOL(s) k##s,
0192   FORALL_BUILTIN_SYMBOLS(DEFINE_SYMBOL)
0193 #undef DEFINE_SYMBOL
0194       kLastSymbol, // where we start counting for new symbols
0195 };
0196 
0197 struct Symbol {
0198   Symbol() {}
0199   /*implicit*/ Symbol(BuiltinSymbol value) : value(value) {}
0200   explicit Symbol(const std::string& s);
0201   explicit Symbol(uint32_t value) : value(value) {}
0202 
0203   operator uint32_t() const {
0204     return value;
0205   }
0206   const char* toString() const;
0207 
0208  private:
0209   uint32_t value;
0210 };
0211 
0212 static inline bool operator==(Symbol lhs, Symbol rhs) {
0213   return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
0214 }
0215 // necessary to prevent ambiguous overload resolutions
0216 static inline bool operator==(BuiltinSymbol lhs, Symbol rhs) {
0217   return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
0218 }
0219 static inline bool operator==(Symbol lhs, BuiltinSymbol rhs) {
0220   return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
0221 }
0222 
0223 inline Symbol operator"" _sym(const char* s, size_t) {
0224   return Symbol(s);
0225 }
0226 
0227 } // namespace ONNX_NAMESPACE
0228 
0229 // make symbol behave like an integer in hash tables
0230 namespace std {
0231 template <>
0232 struct hash<ONNX_NAMESPACE::Symbol> {
0233   std::size_t operator()(ONNX_NAMESPACE::Symbol s) const {
0234     return std::hash<uint32_t>()(static_cast<uint32_t>(s));
0235   }
0236 };
0237 
0238 } // namespace std