Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-15 07:41:52

0001 #!/usr/bin/env python3
0002 """
0003 wls_test.py : pass/fail comparison of GPU vs G4 WLS hit distributions
0004 ======================================================================
0005 
0006 Loads two hit.npy arrays (shape Nx4x4) and runs four checks:
0007 
0008   1. Hit count agreement (Z-score < 5)
0009   2. WLS conversion fraction within 3% (photons with wl > 380 nm)
0010   3. Two-sample KS on the shifted-wavelength spectrum (p > ALPHA)
0011   4. Two-sample KS on shifted-photon arrival times (p > ALPHA)
0012 
0013 Designed to be invoked from tests/test_wavelength_shifting.sh.
0014 
0015 Usage::
0016 
0017     python3 optiphy/ana/wls_test.py <gpu_hit.npy> <g4_hit.npy> [--alpha 0.001]
0018 
0019 Exits 0 on PASS, 1 on FAIL.
0020 """
0021 import argparse
0022 import math
0023 import os
0024 import sys
0025 
0026 import numpy as np
0027 
0028 # Reuse ks_test_2sample from the diagnostic script in the same directory.
0029 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
0030 from wls_diagnostic import ks_test_2sample  # noqa: E402
0031 
0032 
0033 WLS_THRESHOLD_NM = 380.0
0034 
0035 
0036 def load_hits(path):
0037     return np.load(path).reshape(-1, 4, 4)
0038 
0039 
0040 def hit_count_test(gpu, g4):
0041     n_gpu, n_g4 = len(gpu), len(g4)
0042     sigma = math.sqrt(n_gpu + n_g4)
0043     z = abs(n_gpu - n_g4) / sigma if sigma > 0 else 0.0
0044     print("=" * 55)
0045     print("  TEST 1: Hit Count")
0046     print("=" * 55)
0047     print(f"  GPU: {n_gpu}")
0048     print(f"  G4:  {n_g4}")
0049     print(f"  |Z| = {z:.1f} sigma")
0050     passed = z < 5
0051     print(f"  Result: {'PASS' if passed else 'FAIL'} (threshold: 5 sigma)")
0052     return passed
0053 
0054 
0055 def wls_fraction_test(gpu_wl, g4_wl, tolerance=0.03):
0056     gpu_frac = float(np.mean(gpu_wl > WLS_THRESHOLD_NM))
0057     g4_frac = float(np.mean(g4_wl > WLS_THRESHOLD_NM))
0058     diff = abs(gpu_frac - g4_frac)
0059     print()
0060     print("=" * 55)
0061     print("  TEST 2: WLS Conversion Fraction")
0062     print("=" * 55)
0063     print(f"  GPU shifted: {100 * gpu_frac:.1f}%")
0064     print(f"  G4  shifted: {100 * g4_frac:.1f}%")
0065     print(f"  |Difference|: {100 * diff:.2f}%")
0066     passed = diff < tolerance
0067     print(f"  Result: {'PASS' if passed else 'FAIL'} (threshold: {100 * tolerance:.0f}%)")
0068     return passed
0069 
0070 
0071 def shifted_wavelength_ks(gpu_wl, g4_wl, alpha):
0072     gpu_shifted = gpu_wl[gpu_wl > WLS_THRESHOLD_NM]
0073     g4_shifted = g4_wl[g4_wl > WLS_THRESHOLD_NM]
0074     print()
0075     print("=" * 55)
0076     print("  TEST 3: Shifted Wavelength Spectrum (KS Test)")
0077     print("=" * 55)
0078     if len(gpu_shifted) <= 10 or len(g4_shifted) <= 10:
0079         print("  Too few shifted photons for KS test")
0080         print("  Result: PASS (insufficient stats, skipped)")
0081         return True
0082     d, p = ks_test_2sample(gpu_shifted, g4_shifted)
0083     print(f"  GPU shifted: N={len(gpu_shifted)}, mean={gpu_shifted.mean():.1f}nm")
0084     print(f"  G4  shifted: N={len(g4_shifted)}, mean={g4_shifted.mean():.1f}nm")
0085     print(f"  KS D={d:.6f}  p={p:.4f}")
0086     passed = p >= alpha
0087     print(f"  Result: {'PASS' if passed else 'FAIL'} (threshold: p > {alpha})")
0088     return passed
0089 
0090 
0091 def shifted_time_ks(gpu_wl, g4_wl, gpu_time, g4_time, alpha):
0092     gpu_t = gpu_time[gpu_wl > WLS_THRESHOLD_NM]
0093     g4_t = g4_time[g4_wl > WLS_THRESHOLD_NM]
0094     print()
0095     print("=" * 55)
0096     print("  TEST 4: Shifted Photon Arrival Time (KS Test)")
0097     print("=" * 55)
0098     if len(gpu_t) > 0 and len(g4_t) > 0:
0099         print(f"  GPU shifted: N={len(gpu_t)}, "
0100               f"mean={gpu_t.mean():.3f}ns, std={gpu_t.std():.3f}ns")
0101         print(f"  G4  shifted: N={len(g4_t)}, "
0102               f"mean={g4_t.mean():.3f}ns, std={g4_t.std():.3f}ns")
0103         if g4_t.std() > 0:
0104             print(f"  Std ratio: {gpu_t.std() / g4_t.std():.3f} (expect ~1.0)")
0105     if len(gpu_t) <= 10 or len(g4_t) <= 10:
0106         print("  Too few shifted photons for KS test")
0107         print("  Result: PASS (insufficient stats, skipped)")
0108         return True
0109     d, p = ks_test_2sample(gpu_t, g4_t)
0110     print(f"  KS D={d:.6f}  p={p:.4f}")
0111     gpu_unshifted = gpu_time[gpu_wl <= WLS_THRESHOLD_NM]
0112     g4_unshifted = g4_time[g4_wl <= WLS_THRESHOLD_NM]
0113     if len(gpu_unshifted) > 0 and len(g4_unshifted) > 0:
0114         print(f"  Unshifted time: GPU mean={gpu_unshifted.mean():.3f}ns  "
0115               f"G4 mean={g4_unshifted.mean():.3f}ns")
0116     passed = p >= alpha
0117     print(f"  Result: {'PASS' if passed else 'FAIL'} (KS p > {alpha})")
0118     return passed
0119 
0120 
0121 def main():
0122     parser = argparse.ArgumentParser(description=__doc__,
0123                                      formatter_class=argparse.RawDescriptionHelpFormatter)
0124     parser.add_argument("gpu_hit", help="GPU hit.npy")
0125     parser.add_argument("g4_hit", help="G4 hits npy file")
0126     parser.add_argument("--alpha", type=float, default=0.001,
0127                         help="KS p-value significance threshold (default: 0.001)")
0128     args = parser.parse_args()
0129 
0130     gpu = load_hits(args.gpu_hit)
0131     g4 = load_hits(args.g4_hit)
0132 
0133     gpu_wl, g4_wl = gpu[:, 2, 3], g4[:, 2, 3]
0134     gpu_time, g4_time = gpu[:, 0, 3], g4[:, 0, 3]
0135 
0136     results = [
0137         ("Hit count",             hit_count_test(gpu, g4)),
0138         ("WLS fraction",          wls_fraction_test(gpu_wl, g4_wl)),
0139         ("Shifted wavelength KS", shifted_wavelength_ks(gpu_wl, g4_wl, args.alpha)),
0140         ("Shifted time KS",       shifted_time_ks(gpu_wl, g4_wl, gpu_time, g4_time, args.alpha)),
0141     ]
0142 
0143     print()
0144     print("=" * 55)
0145     print("  SUMMARY")
0146     print("=" * 55)
0147     for name, passed in results:
0148         print(f"  {name:>25s}: {'PASS' if passed else 'FAIL'}")
0149 
0150     print()
0151     if all(p for _, p in results):
0152         print("  *** ALL TESTS PASSED ***")
0153         sys.exit(0)
0154     else:
0155         print("  *** SOME TESTS FAILED ***")
0156         sys.exit(1)
0157 
0158 
0159 if __name__ == "__main__":
0160     main()