Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-23 09:22:36

0001 from argparse import ArgumentParser
0002 
0003 from core.constants import MAX_GPU_MEMORY_ALLOCATION, GPU_IDS
0004 from utils.gpu_limiter import GPULimiter
0005 from utils.optimizer import OptimizerType
0006 
0007 # Hyperparemeters to be optimized.
0008 discrete_parameters = {"nb_hidden_layers": (1, 6), "latent_dim": (15, 100)}
0009 continuous_parameters = {"learning_rate": (0.0001, 0.005)}
0010 categorical_parameters = {"optimizer_type": [OptimizerType.ADAM, OptimizerType.RMSPROP]}
0011 
0012 
0013 def parse_args():
0014     argument_parser = ArgumentParser()
0015     argument_parser.add_argument("--study-name", type=str, default="default_study_name")
0016     argument_parser.add_argument("--storage", type=str)
0017     argument_parser.add_argument("--max-gpu-memory-allocation", type=int, default=MAX_GPU_MEMORY_ALLOCATION)
0018     argument_parser.add_argument("--gpu-ids", type=str, default=GPU_IDS)
0019     args = argument_parser.parse_args()
0020     return args
0021 
0022 
0023 def main():
0024     # 0. Parse arguments.
0025     args = parse_args()
0026     study_name = args.study_name
0027     storage = args.storage
0028     max_gpu_memory_allocation = args.max_gpu_memory_allocation
0029     gpu_ids = args.gpu_ids
0030 
0031     # 1. Set GPU memory limits.
0032     GPULimiter(_gpu_ids=gpu_ids, _max_gpu_memory_allocation=max_gpu_memory_allocation)()
0033 
0034     # 2. Manufacture hyperparameter tuner.
0035 
0036     # This import must be local because otherwise it is impossible to call GPULimiter.
0037     from utils.hyperparameter_tuner import HyperparameterTuner
0038     hyperparameter_tuner = HyperparameterTuner(discrete_parameters, continuous_parameters, categorical_parameters,
0039                                                storage, study_name)
0040 
0041     # 3. Run main tuning function.
0042     hyperparameter_tuner.tune()
0043     # Watch out! This script neither deletes the study in DB nor deletes the database itself. If you are using
0044     # parallelized optimization, then you should care about deleting study in the database by yourself.
0045 
0046 
0047 if __name__ == "__main__":
0048     exit(main())