Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-15 09:03:11

0001 // Copyright (c) Microsoft Corporation. All rights reserved.
0002 // Licensed under the MIT License.
0003 
0004 #pragma once
0005 
0006 #include <stdint.h>
0007 #include <cmath>
0008 #include <cstring>
0009 #include <limits>
0010 
0011 namespace onnxruntime_float16 {
0012 
0013 namespace detail {
0014 
0015 enum class endian {
0016 #if defined(_WIN32)
0017   little = 0,
0018   big = 1,
0019   native = little,
0020 #elif defined(__GNUC__) || defined(__clang__)
0021   little = __ORDER_LITTLE_ENDIAN__,
0022   big = __ORDER_BIG_ENDIAN__,
0023   native = __BYTE_ORDER__,
0024 #else
0025 #error onnxruntime_float16::detail::endian is not implemented in this environment.
0026 #endif
0027 };
0028 
0029 static_assert(
0030     endian::native == endian::little || endian::native == endian::big,
0031     "Only little-endian or big-endian native byte orders are supported.");
0032 
0033 }  // namespace detail
0034 
0035 /// <summary>
0036 /// Shared implementation between public and internal classes. CRTP pattern.
0037 /// </summary>
0038 template <class Derived>
0039 struct Float16Impl {
0040  protected:
0041   /// <summary>
0042   /// Converts from float to uint16_t float16 representation
0043   /// </summary>
0044   /// <param name="v"></param>
0045   /// <returns></returns>
0046   constexpr static uint16_t ToUint16Impl(float v) noexcept;
0047 
0048   /// <summary>
0049   /// Converts float16 to float
0050   /// </summary>
0051   /// <returns>float representation of float16 value</returns>
0052   float ToFloatImpl() const noexcept;
0053 
0054   /// <summary>
0055   /// Creates an instance that represents absolute value.
0056   /// </summary>
0057   /// <returns>Absolute value</returns>
0058   uint16_t AbsImpl() const noexcept {
0059     return static_cast<uint16_t>(val & ~kSignMask);
0060   }
0061 
0062   /// <summary>
0063   /// Creates a new instance with the sign flipped.
0064   /// </summary>
0065   /// <returns>Flipped sign instance</returns>
0066   uint16_t NegateImpl() const noexcept {
0067     return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
0068   }
0069 
0070  public:
0071   // uint16_t special values
0072   static constexpr uint16_t kSignMask = 0x8000U;
0073   static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
0074   static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
0075   static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
0076   static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
0077   static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
0078   static constexpr uint16_t kMaxValueBits = 0x7BFFU;  // Largest normal number
0079   static constexpr uint16_t kOneBits = 0x3C00U;
0080   static constexpr uint16_t kMinusOneBits = 0xBC00U;
0081 
0082   uint16_t val{0};
0083 
0084   Float16Impl() = default;
0085 
0086   /// <summary>
0087   /// Checks if the value is negative
0088   /// </summary>
0089   /// <returns>true if negative</returns>
0090   bool IsNegative() const noexcept {
0091     return static_cast<int16_t>(val) < 0;
0092   }
0093 
0094   /// <summary>
0095   /// Tests if the value is NaN
0096   /// </summary>
0097   /// <returns>true if NaN</returns>
0098   bool IsNaN() const noexcept {
0099     return AbsImpl() > kPositiveInfinityBits;
0100   }
0101 
0102   /// <summary>
0103   /// Tests if the value is finite
0104   /// </summary>
0105   /// <returns>true if finite</returns>
0106   bool IsFinite() const noexcept {
0107     return AbsImpl() < kPositiveInfinityBits;
0108   }
0109 
0110   /// <summary>
0111   /// Tests if the value represents positive infinity.
0112   /// </summary>
0113   /// <returns>true if positive infinity</returns>
0114   bool IsPositiveInfinity() const noexcept {
0115     return val == kPositiveInfinityBits;
0116   }
0117 
0118   /// <summary>
0119   /// Tests if the value represents negative infinity
0120   /// </summary>
0121   /// <returns>true if negative infinity</returns>
0122   bool IsNegativeInfinity() const noexcept {
0123     return val == kNegativeInfinityBits;
0124   }
0125 
0126   /// <summary>
0127   /// Tests if the value is either positive or negative infinity.
0128   /// </summary>
0129   /// <returns>True if absolute value is infinity</returns>
0130   bool IsInfinity() const noexcept {
0131     return AbsImpl() == kPositiveInfinityBits;
0132   }
0133 
0134   /// <summary>
0135   /// Tests if the value is NaN or zero. Useful for comparisons.
0136   /// </summary>
0137   /// <returns>True if NaN or zero.</returns>
0138   bool IsNaNOrZero() const noexcept {
0139     auto abs = AbsImpl();
0140     return (abs == 0 || abs > kPositiveInfinityBits);
0141   }
0142 
0143   /// <summary>
0144   /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
0145   /// </summary>
0146   /// <returns>True if so</returns>
0147   bool IsNormal() const noexcept {
0148     auto abs = AbsImpl();
0149     return (abs < kPositiveInfinityBits)           // is finite
0150            && (abs != 0)                           // is not zero
0151            && ((abs & kBiasedExponentMask) != 0);  // is not subnormal (has a non-zero exponent)
0152   }
0153 
0154   /// <summary>
0155   /// Tests if the value is subnormal (denormal).
0156   /// </summary>
0157   /// <returns>True if so</returns>
0158   bool IsSubnormal() const noexcept {
0159     auto abs = AbsImpl();
0160     return (abs < kPositiveInfinityBits)           // is finite
0161            && (abs != 0)                           // is not zero
0162            && ((abs & kBiasedExponentMask) == 0);  // is subnormal (has a zero exponent)
0163   }
0164 
0165   /// <summary>
0166   /// Creates an instance that represents absolute value.
0167   /// </summary>
0168   /// <returns>Absolute value</returns>
0169   Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
0170 
0171   /// <summary>
0172   /// Creates a new instance with the sign flipped.
0173   /// </summary>
0174   /// <returns>Flipped sign instance</returns>
0175   Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
0176 
0177   /// <summary>
0178   /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
0179   /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
0180   /// and therefore equivalent, if the resulting value is still zero.
0181   /// </summary>
0182   /// <param name="lhs">first value</param>
0183   /// <param name="rhs">second value</param>
0184   /// <returns>True if both arguments represent zero</returns>
0185   static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
0186     return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
0187   }
0188 
0189   bool operator==(const Float16Impl& rhs) const noexcept {
0190     if (IsNaN() || rhs.IsNaN()) {
0191       // IEEE defines that NaN is not equal to anything, including itself.
0192       return false;
0193     }
0194     return val == rhs.val;
0195   }
0196 
0197   bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
0198 
0199   bool operator<(const Float16Impl& rhs) const noexcept {
0200     if (IsNaN() || rhs.IsNaN()) {
0201       // IEEE defines that NaN is unordered with respect to everything, including itself.
0202       return false;
0203     }
0204 
0205     const bool left_is_negative = IsNegative();
0206     if (left_is_negative != rhs.IsNegative()) {
0207       // When the signs of left and right differ, we know that left is less than right if it is
0208       // the negative value. The exception to this is if both values are zero, in which case IEEE
0209       // says they should be equal, even if the signs differ.
0210       return left_is_negative && !AreZero(*this, rhs);
0211     }
0212     return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
0213   }
0214 };
0215 
0216 // The following Float16_t conversions are based on the code from
0217 // Eigen library.
0218 
0219 // The conversion routines are Copyright (c) Fabian Giesen, 2016.
0220 // The original license follows:
0221 //
0222 // Copyright (c) Fabian Giesen, 2016
0223 // All rights reserved.
0224 // Redistribution and use in source and binary forms, with or without
0225 // modification, are permitted.
0226 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
0227 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
0228 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
0229 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
0230 // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
0231 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
0232 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
0233 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
0234 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
0235 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
0236 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
0237 
0238 namespace detail {
0239 union float32_bits {
0240   unsigned int u;
0241   float f;
0242 };
0243 }  // namespace detail
0244 
0245 template <class Derived>
0246 inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
0247   detail::float32_bits f{};
0248   f.f = v;
0249 
0250   constexpr detail::float32_bits f32infty = {255 << 23};
0251   constexpr detail::float32_bits f16max = {(127 + 16) << 23};
0252   constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
0253   constexpr unsigned int sign_mask = 0x80000000u;
0254   uint16_t val = static_cast<uint16_t>(0x0u);
0255 
0256   unsigned int sign = f.u & sign_mask;
0257   f.u ^= sign;
0258 
0259   // NOTE all the integer compares in this function can be safely
0260   // compiled into signed compares since all operands are below
0261   // 0x80000000. Important if you want fast straight SSE2 code
0262   // (since there's no unsigned PCMPGTD).
0263 
0264   if (f.u >= f16max.u) {                         // result is Inf or NaN (all exponent bits set)
0265     val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00;  // NaN->qNaN and Inf->Inf
0266   } else {                                       // (De)normalized number or zero
0267     if (f.u < (113 << 23)) {                     // resulting FP16 is subnormal or zero
0268       // use a magic value to align our 10 mantissa bits at the bottom of
0269       // the float. as long as FP addition is round-to-nearest-even this
0270       // just works.
0271       f.f += denorm_magic.f;
0272 
0273       // and one integer subtract of the bias later, we have our final float!
0274       val = static_cast<uint16_t>(f.u - denorm_magic.u);
0275     } else {
0276       unsigned int mant_odd = (f.u >> 13) & 1;  // resulting mantissa is odd
0277 
0278       // update exponent, rounding bias part 1
0279       // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
0280       // without arithmetic overflow.
0281       f.u += 0xc8000fffU;
0282       // rounding bias part 2
0283       f.u += mant_odd;
0284       // take the bits!
0285       val = static_cast<uint16_t>(f.u >> 13);
0286     }
0287   }
0288 
0289   val |= static_cast<uint16_t>(sign >> 16);
0290   return val;
0291 }
0292 
0293 template <class Derived>
0294 inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
0295   constexpr detail::float32_bits magic = {113 << 23};
0296   constexpr unsigned int shifted_exp = 0x7c00 << 13;  // exponent mask after shift
0297   detail::float32_bits o{};
0298 
0299   o.u = (val & 0x7fff) << 13;            // exponent/mantissa bits
0300   unsigned int exp = shifted_exp & o.u;  // just the exponent
0301   o.u += (127 - 15) << 23;               // exponent adjust
0302 
0303   // handle exponent special cases
0304   if (exp == shifted_exp) {   // Inf/NaN?
0305     o.u += (128 - 16) << 23;  // extra exp adjust
0306   } else if (exp == 0) {      // Zero/Denormal?
0307     o.u += 1 << 23;           // extra exp adjust
0308     o.f -= magic.f;           // re-normalize
0309   }
0310 
0311   // Attempt to workaround the Internal Compiler Error on ARM64
0312   // for bitwise | operator, including std::bitset
0313 #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
0314   if (IsNegative()) {
0315     return -o.f;
0316   }
0317 #else
0318   // original code:
0319   o.u |= (val & 0x8000U) << 16U;  // sign bit
0320 #endif
0321   return o.f;
0322 }
0323 
0324 /// Shared implementation between public and internal classes. CRTP pattern.
0325 template <class Derived>
0326 struct BFloat16Impl {
0327  protected:
0328   /// <summary>
0329   /// Converts from float to uint16_t float16 representation
0330   /// </summary>
0331   /// <param name="v"></param>
0332   /// <returns></returns>
0333   static uint16_t ToUint16Impl(float v) noexcept;
0334 
0335   /// <summary>
0336   /// Converts bfloat16 to float
0337   /// </summary>
0338   /// <returns>float representation of bfloat16 value</returns>
0339   float ToFloatImpl() const noexcept;
0340 
0341   /// <summary>
0342   /// Creates an instance that represents absolute value.
0343   /// </summary>
0344   /// <returns>Absolute value</returns>
0345   uint16_t AbsImpl() const noexcept {
0346     return static_cast<uint16_t>(val & ~kSignMask);
0347   }
0348 
0349   /// <summary>
0350   /// Creates a new instance with the sign flipped.
0351   /// </summary>
0352   /// <returns>Flipped sign instance</returns>
0353   uint16_t NegateImpl() const noexcept {
0354     return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
0355   }
0356 
0357  public:
0358   // uint16_t special values
0359   static constexpr uint16_t kSignMask = 0x8000U;
0360   static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
0361   static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
0362   static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
0363   static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
0364   static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
0365   static constexpr uint16_t kMaxValueBits = 0x7F7FU;
0366   static constexpr uint16_t kRoundToNearest = 0x7FFFU;
0367   static constexpr uint16_t kOneBits = 0x3F80U;
0368   static constexpr uint16_t kMinusOneBits = 0xBF80U;
0369 
0370   uint16_t val{0};
0371 
0372   BFloat16Impl() = default;
0373 
0374   /// <summary>
0375   /// Checks if the value is negative
0376   /// </summary>
0377   /// <returns>true if negative</returns>
0378   bool IsNegative() const noexcept {
0379     return static_cast<int16_t>(val) < 0;
0380   }
0381 
0382   /// <summary>
0383   /// Tests if the value is NaN
0384   /// </summary>
0385   /// <returns>true if NaN</returns>
0386   bool IsNaN() const noexcept {
0387     return AbsImpl() > kPositiveInfinityBits;
0388   }
0389 
0390   /// <summary>
0391   /// Tests if the value is finite
0392   /// </summary>
0393   /// <returns>true if finite</returns>
0394   bool IsFinite() const noexcept {
0395     return AbsImpl() < kPositiveInfinityBits;
0396   }
0397 
0398   /// <summary>
0399   /// Tests if the value represents positive infinity.
0400   /// </summary>
0401   /// <returns>true if positive infinity</returns>
0402   bool IsPositiveInfinity() const noexcept {
0403     return val == kPositiveInfinityBits;
0404   }
0405 
0406   /// <summary>
0407   /// Tests if the value represents negative infinity
0408   /// </summary>
0409   /// <returns>true if negative infinity</returns>
0410   bool IsNegativeInfinity() const noexcept {
0411     return val == kNegativeInfinityBits;
0412   }
0413 
0414   /// <summary>
0415   /// Tests if the value is either positive or negative infinity.
0416   /// </summary>
0417   /// <returns>True if absolute value is infinity</returns>
0418   bool IsInfinity() const noexcept {
0419     return AbsImpl() == kPositiveInfinityBits;
0420   }
0421 
0422   /// <summary>
0423   /// Tests if the value is NaN or zero. Useful for comparisons.
0424   /// </summary>
0425   /// <returns>True if NaN or zero.</returns>
0426   bool IsNaNOrZero() const noexcept {
0427     auto abs = AbsImpl();
0428     return (abs == 0 || abs > kPositiveInfinityBits);
0429   }
0430 
0431   /// <summary>
0432   /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
0433   /// </summary>
0434   /// <returns>True if so</returns>
0435   bool IsNormal() const noexcept {
0436     auto abs = AbsImpl();
0437     return (abs < kPositiveInfinityBits)           // is finite
0438            && (abs != 0)                           // is not zero
0439            && ((abs & kBiasedExponentMask) != 0);  // is not subnormal (has a non-zero exponent)
0440   }
0441 
0442   /// <summary>
0443   /// Tests if the value is subnormal (denormal).
0444   /// </summary>
0445   /// <returns>True if so</returns>
0446   bool IsSubnormal() const noexcept {
0447     auto abs = AbsImpl();
0448     return (abs < kPositiveInfinityBits)           // is finite
0449            && (abs != 0)                           // is not zero
0450            && ((abs & kBiasedExponentMask) == 0);  // is subnormal (has a zero exponent)
0451   }
0452 
0453   /// <summary>
0454   /// Creates an instance that represents absolute value.
0455   /// </summary>
0456   /// <returns>Absolute value</returns>
0457   Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
0458 
0459   /// <summary>
0460   /// Creates a new instance with the sign flipped.
0461   /// </summary>
0462   /// <returns>Flipped sign instance</returns>
0463   Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
0464 
0465   /// <summary>
0466   /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
0467   /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
0468   /// and therefore equivalent, if the resulting value is still zero.
0469   /// </summary>
0470   /// <param name="lhs">first value</param>
0471   /// <param name="rhs">second value</param>
0472   /// <returns>True if both arguments represent zero</returns>
0473   static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
0474     // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
0475     // for two values by or'ing the private bits together and stripping the sign. They are both zero,
0476     // and therefore equivalent, if the resulting value is still zero.
0477     return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
0478   }
0479 };
0480 
0481 template <class Derived>
0482 inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
0483   uint16_t result;
0484   if (std::isnan(v)) {
0485     result = kPositiveQNaNBits;
0486   } else {
0487     auto get_msb_half = [](float fl) {
0488       uint16_t result;
0489 #ifdef __cpp_if_constexpr
0490       if constexpr (detail::endian::native == detail::endian::little) {
0491 #else
0492       if (detail::endian::native == detail::endian::little) {
0493 #endif
0494         std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
0495       } else {
0496         std::memcpy(&result, &fl, sizeof(uint16_t));
0497       }
0498       return result;
0499     };
0500 
0501     uint16_t upper_bits = get_msb_half(v);
0502     union {
0503       uint32_t U32;
0504       float F32;
0505     };
0506     F32 = v;
0507     U32 += (upper_bits & 1) + kRoundToNearest;
0508     result = get_msb_half(F32);
0509   }
0510   return result;
0511 }
0512 
0513 template <class Derived>
0514 inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
0515   if (IsNaN()) {
0516     return std::numeric_limits<float>::quiet_NaN();
0517   }
0518   float result;
0519   char* const first = reinterpret_cast<char*>(&result);
0520   char* const second = first + sizeof(uint16_t);
0521 #ifdef __cpp_if_constexpr
0522   if constexpr (detail::endian::native == detail::endian::little) {
0523 #else
0524   if (detail::endian::native == detail::endian::little) {
0525 #endif
0526     std::memset(first, 0, sizeof(uint16_t));
0527     std::memcpy(second, &val, sizeof(uint16_t));
0528   } else {
0529     std::memcpy(first, &val, sizeof(uint16_t));
0530     std::memset(second, 0, sizeof(uint16_t));
0531   }
0532   return result;
0533 }
0534 
0535 }  // namespace onnxruntime_float16