Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-23 09:22:36

0001 from dataclasses import dataclass
0002 from typing import Tuple
0003 
0004 import numpy as np
0005 from matplotlib import pyplot as plt
0006 from scipy.optimize import curve_fit
0007 
0008 from core.constants import N_CELLS_Z, N_CELLS_R, VALID_DIR, SIZE_Z, SIZE_R, HISTOGRAM_TYPE, FULL_SIM_HISTOGRAM_COLOR, \
0009     ML_SIM_HISTOGRAM_COLOR, FULL_SIM_GAUSSIAN_COLOR, ML_SIM_GAUSSIAN_COLOR
0010 from utils.observables import LongitudinalProfile, ProfileType, Profile, Energy
0011 
0012 plt.rcParams.update({"font.size": 22})
0013 
0014 
0015 @dataclass
0016 class Plotter:
0017     """ An abstract class defining interface of all plotters.
0018 
0019     Do not use this class directly. Use ProfilePlotter or EnergyPlotter instead.
0020 
0021     Attributes:
0022         _particle_energy: An integer which is energy of the primary particle in GeV units.
0023         _particle_angle: An integer which is an angle of the primary particle in degrees.
0024         _geometry: A string which is a name of the calorimeter geometry (e.g. SiW, SciPb).
0025 
0026     """
0027     _particle_energy: int
0028     _particle_angle: int
0029     _geometry: str
0030 
0031     def plot_and_save(self):
0032         pass
0033 
0034 
0035 def _gaussian(x: np.ndarray, a: float, mu: float, sigma: float) -> np.ndarray:
0036     """ Computes a value of a Gaussian.
0037 
0038     Args:
0039         x: An argument of a function.
0040         a: A scaling parameter.
0041         mu: A mean.
0042         sigma: A variance.
0043 
0044     Returns:
0045         A value of a function for given arguments.
0046 
0047     """
0048     return a * np.exp(-((x - mu)**2 / (2 * sigma**2)))
0049 
0050 
0051 def _best_fit(data: np.ndarray,
0052               bins: np.ndarray,
0053               hist: bool = False) -> Tuple[np.ndarray, np.ndarray]:
0054     """ Finds estimated shape of a Gaussian using Use non-linear least squares.
0055 
0056     Args:
0057         data: A numpy array with values of observables from multiple events.
0058         bins: A numpy array specifying histogram bins.
0059         hist: If histogram is calculated. Then data is the frequencies.
0060 
0061     Returns:
0062         A tuple of two lists. Xs and Ys of predicted curve.
0063 
0064     """
0065     # Calculate histogram.
0066     if not hist:
0067         hist, _ = np.histogram(data, bins)
0068     else:
0069         hist = data
0070 
0071     # Choose only those bins which are nonzero. Nonzero() return a tuple of arrays. In this case it has a length = 1,
0072     # hence we are interested in its first element.
0073     indices = hist.nonzero()[0]
0074 
0075     # Based on previously chosen nonzero bin, calculate position of xs and ys_bar (true values) which will be used in
0076     # fitting procedure. Len(bins) == len(hist + 1), so we choose middles of bins as xs.
0077     bins_middles = (bins[:-1] + bins[1:]) / 2
0078     xs = bins_middles[indices]
0079     ys_bar = hist[indices]
0080 
0081     # Set initial parameters for curve fitter.
0082     a0 = np.max(ys_bar)
0083     mu0 = np.mean(xs)
0084     sigma0 = np.var(xs)
0085 
0086     # Fit a Gaussian to the prepared data.
0087     (a, mu, sigma), _ = curve_fit(f=_gaussian,
0088                                   xdata=xs,
0089                                   ydata=ys_bar,
0090                                   p0=[a0, mu0, sigma0],
0091                                   method="trf",
0092                                   maxfev=1000)
0093 
0094     # Calculate values of an approximation in given points and return values.
0095     ys = _gaussian(xs, a, mu, sigma)
0096     return xs, ys
0097 
0098 
0099 @dataclass
0100 class ProfilePlotter(Plotter):
0101     """ Plotter responsible for preparing plots of profiles and their first and second moments.
0102 
0103     Attributes:
0104         _full_simulation: A numpy array representing a profile of data generated by Geant4.
0105         _ml_simulation: A numpy array representing a profile of data generated by ML model.
0106         _plot_gaussian: A boolean. Decides whether first and second moment should be plotted as a histogram or
0107             a fitted gaussian.
0108         _profile_type: An enum. A profile can be either lateral or longitudinal.
0109 
0110     """
0111     _full_simulation: Profile
0112     _ml_simulation: Profile
0113     _plot_gaussian: bool = False
0114 
0115     def __post_init__(self):
0116         # Check if profiles are either both longitudinal or lateral.
0117         full_simulation_type = type(self._full_simulation)
0118         ml_generation_type = type(self._ml_simulation)
0119         assert full_simulation_type == ml_generation_type, "Both profiles within a ProfilePlotter must be the same " \
0120                                                            "type."
0121 
0122         # Set an attribute with profile type.
0123         if full_simulation_type == LongitudinalProfile:
0124             self._profile_type = ProfileType.LONGITUDINAL
0125         else:
0126             self._profile_type = ProfileType.LATERAL
0127 
0128     def _plot_and_save_customizable_histogram(
0129             self,
0130             full_simulation: np.ndarray,
0131             ml_simulation: np.ndarray,
0132             bins: np.ndarray,
0133             xlabel: str,
0134             observable_name: str,
0135             plot_profile: bool = False,
0136             y_log_scale: bool = False) -> None:
0137         """ Prepares and saves a histogram for a given pair of observables.
0138 
0139         Args:
0140             full_simulation: A numpy array of observables coming from full simulation.
0141             ml_simulation: A numpy array of observables coming from ML simulation.
0142             bins: A numpy array specifying histogram bins.
0143             xlabel: A string. Name of x-axis on the plot.
0144             observable_name: A string. Name of plotted observable.
0145             plot_profile: A boolean. If set to True, full_simulation and ml_simulation are histogram weights while x is
0146                 defined by the number of layers. This means that in order to plot histogram (and gaussian), one first
0147                 need to create a data repeating each layer or R index appropriate number of times. Should be set to True
0148                 only while plotting profiles not first or second moments.
0149             y_log_scale: A boolean. Used log scale on y-axis is set to True.
0150 
0151         Returns:
0152             None.
0153 
0154         """
0155         fig, axes = plt.subplots(2,
0156                                  1,
0157                                  figsize=(15, 10),
0158                                  clear=True,
0159                                  sharex="all")
0160 
0161         # Plot histograms.
0162         if plot_profile:
0163             # We already have the bins (layers) and freqencies (energies),
0164             # therefore directly plotting a step plot + lines instead of a hist plot.
0165             axes[0].step(bins[:-1],
0166                          full_simulation,
0167                          label="FullSim",
0168                          color=FULL_SIM_HISTOGRAM_COLOR)
0169             axes[0].step(bins[:-1],
0170                          ml_simulation,
0171                          label="MLSim",
0172                          color=ML_SIM_HISTOGRAM_COLOR)
0173             axes[0].vlines(x=bins[0],
0174                            ymin=0,
0175                            ymax=full_simulation[0],
0176                            color=FULL_SIM_HISTOGRAM_COLOR)
0177             axes[0].vlines(x=bins[-2],
0178                            ymin=0,
0179                            ymax=full_simulation[-1],
0180                            color=FULL_SIM_HISTOGRAM_COLOR)
0181             axes[0].vlines(x=bins[0],
0182                            ymin=0,
0183                            ymax=ml_simulation[0],
0184                            color=ML_SIM_HISTOGRAM_COLOR)
0185             axes[0].vlines(x=bins[-2],
0186                            ymin=0,
0187                            ymax=ml_simulation[-1],
0188                            color=ML_SIM_HISTOGRAM_COLOR)
0189             axes[0].set_ylim(0, None)
0190 
0191             # For using it later for the ratios.
0192             energy_full_sim, energy_ml_sim = full_simulation, ml_simulation
0193         else:
0194             energy_full_sim, _, _ = axes[0].hist(
0195                 x=full_simulation,
0196                 bins=bins,
0197                 label="FullSim",
0198                 histtype=HISTOGRAM_TYPE,
0199                 color=FULL_SIM_HISTOGRAM_COLOR)
0200             energy_ml_sim, _, _ = axes[0].hist(x=ml_simulation,
0201                                                bins=bins,
0202                                                label="MLSim",
0203                                                histtype=HISTOGRAM_TYPE,
0204                                                color=ML_SIM_HISTOGRAM_COLOR)
0205 
0206         # Plot Gaussians if needed.
0207         if self._plot_gaussian:
0208             if plot_profile:
0209                 (xs_full_sim, ys_full_sim) = _best_fit(full_simulation,
0210                                                        bins,
0211                                                        hist=True)
0212                 (xs_ml_sim, ys_ml_sim) = _best_fit(ml_simulation,
0213                                                    bins,
0214                                                    hist=True)
0215             else:
0216                 (xs_full_sim, ys_full_sim) = _best_fit(full_simulation, bins)
0217                 (xs_ml_sim, ys_ml_sim) = _best_fit(ml_simulation, bins)
0218             axes[0].plot(xs_full_sim,
0219                          ys_full_sim,
0220                          color=FULL_SIM_GAUSSIAN_COLOR,
0221                          label="FullSim")
0222             axes[0].plot(xs_ml_sim,
0223                          ys_ml_sim,
0224                          color=ML_SIM_GAUSSIAN_COLOR,
0225                          label="MLSim")
0226 
0227         if y_log_scale:
0228             axes[0].set_yscale("log")
0229         axes[0].legend(loc="best")
0230         axes[0].set_xlabel(xlabel)
0231         axes[0].set_ylabel("Energy [Mev]")
0232         axes[0].set_title(
0233             f" $e^-$, {self._particle_energy} [GeV], {self._particle_angle}$^{{\circ}}$, {self._geometry}"
0234         )
0235 
0236         # Calculate ratios.
0237         ratio = np.divide(energy_ml_sim,
0238                           energy_full_sim,
0239                           out=np.ones_like(energy_ml_sim),
0240                           where=(energy_full_sim != 0))
0241         # Since len(bins) == 1 + data, we calculate middles of bins as xs.
0242         bins_middles = (bins[:-1] + bins[1:]) / 2
0243         axes[1].plot(bins_middles, ratio, "-o")
0244         axes[1].set_xlabel(xlabel)
0245         axes[1].set_ylabel("MLSim/FullSim")
0246         axes[1].axhline(y=1, color="black")
0247         plt.savefig(
0248             f"{VALID_DIR}/{observable_name}_Geo_{self._geometry}_E_{self._particle_energy}_"
0249             + f"Angle_{self._particle_angle}.png")
0250         plt.clf()
0251 
0252     def _plot_profile(self) -> None:
0253         """ Plots profile of an observable.
0254 
0255         Returns:
0256             None.
0257 
0258         """
0259         full_simulation_profile = self._full_simulation.calc_profile()
0260         ml_simulation_profile = self._ml_simulation.calc_profile()
0261         if self._profile_type == ProfileType.LONGITUDINAL:
0262             # matplotlib will include the right-limit for the last bar,
0263             # hence extending by 1.
0264             bins = np.linspace(0, N_CELLS_Z, N_CELLS_Z + 1)
0265             observable_name = "LongProf"
0266             xlabel = "Layer index"
0267         else:
0268             bins = np.linspace(0, N_CELLS_R, N_CELLS_R + 1)
0269             observable_name = "LatProf"
0270             xlabel = "R index"
0271         self._plot_and_save_customizable_histogram(full_simulation_profile,
0272                                                    ml_simulation_profile,
0273                                                    bins,
0274                                                    xlabel,
0275                                                    observable_name,
0276                                                    plot_profile=True)
0277 
0278     def _plot_first_moment(self) -> None:
0279         """ Plots and saves a first moment of an observable's profile.
0280 
0281         Returns:
0282             None.
0283 
0284         """
0285         full_simulation_first_moment = self._full_simulation.calc_first_moment(
0286         )
0287         ml_simulation_first_moment = self._ml_simulation.calc_first_moment()
0288         if self._profile_type == ProfileType.LONGITUDINAL:
0289             xlabel = "$<\lambda> [mm]$"
0290             observable_name = "LongFirstMoment"
0291             bins = np.linspace(0, 0.4 * N_CELLS_Z * SIZE_Z, 128)
0292         else:
0293             xlabel = "$<r> [mm]$"
0294             observable_name = "LatFirstMoment"
0295             bins = np.linspace(0, 0.75 * N_CELLS_R * SIZE_R, 128)
0296 
0297         self._plot_and_save_customizable_histogram(
0298             full_simulation_first_moment, ml_simulation_first_moment, bins,
0299             xlabel, observable_name)
0300 
0301     def _plot_second_moment(self) -> None:
0302         """ Plots and saves a second moment of an observable's profile.
0303 
0304         Returns:
0305             None.
0306 
0307         """
0308         full_simulation_second_moment = self._full_simulation.calc_second_moment(
0309         )
0310         ml_simulation_second_moment = self._ml_simulation.calc_second_moment()
0311         if self._profile_type == ProfileType.LONGITUDINAL:
0312             xlabel = "$<\lambda^{2}> [mm^{2}]$"
0313             observable_name = "LongSecondMoment"
0314             bins = np.linspace(0, pow(N_CELLS_Z * SIZE_Z, 2) / 35., 128)
0315         else:
0316             xlabel = "$<r^{2}> [mm^{2}]$"
0317             observable_name = "LatSecondMoment"
0318             bins = np.linspace(0, pow(N_CELLS_R * SIZE_R, 2) / 8., 128)
0319 
0320         self._plot_and_save_customizable_histogram(
0321             full_simulation_second_moment, ml_simulation_second_moment, bins,
0322             xlabel, observable_name)
0323 
0324     def plot_and_save(self) -> None:
0325         """ Main plotting function.
0326 
0327         Calls private methods and prints the information about progress.
0328 
0329         Returns:
0330             None.
0331 
0332         """
0333         if self._profile_type == ProfileType.LONGITUDINAL:
0334             profile_type_name = "longitudinal"
0335         else:
0336             profile_type_name = "lateral"
0337         print(f"Plotting the {profile_type_name} profile...")
0338         self._plot_profile()
0339         print(f"Plotting the first moment of {profile_type_name} profile...")
0340         self._plot_first_moment()
0341         print(f"Plotting the second moment of {profile_type_name} profile...")
0342         self._plot_second_moment()
0343 
0344 
0345 @dataclass
0346 class EnergyPlotter(Plotter):
0347     """ Plotter responsible for preparing plots of profiles and their first and second moments.
0348 
0349     Attributes:
0350         _full_simulation: A numpy array representing a profile of data generated by Geant4.
0351         _ml_simulation: A numpy array representing a profile of data generated by ML model.
0352 
0353     """
0354     _full_simulation: Energy
0355     _ml_simulation: Energy
0356 
0357     def _plot_total_energy(self, y_log_scale=True) -> None:
0358         """ Plots and saves a histogram with total energy detected in an event.
0359 
0360         Args:
0361             y_log_scale: A boolean. Used log scale on y-axis is set to True.
0362 
0363         Returns:
0364             None.
0365 
0366         """
0367         full_simulation_total_energy = self._full_simulation.calc_total_energy(
0368         )
0369         ml_simulation_total_energy = self._ml_simulation.calc_total_energy()
0370 
0371         plt.figure(figsize=(12, 8))
0372         bins = np.linspace(
0373             np.min(full_simulation_total_energy) -
0374             np.min(full_simulation_total_energy) * 0.05,
0375             np.max(full_simulation_total_energy) +
0376             np.max(full_simulation_total_energy) * 0.05, 50)
0377         plt.hist(x=full_simulation_total_energy,
0378                  histtype=HISTOGRAM_TYPE,
0379                  label="FullSim",
0380                  bins=bins,
0381                  color=FULL_SIM_HISTOGRAM_COLOR)
0382         plt.hist(x=ml_simulation_total_energy,
0383                  histtype=HISTOGRAM_TYPE,
0384                  label="MLSim",
0385                  bins=bins,
0386                  color=ML_SIM_HISTOGRAM_COLOR)
0387         plt.legend(loc="upper left")
0388         if y_log_scale:
0389             plt.yscale("log")
0390         plt.xlabel("Energy [MeV]")
0391         plt.ylabel("# events")
0392         plt.title(
0393             f" $e^-$, {self._particle_energy} [GeV], {self._particle_angle}$^{{\circ}}$, {self._geometry} "
0394         )
0395         plt.savefig(
0396             f"{VALID_DIR}/E_tot_Geo_{self._geometry}_E_{self._particle_energy}_Angle_{self._particle_angle}.png"
0397         )
0398         plt.clf()
0399 
0400     def _plot_cell_energy(self) -> None:
0401         """ Plots and saves a histogram with number of detector's cells across whole
0402         calorimeter with particular energy detected.
0403 
0404         Returns:
0405             None.
0406 
0407         """
0408         full_simulation_cell_energy = self._full_simulation.calc_cell_energy()
0409         ml_simulation_cell_energy = self._ml_simulation.calc_cell_energy()
0410 
0411         log_full_simulation_cell_energy = np.log10(
0412             full_simulation_cell_energy,
0413             out=np.zeros_like(full_simulation_cell_energy),
0414             where=(full_simulation_cell_energy != 0))
0415         log_ml_simulation_cell_energy = np.log10(
0416             ml_simulation_cell_energy,
0417             out=np.zeros_like(ml_simulation_cell_energy),
0418             where=(ml_simulation_cell_energy != 0))
0419         plt.figure(figsize=(12, 8))
0420         bins = np.linspace(-4, 1, 1000)
0421         plt.hist(x=log_full_simulation_cell_energy,
0422                  bins=bins,
0423                  histtype=HISTOGRAM_TYPE,
0424                  label="FullSim",
0425                  color=FULL_SIM_HISTOGRAM_COLOR)
0426         plt.hist(x=log_ml_simulation_cell_energy,
0427                  bins=bins,
0428                  histtype=HISTOGRAM_TYPE,
0429                  label="MLSim",
0430                  color=ML_SIM_HISTOGRAM_COLOR)
0431         plt.xlabel("log10(E/MeV)")
0432         plt.ylim(bottom=1)
0433         plt.yscale("log")
0434         plt.ylim(bottom=1)
0435         plt.ylabel("# entries")
0436         plt.title(
0437             f" $e^-$, {self._particle_energy} [GeV], {self._particle_angle}$^{{\circ}}$, {self._geometry} "
0438         )
0439         plt.grid(True)
0440         plt.legend(loc="upper left")
0441         plt.savefig(
0442             f"{VALID_DIR}/E_cell_Geo_{self._geometry}_E_{self._particle_energy}_Angle_{self._particle_angle}.png"
0443         )
0444         plt.clf()
0445 
0446     def _plot_energy_per_layer(self):
0447         """ Plots and saves N_CELLS_Z histograms with total energy detected in particular layers.
0448 
0449         Returns:
0450             None.
0451 
0452         """
0453         full_simulation_energy_per_layer = self._full_simulation.calc_energy_per_layer(
0454         )
0455         ml_simulation_energy_per_layer = self._ml_simulation.calc_energy_per_layer(
0456         )
0457 
0458         number_of_plots_in_row = 9
0459         number_of_plots_in_column = 5
0460 
0461         bins = np.linspace(np.min(full_simulation_energy_per_layer - 10),
0462                            np.max(full_simulation_energy_per_layer + 10), 25)
0463 
0464         fig, ax = plt.subplots(number_of_plots_in_column,
0465                                number_of_plots_in_row,
0466                                figsize=(20, 15),
0467                                sharex="all",
0468                                sharey="all",
0469                                constrained_layout=True)
0470 
0471         for layer_nb in range(N_CELLS_Z):
0472             i = layer_nb // number_of_plots_in_row
0473             j = layer_nb % number_of_plots_in_row
0474 
0475             ax[i][j].hist(full_simulation_energy_per_layer[:, layer_nb],
0476                           histtype=HISTOGRAM_TYPE,
0477                           label="FullSim",
0478                           bins=bins,
0479                           color=FULL_SIM_HISTOGRAM_COLOR)
0480             ax[i][j].hist(ml_simulation_energy_per_layer[:, layer_nb],
0481                           histtype=HISTOGRAM_TYPE,
0482                           label="MLSim",
0483                           bins=bins,
0484                           color=ML_SIM_HISTOGRAM_COLOR)
0485             ax[i][j].set_title(f"Layer {layer_nb}", fontsize=13)
0486             ax[i][j].set_yscale("log")
0487             ax[i][j].tick_params(axis='both', which='major', labelsize=10)
0488 
0489         fig.supxlabel("Energy [MeV]", fontsize=14)
0490         fig.supylabel("# entries", fontsize=14)
0491         fig.suptitle(
0492             f" $e^-$, {self._particle_energy} [GeV], {self._particle_angle}$^{{\circ}}$, {self._geometry} "
0493         )
0494 
0495         # Take legend from one plot and make it a global legend.
0496         handles, labels = ax[0][0].get_legend_handles_labels()
0497         fig.legend(handles, labels, bbox_to_anchor=(1.15, 0.5))
0498 
0499         plt.savefig(
0500             f"{VALID_DIR}/E_layer_Geo_{self._geometry}_E_{self._particle_energy}_Angle_{self._particle_angle}.png",
0501             bbox_inches="tight")
0502         plt.clf()
0503 
0504     def plot_and_save(self):
0505         """ Main plotting function.
0506 
0507         Calls private methods and prints the information about progress.
0508 
0509         Returns:
0510             None.
0511 
0512         """
0513         print("Plotting total energy...")
0514         self._plot_total_energy()
0515         print("Plotting cell energy...")
0516         self._plot_cell_energy()
0517         print("Plotting energy per layer...")
0518         self._plot_energy_per_layer()