File indexing completed on 2026-05-27 07:24:18
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 from collections import namedtuple
0011 import math
0012 import numpy as np
0013
0014
0015 import matplotlib.pyplot as plt
0016 import matplotlib.colors as mcolors
0017 from matplotlib import ticker
0018 import matplotlib.style as style
0019 from mpl_toolkits.axes_grid1 import make_axes_locatable
0020
0021
0022 from .plot_helpers import plt_data, axis_options, legend_options
0023
0024 style.use("tableau-colorblind10")
0025
0026
0027 plt.rcParams.update(
0028 {
0029 "text.usetex": True,
0030 "font.size": 25,
0031 "font.family": "serif",
0032 }
0033 )
0034
0035
0036
0037 class ScalarFormatterForceFormat(ticker.ScalarFormatter):
0038 def _set_format(self):
0039 self.format = "%3.1f"
0040
0041
0042
0043
0044
0045
0046 """ Default color for graphs and histograms """
0047 default_color = "tab:blue"
0048
0049
0050
0051
0052
0053 """ Plotter interface that uses pyplot/matplotlib. """
0054
0055
0056 class pyplot_factory:
0057
0058 def __init__(self, out_dir, logger, atlas_badge=""):
0059 self.name = ("Pyplot",)
0060 self.output_prefix = out_dir
0061 self.logger = logger
0062 self.atlas_badge = atlas_badge
0063 self.badge_scale = 1.1
0064 self.axis_formatter = ScalarFormatterForceFormat()
0065 self.axis_formatter.set_powerlimits((-2, 2))
0066
0067
0068 def __add_legend(self, ax, options=legend_options()):
0069 return ax.legend(
0070 title=options.title,
0071 loc=options.loc,
0072 bbox_to_anchor=(options.horiz_anchor, options.vert_anchor),
0073 ncol=options.ncol,
0074 borderpad=0.3,
0075 columnspacing=options.colspacing,
0076 handletextpad=options.handletextpad,
0077 )
0078
0079
0080 def __adjust_lgd_label_spacing(self, lgd):
0081
0082 lgd.legend_handles[0].set_visible(False)
0083 for handle in lgd.legend_handles[1:]:
0084 handle.set_sizes([40])
0085
0086
0087 for vpack in lgd._legend_handle_box.get_children()[:1]:
0088 for hpack in vpack.get_children():
0089 hpack.get_children()[0].set_width(0)
0090
0091
0092 def __update_legend(self, lgd):
0093 handles, labels = lgd.axes.get_legend_handles_labels()
0094 lgd._legend_box = None
0095 lgd._init_legend_box(handles, labels)
0096 lgd._set_loc(lgd._loc)
0097 lgd.set_title(lgd.get_title().get_text())
0098
0099
0100 def __get_axis_boundaries(self, data, axis_opts):
0101 if axis_opts.min is not None and axis_opts.max is not None:
0102 return axis_opts.min, axis_opts.max
0103 else:
0104 return np.min(data), np.max(data)
0105
0106
0107 def __apply_boundary(self, data, min_v, max_v):
0108 if min_v is not None and max_v is not None:
0109 out = data[np.nonzero(data >= min_v)]
0110 out = out[np.nonzero(out <= max_v)]
0111 return out
0112 else:
0113 return data
0114
0115
0116 def __set_label_format(self, label_format, axis):
0117 if label_format is None:
0118 return
0119
0120 if label_format == "default":
0121 axis.set_major_formatter(self.axis_formatter)
0122 else:
0123 tick_formatter = ticker.StrMethodFormatter(label_format)
0124 axis.set_major_formatter(tick_formatter)
0125 axis.set_minor_formatter(tick_formatter)
0126
0127 """ Create a graph from given input data. """
0128
0129 def graph(
0130 self,
0131 x,
0132 y,
0133 y_errors=None,
0134 title="",
0135 label="",
0136 x_axis=axis_options(label="x"),
0137 y_axis=axis_options(label="y"),
0138 color=None,
0139 marker=".",
0140 lgd_ops=legend_options(),
0141 figsize=(8, 8),
0142 layout="constrained",
0143 ):
0144
0145 fig = plt.figure(figsize=figsize, layout=layout)
0146 ax = fig.add_subplot(1, 1, 1)
0147
0148
0149 ax.set_title(title)
0150 ax.set_xlabel(x_axis.label)
0151 ax.set_ylabel(y_axis.label)
0152 ax.grid(True, alpha=0.25)
0153
0154
0155 if x_axis.log_scale is not None:
0156 ax.set_xscale("log", base=x_axis.log_scale)
0157 if y_axis.log_scale is not None:
0158 ax.set_yscale("log", base=y_axis.log_scale)
0159
0160 if x_axis.tick_positions is not None:
0161 ax.set_xticks(x_axis.tick_positions)
0162 ax.tick_params(axis="x", which="major", pad=7)
0163
0164 if y_axis.tick_positions is not None:
0165 ax.set_yticks(y_axis.tick_positions)
0166 ax.tick_params(axis="y", which="major", pad=7)
0167
0168
0169 x = self.__apply_boundary(x, x_axis.min, x_axis.max)
0170 y = self.__apply_boundary(y, y_axis.min, y_axis.max)
0171
0172
0173 self.__set_label_format(x_axis.label_format, ax.xaxis)
0174 self.__set_label_format(y_axis.label_format, ax.yaxis)
0175
0176
0177 if len(x) == 0:
0178 self.logger.debug(rf" create graph: empty data {label}")
0179 return plt_data(fig=fig, axes=ax)
0180
0181 if len(x) != len(y):
0182 self.logger.debug(rf" create graph: x range does match y range {label}")
0183 return plt_data(fig=fig, axes=ax, errors=y_errors)
0184
0185 data = ax.errorbar(
0186 x=x, y=y, label=label, yerr=y_errors, marker=marker, color=color
0187 )
0188
0189
0190 lgd = self.__add_legend(ax, lgd_ops)
0191
0192 return plt_data(fig=fig, axes=ax, lgd=lgd, data=data, errors=y_errors)
0193
0194 """ Add new graph to an existing plot """
0195
0196 def add_graph(
0197 self,
0198 plot,
0199 x,
0200 y,
0201 y_errors=None,
0202 label="",
0203 marker="+",
0204 color=None,
0205 ):
0206
0207 if len(y) == 0 or plot.data is None:
0208 self.logger.debug(rf" add graph: empty data {label}")
0209 return plot
0210
0211
0212 data = plot.axes.errorbar(
0213 x=x,
0214 y=y,
0215 label=label,
0216 yerr=y_errors,
0217 color=color,
0218 marker=marker,
0219 )
0220
0221 self.__update_legend(plot.lgd)
0222
0223
0224 plot.axes.relim()
0225 plot.axes.autoscale_view()
0226
0227 return plt_data(
0228 fig=plot.fig, axes=plot.axes, lgd=plot.lgd, data=data, errors=y_errors
0229 )
0230
0231 """
0232 Create a histogram from given input data. The normalization is achieved by
0233 dividing the bin count by the total number of observations. The error is
0234 calculated as the square root of the bin content.
0235 """
0236
0237 def hist1D(
0238 self,
0239 x,
0240 bins=1,
0241 errors=None,
0242 w=None,
0243 title="",
0244 label="",
0245 x_axis=axis_options(label="x"),
0246 y_axis=axis_options(label=""),
0247 color=default_color,
0248 alpha=0.75,
0249 normalize=False,
0250 show_error=False,
0251 show_stats=True,
0252 u_outlier=-1,
0253 o_outlier=-1,
0254 lgd_ops=legend_options(),
0255 figsize=(8, 8),
0256 layout="compressed",
0257 ):
0258
0259
0260 fig = plt.figure(figsize=figsize, layout=layout)
0261 ax = fig.add_subplot(1, 1, 1)
0262
0263
0264 ax.set_title(title)
0265 ax.set_xlabel(x_axis.label)
0266 ax.set_ylabel(y_axis.label)
0267 ax.grid(True, alpha=0.25)
0268
0269
0270 if x_axis.log_scale is not None:
0271 ax.set_xscale("log", base=x_axis.log_scale)
0272 if y_axis.log_scale is not None:
0273 ax.set_yscale("log", base=y_axis.log_scale)
0274
0275
0276 ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
0277 self.__set_label_format(y_axis.label_format, ax.yaxis)
0278
0279
0280 x_min, x_max = self.__get_axis_boundaries(x, x_axis)
0281 if x_axis.min is not None and x_axis.max is not None:
0282 x = self.__apply_boundary(x, x_min, x_max)
0283
0284
0285 underflow = len(np.argwhere(x < x_min))
0286 overflow = len(np.argwhere(x > x_max))
0287 if u_outlier >= 0 or o_outlier >= 0:
0288 underflow = underflow + u_outlier
0289 overflow = overflow + o_outlier
0290
0291
0292 if len(x) == 0:
0293 self.logger.debug(rf" create hist: empty data {label}")
0294 return plt_data(fig=fig, axes=ax)
0295
0296
0297 scale = 1.0 / len(x) if normalize else 1.0
0298
0299
0300 newline = "\n"
0301
0302
0303 label_str = f"{label} ({len(x)} entries)"
0304 if u_outlier >= 0 or o_outlier >= 0:
0305 label_str = (
0306 label_str
0307 + f"{newline} underflow: {underflow}"
0308 + f"{newline} overflow: {overflow}"
0309 )
0310
0311
0312 data, bins, _ = ax.hist(
0313 x,
0314 weights=w,
0315 range=(x_min, x_max),
0316 bins=bins,
0317 label=label_str,
0318 histtype="stepfilled",
0319 density=normalize,
0320 alpha=alpha,
0321 facecolor=color,
0322 edgecolor=color,
0323 )
0324
0325
0326 if show_stats:
0327 mean = np.mean(x, axis=0)
0328
0329 stdev = np.std(x, axis=0)
0330
0331
0332 ax.plot(
0333 [],
0334 [],
0335 " ",
0336 label=rf"data:"
0337 rf"{newline}mean = {mean:.2e}"
0338 rf"{newline}stddev = {stdev:.2e}",
0339 )
0340 else:
0341 mean = None
0342 stdev = None
0343
0344
0345 bin_centers = 0.5 * (bins[1:] + bins[:-1])
0346 err = np.sqrt(scale * data) if errors is None else errors
0347 if show_error or errors is not None:
0348 ax.errorbar(
0349 bin_centers,
0350 data,
0351 yerr=err,
0352 fmt=".",
0353 linestyle="",
0354 linewidth=0.4,
0355 color="black",
0356 capsize=2.5,
0357 )
0358
0359
0360 lgd = self.__add_legend(ax, lgd_ops)
0361
0362
0363 lgd.legend_handles[0].set_visible(False)
0364 if show_stats:
0365 lgd.legend_handles[1].set_visible(False)
0366 for vpack in lgd._legend_handle_box.get_children():
0367 for hpack in vpack.get_children():
0368 hpack.get_children()[0].set_width(0)
0369
0370 return plt_data(
0371 fig=fig,
0372 axes=ax,
0373 lgd=lgd,
0374 data=data,
0375 bins=bins,
0376 mu=mean,
0377 rms=stdev,
0378 errors=err,
0379 )
0380
0381 """ Add new histogram to an existing plot """
0382
0383 def add_hist(
0384 self,
0385 old_hist,
0386 x,
0387 errors=None,
0388 w=None,
0389 label="",
0390 color="tab:orange",
0391 alpha=0.75,
0392 normalize=False,
0393 show_error=False,
0394 ):
0395
0396
0397 x = self.__apply_boundary(x, np.min(old_hist.bins), np.max(old_hist.bins))
0398
0399
0400 if len(x) == 0 or old_hist.data is None:
0401 self.logger.debug(rf" add hist: empty data {label}")
0402 return old_hist
0403
0404
0405 scale = 1.0 / len(x) if normalize else 1.0
0406 data, bins, _ = old_hist.axes.hist(
0407 x=x,
0408 bins=old_hist.bins,
0409 label=f"{label} ({len(x)} entries)",
0410 weights=w,
0411 histtype="stepfilled",
0412 facecolor=color,
0413 alpha=alpha,
0414 edgecolor=color,
0415 )
0416
0417
0418 bin_centers = 0.5 * (bins[1:] + bins[:-1])
0419 err = np.sqrt(scale * data) if errors is None else errors
0420 if show_error or errors is not None:
0421 old_hist.axes.errorbar(
0422 bin_centers,
0423 data,
0424 yerr=err,
0425 fmt=".",
0426 linestyle="",
0427 linewidth=0.4,
0428 color="black",
0429 capsize=2.5,
0430 )
0431
0432
0433 self.__update_legend(old_hist.lgd)
0434
0435 return plt_data(
0436 fig=old_hist.fig,
0437 axes=old_hist.axes,
0438 lgd=old_hist.lgd,
0439 data=data,
0440 bins=bins,
0441 errors=err,
0442 )
0443
0444 """
0445 Plot the ratio of two histograms. The data is assumed to be uncorrelated.
0446 """
0447
0448 def add_ratio(
0449 self, nom, denom, label, color="tab:red", set_log=False, show_error=False
0450 ):
0451
0452
0453 nom.fig.set_figheight(7)
0454 nom.fig.set_figwidth(8)
0455
0456 if nom.bins is None or denom.bins is None:
0457 return plt_data(fig=nom.fig, axes=nom.axes)
0458
0459 if len(nom.bins) != len(denom.bins):
0460 return plt_data(fig=nom.fig, axes=nom.axes)
0461
0462
0463 x_label = nom.axes.xaxis.get_label().get_text()
0464 nom.axes.tick_params(
0465 axis="x", which="both", bottom=True, top=False, labelbottom=False
0466 )
0467 nom.axes.set_xlabel("")
0468
0469
0470 with np.errstate(divide="ignore"), np.errstate(invalid="ignore"):
0471
0472 ratio = np.nan_to_num(nom.data / denom.data, nan=0, posinf=0)
0473
0474
0475 bin_centers = 0.5 * (nom.bins[1:] + nom.bins[:-1])
0476 n_data, d_data = (nom.data, denom.data)
0477
0478
0479
0480 n_err, d_err = (nom.errors, denom.errors)
0481 errors = np.nan_to_num(
0482 np.sqrt(
0483 np.square(n_err / d_data)
0484 + np.square(n_data * d_err / np.square(d_data))
0485 ),
0486 nan=0,
0487 posinf=0,
0488 )
0489
0490
0491
0492
0493 divider = make_axes_locatable(nom.axes)
0494 ratio_plot = divider.append_axes("bottom", 1.2, pad=0.2, sharex=nom.axes)
0495
0496 ratio_plot.axes.yaxis.set_major_formatter(
0497 ticker.ScalarFormatter(useOffset=False)
0498 )
0499 if show_error:
0500 ratio_plot.errorbar(
0501 bin_centers, ratio, yerr=errors, label=label, color=color, fmt="."
0502 )
0503 else:
0504 ratio_plot.plot(
0505 bin_centers,
0506 ratio,
0507 label=label,
0508 color=color,
0509 marker=".",
0510 linestyle="",
0511 )
0512
0513
0514 ratio_plot.set_xlabel(x_label)
0515 ratio_plot.set_ylabel("ratio")
0516 ratio_plot.grid(True, alpha=0.25)
0517
0518
0519 if set_log:
0520 ratio_plot.set_yscale("log")
0521
0522
0523 ratio_plot.axline((nom.bins[0], 1), (nom.bins[-1], 1), linewidth=1, color="b")
0524
0525 nom.fig.set_size_inches((9, 9))
0526
0527 return plt_data(fig=nom.fig, axes=ratio_plot, errors=errors)
0528
0529 """
0530 Create a 2D histogram from given input data. If z values are given they will
0531 be used as weights per bin.
0532 """
0533
0534 def hist2D(
0535 self,
0536 x,
0537 y,
0538 z=None,
0539 x_bins=1,
0540 y_bins=1,
0541 x_axis=axis_options(label="x"),
0542 y_axis=axis_options(label="y"),
0543 z_axis=axis_options(label=""),
0544 title="",
0545 label="",
0546 color=default_color,
0547 alpha=0.75,
0548 show_stats=True,
0549 figsize=(8, 6),
0550 ):
0551
0552
0553 fig = plt.figure(figsize=figsize, layout="constrained")
0554 ax = fig.add_subplot(1, 1, 1)
0555
0556
0557 ax.set_title(title)
0558 ax.set_xlabel(x_axis.label)
0559 ax.set_ylabel(y_axis.label)
0560
0561
0562 x_min, x_max = self.__get_axis_boundaries(x, x_axis)
0563 if x_axis.min is not None and x_axis.max is not None:
0564 x = self.__apply_boundary(x, x_min, x_max)
0565
0566 y_min, y_max = self.__get_axis_boundaries(y, y_axis)
0567 if y_axis.min is not None and y_axis.max is not None:
0568 y = self.__apply_boundary(y, y_min, y_max)
0569
0570
0571 if len(x) == 0 or len(y) == 0:
0572 self.logger.debug(rf" create hist: empty data {label}")
0573 return plt_data(fig=fig, axes=ax)
0574
0575
0576 data, _, _, hist = ax.hist2d(
0577 x,
0578 y,
0579 weights=z,
0580 range=[(x_min, x_max), (y_min, y_max)],
0581 bins=(x_bins, y_bins),
0582 label=f"{label} ({len(x)*len(y)} entries)",
0583 facecolor=mcolors.to_rgba(color, alpha),
0584 edgecolor=None,
0585 rasterized=True,
0586 )
0587
0588
0589 if show_stats:
0590 x_mean = np.mean(x, axis=0)
0591 x_rms = np.sqrt(np.mean(np.square(x)))
0592 y_mean = np.mean(y, axis=0)
0593 y_rms = np.sqrt(np.mean(np.square(y)))
0594
0595
0596 newline = "\n"
0597 ax.plot(
0598 [],
0599 [],
0600 " ",
0601 label=rf"xMean = {x_mean:.2e}"
0602 rf"{newline}xRMS = {x_rms:.2e}"
0603 rf"yMean = {y_mean:.2e}"
0604 rf"{newline}yRMS = {y_rms:.2e}",
0605 )
0606
0607
0608 fig.colorbar(hist, label=z_axis.label)
0609
0610 return plt_data(fig=fig, axes=ax, data=data)
0611
0612 """ Create a 2D scatter plot """
0613
0614 def scatter(
0615 self,
0616 x,
0617 y,
0618 x_axis=axis_options(label=""),
0619 y_axis=axis_options(label=""),
0620 title="",
0621 label="",
0622 color=default_color,
0623 alpha=1,
0624 show_stats=lambda x, _: f"{len(x)} entries",
0625 lgd_ops=legend_options(),
0626 figsize=(8, 6),
0627 ):
0628
0629 fig = plt.figure(figsize=figsize, layout="constrained")
0630 ax = fig.add_subplot(1, 1, 1)
0631
0632
0633 ax.set_title(title)
0634 ax.set_xlabel(x_axis.label)
0635 ax.set_ylabel(y_axis.label)
0636 ax.grid(True, alpha=0.25)
0637
0638
0639 ax.plot([], [], " ", label=show_stats(x, y))
0640 scatter = ax.scatter(
0641 x, y, label=label, c=color, s=0.1, alpha=alpha, rasterized=True
0642 )
0643
0644
0645 lgd = self.__add_legend(ax, lgd_ops)
0646
0647
0648 self.__adjust_lgd_label_spacing(lgd)
0649
0650 return plt_data(fig=fig, axes=ax, lgd=lgd, data=scatter)
0651
0652 """ Add new data in a different color to a scatter plot """
0653
0654 def highlight_region(self, plot_data, x, y, color, label=""):
0655
0656 if label == "":
0657 plot_data.axes.scatter(x, y, c=color, alpha=1, s=0.1, rasterized=True)
0658 else:
0659 plot_data.axes.scatter(
0660 x, y, c=color, alpha=1, s=0.1, label=label, rasterized=True
0661 )
0662
0663
0664 self.__update_legend(plot_data.lgd)
0665
0666
0667 self.__adjust_lgd_label_spacing(plot_data.lgd)
0668
0669 """ Fit a Gaussian to a 1D distribution and plot in the same figure. """
0670
0671 def fit_gaussian(self, dist, color="tab:orange"):
0672
0673
0674 bins = dist.bins
0675 if bins is None:
0676
0677 return None, None
0678
0679 bin_centers = [(b1 + b2) / 2 for b1, b2 in zip(bins, bins[1:])]
0680
0681
0682 def __gaussian(x, a, mean, sigma):
0683 return (
0684 a
0685 / (math.sqrt(2 * math.pi) * sigma)
0686 * np.exp(-((x - mean) ** 2 / (2 * sigma**2)))
0687 )
0688
0689
0690 try:
0691 from scipy.optimize import curve_fit
0692 except ImportError:
0693 print("WARNING: Could not find scipy: Skipping fit")
0694 else:
0695 try:
0696
0697 mean = np.mean(bin_centers, axis=0)
0698 sigma = np.std(bin_centers, axis=0)
0699 a = np.max(dist.data) * (math.sqrt(2 * math.pi) * sigma)
0700
0701 popt, _ = curve_fit(
0702 __gaussian, bin_centers, dist.data, p0=[a, mean, sigma]
0703 )
0704 except RuntimeError:
0705
0706 return None, None
0707
0708
0709 mu = float(f"{popt[1]:.2e}")
0710 sig = float(f"{popt[2]:.2e}")
0711 newline = "\n"
0712 plot_label = (
0713 rf"gaussian fit:{newline}$\mu$ = {mu:.2e}"
0714 + rf"{newline}$\sigma$ = {abs(sig):.2e}"
0715 )
0716
0717
0718 min_val = min(bin_centers)
0719 max_val = max(bin_centers)
0720 step = (max_val - min_val) / 1000
0721 x = [v for v in np.arange(min_val, max_val + step, step)]
0722
0723 dist.axes.plot(
0724 x,
0725 __gaussian(x, *popt),
0726 label=plot_label,
0727 color=color,
0728 )
0729
0730
0731 self.__update_legend(dist.lgd)
0732
0733
0734 dist.lgd.legend_handles[0].set_visible(False)
0735 for vpack in dist.lgd._legend_handle_box.get_children()[:-1]:
0736 for hpack in vpack.get_children():
0737 hpack.get_children()[0].set_width(0)
0738
0739 return popt[1], abs(popt[2])
0740
0741 return None, None
0742
0743 """ Draw a vertical line in a given plot"""
0744
0745 def vertical_line(self, plot_data, x, y=None, color="b", label=""):
0746
0747 plot_data.axes.axvline(x=x, color=color, linestyle="--")
0748
0749 ymin, ymax = plot_data.axes.get_ylim()
0750 plot_data.axes.text(
0751 x,
0752 ymin + (ymax - ymin) / 2 if y is None else y,
0753 label,
0754 ha="center",
0755 va="center",
0756 backgroundcolor="white",
0757 )
0758
0759 """ Write a plot to disk """
0760
0761 def write_plot(
0762 self, plot_data, name="plot", file_format="svg", out_prefix="", dpi=450
0763 ):
0764 if out_prefix == "":
0765 file_name = self.output_prefix + "/" + name + "." + file_format
0766 else:
0767 file_name = out_prefix + name + "." + file_format
0768
0769 plot_data.fig.savefig(file_name, dpi=dpi)
0770 plt.close(plot_data.fig)
0771
0772 """ Write a plot as svg """
0773
0774 def write_svg(self, plot_data, name, out_prefix=""):
0775 self.write_plot(plot_data, name, ".svg", out_prefix)
0776
0777 """ Write a plot as pdf """
0778
0779 def write_pdf(self, plot_data, name, out_prefix=""):
0780 self.write_plot(plot_data, name, ".pdf", out_prefix)
0781
0782 """ Write a plot as png """
0783
0784 def write_png(self, plot_data, name, out_prefix=""):
0785 self.write_plot(plot_data, name, ".png", out_prefix)