Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:02:52

0001 // Copyright (c) Microsoft Corporation. All rights reserved.
0002 // Licensed under the MIT License.
0003 
0004 #pragma once
0005 
0006 #include <stdint.h>
0007 #include <cmath>
0008 #include <cstring>
0009 #include <limits>
0010 
0011 namespace onnxruntime_float16 {
0012 
0013 namespace detail {
0014 
0015 enum class endian {
0016 #if defined(_WIN32)
0017   little = 0,
0018   big = 1,
0019   native = little,
0020 #elif defined(__GNUC__) || defined(__clang__)
0021   little = __ORDER_LITTLE_ENDIAN__,
0022   big = __ORDER_BIG_ENDIAN__,
0023   native = __BYTE_ORDER__,
0024 #else
0025 #error onnxruntime_float16::detail::endian is not implemented in this environment.
0026 #endif
0027 };
0028 
0029 static_assert(
0030     endian::native == endian::little || endian::native == endian::big,
0031     "Only little-endian or big-endian native byte orders are supported.");
0032 
0033 }  // namespace detail
0034 
0035 /// <summary>
0036 /// Shared implementation between public and internal classes. CRTP pattern.
0037 /// </summary>
0038 template <class Derived>
0039 struct Float16Impl {
0040  protected:
0041   /// <summary>
0042   /// Converts from float to uint16_t float16 representation
0043   /// </summary>
0044   /// <param name="v"></param>
0045   /// <returns></returns>
0046   constexpr static uint16_t ToUint16Impl(float v) noexcept;
0047 
0048   /// <summary>
0049   /// Converts float16 to float
0050   /// </summary>
0051   /// <returns>float representation of float16 value</returns>
0052   float ToFloatImpl() const noexcept;
0053 
0054   /// <summary>
0055   /// Creates an instance that represents absolute value.
0056   /// </summary>
0057   /// <returns>Absolute value</returns>
0058   uint16_t AbsImpl() const noexcept {
0059     return static_cast<uint16_t>(val & ~kSignMask);
0060   }
0061 
0062   /// <summary>
0063   /// Creates a new instance with the sign flipped.
0064   /// </summary>
0065   /// <returns>Flipped sign instance</returns>
0066   uint16_t NegateImpl() const noexcept {
0067     return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
0068   }
0069 
0070  public:
0071   // uint16_t special values
0072   static constexpr uint16_t kSignMask = 0x8000U;
0073   static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
0074   static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
0075   static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
0076   static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
0077   static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
0078   static constexpr uint16_t kEpsilonBits = 0x4170U;
0079   static constexpr uint16_t kMinValueBits = 0xFBFFU;  // Minimum normal number
0080   static constexpr uint16_t kMaxValueBits = 0x7BFFU;  // Largest normal number
0081   static constexpr uint16_t kOneBits = 0x3C00U;
0082   static constexpr uint16_t kMinusOneBits = 0xBC00U;
0083 
0084   uint16_t val{0};
0085 
0086   Float16Impl() = default;
0087 
0088   /// <summary>
0089   /// Checks if the value is negative
0090   /// </summary>
0091   /// <returns>true if negative</returns>
0092   bool IsNegative() const noexcept {
0093     return static_cast<int16_t>(val) < 0;
0094   }
0095 
0096   /// <summary>
0097   /// Tests if the value is NaN
0098   /// </summary>
0099   /// <returns>true if NaN</returns>
0100   bool IsNaN() const noexcept {
0101     return AbsImpl() > kPositiveInfinityBits;
0102   }
0103 
0104   /// <summary>
0105   /// Tests if the value is finite
0106   /// </summary>
0107   /// <returns>true if finite</returns>
0108   bool IsFinite() const noexcept {
0109     return AbsImpl() < kPositiveInfinityBits;
0110   }
0111 
0112   /// <summary>
0113   /// Tests if the value represents positive infinity.
0114   /// </summary>
0115   /// <returns>true if positive infinity</returns>
0116   bool IsPositiveInfinity() const noexcept {
0117     return val == kPositiveInfinityBits;
0118   }
0119 
0120   /// <summary>
0121   /// Tests if the value represents negative infinity
0122   /// </summary>
0123   /// <returns>true if negative infinity</returns>
0124   bool IsNegativeInfinity() const noexcept {
0125     return val == kNegativeInfinityBits;
0126   }
0127 
0128   /// <summary>
0129   /// Tests if the value is either positive or negative infinity.
0130   /// </summary>
0131   /// <returns>True if absolute value is infinity</returns>
0132   bool IsInfinity() const noexcept {
0133     return AbsImpl() == kPositiveInfinityBits;
0134   }
0135 
0136   /// <summary>
0137   /// Tests if the value is NaN or zero. Useful for comparisons.
0138   /// </summary>
0139   /// <returns>True if NaN or zero.</returns>
0140   bool IsNaNOrZero() const noexcept {
0141     auto abs = AbsImpl();
0142     return (abs == 0 || abs > kPositiveInfinityBits);
0143   }
0144 
0145   /// <summary>
0146   /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
0147   /// </summary>
0148   /// <returns>True if so</returns>
0149   bool IsNormal() const noexcept {
0150     auto abs = AbsImpl();
0151     return (abs < kPositiveInfinityBits)           // is finite
0152            && (abs != 0)                           // is not zero
0153            && ((abs & kBiasedExponentMask) != 0);  // is not subnormal (has a non-zero exponent)
0154   }
0155 
0156   /// <summary>
0157   /// Tests if the value is subnormal (denormal).
0158   /// </summary>
0159   /// <returns>True if so</returns>
0160   bool IsSubnormal() const noexcept {
0161     auto abs = AbsImpl();
0162     return (abs < kPositiveInfinityBits)           // is finite
0163            && (abs != 0)                           // is not zero
0164            && ((abs & kBiasedExponentMask) == 0);  // is subnormal (has a zero exponent)
0165   }
0166 
0167   /// <summary>
0168   /// Creates an instance that represents absolute value.
0169   /// </summary>
0170   /// <returns>Absolute value</returns>
0171   Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
0172 
0173   /// <summary>
0174   /// Creates a new instance with the sign flipped.
0175   /// </summary>
0176   /// <returns>Flipped sign instance</returns>
0177   Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
0178 
0179   /// <summary>
0180   /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
0181   /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
0182   /// and therefore equivalent, if the resulting value is still zero.
0183   /// </summary>
0184   /// <param name="lhs">first value</param>
0185   /// <param name="rhs">second value</param>
0186   /// <returns>True if both arguments represent zero</returns>
0187   static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
0188     return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
0189   }
0190 
0191   bool operator==(const Float16Impl& rhs) const noexcept {
0192     if (IsNaN() || rhs.IsNaN()) {
0193       // IEEE defines that NaN is not equal to anything, including itself.
0194       return false;
0195     }
0196     return val == rhs.val;
0197   }
0198 
0199   bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
0200 
0201   bool operator<(const Float16Impl& rhs) const noexcept {
0202     if (IsNaN() || rhs.IsNaN()) {
0203       // IEEE defines that NaN is unordered with respect to everything, including itself.
0204       return false;
0205     }
0206 
0207     const bool left_is_negative = IsNegative();
0208     if (left_is_negative != rhs.IsNegative()) {
0209       // When the signs of left and right differ, we know that left is less than right if it is
0210       // the negative value. The exception to this is if both values are zero, in which case IEEE
0211       // says they should be equal, even if the signs differ.
0212       return left_is_negative && !AreZero(*this, rhs);
0213     }
0214     return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
0215   }
0216 };
0217 
0218 // The following Float16_t conversions are based on the code from
0219 // Eigen library.
0220 
0221 // The conversion routines are Copyright (c) Fabian Giesen, 2016.
0222 // The original license follows:
0223 //
0224 // Copyright (c) Fabian Giesen, 2016
0225 // All rights reserved.
0226 // Redistribution and use in source and binary forms, with or without
0227 // modification, are permitted.
0228 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
0229 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
0230 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
0231 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
0232 // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
0233 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
0234 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
0235 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
0236 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
0237 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
0238 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
0239 
0240 namespace detail {
0241 union float32_bits {
0242   unsigned int u;
0243   float f;
0244 };
0245 }  // namespace detail
0246 
0247 template <class Derived>
0248 inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
0249   detail::float32_bits f{};
0250   f.f = v;
0251 
0252   constexpr detail::float32_bits f32infty = {255 << 23};
0253   constexpr detail::float32_bits f16max = {(127 + 16) << 23};
0254   constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
0255   constexpr unsigned int sign_mask = 0x80000000u;
0256   uint16_t val = static_cast<uint16_t>(0x0u);
0257 
0258   unsigned int sign = f.u & sign_mask;
0259   f.u ^= sign;
0260 
0261   // NOTE all the integer compares in this function can be safely
0262   // compiled into signed compares since all operands are below
0263   // 0x80000000. Important if you want fast straight SSE2 code
0264   // (since there's no unsigned PCMPGTD).
0265 
0266   if (f.u >= f16max.u) {                         // result is Inf or NaN (all exponent bits set)
0267     val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00;  // NaN->qNaN and Inf->Inf
0268   } else {                                       // (De)normalized number or zero
0269     if (f.u < (113 << 23)) {                     // resulting FP16 is subnormal or zero
0270       // use a magic value to align our 10 mantissa bits at the bottom of
0271       // the float. as long as FP addition is round-to-nearest-even this
0272       // just works.
0273       f.f += denorm_magic.f;
0274 
0275       // and one integer subtract of the bias later, we have our final float!
0276       val = static_cast<uint16_t>(f.u - denorm_magic.u);
0277     } else {
0278       unsigned int mant_odd = (f.u >> 13) & 1;  // resulting mantissa is odd
0279 
0280       // update exponent, rounding bias part 1
0281       // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
0282       // without arithmetic overflow.
0283       f.u += 0xc8000fffU;
0284       // rounding bias part 2
0285       f.u += mant_odd;
0286       // take the bits!
0287       val = static_cast<uint16_t>(f.u >> 13);
0288     }
0289   }
0290 
0291   val |= static_cast<uint16_t>(sign >> 16);
0292   return val;
0293 }
0294 
0295 template <class Derived>
0296 inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
0297   constexpr detail::float32_bits magic = {113 << 23};
0298   constexpr unsigned int shifted_exp = 0x7c00 << 13;  // exponent mask after shift
0299   detail::float32_bits o{};
0300 
0301   o.u = (val & 0x7fff) << 13;            // exponent/mantissa bits
0302   unsigned int exp = shifted_exp & o.u;  // just the exponent
0303   o.u += (127 - 15) << 23;               // exponent adjust
0304 
0305   // handle exponent special cases
0306   if (exp == shifted_exp) {   // Inf/NaN?
0307     o.u += (128 - 16) << 23;  // extra exp adjust
0308   } else if (exp == 0) {      // Zero/Denormal?
0309     o.u += 1 << 23;           // extra exp adjust
0310     o.f -= magic.f;           // re-normalize
0311   }
0312 
0313   // Attempt to workaround the Internal Compiler Error on ARM64
0314   // for bitwise | operator, including std::bitset
0315 #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
0316   if (IsNegative()) {
0317     return -o.f;
0318   }
0319 #else
0320   // original code:
0321   o.u |= (val & 0x8000U) << 16U;  // sign bit
0322 #endif
0323   return o.f;
0324 }
0325 
0326 /// Shared implementation between public and internal classes. CRTP pattern.
0327 template <class Derived>
0328 struct BFloat16Impl {
0329  protected:
0330   /// <summary>
0331   /// Converts from float to uint16_t float16 representation
0332   /// </summary>
0333   /// <param name="v"></param>
0334   /// <returns></returns>
0335   static uint16_t ToUint16Impl(float v) noexcept;
0336 
0337   /// <summary>
0338   /// Converts bfloat16 to float
0339   /// </summary>
0340   /// <returns>float representation of bfloat16 value</returns>
0341   float ToFloatImpl() const noexcept;
0342 
0343   /// <summary>
0344   /// Creates an instance that represents absolute value.
0345   /// </summary>
0346   /// <returns>Absolute value</returns>
0347   uint16_t AbsImpl() const noexcept {
0348     return static_cast<uint16_t>(val & ~kSignMask);
0349   }
0350 
0351   /// <summary>
0352   /// Creates a new instance with the sign flipped.
0353   /// </summary>
0354   /// <returns>Flipped sign instance</returns>
0355   uint16_t NegateImpl() const noexcept {
0356     return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
0357   }
0358 
0359  public:
0360   // uint16_t special values
0361   static constexpr uint16_t kSignMask = 0x8000U;
0362   static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
0363   static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
0364   static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
0365   static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
0366   static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
0367   static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
0368   static constexpr uint16_t kEpsilonBits = 0x0080U;
0369   static constexpr uint16_t kMinValueBits = 0xFF7FU;
0370   static constexpr uint16_t kMaxValueBits = 0x7F7FU;
0371   static constexpr uint16_t kRoundToNearest = 0x7FFFU;
0372   static constexpr uint16_t kOneBits = 0x3F80U;
0373   static constexpr uint16_t kMinusOneBits = 0xBF80U;
0374 
0375   uint16_t val{0};
0376 
0377   BFloat16Impl() = default;
0378 
0379   /// <summary>
0380   /// Checks if the value is negative
0381   /// </summary>
0382   /// <returns>true if negative</returns>
0383   bool IsNegative() const noexcept {
0384     return static_cast<int16_t>(val) < 0;
0385   }
0386 
0387   /// <summary>
0388   /// Tests if the value is NaN
0389   /// </summary>
0390   /// <returns>true if NaN</returns>
0391   bool IsNaN() const noexcept {
0392     return AbsImpl() > kPositiveInfinityBits;
0393   }
0394 
0395   /// <summary>
0396   /// Tests if the value is finite
0397   /// </summary>
0398   /// <returns>true if finite</returns>
0399   bool IsFinite() const noexcept {
0400     return AbsImpl() < kPositiveInfinityBits;
0401   }
0402 
0403   /// <summary>
0404   /// Tests if the value represents positive infinity.
0405   /// </summary>
0406   /// <returns>true if positive infinity</returns>
0407   bool IsPositiveInfinity() const noexcept {
0408     return val == kPositiveInfinityBits;
0409   }
0410 
0411   /// <summary>
0412   /// Tests if the value represents negative infinity
0413   /// </summary>
0414   /// <returns>true if negative infinity</returns>
0415   bool IsNegativeInfinity() const noexcept {
0416     return val == kNegativeInfinityBits;
0417   }
0418 
0419   /// <summary>
0420   /// Tests if the value is either positive or negative infinity.
0421   /// </summary>
0422   /// <returns>True if absolute value is infinity</returns>
0423   bool IsInfinity() const noexcept {
0424     return AbsImpl() == kPositiveInfinityBits;
0425   }
0426 
0427   /// <summary>
0428   /// Tests if the value is NaN or zero. Useful for comparisons.
0429   /// </summary>
0430   /// <returns>True if NaN or zero.</returns>
0431   bool IsNaNOrZero() const noexcept {
0432     auto abs = AbsImpl();
0433     return (abs == 0 || abs > kPositiveInfinityBits);
0434   }
0435 
0436   /// <summary>
0437   /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
0438   /// </summary>
0439   /// <returns>True if so</returns>
0440   bool IsNormal() const noexcept {
0441     auto abs = AbsImpl();
0442     return (abs < kPositiveInfinityBits)           // is finite
0443            && (abs != 0)                           // is not zero
0444            && ((abs & kBiasedExponentMask) != 0);  // is not subnormal (has a non-zero exponent)
0445   }
0446 
0447   /// <summary>
0448   /// Tests if the value is subnormal (denormal).
0449   /// </summary>
0450   /// <returns>True if so</returns>
0451   bool IsSubnormal() const noexcept {
0452     auto abs = AbsImpl();
0453     return (abs < kPositiveInfinityBits)           // is finite
0454            && (abs != 0)                           // is not zero
0455            && ((abs & kBiasedExponentMask) == 0);  // is subnormal (has a zero exponent)
0456   }
0457 
0458   /// <summary>
0459   /// Creates an instance that represents absolute value.
0460   /// </summary>
0461   /// <returns>Absolute value</returns>
0462   Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
0463 
0464   /// <summary>
0465   /// Creates a new instance with the sign flipped.
0466   /// </summary>
0467   /// <returns>Flipped sign instance</returns>
0468   Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
0469 
0470   /// <summary>
0471   /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
0472   /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
0473   /// and therefore equivalent, if the resulting value is still zero.
0474   /// </summary>
0475   /// <param name="lhs">first value</param>
0476   /// <param name="rhs">second value</param>
0477   /// <returns>True if both arguments represent zero</returns>
0478   static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
0479     // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
0480     // for two values by or'ing the private bits together and stripping the sign. They are both zero,
0481     // and therefore equivalent, if the resulting value is still zero.
0482     return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
0483   }
0484 };
0485 
0486 template <class Derived>
0487 inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
0488   uint16_t result;
0489   if (std::isnan(v)) {
0490     result = kPositiveQNaNBits;
0491   } else {
0492     auto get_msb_half = [](float fl) {
0493       uint16_t result;
0494 #ifdef __cpp_if_constexpr
0495       if constexpr (detail::endian::native == detail::endian::little) {
0496 #else
0497       if (detail::endian::native == detail::endian::little) {
0498 #endif
0499         std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
0500       } else {
0501         std::memcpy(&result, &fl, sizeof(uint16_t));
0502       }
0503       return result;
0504     };
0505 
0506     uint16_t upper_bits = get_msb_half(v);
0507     union {
0508       uint32_t U32;
0509       float F32;
0510     };
0511     F32 = v;
0512     U32 += (upper_bits & 1) + kRoundToNearest;
0513     result = get_msb_half(F32);
0514   }
0515   return result;
0516 }
0517 
0518 template <class Derived>
0519 inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
0520   if (IsNaN()) {
0521     return std::numeric_limits<float>::quiet_NaN();
0522   }
0523   float result;
0524   char* const first = reinterpret_cast<char*>(&result);
0525   char* const second = first + sizeof(uint16_t);
0526 #ifdef __cpp_if_constexpr
0527   if constexpr (detail::endian::native == detail::endian::little) {
0528 #else
0529   if (detail::endian::native == detail::endian::little) {
0530 #endif
0531     std::memset(first, 0, sizeof(uint16_t));
0532     std::memcpy(second, &val, sizeof(uint16_t));
0533   } else {
0534     std::memcpy(first, &val, sizeof(uint16_t));
0535     std::memset(second, 0, sizeof(uint16_t));
0536   }
0537   return result;
0538 }
0539 
0540 }  // namespace onnxruntime_float16