Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-03-29 08:37:26

0001 import argparse
0002 import csv
0003 import os
0004 import sys
0005 from datetime import datetime
0006 from pathlib import Path
0007 import glob
0008 
0009 import matplotlib.pyplot as plt
0010 from matplotlib.backends.backend_pdf import PdfPages
0011 import numpy as np
0012 
0013 repo_root = Path(__file__).resolve().parents[1]
0014 if str(repo_root) not in sys.path:
0015     sys.path.insert(0, str(repo_root))
0016 
0017 import epic_analysis_base as ana
0018 from single_particle_performance.io import build_track_mc_df, load_tree, parse_filename_metadata
0019 
0020 
0021 def _ensure_dir(path):
0022     Path(path).mkdir(parents=True, exist_ok=True)
0023 
0024 
0025 def _plot_efficiency(pdf, mc_primary, df_matched, tag, eta_range=None):
0026     # Diagnostic plot for a single file (not used by run_plot.py).
0027     bins = np.arange(-4, 4, 0.1)
0028     centers, eff, err = compute_binned_efficiency(mc_primary, df_matched, bins)
0029 
0030     fig, (ax_top, ax_bot) = plt.subplots(2, 1, figsize=(7, 6), sharex=True, gridspec_kw={"height_ratios": [3, 2]})
0031     ax_top.hist(mc_primary["eta_mc"].to_numpy(), bins=bins, histtype="step", color="black", label="Gen")
0032     ax_top.hist(df_matched["eta_mc"].to_numpy(), bins=bins, histtype="step", color="tab:blue", label="Reco matched")
0033     ax_top.set_ylabel("Entries")
0034     ax_top.legend(frameon=False)
0035     ax_top.grid(True, alpha=0.3)
0036 
0037     ax_bot.errorbar(centers, eff, yerr=err, fmt="o", ms=3, lw=1, color="tab:blue", ecolor="gray", capsize=2)
0038     ax_bot.axhline(1.0, color="gray", lw=1, ls="--")
0039     ax_bot.set_xlabel(r"$\eta$")
0040     ax_bot.set_ylabel("Efficiency")
0041     ax_bot.set_xlim(bins[0], bins[-1])
0042     ax_bot.set_ylim(0, 1.05)
0043     ax_bot.grid(True, alpha=0.3)
0044     if eta_range:
0045         ax_bot.axvspan(eta_range[0], eta_range[1], color="tab:blue", alpha=0.1)
0046 
0047     fig.suptitle(tag)
0048     plt.tight_layout()
0049     pdf.savefig(fig)
0050     plt.close(fig)
0051 
0052 
0053 def compute_binned_efficiency(mc_primary, df_matched, bins):
0054     # Return bin centers + efficiency/error arrays for plotting later.
0055     eta_gen = mc_primary["eta_mc"].to_numpy()
0056     eta_rec = df_matched["eta_mc"].to_numpy()
0057 
0058     gen_counts, bin_edges = np.histogram(eta_gen, bins=bins)
0059     rec_counts, _ = np.histogram(eta_rec, bins=bins)
0060     centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
0061 
0062     eff = np.divide(rec_counts, gen_counts, out=np.zeros_like(rec_counts, dtype=float), where=gen_counts > 0)
0063     err = np.zeros_like(eff)
0064     mask = gen_counts > 0
0065     err[mask] = np.sqrt(eff[mask] * (1.0 - eff[mask]) / gen_counts[mask])
0066     return centers, eff, err
0067 
0068 
0069 def _plot_distributions(pdf, df, tag):
0070     plots = [
0071         ("pull_theta", "Pull distribution(theta)"),
0072         ("pull_phi", "Pull distribution(phi)"),
0073         ("pull_qoverp", "Pull distribution(q/p)"),
0074         ("resol_theta", "Resolution (theta [mrad])"),
0075         ("resol_phi", "Resolution (phi [mrad])"),
0076         ("resol_dp", "Resolution (dp/p [%])"),
0077     ]
0078     if "resol_dca" in df.columns:
0079         plots.append(("resol_dca", "Resolution (DCA$_r$ [mm])"))
0080 
0081     cols = 3
0082     rows = int(np.ceil(len(plots) / cols))
0083     fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3 * rows))
0084     axes_flat = axes.flatten()
0085 
0086     for ax, (col, title) in zip(axes_flat, plots):
0087         mean, sigma, sigma_err = ana.hist_gaus(df[col], ax=ax, bins=101)
0088         label = title
0089         if np.isfinite(sigma) and np.isfinite(sigma_err):
0090             label = f"{title}\n$\\sigma$={sigma:.3g} ± {sigma_err:.3g}"
0091         ax.set_title(label)
0092         ax.set_ylabel("Entries")
0093         ax.grid(True, alpha=0.3)
0094 
0095     for ax in axes_flat[len(plots):]:
0096         ax.axis("off")
0097 
0098     fig.suptitle(tag)
0099     plt.tight_layout()
0100     pdf.savefig(fig)
0101     plt.close(fig)
0102 
0103 
0104 def _write_row(table_path, row):
0105     # Append to a single CSV to keep a growing summary across runs.
0106     columns = _ordered_columns(row)
0107     exists = os.path.exists(table_path)
0108     with open(table_path, "a", newline="") as handle:
0109         writer = csv.DictWriter(handle, fieldnames=columns)
0110         if not exists:
0111             writer.writeheader()
0112         writer.writerow(_format_row(row))
0113 
0114 
0115 def _format_row(row, sig_digits=5):
0116     formatted = {}
0117     for key, value in row.items():
0118         if isinstance(value, float):
0119             formatted[key] = f"{value:.{sig_digits}g}"
0120         elif isinstance(value, (list, tuple)):
0121             formatted[key] = "[" + ", ".join(f"{vv:.{sig_digits}g}" for vv in value) + "]"
0122         else:
0123             formatted[key] = value
0124     return formatted
0125 
0126 
0127 def _ordered_columns(row):
0128     columns = list(row.keys())
0129     if "timestamp" in columns:
0130         columns.remove("timestamp")
0131         columns.insert(1, "timestamp")
0132     if "source_file" in columns:
0133         columns.remove("source_file")
0134         columns.append("source_file")
0135     return columns
0136 
0137 
0138 def select_by_eta_mom(df, eta_range=None, mom_gev=None, mom_tol=0.02):
0139     # Apply eta/momentum selection in MC truth space.
0140     mask = np.ones(len(df), dtype=bool)
0141     if eta_range is not None:
0142         eta_lo, eta_hi = eta_range
0143         mask &= df["eta_mc"].between(eta_lo, eta_hi, inclusive="left")
0144     if mom_gev is not None:
0145         lo = mom_gev * (1.0 - mom_tol)
0146         hi = mom_gev * (1.0 + mom_tol)
0147         mask &= (df["p_mc"] >= lo) & (df["p_mc"] <= hi)
0148     return df[mask]
0149 
0150 
0151 def compute_metrics(df):
0152     # Compute pull/resolution quantities in the same units as plot_single_resol.py.
0153     valid = (df["cov_5"] > 0) & (df["cov_9"] > 0) & (df["cov_14"] > 0)
0154     valid &= np.isfinite(df["theta_rec"]) & np.isfinite(df["phi_rec"]) & np.isfinite(df["qOverP_rec"])
0155     valid &= np.isfinite(df["theta_mc"]) & np.isfinite(df["phi_mc"]) & np.isfinite(df["qOverP_true"])
0156     df = df[valid].copy()
0157 
0158     df["pull_theta"] = (df["theta_rec"] - df["theta_mc"]) / np.sqrt(df["cov_9"])
0159     df["pull_phi"] = (df["phi_rec"] - df["phi_mc"]) / np.sqrt(df["cov_5"])
0160     df["pull_qoverp"] = (np.abs(df["qOverP_rec"]) - np.abs(df["qOverP_true"])) / np.sqrt(df["cov_14"])
0161 
0162     df["resol_theta"] = (df["theta_rec"] - df["theta_mc"]) * 1000.0
0163     df["resol_phi"] = (df["phi_rec"] - df["phi_mc"]) * 1000.0
0164     df["p_rec"] = 1.0 / np.abs(df["qOverP_rec"])
0165     df["resol_dp"] = (df["p_rec"] - df["p_mc"]) / df["p_mc"] * 100.0
0166 
0167     if "loc.a" in df.columns:
0168         df["resol_dca"] = df["loc.a"]
0169 
0170     return df
0171 
0172 
0173 def compute_efficiency(mc, df_matched, eta_range=None, mom_gev=None, mom_tol=0.02, pid=211):
0174     # Compute inclusive efficiency for the selected eta/momentum region.
0175     mc_primary = mc[mc["generatorStatus"] == 1].copy()
0176     if pid is not None and "PDG" in mc_primary.columns:
0177         mc_primary = mc_primary[mc_primary["PDG"] == pid]
0178 
0179     if eta_range is not None:
0180         eta_lo, eta_hi = eta_range
0181         mc_primary = mc_primary[mc_primary["eta_mc"].between(eta_lo, eta_hi, inclusive="left")]
0182 
0183     if mom_gev is not None:
0184         lo = mom_gev * (1.0 - mom_tol)
0185         hi = mom_gev * (1.0 + mom_tol)
0186         mc_primary = mc_primary[(mc_primary["p_mc"] >= lo) & (mc_primary["p_mc"] <= hi)]
0187 
0188     n_gen = len(mc_primary)
0189     matched = df_matched.copy()
0190     if eta_range is not None:
0191         eta_lo, eta_hi = eta_range
0192         matched = matched[matched["eta_mc"].between(eta_lo, eta_hi, inclusive="left")]
0193     if mom_gev is not None:
0194         lo = mom_gev * (1.0 - mom_tol)
0195         hi = mom_gev * (1.0 + mom_tol)
0196         matched = matched[(matched["p_mc"] >= lo) & (matched["p_mc"] <= hi)]
0197 
0198     matched = matched.drop_duplicates(subset=["entry", "mc_index"])
0199     n_rec = len(matched)
0200 
0201     if n_gen <= 0:
0202         return 0.0, 0.0, n_gen, n_rec
0203 
0204     eff = n_rec / n_gen
0205     eff_err = np.sqrt(eff * (1.0 - eff) / n_gen)
0206     return eff, eff_err, n_gen, n_rec
0207 
0208 
0209 def summarize_metrics(df, columns, bins=101):
0210     # Fit gaussians for summary stats stored in the CSV.
0211     summary = {}
0212     for col in columns:
0213         if col not in df.columns:
0214             summary[col] = {"mean": np.nan, "sigma": np.nan, "sigma_err": np.nan}
0215             continue
0216         mean, sigma, sigma_err = ana.hist_gaus(df[col], ax=None, bins=bins)
0217         summary[col] = {"mean": mean, "sigma": sigma, "sigma_err": sigma_err}
0218     return summary
0219 
0220 
0221 def run_single_file(
0222     fname,
0223     s3_dir="",
0224     eta_range=None,
0225     mom_gev=None,
0226     mom_tol=0.02,
0227     pid=211,
0228     track_params="CentralCKFTrackParameters",
0229     assoc_name="CentralCKFTrackAssociations",
0230     track_collection="CentralCKFTracks",
0231     out_dir=".",
0232     plots_dir="./plots",
0233     table_path="./performance_table.csv",
0234     entry_stop=None,
0235 ):
0236     # Enforce the rec_ naming convention so metadata parsing is consistent.
0237     if not os.path.basename(fname).startswith("rec_"):
0238         print(f"Skipping non-rec file: {fname}")
0239         return None
0240     meta = parse_filename_metadata(fname)
0241     tag = meta["tag"]
0242 
0243     if eta_range is None and meta["eta_lo"] is not None and meta["eta_hi"] is not None:
0244         eta_range = (meta["eta_lo"], meta["eta_hi"])
0245     if mom_gev is None and meta["mom_gev"] is not None:
0246         mom_gev = meta["mom_gev"]
0247 
0248     if eta_range is None or mom_gev is None:
0249         raise ValueError("eta_range and mom_gev must be provided or found in the filename tag")
0250 
0251     tree = load_tree(fname, s3_dir=s3_dir, entry_stop=entry_stop)
0252     df, mc = build_track_mc_df(
0253         tree,
0254         track_params=track_params,
0255         assoc_name=assoc_name,
0256         track_collection=track_collection,
0257     )
0258 
0259     df = compute_metrics(df)
0260     df_sel = select_by_eta_mom(df, eta_range=eta_range, mom_gev=mom_gev, mom_tol=mom_tol)
0261 
0262     mc_primary = mc[mc["generatorStatus"] == 1].copy()
0263     if pid is not None and "PDG" in mc_primary.columns:
0264         mc_primary = mc_primary[mc_primary["PDG"] == pid]
0265     if mom_gev is not None:
0266         lo = mom_gev * (1.0 - mom_tol)
0267         hi = mom_gev * (1.0 + mom_tol)
0268         mc_primary = mc_primary[(mc_primary["p_mc"] >= lo) & (mc_primary["p_mc"] <= hi)]
0269 
0270     eff, eff_err, n_gen, n_rec = compute_efficiency(
0271         mc,
0272         df,
0273         eta_range=eta_range,
0274         mom_gev=mom_gev,
0275         mom_tol=mom_tol,
0276         pid=pid,
0277     )
0278 
0279     # Bin-by-bin efficiency is stored for run_plot.py aggregation.
0280     eff_bins, eff_vals, eff_errs = compute_binned_efficiency(mc_primary, df_sel, np.arange(-4, 4, 0.1))
0281 
0282     metric_columns = [
0283         "pull_theta",
0284         "pull_phi",
0285         "pull_qoverp",
0286         "resol_theta",
0287         "resol_phi",
0288         "resol_dp",
0289         "resol_dca",
0290     ]
0291     summary = summarize_metrics(df_sel, metric_columns)
0292 
0293     row = {
0294         "tag": tag,
0295         "source_file": os.path.abspath(fname),
0296         "setting": meta.get("setting"),
0297         "eta_lo": eta_range[0],
0298         "eta_hi": eta_range[1],
0299         "mom_gev": mom_gev,
0300         "pid": pid,
0301         "mom_tol": mom_tol,
0302         "n_gen": n_gen,
0303         "n_rec": n_rec,
0304         "eff": eff,
0305         "eff_err": eff_err,
0306         "eff_bins": eff_bins.tolist(),
0307         "eff_values": eff_vals.tolist(),
0308         "eff_errors": eff_errs.tolist(),
0309         "timestamp": datetime.now().isoformat(timespec="seconds"),
0310     }
0311 
0312     for col, stats in summary.items():
0313         row[f"{col}_mean"] = stats["mean"]
0314         row[f"{col}_sigma"] = stats["sigma"]
0315         row[f"{col}_sigma_err"] = stats["sigma_err"]
0316 
0317     _ensure_dir(out_dir)
0318     _ensure_dir(plots_dir)
0319     _write_row(table_path, row)
0320 
0321     pdf_path = os.path.join(plots_dir, f"{tag}_performance.pdf")
0322     with PdfPages(pdf_path) as pdf:
0323         _plot_efficiency(pdf, mc_primary, df_sel, tag, eta_range=eta_range)
0324         _plot_distributions(pdf, df_sel, tag)
0325 
0326     return row
0327 
0328 
0329 def main():
0330     parser = argparse.ArgumentParser(description="Run single-particle performance study on ROOT files.")
0331     parser.add_argument("inputs", nargs="*", help="Input ROOT files")
0332     parser.add_argument("--pattern", default=None, help="Glob pattern for input ROOT files")
0333     parser.add_argument("--s3-dir", default="", help="Remote s3/xrootd directory")
0334     parser.add_argument("--eta", nargs=2, type=float, metavar=("ETA_LO", "ETA_HI"), help="Override eta range")
0335     parser.add_argument("--mom", type=float, help="Override momentum in GeV")
0336     parser.add_argument("--mom-tol", type=float, default=0.02, help="Momentum tolerance as fraction")
0337     parser.add_argument("--pid", type=int, default=211, help="PDG id (default: pi+ 211)")
0338     parser.add_argument("--track-params", default="CentralCKFTrackParameters")
0339     parser.add_argument("--assoc-name", default="CentralCKFTrackAssociations")
0340     parser.add_argument("--track-collection", default="CentralCKFTracks")
0341     parser.add_argument("--out-dir", default=".")
0342     parser.add_argument("--plots-dir", default="./plots")
0343     parser.add_argument("--table-path", default="./performance_table.csv")
0344     parser.add_argument("--entry-stop", type=int, default=None)
0345     parser.add_argument("--skip-existing", action="store_true", help="Skip files already in the CSV table")
0346 
0347     args = parser.parse_args()
0348 
0349     ana.configure_analysis_environment()
0350 
0351     inputs = list(args.inputs)
0352     if args.pattern:
0353         inputs.extend(sorted(glob.glob(args.pattern)))
0354 
0355     if not inputs:
0356         raise SystemExit("No input files provided. Use positional files or --pattern.")
0357 
0358     if args.skip_existing and os.path.exists(args.table_path):
0359         existing_tags = set()
0360         with open(args.table_path, newline="") as handle:
0361             reader = csv.DictReader(handle)
0362             for row in reader:
0363                 tag = row.get("tag")
0364                 if tag:
0365                     existing_tags.add(tag)
0366         if existing_tags:
0367             inputs = [f for f in inputs if parse_filename_metadata(f)["tag"] not in existing_tags]
0368 
0369     if not inputs:
0370         print("No new files to process after applying --skip-existing.")
0371         return
0372 
0373     for fname in inputs:
0374         run_single_file(
0375             fname,
0376             s3_dir=args.s3_dir,
0377             eta_range=tuple(args.eta) if args.eta else None,
0378             mom_gev=args.mom,
0379             mom_tol=args.mom_tol,
0380             pid=args.pid,
0381             track_params=args.track_params,
0382             assoc_name=args.assoc_name,
0383             track_collection=args.track_collection,
0384             out_dir=args.out_dir,
0385             plots_dir=args.plots_dir,
0386             table_path=args.table_path,
0387             entry_stop=args.entry_stop,
0388         )
0389 
0390 
0391 if __name__ == "__main__":
0392     main()