Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-03-28 07:46:06

0001 import numpy as np
0002 import matplotlib.pyplot as plt
0003 import argparse
0004 import json
0005 import math
0006 
0007 
0008 # Define a super class of an axis, a regular and a vraiable type
0009 class Axis:
0010     def __init__(self, name, bins, range):
0011         self.name = name
0012         self.bins = bins
0013         self.range = range
0014 
0015 
0016 class RegularAxis(Axis):
0017     def __init__(self, name, bins, range):
0018         super().__init__(name, bins, range)
0019 
0020     def get_edges(self):
0021         return np.linspace(self.range[0], self.range[1], self.bins + 1)
0022 
0023 
0024 class VariableAxis(Axis):
0025     def __init__(self, name, edges: np.ndarray):
0026         super().__init__(name, len(edges) - 1, (edges[0], edges[-1]))
0027         self.edges = edges
0028         self.bins = len(edges) - 1
0029 
0030     def get_edges(self):
0031         return self.edges
0032 
0033 
0034 def sort_vertices(policy_type: str, vertices: list):
0035     """Sorts the vertices in counter-clockwise order around their centroid."""
0036     if policy_descr == "Disc":
0037         return vertices
0038 
0039     # Calculate the centroid of the vertices
0040     centroid = [
0041         sum(v[0] for v in vertices) / len(vertices),
0042         sum(v[1] for v in vertices) / len(vertices),
0043     ]
0044 
0045     # Sort the vertices based on the angle from the centroid
0046     sorted_vertices = sorted(
0047         vertices, key=lambda v: math.atan2(v[1] - centroid[1], v[0] - centroid[0])
0048     )
0049 
0050     return sorted_vertices
0051 
0052 
0053 def correct_vertices(
0054     policy_type: str,
0055     vertices: list,
0056     wrap_threshold=1.5 * math.pi,
0057     overlap_threshold=0.1,
0058 ):
0059     """Corrects the vertices for wrapping around the -pi to pi boundary in the phi coordinate."""
0060     if policy_type == "Plane":
0061         return vertices, []
0062     if policy_type == "Disc" or policy_type == "Ring":
0063         # bring vertices from (r,phi) to (x,y) for drawing
0064         corrected_vertices = []
0065         for v in vertices:
0066             x = v[0] * math.cos(v[1])
0067             y = v[0] * math.sin(v[1])
0068             corrected_vertices.append([x, y])
0069         return corrected_vertices, []
0070     # make a deep copy of the vertices
0071     corrected_vertices = vertices.copy()
0072     # sort them in the second coordinate
0073     corrected_vertices.sort(key=lambda v: v[1])
0074     # check if first and last are of bigger than the threshold
0075     diff_phi = abs(corrected_vertices[0][1] - corrected_vertices[-1][1])
0076     half_phi = 0.5 * abs(corrected_vertices[0][1] + corrected_vertices[-1][1])
0077     # sort then in the first coordinate
0078     corrected_vertices.sort(key=lambda v: v[0])
0079     diff_z = abs(corrected_vertices[0][0] - corrected_vertices[-1][0])
0080 
0081     if diff_phi > wrap_threshold:
0082         corrected_vertices = []
0083         mirror_vertices = []
0084         # we have a wrap situation
0085         first_wrap_point = False
0086         for v in vertices:
0087             if v[1] < -overlap_threshold * half_phi:
0088                 # add 2pi to negative phi values
0089                 corr_phi = v[1] + 2 * math.pi
0090                 corrected_vertices.append([v[0], corr_phi])
0091                 mirror_vertices.append(v)
0092                 if first_wrap_point == False:
0093                     first_wrap_point = True
0094             else:
0095                 corrected_vertices.append(v)
0096                 mirror_vertices.append([v[0], v[1] - 2 * math.pi])
0097         return corrected_vertices, mirror_vertices
0098     else:
0099         return vertices, []
0100 
0101 
0102 def plot_rectangular_grid(
0103     x: Axis, y: Axis, grid_data: np.ndarray, add_text=True, add_lines=True
0104 ):
0105     """Helper method to plot the rectangular grid given the two axes and the grid data."""
0106     x_edges = x.get_edges()
0107     y_edges = y.get_edges()
0108 
0109     # 3. Create the plot
0110     fig, ax = plt.subplots(figsize=(8, 6))
0111 
0112     # use a colormap with white for zero entries
0113     cmap = plt.cm.magma_r
0114     cmap.set_bad("white")
0115 
0116     # Add grid lines if requested
0117     if add_lines:
0118         for edge in x_edges:
0119             ax.axvline(edge, color="black", linestyle="--", linewidth=0.5)
0120         for edge in y_edges:
0121             ax.axhline(edge, color="black", linestyle="--", linewidth=0.5)
0122 
0123     # Plot the grid
0124     plt.hist2d(
0125         x=np.repeat((x_edges[:-1] + x_edges[1:]) / 2, y.bins),
0126         y=np.tile((y_edges[:-1] + y_edges[1:]) / 2, x.bins),
0127         bins=[x_edges, y_edges],
0128         weights=grid_data.flatten() if grid_data is not None else None,
0129         cmap=cmap,
0130     )
0131 
0132     # 3. Add text annotations to the bins
0133     if add_text and grid_data is not None:
0134         for i in range(len(x_edges) - 1):
0135             for j in range(len(y_edges) - 1):
0136                 # Get the count value for the current bin
0137                 count = grid_data[i, j]
0138                 # Calculate the center position of the bin
0139                 center_x = (x_edges[i] + x_edges[i + 1]) / 2
0140                 center_y = (y_edges[j] + y_edges[j + 1]) / 2
0141 
0142                 # Only add text if the count is greater than 0
0143                 if count > 0:
0144                     # the color is black or white depending on the background color
0145                     text_color = "white" if count > 2 else "black"
0146                     # Place the text using plt.text()
0147                     ax.text(
0148                         center_x,
0149                         center_y,
0150                         int(count),
0151                         color=text_color,
0152                         ha="center",
0153                         va="center",
0154                         fontsize=8,
0155                         fontweight="bold",
0156                     )
0157     # Add labels and title
0158     ax.set_xlabel(f"{x.name}")
0159     ax.set_ylabel(f"{y.name}")
0160 
0161     return fig, ax
0162 
0163 
0164 def plot_polar_grid(
0165     r: Axis, phi: Axis, grid_data: np.ndarray, add_text=True, add_lines=True
0166 ):
0167     """Helper method to plot the polar grid given the two axes and the grid data."""
0168     r_edges = r.get_edges()
0169     phi_edges = phi.get_edges()
0170 
0171     # use a colormap with white for zero entries
0172     cmap = plt.cm.magma_r
0173     cmap.set_bad("white")
0174 
0175     # Draw the polar grid sectors in cartesian coordinates
0176     for i in range(len(r_edges) - 1):
0177         for j in range(len(phi_edges) - 1):
0178             # Define the sector vertices
0179             sector_vertices = []
0180             sector_vertices.append(
0181                 [
0182                     r_edges[i] * math.cos(phi_edges[j]),
0183                     r_edges[i] * math.sin(phi_edges[j]),
0184                 ]
0185             )
0186             sector_vertices.append(
0187                 [
0188                     r_edges[i + 1] * math.cos(phi_edges[j]),
0189                     r_edges[i + 1] * math.sin(phi_edges[j]),
0190                 ]
0191             )
0192             sector_vertices.append(
0193                 [
0194                     r_edges[i + 1] * math.cos(phi_edges[j + 1]),
0195                     r_edges[i + 1] * math.sin(phi_edges[j + 1]),
0196                 ]
0197             )
0198             sector_vertices.append(
0199                 [
0200                     r_edges[i] * math.cos(phi_edges[j + 1]),
0201                     r_edges[i] * math.sin(phi_edges[j + 1]),
0202                 ]
0203             )
0204             # Create a polygon for the sector
0205             polygon = plt.Polygon(
0206                 sector_vertices,
0207                 closed=True,
0208                 facecolor=(
0209                     cmap(grid_data[i, j] / np.nanmax(grid_data))
0210                     if grid_data is not None
0211                     else "white"
0212                 ),
0213                 edgecolor=None,
0214                 alpha=1.0,
0215             )
0216             # Add the polygon to the plot
0217             plt.gca().add_patch(polygon)
0218 
0219     # Add grid lines if requested
0220     if add_lines:
0221         for edge in r_edges:
0222             circle = plt.Circle(
0223                 (0, 0), edge, color="black", fill=False, linestyle="--", linewidth=0.5
0224             )
0225             ax.add_artist(circle)
0226         for edge in phi_edges:
0227             x = [r.range[0] * math.cos(edge), r.range[1] * math.cos(edge)]
0228             y = [r.range[0] * math.sin(edge), r.range[1] * math.sin(edge)]
0229             ax.plot(x, y, color="black", linestyle="--", linewidth=0.5)
0230 
0231     # Plot the grid
0232     # plt.hist2d(
0233     #    x=np.repeat(
0234     #        np.array([r_edge * math.cos(phi_edge + 0.5 * (phi_edges[1] - phi_edges[0])) for r_edge in (r_edges[:-1] + r_edges[1:]) / 2 for phi_edge in phi_edges[:-1]]),
0235     #        1,
0236     #    ),
0237     #    y=np.repeat(
0238     #        np.array([r_edge * math.sin(phi_edge + 0.5 * (phi_edges[1] - phi_edges[0])) for r_edge in (r_edges[:-1] + r_edges[1:]) / 2 for phi_edge in phi_edges[:-1]]),
0239     #        1,
0240     #    ),
0241     #    bins=[r_edges, phi_edges],
0242     #    weights=grid_data.flatten() if grid_data is not None else None,
0243     #    cmap=cmap,
0244     # )
0245     if add_text and grid_data is not None:
0246         for i in range(len(r_edges) - 1):
0247             for j in range(len(phi_edges) - 1):
0248                 # Get the count value for the current bin
0249                 count = grid_data[i, j]
0250                 # Calculate the center position of the bin
0251                 center_x = (
0252                     (r_edges[i] + r_edges[i + 1])
0253                     / 2
0254                     * math.cos((phi_edges[j] + phi_edges[j + 1]) / 2)
0255                 )
0256                 center_y = (
0257                     (r_edges[i] + r_edges[i + 1])
0258                     / 2
0259                     * math.sin((phi_edges[j] + phi_edges[j + 1]) / 2)
0260                 )
0261                 # Only add text if the count is greater than 0
0262                 if count > 0:
0263                     # the color is black or white depending on the background color
0264                     text_color = "white" if count > 2 else "black"
0265                     # Place the text using plt.text()
0266                     ax.text(
0267                         center_x,
0268                         center_y,
0269                         int(count),
0270                         color=text_color,
0271                         ha="center",
0272                         va="center",
0273                         fontsize=8,
0274                         fontweight="bold",
0275                     )
0276 
0277     # Add labels and title
0278     ax.set_xlabel(f"x [mm]")
0279     ax.set_ylabel(f"y [mm]")
0280     ax.set_aspect("equal", adjustable="box")
0281 
0282     return fig, ax
0283 
0284 
0285 # plot a surface as a Polygon
0286 def plot_surface(vertices: np.ndarray, ax, fill_color):
0287     from matplotlib.patches import Polygon
0288 
0289     polygon = Polygon(
0290         vertices, closed=True, facecolor=fill_color, edgecolor=fill_color, alpha=0.25
0291     )
0292     ax.add_patch(polygon)
0293 
0294 
0295 if __name__ == "__main__":
0296 
0297     p = argparse.ArgumentParser()
0298     p.add_argument(
0299         "-p", "--policy", type=str, default="", help="Input JSON policy file"
0300     )
0301     p.add_argument("--no-text", action="store_true", help="Switch off bin text display")
0302     p.add_argument("--no-grid", action="store_true", help="Switch off grid display")
0303     p.add_argument(
0304         "--no-lines", action="store_true", help="Switch off grid lines display"
0305     )
0306     p.add_argument(
0307         "--no-surfaces", action="store_true", help="Switch off surface display"
0308     )
0309     p.add_argument(
0310         "--random-surface-color", action="store_true", help="Use random surface colors"
0311     )
0312 
0313     args = p.parse_args()
0314 
0315     if args.policy != "":
0316         with open(args.policy, "r") as f:
0317             policy_descr = json.load(f)
0318             type_descr = policy_descr["type"]
0319             grid_descr = policy_descr["grid"]
0320             axes_descr = grid_descr["axes"]
0321             # Lets define the type, possible types are plane, ring, disc, cylinder
0322             policy_type = None
0323             if "Plane" in type_descr:
0324                 policy_type = "Plane"
0325             if "Disc" in type_descr:
0326                 policy_type = "Disc"
0327             if "Ring" in type_descr:
0328                 policy_type = "Ring"
0329             if "Cylinder" in type_descr:
0330                 policy_type = "Cylinder"
0331             if policy_type is None:
0332                 raise ValueError(f"Unknown grid type: {type_descr}")
0333 
0334             # Define default axes
0335             axes = []
0336             # Loop over the axes descriptions and replace the default axes
0337             for i, axis_descr in enumerate(axes_descr):
0338                 axis_type = axis_descr["type"]
0339                 if axis_type == "Equidistant":
0340                     axis_bins = axis_descr["bins"]
0341                     axis_range = axis_descr["range"]
0342                     # Get the axis values, regular for the moment
0343                     axes.append(RegularAxis("Axis_name", axis_bins, axis_range))
0344 
0345             if len(axes) == 1:
0346                 if policy_type == "Ring":
0347                     reference_range = policy_descr["projectedReferenceRange"]
0348                     axes.insert(0, RegularAxis("r [mm]", 1, reference_range))
0349                 grid_data = np.full((axes[0].bins, axes[1].bins), np.nan)
0350                 if not args.no_grid:
0351                     grid_data_descr = grid_descr["data"]
0352                     for entry_descr in grid_data_descr:
0353                         bin_descr = entry_descr[0]
0354                         value = len(entry_descr[1])
0355                         grid_data[0, bin_descr[0] - 1] = value
0356             elif len(axes) == 2:
0357                 # Create an empty grid data for demonstration
0358                 grid_data = np.full((axes[0].bins, axes[1].bins), np.nan)
0359                 if not args.no_grid:
0360                     grid_data_descr = grid_descr["data"]
0361                     for entry_descr in grid_data_descr:
0362                         bin_descr = entry_descr[0]
0363                         value = len(entry_descr[1])
0364                         grid_data[bin_descr[0] - 1, bin_descr[1] - 1] = value
0365             else:
0366                 raise ValueError("Only 1D and 2D grids are supported in this example.")
0367 
0368             # Now plot the grid according to its type
0369             if policy_type == "Plane":
0370                 axes[0].name = "x [mm]"
0371                 axes[1].name = "y [mm]"
0372                 fig, ax = plot_rectangular_grid(
0373                     axes[0],
0374                     axes[1],
0375                     grid_data,
0376                     add_text=not args.no_text,
0377                     add_lines=not args.no_lines,
0378                 )
0379             elif policy_type == "Cylinder":
0380                 axes[0].name = "phi [rad]"
0381                 axes[1].name = "z [mm]"
0382                 fig, ax = plot_rectangular_grid(
0383                     axes[0],
0384                     axes[1],
0385                     grid_data,
0386                     add_text=not args.no_text,
0387                     add_lines=not args.no_lines,
0388                 )
0389 
0390             elif policy_type == "Disc" or policy_type == "Ring":
0391 
0392                 # Make a cartesian view first where the polar grid will be sitting
0393                 cart_axes = [
0394                     RegularAxis(
0395                         "x [mm]", axes[0].bins, (-axes[0].range[1], axes[0].range[1])
0396                     ),
0397                     RegularAxis(
0398                         "y [mm]", axes[1].bins, (-axes[0].range[1], axes[0].range[1])
0399                     ),
0400                 ]
0401                 empty_data = np.full((axes[0].bins, axes[1].bins), np.nan)
0402                 fig, ax = plot_rectangular_grid(
0403                     cart_axes[0], cart_axes[1], empty_data, False, False
0404                 )
0405                 plot_polar_grid(
0406                     axes[0],
0407                     axes[1],
0408                     grid_data,
0409                     add_text=not args.no_text,
0410                     add_lines=not args.no_lines,
0411                 )
0412 
0413             # Draw the projected surface points
0414             if not args.no_surfaces and "projectedSurfaces" in policy_descr:
0415                 # plt.subplot(projection=None)
0416                 for surface_vertices in policy_descr["projectedSurfaces"]:
0417                     # check special cylinder treatment
0418                     surface_vertices, wrap_vertices = correct_vertices(
0419                         policy_type,
0420                         surface_vertices,
0421                     )
0422                     color = "blue"
0423                     # Take a random color for each surface
0424                     if args.random_surface_color:
0425                         color = np.random.rand(
0426                             3,
0427                         )
0428                     plot_surface(
0429                         np.array(sort_vertices(policy_type, surface_vertices)),
0430                         ax,
0431                         fill_color=color,
0432                     )
0433 
0434                     # if len(projected_wrap) > 0:
0435                     #    plot_surface(np.array(sort_vertices(projected_wrap)), ax)
0436 
0437             # Add a colorbar to the plot
0438             if not args.no_grid:
0439                 # Set range for color scale to min 1 to max value + 2
0440                 plt.clim(0, np.nanmax(grid_data))
0441                 plt.colorbar(label="Counts", ax=ax)
0442             plt.show()