Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-23 09:22:36

0001 import argparse
0002 
0003 import numpy as np
0004 
0005 from core.constants import INIT_DIR, GEN_DIR, N_CELLS_PHI, N_CELLS_R, N_CELLS_Z
0006 from utils.observables import LongitudinalProfile, LateralProfile, Energy
0007 from utils.plotters import ProfilePlotter, EnergyPlotter
0008 from utils.preprocess import load_showers
0009 
0010 
0011 def parse_args():
0012     p = argparse.ArgumentParser()
0013     p.add_argument("--geometry", type=str, default="")
0014     p.add_argument("--energy", type=int, default="")
0015     p.add_argument("--angle", type=int, default="")
0016     args = p.parse_args()
0017     return args
0018 
0019 
0020 # main function
0021 def main():
0022     # Parse commandline arguments
0023     args = parse_args()
0024     particle_energy = args.energy
0025     particle_angle = args.angle
0026     geometry = args.geometry
0027     # 1. Full simulation data loading
0028     # Load energy of showers from a single geometry, energy and angle
0029     e_layer_g4 = load_showers(INIT_DIR, geometry, particle_energy,
0030                               particle_angle)
0031     # 2. Fast simulation data loading, scaling to original energy range & reshaping
0032     vae_energies = np.load(f"{GEN_DIR}/VAE_Generated_Geo_{geometry}_E_{particle_energy}_Angle_{particle_angle}.npy")
0033     # Reshape the events into 3D
0034     e_layer_vae = vae_energies.reshape((len(vae_energies), N_CELLS_R, N_CELLS_PHI, N_CELLS_Z))
0035 
0036     print("Data has been loaded.")
0037 
0038     # 3. Create observables from raw data.
0039     full_sim_long = LongitudinalProfile(_input=e_layer_g4)
0040     full_sim_lat = LateralProfile(_input=e_layer_g4)
0041     full_sim_energy = Energy(_input=e_layer_g4)
0042     ml_sim_long = LongitudinalProfile(_input=e_layer_vae)
0043     ml_sim_lat = LateralProfile(_input=e_layer_vae)
0044     ml_sim_energy = Energy(_input=e_layer_vae)
0045 
0046     print("Created observables.")
0047 
0048     # 4. Plot observables
0049     longitudinal_profile_plotter = ProfilePlotter(particle_energy, particle_angle, geometry, full_sim_long, ml_sim_long,
0050                                                   _plot_gaussian=False)
0051     lateral_profile_plotter = ProfilePlotter(particle_energy, particle_angle,
0052                                              geometry, full_sim_lat, ml_sim_lat, _plot_gaussian=False)
0053     energy_plotter = EnergyPlotter(particle_energy, particle_angle, geometry, full_sim_energy, ml_sim_energy)
0054 
0055     longitudinal_profile_plotter.plot_and_save()
0056     lateral_profile_plotter.plot_and_save()
0057     energy_plotter.plot_and_save()
0058     print("Done.")
0059 
0060 
0061 if __name__ == "__main__":
0062     exit(main())