Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-19 09:51:39

0001 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
0002 
0003 Licensed under the Apache License, Version 2.0 (the "License");
0004 you may not use this file except in compliance with the License.
0005 You may obtain a copy of the License at
0006 
0007     http://www.apache.org/licenses/LICENSE-2.0
0008 
0009 Unless required by applicable law or agreed to in writing, software
0010 distributed under the License is distributed on an "AS IS" BASIS,
0011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0012 See the License for the specific language governing permissions and
0013 limitations under the License.
0014 ==============================================================================*/
0015 
0016 #ifndef EIGEN_BFLOAT16_H
0017 #define EIGEN_BFLOAT16_H
0018 
0019 #define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD)         \
0020   template <>                                                       \
0021   EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED  \
0022   PACKET_BF16 METHOD<PACKET_BF16>(const PACKET_BF16& _x) {          \
0023     return F32ToBf16(METHOD<PACKET_F>(Bf16ToF32(_x)));              \
0024   }
0025 
0026 namespace Eigen {
0027 
0028 struct bfloat16;
0029 
0030 namespace bfloat16_impl {
0031 
0032 // Make our own __bfloat16_raw definition.
0033 struct __bfloat16_raw {
0034   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {}
0035   explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {}
0036   unsigned short value;
0037 };
0038 
0039 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
0040 template <bool AssumeArgumentIsNormalOrInfinityOrZero>
0041 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
0042 // Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
0043 // > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
0044 template <>
0045 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
0046 template <>
0047 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
0048 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
0049 
0050 struct bfloat16_base : public __bfloat16_raw {
0051   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {}
0052   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {}
0053 };
0054 
0055 } // namespace bfloat16_impl
0056 
0057 // Class definition.
0058 struct bfloat16 : public bfloat16_impl::bfloat16_base {
0059 
0060   typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw;
0061 
0062   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {}
0063 
0064   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {}
0065 
0066   explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b)
0067       : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
0068 
0069   template<class T>
0070   explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val)
0071       : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
0072 
0073   explicit EIGEN_DEVICE_FUNC bfloat16(float f)
0074       : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
0075 
0076   // Following the convention of numpy, converting between complex and
0077   // float will lead to loss of imag value.
0078   template<typename RealScalar>
0079   explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex<RealScalar>& val)
0080       : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
0081 
0082   EIGEN_DEVICE_FUNC operator float() const {  // NOLINT: Allow implicit conversion to float, because it is lossless.
0083     return bfloat16_impl::bfloat16_to_float(*this);
0084   }
0085 };
0086 } // namespace Eigen
0087 
0088 namespace std {
0089 template<>
0090 struct numeric_limits<Eigen::bfloat16> {
0091   static const bool is_specialized = true;
0092   static const bool is_signed = true;
0093   static const bool is_integer = false;
0094   static const bool is_exact = false;
0095   static const bool has_infinity = true;
0096   static const bool has_quiet_NaN = true;
0097   static const bool has_signaling_NaN = true;
0098   static const float_denorm_style has_denorm = std::denorm_absent;
0099   static const bool has_denorm_loss = false;
0100   static const std::float_round_style round_style = numeric_limits<float>::round_style;
0101   static const bool is_iec559 = false;
0102   static const bool is_bounded = true;
0103   static const bool is_modulo = false;
0104   static const int digits = 8;
0105   static const int digits10 = 2;
0106   static const int max_digits10 = 4;
0107   static const int radix = 2;
0108   static const int min_exponent = numeric_limits<float>::min_exponent;
0109   static const int min_exponent10 = numeric_limits<float>::min_exponent10;
0110   static const int max_exponent = numeric_limits<float>::max_exponent;
0111   static const int max_exponent10 = numeric_limits<float>::max_exponent10;
0112   static const bool traps = numeric_limits<float>::traps;
0113   static const bool tinyness_before = numeric_limits<float>::tinyness_before;
0114 
0115   static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); }
0116   static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); }
0117   static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); }
0118   static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); }
0119   static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); }
0120   static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); }
0121   static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); }
0122   static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); }
0123   static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); }
0124 };
0125 
0126 // If std::numeric_limits<T> is specialized, should also specialize
0127 // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
0128 // std::numeric_limits<const volatile T>
0129 // https://stackoverflow.com/a/16519653/
0130 template<>
0131 struct numeric_limits<const Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
0132 template<>
0133 struct numeric_limits<volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
0134 template<>
0135 struct numeric_limits<const volatile Eigen::bfloat16> : numeric_limits<Eigen::bfloat16> {};
0136 } // namespace std
0137 
0138 namespace Eigen {
0139 
0140 namespace bfloat16_impl {
0141 
0142 // We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
0143 // invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
0144 // of the functions, while the latter can only deal with one of them.
0145 #if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats
0146 
0147 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
0148 // We need to provide emulated *host-side* BF16 operators for clang.
0149 #pragma push_macro("EIGEN_DEVICE_FUNC")
0150 #undef EIGEN_DEVICE_FUNC
0151 #if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16)
0152 #define EIGEN_DEVICE_FUNC __host__
0153 #else // both host and device need emulated ops.
0154 #define EIGEN_DEVICE_FUNC __host__ __device__
0155 #endif
0156 #endif
0157 
0158 // Definitions for CPUs, mostly working through conversion
0159 // to/from fp32.
0160 
0161 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) {
0162   return bfloat16(float(a) + float(b));
0163 }
0164 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) {
0165   return bfloat16(float(a) + static_cast<float>(b));
0166 }
0167 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) {
0168   return bfloat16(static_cast<float>(a) + float(b));
0169 }
0170 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) {
0171   return bfloat16(float(a) * float(b));
0172 }
0173 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) {
0174   return bfloat16(float(a) - float(b));
0175 }
0176 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) {
0177   return bfloat16(float(a) / float(b));
0178 }
0179 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) {
0180   bfloat16 result;
0181   result.value = a.value ^ 0x8000;
0182   return result;
0183 }
0184 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) {
0185   a = bfloat16(float(a) + float(b));
0186   return a;
0187 }
0188 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) {
0189   a = bfloat16(float(a) * float(b));
0190   return a;
0191 }
0192 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) {
0193   a = bfloat16(float(a) - float(b));
0194   return a;
0195 }
0196 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) {
0197   a = bfloat16(float(a) / float(b));
0198   return a;
0199 }
0200 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) {
0201   a += bfloat16(1);
0202   return a;
0203 }
0204 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) {
0205   a -= bfloat16(1);
0206   return a;
0207 }
0208 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) {
0209   bfloat16 original_value = a;
0210   ++a;
0211   return original_value;
0212 }
0213 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) {
0214   bfloat16 original_value = a;
0215   --a;
0216   return original_value;
0217 }
0218 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) {
0219   return numext::equal_strict(float(a),float(b));
0220 }
0221 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) {
0222   return numext::not_equal_strict(float(a), float(b));
0223 }
0224 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) {
0225   return float(a) < float(b);
0226 }
0227 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) {
0228   return float(a) <= float(b);
0229 }
0230 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) {
0231   return float(a) > float(b);
0232 }
0233 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) {
0234   return float(a) >= float(b);
0235 }
0236 
0237 #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
0238 #pragma pop_macro("EIGEN_DEVICE_FUNC")
0239 #endif
0240 #endif  // Emulate support for bfloat16 floats
0241 
0242 // Division by an index. Do it in full float precision to avoid accuracy
0243 // issues in converting the denominator to bfloat16.
0244 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) {
0245   return bfloat16(static_cast<float>(a) / static_cast<float>(b));
0246 }
0247 
0248 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) {
0249   __bfloat16_raw output;
0250   if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
0251     output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
0252     return output;
0253   }
0254   const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
0255 #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
0256   output.value = p[0];
0257 #else
0258   output.value = p[1];
0259 #endif
0260   return output;
0261 }
0262 
0263 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) {
0264   return __bfloat16_raw(value);
0265 }
0266 
0267 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw& bf) {
0268   return bf.value;
0269 }
0270 
0271 // float_to_bfloat16_rtne template specialization that does not make any
0272 // assumption about the value of its function argument (ff).
0273 template <>
0274 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
0275 #if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
0276   // Nothing to do here
0277 #else
0278   __bfloat16_raw output;
0279 
0280   if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
0281     // If the value is a NaN, squash it to a qNaN with msb of fraction set,
0282     // this makes sure after truncation we don't end up with an inf.
0283     //
0284     // qNaN magic: All exponent bits set + most significant bit of fraction
0285     // set.
0286     output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
0287   } else {
0288     // Fast rounding algorithm that rounds a half value to nearest even. This
0289     // reduces expected error when we convert a large number of floats. Here
0290     // is how it works:
0291     //
0292     // Definitions:
0293     // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
0294     // with the following tags:
0295     //
0296     // Sign |  Exp (8 bits) | Frac (23 bits)
0297     //  S     EEEEEEEE         FFFFFFLRTTTTTTTTTTTTTTT
0298     //
0299     //  S: Sign bit.
0300     //  E: Exponent bits.
0301     //  F: First 6 bits of fraction.
0302     //  L: Least significant bit of resulting bfloat16 if we truncate away the
0303     //  rest of the float32. This is also the 7th bit of fraction
0304     //  R: Rounding bit, 8th bit of fraction.
0305     //  T: Sticky bits, rest of fraction, 15 bits.
0306     //
0307     // To round half to nearest even, there are 3 cases where we want to round
0308     // down (simply truncate the result of the bits away, which consists of
0309     // rounding bit and sticky bits) and two cases where we want to round up
0310     // (truncate then add one to the result).
0311     //
0312     // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
0313     // 1s) as the rounding bias, adds the rounding bias to the input, then
0314     // truncates the last 16 bits away.
0315     //
0316     // To understand how it works, we can analyze this algorithm case by case:
0317     //
0318     // 1. L = 0, R = 0:
0319     //   Expect: round down, this is less than half value.
0320     //
0321     //   Algorithm:
0322     //   - Rounding bias: 0x7fff + 0 = 0x7fff
0323     //   - Adding rounding bias to input may create any carry, depending on
0324     //   whether there is any value set to 1 in T bits.
0325     //   - R may be set to 1 if there is a carry.
0326     //   - L remains 0.
0327     //   - Note that this case also handles Inf and -Inf, where all fraction
0328     //   bits, including L, R and Ts are all 0. The output remains Inf after
0329     //   this algorithm.
0330     //
0331     // 2. L = 1, R = 0:
0332     //   Expect: round down, this is less than half value.
0333     //
0334     //   Algorithm:
0335     //   - Rounding bias: 0x7fff + 1 = 0x8000
0336     //   - Adding rounding bias to input doesn't change sticky bits but
0337     //   adds 1 to rounding bit.
0338     //   - L remains 1.
0339     //
0340     // 3. L = 0, R = 1, all of T are 0:
0341     //   Expect: round down, this is exactly at half, the result is already
0342     //   even (L=0).
0343     //
0344     //   Algorithm:
0345     //   - Rounding bias: 0x7fff + 0 = 0x7fff
0346     //   - Adding rounding bias to input sets all sticky bits to 1, but
0347     //   doesn't create a carry.
0348     //   - R remains 1.
0349     //   - L remains 0.
0350     //
0351     // 4. L = 1, R = 1:
0352     //   Expect: round up, this is exactly at half, the result needs to be
0353     //   round to the next even number.
0354     //
0355     //   Algorithm:
0356     //   - Rounding bias: 0x7fff + 1 = 0x8000
0357     //   - Adding rounding bias to input doesn't change sticky bits, but
0358     //   creates a carry from rounding bit.
0359     //   - The carry sets L to 0, creates another carry bit and propagate
0360     //   forward to F bits.
0361     //   - If all the F bits are 1, a carry then propagates to the exponent
0362     //   bits, which then creates the minimum value with the next exponent
0363     //   value. Note that we won't have the case where exponents are all 1,
0364     //   since that's either a NaN (handled in the other if condition) or inf
0365     //   (handled in case 1).
0366     //
0367     // 5. L = 0, R = 1, any of T is 1:
0368     //   Expect: round up, this is greater than half.
0369     //
0370     //   Algorithm:
0371     //   - Rounding bias: 0x7fff + 0 = 0x7fff
0372     //   - Adding rounding bias to input creates a carry from sticky bits,
0373     //   sets rounding bit to 0, then create another carry.
0374     //   - The second carry sets L to 1.
0375     //
0376     // Examples:
0377     //
0378     //  Exact half value that is already even:
0379     //    Input:
0380     //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
0381     //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
0382     //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0     1000000000000000
0383     //
0384     //     This falls into case 3. We truncate the rest of 16 bits and no
0385     //     carry is created into F and L:
0386     //
0387     //    Output:
0388     //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
0389     //     S     E E E E E E E E      F F F F F F L
0390     //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0
0391     //
0392     //  Exact half value, round to next even number:
0393     //    Input:
0394     //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
0395     //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
0396     //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 0 1     1000000000000000
0397     //
0398     //     This falls into case 4. We create a carry from R and T,
0399     //     which then propagates into L and F:
0400     //
0401     //    Output:
0402     //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
0403     //     S     E E E E E E E E      F F F F F F L
0404     //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0
0405     //
0406     //
0407     //  Max denormal value round to min normal value:
0408     //    Input:
0409     //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
0410     //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
0411     //     0     0 0 0 0 0 0 0 0      1 1 1 1 1 1 1     1111111111111111
0412     //
0413     //     This falls into case 4. We create a carry from R and T,
0414     //     propagate into L and F, which then propagates into exponent
0415     //     bits:
0416     //
0417     //    Output:
0418     //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
0419     //     S     E E E E E E E E      F F F F F F L
0420     //     0     0 0 0 0 0 0 0 1      0 0 0 0 0 0 0
0421     //
0422     //  Max normal value round to Inf:
0423     //    Input:
0424     //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
0425     //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
0426     //     0     1 1 1 1 1 1 1 0      1 1 1 1 1 1 1     1111111111111111
0427     //
0428     //     This falls into case 4. We create a carry from R and T,
0429     //     propagate into L and F, which then propagates into exponent
0430     //     bits:
0431     //
0432     //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
0433     //     S     E E E E E E E E      F F F F F F L
0434     //     0     1 1 1 1 1 1 1 1      0 0 0 0 0 0 0
0435 
0436     // At this point, ff must be either a normal float, or +/-infinity.
0437     output = float_to_bfloat16_rtne<true>(ff);
0438   }
0439   return output;
0440 #endif
0441 }
0442 
0443 // float_to_bfloat16_rtne template specialization that assumes that its function
0444 // argument (ff) is either a normal floating point number, or +/-infinity, or
0445 // zero. Used to improve the runtime performance of conversion from an integer
0446 // type to bfloat16.
0447 template <>
0448 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
0449 #if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
0450     // Nothing to do here
0451 #else
0452     numext::uint32_t input = numext::bit_cast<numext::uint32_t>(ff);
0453     __bfloat16_raw output;
0454 
0455     // Least significant bit of resulting bfloat.
0456     numext::uint32_t lsb = (input >> 16) & 1;
0457     numext::uint32_t rounding_bias = 0x7fff + lsb;
0458     input += rounding_bias;
0459     output.value = static_cast<numext::uint16_t>(input >> 16);
0460     return output;
0461 #endif
0462 }
0463 
0464 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) {
0465     float result = 0;
0466     unsigned short* q = reinterpret_cast<unsigned short*>(&result);
0467 #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
0468     q[0] = h.value;
0469 #else
0470     q[1] = h.value;
0471 #endif
0472     return result;
0473 }
0474 // --- standard functions ---
0475 
0476 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) {
0477   EIGEN_USING_STD(isinf);
0478   return (isinf)(float(a));
0479 }
0480 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) {
0481   EIGEN_USING_STD(isnan);
0482   return (isnan)(float(a));
0483 }
0484 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) {
0485   return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
0486 }
0487 
0488 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) {
0489   bfloat16 result;
0490   result.value = a.value & 0x7FFF;
0491   return result;
0492 }
0493 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) {
0494    return bfloat16(::expf(float(a)));
0495 }
0496 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) {
0497   return bfloat16(numext::expm1(float(a)));
0498 }
0499 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) {
0500   return bfloat16(::logf(float(a)));
0501 }
0502 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) {
0503   return bfloat16(numext::log1p(float(a)));
0504 }
0505 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) {
0506   return bfloat16(::log10f(float(a)));
0507 }
0508 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) {
0509   return bfloat16(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
0510 }
0511 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
0512     return bfloat16(::sqrtf(float(a)));
0513 }
0514 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
0515   return bfloat16(::powf(float(a), float(b)));
0516 }
0517 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
0518   return bfloat16(::sinf(float(a)));
0519 }
0520 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) {
0521   return bfloat16(::cosf(float(a)));
0522 }
0523 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) {
0524   return bfloat16(::tanf(float(a)));
0525 }
0526 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) {
0527   return bfloat16(::asinf(float(a)));
0528 }
0529 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) {
0530   return bfloat16(::acosf(float(a)));
0531 }
0532 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) {
0533   return bfloat16(::atanf(float(a)));
0534 }
0535 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) {
0536   return bfloat16(::sinhf(float(a)));
0537 }
0538 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) {
0539   return bfloat16(::coshf(float(a)));
0540 }
0541 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) {
0542   return bfloat16(::tanhf(float(a)));
0543 }
0544 #if EIGEN_HAS_CXX11_MATH
0545 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) {
0546   return bfloat16(::asinhf(float(a)));
0547 }
0548 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) {
0549   return bfloat16(::acoshf(float(a)));
0550 }
0551 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) {
0552   return bfloat16(::atanhf(float(a)));
0553 }
0554 #endif
0555 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) {
0556   return bfloat16(::floorf(float(a)));
0557 }
0558 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) {
0559   return bfloat16(::ceilf(float(a)));
0560 }
0561 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) {
0562   return bfloat16(::rintf(float(a)));
0563 }
0564 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) {
0565   return bfloat16(::roundf(float(a)));
0566 }
0567 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) {
0568   return bfloat16(::fmodf(float(a), float(b)));
0569 }
0570 
0571 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) {
0572   const float f1 = static_cast<float>(a);
0573   const float f2 = static_cast<float>(b);
0574   return f2 < f1 ? b : a;
0575 }
0576 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) {
0577   const float f1 = static_cast<float>(a);
0578   const float f2 = static_cast<float>(b);
0579   return f1 < f2 ? b : a;
0580 }
0581 
0582 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) {
0583   const float f1 = static_cast<float>(a);
0584   const float f2 = static_cast<float>(b);
0585   return bfloat16(::fminf(f1, f2));
0586 }
0587 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) {
0588   const float f1 = static_cast<float>(a);
0589   const float f2 = static_cast<float>(b);
0590   return bfloat16(::fmaxf(f1, f2));
0591 }
0592 
0593 #ifndef EIGEN_NO_IO
0594 EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) {
0595   os << static_cast<float>(v);
0596   return os;
0597 }
0598 #endif
0599 
0600 } // namespace bfloat16_impl
0601 
0602 namespace internal {
0603 
0604 template<>
0605 struct random_default_impl<bfloat16, false, false>
0606 {
0607   static inline bfloat16 run(const bfloat16& x, const bfloat16& y)
0608   {
0609     return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX));
0610   }
0611   static inline bfloat16 run()
0612   {
0613     return run(bfloat16(-1.f), bfloat16(1.f));
0614   }
0615 };
0616 
0617 template<> struct is_arithmetic<bfloat16> { enum { value = true }; };
0618 
0619 } // namespace internal
0620 
0621 template<> struct NumTraits<Eigen::bfloat16>
0622     : GenericNumTraits<Eigen::bfloat16>
0623 {
0624   enum {
0625     IsSigned = true,
0626     IsInteger = false,
0627     IsComplex = false,
0628     RequireInitialization = false
0629   };
0630 
0631   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() {
0632     return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00);
0633   }
0634   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() {
0635     return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D);  // bfloat16(5e-2f);
0636 
0637   }
0638   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() {
0639     return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F);
0640   }
0641   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() {
0642     return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F);
0643   }
0644   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() {
0645     return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80);
0646   }
0647   EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() {
0648     return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0);
0649   }
0650 };
0651 
0652 } // namespace Eigen
0653 
0654 namespace Eigen {
0655 namespace numext {
0656 
0657 template<>
0658 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
0659 bool (isnan)(const Eigen::bfloat16& h) {
0660   return (bfloat16_impl::isnan)(h);
0661 }
0662 
0663 template<>
0664 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
0665 bool (isinf)(const Eigen::bfloat16& h) {
0666   return (bfloat16_impl::isinf)(h);
0667 }
0668 
0669 template<>
0670 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
0671 bool (isfinite)(const Eigen::bfloat16& h) {
0672   return (bfloat16_impl::isfinite)(h);
0673 }
0674 
0675 template <>
0676 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast<Eigen::bfloat16, uint16_t>(const uint16_t& src) {
0677   return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src));
0678 }
0679 
0680 template <>
0681 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat16>(const Eigen::bfloat16& src) {
0682   return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
0683 }
0684 
0685 }  // namespace numext
0686 }  // namespace Eigen
0687 
0688 #if EIGEN_HAS_STD_HASH
0689 namespace std {
0690 template <>
0691 struct hash<Eigen::bfloat16> {
0692   EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const {
0693     return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
0694   }
0695 };
0696 } // namespace std
0697 #endif
0698 
0699 
0700 #endif // EIGEN_BFLOAT16_H