Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-10 07:49:41

0001 /**
0002 QPMT.cu
0003 ==========
0004 
0005 _QPMT_lpmtcat_rindex
0006 _QPMT_lpmtcat_qeshape
0007 _QPMT_lpmtcat_stackspec
0008     kernel funcs taking (qpmt,lookup,domain,domain_width) args
0009 
0010 QPMT_pmtcat_scan
0011     CPU entry point to launch above kernels controlled by etype
0012 
0013 
0014 _QPMT_lpmtid_stackspec
0015     kernel funcs taking (qpmt,lookup,domain,domain_width,lpmtid,num_lpmtid) args
0016 
0017 _QPMT_mct_lpmtid
0018     payload size P templated kernel function with domain and lpmtid array inputs
0019 
0020     * within lpmtid loop calls qpmt.h method depending on etype
0021     * etype : (qpmt_SPEC qpmt_LL qpmt_COMP qpmt_ART qpmt_ARTE)
0022 
0023 QPMT_mct_lpmtid_scan
0024     CPU entry point to launch above kernel passing etype
0025 
0026 
0027 **/
0028 
0029 #include "QUDARAP_API_EXPORT.hh"
0030 #include <stdio.h>
0031 #include "qpmt_enum.h"
0032 #include "qpmt.h"
0033 #include "qprop.h"
0034 
0035 
0036 /**
0037 _QPMT_lpmtcat_rindex
0038 ---------------------------
0039 
0040 max_iprop::
0041 
0042    . (ni-1)*nj*nk + (nj-1)*nk + (nk-1)
0043    =  ni*nj*nk - nj*nk + nj*nk - nk + nk - 1
0044    =  ni*nj*nk - 1
0045 
0046 
0047 HMM: not so easy to generalize from rindex to also do qeshape
0048 because of the different array shapes
0049 
0050 Each thread does all pmtcat,layers and props for a single energy_eV.
0051 
0052 **/
0053 
0054 template <typename F>
0055 __global__ void _QPMT_lpmtcat_rindex( int etype, qpmt<F>* pmt, F* lookup , const F* domain, unsigned domain_width )
0056 {
0057     unsigned ix = blockIdx.x * blockDim.x + threadIdx.x;
0058     if (ix >= domain_width ) return;
0059     F domain_value = domain[ix] ;    // energy_eV
0060 
0061     //printf("//_QPMT_rindex domain_width %d ix %d domain_value %10.4f \n", domain_width, ix, domain_value );
0062     // wierd unsigned/int diff between qpmt.h and here ? to get it to compile for device
0063     // switching to enum rather than constexpr const avoids the wierdness
0064 
0065     const int& ni = s_pmt::NUM_CAT ;
0066     const int& nj = s_pmt::NUM_LAYR ;
0067     const int& nk = s_pmt::NUM_PROP ;
0068 
0069     //printf("//_QPMT_lpmtcat_rindex ni %d nj %d nk %d \n", ni, nj, nk );
0070     // cf the CPU equivalent NP::combined_interp_5
0071 
0072     for(int i=0 ; i < ni ; i++)
0073     for(int j=0 ; j < nj ; j++)
0074     for(int k=0 ; k < nk ; k++)
0075     {
0076         int iprop = i*nj*nk+j*nk+k ;            // linearized higher dimensions
0077         int index = iprop * domain_width + ix ; // output index into lookup
0078 
0079         F value = pmt->rindex_prop->interpolate(iprop, domain_value );
0080 
0081         //printf("//_QPMT_lpmtcat_rindex iprop %d index %d value %10.4f \n", iprop, index, value );
0082 
0083         lookup[index] = value ;
0084     }
0085 }
0086 
0087 
0088 
0089 template <typename F>
0090 __global__ void _QPMT_lpmtcat_stackspec( int etype, qpmt<F>* pmt, F* lookup , const F* domain, unsigned domain_width )
0091 {
0092     unsigned ix = blockIdx.x * blockDim.x + threadIdx.x;
0093     if (ix >= domain_width ) return;
0094     F domain_value = domain[ix] ;
0095 
0096     //printf("//_QPMT_lpmtcat_stackspec domain_width %d ix %d domain_value %10.4f \n", domain_width, ix, domain_value );
0097 
0098     const int& ni = s_pmt::NUM_CAT ;
0099     const int& nj = domain_width ;
0100     const int  nk = 16 ;
0101     const int&  j = ix ;
0102 
0103     F ss[nk] ;
0104 
0105     for(int i=0 ; i < ni ; i++)  // over pmtcat
0106     {
0107         int index = i*nj*nk + j*nk  ;
0108         pmt->get_lpmtcat_stackspec(ss, i, domain_value );
0109         for( int k=0 ; k < nk ; k++) lookup[index+k] = ss[k] ;
0110     }
0111 }
0112 
0113 
0114 
0115 template <typename F>
0116 __global__ void _QPMT_pmtcat_launch( int etype, qpmt<F>* pmt, F* lookup , const F* domain, unsigned domain_width )
0117 {
0118     unsigned ix = blockIdx.x * blockDim.x + threadIdx.x;
0119     if (ix >= domain_width ) return;
0120     F domain_value = domain[ix] ;
0121 
0122     //printf("//_QPMT_pmtcat_launch etype %d domain_width %d ix %d  \n", etype, domain_width, ix  );
0123 
0124     const int ni = ( etype == qpmt_S_QESHAPE ) ? 1 : s_pmt::NUM_CAT ;
0125 
0126     for(int i=0 ; i < ni ; i++)
0127     {
0128         int pmtcat = i ;
0129         F value = 0.f ;
0130 
0131         if( etype == qpmt_QESHAPE )
0132         {
0133             value = pmt->qeshape_prop->interpolate( pmtcat, domain_value );
0134         }
0135         else if( etype == qpmt_CETHETA )
0136         {
0137             //value = pmt->cetheta_prop->interpolate(lpmtcat, domain_value );
0138             value = pmt->get_lpmtcat_ce( pmtcat, domain_value );
0139         }
0140         else if ( etype == qpmt_CECOSTH )
0141         {
0142             value = pmt->cecosth_prop->interpolate( pmtcat, domain_value );
0143         }
0144         else if( etype == qpmt_S_QESHAPE )
0145         {
0146             value = pmt->s_qeshape_prop->interpolate( pmtcat, domain_value );
0147         }
0148 
0149 
0150         int index = i * domain_width + ix ; // output index into lookup
0151         lookup[index] = value ;
0152     }
0153 }
0154 
0155 
0156 
0157 
0158 /**
0159 QPMT_pmtcat_scan
0160 -------------------
0161 
0162 Performs CUDA launches, invoked from QPMT.cc QPMT<T>::pmtcat_scan
0163 
0164 **/
0165 
0166 
0167 template <typename F> extern void QPMT_pmtcat_scan(
0168     dim3 numBlocks,
0169     dim3 threadsPerBlock,
0170     qpmt<F>* pmt,
0171     int etype,
0172     F* lookup,
0173     const F* domain,
0174     unsigned domain_width
0175 )
0176 {
0177 
0178     switch(etype)
0179     {
0180         case qpmt_RINDEX     : _QPMT_lpmtcat_rindex<F><<<numBlocks,threadsPerBlock>>>(    etype, pmt, lookup, domain, domain_width )   ; break ;
0181         case qpmt_CATSPEC    : _QPMT_lpmtcat_stackspec<F><<<numBlocks,threadsPerBlock>>>( etype, pmt, lookup, domain, domain_width )   ; break ;
0182         case qpmt_QESHAPE    : _QPMT_pmtcat_launch<F><<<numBlocks,threadsPerBlock>>>(    etype, pmt, lookup, domain, domain_width )   ; break ;
0183         case qpmt_CETHETA    : _QPMT_pmtcat_launch<F><<<numBlocks,threadsPerBlock>>>(    etype, pmt, lookup, domain, domain_width )   ; break ;
0184         case qpmt_CECOSTH    : _QPMT_pmtcat_launch<F><<<numBlocks,threadsPerBlock>>>(    etype, pmt, lookup, domain, domain_width )   ; break ;
0185         case qpmt_S_QESHAPE  : _QPMT_pmtcat_launch<F><<<numBlocks,threadsPerBlock>>>(    etype, pmt, lookup, domain, domain_width )   ; break ;
0186     }
0187 }
0188 
0189 template void QPMT_pmtcat_scan(
0190    dim3,
0191    dim3,
0192    qpmt<float>*,
0193    int etype,
0194    float*,
0195    const float* ,
0196    unsigned
0197   );
0198 
0199 
0200 
0201 
0202 
0203 
0204 
0205 
0206 /**
0207 _QPMT_lpmtid_stackspec
0208 -------------------------
0209 
0210 **/
0211 
0212 
0213 template <typename F>
0214 __global__ void _QPMT_lpmtid_stackspec(
0215     qpmt<F>* pmt,
0216     F* lookup ,
0217     const F* domain,
0218     unsigned domain_width,
0219     const int* lpmtid,
0220     unsigned num_lpmtid )
0221 {
0222     unsigned ix = blockIdx.x * blockDim.x + threadIdx.x;
0223     if (ix >= domain_width ) return;
0224     F energy_eV = domain[ix] ;
0225 
0226     const int& ni = num_lpmtid ;
0227     const int& nj = domain_width ;
0228     const int  nk = 16 ;
0229     const int&  j = ix ;
0230 
0231     F ss[nk] ;
0232 
0233     for(int i=0 ; i < ni ; i++)  // over num_lpmtid
0234     {
0235         int pmtid = lpmtid[i] ;
0236         int index = i*nj*nk + j*nk  ;
0237         pmt->get_lpmtid_stackspec(ss, pmtid, energy_eV );
0238         for( int k=0 ; k < nk ; k++) lookup[index+k] = ss[k] ;
0239     }
0240 }
0241 
0242 
0243 
0244 
0245 
0246 /**
0247 _QPMT_mct_lpmtid
0248 -----------------
0249 
0250 * using templated payload size P as it needs to be a compile time constant
0251 * parallelism over mct domain only
0252 * loops over the provided list of pmtid
0253 
0254 
0255 **/
0256 
0257 #ifdef WITH_CUSTOM4
0258 template <typename F, int P>
0259 __global__ void _QPMT_mct_lpmtid(
0260     qpmt<F>* pmt,
0261     int etype,
0262     F* lookup ,
0263     const F* domain,
0264     unsigned domain_width,
0265     const int* lpmtid,
0266     unsigned num_lpmtid )
0267 {
0268     unsigned ix = blockIdx.x * blockDim.x + threadIdx.x;
0269     if (ix >= domain_width ) return;
0270 
0271     //printf("//_QPMT_mct_lpmtid etype %d ix %d num_lpmtid %d P %d \n", etype, ix, num_lpmtid, P );
0272 
0273     F minus_cos_theta = domain[ix] ;
0274     F wavelength_nm = 440.f ;
0275     F dot_pol_cross_mom_nrm = 0.f ; // SPOL zero is pure P polarized
0276     F lposcost = 0.5f ;  // np.acos(0.5) 1.047197
0277 
0278     const int& ni = num_lpmtid ;
0279     const int& nj = domain_width ;   // minus_cos_theta values "AOI"
0280     const int&  j = ix ;
0281 
0282 
0283 #if !defined(PRODUCTION) && defined(DEBUG_PIDX)
0284     unsigned pidx = 0u ;
0285     bool pidx_debug = false ;
0286 #endif
0287 
0288     F payload[P] ;
0289 
0290     for(int i=0 ; i < ni ; i++)  // over num_lpmtid
0291     {
0292         int pmtid = lpmtid[i] ;
0293 
0294         if( etype == qpmt_SPEC )
0295         {
0296             pmt->get_lpmtid_SPEC(payload, pmtid, wavelength_nm );
0297         }
0298         else if( etype == qpmt_SPEC_ce )
0299         {
0300 #if !defined(PRODUCTION) && defined(DEBUG_PIDX)
0301             pmt->get_lpmtid_SPEC_ce(payload, pmtid, wavelength_nm, lposcost, pidx, pidx_debug );
0302 #else
0303             pmt->get_lpmtid_SPEC_ce(payload, pmtid, wavelength_nm, lposcost );
0304 #endif
0305         }
0306         else if( etype == qpmt_LL )
0307         {
0308             pmt->get_lpmtid_LL(payload, pmtid, wavelength_nm, minus_cos_theta, dot_pol_cross_mom_nrm );
0309         }
0310         else if( etype == qpmt_COMP )
0311         {
0312             pmt->get_lpmtid_COMP(payload, pmtid, wavelength_nm, minus_cos_theta, dot_pol_cross_mom_nrm );
0313         }
0314         else if( etype == qpmt_ART )
0315         {
0316             pmt->get_lpmtid_ART(payload, pmtid, wavelength_nm, minus_cos_theta, dot_pol_cross_mom_nrm );
0317         }
0318         else if( etype == qpmt_ARTE )
0319         {
0320             pmt->get_lpmtid_ARTE(payload, pmtid, wavelength_nm, minus_cos_theta, dot_pol_cross_mom_nrm );
0321         }
0322         else if( etype == qpmt_ATQC )
0323         {
0324 #if !defined(PRODUCTION) && defined(DEBUG_PIDX)
0325             pmt->get_lpmtid_ATQC(payload, pmtid, wavelength_nm, minus_cos_theta, dot_pol_cross_mom_nrm, lposcost, pidx, pidx_debug );
0326 #else
0327             pmt->get_lpmtid_ATQC(payload, pmtid, wavelength_nm, minus_cos_theta, dot_pol_cross_mom_nrm, lposcost );
0328 #endif
0329         }
0330 
0331 
0332         int index = i*nj*P + j*P  ;  // output index
0333         for( int k=0 ; k < P ; k++) lookup[index+k] = payload[k] ;
0334     }
0335 }
0336 
0337 
0338 template <typename F> extern void QPMT_mct_lpmtid_scan(
0339     dim3 numBlocks,
0340     dim3 threadsPerBlock,
0341     qpmt<F>* pmt,
0342     int etype,
0343     F* lookup,
0344     const F* domain,
0345     unsigned domain_width,
0346     const int* lpmtid,
0347     unsigned num_lpmtid
0348 )
0349 {
0350     printf("//QPMT_mct_lpmtid_scan etype %d domain_width %d num_lpmtid %d \n", etype, domain_width, num_lpmtid);
0351 
0352     switch(etype)
0353     {
0354         case qpmt_SPEC:
0355            _QPMT_mct_lpmtid<F,16><<<numBlocks,threadsPerBlock>>>(
0356               pmt, etype, lookup, domain, domain_width, lpmtid, num_lpmtid ) ;  break ;
0357 
0358         case qpmt_SPEC_ce:
0359            _QPMT_mct_lpmtid<F,16><<<numBlocks,threadsPerBlock>>>(
0360               pmt, etype, lookup, domain, domain_width, lpmtid, num_lpmtid ) ;  break ;
0361 
0362         case qpmt_ART:
0363            _QPMT_mct_lpmtid<F,16><<<numBlocks,threadsPerBlock>>>(
0364               pmt, etype, lookup, domain, domain_width, lpmtid, num_lpmtid ) ;  break ;
0365 
0366         case qpmt_COMP:
0367            _QPMT_mct_lpmtid<F,32><<<numBlocks,threadsPerBlock>>>(
0368               pmt, etype, lookup, domain, domain_width, lpmtid, num_lpmtid ) ;  break ;
0369 
0370         case qpmt_LL:
0371            _QPMT_mct_lpmtid<F,128><<<numBlocks,threadsPerBlock>>>(
0372               pmt, etype, lookup, domain, domain_width, lpmtid, num_lpmtid ) ;  break ;
0373 
0374         case qpmt_ARTE:
0375            _QPMT_mct_lpmtid<F,4><<<numBlocks,threadsPerBlock>>>(
0376               pmt, etype, lookup, domain, domain_width, lpmtid, num_lpmtid ) ;  break ;
0377 
0378         case qpmt_ATQC:
0379            _QPMT_mct_lpmtid<F,4><<<numBlocks,threadsPerBlock>>>(
0380               pmt, etype, lookup, domain, domain_width, lpmtid, num_lpmtid ) ;  break ;
0381 
0382         default:
0383               printf("//PMT_mct_lpmtid_scan etype %d UNHANDLED \n", etype)   ; break ;
0384 
0385     }
0386 }
0387 
0388 template void QPMT_mct_lpmtid_scan<float>(   dim3, dim3, qpmt<float>*, int etype, float*,  const float* , unsigned, const int*, unsigned);
0389 // end WITH_CUSTOM4
0390 #endif
0391 
0392 
0393 
0394 
0395 
0396 
0397 
0398 
0399 template <typename F>
0400 __global__ void _QPMT_spmtid(
0401     qpmt<F>* pmt,
0402     int etype,
0403     F* lookup ,
0404     const int* spmtid,
0405     unsigned num_spmtid )
0406 {
0407     unsigned ix = blockIdx.x * blockDim.x + threadIdx.x;
0408     if (ix >= num_spmtid ) return;
0409     int _spmtid = spmtid[ix];
0410     //printf("//_QPMT_spmtid etype %d ix %d num_spmtid %d _spmtid %d \n", etype, ix, num_spmtid, _spmtid );
0411 
0412     F value = 0.f ;
0413     if( etype == qpmt_S_QESCALE )
0414     {
0415         value = pmt->get_s_qescale_from_spmtid( _spmtid );
0416     }
0417     lookup[ix] = value ;
0418 }
0419 
0420 
0421 
0422 
0423 template <typename F> extern void QPMT_spmtid_scan(
0424     dim3 numBlocks,
0425     dim3 threadsPerBlock,
0426     qpmt<F>* pmt,
0427     int etype,
0428     F* lookup,
0429     const int* spmtid,
0430     unsigned num_spmtid
0431 )
0432 {
0433     printf("//QPMT_spmtid_scan etype %d num_spmtid %d \n", etype, num_spmtid);
0434     switch(etype)
0435     {
0436         case qpmt_S_QESCALE:
0437            _QPMT_spmtid<F><<<numBlocks,threadsPerBlock>>>(pmt, etype, lookup, spmtid, num_spmtid ) ;  break ;
0438     }
0439 }
0440 
0441 template void QPMT_spmtid_scan<float>( dim3, dim3, qpmt<float>*, int, float*, const int*, unsigned );
0442 
0443