Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-10-13 08:28:38

0001 from argparse import ArgumentParser
0002 
0003 from core.constants import GPU_IDS, MAX_GPU_MEMORY_ALLOCATION, GLOBAL_CHECKPOINT_DIR
0004 from utils.gpu_limiter import GPULimiter
0005 from utils.preprocess import preprocess
0006 
0007 
0008 def parse_args():
0009     argument_parser = ArgumentParser()
0010     argument_parser.add_argument("--max-gpu-memory-allocation", type=int, default=MAX_GPU_MEMORY_ALLOCATION)
0011     argument_parser.add_argument("--gpu-ids", type=str, default=GPU_IDS)
0012     argument_parser.add_argument("--study-name", type=str, default="default_study_name")
0013     argument_parser.add_argument("--run-name", type=str, default=None)  # randomly chosen by wandb   
0014     args = argument_parser.parse_args()
0015     return args
0016 
0017 
0018 def main():
0019     # 0. Parse arguments.
0020     args = parse_args()
0021     max_gpu_memory_allocation = args.max_gpu_memory_allocation
0022     gpu_ids = args.gpu_ids
0023     print(f"Running on GPU ID {gpu_ids}")
0024     study_name = args.study_name
0025     run_name = args.run_name
0026     checkpoint_dir = f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{run_name}"
0027 
0028     # 1. Set GPU memory limits.
0029     GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memory_allocation=max_gpu_memory_allocation)()
0030 
0031     # 2. Data loading/preprocessing
0032 
0033     # The preprocess function reads the data and performs preprocessing and encoding for the values of energy,
0034     # angle and geometry
0035     energies_train, cond_e_train, cond_angle_train, cond_geo_train = preprocess()
0036 
0037     # 3. Manufacture model handler.
0038 
0039     # This import must be local because otherwise it is impossible to call GPULimiter.
0040     from core.model import VAEHandler
0041     vae = VAEHandler(_wandb_project_name=study_name, _wandb_run_name=run_name, _wandb_tags=["single training"], _checkpoint_dir=checkpoint_dir)
0042 
0043     # 4. Train model.
0044     histories = vae.train(energies_train,
0045                           cond_e_train,
0046                           cond_angle_train,
0047                           cond_geo_train
0048                           )
0049 
0050     # Note : One history object can be used to plot the loss evaluation as function of the epochs. Remember that the
0051     # function returns a list of those objects. Each of them represents a different fold of cross validation.
0052 
0053 
0054 if __name__ == "__main__":
0055     exit(main())