Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-03-29 08:37:26

0001 import argparse
0002 import os
0003 import sys
0004 from pathlib import Path
0005 
0006 import ast
0007 import numpy as np
0008 import pandas as pd
0009 from matplotlib import pyplot as plt
0010 from matplotlib.backends.backend_pdf import PdfPages
0011 
0012 repo_root = Path(__file__).resolve().parents[1]
0013 if str(repo_root) not in sys.path:
0014     sys.path.insert(0, str(repo_root))
0015 
0016 plt.rcParams['figure.figsize'] = [8.0, 6.0]
0017 plt.rcParams['ytick.direction'] = 'in'
0018 plt.rcParams['xtick.direction'] = 'in'
0019 plt.rcParams['xaxis.labellocation'] = 'right'
0020 plt.rcParams['yaxis.labellocation'] = 'top'
0021 SMALL_SIZE = 10
0022 MEDIUM_SIZE = 12
0023 BIGGER_SIZE = 16
0024 plt.rc('font', size=SMALL_SIZE)
0025 plt.rc('axes', titlesize=MEDIUM_SIZE)
0026 plt.rc('axes', labelsize=MEDIUM_SIZE)
0027 plt.rc('xtick', labelsize=MEDIUM_SIZE)
0028 plt.rc('ytick', labelsize=MEDIUM_SIZE)
0029 plt.rc('legend', fontsize=SMALL_SIZE)
0030 plt.rc('figure', titlesize=BIGGER_SIZE)
0031 
0032 deg2rad = np.pi / 180.0
0033 
0034 
0035 def theta2eta(xx, inverse=0):
0036     xx = np.array(xx)
0037     if inverse == 1:
0038         return np.arctan((np.e) ** (-xx)) * 2
0039     return -np.log(np.tan(xx / 2.0))
0040 
0041 
0042 def _load_pwg_table(pwg_file):
0043     candidates = []
0044     if pwg_file:
0045         candidates.append(pwg_file)
0046     candidates.append(str(Path.cwd() / "pwg_requirements.txt"))
0047     candidates.append(str(Path(__file__).with_name("pwg_requirements.txt")))
0048     candidates.append(
0049         "/global/cfs/cdirs/m3763/shujie/worksim/snippets/Tracking/PerformanceStudy/SinglePion/7etabins/"
0050         "pwg_requirements.txt"
0051     )
0052     for path in candidates:
0053         if path and os.path.exists(path):
0054             return pd.read_csv(path, sep=r"\s+", skiprows=1)
0055     return None
0056 
0057 
0058 def pwg_value(pwg, varname, eta, mom):
0059     if pwg is None:
0060         return None
0061     if varname not in ("dca", "dp"):
0062         return None
0063     cond = (pwg.eta_lo <= eta) & (pwg.eta_hi > eta)
0064     a = pwg[varname + "_par1"].values[cond][0]
0065     b = pwg[varname + "_par2"].values[cond][0]
0066     x = mom
0067     if varname == "dca":
0068         x *= np.sin(theta2eta(eta, 1))
0069         return np.sqrt((a / 1000.0 / x) ** 2 + (b / 1000.0) ** 2)
0070     if varname == "dp":
0071         return np.sqrt((a * x) ** 2 + b ** 2)
0072     return None
0073 
0074 
0075 def _eta_center(df):
0076     return 0.5 * (df["eta_lo"] + df["eta_hi"])
0077 
0078 
0079 def _parse_list(value):
0080     # Parse CSV list columns and coerce single NaN entries to 0.
0081     if value is None:
0082         return None
0083     parsed = value
0084     if not isinstance(value, (list, tuple, np.ndarray)):
0085         try:
0086             text = str(value)
0087             text = text.replace("nan", "None").replace("NaN", "None")
0088             parsed = ast.literal_eval(text)
0089         except Exception:
0090             return None
0091     if not isinstance(parsed, (list, tuple, np.ndarray)):
0092         return None
0093     cleaned = [0.0 if v is None or (isinstance(v, float) and np.isnan(v)) else v for v in parsed]
0094     arr = np.array(cleaned, dtype=float)
0095     return np.nan_to_num(arr, nan=0.0)
0096 
0097 
0098 def plot_eff(df, out_path=None, setting="", title=None):
0099     if setting:
0100         df = df[df["setting"] == setting]
0101 
0102     df = df.dropna(subset=["eta_lo", "eta_hi", "mom_gev", "eff"]).copy()
0103     if df.empty:
0104         print("No data to plot for efficiency.")
0105         return
0106 
0107     mom_list = [0.5, 1, 2, 5, 10, 15]
0108     line_styles = [(0, (3, 3, 1, 2)), '--', '-.', ':', '-', (0, (3, 1, 1, 1))]
0109 
0110     eta_bins = None
0111     if "eff_bins" in df.columns:
0112         for val in df["eff_bins"]:
0113             parsed = _parse_list(val)
0114             if parsed is not None:
0115                 eta_bins = parsed
0116                 break
0117 
0118     plt.figure()
0119     for ii, mom in enumerate(mom_list):
0120         dft = df[np.isclose(df["mom_gev"], mom)].copy()
0121         if dft.empty:
0122             continue
0123         if "eff_values" in dft.columns and "eff_errors" in dft.columns and eta_bins is not None:
0124             ys = None
0125             cnt = None
0126             for _, row in dft.iterrows():
0127                 yy = _parse_list(row["eff_values"])
0128                 ee = _parse_list(row["eff_errors"])
0129 
0130                 if yy is None or ee is None:
0131                     continue
0132                 if ys is None:
0133                     ys = yy.astype(float)
0134                     cnt = (yy != 0).astype(int)
0135                 else:
0136                     ys = ys + yy
0137                     cnt = cnt + (yy != 0).astype(int)
0138             if ys is None:
0139                 continue
0140             cnt[cnt == 0] = 1
0141             ys = ys / cnt
0142             plt.plot(eta_bins, ys, ls=line_styles[ii], label=f"{mom} GeV")
0143         else:
0144             dft["eta_center"] = _eta_center(dft)
0145             dft = dft.groupby("eta_center")["eff"].mean().reset_index()
0146             dft = dft.sort_values("eta_center")
0147             plt.plot(dft["eta_center"], dft["eff"], ls=line_styles[ii], label=f"{mom} GeV")
0148 
0149     plt.legend(frameon=0, loc="upper left", ncol=2, fontsize=13)
0150     plt.ylim(0.0, 1.4)
0151     plt.xlim(-4, 4)
0152     plt.xlabel("$\\eta$")
0153     plt.ylabel("efficiency")
0154     plt.grid()
0155     if title:
0156         plt.title(title)
0157     if out_path:
0158         plt.savefig(out_path)
0159         plt.close()
0160         return None
0161     fig = plt.gcf()
0162     return fig
0163 
0164 
0165 def plot_resol(df, varname, out_path, setting1="", setting2="", pwg=None, save=True):
0166     # Match the legacy 7-eta-bin layout and PWG overlays.
0167     df = df.dropna(subset=["eta_lo", "eta_hi", "mom_gev"]).copy()
0168     df["name"] = df["setting"].fillna("")
0169     df["mom"] = df["mom_gev"]
0170 
0171     eta_lo_pwg = [-3.5, -3.0, -2.5, -1.0, 1.0, 2.5, 3.0]
0172     eta_hi_pwg = [-3.0, -2.5, -1.0, 1.0, 2.5, 3.0, 3.5]
0173 
0174     if varname == "th":
0175         y_hi = 0.01
0176         yname = r"$\theta$ [rad]"
0177         xname = "momentum [GeV]"
0178         x_hi = [20, 20, 20, 20, 20, 20, 20]
0179         sig_col = "resol_theta_sigma"
0180         err_col = "resol_theta_sigma_err"
0181         scale = 1.0 / 1000.0
0182     elif varname == "ph":
0183         y_hi = 0.025
0184         yname = r"$\phi$ [rad]"
0185         xname = "momentum [GeV]"
0186         x_hi = [20, 20, 20, 20, 20, 20, 20]
0187         sig_col = "resol_phi_sigma"
0188         err_col = "resol_phi_sigma_err"
0189         scale = 1.0 / 1000.0
0190     elif varname == "dp":
0191         y_hi = 12
0192         yname = r"$\delta p/p$ [%]"
0193         xname = "momentum [GeV/c]"
0194         x_hi = [20, 20, 20, 20, 20, 20, 20]
0195         sig_col = "resol_dp_sigma"
0196         err_col = "resol_dp_sigma_err"
0197         scale = 1.0
0198     elif varname == "dca":
0199         y_hi = 1
0200         yname = "DCA$_r$ [mm]"
0201         xname = "pT [GeV]"
0202         x_hi = [1.5, 2.5, 5, 10, 5, 2.5, 1.5]
0203         sig_col = "resol_dca_sigma"
0204         err_col = "resol_dca_sigma_err"
0205         scale = 1.0
0206     else:
0207         print("ERROR(plot_resol): please use a valid varname: th, ph, dp, dca")
0208         return
0209 
0210     fig, axs = plt.subplots(2, 4, figsize=(16, 8))
0211     axs = axs.flat
0212 
0213     for ii, e_lo in enumerate(eta_lo_pwg):
0214         e_hi = eta_hi_pwg[ii]
0215         ax = axs[ii]
0216 
0217         c1 = df.eta_lo >= e_lo - 0.01
0218         c2 = df.eta_hi <= e_hi + 0.01
0219         dft = df[c1 & c2]
0220 
0221         if len(dft) == 0:
0222             continue
0223 
0224         dft = dft[["name", "mom", sig_col, err_col]].groupby(["name", "mom"]).mean().reset_index()
0225 
0226         if setting1:
0227             cond = (dft.name == setting1) & (dft[sig_col] > 0)
0228         else:
0229             cond = dft[sig_col] > 0
0230 
0231         if varname == "dca":
0232             xdata = dft.mom[cond] * np.sin(theta2eta((e_lo + e_hi) / 2, 1))
0233         else:
0234             xdata = dft.mom[cond]
0235         ax.errorbar(
0236             xdata,
0237             dft[sig_col][cond] * scale,
0238             yerr=dft[err_col][cond] * scale,
0239             color="b",
0240             ls="none",
0241             marker="o",
0242         )
0243 
0244         if setting2:
0245             cond = (dft.name == setting2) & (dft[sig_col] > 0)
0246             if varname == "dca":
0247                 xdata = dft.mom[cond] * np.sin(theta2eta((e_lo + e_hi) / 2, 1))
0248             else:
0249                 xdata = dft.mom[cond]
0250             ax.errorbar(
0251                 xdata,
0252                 dft[sig_col][cond] * scale,
0253                 yerr=dft[err_col][cond] * scale,
0254                 color="r",
0255                 ls="none",
0256                 marker="x",
0257             )
0258 
0259         xline = np.arange(0.001, 20, 0.001)
0260         y_pwg = pwg_value(pwg, varname, e_lo, xline) if pwg is not None else None
0261         if y_pwg is not None:
0262             ax.plot(xline, y_pwg, 'k--', zorder=10)
0263 
0264         ax.set_ylim(-y_hi * 0.05, y_hi)
0265         ax.set_xlim(0, x_hi[ii] * 1.05)
0266         ax.text(x_hi[ii] * 0.1, y_hi * 0.9, f"{e_lo}<$\\eta$<{e_hi}", fontsize=14)
0267 
0268     ax = axs[7]
0269     ax.axis('off')
0270     xline = np.arange(0.001, 20, 0.001)
0271     y_pwg = pwg_value(pwg, varname, eta_lo_pwg[0], xline) if pwg is not None else None
0272     sim_label = setting1 or "Simulation"
0273     if y_pwg is not None:
0274         ax.plot(xline, y_pwg - 100000, "k--", label="PWG Requirements")
0275         ax.errorbar(xline, y_pwg - 100000, ls="none", marker="o", color="blue", label=sim_label)
0276     else:
0277         ax.errorbar(xline, xline * 0 - 100000, ls="none", marker="o", color="blue", label=sim_label)
0278     if setting2:
0279         ax.errorbar(xline, xline * 0 - 100000, ls="none", marker="x", color="r", label=setting2)
0280     ax.set_ylim(0, 1)
0281     ax.legend(frameon=0, loc="upper left", fontsize=16)
0282 
0283     axs[4].set_xlabel(xname)
0284     axs[5].set_xlabel(xname)
0285     axs[6].set_xlabel(xname)
0286 
0287     axs[0].set_ylabel(yname)
0288     axs[4].set_ylabel(yname)
0289 
0290     plt.tight_layout()
0291     if save:
0292         plt.savefig(out_path)
0293         plt.close()
0294         return None
0295     return fig
0296 
0297 
0298 def main():
0299     parser = argparse.ArgumentParser(description="Plot single-particle performance summaries.")
0300     parser.add_argument("--table", default="performance_table.csv")
0301     parser.add_argument("--out-dir", default="./plots")
0302     parser.add_argument("--setting", default="")
0303     parser.add_argument("--setting2", default="")
0304     # Always plot efficiency; no flag needed.
0305     parser.add_argument("--plot-resol", nargs="*", default=["dp", "th", "ph", "dca"])
0306     parser.add_argument("--pwg-file", default="")
0307 
0308     args = parser.parse_args()
0309 
0310     df = pd.read_csv(args.table)
0311     out_dir = Path(args.out_dir)
0312     out_dir.mkdir(parents=True, exist_ok=True)
0313 
0314     pwg = _load_pwg_table(args.pwg_file)
0315 
0316     tag = args.setting or "all"
0317     if args.setting2:
0318         tag = f"{tag}_{args.setting2}"
0319 
0320     combined_path = str(out_dir / f"tracking_single_particle_perf_{tag}.pdf")
0321     print(f"Plotting efficiency and resolutions -> {combined_path}")
0322     with PdfPages(combined_path) as pdf:
0323         eff_fig = plot_eff(df, out_path=None, setting=args.setting, title=args.setting or "all")
0324         if eff_fig is not None:
0325             pdf.savefig(eff_fig)
0326             plt.close(eff_fig)
0327 
0328         if args.setting2:
0329             eff_fig = plot_eff(df, out_path=None, setting=args.setting2, title=args.setting2)
0330             if eff_fig is not None:
0331                 pdf.savefig(eff_fig)
0332                 plt.close(eff_fig)
0333 
0334         if args.plot_resol:
0335             for varname in args.plot_resol:
0336                 fig = plot_resol(
0337                     df,
0338                     varname,
0339                     combined_path,
0340                     setting1=args.setting,
0341                     setting2=args.setting2,
0342                     pwg=pwg,
0343                     save=False,
0344                 )
0345                 if fig is not None:
0346                     pdf.savefig(fig)
0347                     plt.close(fig)
0348 
0349 
0350 if __name__ == "__main__":
0351     main()