File indexing completed on 2026-03-29 08:37:26
0001 '''
0002 Utility functions for epic tracking analysis.
0003 See also epic_analysis_podio.py
0004 Shujie Li, Aug 2025
0005 '''
0006
0007
0008 import numpy as np
0009 import pandas as pd
0010 import seaborn as sns
0011
0012 import awkward as ak
0013 import uproot as ur
0014
0015
0016
0017 import time
0018 from fnmatch import fnmatch
0019 import types
0020 from particle import Particle
0021
0022 from matplotlib.backends.backend_pdf import PdfPages
0023 from matplotlib.colors import LogNorm
0024 from matplotlib import pyplot as plt
0025 from matplotlib.gridspec import GridSpec
0026 import matplotlib.ticker as ticker
0027 import matplotlib.cm as cm
0028 import matplotlib as mpl
0029
0030 def configure_analysis_environment(
0031 apply_pandas=True,
0032 apply_matplotlib=True,
0033 apply_sns=True
0034 ):
0035 """Apply pandas/matplotlib defaults for interactive analysis."""
0036 if apply_pandas:
0037 pd.options.display.max_rows = 200
0038 pd.options.display.min_rows = 20
0039 pd.options.display.max_columns = 100
0040 if apply_matplotlib:
0041 plt.rcParams['figure.figsize'] = [8.0, 6.0]
0042 plt.rcParams['ytick.direction'] = 'in'
0043 plt.rcParams['xtick.direction'] = 'in'
0044 plt.rcParams['xaxis.labellocation'] = 'right'
0045 plt.rcParams['yaxis.labellocation'] = 'top'
0046 small_size = 10
0047 medium_size = 12
0048 bigger_size = 20
0049 plt.rc('font', size=small_size)
0050 plt.rc('axes', titlesize=medium_size)
0051 plt.rc('axes', labelsize=medium_size)
0052 plt.rc('xtick', labelsize=medium_size)
0053 plt.rc('ytick', labelsize=medium_size)
0054 plt.rc('legend', fontsize=small_size)
0055 plt.rc('figure', titlesize=bigger_size)
0056 if apply_sns:
0057 sns.set_theme(
0058 style='whitegrid',
0059 context='notebook',
0060 palette='bright',
0061 font_scale=1.0,
0062 rc={'figure.figsize': (6, 4)},
0063 )
0064
0065
0066 deg2rad = np.pi/180.0
0067
0068
0069 status_to_source = {
0070 1: "DIS", 2: "DIS",
0071 2001: "SR", 2002: "SR",
0072 3001: "Bremstrahlung", 3002: "Bremstrahlung",
0073 4001: "Coulomb", 4002: "Coulomb",
0074 5001: "Touschek", 5002: "Touschek",
0075 6001: "Proton beam gas", 6002: "Proton beam gas"
0076 }
0077
0078
0079 TRACK_HIT_COUNT_MIN_MIN = 3
0080 TRACK_HIT_COUNT_MIN = 4
0081 TRACK_MOM_MIN = 0.3
0082 TRACK_PT_MIN = 0.2
0083 TRACK_HIT_FRACTION_MIN = 0.5
0084 TRACK_HIT_COUNT_GHOST_MAX = 2
0085 VERTEX_CUT_R_MAX = 2
0086 VERTEX_CUT_Z_MAX = 200
0087
0088
0089 barrel_range = [(30,42),(46,60),(115,130),(250,290),(410,450),(540,600),(620,655),(700,760)]
0090 barrel_name = ["L0","L1","L2","L3","L4","inner MPGD","TOF","outer MPGD"]
0091 name_sim_barrel = ["VertexBarrelHits","VertexBarrelHits","VertexBarrelHits","SiBarrelHits","SiBarrelHits","MPGDBarrelHits","TOFBarrelHits","OuterMPGDBarrelHits"]
0092 name_rec_barrel = ["SiBarrelVertexRecHits","SiBarrelVertexRecHits","SiBarrelVertexRecHits","SiBarrelTrackerRecHits","SiBarrelTrackerRecHits","MPGDBarrelRecHits","TOFBarrelRecHits","OuterMPGDBarrelRecHits"]
0093
0094 disk_range = [(-1210.0, -1190.0), (-1110.0, -1090.0),(-1055.0, -1000.0), (-860.0, -840.0),
0095 (-660.0, -640.0), (-460.0, -440.0), (-260.0, -240.0), (240.0, 260.0),
0096 (440.0, 460.0), (690.0, 710.0), (990.0, 1010.0), (1340.0, 1360.0),
0097 (1480.0, 1500.0), (1600.0, 1620.0), (1840.0, 1860.0), (1865.0, 1885.0)]
0098 disk_name = ["E-MPGD Disk2","E-MPGD Disk 1","E-Si Disk 4","E-Si Disk 3","E-Si Disk 2","E-Si Disk 1","E-Si Disk 0",
0099 "H-Si Disk 0","H-Si Disk 1","H-Si Disk 2","H-Si Disk 3","H-Si Disk 4","H-MPGD Disk 1","H-MPGD Disk 2", "H-TOF Disk1","H-TOF Disk2"]
0100 name_rec_disk = ["BackwardMPGDEndcapRecHits","BackwardMPGDEndcapRecHits",
0101 "SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits",
0102 "ForwardMPGDEndcapRecHits","ForwardMPGDEndcapRecHits",
0103 "TOFEndcapRecHits","TOFEndcapRecHits"]
0104 name_sim_disk = ["BackwardMPGDEndcapHits","BackwardMPGDEndcapHits",
0105 "TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","ForwardMPGDEndcapHits","ForwardMPGDEndcapHits",
0106 "TOFEndcapHits","TOFEndcapHits"]
0107
0108
0109 geo_mask_dict = {
0110 "approach": 0x0000000ff0000000,
0111 "boundary": 0x00ff000000000000,
0112 "extra": 0x00000000000000ff,
0113 "layer": 0x0000fff000000000,
0114 "sensitive": 0x000000000fffff00,
0115 "volume": 0xff00000000000000
0116 }
0117 geo_mask_values = types.MappingProxyType(geo_mask_dict)
0118
0119
0120 COL_TABLE = {}
0121 CACHED_DATA = {}
0122
0123 def ak_flat(ak_array):
0124 return ak.to_numpy(ak.flatten(ak_array,axis=0))
0125
0126 def ak_df(ak_array):
0127 return ak.to_dataframe(ak_array)
0128
0129 def ak_hist(ak_array, **kwargs):
0130 return plt.hist(ak_flat(ak_array), **kwargs)
0131
0132 def ak_filter(br, cond, field=None):
0133 filtered = br[cond]
0134 return filtered[field] if field else filtered
0135
0136
0137 def ak_sns(ak_array, **kwargs):
0138 """Histogram helper for awkward arrays using seaborn."""
0139 if isinstance(ak_array, (tuple, list)) and len(ak_array) == 2:
0140 x_data = ak_flat(ak_array[0])
0141 y_data = ak_flat(ak_array[1])
0142 kwargs.pop('element', None)
0143 kwargs.pop('fill', None)
0144 return sns.histplot(x=x_data, y=y_data, **kwargs)
0145 return sns.histplot(ak_flat(ak_array), element="step", fill=False, **kwargs)
0146
0147 def get_pdg_info(PDG):
0148 """Get particle info from PDG code"""
0149 try:
0150 return Particle.from_pdgid(PDG)
0151 except Exception:
0152 if PDG == 9902210:
0153 return Particle.from_pdgid(2212)
0154 print(f"ERROR (get_pdg_info): unknown PDG ID {PDG}")
0155 return Particle.empty()
0156
0157 def get_geoID(geoID, name="layer"):
0158 """Extract geometry ID components"""
0159 kMask = geo_mask_values[name]
0160 shift = 0
0161 mask_temp = kMask
0162 while (mask_temp & 1) == 0:
0163 mask_temp >>= 1
0164 shift += 1
0165 return (geoID & kMask) >> shift
0166
0167
0168 def theta2eta(xx, inverse=0):
0169 """Convert theta to eta or vice versa"""
0170 if type(xx)==list:
0171 xx = np.array(xx)
0172 if inverse==1:
0173 return np.arctan((np.e)**(-xx))*2
0174 else:
0175 return -np.log(np.tan(xx/2.))
0176
0177 def select_string(strings, patterns):
0178 """Select strings matching patterns with wildcards"""
0179 if not isinstance(patterns, list):
0180 raise ValueError("The 'patterns' argument must be a list.")
0181
0182 patterns = [pattern.lower() for pattern in patterns]
0183 return [s for s in strings if any(fnmatch(s.lower(), pattern) for pattern in patterns)]
0184
0185 def read_ur(fname, tname, s3_dir="", entry_start=0, entry_stop=None, return_range=False):
0186 """Read ROOT file with uproot
0187 fname: path to file
0188 tname: tree name
0189 s3_dir: if provided, read from corresponding path from simulation campaigns"""
0190 if len(s3_dir) > 0:
0191 server = 'root://dtn-eic.jlab.org//volatile/eic/'
0192 fname = server + s3_dir + fname
0193 tree = ur.open(fname)[tname]
0194 if entry_stop is None or entry_stop == -1:
0195 entry_stop = tree.num_entries
0196 entry_stop = min(entry_stop, tree.num_entries)
0197 if entry_start < 0 or entry_start >= entry_stop:
0198 raise ValueError(f"read_ur: invalid entry range {entry_start}:{entry_stop}")
0199 print(
0200 f"read_ur: read {fname}:{tname}. "
0201 f"{tree.num_entries} events total; using [{entry_start}, {entry_stop})"
0202 )
0203 tree._entry_start = entry_start
0204 tree._entry_stop = entry_stop
0205 if return_range:
0206 return tree, entry_start, entry_stop
0207 return tree
0208
0209 def get_col_table(fname, s3_dir="", verb=0):
0210 """Get collection table from metadata"""
0211 global COL_TABLE
0212 meta = read_ur(fname, "podio_metadata", s3_dir)
0213 if "events___idTable" in meta.keys():
0214 col_name = np.array(meta["m_names"].array()[0])
0215 col_id = np.array(meta["m_collectionIDs"].array()[0])
0216 else:
0217 col_id = get_branch_df(meta,"events___CollectionTypeInfo/events___CollectionTypeInfo.collectionID")["values"].tolist()
0218 col_name = get_branch_df(meta,"events___CollectionTypeInfo/events___CollectionTypeInfo.name")["values"].tolist()
0219
0220 COL_TABLE = {}
0221 for ii, nn in zip(col_id, col_name):
0222 if verb:
0223 print(ii, nn)
0224 COL_TABLE[ii] = nn
0225 return COL_TABLE
0226
0227
0228
0229
0230 def get_branch_ak(tree, bname="", entry_start=0, entry_stop=-1,
0231 fields_subset=None, chunk_size=1000, verb=0):
0232 """Optimized branch reading with awkward arrays"""
0233 if bname not in tree.keys():
0234 raise KeyError(f"get_branch_ak: can't find branch {bname}")
0235 if verb:
0236 print(f"Reading branch: {bname}")
0237 start_time = time.time()
0238
0239
0240 if entry_start == 0 and hasattr(tree, "_entry_start"):
0241 entry_start = tree._entry_start
0242 if entry_stop == -1:
0243 entry_stop = tree._entry_stop if hasattr(tree, "_entry_stop") else tree.num_entries
0244
0245 total_entries = entry_stop - entry_start
0246 if verb:
0247 print(f"Reading {total_entries} entries")
0248
0249
0250 if total_entries > chunk_size:
0251 if verb:
0252 print(f"Using chunked reading with chunk_size={chunk_size}")
0253 all_data = []
0254
0255 for chunk_start in range(entry_start, entry_stop, chunk_size):
0256 chunk_end = min(chunk_start + chunk_size, entry_stop)
0257 if verb:
0258 print(f" Reading chunk: {chunk_start} to {chunk_end}")
0259
0260 chunk_data = tree[bname].array(
0261 library="ak",
0262 entry_start=chunk_start,
0263 entry_stop=chunk_end
0264 )
0265 all_data.append(chunk_data)
0266
0267
0268 ak_data = ak.concatenate(all_data)
0269 else:
0270
0271 ak_data = tree[bname].array(
0272 library="ak",
0273 entry_start=entry_start,
0274 entry_stop=entry_stop
0275 )
0276
0277 read_time = time.time()
0278 if verb:
0279 try:
0280 size_bytes = ak.nbytes(ak_data)
0281 except Exception:
0282 size_bytes = None
0283 if size_bytes is not None:
0284 print(f"Awkward read: {read_time - start_time:.2f}s ({size_bytes/1e6:.1f} MB)")
0285 else:
0286 print(f"Awkward read: {read_time - start_time:.2f}s")
0287
0288
0289 if hasattr(ak_data, 'fields'):
0290 if not ak_data.fields:
0291
0292 return ak_data
0293 renamed_fields = {}
0294 for field in ak_data.fields:
0295 if "[" in field:
0296 continue
0297 if field.startswith(f'{bname}.'):
0298 new_name = field.replace(f'{bname}.', '')
0299 renamed_fields[new_name] = ak_data[field]
0300 else:
0301 renamed_fields[field] = ak_data[field]
0302
0303 ak_data = ak.zip(renamed_fields)
0304 if verb:
0305 print(f"Renamed {len(renamed_fields)} fields")
0306
0307
0308 if fields_subset and hasattr(ak_data, 'fields'):
0309 subset_data = {}
0310 for field in fields_subset:
0311 if field in ak_data.fields:
0312 subset_data[field] = ak_data[field]
0313 ak_data = ak.zip(subset_data)
0314 if verb:
0315 print(f"Extracted subset: {fields_subset}")
0316
0317 total_time = time.time()
0318 if verb:
0319 print(f"Total time: {total_time - start_time:.2f}s")
0320
0321 return ak_data
0322
0323
0324 def get_branch_df(tree, bname="", entry_start=0, entry_stop=-1,
0325 fields_subset=None, chunk_size=1000, verb=0):
0326 """Get branch as DataFrame when needed (for compatibility)"""
0327 ak_data = get_branch_ak(tree, bname, entry_start, entry_stop,
0328 fields_subset, chunk_size)
0329
0330 if verb:
0331 print("Converting to DataFrame...")
0332 start_time = time.time()
0333
0334 try:
0335 df = ak.to_dataframe(ak_data)
0336 df = df.reset_index()
0337
0338 convert_time = time.time() - start_time
0339 if verb:
0340 print(f"DataFrame conversion: {convert_time:.2f}s")
0341 print(f"DataFrame shape: {df.shape}")
0342
0343 return df
0344
0345 except Exception as e:
0346 print(f"DataFrame conversion failed: {e}")
0347 print("Returning awkward array instead")
0348 return ak_data
0349
0350 def get_part(tree, entry_start=0, entry_stop=-1, chunk_size=1000, kprimary=1):
0351 """MC particles reading, return ak with calculated eta, momentum etc"""
0352
0353
0354
0355 ak_data = get_branch_ak(tree, "MCParticles", entry_start, entry_stop,
0356 chunk_size=chunk_size)
0357 if kprimary:
0358 print("Select all primary particles with generatorStatus==x001 or x002")
0359 ak_data=ak_data[(ak_data.generatorStatus%1000==1)|(ak_data.generatorStatus%1000==2)]
0360
0361 px = ak_data["momentum.x"]
0362 py = ak_data["momentum.y"]
0363 pz = ak_data["momentum.z"]
0364
0365 p_mag = np.sqrt(px**2 + py**2 + pz**2)
0366 safe_p_mag = ak.where(p_mag != 0, p_mag, np.nan)
0367 theta = np.arccos(pz / safe_p_mag)
0368 phi = np.arctan2(py, px)
0369 eta = -np.log(np.tan(theta / 2.0))
0370 pt = p_mag * np.sin(theta)
0371
0372
0373 vx = ak_data["vertex.x"]
0374 vy = ak_data["vertex.y"]
0375 vz = ak_data["vertex.z"]
0376 vertex_r = np.sqrt(vx**2 + vy**2)
0377 vertex_dist = np.sqrt(vertex_r**2 + vz**2)
0378
0379
0380 ex = ak_data["endpoint.x"]
0381 ey = ak_data["endpoint.y"]
0382 endpoint_r = np.sqrt(ex**2 + ey**2)
0383
0384
0385 pdg_codes = ak.to_numpy(ak.flatten(ak_data.PDG))
0386 unique_pdgs = np.unique(pdg_codes)
0387
0388
0389 pdg_name_map = {}
0390 for pdg in unique_pdgs:
0391 pdg_name_map[pdg] = get_pdg_info(pdg).name
0392
0393
0394 pdg_names = ak.unflatten(
0395 np.array([pdg_name_map[pdg] for pdg in pdg_codes]),
0396 ak.num(ak_data.PDG)
0397 )
0398
0399
0400 my_field=['PDG', 'generatorStatus', 'charge', 'time', 'mass',
0401 'vertex.x', 'vertex.y', 'vertex.z', 'endpoint.x', 'endpoint.y',
0402 'endpoint.z', 'momentum.x', 'momentum.y', 'momentum.z']
0403 enhanced_data = ak.zip({
0404
0405 **{field: ak_data[field] for field in my_field},
0406
0407 'mom': p_mag,
0408 'theta': theta,
0409 'phi': phi,
0410 'eta': eta,
0411 'pt': pt,
0412 'vertex_r': vertex_r,
0413 'vertex_dist': vertex_dist,
0414 'endpoint_r': endpoint_r,
0415 'pdg_name': pdg_names
0416 })
0417
0418 return enhanced_data
0419
0420 def get_params(tree, bname="CentralCKFTrackParameters", entry_start=0, entry_stop=-1, chunk_size=1000):
0421 """Track Parameters reading, return ak with calculated eta, mom, pt"""
0422
0423
0424 ak_data = get_branch_ak(tree, bname, entry_start, entry_stop,
0425 chunk_size=chunk_size)
0426 eta = theta2eta(ak_data.theta)
0427 mom = abs(1.0 / ak_data.qOverP)
0428 pt = abs(mom * np.sin(ak_data.theta))
0429 ak_data = ak.with_field(ak_data, eta, "eta")
0430 ak_data = ak.with_field(ak_data, mom, "mom")
0431 ak_data = ak.with_field(ak_data, pt, "pt")
0432 return ak_data
0433
0434 def get_branches(trees,bname):
0435 df=pd.DataFrame()
0436 for tree in trees:
0437 if bname=="MCParticles":
0438 dff = get_part(tree)
0439 else:
0440 dff=get_branch_df(tree,bname)
0441 df = pd.concat([df,dff],ignore_index=True)
0442
0443
0444 event_id = []
0445 current_id = -1
0446 prev_entry = None
0447 for e in df["entry"]:
0448 if prev_entry is None:
0449 current_id += 1
0450 elif e != prev_entry:
0451 current_id += 1
0452 event_id.append(current_id)
0453 prev_entry = e
0454 df["event_id"] = event_id
0455
0456 return df
0457
0458
0459 def get_collections(tree, bname='', kflatten=1):
0460 """Extract collections that a given branch pointed to (one to one/many relation)"""
0461
0462 if not COL_TABLE:
0463 raise RuntimeError("COL_TABLE not populated. Call get_col_table() first.")
0464
0465 br = get_branch_ak(tree, bname) if kflatten else get_branch_df(tree, bname, chunk_size=1000)
0466
0467
0468 if hasattr(br, 'fields') and 'collectionID' in br.fields:
0469
0470 colID = np.unique(ak.to_numpy(ak.flatten(br.collectionID)))
0471 collections = {}
0472
0473 print(f"Loading {len(colID)} collections...")
0474 for ii in colID:
0475 if ii in COL_TABLE:
0476
0477 collections[ii] = get_branch_ak(tree, COL_TABLE[ii]) if kflatten else get_branch_df(tree, COL_TABLE[ii], chunk_size=1000)
0478 else:
0479 print(f"Warning: Collection ID {ii} not found in COL_TABLE")
0480
0481 return collections
0482 else:
0483
0484 if hasattr(br, 'fields'):
0485 br_df = ak.to_dataframe(br).reset_index()
0486 else:
0487 br_df = br
0488
0489 if "collectionID" in br_df.columns:
0490 colID = br_df.collectionID.unique()
0491 collections = {}
0492 for ii in colID:
0493 if ii in COL_TABLE:
0494 collections[ii] = get_branch_df(tree, COL_TABLE[ii], chunk_size=1000)
0495 return collections
0496 else:
0497 print("ERROR(get_collections):", bname, "is not a relation.")
0498 return 0
0499
0500 def get_relation(tree, b_name, v_name):
0501 """Get relation or vector members with index"""
0502 print(f"Processing relation: {b_name}.{v_name}")
0503
0504
0505 br = get_branch_df(tree, b_name, chunk_size=1000)
0506
0507
0508 begin_col = v_name + "_begin"
0509 end_col = v_name + "_end"
0510
0511 if begin_col not in br.columns or end_col not in br.columns:
0512 print(f"ERROR(get_relation): {begin_col} or {end_col} not found in {b_name}")
0513 return 0
0514
0515 loc1 = br.columns.get_loc(begin_col)
0516 loc2 = br.columns.get_loc(end_col)
0517 in_name = "_" + b_name + "_" + v_name
0518
0519
0520 try:
0521 app = get_branch_df(tree, in_name, chunk_size=1000)
0522 except:
0523 print(f"ERROR(get_relation): Cannot read branch {in_name}")
0524 return 0
0525
0526 if not isinstance(app, pd.DataFrame):
0527 print(f"ERROR(get_relation): {in_name} is not a valid DataFrame")
0528 return 0
0529
0530
0531 if len(app.columns) == 3 and app.columns[2] == "values":
0532 print("Processing vector values...")
0533 l_val = []
0534 l_ind = []
0535
0536 for row in br.itertuples(index=False):
0537 l_ind.append(row[0])
0538
0539 i1, i2 = row[loc1], row[loc2]
0540 if i1 == i2:
0541 l_val.append(np.array([]))
0542 else:
0543
0544 event_data = app[app['entry'] == row.entry]
0545 if not event_data.empty:
0546 v_row = np.array(event_data['values'])[i1:i2]
0547 l_val.append(v_row)
0548
0549 else:
0550 l_val.append(np.array([]))
0551
0552 return pd.DataFrame({"entry": l_ind, "values": l_val})
0553
0554
0555 elif len(app.columns) > 1 and 'index' in app.columns and 'collectionID' in app.columns:
0556 print("Processing relations...")
0557 l_index = []
0558 l_id = []
0559
0560
0561 app_grouped = app.groupby('entry')
0562
0563 for row in br.itertuples(index=False):
0564 i1, i2 = row[loc1], row[loc2]
0565 if i1 == i2:
0566 l_index.append(np.array([]))
0567 l_id.append(np.array([]))
0568 else:
0569
0570 if row.entry in app_grouped.groups:
0571 v_row = app_grouped.get_group(row.entry)
0572 if len(v_row) >= i2:
0573 l1 = np.array(v_row["index"])[i1:i2]
0574 l2 = np.array(v_row["collectionID"])[i1:i2]
0575
0576
0577 l_index.append(l1)
0578 l_id.append(l2)
0579 else:
0580 l_index.append(np.array([]))
0581 l_id.append(np.array([]))
0582 else:
0583 l_index.append(np.array([]))
0584 l_id.append(np.array([]))
0585
0586 return pd.DataFrame({"index": l_index, "collectionID": l_id})
0587
0588 else:
0589 print("ERROR(get_relation): Invalid vector or relation member structure")
0590 print(f"Columns found: {app.columns.tolist()}")
0591 return 0
0592
0593 def get_branch_relation(tree, branch_name="CentralCKFTrajectories", relation_name="measurements_deprecated", relation_variables=["*"]):
0594 """Get relation or vector members appended to the original branch with optimization"""
0595 print(f"Processing branch relation: {branch_name}.{relation_name}")
0596 if not COL_TABLE:
0597 raise RuntimeError("COL_TABLE not populated. Call get_col_table() first.")
0598
0599 br = get_branch_df(tree, branch_name, chunk_size=1000)
0600 df = get_relation(tree, branch_name, relation_name)
0601
0602 if not isinstance(df, pd.DataFrame):
0603 print("ERROR (get_branch_relation): please provide a valid relation name.")
0604 return br, None
0605
0606
0607 if len(df.columns) == 1 and df.columns[0] == "values":
0608 return br, df["values"]
0609
0610
0611 br[relation_name + "_index"] = df["index"]
0612 br[relation_name + "_colID"] = df["collectionID"]
0613
0614
0615 if len(relation_variables) == 0:
0616 br = br.explode([relation_name + "_index", relation_name + "_colID"]).reset_index(drop=True)
0617 return br, None
0618
0619
0620 in_name = "_" + branch_name + "_" + relation_name
0621 print("Loading collections...")
0622 collections = get_collections(tree, in_name, 0)
0623
0624 if not collections:
0625 print(f"ERROR: No collections {in_name} found")
0626 return br, None
0627
0628
0629 print("Processing relation data...")
0630 l_relations = []
0631 l_collection_names = []
0632 loc1 = br.columns.get_loc(relation_name + "_index")
0633 loc2 = br.columns.get_loc(relation_name + "_colID")
0634
0635
0636 collections_grouped = {}
0637 sample_columns = None
0638
0639 for col_id, col_data in collections.items():
0640 if hasattr(col_data, 'fields'):
0641 col_df = ak.to_dataframe(col_data).reset_index()
0642 else:
0643 col_df = col_data
0644 collections_grouped[col_id] = col_df.groupby('entry')
0645 if sample_columns is None:
0646 sample_columns = col_df.columns
0647
0648
0649 for row in br.itertuples(index=False):
0650 ind = row[loc1]
0651 col = row[loc2]
0652 if len(ind) == 0:
0653 l_collection_names.append(None)
0654
0655 if sample_columns is not None:
0656 l_relations.append([np.nan] * (len(sample_columns) - 2))
0657 else:
0658 l_relations.append([np.nan])
0659 else:
0660 for ii, cc in zip(ind, col):
0661 if cc in COL_TABLE:
0662 l_collection_names.append(COL_TABLE[cc])
0663
0664
0665 if cc in collections_grouped and row.entry in collections_grouped[cc].groups:
0666 event_data = collections_grouped[cc].get_group(row.entry)
0667 if ii < len(event_data):
0668
0669 row_data = event_data.iloc[ii]
0670 filtered_data = [row_data[col] for col in row_data.index if col not in ['entry', 'subentry']]
0671 l_relations.append(filtered_data)
0672 else:
0673 l_relations.append([np.nan] * (len(sample_columns) - 2))
0674 else:
0675 l_relations.append([np.nan] * (len(sample_columns) - 2))
0676 else:
0677 l_collection_names.append(f"Unknown_{cc}")
0678 l_relations.append([np.nan] * (len(sample_columns) - 2))
0679
0680
0681 br = br.explode([relation_name + "_index", relation_name + "_colID"]).reset_index(drop=True)
0682
0683
0684 if sample_columns is not None:
0685 column_names = [col for col in sample_columns if col not in ['entry', 'subentry']]
0686 else:
0687 column_names = ['value']
0688
0689 df_add = pd.DataFrame(l_relations, columns=column_names)
0690
0691
0692 if relation_variables != ["*"]:
0693 available_columns = select_string(column_names, relation_variables)
0694 df_add = df_add[available_columns]
0695
0696
0697 df_add[relation_name + "_colName"] = l_collection_names
0698
0699
0700 df_add.insert(0, 'entry', br['entry'])
0701 df_add.insert(1, 'subentry', br['subentry'])
0702
0703 print(f"Completed processing. Result shape: {df_add.shape}")
0704 return br, df_add
0705
0706 def get_traj_relations(tree,bname="CentralCKFTrajectories",l_var=["measurementChi2", "outlierChi2", "trackParameters", "measurements_deprecated", "outliers_deprecated"]):
0707
0708 br = get_branch_df(tree,bname)
0709 print("get_traj_relations: accessing the following vector members:")
0710 for vv in l_var:
0711 print(vv)
0712 a = get_relation(tree,bname,vv)
0713 if not isinstance(a, pd.DataFrame):
0714 print(f"WARNING(get_traj_relations): {bname}.{vv} returned no data")
0715 continue
0716 for cc in a.columns:
0717 if cc=='values':
0718 br[vv]=a[cc]
0719 break
0720 elif cc=='index':
0721 br[vv+'_index']=a[cc]
0722 elif cc=='collectionID':
0723 br[vv+'_colID']=a[cc]
0724 else:
0725 print('WARNING: invalid column ',cc,' in ',bname,'_',vv)
0726 return br
0727
0728
0729 def get_traj_hits_particle(tree, traj_name = 'CentralCKFTrajectories', measurement_name='measurements_deprecated'):
0730 '''
0731 for each trajectory, find corresponding rec hits (measurements_deprecated or outliers_deprecated) and particle index
0732 '''
0733
0734
0735
0736
0737 _,df_hits = get_branch_relation(tree,branch_name="CentralTrackerMeasurements",relation_name="hits")
0738
0739
0740
0741 name_sim_tracker = name_sim_barrel + name_sim_disk + ['B0TrackerHits']
0742 name_rec_tracker = name_rec_barrel + name_rec_disk + ["B0TrackerRecHits"]
0743 def rec_sim_trackerhit_mapping(rec_tracker, sim_tracker):
0744 return {**{rec: sim for rec, sim in zip(rec_tracker, sim_tracker)},
0745 **{sim: rec for rec, sim in zip(rec_tracker, sim_tracker)}}
0746 HitName_dict = rec_sim_trackerhit_mapping(name_rec_tracker, name_sim_tracker)
0747
0748 col_name = set(df_hits.hits_colName.values)
0749 tracker_hits = {}
0750 for cc in col_name:
0751 try:
0752 cc = HitName_dict[cc]
0753 tracker_hits[cc] = get_branch_df(tree,cc)
0754 except:
0755 print(f"WARNING: {cc} is not a valid hit collection")
0756 continue
0757 try:
0758 relation = get_branch_df(tree,f"_{cc}_particle")
0759 except:
0760 print(f"WARNING: _{cc}_particle is not a branch")
0761 continue
0762 tracker_hits[cc]["particle_index"] = relation["index"]
0763
0764
0765 def get_particle_index(col,evtID,cellID):
0766 hits = tracker_hits[HitName_dict[col]]
0767 matched = hits[(hits.entry==evtID)&(hits.cellID==cellID)]
0768 if matched.empty:
0769 return np.nan
0770 return matched.particle_index.values[0]
0771 df_hits['particle_index'] = df_hits.apply(
0772 lambda row: get_particle_index(row['hits_colName'], row['entry'], row['cellID']),
0773 axis=1,
0774 )
0775
0776
0777 br0,_= get_branch_relation(tree,branch_name=traj_name,relation_name=measurement_name,relation_variables=[])
0778 def get_hits(evtID, hitID):
0779 matched = df_hits[(df_hits.entry==evtID)&(df_hits.subentry==hitID)]
0780 if matched.empty:
0781 return pd.Series({col: np.nan for col in df_hits.columns})
0782 return matched.iloc[0]
0783 temp_hits = br0.apply(lambda row: get_hits(row['entry'], row[measurement_name+'_index']), axis=1)
0784 traj_hits = pd.concat([br0,temp_hits.drop(['entry','subentry'], axis=1, inplace=False)],axis=1)
0785 traj_hits['position.r'] = np.sqrt(traj_hits['position.x']**2+traj_hits['position.y']**2)
0786 traj_hits['position.phi'] = np.arctan2(traj_hits['position.y'], traj_hits['position.x'])
0787
0788 return traj_hits
0789
0790 from lmfit.models import GaussianModel
0791
0792 def gaussian(x, amplitude, mean, std):
0793 return amplitude * np.exp(-0.5 * ((x - mean) / std) ** 2) / (std * np.sqrt(2 * np.pi))
0794
0795 def hist_gaus(
0796 data, ax,
0797 bins=100, klog=False, header=None,
0798 clip=3.0, max_iters=5, min_points=50,
0799 verbose=False,
0800 ):
0801 data = np.asarray(data, dtype=float)
0802 data = data[np.isfinite(data)]
0803 if data.size < min_points:
0804 if verbose:
0805 print('hist_gaus: not enough finite points')
0806 return np.nan, np.nan, np.nan
0807
0808 center = np.median(data)
0809 mad = np.median(np.abs(data - center))
0810 scale = 1.4826 * mad if mad > 0 else np.std(data)
0811 if not np.isfinite(scale) or scale <= 0:
0812 if verbose:
0813 print('hist_gaus: invalid initial scale')
0814 return np.nan, np.nan, np.nan
0815
0816 for _ in range(max_iters):
0817 mask = np.abs(data - center) <= clip * scale
0818 clipped = data[mask]
0819 if clipped.size < min_points:
0820 break
0821 new_center = np.median(clipped)
0822 mad = np.median(np.abs(clipped - new_center))
0823 new_scale = 1.4826 * mad if mad > 0 else np.std(clipped)
0824 if not np.isfinite(new_scale) or new_scale <= 0:
0825 break
0826 if np.isclose(new_center, center) and np.isclose(new_scale, scale):
0827 center, scale = new_center, new_scale
0828 break
0829 center, scale = new_center, new_scale
0830
0831 lo, hi = center - clip * scale, center + clip * scale
0832 if lo == hi:
0833 if verbose:
0834 print('hist_gaus: degenerate range')
0835 return np.nan, np.nan, np.nan
0836
0837 counts, edges = np.histogram(data, bins=bins, range=(lo, hi))
0838 mid = 0.5 * (edges[:-1] + edges[1:])
0839
0840 if ax is not None:
0841 ax.hist(data, bins=bins, range=(lo, hi), histtype='stepfilled', alpha=0.3)
0842
0843 model = GaussianModel()
0844 params = model.make_params(center=center, sigma=scale, amplitude=np.max(counts))
0845 try:
0846 result = model.fit(counts, params, x=mid)
0847 except Exception as exc:
0848 if verbose:
0849 print('hist_gaus: fit failed', exc)
0850 return np.nan, np.nan, np.nan
0851
0852 sigma = float(result.params['sigma'])
0853 sigma_err = result.params['sigma'].stderr
0854 if sigma_err is None or not np.isfinite(sigma) or sigma <= 0:
0855 if verbose:
0856 print('hist_gaus: invalid fit result')
0857 return np.nan, np.nan, np.nan
0858
0859 mean = float(result.params['center'])
0860 sigma_err = float(sigma_err)
0861 ampl = float(result.params['amplitude'])
0862 peak = ampl / (sigma * np.sqrt(2 * np.pi))
0863
0864 if ax is not None:
0865 ax.plot(mid, gaussian(mid, ampl, mean, sigma), 'r-')
0866 if header:
0867 ax.set_title(header)
0868 ax.set_xlabel('value')
0869 ax.set_ylabel('entries')
0870 if klog:
0871 ax.set_yscale('log')
0872 else:
0873 ymax = max(np.max(counts), peak)
0874 ax.set_ylim(0, ymax / 0.7)
0875
0876 return mean, sigma, sigma_err
0877
0878
0879 __all__ = [
0880 "configure_analysis_environment",
0881 "ak_flat",
0882 "ak_df",
0883 "ak_hist",
0884 "ak_filter",
0885 "get_pdg_info",
0886 "get_geoID",
0887 "theta2eta",
0888 "select_string",
0889 "read_ur",
0890 "get_col_table",
0891 "get_branch_ak",
0892 "get_branch_df",
0893 "get_part",
0894 "get_params",
0895 "get_branches",
0896 "get_collections",
0897 "get_relation",
0898 "get_branch_relation",
0899 "get_traj_relations",
0900 "get_traj_hits_particle",
0901 "deg2rad",
0902 "status_to_source",
0903 "geo_mask_values",
0904 "hist_gaus"
0905 ]