File indexing completed on 2025-02-23 09:22:36
0001 from enum import IntEnum
0002
0003 from tensorflow.keras.optimizers import Optimizer, Adadelta, Adagrad, Adam, Adamax, Ftrl, SGD, Nadam, RMSprop
0004
0005
0006 class OptimizerType(IntEnum):
0007 """ Enum class of various optimizer types.
0008
0009 This class must be IntEnum to be JSON serializable. This feature is important because, when Optuna's study is
0010 saved in a relational DB, all objects must be JSON serializable.
0011 """
0012
0013 SGD = 0
0014 RMSPROP = 1
0015 ADAM = 2
0016 ADADELTA = 3
0017 ADAGRAD = 4
0018 ADAMAX = 5
0019 NADAM = 6
0020 FTRL = 7
0021
0022
0023 class OptimizerFactory:
0024 """Factory of optimizer like Stochastic Gradient Descent, RMSProp, Adam, etc.
0025 """
0026
0027 @staticmethod
0028 def create_optimizer(optimizer_type: OptimizerType, learning_rate: float) -> Optimizer:
0029 """For a given type and a learning rate creates an instance of optimizer.
0030
0031 Args:
0032 optimizer_type: a type of optimizer
0033 learning_rate: a learning rate that should be passed to an optimizer
0034
0035 Returns:
0036 An instance of optimizer.
0037
0038 """
0039 if optimizer_type == OptimizerType.SGD:
0040 return SGD(learning_rate)
0041 elif optimizer_type == OptimizerType.RMSPROP:
0042 return RMSprop(learning_rate)
0043 elif optimizer_type == OptimizerType.ADAM:
0044 return Adam(learning_rate)
0045 elif optimizer_type == OptimizerType.ADADELTA:
0046 return Adadelta(learning_rate)
0047 elif optimizer_type == OptimizerType.ADAGRAD:
0048 return Adagrad(learning_rate)
0049 elif optimizer_type == OptimizerType.ADAMAX:
0050 return Adamax(learning_rate)
0051 elif optimizer_type == OptimizerType.NADAM:
0052 return Nadam(learning_rate)
0053 else:
0054
0055 return Ftrl(learning_rate)