Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:12

0001 /**
0002  * Copyright (c) Facebook, Inc. and its affiliates.
0003  */
0004 
0005 #pragma once
0006 
0007 #include <iostream>
0008 
0009 #ifdef __CUDA_ARCH__
0010 #include <cuda.h>
0011 // Disable strict aliasing errors for CUDA 9.
0012 #if CUDA_VERSION >= 9000
0013 #ifdef __GNUC__
0014 #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
0015 #pragma GCC diagnostic push
0016 #endif
0017 #pragma GCC diagnostic ignored "-Wstrict-aliasing"
0018 #endif // __GNUC__
0019 #endif // CUDA_VERSION >= 9000
0020 #include <cuda_fp16.h>
0021 #if CUDA_VERSION >= 9000
0022 #ifdef __GNUC__
0023 #if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)
0024 #pragma GCC diagnostic pop
0025 #endif
0026 #endif // __GNUC__
0027 #endif // CUDA_VERSION >= 9000
0028 #endif
0029 
0030 #include "gloo/common/common.h"
0031 
0032 #ifdef _WIN32
0033 #include <BaseTsd.h>
0034 typedef SSIZE_T ssize_t;
0035 #endif
0036 
0037 namespace gloo {
0038 
0039 // Unlike old style collectives that are class instances that hold
0040 // some state, the new style collectives do not need initialization
0041 // before they can run. Instead of asking the context for a series of
0042 // slots and storing them for later use and reuse, the new style
0043 // collectives take a slot (or tag) argument that allows for
0044 // concurrent execution of multiple collectives on the same context.
0045 //
0046 // This tag is what determines the slot numbers for the send and recv
0047 // operations that the collectives end up executing. A single
0048 // collective may have many send and recv operations running in
0049 // parallel, so instead of using the specified tag verbatim, we use it
0050 // as a prefix. Also, to avoid conflicts between collectives with the
0051 // same tag, we have another tag prefix per collective type. Out of
0052 // the 64 bits we can use for a slot, we use 8 of them to identify a
0053 // collective, 32 to identify the collective tag, another 8 for use by
0054 // the collective operation itself (allowing for 256 independent send
0055 // and recv operations against the same point to point pair), and
0056 // leave 16 bits unused.
0057 //
0058 // Below, you find constexprs for the prefix per collective type, as
0059 // well as a way to compute slots when executing a collective. The
0060 // slot class below captures both a prefix and a delta on that prefix
0061 // to support addition with bounds checking. It is usable as an
0062 // uint64_t, but one that cannot overflow beyond the bits allocated
0063 // for use within a collective.
0064 //
0065 
0066 constexpr uint8_t kGatherSlotPrefix = 0x01;
0067 constexpr uint8_t kAllgatherSlotPrefix = 0x02;
0068 constexpr uint8_t kReduceSlotPrefix = 0x03;
0069 constexpr uint8_t kAllreduceSlotPrefix = 0x04;
0070 constexpr uint8_t kScatterSlotPrefix = 0x05;
0071 constexpr uint8_t kBroadcastSlotPrefix = 0x06;
0072 constexpr uint8_t kBarrierSlotPrefix = 0x07;
0073 constexpr uint8_t kAlltoallSlotPrefix = 0x08;
0074 
0075 class Slot {
0076  public:
0077   static Slot build(uint8_t prefix, uint32_t tag);
0078 
0079   operator uint64_t() const {
0080     return base_ + delta_;
0081   }
0082 
0083   Slot operator+(uint8_t i) const;
0084 
0085  protected:
0086   explicit Slot(uint64_t base, uint64_t delta) : base_(base), delta_(delta) {}
0087 
0088   const uint64_t base_;
0089   const uint64_t delta_;
0090 };
0091 
0092 struct float16;
0093 float16 cpu_float2half_rn(float f);
0094 float cpu_half2float(float16 h);
0095 
0096 struct alignas(2) float16 {
0097   uint16_t x;
0098 
0099   float16() : x(0) {}
0100 
0101   float16(const float16 &) = default;
0102 
0103   explicit float16(int val) {
0104     float16 res = cpu_float2half_rn(static_cast<float>(val));
0105     x = res.x;
0106   }
0107 
0108   explicit float16(unsigned long val) {
0109     float16 res = cpu_float2half_rn(static_cast<float>(val));
0110     x = res.x;
0111   }
0112 
0113   explicit float16(unsigned long long val) {
0114     float16 res = cpu_float2half_rn(static_cast<float>(val));
0115     x = res.x;
0116   }
0117 
0118   explicit float16(double val) {
0119     float16 res = cpu_float2half_rn(static_cast<float>(val));
0120     x = res.x;
0121   }
0122 
0123   float16& operator=(const int& rhs) {
0124     float16 res = cpu_float2half_rn(static_cast<float>(rhs));
0125     x = res.x;
0126     return *this;
0127   }
0128 
0129   float16& operator=(const float16& rhs) {
0130     if (rhs != *this) {
0131       x = rhs.x;
0132     }
0133     return *this;
0134   }
0135 
0136   bool operator==(const float16& rhs) const {
0137     return x == rhs.x;
0138   }
0139 
0140   bool operator!=(const float16& rhs) const {
0141     return !(*this == rhs.x);
0142   }
0143 
0144   bool operator==(const int& rhs) const {
0145     float16 res = cpu_float2half_rn(static_cast<float>(rhs));
0146     return x == res.x;
0147   }
0148 
0149   bool operator==(const unsigned long& rhs) const {
0150     float16 res = cpu_float2half_rn(static_cast<float>(rhs));
0151     return x == res.x;
0152   }
0153 
0154   bool operator==(const double& rhs) const {
0155     float16 res = cpu_float2half_rn(static_cast<float>(rhs));
0156     return x == res.x;
0157   }
0158 #ifdef __CUDA_ARCH__
0159   float16(half h) {
0160 #if CUDA_VERSION >= 9000
0161     x = reinterpret_cast<__half_raw*>(&h)->x;
0162 #else
0163     x = h.x;
0164 #endif // CUDA_VERSION
0165   }
0166 
0167   // half and float16 are supposed to have identical representation so implicit
0168   // conversion should be fine
0169   /* implicit */
0170   operator half() const {
0171 #if CUDA_VERSION >= 9000
0172     __half_raw hr;
0173     hr.x = this->x;
0174     return half(hr);
0175 #else
0176     return (half) * this;
0177 #endif // CUDA_VERSION
0178   }
0179 #endif // __CUDA_ARCH
0180 
0181   float16& operator+=(const float16& rhs) {
0182     float r = cpu_half2float(*this) + cpu_half2float(rhs);
0183     *this = cpu_float2half_rn(r);
0184     return *this;
0185   }
0186 
0187   float16& operator-=(const float16& rhs) {
0188     float r = cpu_half2float(*this) - cpu_half2float(rhs);
0189     *this = cpu_float2half_rn(r);
0190     return *this;
0191   }
0192 
0193   float16& operator*=(const float16& rhs) {
0194     float r = cpu_half2float(*this) * cpu_half2float(rhs);
0195     *this = cpu_float2half_rn(r);
0196     return *this;
0197   }
0198 
0199   float16& operator/=(const float16& rhs) {
0200     float r = cpu_half2float(*this) / cpu_half2float(rhs);
0201     *this = cpu_float2half_rn(r);
0202     return *this;
0203   }
0204 };
0205 
0206 inline std::ostream& operator<<(std::ostream& stream, const float16& val) {
0207   stream << cpu_half2float(val);
0208   return stream;
0209 }
0210 
0211 inline float16 operator+(const float16& lhs, const float16& rhs) {
0212   float16 result = lhs;
0213   result += rhs;
0214   return result;
0215 }
0216 
0217 inline float16 operator-(const float16& lhs, const float16& rhs) {
0218   float16 result = lhs;
0219   result -= rhs;
0220   return result;
0221 }
0222 
0223 inline float16 operator*(const float16& lhs, const float16& rhs) {
0224   float16 result = lhs;
0225   result *= rhs;
0226   return result;
0227 }
0228 
0229 inline float16 operator/(const float16& lhs, const float16& rhs) {
0230   float16 result = lhs;
0231   result /= rhs;
0232   return result;
0233 }
0234 
0235 inline bool operator<(const float16& lhs, const float16& rhs) {
0236   return cpu_half2float(lhs) < cpu_half2float(rhs);
0237 }
0238 
0239 inline bool operator<=(const float16& lhs, const float16& rhs) {
0240   return cpu_half2float(lhs) <= cpu_half2float(rhs);
0241 }
0242 
0243 inline bool operator>(const float16& lhs, const float16& rhs) {
0244   return cpu_half2float(lhs) > cpu_half2float(rhs);
0245 }
0246 
0247 inline bool operator>=(const float16& lhs, const float16& rhs) {
0248   return cpu_half2float(lhs) >= cpu_half2float(rhs);
0249 }
0250 
0251 inline float16 cpu_float2half_rn(float f) {
0252   float16 ret;
0253 
0254   static_assert(
0255       sizeof(unsigned int) == sizeof(float),
0256       "Programming error sizeof(unsigned int) != sizeof(float)");
0257 
0258   unsigned* xp = reinterpret_cast<unsigned int*>(&f);
0259   unsigned x = *xp;
0260   unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
0261   unsigned sign, exponent, mantissa;
0262 
0263   // Get rid of +NaN/-NaN case first.
0264   if (u > 0x7f800000) {
0265     ret.x = 0x7fffU;
0266     return ret;
0267   }
0268 
0269   sign = ((x >> 16) & 0x8000);
0270 
0271   // Get rid of +Inf/-Inf, +0/-0.
0272   if (u > 0x477fefff) {
0273     ret.x = sign | 0x7c00U;
0274     return ret;
0275   }
0276   if (u < 0x33000001) {
0277     ret.x = (sign | 0x0000);
0278     return ret;
0279   }
0280 
0281   exponent = ((u >> 23) & 0xff);
0282   mantissa = (u & 0x7fffff);
0283 
0284   if (exponent > 0x70) {
0285     shift = 13;
0286     exponent -= 0x70;
0287   } else {
0288     shift = 0x7e - exponent;
0289     exponent = 0;
0290     mantissa |= 0x800000;
0291   }
0292   lsb = (1 << shift);
0293   lsb_s1 = (lsb >> 1);
0294   lsb_m1 = (lsb - 1);
0295 
0296   // Round to nearest even.
0297   remainder = (mantissa & lsb_m1);
0298   mantissa >>= shift;
0299   if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
0300     ++mantissa;
0301     if (!(mantissa & 0x3ff)) {
0302       ++exponent;
0303       mantissa = 0;
0304     }
0305   }
0306 
0307   ret.x = (sign | (exponent << 10) | mantissa);
0308 
0309   return ret;
0310 }
0311 
0312 inline float cpu_half2float(float16 h) {
0313   unsigned sign = ((h.x >> 15) & 1);
0314   unsigned exponent = ((h.x >> 10) & 0x1f);
0315   unsigned mantissa = ((h.x & 0x3ff) << 13);
0316 
0317   if (exponent == 0x1f) { /* NaN or Inf */
0318     mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
0319     exponent = 0xff;
0320   } else if (!exponent) { /* Denorm or Zero */
0321     if (mantissa) {
0322       unsigned int msb;
0323       exponent = 0x71;
0324       do {
0325         msb = (mantissa & 0x400000);
0326         mantissa <<= 1; /* normalize */
0327         --exponent;
0328       } while (!msb);
0329       mantissa &= 0x7fffff; /* 1.mantissa is implicit */
0330     }
0331   } else {
0332     exponent += 0x70;
0333   }
0334 
0335   unsigned temp = ((sign << 31) | (exponent << 23) | mantissa);
0336 
0337   void* rp = &temp;
0338   return *(float*)rp;
0339 }
0340 
0341 } // namespace gloo