File indexing completed on 2025-01-18 10:00:12
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include <iostream>
0008
0009 #ifdef __CUDA_ARCH__
0010 #include <cuda.h>
0011
0012 #if CUDA_VERSION >= 9000
0013 #ifdef __GNUC__
0014 #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
0015 #pragma GCC diagnostic push
0016 #endif
0017 #pragma GCC diagnostic ignored "-Wstrict-aliasing"
0018 #endif
0019 #endif
0020 #include <cuda_fp16.h>
0021 #if CUDA_VERSION >= 9000
0022 #ifdef __GNUC__
0023 #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
0024 #pragma GCC diagnostic pop
0025 #endif
0026 #endif
0027 #endif
0028 #endif
0029
0030 #include "gloo/common/common.h"
0031
0032 #ifdef _WIN32
0033 #include <BaseTsd.h>
0034 typedef SSIZE_T ssize_t;
0035 #endif
0036
0037 namespace gloo {
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066 constexpr uint8_t kGatherSlotPrefix = 0x01;
0067 constexpr uint8_t kAllgatherSlotPrefix = 0x02;
0068 constexpr uint8_t kReduceSlotPrefix = 0x03;
0069 constexpr uint8_t kAllreduceSlotPrefix = 0x04;
0070 constexpr uint8_t kScatterSlotPrefix = 0x05;
0071 constexpr uint8_t kBroadcastSlotPrefix = 0x06;
0072 constexpr uint8_t kBarrierSlotPrefix = 0x07;
0073 constexpr uint8_t kAlltoallSlotPrefix = 0x08;
0074
0075 class Slot {
0076 public:
0077 static Slot build(uint8_t prefix, uint32_t tag);
0078
0079 operator uint64_t() const {
0080 return base_ + delta_;
0081 }
0082
0083 Slot operator+(uint8_t i) const;
0084
0085 protected:
0086 explicit Slot(uint64_t base, uint64_t delta) : base_(base), delta_(delta) {}
0087
0088 const uint64_t base_;
0089 const uint64_t delta_;
0090 };
0091
0092 struct float16;
0093 float16 cpu_float2half_rn(float f);
0094 float cpu_half2float(float16 h);
0095
0096 struct alignas(2) float16 {
0097 uint16_t x;
0098
0099 float16() : x(0) {}
0100
0101 float16(const float16 &) = default;
0102
0103 explicit float16(int val) {
0104 float16 res = cpu_float2half_rn(static_cast<float>(val));
0105 x = res.x;
0106 }
0107
0108 explicit float16(unsigned long val) {
0109 float16 res = cpu_float2half_rn(static_cast<float>(val));
0110 x = res.x;
0111 }
0112
0113 explicit float16(unsigned long long val) {
0114 float16 res = cpu_float2half_rn(static_cast<float>(val));
0115 x = res.x;
0116 }
0117
0118 explicit float16(double val) {
0119 float16 res = cpu_float2half_rn(static_cast<float>(val));
0120 x = res.x;
0121 }
0122
0123 float16& operator=(const int& rhs) {
0124 float16 res = cpu_float2half_rn(static_cast<float>(rhs));
0125 x = res.x;
0126 return *this;
0127 }
0128
0129 float16& operator=(const float16& rhs) {
0130 if (rhs != *this) {
0131 x = rhs.x;
0132 }
0133 return *this;
0134 }
0135
0136 bool operator==(const float16& rhs) const {
0137 return x == rhs.x;
0138 }
0139
0140 bool operator!=(const float16& rhs) const {
0141 return !(*this == rhs.x);
0142 }
0143
0144 bool operator==(const int& rhs) const {
0145 float16 res = cpu_float2half_rn(static_cast<float>(rhs));
0146 return x == res.x;
0147 }
0148
0149 bool operator==(const unsigned long& rhs) const {
0150 float16 res = cpu_float2half_rn(static_cast<float>(rhs));
0151 return x == res.x;
0152 }
0153
0154 bool operator==(const double& rhs) const {
0155 float16 res = cpu_float2half_rn(static_cast<float>(rhs));
0156 return x == res.x;
0157 }
0158 #ifdef __CUDA_ARCH__
0159 float16(half h) {
0160 #if CUDA_VERSION >= 9000
0161 x = reinterpret_cast<__half_raw*>(&h)->x;
0162 #else
0163 x = h.x;
0164 #endif
0165 }
0166
0167
0168
0169
0170 operator half() const {
0171 #if CUDA_VERSION >= 9000
0172 __half_raw hr;
0173 hr.x = this->x;
0174 return half(hr);
0175 #else
0176 return (half) * this;
0177 #endif
0178 }
0179 #endif
0180
0181 float16& operator+=(const float16& rhs) {
0182 float r = cpu_half2float(*this) + cpu_half2float(rhs);
0183 *this = cpu_float2half_rn(r);
0184 return *this;
0185 }
0186
0187 float16& operator-=(const float16& rhs) {
0188 float r = cpu_half2float(*this) - cpu_half2float(rhs);
0189 *this = cpu_float2half_rn(r);
0190 return *this;
0191 }
0192
0193 float16& operator*=(const float16& rhs) {
0194 float r = cpu_half2float(*this) * cpu_half2float(rhs);
0195 *this = cpu_float2half_rn(r);
0196 return *this;
0197 }
0198
0199 float16& operator/=(const float16& rhs) {
0200 float r = cpu_half2float(*this) / cpu_half2float(rhs);
0201 *this = cpu_float2half_rn(r);
0202 return *this;
0203 }
0204 };
0205
0206 inline std::ostream& operator<<(std::ostream& stream, const float16& val) {
0207 stream << cpu_half2float(val);
0208 return stream;
0209 }
0210
0211 inline float16 operator+(const float16& lhs, const float16& rhs) {
0212 float16 result = lhs;
0213 result += rhs;
0214 return result;
0215 }
0216
0217 inline float16 operator-(const float16& lhs, const float16& rhs) {
0218 float16 result = lhs;
0219 result -= rhs;
0220 return result;
0221 }
0222
0223 inline float16 operator*(const float16& lhs, const float16& rhs) {
0224 float16 result = lhs;
0225 result *= rhs;
0226 return result;
0227 }
0228
0229 inline float16 operator/(const float16& lhs, const float16& rhs) {
0230 float16 result = lhs;
0231 result /= rhs;
0232 return result;
0233 }
0234
0235 inline bool operator<(const float16& lhs, const float16& rhs) {
0236 return cpu_half2float(lhs) < cpu_half2float(rhs);
0237 }
0238
0239 inline bool operator<=(const float16& lhs, const float16& rhs) {
0240 return cpu_half2float(lhs) <= cpu_half2float(rhs);
0241 }
0242
0243 inline bool operator>(const float16& lhs, const float16& rhs) {
0244 return cpu_half2float(lhs) > cpu_half2float(rhs);
0245 }
0246
0247 inline bool operator>=(const float16& lhs, const float16& rhs) {
0248 return cpu_half2float(lhs) >= cpu_half2float(rhs);
0249 }
0250
0251 inline float16 cpu_float2half_rn(float f) {
0252 float16 ret;
0253
0254 static_assert(
0255 sizeof(unsigned int) == sizeof(float),
0256 "Programming error sizeof(unsigned int) != sizeof(float)");
0257
0258 unsigned* xp = reinterpret_cast<unsigned int*>(&f);
0259 unsigned x = *xp;
0260 unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
0261 unsigned sign, exponent, mantissa;
0262
0263
0264 if (u > 0x7f800000) {
0265 ret.x = 0x7fffU;
0266 return ret;
0267 }
0268
0269 sign = ((x >> 16) & 0x8000);
0270
0271
0272 if (u > 0x477fefff) {
0273 ret.x = sign | 0x7c00U;
0274 return ret;
0275 }
0276 if (u < 0x33000001) {
0277 ret.x = (sign | 0x0000);
0278 return ret;
0279 }
0280
0281 exponent = ((u >> 23) & 0xff);
0282 mantissa = (u & 0x7fffff);
0283
0284 if (exponent > 0x70) {
0285 shift = 13;
0286 exponent -= 0x70;
0287 } else {
0288 shift = 0x7e - exponent;
0289 exponent = 0;
0290 mantissa |= 0x800000;
0291 }
0292 lsb = (1 << shift);
0293 lsb_s1 = (lsb >> 1);
0294 lsb_m1 = (lsb - 1);
0295
0296
0297 remainder = (mantissa & lsb_m1);
0298 mantissa >>= shift;
0299 if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
0300 ++mantissa;
0301 if (!(mantissa & 0x3ff)) {
0302 ++exponent;
0303 mantissa = 0;
0304 }
0305 }
0306
0307 ret.x = (sign | (exponent << 10) | mantissa);
0308
0309 return ret;
0310 }
0311
0312 inline float cpu_half2float(float16 h) {
0313 unsigned sign = ((h.x >> 15) & 1);
0314 unsigned exponent = ((h.x >> 10) & 0x1f);
0315 unsigned mantissa = ((h.x & 0x3ff) << 13);
0316
0317 if (exponent == 0x1f) {
0318 mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
0319 exponent = 0xff;
0320 } else if (!exponent) {
0321 if (mantissa) {
0322 unsigned int msb;
0323 exponent = 0x71;
0324 do {
0325 msb = (mantissa & 0x400000);
0326 mantissa <<= 1;
0327 --exponent;
0328 } while (!msb);
0329 mantissa &= 0x7fffff;
0330 }
0331 } else {
0332 exponent += 0x70;
0333 }
0334
0335 unsigned temp = ((sign << 31) | (exponent << 23) | mantissa);
0336
0337 void* rp = &temp;
0338 return *(float*)rp;
0339 }
0340
0341 }