File indexing completed on 2025-01-19 09:51:39
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016 #ifndef EIGEN_BFLOAT16_H
0017 #define EIGEN_BFLOAT16_H
0018
0019 #define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
0020 template <> \
0021 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
0022 PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) { \
0023 return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x))); \
0024 }
0025
0026 namespace Eigen {
0027
0028 struct bfloat16;
0029
0030 namespace bfloat16_impl {
0031
0032
0033 struct __bfloat16_raw {
0034 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
0035 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
0036 unsigned short value;
0037 };
0038
0039 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
0040 template <bool AssumeArgumentIsNormalOrInfinityOrZero>
0041 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
0042
0043
0044 template <>
0045 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
0046 template <>
0047 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
0048 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
0049
0050 struct bfloat16_base : public __bfloat16_raw {
0051 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
0052 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
0053 };
0054
0055 }
0056
0057
0058 struct bfloat16 : public bfloat16_impl::bfloat16_base {
0059
0060 typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
0061
0062 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
0063
0064 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
0065
0066 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
0067 : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
0068
0069 template<class T>
0070 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
0071 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
0072
0073 explicit EIGEN_DEVICE_FUNC bfloat16(float f)
0074 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
0075
0076
0077
0078 template<typename RealScalar>
0079 explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
0080 : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
0081
0082 EIGEN_DEVICE_FUNC operator float() const {
0083 return bfloat16_impl::bfloat16_to_float(*this);
0084 }
0085 };
0086 }
0087
0088 namespace std {
0089 template<>
0090 struct numeric_limits<Eigen::bfloat16> {
0091 static const bool is_specialized = true;
0092 static const bool is_signed = true;
0093 static const bool is_integer = false;
0094 static const bool is_exact = false;
0095 static const bool has_infinity = true;
0096 static const bool has_quiet_NaN = true;
0097 static const bool has_signaling_NaN = true;
0098 static const float_denorm_style has_denorm = std::denorm_absent;
0099 static const bool has_denorm_loss = false;
0100 static const std::float_round_style round_style = numeric_limits<float>::round_style;
0101 static const bool is_iec559 = false;
0102 static const bool is_bounded = true;
0103 static const bool is_modulo = false;
0104 static const int digits = 8;
0105 static const int digits10 = 2;
0106 static const int max_digits10 = 4;
0107 static const int radix = 2;
0108 static const int min_exponent = numeric_limits<float>::min_exponent;
0109 static const int min_exponent10 = numeric_limits<float>::min_exponent10;
0110 static const int max_exponent = numeric_limits<float>::max_exponent;
0111 static const int max_exponent10 = numeric_limits<float>::max_exponent10;
0112 static const bool traps = numeric_limits<float>::traps;
0113 static const bool tinyness_before = numeric_limits<float>::tinyness_before;
0114
0115 static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
0116 static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
0117 static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
0118 static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
0119 static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
0120 static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
0121 static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
0122 static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
0123 static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
0124 };
0125
0126
0127
0128
0129
0130 template<>
0131 struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
0132 template<>
0133 struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
0134 template<>
0135 struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
0136 }
0137
0138 namespace Eigen {
0139
0140 namespace bfloat16_impl {
0141
0142
0143
0144
0145 #if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
0146
0147 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
0148
0149 #pragma push_macro("EIGEN_DEVICE_FUNC")
0150 #undef EIGEN_DEVICE_FUNC
0151 #if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
0152 #define EIGEN_DEVICE_FUNC __host__
0153 #else
0154 #define EIGEN_DEVICE_FUNC __host__ __device__
0155 #endif
0156 #endif
0157
0158
0159
0160
0161 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
0162 return bfloat16(float(a) + float(b));
0163 }
0164 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
0165 return bfloat16(float(a) + static_cast<float>(b));
0166 }
0167 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
0168 return bfloat16(static_cast<float>(a) + float(b));
0169 }
0170 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
0171 return bfloat16(float(a) * float(b));
0172 }
0173 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
0174 return bfloat16(float(a) - float(b));
0175 }
0176 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
0177 return bfloat16(float(a) / float(b));
0178 }
0179 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
0180 bfloat16 result;
0181 result.value = a.value ^ 0x8000;
0182 return result;
0183 }
0184 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
0185 a = bfloat16(float(a) + float(b));
0186 return a;
0187 }
0188 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
0189 a = bfloat16(float(a) * float(b));
0190 return a;
0191 }
0192 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
0193 a = bfloat16(float(a) - float(b));
0194 return a;
0195 }
0196 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
0197 a = bfloat16(float(a) / float(b));
0198 return a;
0199 }
0200 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
0201 a += bfloat16(1);
0202 return a;
0203 }
0204 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
0205 a -= bfloat16(1);
0206 return a;
0207 }
0208 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
0209 bfloat16 original_value = a;
0210 ++a;
0211 return original_value;
0212 }
0213 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
0214 bfloat16 original_value = a;
0215 --a;
0216 return original_value;
0217 }
0218 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
0219 return numext::equal_strict(float(a),float(b));
0220 }
0221 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
0222 return numext::not_equal_strict(float(a), float(b));
0223 }
0224 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
0225 return float(a) < float(b);
0226 }
0227 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
0228 return float(a) <= float(b);
0229 }
0230 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
0231 return float(a) > float(b);
0232 }
0233 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
0234 return float(a) >= float(b);
0235 }
0236
0237 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
0238 #pragma pop_macro("EIGEN_DEVICE_FUNC")
0239 #endif
0240 #endif
0241
0242
0243
0244 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
0245 return bfloat16(static_cast<float>(a) / static_cast<float>(b));
0246 }
0247
0248 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
0249 __bfloat16_raw output;
0250 if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
0251 output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
0252 return output;
0253 }
0254 const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
0255 #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
0256 output.value = p[0];
0257 #else
0258 output.value = p[1];
0259 #endif
0260 return output;
0261 }
0262
0263 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
0264 return __bfloat16_raw(value);
0265 }
0266
0267 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw& bf) {
0268 return bf.value;
0269 }
0270
0271
0272
0273 template <>
0274 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
0275 #if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
0276
0277 #else
0278 __bfloat16_raw output;
0279
0280 if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
0281
0282
0283
0284
0285
0286 output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
0287 } else {
0288
0289
0290
0291
0292
0293
0294
0295
0296
0297
0298
0299
0300
0301
0302
0303
0304
0305
0306
0307
0308
0309
0310
0311
0312
0313
0314
0315
0316
0317
0318
0319
0320
0321
0322
0323
0324
0325
0326
0327
0328
0329
0330
0331
0332
0333
0334
0335
0336
0337
0338
0339
0340
0341
0342
0343
0344
0345
0346
0347
0348
0349
0350
0351
0352
0353
0354
0355
0356
0357
0358
0359
0360
0361
0362
0363
0364
0365
0366
0367
0368
0369
0370
0371
0372
0373
0374
0375
0376
0377
0378
0379
0380
0381
0382
0383
0384
0385
0386
0387
0388
0389
0390
0391
0392
0393
0394
0395
0396
0397
0398
0399
0400
0401
0402
0403
0404
0405
0406
0407
0408
0409
0410
0411
0412
0413
0414
0415
0416
0417
0418
0419
0420
0421
0422
0423
0424
0425
0426
0427
0428
0429
0430
0431
0432
0433
0434
0435
0436
0437 output = float_to_bfloat16_rtne<true>(ff);
0438 }
0439 return output;
0440 #endif
0441 }
0442
0443
0444
0445
0446
0447 template <>
0448 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
0449 #if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
0450
0451 #else
0452 numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
0453 __bfloat16_raw output;
0454
0455
0456 numext::uint32_t lsb = (input >> 16) & 1;
0457 numext::uint32_t rounding_bias = 0x7fff + lsb;
0458 input += rounding_bias;
0459 output.value = static_cast<numext::uint16_t>(input >> 16);
0460 return output;
0461 #endif
0462 }
0463
0464 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
0465 float result = 0;
0466 unsigned short* q = reinterpret_cast<unsigned short*>(&result);
0467 #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
0468 q[0] = h.value;
0469 #else
0470 q[1] = h.value;
0471 #endif
0472 return result;
0473 }
0474
0475
0476 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
0477 EIGEN_USING_STD(isinf);
0478 return (isinf)(float(a));
0479 }
0480 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
0481 EIGEN_USING_STD(isnan);
0482 return (isnan)(float(a));
0483 }
0484 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
0485 return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
0486 }
0487
0488 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
0489 bfloat16 result;
0490 result.value = a.value & 0x7FFF;
0491 return result;
0492 }
0493 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
0494 return bfloat16(::expf(float(a)));
0495 }
0496 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
0497 return bfloat16(numext::expm1(float(a)));
0498 }
0499 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
0500 return bfloat16(::logf(float(a)));
0501 }
0502 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
0503 return bfloat16(numext::log1p(float(a)));
0504 }
0505 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
0506 return bfloat16(::log10f(float(a)));
0507 }
0508 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
0509 return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
0510 }
0511 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
0512 return bfloat16(::sqrtf(float(a)));
0513 }
0514 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
0515 return bfloat16(::powf(float(a), float(b)));
0516 }
0517 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
0518 return bfloat16(::sinf(float(a)));
0519 }
0520 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
0521 return bfloat16(::cosf(float(a)));
0522 }
0523 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
0524 return bfloat16(::tanf(float(a)));
0525 }
0526 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
0527 return bfloat16(::asinf(float(a)));
0528 }
0529 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
0530 return bfloat16(::acosf(float(a)));
0531 }
0532 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
0533 return bfloat16(::atanf(float(a)));
0534 }
0535 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
0536 return bfloat16(::sinhf(float(a)));
0537 }
0538 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
0539 return bfloat16(::coshf(float(a)));
0540 }
0541 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
0542 return bfloat16(::tanhf(float(a)));
0543 }
0544 #if EIGEN_HAS_CXX11_MATH
0545 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
0546 return bfloat16(::asinhf(float(a)));
0547 }
0548 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
0549 return bfloat16(::acoshf(float(a)));
0550 }
0551 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
0552 return bfloat16(::atanhf(float(a)));
0553 }
0554 #endif
0555 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
0556 return bfloat16(::floorf(float(a)));
0557 }
0558 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
0559 return bfloat16(::ceilf(float(a)));
0560 }
0561 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) {
0562 return bfloat16(::rintf(float(a)));
0563 }
0564 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) {
0565 return bfloat16(::roundf(float(a)));
0566 }
0567 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
0568 return bfloat16(::fmodf(float(a), float(b)));
0569 }
0570
0571 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
0572 const float f1 = static_cast<float>(a);
0573 const float f2 = static_cast<float>(b);
0574 return f2 < f1 ? b : a;
0575 }
0576 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
0577 const float f1 = static_cast<float>(a);
0578 const float f2 = static_cast<float>(b);
0579 return f1 < f2 ? b : a;
0580 }
0581
0582 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
0583 const float f1 = static_cast<float>(a);
0584 const float f2 = static_cast<float>(b);
0585 return bfloat16(::fminf(f1, f2));
0586 }
0587 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
0588 const float f1 = static_cast<float>(a);
0589 const float f2 = static_cast<float>(b);
0590 return bfloat16(::fmaxf(f1, f2));
0591 }
0592
0593 #ifndef EIGEN_NO_IO
0594 EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
0595 os << static_cast<float>(v);
0596 return os;
0597 }
0598 #endif
0599
0600 }
0601
0602 namespace internal {
0603
0604 template<>
0605 struct random_default_impl<bfloat16, false, false>
0606 {
0607 static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
0608 {
0609 return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
0610 }
0611 static inline bfloat16 run()
0612 {
0613 return run(bfloat16(-1.f), bfloat16(1.f));
0614 }
0615 };
0616
0617 template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
0618
0619 }
0620
0621 template<> struct NumTraits<Eigen::bfloat16>
0622 : GenericNumTraits<Eigen::bfloat16>
0623 {
0624 enum {
0625 IsSigned = true,
0626 IsInteger = false,
0627 IsComplex = false,
0628 RequireInitialization = false
0629 };
0630
0631 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
0632 return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
0633 }
0634 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
0635 return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D);
0636
0637 }
0638 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
0639 return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
0640 }
0641 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
0642 return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
0643 }
0644 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
0645 return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
0646 }
0647 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
0648 return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
0649 }
0650 };
0651
0652 }
0653
0654 namespace Eigen {
0655 namespace numext {
0656
0657 template<>
0658 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
0659 bool (isnan)(const Eigen::bfloat16& h) {
0660 return (bfloat16_impl::isnan)(h);
0661 }
0662
0663 template<>
0664 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
0665 bool (isinf)(const Eigen::bfloat16& h) {
0666 return (bfloat16_impl::isinf)(h);
0667 }
0668
0669 template<>
0670 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
0671 bool (isfinite)(const Eigen::bfloat16& h) {
0672 return (bfloat16_impl::isfinite)(h);
0673 }
0674
0675 template <>
0676 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
0677 return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
0678 }
0679
0680 template <>
0681 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
0682 return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
0683 }
0684
0685 }
0686 }
0687
0688 #if EIGEN_HAS_STD_HASH
0689 namespace std {
0690 template <>
0691 struct hash<Eigen::bfloat16> {
0692 EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
0693 return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
0694 }
0695 };
0696 }
0697 #endif
0698
0699
0700 #endif