Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:49:13

0001 #pragma once
0002 /**
0003 curanddr.hxx
0004 =============
0005 
0006 Googling for curand with less state revealed CBPRNG (counter based PRNG)
0007 and Philox. See:: 
0008 
0009 https://github.com/kshitijl/curand-done-right
0010 https://github.com/kshitijl/curand-done-right/blob/master/src/curand-done-right/curanddr.hxx
0011 
0012 
0013 
0014 
0015 **/
0016 
0017 #include <curand_kernel.h>
0018 #include <curand_normal.h>
0019 
0020 namespace curanddr {  
0021   template <int Arity, typename num_t = float>
0022   struct alignas(8) vector_t {
0023     num_t values[Arity];
0024     __device__ num_t operator[] (size_t n) const {
0025       return values[n];
0026     }    
0027   };
0028 
0029   template<typename num_t>
0030   struct alignas(8) vector_t<1, num_t> {
0031     num_t values[1];
0032     __device__ num_t operator[] (size_t n) const {
0033       return values[n];
0034     }
0035     __device__ operator num_t() const {
0036       return values[0];
0037     }
0038   };
0039   
0040   // from moderngpu meta.hxx
0041   template<int i, int count, bool valid = (i < count)>
0042   struct iterate_t {
0043     template<typename func_t>
0044     __device__ static void eval(func_t f) {
0045       f(i);
0046       iterate_t<i+1, count>::eval(f);
0047     }
0048   };
0049 
0050   template<int i, int count>
0051   struct iterate_t<i, count, false> {
0052     template<typename func_t>
0053     __device__ static void eval(func_t f) { }
0054   };
0055 
0056   template<int count, typename func_t>
0057   __device__ void iterate(func_t f) {
0058     iterate_t<0, count>::eval(f);
0059   }
0060   
0061   template<int Arity>
0062   __device__ vector_t<Arity> gaussians(uint4 counter, uint key) {
0063     enum { n_blocks = (Arity + 4 - 1)/4 };
0064 
0065     float scratch[n_blocks * 4];
0066   
0067     iterate<n_blocks>([&](uint index) {
0068         uint2 local_key{key, index};
0069         uint4 result = curand_Philox4x32_10(counter, local_key);
0070 
0071         float2 hi = _curand_box_muller(result.x, result.y);
0072         float2 lo = _curand_box_muller(result.z, result.w);
0073 
0074         uint ii = index*4;
0075         scratch[ii] = hi.x;
0076         scratch[ii+1] = hi.y;
0077         scratch[ii+2] = lo.x;
0078         scratch[ii+3] = lo.y;
0079       });
0080 
0081     vector_t<Arity> answer;
0082 
0083     iterate<Arity>([&](uint index) {
0084         answer.values[index] = scratch[index];
0085       });
0086   
0087     return answer;
0088   }
0089 
0090   template<int Arity>
0091   __device__ vector_t<Arity> uniforms(uint4 counter, uint key) {
0092     enum { n_blocks = (Arity + 4 - 1)/4 };
0093 
0094     float scratch[n_blocks * 4];
0095   
0096     iterate<n_blocks>([&](uint index) {
0097         uint2 local_key{key, index};
0098         uint4 result = curand_Philox4x32_10(counter, local_key);
0099 
0100         uint ii = index*4;
0101         scratch[ii]   = _curand_uniform(result.x);
0102         scratch[ii+1] = _curand_uniform(result.y);
0103         scratch[ii+2] = _curand_uniform(result.z);
0104         scratch[ii+3] = _curand_uniform(result.w);
0105       });
0106 
0107     vector_t<Arity> answer;
0108 
0109     iterate<Arity>([&](uint index) {
0110         answer.values[index] = scratch[index];
0111       });
0112   
0113     return answer;
0114   }
0115 
0116   template<int Arity>
0117   __device__ void uniforms_into_buffer(float* answer, uint4 counter, uint key) 
0118   {
0119     enum { n_blocks = (Arity + 4 - 1)/4 };
0120     float scratch[n_blocks * 4];
0121   
0122     iterate<n_blocks>([&](uint index) {
0123         uint2 local_key{key, index};
0124         uint4 result = curand_Philox4x32_10(counter, local_key);
0125 
0126         uint ii = index*4;
0127         scratch[ii]   = _curand_uniform(result.x);
0128         scratch[ii+1] = _curand_uniform(result.y);
0129         scratch[ii+2] = _curand_uniform(result.z);
0130         scratch[ii+3] = _curand_uniform(result.w);
0131       });
0132 
0133     iterate<Arity>([&](uint index) {
0134         answer[index] = scratch[index];
0135       });
0136   }  
0137 
0138   template<int Arity>
0139   __device__ vector_t<Arity, uint> uniform_uints(uint4 counter, uint key) {
0140     enum { n_blocks = (Arity + 4 - 1)/4 };
0141 
0142     uint scratch[n_blocks * 4];
0143   
0144     iterate<n_blocks>([&](uint index) {
0145         uint2 local_key{key, index};
0146         uint4 result = curand_Philox4x32_10(counter, local_key);
0147 
0148         uint ii = index*4;
0149         scratch[ii]   = result.x;
0150         scratch[ii+1] = result.y;
0151         scratch[ii+2] = result.z;
0152         scratch[ii+3] = result.w;
0153       });
0154 
0155     vector_t<Arity, uint> answer;
0156 
0157     iterate<Arity>([&](uint index) {
0158         answer.values[index] = scratch[index];
0159       });
0160   
0161     return answer;
0162   }    
0163 }