File indexing completed on 2025-02-23 09:22:36
0001 """
0002 ** generate **
0003 generate showers using a saved VAE model
0004 """
0005 import argparse
0006
0007 import numpy as np
0008 import tensorflow as tf
0009 from tensorflow.python.data import Dataset
0010
0011 from core.constants import GLOBAL_CHECKPOINT_DIR, GEN_DIR, BATCH_SIZE_PER_REPLICA, MAX_GPU_MEMORY_ALLOCATION, GPU_IDS
0012 from utils.gpu_limiter import GPULimiter
0013 from utils.preprocess import get_condition_arrays
0014
0015
0016 def parse_args():
0017 argument_parser = argparse.ArgumentParser()
0018 argument_parser.add_argument("--geometry", type=str, default="")
0019 argument_parser.add_argument("--energy", type=int, default="")
0020 argument_parser.add_argument("--angle", type=int, default="")
0021 argument_parser.add_argument("--events", type=int, default=10000)
0022 argument_parser.add_argument("--epoch", type=int, default=None)
0023 argument_parser.add_argument("--study-name", type=str, default="default_study_name")
0024 argument_parser.add_argument("--max-gpu-memory-allocation", type=int, default=MAX_GPU_MEMORY_ALLOCATION)
0025 argument_parser.add_argument("--gpu-ids", type=str, default=GPU_IDS)
0026 args = argument_parser.parse_args()
0027 return args
0028
0029
0030
0031 def main():
0032
0033 args = parse_args()
0034 energy = args.energy
0035 angle = args.angle
0036 geometry = args.geometry
0037 events = args.events
0038 epoch = args.epoch
0039 study_name = args.study_name
0040 max_gpu_memory_allocation = args.max_gpu_memory_allocation
0041 gpu_ids = args.gpu_ids
0042
0043
0044 GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memory_allocation=max_gpu_memory_allocation)()
0045
0046
0047
0048
0049
0050 from core.model import VAEHandler
0051 vae = VAEHandler()
0052
0053
0054 weights_dir = f"VAE_epoch_{epoch:03}" if epoch is not None else "VAE_best"
0055 vae.model.load_weights(f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{weights_dir}/model_weights").expect_partial()
0056
0057
0058 generator = vae.model.decoder
0059
0060
0061
0062
0063 e_cond, angle_cond, geo_cond = get_condition_arrays(geometry, energy, events)
0064
0065 z_r = np.random.normal(loc=0, scale=1, size=(events, vae.latent_dim))
0066
0067 data = ((z_r, e_cond, angle_cond, geo_cond),)
0068
0069 data = Dataset.from_tensor_slices(data)
0070
0071 batch_size = BATCH_SIZE_PER_REPLICA
0072
0073 data = data.batch(batch_size)
0074
0075 options = tf.data.Options()
0076 options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
0077 data = data.with_options(options)
0078
0079
0080 generated_events = generator.predict(data) * (energy * 1000)
0081
0082
0083 np.save(f"{GEN_DIR}/VAE_Generated_Geo_{geometry}_E_{energy}_Angle_{angle}.npy", generated_events)
0084
0085
0086 if __name__ == "__main__":
0087 exit(main())