Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:24:18

0001 # This file is part of the ACTS project.
0002 #
0003 # Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 #
0005 # This Source Code Form is subject to the terms of the Mozilla Public
0006 # License, v. 2.0. If a copy of the MPL was not distributed with this
0007 # file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 # detray includes
0010 import plotting
0011 
0012 # python includes
0013 import numpy as np
0014 import pandas as pd
0015 import math
0016 import os
0017 import sys
0018 
0019 # Common options
0020 lgd_loc = "upper right"
0021 
0022 """ Read the detector scan data from files and prepare data frames """
0023 
0024 
0025 def read_detector_scan_data(intersection_file, track_param_file, logging):
0026     if intersection_file:
0027         inters_df = pd.read_csv(intersection_file, float_precision="round_trip")
0028         trk_param_df = pd.read_csv(track_param_file, float_precision="round_trip")
0029         scan_df = pd.concat([inters_df, trk_param_df], axis=1)
0030 
0031         logging.debug(scan_df)
0032     else:
0033         logging.warning("Could not find ray scan data: " + intersection_file)
0034         scan_df = pd.DataFrame({})
0035 
0036     return scan_df
0037 
0038 
0039 """ Read intersection data """
0040 
0041 
0042 def read_intersection_data(file, logging):
0043     if file:
0044         # Preserve floating point precision
0045         df = pd.read_csv(file, float_precision="round_trip")
0046         logging.debug(df)
0047 
0048         return df
0049     else:
0050         logging.error("Could not find intersection data file: " + file)
0051         sys.exit(1)
0052 
0053 
0054 """ Plot the intersection points of the detector with the rays - xy view """
0055 
0056 
0057 def plot_intersection_points_xy(
0058     opts, df, detector, scan_type, plot_factory, out_format="png"
0059 ):
0060 
0061     n_rays = np.max(df["track_id"]) + 1
0062     tracks = "rays" if scan_type == "ray" else "helices"
0063 
0064     # Reduce data to the requested z-range (50mm tolerance)
0065     min_z = opts.z_range[0]
0066     max_z = opts.z_range[1]
0067     assert min_z < max_z, "xy plotting range: min z must be smaller that max z"
0068     sensitive_range = lambda data: (
0069         (data["z"] > min_z) & (data["z"] < max_z) & (data["type"] == 1)
0070     )
0071     portal_range = lambda data: (
0072         (data["z"] > min_z) & (data["z"] < max_z) & (data["type"] == 0)
0073     )
0074     passive_range = lambda data: (
0075         (data["z"] > min_z) & (data["z"] < max_z) & (data["type"] == 2)
0076     )
0077 
0078     senstive_x, senstive_y = plotting.filter_data(
0079         data=df, filter=sensitive_range, variables=["x", "y"]
0080     )
0081 
0082     # Plot the xy coordinates of the filtered intersections points
0083     lgd_ops = plotting.legend_options(
0084         loc="upper center",
0085         ncol=4,
0086         colspacing=0.4,
0087         handletextpad=0.005,
0088         horiz_anchor=0.5,
0089         vert_anchor=1.095,
0090     )
0091 
0092     hist_data = plot_factory.scatter(
0093         figsize=(10, 10),
0094         x=senstive_x,
0095         y=senstive_y,
0096         x_axis=plotting.axis_options(label=r"$x\,\mathrm{[mm]}$"),
0097         y_axis=plotting.axis_options(label=r"$y\,\mathrm{[mm]}$"),
0098         label="sensitives",
0099         color="C5",
0100         show_stats=lambda x, y: f"{n_rays} {tracks}",
0101         lgd_ops=lgd_ops,
0102     )
0103 
0104     # Portal surfaces
0105     if not opts.hide_portals:
0106         portal_x, portal_y = plotting.filter_data(
0107             data=df, filter=portal_range, variables=["x", "y"]
0108         )
0109 
0110         plot_factory.highlight_region(hist_data, portal_x, portal_y, "C0", "portals")
0111 
0112     # Passive surfaces
0113     if not opts.hide_passives:
0114         passive_x, passive_y = plotting.filter_data(
0115             data=df, filter=passive_range, variables=["x", "y"]
0116         )
0117 
0118         plot_factory.highlight_region(hist_data, passive_x, passive_y, "C2", "passives")
0119 
0120     # Set aspect ratio
0121     hist_data.axes.set_aspect("equal")
0122 
0123     detector_name = detector.replace(" ", "_")
0124     plot_factory.write_plot(
0125         hist_data, f"{detector_name}_{scan_type}_scan_xy", out_format
0126     )
0127 
0128 
0129 """ Plot the intersection points of the detector with the rays - rz view """
0130 
0131 
0132 def plot_intersection_points_rz(
0133     opts, df, detector, scan_type, plot_factory, out_format="png"
0134 ):
0135 
0136     n_rays = np.max(df["track_id"]) + 1
0137     tracks = "rays" if scan_type == "ray" else "helices"
0138 
0139     # Reduce data to the requested z-range
0140     sensitive_range = lambda data: (data["type"] == 1)
0141     portal_range = lambda data: (data["type"] == 0)
0142     passive_range = lambda data: (data["type"] == 2)
0143 
0144     sensitive_x, sensitive_y, sensitive_z = plotting.filter_data(
0145         data=df, filter=sensitive_range, variables=["x", "y", "z"]
0146     )
0147 
0148     # Plot the xy coordinates of the filtered intersections points
0149     lgd_ops = plotting.legend_options(
0150         loc="upper center",
0151         ncol=4,
0152         colspacing=0.8,
0153         handletextpad=0.005,
0154         horiz_anchor=0.5,
0155         vert_anchor=1.165,
0156     )
0157 
0158     hist_data = plot_factory.scatter(
0159         figsize=(12, 6),
0160         x=sensitive_z,
0161         y=np.hypot(sensitive_x, sensitive_y),
0162         x_axis=plotting.axis_options(label=r"$z\,\mathrm{[mm]}$"),
0163         y_axis=plotting.axis_options(label=r"$r\,\mathrm{[mm]}$"),
0164         label="sensitives",
0165         color="C5",
0166         show_stats=lambda x, y: f"{n_rays} {tracks}",
0167         lgd_ops=lgd_ops,
0168     )
0169 
0170     # Portal surfaces
0171     if not opts.hide_portals:
0172         portal_x, portal_y, portal_z = plotting.filter_data(
0173             data=df, filter=portal_range, variables=["x", "y", "z"]
0174         )
0175 
0176         plot_factory.highlight_region(
0177             hist_data, portal_z, np.hypot(portal_x, portal_y), "C0", "portals"
0178         )
0179 
0180     # Passive surfaces
0181     if not opts.hide_passives:
0182         passive_x, passive_y, passive_z = plotting.filter_data(
0183             data=df, filter=passive_range, variables=["x", "y", "z"]
0184         )
0185 
0186         plot_factory.highlight_region(
0187             hist_data, passive_z, np.hypot(passive_x, passive_y), "C2", "passives"
0188         )
0189 
0190     detector_name = detector.replace(" ", "_")
0191     plot_factory.write_plot(
0192         hist_data, f"{detector_name}_{scan_type}_scan_rz", out_format
0193     )
0194 
0195 
0196 """ Plot the intersection local position residual for the given variable """
0197 
0198 
0199 def plot_intersection_pos_res(
0200     opts, detector, plot_factory, scan_type, df1, label1, df2, label2, var, out_format
0201 ):
0202 
0203     tracks = "rays" if scan_type == "ray" else "helices"
0204 
0205     # Filter the relevant data from the frame (sensitive = 1, hole = 15)
0206     is_sensitive = lambda data_frame: (
0207         (data_frame["type"] == 1) | (data_frame["type"] == 15)
0208     )
0209 
0210     var_truth, track_ids = plotting.filter_data(
0211         data=df1, filter=is_sensitive, variables=[var, "track_id"]
0212     )
0213     var_nav = plotting.filter_data(data=df2, filter=is_sensitive, variables=[var])
0214 
0215     assert len(var_truth) == len(var_nav)
0216     res = var_truth - var_nav
0217 
0218     # Remove outliers (happens when comparing a hole with a valid intersection)
0219     filter_res = np.absolute(res) < opts.outlier
0220     filtered_res = res[filter_res]
0221 
0222     u_out = o_out = int(0)
0223     if not np.all(filter_res == True):
0224         print(f"\nRemoved outliers ({var}):")
0225         for i, r in enumerate(res):
0226             if math.fabs(r) > opts.outlier:
0227                 print(f"track {track_ids[i]}: {var_truth[i]} - {var_nav[i]} = {r}")
0228 
0229                 if r < 0.0:
0230                     u_out = u_out + 1
0231                 else:
0232                     o_out = o_out + 1
0233 
0234     lgd_ops = plotting.legend_options(
0235         loc=lgd_loc,
0236         ncol=4,
0237         colspacing=0.01,
0238         handletextpad=0.0005,
0239         horiz_anchor=1.02,
0240         vert_anchor=1.28,
0241     )
0242 
0243     # Plot the residuals as a histogram and fit a gaussian to it
0244     var_for_label = var.replace("_", "\,")
0245     hist_data = plot_factory.hist1D(
0246         x=filtered_res,
0247         figsize=(9, 9),
0248         bins=100,
0249         x_axis=plotting.axis_options(
0250             label=r"$\mathrm{res}" + rf"~{var_for_label}" + r"\,\mathrm{[mm]}$"
0251         ),
0252         lgd_ops=lgd_ops,
0253         u_outlier=u_out,
0254         o_outlier=o_out,
0255     )
0256 
0257     mu, sig = plot_factory.fit_gaussian(hist_data)
0258     if mu is None or sig is None:
0259         print(rf"WARNING: fit failed (res ({tracks}): {label1} - {label2} )")
0260 
0261     detector_name = detector.replace(" ", "_")
0262     l1 = label1.replace(" ", "_").replace("(", "").replace(")", "")
0263     l2 = label2.replace(" ", "_").replace("(", "").replace(")", "")
0264 
0265     plot_factory.write_plot(
0266         hist_data, f"{detector_name}_{scan_type}_intr_res_{var}_{l1}_{l2}", out_format
0267     )