File indexing completed on 2026-03-28 07:46:06
0001 import numpy as np
0002 import matplotlib.pyplot as plt
0003 import argparse
0004 import json
0005 import math
0006
0007
0008
0009 class Axis:
0010 def __init__(self, name, bins, range):
0011 self.name = name
0012 self.bins = bins
0013 self.range = range
0014
0015
0016 class RegularAxis(Axis):
0017 def __init__(self, name, bins, range):
0018 super().__init__(name, bins, range)
0019
0020 def get_edges(self):
0021 return np.linspace(self.range[0], self.range[1], self.bins + 1)
0022
0023
0024 class VariableAxis(Axis):
0025 def __init__(self, name, edges: np.ndarray):
0026 super().__init__(name, len(edges) - 1, (edges[0], edges[-1]))
0027 self.edges = edges
0028 self.bins = len(edges) - 1
0029
0030 def get_edges(self):
0031 return self.edges
0032
0033
0034 def sort_vertices(policy_type: str, vertices: list):
0035 """Sorts the vertices in counter-clockwise order around their centroid."""
0036 if policy_descr == "Disc":
0037 return vertices
0038
0039
0040 centroid = [
0041 sum(v[0] for v in vertices) / len(vertices),
0042 sum(v[1] for v in vertices) / len(vertices),
0043 ]
0044
0045
0046 sorted_vertices = sorted(
0047 vertices, key=lambda v: math.atan2(v[1] - centroid[1], v[0] - centroid[0])
0048 )
0049
0050 return sorted_vertices
0051
0052
0053 def correct_vertices(
0054 policy_type: str,
0055 vertices: list,
0056 wrap_threshold=1.5 * math.pi,
0057 overlap_threshold=0.1,
0058 ):
0059 """Corrects the vertices for wrapping around the -pi to pi boundary in the phi coordinate."""
0060 if policy_type == "Plane":
0061 return vertices, []
0062 if policy_type == "Disc" or policy_type == "Ring":
0063
0064 corrected_vertices = []
0065 for v in vertices:
0066 x = v[0] * math.cos(v[1])
0067 y = v[0] * math.sin(v[1])
0068 corrected_vertices.append([x, y])
0069 return corrected_vertices, []
0070
0071 corrected_vertices = vertices.copy()
0072
0073 corrected_vertices.sort(key=lambda v: v[1])
0074
0075 diff_phi = abs(corrected_vertices[0][1] - corrected_vertices[-1][1])
0076 half_phi = 0.5 * abs(corrected_vertices[0][1] + corrected_vertices[-1][1])
0077
0078 corrected_vertices.sort(key=lambda v: v[0])
0079 diff_z = abs(corrected_vertices[0][0] - corrected_vertices[-1][0])
0080
0081 if diff_phi > wrap_threshold:
0082 corrected_vertices = []
0083 mirror_vertices = []
0084
0085 first_wrap_point = False
0086 for v in vertices:
0087 if v[1] < -overlap_threshold * half_phi:
0088
0089 corr_phi = v[1] + 2 * math.pi
0090 corrected_vertices.append([v[0], corr_phi])
0091 mirror_vertices.append(v)
0092 if first_wrap_point == False:
0093 first_wrap_point = True
0094 else:
0095 corrected_vertices.append(v)
0096 mirror_vertices.append([v[0], v[1] - 2 * math.pi])
0097 return corrected_vertices, mirror_vertices
0098 else:
0099 return vertices, []
0100
0101
0102 def plot_rectangular_grid(
0103 x: Axis, y: Axis, grid_data: np.ndarray, add_text=True, add_lines=True
0104 ):
0105 """Helper method to plot the rectangular grid given the two axes and the grid data."""
0106 x_edges = x.get_edges()
0107 y_edges = y.get_edges()
0108
0109
0110 fig, ax = plt.subplots(figsize=(8, 6))
0111
0112
0113 cmap = plt.cm.magma_r
0114 cmap.set_bad("white")
0115
0116
0117 if add_lines:
0118 for edge in x_edges:
0119 ax.axvline(edge, color="black", linestyle="--", linewidth=0.5)
0120 for edge in y_edges:
0121 ax.axhline(edge, color="black", linestyle="--", linewidth=0.5)
0122
0123
0124 plt.hist2d(
0125 x=np.repeat((x_edges[:-1] + x_edges[1:]) / 2, y.bins),
0126 y=np.tile((y_edges[:-1] + y_edges[1:]) / 2, x.bins),
0127 bins=[x_edges, y_edges],
0128 weights=grid_data.flatten() if grid_data is not None else None,
0129 cmap=cmap,
0130 )
0131
0132
0133 if add_text and grid_data is not None:
0134 for i in range(len(x_edges) - 1):
0135 for j in range(len(y_edges) - 1):
0136
0137 count = grid_data[i, j]
0138
0139 center_x = (x_edges[i] + x_edges[i + 1]) / 2
0140 center_y = (y_edges[j] + y_edges[j + 1]) / 2
0141
0142
0143 if count > 0:
0144
0145 text_color = "white" if count > 2 else "black"
0146
0147 ax.text(
0148 center_x,
0149 center_y,
0150 int(count),
0151 color=text_color,
0152 ha="center",
0153 va="center",
0154 fontsize=8,
0155 fontweight="bold",
0156 )
0157
0158 ax.set_xlabel(f"{x.name}")
0159 ax.set_ylabel(f"{y.name}")
0160
0161 return fig, ax
0162
0163
0164 def plot_polar_grid(
0165 r: Axis, phi: Axis, grid_data: np.ndarray, add_text=True, add_lines=True
0166 ):
0167 """Helper method to plot the polar grid given the two axes and the grid data."""
0168 r_edges = r.get_edges()
0169 phi_edges = phi.get_edges()
0170
0171
0172 cmap = plt.cm.magma_r
0173 cmap.set_bad("white")
0174
0175
0176 for i in range(len(r_edges) - 1):
0177 for j in range(len(phi_edges) - 1):
0178
0179 sector_vertices = []
0180 sector_vertices.append(
0181 [
0182 r_edges[i] * math.cos(phi_edges[j]),
0183 r_edges[i] * math.sin(phi_edges[j]),
0184 ]
0185 )
0186 sector_vertices.append(
0187 [
0188 r_edges[i + 1] * math.cos(phi_edges[j]),
0189 r_edges[i + 1] * math.sin(phi_edges[j]),
0190 ]
0191 )
0192 sector_vertices.append(
0193 [
0194 r_edges[i + 1] * math.cos(phi_edges[j + 1]),
0195 r_edges[i + 1] * math.sin(phi_edges[j + 1]),
0196 ]
0197 )
0198 sector_vertices.append(
0199 [
0200 r_edges[i] * math.cos(phi_edges[j + 1]),
0201 r_edges[i] * math.sin(phi_edges[j + 1]),
0202 ]
0203 )
0204
0205 polygon = plt.Polygon(
0206 sector_vertices,
0207 closed=True,
0208 facecolor=(
0209 cmap(grid_data[i, j] / np.nanmax(grid_data))
0210 if grid_data is not None
0211 else "white"
0212 ),
0213 edgecolor=None,
0214 alpha=1.0,
0215 )
0216
0217 plt.gca().add_patch(polygon)
0218
0219
0220 if add_lines:
0221 for edge in r_edges:
0222 circle = plt.Circle(
0223 (0, 0), edge, color="black", fill=False, linestyle="--", linewidth=0.5
0224 )
0225 ax.add_artist(circle)
0226 for edge in phi_edges:
0227 x = [r.range[0] * math.cos(edge), r.range[1] * math.cos(edge)]
0228 y = [r.range[0] * math.sin(edge), r.range[1] * math.sin(edge)]
0229 ax.plot(x, y, color="black", linestyle="--", linewidth=0.5)
0230
0231
0232
0233
0234
0235
0236
0237
0238
0239
0240
0241
0242
0243
0244
0245 if add_text and grid_data is not None:
0246 for i in range(len(r_edges) - 1):
0247 for j in range(len(phi_edges) - 1):
0248
0249 count = grid_data[i, j]
0250
0251 center_x = (
0252 (r_edges[i] + r_edges[i + 1])
0253 / 2
0254 * math.cos((phi_edges[j] + phi_edges[j + 1]) / 2)
0255 )
0256 center_y = (
0257 (r_edges[i] + r_edges[i + 1])
0258 / 2
0259 * math.sin((phi_edges[j] + phi_edges[j + 1]) / 2)
0260 )
0261
0262 if count > 0:
0263
0264 text_color = "white" if count > 2 else "black"
0265
0266 ax.text(
0267 center_x,
0268 center_y,
0269 int(count),
0270 color=text_color,
0271 ha="center",
0272 va="center",
0273 fontsize=8,
0274 fontweight="bold",
0275 )
0276
0277
0278 ax.set_xlabel(f"x [mm]")
0279 ax.set_ylabel(f"y [mm]")
0280 ax.set_aspect("equal", adjustable="box")
0281
0282 return fig, ax
0283
0284
0285
0286 def plot_surface(vertices: np.ndarray, ax, fill_color):
0287 from matplotlib.patches import Polygon
0288
0289 polygon = Polygon(
0290 vertices, closed=True, facecolor=fill_color, edgecolor=fill_color, alpha=0.25
0291 )
0292 ax.add_patch(polygon)
0293
0294
0295 if __name__ == "__main__":
0296
0297 p = argparse.ArgumentParser()
0298 p.add_argument(
0299 "-p", "--policy", type=str, default="", help="Input JSON policy file"
0300 )
0301 p.add_argument("--no-text", action="store_true", help="Switch off bin text display")
0302 p.add_argument("--no-grid", action="store_true", help="Switch off grid display")
0303 p.add_argument(
0304 "--no-lines", action="store_true", help="Switch off grid lines display"
0305 )
0306 p.add_argument(
0307 "--no-surfaces", action="store_true", help="Switch off surface display"
0308 )
0309 p.add_argument(
0310 "--random-surface-color", action="store_true", help="Use random surface colors"
0311 )
0312
0313 args = p.parse_args()
0314
0315 if args.policy != "":
0316 with open(args.policy, "r") as f:
0317 policy_descr = json.load(f)
0318 type_descr = policy_descr["type"]
0319 grid_descr = policy_descr["grid"]
0320 axes_descr = grid_descr["axes"]
0321
0322 policy_type = None
0323 if "Plane" in type_descr:
0324 policy_type = "Plane"
0325 if "Disc" in type_descr:
0326 policy_type = "Disc"
0327 if "Ring" in type_descr:
0328 policy_type = "Ring"
0329 if "Cylinder" in type_descr:
0330 policy_type = "Cylinder"
0331 if policy_type is None:
0332 raise ValueError(f"Unknown grid type: {type_descr}")
0333
0334
0335 axes = []
0336
0337 for i, axis_descr in enumerate(axes_descr):
0338 axis_type = axis_descr["type"]
0339 if axis_type == "Equidistant":
0340 axis_bins = axis_descr["bins"]
0341 axis_range = axis_descr["range"]
0342
0343 axes.append(RegularAxis("Axis_name", axis_bins, axis_range))
0344
0345 if len(axes) == 1:
0346 if policy_type == "Ring":
0347 reference_range = policy_descr["projectedReferenceRange"]
0348 axes.insert(0, RegularAxis("r [mm]", 1, reference_range))
0349 grid_data = np.full((axes[0].bins, axes[1].bins), np.nan)
0350 if not args.no_grid:
0351 grid_data_descr = grid_descr["data"]
0352 for entry_descr in grid_data_descr:
0353 bin_descr = entry_descr[0]
0354 value = len(entry_descr[1])
0355 grid_data[0, bin_descr[0] - 1] = value
0356 elif len(axes) == 2:
0357
0358 grid_data = np.full((axes[0].bins, axes[1].bins), np.nan)
0359 if not args.no_grid:
0360 grid_data_descr = grid_descr["data"]
0361 for entry_descr in grid_data_descr:
0362 bin_descr = entry_descr[0]
0363 value = len(entry_descr[1])
0364 grid_data[bin_descr[0] - 1, bin_descr[1] - 1] = value
0365 else:
0366 raise ValueError("Only 1D and 2D grids are supported in this example.")
0367
0368
0369 if policy_type == "Plane":
0370 axes[0].name = "x [mm]"
0371 axes[1].name = "y [mm]"
0372 fig, ax = plot_rectangular_grid(
0373 axes[0],
0374 axes[1],
0375 grid_data,
0376 add_text=not args.no_text,
0377 add_lines=not args.no_lines,
0378 )
0379 elif policy_type == "Cylinder":
0380 axes[0].name = "phi [rad]"
0381 axes[1].name = "z [mm]"
0382 fig, ax = plot_rectangular_grid(
0383 axes[0],
0384 axes[1],
0385 grid_data,
0386 add_text=not args.no_text,
0387 add_lines=not args.no_lines,
0388 )
0389
0390 elif policy_type == "Disc" or policy_type == "Ring":
0391
0392
0393 cart_axes = [
0394 RegularAxis(
0395 "x [mm]", axes[0].bins, (-axes[0].range[1], axes[0].range[1])
0396 ),
0397 RegularAxis(
0398 "y [mm]", axes[1].bins, (-axes[0].range[1], axes[0].range[1])
0399 ),
0400 ]
0401 empty_data = np.full((axes[0].bins, axes[1].bins), np.nan)
0402 fig, ax = plot_rectangular_grid(
0403 cart_axes[0], cart_axes[1], empty_data, False, False
0404 )
0405 plot_polar_grid(
0406 axes[0],
0407 axes[1],
0408 grid_data,
0409 add_text=not args.no_text,
0410 add_lines=not args.no_lines,
0411 )
0412
0413
0414 if not args.no_surfaces and "projectedSurfaces" in policy_descr:
0415
0416 for surface_vertices in policy_descr["projectedSurfaces"]:
0417
0418 surface_vertices, wrap_vertices = correct_vertices(
0419 policy_type,
0420 surface_vertices,
0421 )
0422 color = "blue"
0423
0424 if args.random_surface_color:
0425 color = np.random.rand(
0426 3,
0427 )
0428 plot_surface(
0429 np.array(sort_vertices(policy_type, surface_vertices)),
0430 ax,
0431 fill_color=color,
0432 )
0433
0434
0435
0436
0437
0438 if not args.no_grid:
0439
0440 plt.clim(0, np.nanmax(grid_data))
0441 plt.colorbar(label="Counts", ax=ax)
0442 plt.show()