Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-15 08:14:45

0001 #!/usr/bin/env python3
0002 from datetime import datetime
0003 import uproot
0004 import pandas as pd
0005 import numpy as np
0006 import logging
0007 import argparse
0008 import matplotlib as mpl
0009 import matplotlib.pyplot as plt
0010 import json
0011 import math
0012 import os
0013 
0014 from pathlib import Path
0015 
0016 
0017 def run_error_parametriation(
0018     rfile,
0019     digi_cfg,
0020     volumes,
0021     output_dir=Path.cwd(),
0022     json_out="rms_out.json",
0023     break_min_stat=5000,
0024     break_rms_change=0.05,
0025     break_cluster_size=5,
0026     view_colors=["deepskyblue", "gold"],
0027     view_rms_range=5,
0028     plot_pulls=False,
0029     pull_view_colors=["steelblue", "goldenrod"],
0030 ):
0031     # Create a figure directory
0032     output_dir.mkdir(parents=True, exist_ok=True)
0033     output_html_dir = output_dir / "html"
0034     output_html_dir.mkdir(parents=True, exist_ok=True)
0035     output_fig_dir = output_html_dir / "plots"
0036     output_fig_dir.mkdir(parents=True, exist_ok=True)
0037 
0038     volume_links = ""
0039 
0040     logging.info(f"Hit error parameterisation for {len(volumes)} volumes")
0041 
0042     var_dict = {}
0043     header_dict = {}
0044     header_dict["format-version"] = 0
0045     header_dict["value-identifier"] = "hit-error-parametrisation"
0046     var_dict["acts-geometry-hierarchy-map"] = header_dict
0047 
0048     var_entries = []
0049 
0050     measurements = rfile["measurements"].arrays(library="pd")
0051     # Make a new column with g_r = sqrt(x^2+y^2)
0052     measurements["rec_gr"] = np.sqrt(
0053         measurements["rec_gx"] ** 2 + measurements["rec_gy"] ** 2
0054     )
0055     measurements["true_r"] = np.sqrt(
0056         measurements["true_x"] ** 2 + measurements["true_y"] ** 2
0057     )
0058 
0059     plt.scatter(x=measurements["rec_gz"], y=measurements["rec_gr"], s=1, alpha=0.1)
0060     plt.xlabel("z [mm]")
0061     plt.ylabel("r [mm]")
0062     plt.title("Reconstructed hit positions")
0063     plt.savefig(output_fig_dir / "overview_rec_hit_positions.png")
0064     plt.clf()
0065     volume_overview = (
0066         '<div><img src="'
0067         + str(output_fig_dir / "overview_rec_hit_positions.png")
0068         + '" alt="Rec hit positions">'
0069     )
0070 
0071     plt.scatter(
0072         x=measurements["true_z"], y=measurements["true_r"], s=1, alpha=0.1, c="orange"
0073     )
0074     plt.xlabel("z [mm]")
0075     plt.ylabel("r [mm]")
0076     plt.title("True hit positions")
0077     plt.savefig(output_fig_dir / "overview_true_hit_positions.png")
0078     volume_overview += (
0079         '<img src="'
0080         + str(output_fig_dir / "overview_true_hit_positions.png")
0081         + '" alt="True hit positions"></div>'
0082     )
0083     plt.clf()
0084 
0085     # loop over the volumes
0086     for iv, v_id_n in enumerate(volumes):
0087         v_id, v_name = v_id_n
0088 
0089         logging.info(f"Processing volume {v_name} with ID: {v_id}")
0090 
0091         # previous and next volume
0092         prev_id = volumes[iv - 1] if iv > 0 else volumes[-1]
0093         next_id = volumes[iv + 1] if iv < len(volumes) - 1 else volumes[0]
0094 
0095         # Get the volume
0096         vol = measurements[measurements["volume_id"] == v_id]
0097         v_id_str = "volume_" + str(v_id)
0098 
0099         # RMS matrix
0100         max_size_0 = 1
0101         max_size_1 = 1
0102 
0103         # We should be able to get this from the volume
0104         local_values = []
0105         if not np.isnan(vol["rec_loc0"]).any():
0106             logging.info(f" - local 0 coorindate found")
0107             local_values.append(0)
0108         if not np.isnan(vol["rec_loc1"]).any():
0109             local_values.append(1)
0110             logging.info(f" - local 1 coorindate found")
0111 
0112         # variance matrix
0113         rms_matrix = np.zeros((2, break_cluster_size))
0114         var_entry = {"volume": v_id}
0115         var_data = []
0116 
0117         # pull matrix
0118         pull_matrix = np.zeros((2, break_cluster_size))
0119 
0120         # write html content
0121         plots = []
0122         # Loop over the local variables
0123         for l in local_values:
0124             # Local var_data
0125             rms_local_values = {"index": l}
0126             rms_local_data = []
0127             # The plots per column
0128             lplots = []
0129             # Overview plot
0130             plt.hist(
0131                 vol["clus_size_loc" + str(l)],
0132                 bins=range(1, max(vol["clus_size_loc" + str(l)]) + 3),
0133                 histtype="step",
0134                 fill=True,
0135                 color=view_colors[l],
0136             )
0137             plt.xlabel("Cluster size local " + str(l))
0138             plt.ylabel("Entries")
0139             # Create the svg path
0140             svg_path = output_fig_dir / f"{v_id_str}_clus_size_loc{l}.svg"
0141             plt.savefig(svg_path)
0142             lplots.append(svg_path)
0143             plt.clf()
0144             # Resolution plot, break
0145             max_clus_size = max(vol["clus_size_loc" + str(l)]) + 1
0146             if max_clus_size > break_cluster_size:
0147                 max_clus_size = break_cluster_size
0148             # loop over the cluster sizes
0149             for c_size in range(1, max_clus_size):
0150                 # Break conditions: not enough change, not enough statistics
0151                 break_condition = False
0152                 # Select the cluster size
0153                 vol_sel = vol[vol["clus_size_loc" + str(l)] == c_size]
0154                 # Plot the resolution
0155                 res = vol_sel["rec_loc" + str(l)] - vol_sel["true_loc" + str(l)]
0156                 rms = np.std(res)
0157                 rms_matrix[l, c_size] = rms
0158                 rms_local_data.append(float(rms))
0159                 # Plot the pull distributions
0160                 rms_pull = 0
0161                 if plot_pulls:
0162                     pull = res / vol_sel["var_loc" + str(l)].apply(np.sqrt)
0163                     rms_pull = np.std(pull)
0164                     pull_matrix[l, c_size] = rms_pull
0165 
0166                 c_size_flag = str(c_size)
0167                 # Peak into next selection
0168                 next_sel = vol[vol["clus_size_loc" + str(l)] == c_size + 1]
0169                 if not next_sel.empty:
0170                     # Check if enough statistics
0171                     next_res = (
0172                         next_sel["rec_loc" + str(l)] - next_sel["true_loc" + str(l)]
0173                     )
0174                     if (
0175                         len(next_sel) < break_min_stat
0176                         or abs(rms - np.std(next_res)) / rms < break_rms_change
0177                     ):
0178                         # Recaluate with rest
0179                         vol_sel = vol[vol["clus_size_loc" + str(l)] >= c_size]
0180                         res = vol_sel["rec_loc" + str(l)] - vol_sel["true_loc" + str(l)]
0181                         # Set the new cluster size
0182                         c_size_flag = "N"
0183                         # Set the break condition
0184                         break_condition = True
0185 
0186                 # Plot the resolution within +/- n rms
0187                 plt.hist(
0188                     res,
0189                     bins=100,
0190                     range=(-view_rms_range * rms, view_rms_range * rms),
0191                     histtype="step",
0192                     fill=True,
0193                     color=view_colors[l],
0194                 )
0195                 plt.text(
0196                     0.05,
0197                     0.95,
0198                     "rms = " + str(round(rms, 3)),
0199                     transform=plt.gca().transAxes,
0200                     fontsize=14,
0201                     verticalalignment="top",
0202                 )
0203                 plt.xlabel(
0204                     "Resolution - local " + str(l) + ", cluster size " + c_size_flag
0205                 )
0206                 # Save the figure
0207                 svg_path = (
0208                     output_fig_dir / f"{v_id_str}_res_loc{l}_clus_size{c_size_flag}.svg"
0209                 )
0210                 plt.savefig(svg_path)
0211                 lplots.append(svg_path)
0212                 plt.clf()
0213                 if plot_pulls:
0214                     plt.hist(
0215                         pull,
0216                         bins=100,
0217                         range=(-view_rms_range, view_rms_range),
0218                         histtype="step",
0219                         fill=True,
0220                         color=pull_view_colors[l],
0221                     )
0222                     plt.text(
0223                         0.05,
0224                         0.95,
0225                         "rms = " + str(round(rms_pull, 3)),
0226                         transform=plt.gca().transAxes,
0227                         fontsize=14,
0228                         verticalalignment="top",
0229                     )
0230                     plt.xlabel(
0231                         "Pull - local " + str(l) + ", cluster size " + c_size_flag
0232                     )
0233                     # Save the figure
0234                     svg_path = (
0235                         output_fig_dir
0236                         / f"{v_id_str}_pull_loc{l}_clus_size{c_size_flag}.svg"
0237                     )
0238                     plt.savefig(svg_path)
0239                     lplots.append(svg_path)
0240                     plt.clf()
0241 
0242                 if break_condition:
0243                     break
0244             # Add the rms data
0245             rms_local_values["rms"] = rms_local_data
0246             var_data.append(rms_local_values)
0247             # Add the plots to the column
0248             plots.append(lplots)
0249 
0250         # Add the rms data to the dictionary
0251         var_entry["value"] = var_data
0252         var_entries.append(var_entry)
0253         var_dict["entries"] = var_entries
0254 
0255         # Write the rms dictionary
0256         if digi_cfg is not None:
0257             # Update the json
0258             digi_cfg_entries = digi_cfg["entries"]
0259             for entry in digi_cfg_entries:
0260                 if entry["volume"] == v_id:
0261                     entry["value"]["geometric"]["variances"] = var_data
0262 
0263             with open(json_out, "w") as outfile:
0264                 json.dump(digi_cfg, outfile, indent=4)
0265         else:
0266             with open(json_out, "w") as outfile:
0267                 json.dump(var_dict, outfile, indent=4)
0268 
0269         # The matrix plot - variances
0270         fig, ax = plt.subplots(ncols=1, nrows=1)
0271         pos = ax.matshow(rms_matrix, cmap="Blues")
0272         plt.xlabel("Cluster size")
0273         plt.ylabel("Local coordinate")
0274         plt.title(v_name)
0275         fig.colorbar(pos, ax=ax, label="RMS")
0276         svg_path = output_fig_dir / f"{v_id_str}_summary.svg"
0277         plt.savefig(svg_path)
0278         plt.clf()
0279 
0280         # The matrix plot - pulls
0281         fig, ax = plt.subplots(ncols=1, nrows=1)
0282         pos = ax.matshow(pull_matrix, cmap="Reds")
0283         plt.xlabel("Cluster size")
0284         plt.ylabel("Local coordinate")
0285         plt.title(v_name)
0286         fig.colorbar(pos, ax=ax, label="RMS (pull)")
0287         pull_svg_path = output_fig_dir / f"{v_id_str}_pull_summary.svg"
0288         plt.savefig(pull_svg_path)
0289         plt.clf()
0290 
0291         # Create the html content
0292         plot_content = ""
0293 
0294         for ip in range(max([len(p) for p in plots])):
0295             for ic in range(len(plots)):
0296                 if ip < len(plots[ic]):
0297                     plot_content += f"<div>{plots[ic][ip].read_text()}</div>"
0298                 else:
0299                     plot_content += f"<div></div>"
0300 
0301         volume_links += (
0302             f'<div><a href="html/volume_{v_id}.html"><div>{svg_path.read_text()}'
0303         )
0304         if plot_pulls:
0305             volume_links += f"{pull_svg_path.read_text()}"
0306         volume_links += "</div></a></div>"
0307 
0308         volume_file = output_html_dir / f"volume_{v_id}.html"
0309         previous_file = output_html_dir / f"volume_{prev_id[0]}.html"
0310         next_file = output_html_dir / f"volume_{next_id[0]}.html"
0311 
0312         volume_file.write_text(
0313             """<!DOCTYPE html>
0314 <html>
0315 <head>
0316     <title>Error Parameterisation</title>
0317     <style>
0318     .wrapper {{
0319         max-width: 1500px;
0320         margin: 0 auto;
0321     }}
0322     .grid {{
0323         display: grid;
0324         grid-template-columns: repeat(2, 50%);
0325     }}
0326     .grid svg {{
0327         width:100%;
0328         height:auto;
0329     }}
0330     </style>
0331 </head>
0332 <body>
0333 <div class="wrapper">
0334     <a href="{previous}">Previous volume</a> |
0335     <a href="../index.html">Back to index</a> |
0336     <a href="{next}">Next volume</a><br>
0337     <h1>Error Parameterisation : volume {vid} </h1>
0338     Generated: {date}<br>
0339     <div class="grid">
0340     {content}
0341     </div>
0342 </div>
0343 </body>
0344 </html>
0345     """.format(
0346                 vid=v_id,
0347                 previous=str(previous_file),
0348                 next=str(next_file),
0349                 content=plot_content,
0350                 date=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
0351             )
0352         )
0353 
0354         # Write the index file
0355         index_file = output_dir / "index.html"
0356         index_file.write_text(
0357             """<!DOCTYPE html>
0358 <html>
0359 <body>
0360 <div class="wrapper">
0361 <h1>Error Parameterisation</h1>
0362 Generated: {date}<br>
0363 <div class="grid">
0364 {overall_content}
0365 </div>
0366 <div class="grid">
0367 {volume_content}
0368 </div>
0369 </div>
0370 </body>
0371 </html>
0372         """.format(
0373                 date=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
0374                 overall_content=volume_overview,
0375                 volume_content=volume_links,
0376             )
0377         )
0378 
0379 
0380 # Main function
0381 if "__main__" == __name__:
0382     # Parse the command line arguments
0383     p = argparse.ArgumentParser(description="Hit parameterisation")
0384     p.add_argument("--root")
0385     p.add_argument("--json-in")
0386     p.add_argument("--json-out")
0387     p.add_argument("--plot-pulls", action="store_true")
0388     p.add_argument(
0389         "--volumes-ids",
0390         nargs="+",
0391         type=int,
0392         default=[16, 17, 18, 23, 24, 25, 28, 29, 30],
0393     )
0394     p.add_argument(
0395         "--volume-names",
0396         nargs="+",
0397         type=str,
0398         default=[
0399             "Pixel NEC",
0400             "Pixel Barrel",
0401             "Pixel PEC",
0402             "SStrips NEC",
0403             "SStrips Barrel",
0404             "SStrips PEC",
0405             "LStrips NEC",
0406             "LStrips Barrel",
0407             "LStrips PEC",
0408         ],
0409     )
0410     args = p.parse_args()
0411 
0412     # Open the root file
0413     rfile = uproot.open(args.root)
0414 
0415     # For the current ODD this would be
0416     if len(args.volumes_ids) != len(args.volume_names):
0417         raise ValueError("Volume IDs and names must have the same length")
0418 
0419     volumes = list(zip(args.volumes_ids, args.volume_names))
0420 
0421     # Open the json to be updated
0422     digi_cfg = None
0423     if (
0424         args.json_in is not None
0425         and os.path.isfile(args.json_in)
0426         and os.access(args.json_in, os.R_OK)
0427     ):
0428         jfile = open(args.json_in, "r")
0429         digi_cfg = json.load(jfile)
0430 
0431     logging.basicConfig(encoding="utf-8", level=logging.INFO)
0432 
0433     run_error_parametriation(
0434         rfile,
0435         digi_cfg,
0436         volumes,
0437         Path.cwd() / "output",
0438         args.json_out,
0439         plot_pulls=args.plot_pulls,
0440     )