Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:24:18

0001 # This file is part of the ACTS project.
0002 #
0003 # Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 #
0005 # This Source Code Form is subject to the terms of the Mozilla Public
0006 # License, v. 2.0. If a copy of the MPL was not distributed with this
0007 # file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 # detray imports
0010 from impl import (
0011     read_scan_data,
0012     read_navigation_intersection_data,
0013     read_navigation_track_data,
0014 )
0015 from impl import (
0016     plot_detector_scan_data,
0017     plot_navigation_intersection_data,
0018     plot_navigation_track_data,
0019 )
0020 from impl import plot_track_params
0021 from options import (
0022     common_options,
0023     detector_io_options,
0024     random_track_generator_options,
0025     propagation_options,
0026     plotting_options,
0027 )
0028 from options import (
0029     parse_common_options,
0030     parse_detector_io_options,
0031     parse_plotting_options,
0032 )
0033 from plotting import pyplot_factory as plt_factory
0034 from utils import read_detector_name, get_p_range
0035 from utils import add_track_generator_args, add_propagation_args, add_detector_io_args
0036 
0037 # python imports
0038 import argparse
0039 import os
0040 import subprocess
0041 import sys
0042 import json
0043 
0044 
0045 def __main__():
0046 
0047     # ---------------------------------------------------------------arg parsing
0048 
0049     descr = "Detray Navigation Validation"
0050 
0051     # Define options
0052     parent_parsers = [
0053         common_options(descr),
0054         detector_io_options(),
0055         random_track_generator_options(),
0056         propagation_options(),
0057         plotting_options(),
0058     ]
0059 
0060     parser = argparse.ArgumentParser(description=descr, parents=parent_parsers)
0061 
0062     parser.add_argument(
0063         "--bindir",
0064         "-bin",
0065         help=("Directory containing the validation executables"),
0066         default="./bin",
0067         type=str,
0068     )
0069     parser.add_argument(
0070         "--datadir",
0071         "-data",
0072         help=("Directory containing the data files"),
0073         default="./validation_data",
0074         type=str,
0075     )
0076     parser.add_argument(
0077         "--cuda",
0078         help=("Run the CUDA navigation validation."),
0079         action="store_true",
0080         default=False,
0081     )
0082     parser.add_argument(
0083         "--sycl",
0084         help=("Run the SYCL navigation validation."),
0085         action="store_true",
0086         default=False,
0087     )
0088     parser.add_argument(
0089         "--overlaps_tol",
0090         "-ot",
0091         help=("Tolerance for considering surfaces to be overlapping [mm]"),
0092         default=0.0001,
0093         type=float,
0094     )
0095     parser.add_argument(
0096         "--z_range",
0097         "-zrng",
0098         nargs=2,
0099         help=("z range for the xy-view [mm]."),
0100         default=[-50, 50],
0101         type=float,
0102     )
0103     parser.add_argument(
0104         "--hide_portals",
0105         help=("Hide portal surfaces in plots."),
0106         action="store_true",
0107         default=False,
0108     )
0109     parser.add_argument(
0110         "--hide_passives",
0111         help=("Hide passive surfaces in plots."),
0112         action="store_true",
0113         default=False,
0114     )
0115     parser.add_argument(
0116         "--outlier",
0117         "-out",
0118         help=("Threshold for outliers in residual plots [mm]."),
0119         default=1,
0120         type=float,
0121     )
0122 
0123     # Parse options
0124     args = parser.parse_args()
0125 
0126     logging = parse_common_options(args, descr)
0127     parse_detector_io_options(args, logging)
0128     _, out_dir, out_format = parse_plotting_options(args, logging)
0129 
0130     # IO path for data files
0131     datadir = args.datadir.strip("/")
0132 
0133     # Check bin path
0134     bindir = args.bindir.strip("/")
0135     cpu_validation = bindir + "/detray_navigation_validation"
0136     cuda_validation = bindir + "/detray_navigation_validation_cuda"
0137 
0138     if not os.path.isdir(bindir) or not os.path.isfile(cpu_validation):
0139         logging.error(f"Navigation validation binaries were not found! ({args.bindir})")
0140         sys.exit(1)
0141 
0142     # -----------------------------------------------------------------------run
0143 
0144     # Pass on the options for the validation tools
0145     args_list = [
0146         "--data_dir",
0147         datadir,
0148         "--overlaps_tol",
0149         str(args.overlaps_tol),
0150     ]
0151 
0152     # Add parsed options to argument list
0153     add_detector_io_args(args_list, args)
0154     add_track_generator_args(args_list, args)
0155     add_propagation_args(args_list, args)
0156 
0157     logging.debug(args_list)
0158 
0159     # Run the host validation and produce the truth data
0160     logging.debug("Running CPU validation")
0161     subprocess.run([cpu_validation, "--write_scan_data"] + args_list)
0162 
0163     # Run the device validation (if it has been built)
0164     if args.cuda and os.path.isfile(cuda_validation):
0165         logging.debug("Running CUDA validation")
0166         subprocess.run([cuda_validation] + args_list)
0167 
0168     elif args.cuda:
0169         logging.error("Could not find CUDA navigation validation executable")
0170 
0171     if args.sycl:
0172         logging.error("SYCL validation is not implemented")
0173 
0174     # ----------------------------------------------------------------------plot
0175 
0176     logging.info("Generating data plots...\n")
0177 
0178     det_name = read_detector_name(args.geometry_file, logging)
0179     logging.debug("Detector: " + det_name)
0180 
0181     # Check the data path (should have been created when running the validation)
0182     if not os.path.isdir(datadir):
0183         logging.error(f"Data directory was not found! ({args.datadir})")
0184         sys.exit(1)
0185 
0186     plot_factory = plt_factory(out_dir, logging)
0187 
0188     # Read the truth data
0189     p_min, p_max = get_p_range(args, logging)
0190     ray_scan_df, helix_scan_df = read_scan_data(
0191         logging, datadir, det_name, p_min, p_max
0192     )
0193 
0194     # Plot detector scan data
0195     plot_detector_scan_data(
0196         args, det_name, plot_factory, "ray", ray_scan_df, out_format
0197     )
0198     plot_detector_scan_data(
0199         args, det_name, plot_factory, "helix", helix_scan_df, out_format
0200     )
0201 
0202     # Read the recorded intersection data
0203     (
0204         ray_nav_intr_df,
0205         ray_nav_intr_truth_df,
0206         ray_nav_intr_cuda_df,
0207         helix_nav_intr_df,
0208         helix_nav_intr_truth_df,
0209         helix_nav_intr_cuda_df,
0210     ) = read_navigation_intersection_data(
0211         logging, datadir, det_name, p_min, p_max, args.cuda
0212     )
0213 
0214     # Plot intersection data
0215     label_cpu = "navigation (CPU)"
0216     label_cuda = "navigation (CUDA)"
0217 
0218     plot_navigation_intersection_data(
0219         args,
0220         det_name,
0221         plot_factory,
0222         "ray",
0223         ray_nav_intr_truth_df,
0224         ray_nav_intr_df,
0225         label_cpu,
0226         out_format,
0227     )
0228 
0229     plot_navigation_intersection_data(
0230         args,
0231         det_name,
0232         plot_factory,
0233         "helix",
0234         helix_nav_intr_truth_df,
0235         helix_nav_intr_df,
0236         label_cpu,
0237         out_format,
0238     )
0239 
0240     if args.cuda:
0241         plot_navigation_intersection_data(
0242             args,
0243             det_name,
0244             plot_factory,
0245             "ray",
0246             ray_nav_intr_truth_df,
0247             ray_nav_intr_cuda_df,
0248             label_cuda,
0249             out_format,
0250         )
0251 
0252         plot_navigation_intersection_data(
0253             args,
0254             det_name,
0255             plot_factory,
0256             "helix",
0257             helix_nav_intr_truth_df,
0258             helix_nav_intr_cuda_df,
0259             label_cuda,
0260             out_format,
0261         )
0262 
0263     # Plot distributions of track parameter values
0264     # Only take initial track parameters from generator
0265     ray_intial_trk_df = ray_scan_df.drop_duplicates(subset=["track_id"])
0266     helix_intial_trk_df = helix_scan_df.drop_duplicates(subset=["track_id"])
0267     plot_track_params(
0268         args, det_name, "helix", plot_factory, out_format, helix_intial_trk_df
0269     )
0270     plot_track_params(
0271         args, det_name, "ray", plot_factory, out_format, ray_intial_trk_df
0272     )
0273 
0274     # Read the recorded track data
0275     (
0276         ray_nav_df,
0277         ray_truth_df,
0278         ray_nav_cuda_df,
0279         helix_nav_df,
0280         helix_truth_df,
0281         helix_nav_cuda_df,
0282     ) = read_navigation_track_data(logging, datadir, det_name, p_min, p_max, args.cuda)
0283 
0284     # Plot track data
0285     plot_navigation_track_data(
0286         args,
0287         det_name,
0288         plot_factory,
0289         "ray",
0290         ray_truth_df,
0291         "truth",
0292         ray_nav_df,
0293         label_cpu,
0294         out_format,
0295     )
0296 
0297     plot_navigation_track_data(
0298         args,
0299         det_name,
0300         plot_factory,
0301         "helix",
0302         helix_truth_df,
0303         "truth",
0304         helix_nav_df,
0305         label_cpu,
0306         out_format,
0307     )
0308 
0309     if args.cuda:
0310         # Truth vs. Device
0311         plot_navigation_track_data(
0312             args,
0313             det_name,
0314             plot_factory,
0315             "ray",
0316             ray_truth_df,
0317             "truth",
0318             ray_nav_cuda_df,
0319             label_cuda,
0320             out_format,
0321         )
0322 
0323         plot_navigation_track_data(
0324             args,
0325             det_name,
0326             plot_factory,
0327             "helix",
0328             helix_truth_df,
0329             "truth",
0330             helix_nav_cuda_df,
0331             label_cuda,
0332             out_format,
0333         )
0334 
0335         # Host vs. Device
0336         plot_navigation_track_data(
0337             args,
0338             det_name,
0339             plot_factory,
0340             "ray",
0341             ray_nav_df,
0342             label_cpu,
0343             ray_nav_cuda_df,
0344             label_cuda,
0345             out_format,
0346         )
0347 
0348         plot_navigation_track_data(
0349             args,
0350             det_name,
0351             plot_factory,
0352             "helix",
0353             helix_nav_df,
0354             label_cpu,
0355             helix_nav_cuda_df,
0356             label_cuda,
0357             out_format,
0358         )
0359 
0360 
0361 # ------------------------------------------------------------------------------
0362 
0363 if __name__ == "__main__":
0364     __main__()
0365 
0366 # ------------------------------------------------------------------------------