Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-06-26 07:49:48

0001 #!/usr/bin/env python3
0002 """
0003 compare_ab.py : pass/fail comparison of A/B event records
0004 =========================================================
0005 
0006 Validates persisted `record.npy` outputs from paired A/B event directories by
0007 comparing aligned Opticks/G4 records and checking the known
0008 Geant4-version-dependent mismatch indices.
0009 """
0010 
0011 import argparse
0012 import sys
0013 from pathlib import Path
0014 
0015 import numpy as np
0016 
0017 
0018 REPO_ROOT = Path(__file__).resolve().parents[2]
0019 if str(REPO_ROOT) not in sys.path:
0020     sys.path.insert(0, str(REPO_ROOT))
0021 
0022 from optiphy.geant4_version import detect_geant4_version, geant4_series
0023 
0024 
0025 EXPECTED_DIFF = {
0026     "11.3": [14, 22, 32, 34, 40, 81, 85],
0027     "11.4+": [0, 30, 32, 34, 42, 69, 78, 85, 86],
0028 }
0029 
0030 
0031 def expected_diff_for_version(version):
0032     return EXPECTED_DIFF[geant4_series(version)]
0033 
0034 
0035 def load_records(base, a_record, b_record):
0036     a_path = base / a_record
0037     b_path = base / b_record
0038 
0039     if not a_path.is_file():
0040         raise FileNotFoundError(f"Missing Opticks record file: {a_path}")
0041     if not b_path.is_file():
0042         raise FileNotFoundError(f"Missing Geant4 record file: {b_path}")
0043 
0044     return np.load(a_path), np.load(b_path)
0045 
0046 
0047 def compare_records(a, b):
0048     if a.shape != b.shape:
0049         raise AssertionError(f"Shape mismatch: {a.shape} != {b.shape}")
0050 
0051     # Geant4 and Opticks record one-step-shifted sequences for this geometry,
0052     # so compare aligned slices directly, including time.
0053     a_cmp = a[:, 1:]
0054     b_cmp = b[:, :-1]
0055 
0056     return [
0057         index
0058         for index, (a_row, b_row) in enumerate(zip(a_cmp, b_cmp))
0059         if not np.allclose(a_row, b_row, rtol=0.0, atol=1e-5)
0060     ]
0061 
0062 
0063 def main():
0064     parser = argparse.ArgumentParser(description=__doc__)
0065     parser.add_argument("--base", default=".", help="directory containing the A/B event outputs")
0066     parser.add_argument(
0067         "--a-record",
0068         default="ALL0_no_opticks_event_name/A000/record.npy",
0069         help="path to the A-side record.npy relative to --base",
0070     )
0071     parser.add_argument(
0072         "--b-record",
0073         default="ALL0_no_opticks_event_name/B000/f000/record.npy",
0074         help="path to the B-side record.npy relative to --base",
0075     )
0076     args = parser.parse_args()
0077 
0078     base = Path(args.base).resolve()
0079     geant4_version = detect_geant4_version()
0080     expected_diff = expected_diff_for_version(geant4_version)
0081 
0082     a, b = load_records(base, Path(args.a_record), Path(args.b_record))
0083     diff = compare_records(a, b)
0084 
0085     print(f"BASE={base}")
0086     print(f"A_SHAPE={a.shape}")
0087     print(f"B_SHAPE={b.shape}")
0088     print(f"GEANT4_VERSION={geant4_version}")
0089     print(f"EXPECTED_DIFF={expected_diff}")
0090     print(f"ACTUAL_DIFF={diff}")
0091 
0092     if diff != expected_diff:
0093         raise AssertionError(f"Mismatch indices differ: expected {expected_diff}, got {diff}")
0094 
0095 
0096 if __name__ == "__main__":
0097     main()