File indexing completed on 2025-01-18 10:02:52
0001
0002
0003
0004 #pragma once
0005
0006 #include <stdint.h>
0007 #include <cmath>
0008 #include <cstring>
0009 #include <limits>
0010
0011 namespace onnxruntime_float16 {
0012
0013 namespace detail {
0014
0015 enum class endian {
0016 #if defined(_WIN32)
0017 little = 0,
0018 big = 1,
0019 native = little,
0020 #elif defined(__GNUC__) || defined(__clang__)
0021 little = __ORDER_LITTLE_ENDIAN__,
0022 big = __ORDER_BIG_ENDIAN__,
0023 native = __BYTE_ORDER__,
0024 #else
0025 #error onnxruntime_float16::detail::endian is not implemented in this environment.
0026 #endif
0027 };
0028
0029 static_assert(
0030 endian::native == endian::little || endian::native == endian::big,
0031 "Only little-endian or big-endian native byte orders are supported.");
0032
0033 }
0034
0035
0036
0037
0038 template <class Derived>
0039 struct Float16Impl {
0040 protected:
0041
0042
0043
0044
0045
0046 constexpr static uint16_t ToUint16Impl(float v) noexcept;
0047
0048
0049
0050
0051
0052 float ToFloatImpl() const noexcept;
0053
0054
0055
0056
0057
0058 uint16_t AbsImpl() const noexcept {
0059 return static_cast<uint16_t>(val & ~kSignMask);
0060 }
0061
0062
0063
0064
0065
0066 uint16_t NegateImpl() const noexcept {
0067 return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
0068 }
0069
0070 public:
0071
0072 static constexpr uint16_t kSignMask = 0x8000U;
0073 static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
0074 static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
0075 static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
0076 static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
0077 static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
0078 static constexpr uint16_t kEpsilonBits = 0x4170U;
0079 static constexpr uint16_t kMinValueBits = 0xFBFFU;
0080 static constexpr uint16_t kMaxValueBits = 0x7BFFU;
0081 static constexpr uint16_t kOneBits = 0x3C00U;
0082 static constexpr uint16_t kMinusOneBits = 0xBC00U;
0083
0084 uint16_t val{0};
0085
0086 Float16Impl() = default;
0087
0088
0089
0090
0091
0092 bool IsNegative() const noexcept {
0093 return static_cast<int16_t>(val) < 0;
0094 }
0095
0096
0097
0098
0099
0100 bool IsNaN() const noexcept {
0101 return AbsImpl() > kPositiveInfinityBits;
0102 }
0103
0104
0105
0106
0107
0108 bool IsFinite() const noexcept {
0109 return AbsImpl() < kPositiveInfinityBits;
0110 }
0111
0112
0113
0114
0115
0116 bool IsPositiveInfinity() const noexcept {
0117 return val == kPositiveInfinityBits;
0118 }
0119
0120
0121
0122
0123
0124 bool IsNegativeInfinity() const noexcept {
0125 return val == kNegativeInfinityBits;
0126 }
0127
0128
0129
0130
0131
0132 bool IsInfinity() const noexcept {
0133 return AbsImpl() == kPositiveInfinityBits;
0134 }
0135
0136
0137
0138
0139
0140 bool IsNaNOrZero() const noexcept {
0141 auto abs = AbsImpl();
0142 return (abs == 0 || abs > kPositiveInfinityBits);
0143 }
0144
0145
0146
0147
0148
0149 bool IsNormal() const noexcept {
0150 auto abs = AbsImpl();
0151 return (abs < kPositiveInfinityBits)
0152 && (abs != 0)
0153 && ((abs & kBiasedExponentMask) != 0);
0154 }
0155
0156
0157
0158
0159
0160 bool IsSubnormal() const noexcept {
0161 auto abs = AbsImpl();
0162 return (abs < kPositiveInfinityBits)
0163 && (abs != 0)
0164 && ((abs & kBiasedExponentMask) == 0);
0165 }
0166
0167
0168
0169
0170
0171 Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
0172
0173
0174
0175
0176
0177 Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
0178
0179
0180
0181
0182
0183
0184
0185
0186
0187 static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
0188 return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
0189 }
0190
0191 bool operator==(const Float16Impl& rhs) const noexcept {
0192 if (IsNaN() || rhs.IsNaN()) {
0193
0194 return false;
0195 }
0196 return val == rhs.val;
0197 }
0198
0199 bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
0200
0201 bool operator<(const Float16Impl& rhs) const noexcept {
0202 if (IsNaN() || rhs.IsNaN()) {
0203
0204 return false;
0205 }
0206
0207 const bool left_is_negative = IsNegative();
0208 if (left_is_negative != rhs.IsNegative()) {
0209
0210
0211
0212 return left_is_negative && !AreZero(*this, rhs);
0213 }
0214 return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
0215 }
0216 };
0217
0218
0219
0220
0221
0222
0223
0224
0225
0226
0227
0228
0229
0230
0231
0232
0233
0234
0235
0236
0237
0238
0239
0240 namespace detail {
0241 union float32_bits {
0242 unsigned int u;
0243 float f;
0244 };
0245 }
0246
0247 template <class Derived>
0248 inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
0249 detail::float32_bits f{};
0250 f.f = v;
0251
0252 constexpr detail::float32_bits f32infty = {255 << 23};
0253 constexpr detail::float32_bits f16max = {(127 + 16) << 23};
0254 constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
0255 constexpr unsigned int sign_mask = 0x80000000u;
0256 uint16_t val = static_cast<uint16_t>(0x0u);
0257
0258 unsigned int sign = f.u & sign_mask;
0259 f.u ^= sign;
0260
0261
0262
0263
0264
0265
0266 if (f.u >= f16max.u) {
0267 val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00;
0268 } else {
0269 if (f.u < (113 << 23)) {
0270
0271
0272
0273 f.f += denorm_magic.f;
0274
0275
0276 val = static_cast<uint16_t>(f.u - denorm_magic.u);
0277 } else {
0278 unsigned int mant_odd = (f.u >> 13) & 1;
0279
0280
0281
0282
0283 f.u += 0xc8000fffU;
0284
0285 f.u += mant_odd;
0286
0287 val = static_cast<uint16_t>(f.u >> 13);
0288 }
0289 }
0290
0291 val |= static_cast<uint16_t>(sign >> 16);
0292 return val;
0293 }
0294
0295 template <class Derived>
0296 inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
0297 constexpr detail::float32_bits magic = {113 << 23};
0298 constexpr unsigned int shifted_exp = 0x7c00 << 13;
0299 detail::float32_bits o{};
0300
0301 o.u = (val & 0x7fff) << 13;
0302 unsigned int exp = shifted_exp & o.u;
0303 o.u += (127 - 15) << 23;
0304
0305
0306 if (exp == shifted_exp) {
0307 o.u += (128 - 16) << 23;
0308 } else if (exp == 0) {
0309 o.u += 1 << 23;
0310 o.f -= magic.f;
0311 }
0312
0313
0314
0315 #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
0316 if (IsNegative()) {
0317 return -o.f;
0318 }
0319 #else
0320
0321 o.u |= (val & 0x8000U) << 16U;
0322 #endif
0323 return o.f;
0324 }
0325
0326
0327 template <class Derived>
0328 struct BFloat16Impl {
0329 protected:
0330
0331
0332
0333
0334
0335 static uint16_t ToUint16Impl(float v) noexcept;
0336
0337
0338
0339
0340
0341 float ToFloatImpl() const noexcept;
0342
0343
0344
0345
0346
0347 uint16_t AbsImpl() const noexcept {
0348 return static_cast<uint16_t>(val & ~kSignMask);
0349 }
0350
0351
0352
0353
0354
0355 uint16_t NegateImpl() const noexcept {
0356 return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
0357 }
0358
0359 public:
0360
0361 static constexpr uint16_t kSignMask = 0x8000U;
0362 static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
0363 static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
0364 static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
0365 static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
0366 static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
0367 static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
0368 static constexpr uint16_t kEpsilonBits = 0x0080U;
0369 static constexpr uint16_t kMinValueBits = 0xFF7FU;
0370 static constexpr uint16_t kMaxValueBits = 0x7F7FU;
0371 static constexpr uint16_t kRoundToNearest = 0x7FFFU;
0372 static constexpr uint16_t kOneBits = 0x3F80U;
0373 static constexpr uint16_t kMinusOneBits = 0xBF80U;
0374
0375 uint16_t val{0};
0376
0377 BFloat16Impl() = default;
0378
0379
0380
0381
0382
0383 bool IsNegative() const noexcept {
0384 return static_cast<int16_t>(val) < 0;
0385 }
0386
0387
0388
0389
0390
0391 bool IsNaN() const noexcept {
0392 return AbsImpl() > kPositiveInfinityBits;
0393 }
0394
0395
0396
0397
0398
0399 bool IsFinite() const noexcept {
0400 return AbsImpl() < kPositiveInfinityBits;
0401 }
0402
0403
0404
0405
0406
0407 bool IsPositiveInfinity() const noexcept {
0408 return val == kPositiveInfinityBits;
0409 }
0410
0411
0412
0413
0414
0415 bool IsNegativeInfinity() const noexcept {
0416 return val == kNegativeInfinityBits;
0417 }
0418
0419
0420
0421
0422
0423 bool IsInfinity() const noexcept {
0424 return AbsImpl() == kPositiveInfinityBits;
0425 }
0426
0427
0428
0429
0430
0431 bool IsNaNOrZero() const noexcept {
0432 auto abs = AbsImpl();
0433 return (abs == 0 || abs > kPositiveInfinityBits);
0434 }
0435
0436
0437
0438
0439
0440 bool IsNormal() const noexcept {
0441 auto abs = AbsImpl();
0442 return (abs < kPositiveInfinityBits)
0443 && (abs != 0)
0444 && ((abs & kBiasedExponentMask) != 0);
0445 }
0446
0447
0448
0449
0450
0451 bool IsSubnormal() const noexcept {
0452 auto abs = AbsImpl();
0453 return (abs < kPositiveInfinityBits)
0454 && (abs != 0)
0455 && ((abs & kBiasedExponentMask) == 0);
0456 }
0457
0458
0459
0460
0461
0462 Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
0463
0464
0465
0466
0467
0468 Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
0469
0470
0471
0472
0473
0474
0475
0476
0477
0478 static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
0479
0480
0481
0482 return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
0483 }
0484 };
0485
0486 template <class Derived>
0487 inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
0488 uint16_t result;
0489 if (std::isnan(v)) {
0490 result = kPositiveQNaNBits;
0491 } else {
0492 auto get_msb_half = [](float fl) {
0493 uint16_t result;
0494 #ifdef __cpp_if_constexpr
0495 if constexpr (detail::endian::native == detail::endian::little) {
0496 #else
0497 if (detail::endian::native == detail::endian::little) {
0498 #endif
0499 std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
0500 } else {
0501 std::memcpy(&result, &fl, sizeof(uint16_t));
0502 }
0503 return result;
0504 };
0505
0506 uint16_t upper_bits = get_msb_half(v);
0507 union {
0508 uint32_t U32;
0509 float F32;
0510 };
0511 F32 = v;
0512 U32 += (upper_bits & 1) + kRoundToNearest;
0513 result = get_msb_half(F32);
0514 }
0515 return result;
0516 }
0517
0518 template <class Derived>
0519 inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
0520 if (IsNaN()) {
0521 return std::numeric_limits<float>::quiet_NaN();
0522 }
0523 float result;
0524 char* const first = reinterpret_cast<char*>(&result);
0525 char* const second = first + sizeof(uint16_t);
0526 #ifdef __cpp_if_constexpr
0527 if constexpr (detail::endian::native == detail::endian::little) {
0528 #else
0529 if (detail::endian::native == detail::endian::little) {
0530 #endif
0531 std::memset(first, 0, sizeof(uint16_t));
0532 std::memcpy(second, &val, sizeof(uint16_t));
0533 } else {
0534 std::memcpy(first, &val, sizeof(uint16_t));
0535 std::memset(second, 0, sizeof(uint16_t));
0536 }
0537 return result;
0538 }
0539
0540 }