File indexing completed on 2025-10-13 08:28:38
0001 from argparse import ArgumentParser
0002
0003 from core.constants import GPU_IDS, MAX_GPU_MEMORY_ALLOCATION, GLOBAL_CHECKPOINT_DIR
0004 from utils.gpu_limiter import GPULimiter
0005 from utils.preprocess import preprocess
0006
0007
0008 def parse_args():
0009 argument_parser = ArgumentParser()
0010 argument_parser.add_argument("--max-gpu-memory-allocation", type=int, default=MAX_GPU_MEMORY_ALLOCATION)
0011 argument_parser.add_argument("--gpu-ids", type=str, default=GPU_IDS)
0012 argument_parser.add_argument("--study-name", type=str, default="default_study_name")
0013 argument_parser.add_argument("--run-name", type=str, default=None)
0014 args = argument_parser.parse_args()
0015 return args
0016
0017
0018 def main():
0019
0020 args = parse_args()
0021 max_gpu_memory_allocation = args.max_gpu_memory_allocation
0022 gpu_ids = args.gpu_ids
0023 print(f"Running on GPU ID {gpu_ids}")
0024 study_name = args.study_name
0025 run_name = args.run_name
0026 checkpoint_dir = f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{run_name}"
0027
0028
0029 GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memory_allocation=max_gpu_memory_allocation)()
0030
0031
0032
0033
0034
0035 energies_train, cond_e_train, cond_angle_train, cond_geo_train = preprocess()
0036
0037
0038
0039
0040 from core.model import VAEHandler
0041 vae = VAEHandler(_wandb_project_name=study_name, _wandb_run_name=run_name, _wandb_tags=["single training"], _checkpoint_dir=checkpoint_dir)
0042
0043
0044 histories = vae.train(energies_train,
0045 cond_e_train,
0046 cond_angle_train,
0047 cond_geo_train
0048 )
0049
0050
0051
0052
0053
0054 if __name__ == "__main__":
0055 exit(main())