Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-03-29 08:37:26

0001 '''
0002     Utility functions for epic tracking analysis. 
0003     See also epic_analysis_podio.py
0004     Shujie Li, Aug 2025
0005 '''
0006 
0007 # %load_ext memory_profiler
0008 import numpy as np
0009 import pandas as pd
0010 import seaborn as sns
0011 
0012 import awkward as ak
0013 import uproot as ur
0014 # import podio
0015 # from podio import root_io
0016 
0017 import time
0018 from fnmatch import fnmatch
0019 import types
0020 from particle import Particle
0021 
0022 from matplotlib.backends.backend_pdf import PdfPages
0023 from matplotlib.colors import LogNorm
0024 from matplotlib import pyplot as plt
0025 from matplotlib.gridspec import GridSpec
0026 import matplotlib.ticker as ticker
0027 import matplotlib.cm as cm
0028 import matplotlib as mpl
0029 
0030 def configure_analysis_environment(
0031     apply_pandas=True,
0032     apply_matplotlib=True,
0033     apply_sns=True
0034 ):
0035     """Apply pandas/matplotlib defaults for interactive analysis."""
0036     if apply_pandas:
0037         pd.options.display.max_rows = 200
0038         pd.options.display.min_rows = 20
0039         pd.options.display.max_columns = 100
0040     if apply_matplotlib:
0041         plt.rcParams['figure.figsize'] = [8.0, 6.0]
0042         plt.rcParams['ytick.direction'] = 'in'
0043         plt.rcParams['xtick.direction'] = 'in'
0044         plt.rcParams['xaxis.labellocation'] = 'right'
0045         plt.rcParams['yaxis.labellocation'] = 'top'
0046         small_size = 10
0047         medium_size = 12
0048         bigger_size = 20
0049         plt.rc('font', size=small_size)
0050         plt.rc('axes', titlesize=medium_size)
0051         plt.rc('axes', labelsize=medium_size)
0052         plt.rc('xtick', labelsize=medium_size)
0053         plt.rc('ytick', labelsize=medium_size)
0054         plt.rc('legend', fontsize=small_size)
0055         plt.rc('figure', titlesize=bigger_size)
0056     if apply_sns:
0057         sns.set_theme(
0058             style='whitegrid',
0059             context='notebook',
0060             palette='bright',
0061             font_scale=1.0,
0062             rc={'figure.figsize': (6, 4)},
0063         )
0064 
0065 # Constants
0066 deg2rad = np.pi/180.0
0067 
0068 ## event source 
0069 status_to_source = {
0070     1: "DIS", 2: "DIS",
0071     2001: "SR", 2002: "SR", 
0072     3001: "Bremstrahlung", 3002: "Bremstrahlung",
0073     4001: "Coulomb", 4002: "Coulomb",
0074     5001: "Touschek", 5002: "Touschek",
0075     6001: "Proton beam gas", 6002: "Proton beam gas"
0076 }
0077 
0078 ## track quality cuts
0079 TRACK_HIT_COUNT_MIN_MIN   = 3 ## absolute min to form a track with CKF
0080 TRACK_HIT_COUNT_MIN       = 4
0081 TRACK_MOM_MIN             = 0.3
0082 TRACK_PT_MIN              = 0.2
0083 TRACK_HIT_FRACTION_MIN    = 0.5
0084 TRACK_HIT_COUNT_GHOST_MAX = 2
0085 VERTEX_CUT_R_MAX          = 2
0086 VERTEX_CUT_Z_MAX          = 200#mm
0087 
0088 # Detector geometry definitions (unchanged)
0089 barrel_range = [(30,42),(46,60),(115,130),(250,290),(410,450),(540,600),(620,655),(700,760)]
0090 barrel_name = ["L0","L1","L2","L3","L4","inner MPGD","TOF","outer MPGD"]
0091 name_sim_barrel = ["VertexBarrelHits","VertexBarrelHits","VertexBarrelHits","SiBarrelHits","SiBarrelHits","MPGDBarrelHits","TOFBarrelHits","OuterMPGDBarrelHits"]
0092 name_rec_barrel = ["SiBarrelVertexRecHits","SiBarrelVertexRecHits","SiBarrelVertexRecHits","SiBarrelTrackerRecHits","SiBarrelTrackerRecHits","MPGDBarrelRecHits","TOFBarrelRecHits","OuterMPGDBarrelRecHits"]
0093 
0094 disk_range = [(-1210.0, -1190.0), (-1110.0, -1090.0),(-1055.0, -1000.0), (-860.0, -840.0),
0095  (-660.0, -640.0), (-460.0, -440.0), (-260.0, -240.0), (240.0, 260.0),
0096  (440.0, 460.0), (690.0, 710.0), (990.0, 1010.0), (1340.0, 1360.0),
0097  (1480.0, 1500.0), (1600.0, 1620.0), (1840.0, 1860.0), (1865.0, 1885.0)]
0098 disk_name = ["E-MPGD Disk2","E-MPGD Disk 1","E-Si Disk 4","E-Si Disk 3","E-Si Disk 2","E-Si Disk 1","E-Si Disk 0",
0099                 "H-Si Disk 0","H-Si Disk 1","H-Si Disk 2","H-Si Disk 3","H-Si Disk 4","H-MPGD Disk 1","H-MPGD Disk 2", "H-TOF Disk1","H-TOF Disk2"]
0100 name_rec_disk = ["BackwardMPGDEndcapRecHits","BackwardMPGDEndcapRecHits",
0101                "SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits","SiEndcapTrackerRecHits",
0102                "ForwardMPGDEndcapRecHits","ForwardMPGDEndcapRecHits",
0103                "TOFEndcapRecHits","TOFEndcapRecHits"]
0104 name_sim_disk = ["BackwardMPGDEndcapHits","BackwardMPGDEndcapHits",
0105                "TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","TrackerEndcapHits","ForwardMPGDEndcapHits","ForwardMPGDEndcapHits",
0106                "TOFEndcapHits","TOFEndcapHits"]
0107 
0108 # ACTS geometry ID masks
0109 geo_mask_dict = {
0110     "approach": 0x0000000ff0000000,
0111     "boundary": 0x00ff000000000000,
0112     "extra": 0x00000000000000ff,
0113     "layer": 0x0000fff000000000,
0114     "sensitive": 0x000000000fffff00,
0115     "volume": 0xff00000000000000
0116 }
0117 geo_mask_values = types.MappingProxyType(geo_mask_dict)
0118 
0119 # Global variables for caching
0120 COL_TABLE = {}
0121 CACHED_DATA = {}
0122 
0123 def ak_flat(ak_array):
0124     return ak.to_numpy(ak.flatten(ak_array,axis=0))
0125 
0126 def ak_df(ak_array):
0127     return ak.to_dataframe(ak_array)
0128 
0129 def ak_hist(ak_array, **kwargs):
0130     return plt.hist(ak_flat(ak_array), **kwargs)
0131 
0132 def ak_filter(br, cond, field=None):
0133     filtered = br[cond]
0134     return filtered[field] if field else filtered
0135 
0136 
0137 def ak_sns(ak_array, **kwargs):
0138     """Histogram helper for awkward arrays using seaborn."""
0139     if isinstance(ak_array, (tuple, list)) and len(ak_array) == 2:
0140         x_data = ak_flat(ak_array[0])
0141         y_data = ak_flat(ak_array[1])
0142         kwargs.pop('element', None)
0143         kwargs.pop('fill', None)
0144         return sns.histplot(x=x_data, y=y_data, **kwargs)
0145     return sns.histplot(ak_flat(ak_array), element="step", fill=False, **kwargs)
0146 
0147 def get_pdg_info(PDG):
0148     """Get particle info from PDG code"""
0149     try:
0150         return Particle.from_pdgid(PDG)
0151     except Exception:
0152         if PDG == 9902210:
0153             return Particle.from_pdgid(2212)
0154         print(f"ERROR (get_pdg_info): unknown PDG ID {PDG}")
0155         return Particle.empty()
0156 
0157 def get_geoID(geoID, name="layer"):
0158     """Extract geometry ID components"""
0159     kMask = geo_mask_values[name]
0160     shift = 0
0161     mask_temp = kMask
0162     while (mask_temp & 1) == 0:
0163         mask_temp >>= 1
0164         shift += 1
0165     return (geoID & kMask) >> shift
0166 
0167 
0168 def theta2eta(xx, inverse=0):
0169     """Convert theta to eta or vice versa"""
0170     if type(xx)==list:
0171         xx = np.array(xx)
0172     if inverse==1:
0173         return np.arctan((np.e)**(-xx))*2
0174     else:
0175         return -np.log(np.tan(xx/2.))
0176 
0177 def select_string(strings, patterns):
0178     """Select strings matching patterns with wildcards"""
0179     if not isinstance(patterns, list):
0180         raise ValueError("The 'patterns' argument must be a list.")
0181     
0182     patterns = [pattern.lower() for pattern in patterns]
0183     return [s for s in strings if any(fnmatch(s.lower(), pattern) for pattern in patterns)]
0184 
0185 def read_ur(fname, tname, s3_dir="", entry_start=0, entry_stop=None, return_range=False):
0186     """Read ROOT file with uproot
0187     fname: path to file
0188     tname: tree name
0189     s3_dir: if provided, read from corresponding path from simulation campaigns"""
0190     if len(s3_dir) > 0:
0191         server = 'root://dtn-eic.jlab.org//volatile/eic/'
0192         fname = server + s3_dir + fname
0193     tree = ur.open(fname)[tname]
0194     if entry_stop is None or entry_stop == -1:
0195         entry_stop = tree.num_entries
0196     entry_stop = min(entry_stop, tree.num_entries)
0197     if entry_start < 0 or entry_start >= entry_stop:
0198         raise ValueError(f"read_ur: invalid entry range {entry_start}:{entry_stop}")
0199     print(
0200         f"read_ur: read {fname}:{tname}. "
0201         f"{tree.num_entries} events total; using [{entry_start}, {entry_stop})"
0202     )
0203     tree._entry_start = entry_start
0204     tree._entry_stop = entry_stop
0205     if return_range:
0206         return tree, entry_start, entry_stop
0207     return tree
0208 
0209 def get_col_table(fname, s3_dir="", verb=0):
0210     """Get collection table from metadata"""
0211     global COL_TABLE
0212     meta = read_ur(fname, "podio_metadata", s3_dir)
0213     if "events___idTable" in meta.keys(): ## < eic-shell 25.09
0214         col_name = np.array(meta["m_names"].array()[0])
0215         col_id = np.array(meta["m_collectionIDs"].array()[0])
0216     else:
0217         col_id = get_branch_df(meta,"events___CollectionTypeInfo/events___CollectionTypeInfo.collectionID")["values"].tolist()
0218         col_name = get_branch_df(meta,"events___CollectionTypeInfo/events___CollectionTypeInfo.name")["values"].tolist()
0219 
0220     COL_TABLE = {}
0221     for ii, nn in zip(col_id, col_name):
0222         if verb:
0223             print(ii, nn)
0224         COL_TABLE[ii] = nn
0225     return COL_TABLE
0226 
0227 
0228 # ============= BRANCH READING with ak or df =============
0229 
0230 def get_branch_ak(tree, bname="", entry_start=0, entry_stop=-1, 
0231                            fields_subset=None, chunk_size=1000, verb=0):
0232     """Optimized branch reading with awkward arrays"""
0233     if bname not in tree.keys():
0234         raise KeyError(f"get_branch_ak: can't find branch {bname}")
0235     if verb:    
0236         print(f"Reading branch: {bname}")
0237     start_time = time.time()
0238     
0239     # Determine actual entry range
0240     if entry_start == 0 and hasattr(tree, "_entry_start"):
0241         entry_start = tree._entry_start
0242     if entry_stop == -1:
0243         entry_stop = tree._entry_stop if hasattr(tree, "_entry_stop") else tree.num_entries
0244     
0245     total_entries = entry_stop - entry_start
0246     if verb:
0247         print(f"Reading {total_entries} entries")
0248     
0249     # For large datasets, read in chunks
0250     if total_entries > chunk_size:
0251         if verb:
0252             print(f"Using chunked reading with chunk_size={chunk_size}")
0253         all_data = []
0254         
0255         for chunk_start in range(entry_start, entry_stop, chunk_size):
0256             chunk_end = min(chunk_start + chunk_size, entry_stop)
0257             if verb:
0258                 print(f"  Reading chunk: {chunk_start} to {chunk_end}")
0259             
0260             chunk_data = tree[bname].array(
0261                 library="ak",
0262                 entry_start=chunk_start,
0263                 entry_stop=chunk_end
0264             )
0265             all_data.append(chunk_data)
0266         
0267         # Concatenate all chunks
0268         ak_data = ak.concatenate(all_data)
0269     else:
0270         # Read all at once for smaller datasets
0271         ak_data = tree[bname].array(
0272             library="ak",
0273             entry_start=entry_start,
0274             entry_stop=entry_stop
0275         )
0276     
0277     read_time = time.time()
0278     if verb:
0279         try:
0280             size_bytes = ak.nbytes(ak_data)
0281         except Exception:
0282             size_bytes = None
0283         if size_bytes is not None:
0284             print(f"Awkward read: {read_time - start_time:.2f}s ({size_bytes/1e6:.1f} MB)")
0285         else:
0286             print(f"Awkward read: {read_time - start_time:.2f}s")
0287     
0288     # Rename fields by dropping branch prefix
0289     if hasattr(ak_data, 'fields'):
0290         if not ak_data.fields:
0291             ## return a single array
0292             return ak_data
0293         renamed_fields = {}
0294         for field in ak_data.fields:
0295             if "[" in field: ## drop nested array such as CentralCKFTrackParameters.covariance.covariance[21]
0296                 continue
0297             if field.startswith(f'{bname}.'):
0298                 new_name = field.replace(f'{bname}.', '')
0299                 renamed_fields[new_name] = ak_data[field]
0300             else:
0301                 renamed_fields[field] = ak_data[field]
0302         
0303         ak_data = ak.zip(renamed_fields)
0304         if verb:
0305             print(f"Renamed {len(renamed_fields)} fields")
0306     
0307     # Extract subset of fields if specified
0308     if fields_subset and hasattr(ak_data, 'fields'):
0309         subset_data = {}
0310         for field in fields_subset:
0311             if field in ak_data.fields:
0312                 subset_data[field] = ak_data[field]
0313         ak_data = ak.zip(subset_data)
0314         if verb:
0315             print(f"Extracted subset: {fields_subset}")
0316     
0317     total_time = time.time()
0318     if verb:
0319             print(f"Total time: {total_time - start_time:.2f}s")
0320     
0321     return ak_data
0322 
0323 
0324 def get_branch_df(tree, bname="", entry_start=0, entry_stop=-1, 
0325                            fields_subset=None, chunk_size=1000, verb=0):
0326     """Get branch as DataFrame when needed (for compatibility)"""
0327     ak_data = get_branch_ak(tree, bname, entry_start, entry_stop, 
0328                                      fields_subset, chunk_size)
0329     
0330     if verb:
0331         print("Converting to DataFrame...")
0332     start_time = time.time()
0333     
0334     try:
0335         df = ak.to_dataframe(ak_data)
0336         df = df.reset_index()  # Flatten the multi-index
0337         
0338         convert_time = time.time() - start_time
0339         if verb:
0340             print(f"DataFrame conversion: {convert_time:.2f}s")
0341             print(f"DataFrame shape: {df.shape}")
0342         
0343         return df
0344     
0345     except Exception as e:
0346         print(f"DataFrame conversion failed: {e}")
0347         print("Returning awkward array instead")
0348         return ak_data
0349 
0350 def get_part(tree, entry_start=0, entry_stop=-1, chunk_size=1000, kprimary=1):
0351     """MC particles reading, return ak with calculated eta, momentum etc"""
0352     # print("Reading MC particles as akward arrays...")
0353     
0354     # Read as awkward array
0355     ak_data = get_branch_ak(tree, "MCParticles", entry_start, entry_stop, 
0356                                      chunk_size=chunk_size)
0357     if kprimary:
0358         print("Select all primary particles with generatorStatus==x001 or x002")
0359         ak_data=ak_data[(ak_data.generatorStatus%1000==1)|(ak_data.generatorStatus%1000==2)]
0360     # Compute momentum quantities vectorized
0361     px = ak_data["momentum.x"]
0362     py = ak_data["momentum.y"]
0363     pz = ak_data["momentum.z"]
0364     
0365     p_mag = np.sqrt(px**2 + py**2 + pz**2)
0366     safe_p_mag = ak.where(p_mag != 0, p_mag, np.nan)
0367     theta = np.arccos(pz / safe_p_mag)
0368     phi = np.arctan2(py, px)
0369     eta = -np.log(np.tan(theta / 2.0))
0370     pt = p_mag * np.sin(theta)
0371     
0372     # Compute vertex quantities
0373     vx = ak_data["vertex.x"]
0374     vy = ak_data["vertex.y"]
0375     vz = ak_data["vertex.z"]
0376     vertex_r = np.sqrt(vx**2 + vy**2)
0377     vertex_dist = np.sqrt(vertex_r**2 + vz**2)
0378     
0379     # Compute endpoint quantities  
0380     ex = ak_data["endpoint.x"]
0381     ey = ak_data["endpoint.y"]
0382     endpoint_r = np.sqrt(ex**2 + ey**2)
0383     
0384     # Get PDG names (this is slower, so we do it efficiently)
0385     pdg_codes = ak.to_numpy(ak.flatten(ak_data.PDG))
0386     unique_pdgs = np.unique(pdg_codes)
0387     
0388     # Create PDG name mapping for unique values only
0389     pdg_name_map = {}
0390     for pdg in unique_pdgs:
0391         pdg_name_map[pdg] = get_pdg_info(pdg).name
0392     
0393     # Apply mapping vectorized
0394     pdg_names = ak.unflatten(
0395         np.array([pdg_name_map[pdg] for pdg in pdg_codes]),
0396         ak.num(ak_data.PDG)
0397     )
0398     
0399     # Create new awkward array with selected quantities
0400     my_field=['PDG', 'generatorStatus',  'charge', 'time', 'mass',
0401        'vertex.x', 'vertex.y', 'vertex.z', 'endpoint.x', 'endpoint.y',
0402        'endpoint.z', 'momentum.x', 'momentum.y', 'momentum.z']
0403     enhanced_data = ak.zip({
0404         # Original fields
0405         **{field: ak_data[field] for field in my_field},
0406         # Derived quantities
0407         'mom': p_mag,
0408         'theta': theta,
0409         'phi': phi,
0410         'eta': eta,
0411         'pt': pt,
0412         'vertex_r': vertex_r,
0413         'vertex_dist': vertex_dist,
0414         'endpoint_r': endpoint_r,
0415         'pdg_name': pdg_names
0416     })
0417     
0418     return enhanced_data
0419 
0420 def get_params(tree, bname="CentralCKFTrackParameters", entry_start=0, entry_stop=-1, chunk_size=1000):
0421     """Track Parameters reading, return ak with calculated eta, mom, pt"""
0422     
0423     # Read as awkward array
0424     ak_data = get_branch_ak(tree, bname, entry_start, entry_stop, 
0425                                      chunk_size=chunk_size)
0426     eta = theta2eta(ak_data.theta)
0427     mom = abs(1.0 / ak_data.qOverP)
0428     pt = abs(mom * np.sin(ak_data.theta))
0429     ak_data = ak.with_field(ak_data, eta, "eta")
0430     ak_data = ak.with_field(ak_data, mom, "mom")
0431     ak_data = ak.with_field(ak_data, pt, "pt")
0432     return ak_data
0433 
0434 def get_branches(trees,bname):
0435     df=pd.DataFrame()
0436     for tree in trees:
0437         if bname=="MCParticles":
0438             dff = get_part(tree)
0439         else:
0440             dff=get_branch_df(tree,bname)
0441         df = pd.concat([df,dff],ignore_index=True)
0442 
0443         ## add a new counter (event_id)
0444         event_id = []
0445         current_id = -1
0446         prev_entry = None
0447         for e in df["entry"]:
0448             if prev_entry is None:
0449                 current_id += 1
0450             elif e != prev_entry:
0451                 current_id += 1
0452             event_id.append(current_id)
0453             prev_entry = e
0454         df["event_id"] = event_id
0455 
0456     return df
0457 
0458 
0459 def get_collections(tree, bname='', kflatten=1):
0460     """Extract collections that a given branch pointed to (one to one/many relation)"""
0461 
0462     if not COL_TABLE:
0463         raise RuntimeError("COL_TABLE not populated. Call get_col_table() first.")    
0464     # Use the optimized awkward array reader
0465     br = get_branch_ak(tree, bname) if kflatten else get_branch_df(tree, bname, chunk_size=1000)
0466     
0467     # Convert to DataFrame to check for collectionID
0468     if hasattr(br, 'fields') and 'collectionID' in br.fields:
0469         # Get unique collection IDs efficiently
0470         colID = np.unique(ak.to_numpy(ak.flatten(br.collectionID)))
0471         collections = {}
0472 
0473         print(f"Loading {len(colID)} collections...")
0474         for ii in colID:
0475             if ii in COL_TABLE:
0476                 # Use optimized reader for each collection
0477                 collections[ii] = get_branch_ak(tree, COL_TABLE[ii]) if kflatten else get_branch_df(tree, COL_TABLE[ii], chunk_size=1000)
0478             else:
0479                 print(f"Warning: Collection ID {ii} not found in COL_TABLE")
0480         
0481         return collections
0482     else:
0483         # Convert to DataFrame to check columns if it's awkward array
0484         if hasattr(br, 'fields'):
0485             br_df = ak.to_dataframe(br).reset_index()
0486         else:
0487             br_df = br
0488             
0489         if "collectionID" in br_df.columns:
0490             colID = br_df.collectionID.unique()
0491             collections = {}
0492             for ii in colID:
0493                 if ii in COL_TABLE:
0494                     collections[ii] = get_branch_df(tree, COL_TABLE[ii], chunk_size=1000)
0495             return collections
0496         else:
0497             print("ERROR(get_collections):", bname, "is not a relation.")
0498             return 0
0499 
0500 def get_relation(tree, b_name, v_name):
0501     """Get relation or vector members with index"""
0502     print(f"Processing relation: {b_name}.{v_name}")
0503     
0504     # Read main branch as DataFrame for index operations
0505     br = get_branch_df(tree, b_name, chunk_size=1000)
0506     
0507     # Check if the required columns exist
0508     begin_col = v_name + "_begin"
0509     end_col = v_name + "_end"
0510     
0511     if begin_col not in br.columns or end_col not in br.columns:
0512         print(f"ERROR(get_relation): {begin_col} or {end_col} not found in {b_name}")
0513         return 0
0514     
0515     loc1 = br.columns.get_loc(begin_col)
0516     loc2 = br.columns.get_loc(end_col)
0517     in_name = "_" + b_name + "_" + v_name
0518     
0519     # Read the relation/vector data
0520     try:
0521         app = get_branch_df(tree, in_name, chunk_size=1000)
0522     except:
0523         print(f"ERROR(get_relation): Cannot read branch {in_name}")
0524         return 0
0525     
0526     if not isinstance(app, pd.DataFrame):
0527         print(f"ERROR(get_relation): {in_name} is not a valid DataFrame")
0528         return 0
0529     
0530     # Vector of float values
0531     if len(app.columns) == 3 and app.columns[2] == "values":
0532         print("Processing vector values...")
0533         l_val = []
0534         l_ind = []
0535         # Vectorized approach for better performance
0536         for row in br.itertuples(index=False):
0537             l_ind.append(row[0])
0538             
0539             i1, i2 = row[loc1], row[loc2]
0540             if i1 == i2:  # empty
0541                 l_val.append(np.array([]))
0542             else:
0543                 # Use entry to locate the correct event
0544                 event_data = app[app['entry'] == row.entry]
0545                 if not event_data.empty:
0546                     v_row = np.array(event_data['values'])[i1:i2]
0547                     l_val.append(v_row)
0548 
0549                 else:
0550                     l_val.append(np.array([]))
0551         
0552         return pd.DataFrame({"entry": l_ind, "values": l_val})
0553     
0554     # Relations with collectionID and index
0555     elif len(app.columns) > 1 and 'index' in app.columns and 'collectionID' in app.columns:
0556         print("Processing relations...")
0557         l_index = []
0558         l_id = []
0559         
0560         # Group app by entry for faster lookup
0561         app_grouped = app.groupby('entry')
0562         
0563         for row in br.itertuples(index=False):
0564             i1, i2 = row[loc1], row[loc2]
0565             if i1 == i2:  # empty
0566                 l_index.append(np.array([]))
0567                 l_id.append(np.array([]))
0568             else:
0569                 # Get the event data
0570                 if row.entry in app_grouped.groups:
0571                     v_row = app_grouped.get_group(row.entry)
0572                     if len(v_row) >= i2:  # Check bounds
0573                         l1 = np.array(v_row["index"])[i1:i2]
0574                         l2 = np.array(v_row["collectionID"])[i1:i2]
0575                         # l2 = np.array(v_row["collectionID"].iloc[0])[i1:i2] if hasattr(v_row["collectionID"].iloc[0], '__iter__') else np.array(v_row["collectionID"])[i1:i2]
0576                         
0577                         l_index.append(l1)
0578                         l_id.append(l2)
0579                     else:
0580                         l_index.append(np.array([]))
0581                         l_id.append(np.array([]))
0582                 else:
0583                     l_index.append(np.array([]))
0584                     l_id.append(np.array([]))
0585         
0586         return pd.DataFrame({"index": l_index, "collectionID": l_id})
0587     
0588     else:
0589         print("ERROR(get_relation): Invalid vector or relation member structure")
0590         print(f"Columns found: {app.columns.tolist()}")
0591         return 0
0592 
0593 def get_branch_relation(tree, branch_name="CentralCKFTrajectories", relation_name="measurements_deprecated", relation_variables=["*"]):
0594     """Get relation or vector members appended to the original branch with optimization"""
0595     print(f"Processing branch relation: {branch_name}.{relation_name}")
0596     if not COL_TABLE:
0597         raise RuntimeError("COL_TABLE not populated. Call get_col_table() first.")    
0598     # Read main branch
0599     br = get_branch_df(tree, branch_name, chunk_size=1000)
0600     df = get_relation(tree, branch_name, relation_name)
0601     
0602     if not isinstance(df, pd.DataFrame):
0603         print("ERROR (get_branch_relation): please provide a valid relation name.")
0604         return br, None
0605     
0606     # Handle vector values case
0607     if len(df.columns) == 1 and df.columns[0] == "values":
0608         return br, df["values"]
0609     
0610     # Handle relations case
0611     br[relation_name + "_index"] = df["index"]
0612     br[relation_name + "_colID"] = df["collectionID"]
0613     
0614     # Return early if no relation variables requested
0615     if len(relation_variables) == 0:
0616         br = br.explode([relation_name + "_index", relation_name + "_colID"]).reset_index(drop=True)
0617         return br, None
0618     
0619     # Prepare collections
0620     in_name = "_" + branch_name + "_" + relation_name
0621     print("Loading collections...")
0622     collections = get_collections(tree, in_name, 0)
0623     
0624     if not collections:
0625         print(f"ERROR: No collections {in_name} found")
0626         return br, None
0627     
0628     # Process relations efficiently
0629     print("Processing relation data...")
0630     l_relations = []
0631     l_collection_names = []
0632     loc1 = br.columns.get_loc(relation_name + "_index")
0633     loc2 = br.columns.get_loc(relation_name + "_colID")
0634     
0635     # Pre-convert collections to grouped DataFrames for faster access
0636     collections_grouped = {}
0637     sample_columns = None
0638     
0639     for col_id, col_data in collections.items():
0640         if hasattr(col_data, 'fields'):
0641             col_df = ak.to_dataframe(col_data).reset_index()
0642         else:
0643             col_df = col_data
0644         collections_grouped[col_id] = col_df.groupby('entry')
0645         if sample_columns is None:
0646             sample_columns = col_df.columns
0647     
0648     # Process each row
0649     for row in br.itertuples(index=False):
0650         ind = row[loc1]
0651         col = row[loc2]
0652         if len(ind) == 0:  # empty relation
0653             l_collection_names.append(None)
0654             # Add NaN row with correct number of columns
0655             if sample_columns is not None:
0656                 l_relations.append([np.nan] * (len(sample_columns) - 2))  # -2 for entry, subentry
0657             else:
0658                 l_relations.append([np.nan])
0659         else:
0660             for ii, cc in zip(ind, col):
0661                 if cc in COL_TABLE:
0662                     l_collection_names.append(COL_TABLE[cc])
0663                     
0664                     # Get data from pre-grouped collections
0665                     if cc in collections_grouped and row.entry in collections_grouped[cc].groups:
0666                         event_data = collections_grouped[cc].get_group(row.entry)
0667                         if ii < len(event_data):
0668                             # Get the row data, excluding entry and subentry columns
0669                             row_data = event_data.iloc[ii]
0670                             filtered_data = [row_data[col] for col in row_data.index if col not in ['entry', 'subentry']]
0671                             l_relations.append(filtered_data)
0672                         else:
0673                             l_relations.append([np.nan] * (len(sample_columns) - 2))
0674                     else:
0675                         l_relations.append([np.nan] * (len(sample_columns) - 2))
0676                 else:
0677                     l_collection_names.append(f"Unknown_{cc}")
0678                     l_relations.append([np.nan] * (len(sample_columns) - 2))
0679     
0680     # Explode the main branch
0681     br = br.explode([relation_name + "_index", relation_name + "_colID"]).reset_index(drop=True)
0682     
0683     # Create the additional DataFrame
0684     if sample_columns is not None:
0685         column_names = [col for col in sample_columns if col not in ['entry', 'subentry']]
0686     else:
0687         column_names = ['value']  # fallback
0688     
0689     df_add = pd.DataFrame(l_relations, columns=column_names)
0690     
0691     # Filter columns based on relation_variables
0692     if relation_variables != ["*"]:
0693         available_columns = select_string(column_names, relation_variables)
0694         df_add = df_add[available_columns]
0695     
0696     # Add metadata
0697     df_add[relation_name + "_colName"] = l_collection_names
0698     
0699     # Add entry and subentry columns at the beginning
0700     df_add.insert(0, 'entry', br['entry'])
0701     df_add.insert(1, 'subentry', br['subentry'])
0702     
0703     print(f"Completed processing. Result shape: {df_add.shape}")
0704     return br, df_add
0705 
0706 def get_traj_relations(tree,bname="CentralCKFTrajectories",l_var=["measurementChi2", "outlierChi2", "trackParameters", "measurements_deprecated", "outliers_deprecated"]): 
0707     # l_var   = ["measurementChi2", "outlierChi2", "trackParameters", "measurements_deprecated", "outliers_deprecated"]#,"seed"]
0708     br      = get_branch_df(tree,bname)
0709     print("get_traj_relations: accessing the following vector members:")
0710     for vv in l_var:
0711         print(vv)
0712         a     = get_relation(tree,bname,vv)
0713         if not isinstance(a, pd.DataFrame):
0714             print(f"WARNING(get_traj_relations): {bname}.{vv} returned no data")
0715             continue
0716         for cc in a.columns:
0717             if cc=='values':
0718                 br[vv]=a[cc]
0719                 break
0720             elif cc=='index':
0721                 br[vv+'_index']=a[cc]
0722             elif cc=='collectionID':
0723                 br[vv+'_colID']=a[cc]
0724             else:
0725                 print('WARNING: invalid column ',cc,' in ',bname,'_',vv)
0726     return br
0727 
0728 
0729 def get_traj_hits_particle(tree, traj_name = 'CentralCKFTrajectories', measurement_name='measurements_deprecated'):
0730     '''
0731     for each trajectory, find corresponding rec hits (measurements_deprecated or outliers_deprecated) and particle index
0732     '''
0733     # traj_name = 'CentralCKFTrajectories'
0734     # measurement_name='measurements_deprecated'
0735 
0736     ## for each measurement, get corresponding rec hit 
0737     _,df_hits = get_branch_relation(tree,branch_name="CentralTrackerMeasurements",relation_name="hits")
0738 
0739     ## prepare all relevant sim hits and find corresponding particle index
0740     # Create a bidirectional mapping dictionary
0741     name_sim_tracker = name_sim_barrel + name_sim_disk + ['B0TrackerHits']
0742     name_rec_tracker = name_rec_barrel + name_rec_disk + ["B0TrackerRecHits"]
0743     def rec_sim_trackerhit_mapping(rec_tracker, sim_tracker):
0744         return {**{rec: sim for rec, sim in zip(rec_tracker, sim_tracker)},
0745                 **{sim: rec for rec, sim in zip(rec_tracker, sim_tracker)}}
0746     HitName_dict = rec_sim_trackerhit_mapping(name_rec_tracker, name_sim_tracker)
0747     # use sim hit to particle relation
0748     col_name = set(df_hits.hits_colName.values)
0749     tracker_hits = {}
0750     for cc in col_name:
0751         try: 
0752             cc = HitName_dict[cc]
0753             tracker_hits[cc] = get_branch_df(tree,cc)
0754         except:
0755             print(f"WARNING: {cc} is not a valid hit collection")
0756             continue
0757         try:
0758             relation = get_branch_df(tree,f"_{cc}_particle")
0759         except:
0760             print(f"WARNING: _{cc}_particle is not a branch")
0761             continue
0762         tracker_hits[cc]["particle_index"] = relation["index"]
0763     ## for each rec hit in measurement, match it with sim hit by cell ID (b/c no relation is available), then append particle index
0764     ## FIXME: this works only b/c all trackers now use the simple si tracker digi algorithm which is one-to-one b/w sim-raw-rec hits
0765     def get_particle_index(col,evtID,cellID):
0766         hits = tracker_hits[HitName_dict[col]]
0767         matched = hits[(hits.entry==evtID)&(hits.cellID==cellID)]
0768         if matched.empty:
0769             return np.nan
0770         return matched.particle_index.values[0]
0771     df_hits['particle_index'] = df_hits.apply(
0772         lambda row: get_particle_index(row['hits_colName'], row['entry'], row['cellID']),
0773         axis=1,
0774     )
0775 
0776     ## for each trajectory, append all measurements, then add corresponding rec hit + particle index
0777     br0,_= get_branch_relation(tree,branch_name=traj_name,relation_name=measurement_name,relation_variables=[])
0778     def get_hits(evtID, hitID):
0779         matched = df_hits[(df_hits.entry==evtID)&(df_hits.subentry==hitID)]
0780         if matched.empty:
0781             return pd.Series({col: np.nan for col in df_hits.columns})
0782         return matched.iloc[0]
0783     temp_hits = br0.apply(lambda row: get_hits(row['entry'], row[measurement_name+'_index']), axis=1)
0784     traj_hits = pd.concat([br0,temp_hits.drop(['entry','subentry'], axis=1, inplace=False)],axis=1)
0785     traj_hits['position.r'] = np.sqrt(traj_hits['position.x']**2+traj_hits['position.y']**2)
0786     traj_hits['position.phi'] = np.arctan2(traj_hits['position.y'], traj_hits['position.x'])
0787 
0788     return traj_hits
0789 
0790 from lmfit.models import GaussianModel
0791 
0792 def gaussian(x, amplitude, mean, std):
0793     return amplitude * np.exp(-0.5 * ((x - mean) / std) ** 2) / (std * np.sqrt(2 * np.pi))
0794 
0795 def hist_gaus(
0796     data, ax,
0797     bins=100, klog=False, header=None,
0798     clip=3.0, max_iters=5, min_points=50,
0799     verbose=False,
0800 ):
0801     data = np.asarray(data, dtype=float)
0802     data = data[np.isfinite(data)]
0803     if data.size < min_points:
0804         if verbose:
0805             print('hist_gaus: not enough finite points')
0806         return np.nan, np.nan, np.nan
0807 
0808     center = np.median(data)
0809     mad = np.median(np.abs(data - center))
0810     scale = 1.4826 * mad if mad > 0 else np.std(data)
0811     if not np.isfinite(scale) or scale <= 0:
0812         if verbose:
0813             print('hist_gaus: invalid initial scale')
0814         return np.nan, np.nan, np.nan
0815 
0816     for _ in range(max_iters):
0817         mask = np.abs(data - center) <= clip * scale
0818         clipped = data[mask]
0819         if clipped.size < min_points:
0820             break
0821         new_center = np.median(clipped)
0822         mad = np.median(np.abs(clipped - new_center))
0823         new_scale = 1.4826 * mad if mad > 0 else np.std(clipped)
0824         if not np.isfinite(new_scale) or new_scale <= 0:
0825             break
0826         if np.isclose(new_center, center) and np.isclose(new_scale, scale):
0827             center, scale = new_center, new_scale
0828             break
0829         center, scale = new_center, new_scale
0830 
0831     lo, hi = center - clip * scale, center + clip * scale
0832     if lo == hi:
0833         if verbose:
0834             print('hist_gaus: degenerate range')
0835         return np.nan, np.nan, np.nan
0836 
0837     counts, edges = np.histogram(data, bins=bins, range=(lo, hi))
0838     mid = 0.5 * (edges[:-1] + edges[1:])
0839 
0840     if ax is not None:
0841         ax.hist(data, bins=bins, range=(lo, hi), histtype='stepfilled', alpha=0.3)
0842 
0843     model = GaussianModel()
0844     params = model.make_params(center=center, sigma=scale, amplitude=np.max(counts))
0845     try:
0846         result = model.fit(counts, params, x=mid)
0847     except Exception as exc:
0848         if verbose:
0849             print('hist_gaus: fit failed', exc)
0850         return np.nan, np.nan, np.nan
0851 
0852     sigma = float(result.params['sigma'])
0853     sigma_err = result.params['sigma'].stderr
0854     if sigma_err is None or not np.isfinite(sigma) or sigma <= 0:
0855         if verbose:
0856             print('hist_gaus: invalid fit result')
0857         return np.nan, np.nan, np.nan
0858 
0859     mean = float(result.params['center'])
0860     sigma_err = float(sigma_err)
0861     ampl = float(result.params['amplitude'])
0862     peak = ampl / (sigma * np.sqrt(2 * np.pi))
0863 
0864     if ax is not None:
0865         ax.plot(mid, gaussian(mid, ampl, mean, sigma), 'r-')
0866         if header:
0867             ax.set_title(header)
0868         ax.set_xlabel('value')
0869         ax.set_ylabel('entries')
0870         if klog:
0871             ax.set_yscale('log')
0872         else:
0873             ymax = max(np.max(counts), peak)
0874             ax.set_ylim(0, ymax / 0.7)
0875 
0876     return mean, sigma, sigma_err
0877 
0878 
0879 __all__ = [
0880     "configure_analysis_environment",
0881     "ak_flat",
0882     "ak_df",
0883     "ak_hist",
0884     "ak_filter",
0885     "get_pdg_info",
0886     "get_geoID",
0887     "theta2eta",
0888     "select_string",
0889     "read_ur",
0890     "get_col_table",
0891     "get_branch_ak",
0892     "get_branch_df",
0893     "get_part",
0894     "get_params",
0895     "get_branches",
0896     "get_collections",
0897     "get_relation",
0898     "get_branch_relation",
0899     "get_traj_relations",
0900     "get_traj_hits_particle",
0901     "deg2rad",
0902     "status_to_source",
0903     "geo_mask_values",
0904     "hist_gaus"
0905 ]