Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-23 09:22:36

0001 from enum import IntEnum
0002 
0003 from tensorflow.keras.optimizers import Optimizer, Adadelta, Adagrad, Adam, Adamax, Ftrl, SGD, Nadam, RMSprop
0004 
0005 
0006 class OptimizerType(IntEnum):
0007     """ Enum class of various optimizer types.
0008 
0009     This class must be IntEnum to be JSON serializable. This feature is important because, when Optuna's study is
0010     saved in a relational DB, all objects must be JSON serializable.
0011     """
0012 
0013     SGD = 0
0014     RMSPROP = 1
0015     ADAM = 2
0016     ADADELTA = 3
0017     ADAGRAD = 4
0018     ADAMAX = 5
0019     NADAM = 6
0020     FTRL = 7
0021 
0022 
0023 class OptimizerFactory:
0024     """Factory of optimizer like Stochastic Gradient Descent, RMSProp, Adam, etc.
0025     """
0026 
0027     @staticmethod
0028     def create_optimizer(optimizer_type: OptimizerType, learning_rate: float) -> Optimizer:
0029         """For a given type and a learning rate creates an instance of optimizer.
0030 
0031         Args:
0032             optimizer_type: a type of optimizer
0033             learning_rate: a learning rate that should be passed to an optimizer
0034 
0035         Returns:
0036             An instance of optimizer.
0037 
0038         """
0039         if optimizer_type == OptimizerType.SGD:
0040             return SGD(learning_rate)
0041         elif optimizer_type == OptimizerType.RMSPROP:
0042             return RMSprop(learning_rate)
0043         elif optimizer_type == OptimizerType.ADAM:
0044             return Adam(learning_rate)
0045         elif optimizer_type == OptimizerType.ADADELTA:
0046             return Adadelta(learning_rate)
0047         elif optimizer_type == OptimizerType.ADAGRAD:
0048             return Adagrad(learning_rate)
0049         elif optimizer_type == OptimizerType.ADAMAX:
0050             return Adamax(learning_rate)
0051         elif optimizer_type == OptimizerType.NADAM:
0052             return Nadam(learning_rate)
0053         else:
0054             # i.e. optimizer_type == OptimizerType.FTRL
0055             return Ftrl(learning_rate)