Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:12:06

0001 #!/usr/bin/env python3
0002 from datetime import datetime
0003 import uproot
0004 import pandas as pd
0005 import numpy as np
0006 import logging
0007 import argparse
0008 import matplotlib as mpl
0009 import matplotlib.pyplot as plt
0010 import json
0011 import os
0012 
0013 from pathlib import Path
0014 
0015 
0016 def run_error_parametriation(
0017     rfile,
0018     digi_cfg,
0019     volumes,
0020     output_dir=Path.cwd(),
0021     json_out="rms_out.json",
0022     break_min_stat=5000,
0023     break_rms_change=0.05,
0024     break_cluster_size=5,
0025     view_colors=["deepskyblue", "gold"],
0026     view_rms_range=5,
0027 ):
0028     # Create a figure directory
0029     output_dir.mkdir(parents=True, exist_ok=True)
0030     output_html_dir = output_dir / "html"
0031     output_html_dir.mkdir(parents=True, exist_ok=True)
0032     output_fig_dir = output_html_dir / "plots"
0033     output_fig_dir.mkdir(parents=True, exist_ok=True)
0034 
0035     volume_links = ""
0036 
0037     logging.info(f"Hit error parameterisation for {len(volumes)} volumes")
0038 
0039     var_dict = {}
0040     header_dict = {}
0041     header_dict["format-version"] = 0
0042     header_dict["value-identifier"] = "hit-error-parametrisation"
0043     var_dict["acts-geometry-hierarchy-map"] = header_dict
0044 
0045     var_entries = []
0046 
0047     # loop over the volumes
0048     for iv, v_id_n in enumerate(volumes):
0049         v_id, v_name = v_id_n
0050 
0051         logging.info(f"Processing volume {v_name} with ID: {v_id}")
0052 
0053         # previous and next volume
0054         prev_id = volumes[iv - 1] if iv > 0 else volumes[-1]
0055         next_id = volumes[iv + 1] if iv < len(volumes) - 1 else volumes[0]
0056 
0057         # Get the volume
0058         v_id_str = "vol" + str(v_id)
0059         vol = rfile[v_id_str].arrays(library="pd")
0060 
0061         # RMS matrix
0062         max_size_0 = 1
0063         max_size_1 = 1
0064 
0065         # We should be able to get this from the volume
0066         local_values = []
0067         if "clus_size_loc0" in vol.columns and vol["clus_size_loc0"].any():
0068             logging.info(f" - local 0 coorindate found")
0069             local_values.append(0)
0070         if "clus_size_loc1" in vol.columns and vol["clus_size_loc1"].any():
0071             local_values.append(1)
0072             logging.info(f" - local 1 coorindate found")
0073 
0074         var_matrix = np.zeros((2, break_cluster_size))
0075         var_entry = {"volume": v_id}
0076         var_data = []
0077 
0078         # write html content
0079         plots = []
0080         # Loop over the local variables
0081         for l in local_values:
0082             # Local var_data
0083             rms_local_values = {"index": l}
0084             rms_local_data = []
0085             # The plots per column
0086             lplots = []
0087             # Overview plot
0088             plt.hist(
0089                 vol["clus_size_loc" + str(l)],
0090                 bins=range(1, max(vol["clus_size_loc" + str(l)]) + 3),
0091                 histtype="step",
0092                 fill=True,
0093                 color=view_colors[l],
0094             )
0095             plt.xlabel("Cluster size local " + str(l))
0096             plt.ylabel("Entries")
0097             # Create the svg path
0098             svg_path = output_fig_dir / f"{v_id_str}_clus_size_loc{l}.svg"
0099             plt.savefig(svg_path)
0100             lplots.append(svg_path)
0101             plt.clf()
0102             # Resolution plot, break
0103             max_clus_size = max(vol["clus_size_loc" + str(l)]) + 1
0104             if max_clus_size > break_cluster_size:
0105                 max_clus_size = break_cluster_size
0106             # loop over the cluster sizes
0107             for c_size in range(1, max_clus_size):
0108                 # Break conditions: not enough change, not enough statistics
0109                 break_condition = False
0110                 # Select the cluster size
0111                 vol_sel = vol[vol["clus_size_loc" + str(l)] == c_size]
0112                 # Plot the resolution
0113                 res = vol_sel["rec_loc" + str(l)] - vol_sel["true_loc" + str(l)]
0114                 rms = np.std(res)
0115                 var_matrix[l, c_size] = rms * rms
0116                 rms_local_data.append(float(rms * rms))
0117                 c_size_flag = str(c_size)
0118                 # Peak into next selection
0119                 next_sel = vol[vol["clus_size_loc" + str(l)] == c_size + 1]
0120                 if not next_sel.empty:
0121                     # Check if enough statistics
0122                     next_res = (
0123                         next_sel["rec_loc" + str(l)] - next_sel["true_loc" + str(l)]
0124                     )
0125                     if (
0126                         len(next_sel) < break_min_stat
0127                         or abs(rms - np.std(next_res)) / rms < break_rms_change
0128                     ):
0129                         # Recaluate with rest
0130                         vol_sel = vol[vol["clus_size_loc" + str(l)] >= c_size]
0131                         res = vol_sel["rec_loc" + str(l)] - vol_sel["true_loc" + str(l)]
0132                         # Set the new cluster size
0133                         c_size_flag = "N"
0134                         # Set the break condition
0135                         break_condition = True
0136 
0137                 # Plot the resolution within +/- n rms
0138                 plt.hist(
0139                     res,
0140                     bins=100,
0141                     range=(-view_rms_range * rms, view_rms_range * rms),
0142                     histtype="step",
0143                     fill=True,
0144                     color=view_colors[l],
0145                 )
0146                 plt.text(
0147                     0.05,
0148                     0.95,
0149                     "rms = " + str(round(rms, 3)),
0150                     transform=plt.gca().transAxes,
0151                     fontsize=14,
0152                     verticalalignment="top",
0153                 )
0154                 plt.xlabel(
0155                     "Resolution - local " + str(l) + ", cluster size " + c_size_flag
0156                 )
0157                 # Save the figure
0158                 svg_path = (
0159                     output_fig_dir / f"{v_id_str}_res_loc{l}_clus_size{c_size_flag}.svg"
0160                 )
0161                 plt.savefig(svg_path)
0162                 lplots.append(svg_path)
0163                 plt.clf()
0164                 if break_condition:
0165                     break
0166             # Add the rms data
0167             rms_local_values["rms"] = rms_local_data
0168             var_data.append(rms_local_values)
0169             # Add the plots to the column
0170             plots.append(lplots)
0171 
0172         # Add the rms data to the dictionary
0173         var_entry["value"] = var_data
0174         var_entries.append(var_entry)
0175         var_dict["entries"] = var_entries
0176 
0177         # Write the rms dictionary
0178         if digi_cfg is not None:
0179             # Update the json
0180             digi_cfg_entries = digi_cfg["entries"]
0181             for entry in digi_cfg_entries:
0182                 if entry["volume"] == v_id:
0183                     entry["value"]["geometric"]["variances"] = var_data
0184 
0185             with open(json_out, "w") as outfile:
0186                 json.dump(digi_cfg, outfile, indent=4)
0187         else:
0188             with open(json_out, "w") as outfile:
0189                 json.dump(var_dict, outfile, indent=4)
0190 
0191         # The matrix plot
0192         fig, ax = plt.subplots(ncols=1, nrows=1)
0193         pos = ax.matshow(var_matrix, cmap="Blues")
0194         plt.xlabel("Cluster size")
0195         plt.ylabel("Local coordinate")
0196         plt.title(v_name)
0197         fig.colorbar(pos, ax=ax)
0198         svg_path = output_fig_dir / f"{v_id_str}_summary.svg"
0199         plt.savefig(svg_path)
0200         plt.clf()
0201 
0202         # Create the html content
0203         plot_content = ""
0204 
0205         for ip in range(max([len(p) for p in plots])):
0206             for ic in range(len(plots)):
0207                 if ip < len(plots[ic]):
0208                     plot_content += f"<div>{plots[ic][ip].read_text()}</div>"
0209                 else:
0210                     plot_content += f"<div></div>"
0211 
0212         volume_links += f'<div><a href="html/volume_{v_id}.html"><div>{svg_path.read_text()}</div></a></div>'
0213 
0214         volume_file = output_html_dir / f"volume_{v_id}.html"
0215         previous_file = output_html_dir / f"volume_{prev_id[0]}.html"
0216         next_file = output_html_dir / f"volume_{next_id[0]}.html"
0217 
0218         volume_file.write_text(
0219             """<!DOCTYPE html>
0220 <html>
0221 <head>
0222     <title>Error Parameterisation</title>
0223     <style>
0224     .wrapper {{
0225         max-width: 1500px;
0226         margin: 0 auto;
0227     }}
0228     .grid {{
0229         display: grid;
0230         grid-template-columns: repeat(2, 50%);
0231     }}
0232     .grid svg {{
0233         width:100%;
0234         height:auto;
0235     }}
0236     </style>
0237 </head>
0238 <body>
0239 <div class="wrapper">
0240     <a href="{previous}">Previous volume</a> |
0241     <a href="../index.html">Back to index</a> |
0242     <a href="{next}">Next volume</a><br>
0243     <h1>Error Parameterisation : volume {vid} </h1>
0244     Generated: {date}<br>
0245     <div class="grid">
0246     {content}
0247     </div>
0248 </div>
0249 </body>
0250 </html>
0251     """.format(
0252                 vid=v_id,
0253                 previous=str(previous_file),
0254                 next=str(next_file),
0255                 content=plot_content,
0256                 date=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
0257             )
0258         )
0259 
0260         # Write the index file
0261         index_file = output_dir / "index.html"
0262         index_file.write_text(
0263             """<!DOCTYPE html>
0264 <html>
0265 <body>
0266 <div class="wrapper">
0267 <h1>Error Parameterisation</h1>
0268 Generated: {date}<br>
0269 <div class="grid">
0270 {volume_content}
0271 </div>
0272 </div>
0273 </body>
0274 </html>
0275         """.format(
0276                 date=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
0277                 volume_content=volume_links,
0278             )
0279         )
0280 
0281 
0282 # Main function
0283 if "__main__" == __name__:
0284     # Parse the command line arguments
0285     p = argparse.ArgumentParser(description="Hit parameterisation")
0286     p.add_argument("--root")
0287     p.add_argument("--json-in")
0288     p.add_argument("--json-out")
0289     args = p.parse_args()
0290 
0291     # Open the root file
0292     rfile = uproot.open(args.root)
0293     volumes = [
0294         (16, "Pixel NEC"),
0295         (17, "Pixel Barrel"),
0296         (18, "Pixel PEC"),
0297         (23, "SStrips NEC"),
0298         (24, "SStrips Barrel"),
0299         (25, "SStrips PEC"),
0300         (28, "LStrips NEC"),
0301         (29, "LStrips Barrel"),
0302         (30, "LStrips PEC"),
0303     ]
0304 
0305     # Open the json to be updated
0306     digi_cfg = None
0307     if (
0308         args.json_in is not None
0309         and os.path.isfile(args.json_in)
0310         and os.access(args.json_in, os.R_OK)
0311     ):
0312         jfile = open(args.json_in, "r")
0313         digi_cfg = json.load(jfile)
0314 
0315     logging.basicConfig(encoding="utf-8", level=logging.INFO)
0316 
0317     run_error_parametriation(
0318         rfile, digi_cfg, volumes, Path.cwd() / "output", args.json_out
0319     )