File indexing completed on 2026-04-09 07:49:13
0001 #pragma once
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 #include <curand_kernel.h>
0018 #include <curand_normal.h>
0019
0020 namespace curanddr {
0021 template <int Arity, typename num_t = float>
0022 struct alignas(8) vector_t {
0023 num_t values[Arity];
0024 __device__ num_t operator[] (size_t n) const {
0025 return values[n];
0026 }
0027 };
0028
0029 template<typename num_t>
0030 struct alignas(8) vector_t<1, num_t> {
0031 num_t values[1];
0032 __device__ num_t operator[] (size_t n) const {
0033 return values[n];
0034 }
0035 __device__ operator num_t() const {
0036 return values[0];
0037 }
0038 };
0039
0040
0041 template<int i, int count, bool valid = (i < count)>
0042 struct iterate_t {
0043 template<typename func_t>
0044 __device__ static void eval(func_t f) {
0045 f(i);
0046 iterate_t<i+1, count>::eval(f);
0047 }
0048 };
0049
0050 template<int i, int count>
0051 struct iterate_t<i, count, false> {
0052 template<typename func_t>
0053 __device__ static void eval(func_t f) { }
0054 };
0055
0056 template<int count, typename func_t>
0057 __device__ void iterate(func_t f) {
0058 iterate_t<0, count>::eval(f);
0059 }
0060
0061 template<int Arity>
0062 __device__ vector_t<Arity> gaussians(uint4 counter, uint key) {
0063 enum { n_blocks = (Arity + 4 - 1)/4 };
0064
0065 float scratch[n_blocks * 4];
0066
0067 iterate<n_blocks>([&](uint index) {
0068 uint2 local_key{key, index};
0069 uint4 result = curand_Philox4x32_10(counter, local_key);
0070
0071 float2 hi = _curand_box_muller(result.x, result.y);
0072 float2 lo = _curand_box_muller(result.z, result.w);
0073
0074 uint ii = index*4;
0075 scratch[ii] = hi.x;
0076 scratch[ii+1] = hi.y;
0077 scratch[ii+2] = lo.x;
0078 scratch[ii+3] = lo.y;
0079 });
0080
0081 vector_t<Arity> answer;
0082
0083 iterate<Arity>([&](uint index) {
0084 answer.values[index] = scratch[index];
0085 });
0086
0087 return answer;
0088 }
0089
0090 template<int Arity>
0091 __device__ vector_t<Arity> uniforms(uint4 counter, uint key) {
0092 enum { n_blocks = (Arity + 4 - 1)/4 };
0093
0094 float scratch[n_blocks * 4];
0095
0096 iterate<n_blocks>([&](uint index) {
0097 uint2 local_key{key, index};
0098 uint4 result = curand_Philox4x32_10(counter, local_key);
0099
0100 uint ii = index*4;
0101 scratch[ii] = _curand_uniform(result.x);
0102 scratch[ii+1] = _curand_uniform(result.y);
0103 scratch[ii+2] = _curand_uniform(result.z);
0104 scratch[ii+3] = _curand_uniform(result.w);
0105 });
0106
0107 vector_t<Arity> answer;
0108
0109 iterate<Arity>([&](uint index) {
0110 answer.values[index] = scratch[index];
0111 });
0112
0113 return answer;
0114 }
0115
0116 template<int Arity>
0117 __device__ void uniforms_into_buffer(float* answer, uint4 counter, uint key)
0118 {
0119 enum { n_blocks = (Arity + 4 - 1)/4 };
0120 float scratch[n_blocks * 4];
0121
0122 iterate<n_blocks>([&](uint index) {
0123 uint2 local_key{key, index};
0124 uint4 result = curand_Philox4x32_10(counter, local_key);
0125
0126 uint ii = index*4;
0127 scratch[ii] = _curand_uniform(result.x);
0128 scratch[ii+1] = _curand_uniform(result.y);
0129 scratch[ii+2] = _curand_uniform(result.z);
0130 scratch[ii+3] = _curand_uniform(result.w);
0131 });
0132
0133 iterate<Arity>([&](uint index) {
0134 answer[index] = scratch[index];
0135 });
0136 }
0137
0138 template<int Arity>
0139 __device__ vector_t<Arity, uint> uniform_uints(uint4 counter, uint key) {
0140 enum { n_blocks = (Arity + 4 - 1)/4 };
0141
0142 uint scratch[n_blocks * 4];
0143
0144 iterate<n_blocks>([&](uint index) {
0145 uint2 local_key{key, index};
0146 uint4 result = curand_Philox4x32_10(counter, local_key);
0147
0148 uint ii = index*4;
0149 scratch[ii] = result.x;
0150 scratch[ii+1] = result.y;
0151 scratch[ii+2] = result.z;
0152 scratch[ii+3] = result.w;
0153 });
0154
0155 vector_t<Arity, uint> answer;
0156
0157 iterate<Arity>([&](uint index) {
0158 answer.values[index] = scratch[index];
0159 });
0160
0161 return answer;
0162 }
0163 }