Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:24:18

0001 # This file is part of the ACTS project.
0002 #
0003 # Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 #
0005 # This Source Code Form is subject to the terms of the Mozilla Public
0006 # License, v. 2.0. If a copy of the MPL was not distributed with this
0007 # file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 # python includes
0010 from collections import namedtuple
0011 import math
0012 import numpy as np
0013 
0014 # python based plotting
0015 import matplotlib.pyplot as plt
0016 import matplotlib.colors as mcolors
0017 from matplotlib import ticker
0018 import matplotlib.style as style
0019 from mpl_toolkits.axes_grid1 import make_axes_locatable
0020 
0021 # detray imports
0022 from .plot_helpers import plt_data, axis_options, legend_options
0023 
0024 style.use("tableau-colorblind10")
0025 # style.use('seaborn-colorblind')
0026 
0027 plt.rcParams.update(
0028     {
0029         "text.usetex": True,
0030         "font.size": 25,
0031         "font.family": "serif",
0032     }
0033 )
0034 
0035 
0036 # See: https://stackoverflow.com/questions/42142144/displaying-first-decimal-digit-in-scientific-notation-in-matplotlib
0037 class ScalarFormatterForceFormat(ticker.ScalarFormatter):
0038     def _set_format(self):
0039         self.format = "%3.1f"
0040 
0041 
0042 # ------------------------------------------------------------------------------
0043 # Global identifiers
0044 # ------------------------------------------------------------------------------
0045 
0046 """ Default color for graphs and histograms """
0047 default_color = "tab:blue"
0048 
0049 # ------------------------------------------------------------------------------
0050 # Data Plotting
0051 # ------------------------------------------------------------------------------
0052 
0053 """ Plotter interface that uses pyplot/matplotlib. """
0054 
0055 
0056 class pyplot_factory:
0057 
0058     def __init__(self, out_dir, logger, atlas_badge=""):
0059         self.name = ("Pyplot",)
0060         self.output_prefix = out_dir
0061         self.logger = logger
0062         self.atlas_badge = atlas_badge
0063         self.badge_scale = 1.1
0064         self.axis_formatter = ScalarFormatterForceFormat()
0065         self.axis_formatter.set_powerlimits((-2, 2))
0066 
0067     # Add legend to a plot. Labels must be defined.
0068     def __add_legend(self, ax, options=legend_options()):
0069         return ax.legend(
0070             title=options.title,
0071             loc=options.loc,
0072             bbox_to_anchor=(options.horiz_anchor, options.vert_anchor),
0073             ncol=options.ncol,
0074             borderpad=0.3,
0075             columnspacing=options.colspacing,
0076             handletextpad=options.handletextpad,
0077         )
0078 
0079     # Adjust label spacing in legend
0080     def __adjust_lgd_label_spacing(self, lgd):
0081         # Refine legend
0082         lgd.legend_handles[0].set_visible(False)
0083         for handle in lgd.legend_handles[1:]:
0084             handle.set_sizes([40])
0085 
0086         # Adjust spacing in box
0087         for vpack in lgd._legend_handle_box.get_children()[:1]:
0088             for hpack in vpack.get_children():
0089                 hpack.get_children()[0].set_width(0)
0090 
0091     # Update after adding new entry to existing legend
0092     def __update_legend(self, lgd):
0093         handles, labels = lgd.axes.get_legend_handles_labels()
0094         lgd._legend_box = None
0095         lgd._init_legend_box(handles, labels)
0096         lgd._set_loc(lgd._loc)
0097         lgd.set_title(lgd.get_title().get_text())
0098 
0099     # Find the axis boundaries either from data or custom boundaries
0100     def __get_axis_boundaries(self, data, axis_opts):
0101         if axis_opts.min is not None and axis_opts.max is not None:
0102             return axis_opts.min, axis_opts.max
0103         else:
0104             return np.min(data), np.max(data)
0105 
0106     # Apply boundary to input data
0107     def __apply_boundary(self, data, min_v, max_v):
0108         if min_v is not None and max_v is not None:
0109             out = data[np.nonzero(data >= min_v)]
0110             out = out[np.nonzero(out <= max_v)]
0111             return out
0112         else:
0113             return data
0114 
0115     # Set axis tick label formatting
0116     def __set_label_format(self, label_format, axis):
0117         if label_format is None:
0118             return
0119 
0120         if label_format == "default":
0121             axis.set_major_formatter(self.axis_formatter)
0122         else:
0123             tick_formatter = ticker.StrMethodFormatter(label_format)
0124             axis.set_major_formatter(tick_formatter)
0125             axis.set_minor_formatter(tick_formatter)
0126 
0127     """ Create a graph from given input data. """
0128 
0129     def graph(
0130         self,
0131         x,
0132         y,
0133         y_errors=None,
0134         title="",
0135         label="",
0136         x_axis=axis_options(label="x"),
0137         y_axis=axis_options(label="y"),
0138         color=None,
0139         marker=".",
0140         lgd_ops=legend_options(),
0141         figsize=(8, 8),
0142         layout="constrained",
0143     ):
0144         # Create fresh plot
0145         fig = plt.figure(figsize=figsize, layout=layout)
0146         ax = fig.add_subplot(1, 1, 1)
0147 
0148         # Refine plot
0149         ax.set_title(title)
0150         ax.set_xlabel(x_axis.label)
0151         ax.set_ylabel(y_axis.label)
0152         ax.grid(True, alpha=0.25)
0153 
0154         # Plot log scale
0155         if x_axis.log_scale is not None:
0156             ax.set_xscale("log", base=x_axis.log_scale)
0157         if y_axis.log_scale is not None:
0158             ax.set_yscale("log", base=y_axis.log_scale)
0159 
0160         if x_axis.tick_positions is not None:
0161             ax.set_xticks(x_axis.tick_positions)
0162             ax.tick_params(axis="x", which="major", pad=7)
0163 
0164         if y_axis.tick_positions is not None:
0165             ax.set_yticks(y_axis.tick_positions)
0166             ax.tick_params(axis="y", which="major", pad=7)
0167 
0168         # Restrict x and y ranges
0169         x = self.__apply_boundary(x, x_axis.min, x_axis.max)
0170         y = self.__apply_boundary(y, y_axis.min, y_axis.max)
0171 
0172         # Format of tick labels
0173         self.__set_label_format(x_axis.label_format, ax.xaxis)
0174         self.__set_label_format(y_axis.label_format, ax.yaxis)
0175 
0176         # Nothing left to do
0177         if len(x) == 0:
0178             self.logger.debug(rf" create graph: empty data {label}")
0179             return plt_data(fig=fig, axes=ax)
0180 
0181         if len(x) != len(y):
0182             self.logger.debug(rf" create graph: x range does match y range {label}")
0183             return plt_data(fig=fig, axes=ax, errors=y_errors)
0184 
0185         data = ax.errorbar(
0186             x=x, y=y, label=label, yerr=y_errors, marker=marker, color=color
0187         )
0188 
0189         # Add legend
0190         lgd = self.__add_legend(ax, lgd_ops)
0191 
0192         return plt_data(fig=fig, axes=ax, lgd=lgd, data=data, errors=y_errors)
0193 
0194     """ Add new graph to an existing plot """
0195 
0196     def add_graph(
0197         self,
0198         plot,
0199         x,
0200         y,
0201         y_errors=None,
0202         label="",
0203         marker="+",
0204         color=None,
0205     ):
0206         # Nothing left to do
0207         if len(y) == 0 or plot.data is None:
0208             self.logger.debug(rf" add graph: empty data {label}")
0209             return plot
0210 
0211         # Add new data to old plot axis
0212         data = plot.axes.errorbar(
0213             x=x,
0214             y=y,
0215             label=label,
0216             yerr=y_errors,
0217             color=color,
0218             marker=marker,
0219         )
0220 
0221         self.__update_legend(plot.lgd)
0222 
0223         # Rescale the plot
0224         plot.axes.relim()
0225         plot.axes.autoscale_view()
0226 
0227         return plt_data(
0228             fig=plot.fig, axes=plot.axes, lgd=plot.lgd, data=data, errors=y_errors
0229         )
0230 
0231     """
0232     Create a histogram from given input data. The normalization is achieved by
0233     dividing the bin count by the total number of observations. The error is
0234     calculated as the square root of the bin content.
0235     """
0236 
0237     def hist1D(
0238         self,
0239         x,
0240         bins=1,
0241         errors=None,
0242         w=None,
0243         title="",
0244         label="",
0245         x_axis=axis_options(label="x"),
0246         y_axis=axis_options(label=""),
0247         color=default_color,
0248         alpha=0.75,
0249         normalize=False,
0250         show_error=False,
0251         show_stats=True,
0252         u_outlier=-1,
0253         o_outlier=-1,
0254         lgd_ops=legend_options(),
0255         figsize=(8, 8),
0256         layout="compressed",
0257     ):
0258 
0259         # Create fresh plot
0260         fig = plt.figure(figsize=figsize, layout=layout)
0261         ax = fig.add_subplot(1, 1, 1)
0262 
0263         # Refine plot
0264         ax.set_title(title)
0265         ax.set_xlabel(x_axis.label)
0266         ax.set_ylabel(y_axis.label)
0267         ax.grid(True, alpha=0.25)
0268 
0269         # Plot log scale
0270         if x_axis.log_scale is not None:
0271             ax.set_xscale("log", base=x_axis.log_scale)
0272         if y_axis.log_scale is not None:
0273             ax.set_yscale("log", base=y_axis.log_scale)
0274 
0275         # Leave x-axis with default formatter for 1D histograms
0276         ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
0277         self.__set_label_format(y_axis.label_format, ax.yaxis)
0278 
0279         # Do calculations on data in the range of the histogram
0280         x_min, x_max = self.__get_axis_boundaries(x, x_axis)
0281         if x_axis.min is not None and x_axis.max is not None:
0282             x = self.__apply_boundary(x, x_min, x_max)
0283 
0284         # Display number of entries in under- and overflow bins
0285         underflow = len(np.argwhere(x < x_min))
0286         overflow = len(np.argwhere(x > x_max))
0287         if u_outlier >= 0 or o_outlier >= 0:
0288             underflow = underflow + u_outlier
0289             overflow = overflow + o_outlier
0290 
0291         # Nothing left to do
0292         if len(x) == 0:
0293             self.logger.debug(rf" create hist: empty data {label}")
0294             return plt_data(fig=fig, axes=ax)
0295 
0296         # Histogram normalization
0297         scale = 1.0 / len(x) if normalize else 1.0
0298 
0299         # Format the 'newline'
0300         newline = "\n"
0301 
0302         # Name of the data collection
0303         label_str = f"{label} ({len(x)} entries)"
0304         if u_outlier >= 0 or o_outlier >= 0:
0305             label_str = (
0306                 label_str
0307                 + f"{newline} underflow: {underflow}"
0308                 + f"{newline} overflow:  {overflow}"
0309             )
0310 
0311         # Fill data
0312         data, bins, _ = ax.hist(
0313             x,
0314             weights=w,
0315             range=(x_min, x_max),
0316             bins=bins,
0317             label=label_str,
0318             histtype="stepfilled",
0319             density=normalize,
0320             alpha=alpha,
0321             facecolor=color,
0322             edgecolor=color,
0323         )
0324 
0325         # Add some additional information
0326         if show_stats:
0327             mean = np.mean(x, axis=0)
0328             # rms  = np.sqrt(np.mean(np.square(x)))
0329             stdev = np.std(x, axis=0)
0330 
0331             # Create empty plot with blank marker containing the extra label
0332             ax.plot(
0333                 [],
0334                 [],
0335                 " ",
0336                 label=rf"data:"
0337                 rf"{newline}mean    = {mean:.2e}"
0338                 rf"{newline}stddev  = {stdev:.2e}",
0339             )
0340         else:
0341             mean = None
0342             stdev = None
0343 
0344         # Calculate the bin error
0345         bin_centers = 0.5 * (bins[1:] + bins[:-1])
0346         err = np.sqrt(scale * data) if errors is None else errors
0347         if show_error or errors is not None:
0348             ax.errorbar(
0349                 bin_centers,
0350                 data,
0351                 yerr=err,
0352                 fmt=".",
0353                 linestyle="",
0354                 linewidth=0.4,
0355                 color="black",
0356                 capsize=2.5,
0357             )
0358 
0359         # Add legend
0360         lgd = self.__add_legend(ax, lgd_ops)
0361 
0362         # Adjust spacing in box
0363         lgd.legend_handles[0].set_visible(False)
0364         if show_stats:
0365             lgd.legend_handles[1].set_visible(False)
0366         for vpack in lgd._legend_handle_box.get_children():
0367             for hpack in vpack.get_children():
0368                 hpack.get_children()[0].set_width(0)
0369 
0370         return plt_data(
0371             fig=fig,
0372             axes=ax,
0373             lgd=lgd,
0374             data=data,
0375             bins=bins,
0376             mu=mean,
0377             rms=stdev,
0378             errors=err,
0379         )
0380 
0381     """ Add new histogram to an existing plot """
0382 
0383     def add_hist(
0384         self,
0385         old_hist,
0386         x,
0387         errors=None,
0388         w=None,
0389         label="",
0390         color="tab:orange",
0391         alpha=0.75,
0392         normalize=False,
0393         show_error=False,
0394     ):
0395 
0396         # Do calculations on data in the range of the histogram
0397         x = self.__apply_boundary(x, np.min(old_hist.bins), np.max(old_hist.bins))
0398 
0399         # Nothing left to do
0400         if len(x) == 0 or old_hist.data is None:
0401             self.logger.debug(rf" add hist: empty data {label}")
0402             return old_hist
0403 
0404         # Add new data to old hist axis
0405         scale = 1.0 / len(x) if normalize else 1.0
0406         data, bins, _ = old_hist.axes.hist(
0407             x=x,
0408             bins=old_hist.bins,
0409             label=f"{label} ({len(x)} entries)",
0410             weights=w,
0411             histtype="stepfilled",
0412             facecolor=color,
0413             alpha=alpha,
0414             edgecolor=color,
0415         )
0416 
0417         # Calculate the bin error
0418         bin_centers = 0.5 * (bins[1:] + bins[:-1])
0419         err = np.sqrt(scale * data) if errors is None else errors
0420         if show_error or errors is not None:
0421             old_hist.axes.errorbar(
0422                 bin_centers,
0423                 data,
0424                 yerr=err,
0425                 fmt=".",
0426                 linestyle="",
0427                 linewidth=0.4,
0428                 color="black",
0429                 capsize=2.5,
0430             )
0431 
0432         # Update legend
0433         self.__update_legend(old_hist.lgd)
0434 
0435         return plt_data(
0436             fig=old_hist.fig,
0437             axes=old_hist.axes,
0438             lgd=old_hist.lgd,
0439             data=data,
0440             bins=bins,
0441             errors=err,
0442         )
0443 
0444     """
0445     Plot the ratio of two histograms. The data is assumed to be uncorrelated.
0446     """
0447 
0448     def add_ratio(
0449         self, nom, denom, label, color="tab:red", set_log=False, show_error=False
0450     ):
0451 
0452         # Resize figure
0453         nom.fig.set_figheight(7)
0454         nom.fig.set_figwidth(8)
0455 
0456         if nom.bins is None or denom.bins is None:
0457             return plt_data(fig=nom.fig, axes=nom.axes)
0458 
0459         if len(nom.bins) != len(denom.bins):
0460             return plt_data(fig=nom.fig, axes=nom.axes)
0461 
0462         # Remove ticks/labels that are already visible on the ratio plot
0463         x_label = nom.axes.xaxis.get_label().get_text()
0464         nom.axes.tick_params(
0465             axis="x", which="both", bottom=True, top=False, labelbottom=False
0466         )
0467         nom.axes.set_xlabel("")
0468 
0469         # Don't print a warning when dividing by zero
0470         with np.errstate(divide="ignore"), np.errstate(invalid="ignore"):
0471             # Filter out nan results from division by zero
0472             ratio = np.nan_to_num(nom.data / denom.data, nan=0, posinf=0)
0473 
0474             # Calculate errors by Gaussian propagation
0475             bin_centers = 0.5 * (nom.bins[1:] + nom.bins[:-1])
0476             n_data, d_data = (nom.data, denom.data)
0477 
0478             # Gaussian approximation for large number of events in bin
0479             # Note: Should be Clopper-Pearson
0480             n_err, d_err = (nom.errors, denom.errors)
0481             errors = np.nan_to_num(
0482                 np.sqrt(
0483                     np.square(n_err / d_data)
0484                     + np.square(n_data * d_err / np.square(d_data))
0485                 ),
0486                 nan=0,
0487                 posinf=0,
0488             )
0489 
0490         # Create new axes on the bottom of the current axes
0491         # The first argument of the new_vertical(new_horizontal) method is
0492         # the height (width) of the axes to be created in inches.
0493         divider = make_axes_locatable(nom.axes)
0494         ratio_plot = divider.append_axes("bottom", 1.2, pad=0.2, sharex=nom.axes)
0495         # Ratio should be around 1: Don't use scientific notation/offset
0496         ratio_plot.axes.yaxis.set_major_formatter(
0497             ticker.ScalarFormatter(useOffset=False)
0498         )
0499         if show_error:
0500             ratio_plot.errorbar(
0501                 bin_centers, ratio, yerr=errors, label=label, color=color, fmt="."
0502             )
0503         else:
0504             ratio_plot.plot(
0505                 bin_centers,
0506                 ratio,
0507                 label=label,
0508                 color=color,
0509                 marker=".",
0510                 linestyle="",
0511             )
0512 
0513         # Refine plot
0514         ratio_plot.set_xlabel(x_label)
0515         ratio_plot.set_ylabel("ratio")
0516         ratio_plot.grid(True, alpha=0.25)
0517 
0518         # Plot log scale
0519         if set_log:
0520             ratio_plot.set_yscale("log")
0521 
0522         # Add a horizontal blue line at y = 1.
0523         ratio_plot.axline((nom.bins[0], 1), (nom.bins[-1], 1), linewidth=1, color="b")
0524 
0525         nom.fig.set_size_inches((9, 9))
0526 
0527         return plt_data(fig=nom.fig, axes=ratio_plot, errors=errors)
0528 
0529     """
0530     Create a 2D histogram from given input data. If z values are given they will
0531     be used as weights per bin.
0532     """
0533 
0534     def hist2D(
0535         self,
0536         x,
0537         y,
0538         z=None,
0539         x_bins=1,
0540         y_bins=1,
0541         x_axis=axis_options(label="x"),
0542         y_axis=axis_options(label="y"),
0543         z_axis=axis_options(label=""),
0544         title="",
0545         label="",
0546         color=default_color,
0547         alpha=0.75,
0548         show_stats=True,
0549         figsize=(8, 6),
0550     ):
0551 
0552         # Create fresh plot
0553         fig = plt.figure(figsize=figsize, layout="constrained")
0554         ax = fig.add_subplot(1, 1, 1)
0555 
0556         # Refine plot
0557         ax.set_title(title)
0558         ax.set_xlabel(x_axis.label)
0559         ax.set_ylabel(y_axis.label)
0560 
0561         # Do calculations on data in the range of the histogram
0562         x_min, x_max = self.__get_axis_boundaries(x, x_axis)
0563         if x_axis.min is not None and x_axis.max is not None:
0564             x = self.__apply_boundary(x, x_min, x_max)
0565 
0566         y_min, y_max = self.__get_axis_boundaries(y, y_axis)
0567         if y_axis.min is not None and y_axis.max is not None:
0568             y = self.__apply_boundary(y, y_min, y_max)
0569 
0570         # Nothing left to do
0571         if len(x) == 0 or len(y) == 0:
0572             self.logger.debug(rf" create hist: empty data {label}")
0573             return plt_data(fig=fig, axes=ax)
0574 
0575         # Fill data
0576         data, _, _, hist = ax.hist2d(
0577             x,
0578             y,
0579             weights=z,
0580             range=[(x_min, x_max), (y_min, y_max)],
0581             bins=(x_bins, y_bins),
0582             label=f"{label}  ({len(x)*len(y)} entries)",
0583             facecolor=mcolors.to_rgba(color, alpha),
0584             edgecolor=None,
0585             rasterized=True,
0586         )
0587 
0588         # Add some additional information
0589         if show_stats:
0590             x_mean = np.mean(x, axis=0)
0591             x_rms = np.sqrt(np.mean(np.square(x)))
0592             y_mean = np.mean(y, axis=0)
0593             y_rms = np.sqrt(np.mean(np.square(y)))
0594 
0595             # Create empty plot with blank marker containing the extra label
0596             newline = "\n"
0597             ax.plot(
0598                 [],
0599                 [],
0600                 " ",
0601                 label=rf"xMean = {x_mean:.2e}"
0602                 rf"{newline}xRMS  = {x_rms:.2e}"
0603                 rf"yMean = {y_mean:.2e}"
0604                 rf"{newline}yRMS  = {y_rms:.2e}",
0605             )
0606 
0607         # Add the colorbar
0608         fig.colorbar(hist, label=z_axis.label)
0609 
0610         return plt_data(fig=fig, axes=ax, data=data)
0611 
0612     """ Create a 2D scatter plot """
0613 
0614     def scatter(
0615         self,
0616         x,
0617         y,
0618         x_axis=axis_options(label=""),
0619         y_axis=axis_options(label=""),
0620         title="",
0621         label="",
0622         color=default_color,
0623         alpha=1,
0624         show_stats=lambda x, _: f"{len(x)} entries",
0625         lgd_ops=legend_options(),
0626         figsize=(8, 6),
0627     ):
0628 
0629         fig = plt.figure(figsize=figsize, layout="constrained")
0630         ax = fig.add_subplot(1, 1, 1)
0631 
0632         # Refine plot
0633         ax.set_title(title)
0634         ax.set_xlabel(x_axis.label)
0635         ax.set_ylabel(y_axis.label)
0636         ax.grid(True, alpha=0.25)
0637 
0638         # Create empty plot with blank marker containing the extra label
0639         ax.plot([], [], " ", label=show_stats(x, y))
0640         scatter = ax.scatter(
0641             x, y, label=label, c=color, s=0.1, alpha=alpha, rasterized=True
0642         )
0643 
0644         # Add legend
0645         lgd = self.__add_legend(ax, lgd_ops)
0646 
0647         # Refine legend
0648         self.__adjust_lgd_label_spacing(lgd)
0649 
0650         return plt_data(fig=fig, axes=ax, lgd=lgd, data=scatter)
0651 
0652     """ Add new data in a different color to a scatter plot """
0653 
0654     def highlight_region(self, plot_data, x, y, color, label=""):
0655 
0656         if label == "":
0657             plot_data.axes.scatter(x, y, c=color, alpha=1, s=0.1, rasterized=True)
0658         else:
0659             plot_data.axes.scatter(
0660                 x, y, c=color, alpha=1, s=0.1, label=label, rasterized=True
0661             )
0662 
0663             # Update legend
0664             self.__update_legend(plot_data.lgd)
0665 
0666         # Refine legend
0667         self.__adjust_lgd_label_spacing(plot_data.lgd)
0668 
0669     """ Fit a Gaussian to a 1D distribution and plot in the same figure. """
0670 
0671     def fit_gaussian(self, dist, color="tab:orange"):
0672 
0673         # Calculate bin centers from bin edges
0674         bins = dist.bins
0675         if bins is None:
0676             # If fit failed, return empty result
0677             return None, None
0678 
0679         bin_centers = [(b1 + b2) / 2 for b1, b2 in zip(bins, bins[1:])]
0680 
0681         # Gaussian distribution with all fit parameters
0682         def __gaussian(x, a, mean, sigma):
0683             return (
0684                 a
0685                 / (math.sqrt(2 * math.pi) * sigma)
0686                 * np.exp(-((x - mean) ** 2 / (2 * sigma**2)))
0687             )
0688 
0689         # Gaussian fit
0690         try:
0691             from scipy.optimize import curve_fit
0692         except ImportError:
0693             print("WARNING: Could not find scipy: Skipping fit")
0694         else:
0695             try:
0696                 # Initial estimators
0697                 mean = np.mean(bin_centers, axis=0)
0698                 sigma = np.std(bin_centers, axis=0)
0699                 a = np.max(dist.data) * (math.sqrt(2 * math.pi) * sigma)
0700 
0701                 popt, _ = curve_fit(
0702                     __gaussian, bin_centers, dist.data, p0=[a, mean, sigma]
0703                 )
0704             except RuntimeError:
0705                 # If fit failed, return empty result
0706                 return None, None
0707 
0708             # If the fitting was successful, plot the curve
0709             mu = float(f"{popt[1]:.2e}")  # < formatting the sig. digits
0710             sig = float(f"{popt[2]:.2e}")
0711             newline = "\n"
0712             plot_label = (
0713                 rf"gaussian fit:{newline}$\mu$ = {mu:.2e}"
0714                 + rf"{newline}$\sigma$ = {abs(sig):.2e}"
0715             )
0716 
0717             # Generate points for the curve
0718             min_val = min(bin_centers)
0719             max_val = max(bin_centers)
0720             step = (max_val - min_val) / 1000
0721             x = [v for v in np.arange(min_val, max_val + step, step)]
0722 
0723             dist.axes.plot(
0724                 x,
0725                 __gaussian(x, *popt),
0726                 label=plot_label,
0727                 color=color,
0728             )
0729 
0730             # Update legend
0731             self.__update_legend(dist.lgd)
0732 
0733             # Adjust spacing in box
0734             dist.lgd.legend_handles[0].set_visible(False)
0735             for vpack in dist.lgd._legend_handle_box.get_children()[:-1]:
0736                 for hpack in vpack.get_children():
0737                     hpack.get_children()[0].set_width(0)
0738 
0739             return popt[1], abs(popt[2])
0740 
0741         return None, None
0742 
0743     """ Draw a vertical line in a given plot"""
0744 
0745     def vertical_line(self, plot_data, x, y=None, color="b", label=""):
0746 
0747         plot_data.axes.axvline(x=x, color=color, linestyle="--")
0748 
0749         ymin, ymax = plot_data.axes.get_ylim()
0750         plot_data.axes.text(
0751             x,
0752             ymin + (ymax - ymin) / 2 if y is None else y,
0753             label,
0754             ha="center",
0755             va="center",
0756             backgroundcolor="white",
0757         )
0758 
0759     """ Write a plot to disk """
0760 
0761     def write_plot(
0762         self, plot_data, name="plot", file_format="svg", out_prefix="", dpi=450
0763     ):
0764         if out_prefix == "":
0765             file_name = self.output_prefix + "/" + name + "." + file_format
0766         else:
0767             file_name = out_prefix + name + "." + file_format
0768 
0769         plot_data.fig.savefig(file_name, dpi=dpi)
0770         plt.close(plot_data.fig)
0771 
0772     """ Write a plot as svg """
0773 
0774     def write_svg(self, plot_data, name, out_prefix=""):
0775         self.write_plot(plot_data, name, ".svg", out_prefix)
0776 
0777     """ Write a plot as pdf """
0778 
0779     def write_pdf(self, plot_data, name, out_prefix=""):
0780         self.write_plot(plot_data, name, ".pdf", out_prefix)
0781 
0782     """ Write a plot as png """
0783 
0784     def write_png(self, plot_data, name, out_prefix=""):
0785         self.write_plot(plot_data, name, ".png", out_prefix)