File indexing completed on 2025-09-15 09:03:11
0001
0002
0003
0004 #pragma once
0005
0006 #include <stdint.h>
0007 #include <cmath>
0008 #include <cstring>
0009 #include <limits>
0010
0011 namespace onnxruntime_float16 {
0012
0013 namespace detail {
0014
0015 enum class endian {
0016 #if defined(_WIN32)
0017 little = 0,
0018 big = 1,
0019 native = little,
0020 #elif defined(__GNUC__) || defined(__clang__)
0021 little = __ORDER_LITTLE_ENDIAN__,
0022 big = __ORDER_BIG_ENDIAN__,
0023 native = __BYTE_ORDER__,
0024 #else
0025 #error onnxruntime_float16::detail::endian is not implemented in this environment.
0026 #endif
0027 };
0028
0029 static_assert(
0030 endian::native == endian::little || endian::native == endian::big,
0031 "Only little-endian or big-endian native byte orders are supported.");
0032
0033 }
0034
0035
0036
0037
0038 template <class Derived>
0039 struct Float16Impl {
0040 protected:
0041
0042
0043
0044
0045
0046 constexpr static uint16_t ToUint16Impl(float v) noexcept;
0047
0048
0049
0050
0051
0052 float ToFloatImpl() const noexcept;
0053
0054
0055
0056
0057
0058 uint16_t AbsImpl() const noexcept {
0059 return static_cast<uint16_t>(val & ~kSignMask);
0060 }
0061
0062
0063
0064
0065
0066 uint16_t NegateImpl() const noexcept {
0067 return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
0068 }
0069
0070 public:
0071
0072 static constexpr uint16_t kSignMask = 0x8000U;
0073 static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
0074 static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
0075 static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
0076 static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
0077 static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
0078 static constexpr uint16_t kMaxValueBits = 0x7BFFU;
0079 static constexpr uint16_t kOneBits = 0x3C00U;
0080 static constexpr uint16_t kMinusOneBits = 0xBC00U;
0081
0082 uint16_t val{0};
0083
0084 Float16Impl() = default;
0085
0086
0087
0088
0089
0090 bool IsNegative() const noexcept {
0091 return static_cast<int16_t>(val) < 0;
0092 }
0093
0094
0095
0096
0097
0098 bool IsNaN() const noexcept {
0099 return AbsImpl() > kPositiveInfinityBits;
0100 }
0101
0102
0103
0104
0105
0106 bool IsFinite() const noexcept {
0107 return AbsImpl() < kPositiveInfinityBits;
0108 }
0109
0110
0111
0112
0113
0114 bool IsPositiveInfinity() const noexcept {
0115 return val == kPositiveInfinityBits;
0116 }
0117
0118
0119
0120
0121
0122 bool IsNegativeInfinity() const noexcept {
0123 return val == kNegativeInfinityBits;
0124 }
0125
0126
0127
0128
0129
0130 bool IsInfinity() const noexcept {
0131 return AbsImpl() == kPositiveInfinityBits;
0132 }
0133
0134
0135
0136
0137
0138 bool IsNaNOrZero() const noexcept {
0139 auto abs = AbsImpl();
0140 return (abs == 0 || abs > kPositiveInfinityBits);
0141 }
0142
0143
0144
0145
0146
0147 bool IsNormal() const noexcept {
0148 auto abs = AbsImpl();
0149 return (abs < kPositiveInfinityBits)
0150 && (abs != 0)
0151 && ((abs & kBiasedExponentMask) != 0);
0152 }
0153
0154
0155
0156
0157
0158 bool IsSubnormal() const noexcept {
0159 auto abs = AbsImpl();
0160 return (abs < kPositiveInfinityBits)
0161 && (abs != 0)
0162 && ((abs & kBiasedExponentMask) == 0);
0163 }
0164
0165
0166
0167
0168
0169 Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
0170
0171
0172
0173
0174
0175 Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
0176
0177
0178
0179
0180
0181
0182
0183
0184
0185 static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
0186 return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
0187 }
0188
0189 bool operator==(const Float16Impl& rhs) const noexcept {
0190 if (IsNaN() || rhs.IsNaN()) {
0191
0192 return false;
0193 }
0194 return val == rhs.val;
0195 }
0196
0197 bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
0198
0199 bool operator<(const Float16Impl& rhs) const noexcept {
0200 if (IsNaN() || rhs.IsNaN()) {
0201
0202 return false;
0203 }
0204
0205 const bool left_is_negative = IsNegative();
0206 if (left_is_negative != rhs.IsNegative()) {
0207
0208
0209
0210 return left_is_negative && !AreZero(*this, rhs);
0211 }
0212 return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
0213 }
0214 };
0215
0216
0217
0218
0219
0220
0221
0222
0223
0224
0225
0226
0227
0228
0229
0230
0231
0232
0233
0234
0235
0236
0237
0238 namespace detail {
0239 union float32_bits {
0240 unsigned int u;
0241 float f;
0242 };
0243 }
0244
0245 template <class Derived>
0246 inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
0247 detail::float32_bits f{};
0248 f.f = v;
0249
0250 constexpr detail::float32_bits f32infty = {255 << 23};
0251 constexpr detail::float32_bits f16max = {(127 + 16) << 23};
0252 constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
0253 constexpr unsigned int sign_mask = 0x80000000u;
0254 uint16_t val = static_cast<uint16_t>(0x0u);
0255
0256 unsigned int sign = f.u & sign_mask;
0257 f.u ^= sign;
0258
0259
0260
0261
0262
0263
0264 if (f.u >= f16max.u) {
0265 val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00;
0266 } else {
0267 if (f.u < (113 << 23)) {
0268
0269
0270
0271 f.f += denorm_magic.f;
0272
0273
0274 val = static_cast<uint16_t>(f.u - denorm_magic.u);
0275 } else {
0276 unsigned int mant_odd = (f.u >> 13) & 1;
0277
0278
0279
0280
0281 f.u += 0xc8000fffU;
0282
0283 f.u += mant_odd;
0284
0285 val = static_cast<uint16_t>(f.u >> 13);
0286 }
0287 }
0288
0289 val |= static_cast<uint16_t>(sign >> 16);
0290 return val;
0291 }
0292
0293 template <class Derived>
0294 inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
0295 constexpr detail::float32_bits magic = {113 << 23};
0296 constexpr unsigned int shifted_exp = 0x7c00 << 13;
0297 detail::float32_bits o{};
0298
0299 o.u = (val & 0x7fff) << 13;
0300 unsigned int exp = shifted_exp & o.u;
0301 o.u += (127 - 15) << 23;
0302
0303
0304 if (exp == shifted_exp) {
0305 o.u += (128 - 16) << 23;
0306 } else if (exp == 0) {
0307 o.u += 1 << 23;
0308 o.f -= magic.f;
0309 }
0310
0311
0312
0313 #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
0314 if (IsNegative()) {
0315 return -o.f;
0316 }
0317 #else
0318
0319 o.u |= (val & 0x8000U) << 16U;
0320 #endif
0321 return o.f;
0322 }
0323
0324
0325 template <class Derived>
0326 struct BFloat16Impl {
0327 protected:
0328
0329
0330
0331
0332
0333 static uint16_t ToUint16Impl(float v) noexcept;
0334
0335
0336
0337
0338
0339 float ToFloatImpl() const noexcept;
0340
0341
0342
0343
0344
0345 uint16_t AbsImpl() const noexcept {
0346 return static_cast<uint16_t>(val & ~kSignMask);
0347 }
0348
0349
0350
0351
0352
0353 uint16_t NegateImpl() const noexcept {
0354 return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
0355 }
0356
0357 public:
0358
0359 static constexpr uint16_t kSignMask = 0x8000U;
0360 static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
0361 static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
0362 static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
0363 static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
0364 static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
0365 static constexpr uint16_t kMaxValueBits = 0x7F7FU;
0366 static constexpr uint16_t kRoundToNearest = 0x7FFFU;
0367 static constexpr uint16_t kOneBits = 0x3F80U;
0368 static constexpr uint16_t kMinusOneBits = 0xBF80U;
0369
0370 uint16_t val{0};
0371
0372 BFloat16Impl() = default;
0373
0374
0375
0376
0377
0378 bool IsNegative() const noexcept {
0379 return static_cast<int16_t>(val) < 0;
0380 }
0381
0382
0383
0384
0385
0386 bool IsNaN() const noexcept {
0387 return AbsImpl() > kPositiveInfinityBits;
0388 }
0389
0390
0391
0392
0393
0394 bool IsFinite() const noexcept {
0395 return AbsImpl() < kPositiveInfinityBits;
0396 }
0397
0398
0399
0400
0401
0402 bool IsPositiveInfinity() const noexcept {
0403 return val == kPositiveInfinityBits;
0404 }
0405
0406
0407
0408
0409
0410 bool IsNegativeInfinity() const noexcept {
0411 return val == kNegativeInfinityBits;
0412 }
0413
0414
0415
0416
0417
0418 bool IsInfinity() const noexcept {
0419 return AbsImpl() == kPositiveInfinityBits;
0420 }
0421
0422
0423
0424
0425
0426 bool IsNaNOrZero() const noexcept {
0427 auto abs = AbsImpl();
0428 return (abs == 0 || abs > kPositiveInfinityBits);
0429 }
0430
0431
0432
0433
0434
0435 bool IsNormal() const noexcept {
0436 auto abs = AbsImpl();
0437 return (abs < kPositiveInfinityBits)
0438 && (abs != 0)
0439 && ((abs & kBiasedExponentMask) != 0);
0440 }
0441
0442
0443
0444
0445
0446 bool IsSubnormal() const noexcept {
0447 auto abs = AbsImpl();
0448 return (abs < kPositiveInfinityBits)
0449 && (abs != 0)
0450 && ((abs & kBiasedExponentMask) == 0);
0451 }
0452
0453
0454
0455
0456
0457 Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
0458
0459
0460
0461
0462
0463 Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
0464
0465
0466
0467
0468
0469
0470
0471
0472
0473 static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
0474
0475
0476
0477 return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
0478 }
0479 };
0480
0481 template <class Derived>
0482 inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
0483 uint16_t result;
0484 if (std::isnan(v)) {
0485 result = kPositiveQNaNBits;
0486 } else {
0487 auto get_msb_half = [](float fl) {
0488 uint16_t result;
0489 #ifdef __cpp_if_constexpr
0490 if constexpr (detail::endian::native == detail::endian::little) {
0491 #else
0492 if (detail::endian::native == detail::endian::little) {
0493 #endif
0494 std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
0495 } else {
0496 std::memcpy(&result, &fl, sizeof(uint16_t));
0497 }
0498 return result;
0499 };
0500
0501 uint16_t upper_bits = get_msb_half(v);
0502 union {
0503 uint32_t U32;
0504 float F32;
0505 };
0506 F32 = v;
0507 U32 += (upper_bits & 1) + kRoundToNearest;
0508 result = get_msb_half(F32);
0509 }
0510 return result;
0511 }
0512
0513 template <class Derived>
0514 inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
0515 if (IsNaN()) {
0516 return std::numeric_limits<float>::quiet_NaN();
0517 }
0518 float result;
0519 char* const first = reinterpret_cast<char*>(&result);
0520 char* const second = first + sizeof(uint16_t);
0521 #ifdef __cpp_if_constexpr
0522 if constexpr (detail::endian::native == detail::endian::little) {
0523 #else
0524 if (detail::endian::native == detail::endian::little) {
0525 #endif
0526 std::memset(first, 0, sizeof(uint16_t));
0527 std::memcpy(second, &val, sizeof(uint16_t));
0528 } else {
0529 std::memcpy(first, &val, sizeof(uint16_t));
0530 std::memset(second, 0, sizeof(uint16_t));
0531 }
0532 return result;
0533 }
0534
0535 }