Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-23 09:22:36

0001 """
0002 ** generate **
0003 generate showers using a saved VAE model 
0004 """
0005 import argparse
0006 
0007 import numpy as np
0008 import tensorflow as tf
0009 from tensorflow.python.data import Dataset
0010 
0011 from core.constants import GLOBAL_CHECKPOINT_DIR, GEN_DIR, BATCH_SIZE_PER_REPLICA, MAX_GPU_MEMORY_ALLOCATION, GPU_IDS
0012 from utils.gpu_limiter import GPULimiter
0013 from utils.preprocess import get_condition_arrays
0014 
0015 
0016 def parse_args():
0017     argument_parser = argparse.ArgumentParser()
0018     argument_parser.add_argument("--geometry", type=str, default="")
0019     argument_parser.add_argument("--energy", type=int, default="")
0020     argument_parser.add_argument("--angle", type=int, default="")
0021     argument_parser.add_argument("--events", type=int, default=10000)
0022     argument_parser.add_argument("--epoch", type=int, default=None)
0023     argument_parser.add_argument("--study-name", type=str, default="default_study_name")
0024     argument_parser.add_argument("--max-gpu-memory-allocation", type=int, default=MAX_GPU_MEMORY_ALLOCATION)
0025     argument_parser.add_argument("--gpu-ids", type=str, default=GPU_IDS)
0026     args = argument_parser.parse_args()
0027     return args
0028 
0029 
0030 # main function
0031 def main():
0032     # 0. Parse arguments.
0033     args = parse_args()
0034     energy = args.energy
0035     angle = args.angle
0036     geometry = args.geometry
0037     events = args.events
0038     epoch = args.epoch
0039     study_name = args.study_name
0040     max_gpu_memory_allocation = args.max_gpu_memory_allocation
0041     gpu_ids = args.gpu_ids
0042 
0043     # 1. Set GPU memory limits.
0044     GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memory_allocation=max_gpu_memory_allocation)()
0045 
0046     # 2. Load a saved model.
0047 
0048     # Create a handler and build model.
0049     # This import must be local because otherwise it is impossible to call GPULimiter.
0050     from core.model import VAEHandler
0051     vae = VAEHandler()
0052 
0053     # Load the saved weights
0054     weights_dir = f"VAE_epoch_{epoch:03}" if epoch is not None else "VAE_best"
0055     vae.model.load_weights(f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{weights_dir}/model_weights").expect_partial()
0056 
0057     # The generator is defined as the decoder part only
0058     generator = vae.model.decoder
0059 
0060     # 3. Prepare data. Get condition values. Sample from the prior (normal distribution) in d dimension (d=latent_dim,
0061     # latent space dimension). Gather them into tuples. Wrap data in Dataset objects. The batch size must now be set
0062     # on the Dataset objects. Disable AutoShard.
0063     e_cond, angle_cond, geo_cond = get_condition_arrays(geometry, energy, events)
0064 
0065     z_r = np.random.normal(loc=0, scale=1, size=(events, vae.latent_dim))
0066 
0067     data = ((z_r, e_cond, angle_cond, geo_cond),)
0068 
0069     data = Dataset.from_tensor_slices(data)
0070 
0071     batch_size = BATCH_SIZE_PER_REPLICA
0072 
0073     data = data.batch(batch_size)
0074 
0075     options = tf.data.Options()
0076     options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
0077     data = data.with_options(options)
0078 
0079     # 4. Generate showers using the VAE model.
0080     generated_events = generator.predict(data) * (energy * 1000)
0081 
0082     # 5. Save the generated showers.
0083     np.save(f"{GEN_DIR}/VAE_Generated_Geo_{geometry}_E_{energy}_Angle_{angle}.npy", generated_events)
0084 
0085 
0086 if __name__ == "__main__":
0087     exit(main())