File indexing completed on 2025-02-23 09:22:36
0001 import gc
0002 from dataclasses import dataclass, field
0003 from typing import List, Tuple
0004
0005 import numpy as np
0006 import tensorflow as tf
0007 import wandb
0008 from sklearn.model_selection import KFold
0009 from tensorflow.keras import backend as K
0010 from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, History, Callback
0011 from tensorflow.keras.layers import BatchNormalization, Input, Dense, Layer, concatenate
0012 from tensorflow.keras.losses import BinaryCrossentropy, Reduction
0013 from tensorflow.keras.models import Model
0014 from tensorflow.python.data import Dataset
0015 from tensorflow.python.distribute.distribute_lib import Strategy
0016 from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
0017 from wandb.keras import WandbCallback
0018
0019 from core.constants import ORIGINAL_DIM, LATENT_DIM, BATCH_SIZE_PER_REPLICA, EPOCHS, LEARNING_RATE, ACTIVATION, \
0020 OUT_ACTIVATION, OPTIMIZER_TYPE, KERNEL_INITIALIZER, GLOBAL_CHECKPOINT_DIR, EARLY_STOP, BIAS_INITIALIZER, \
0021 INTERMEDIATE_DIMS, SAVE_MODEL_EVERY_EPOCH, SAVE_BEST_MODEL, PATIENCE, MIN_DELTA, BEST_MODEL_FILENAME, \
0022 NUMBER_OF_K_FOLD_SPLITS, VALIDATION_SPLIT, WANDB_ENTITY
0023 from utils.optimizer import OptimizerFactory, OptimizerType
0024
0025
0026 class _Sampling(Layer):
0027 """ Custom layer to do the reparameterization trick: sample random latent vectors z from the latent Gaussian
0028 distribution.
0029
0030 The sampled vector z is given by sampled_z = mean + std * epsilon
0031 """
0032
0033 def __call__(self, inputs, **kwargs):
0034 z_mean, z_log_var, epsilon = inputs
0035 z_sigma = K.exp(0.5 * z_log_var)
0036 return z_mean + z_sigma * epsilon
0037
0038
0039
0040 class _KLDivergenceLayer(Layer):
0041
0042 def call(self, inputs, **kwargs):
0043 mu, log_var = inputs
0044 kl_loss = -0.5 * (1 + log_var - K.square(mu) - K.exp(log_var))
0045 kl_loss = K.mean(K.sum(kl_loss, axis=-1))
0046 self.add_loss(kl_loss)
0047 return inputs
0048
0049
0050 class VAE(Model):
0051 def get_config(self):
0052 config = super().get_config()
0053 config["encoder"] = self.encoder
0054 config["decoder"] = self.decoder
0055 return config
0056
0057 def call(self, inputs, training=None, mask=None):
0058 _, e_input, angle_input, geo_input, _ = inputs
0059 z = self.encoder(inputs)
0060 return self.decoder([z, e_input, angle_input, geo_input])
0061
0062 def __init__(self, encoder, decoder, **kwargs):
0063 super(VAE, self).__init__(**kwargs)
0064 self.encoder = encoder
0065 self.decoder = decoder
0066 self._set_inputs(inputs=self.encoder.inputs, outputs=self(self.encoder.inputs))
0067
0068
0069 @dataclass
0070 class VAEHandler:
0071 """
0072 Class to handle building and training VAE models.
0073 """
0074 _wandb_project_name: str = None
0075 _wandb_tags: List[str] = field(default_factory=list)
0076 _original_dim: int = ORIGINAL_DIM
0077 latent_dim: int = LATENT_DIM
0078 _batch_size_per_replica: int = BATCH_SIZE_PER_REPLICA
0079 _intermediate_dims: List[int] = field(default_factory=lambda: INTERMEDIATE_DIMS)
0080 _learning_rate: float = LEARNING_RATE
0081 _epochs: int = EPOCHS
0082 _activation: str = ACTIVATION
0083 _out_activation: str = OUT_ACTIVATION
0084 _number_of_k_fold_splits: float = NUMBER_OF_K_FOLD_SPLITS
0085 _optimizer_type: OptimizerType = OPTIMIZER_TYPE
0086 _kernel_initializer: str = KERNEL_INITIALIZER
0087 _bias_initializer: str = BIAS_INITIALIZER
0088 _checkpoint_dir: str = GLOBAL_CHECKPOINT_DIR
0089 _early_stop: bool = EARLY_STOP
0090 _save_model_every_epoch: bool = SAVE_MODEL_EVERY_EPOCH
0091 _save_best_model: bool = SAVE_BEST_MODEL
0092 _patience: int = PATIENCE
0093 _min_delta: float = MIN_DELTA
0094 _best_model_filename: str = BEST_MODEL_FILENAME
0095 _validation_split: float = VALIDATION_SPLIT
0096 _strategy: Strategy = MirroredStrategy()
0097
0098 def __post_init__(self) -> None:
0099
0100 self._batch_size = self._batch_size_per_replica * self._strategy.num_replicas_in_sync
0101 self._build_and_compile_new_model()
0102
0103 if self._wandb_project_name is not None:
0104 self._setup_wandb()
0105
0106 def _setup_wandb(self) -> None:
0107 config = {
0108 "learning_rate": self._learning_rate,
0109 "batch_size": self._batch_size,
0110 "epochs": self._epochs,
0111 "optimizer_type": self._optimizer_type,
0112 "intermediate_dims": self._intermediate_dims,
0113 "latent_dim": self.latent_dim
0114 }
0115
0116
0117 wandb.init(project=self._wandb_project_name, entity=WANDB_ENTITY, reinit=True, config=config,
0118 tags=self._wandb_tags)
0119
0120 def _build_and_compile_new_model(self) -> None:
0121 """ Builds and compiles a new model.
0122
0123 VAEHandler keep a list of VAE instance. The reason is that while k-fold cross validation is performed,
0124 each fold requires a new, clear instance of model. New model is always added at the end of the list of
0125 existing ones.
0126
0127 Returns: None
0128
0129 """
0130
0131 encoder = self._build_encoder()
0132 decoder = self._build_decoder()
0133
0134
0135 with self._strategy.scope():
0136
0137 self.model = VAE(encoder, decoder)
0138
0139 optimizer = OptimizerFactory.create_optimizer(self._optimizer_type, self._learning_rate)
0140 reconstruction_loss = BinaryCrossentropy(reduction=Reduction.SUM)
0141 self.model.compile(optimizer=optimizer, loss=[reconstruction_loss], loss_weights=[ORIGINAL_DIM])
0142
0143 def _prepare_input_layers(self, for_encoder: bool) -> List[Input]:
0144 """
0145 Create four Input layers. Each of them is responsible to take respectively: batch of showers/batch of latent
0146 vectors, batch of energies, batch of angles, batch of geometries.
0147
0148 Args:
0149 for_encoder: Boolean which decides whether an input is full dimensional shower or a latent vector.
0150
0151 Returns:
0152 List of Input layers (five for encoder and four for decoder).
0153
0154 """
0155 e_input = Input(shape=(1,))
0156 angle_input = Input(shape=(1,))
0157 geo_input = Input(shape=(2,))
0158 if for_encoder:
0159 x_input = Input(shape=self._original_dim)
0160 eps_input = Input(shape=self.latent_dim)
0161 return [x_input, e_input, angle_input, geo_input, eps_input]
0162 else:
0163 x_input = Input(shape=self.latent_dim)
0164 return [x_input, e_input, angle_input, geo_input]
0165
0166 def _build_encoder(self) -> Model:
0167 """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds
0168 the encoder.
0169
0170 Returns:
0171 Encoder is returned as a keras.Model.
0172
0173 """
0174
0175 with self._strategy.scope():
0176
0177 x_input, e_input, angle_input, geo_input, eps_input = self._prepare_input_layers(for_encoder=True)
0178 x = concatenate([x_input, e_input, angle_input, geo_input])
0179
0180 for intermediate_dim in self._intermediate_dims:
0181 x = Dense(units=intermediate_dim, activation=self._activation,
0182 kernel_initializer=self._kernel_initializer,
0183 bias_initializer=self._bias_initializer)(x)
0184 x = BatchNormalization()(x)
0185
0186
0187 z_mean = Dense(self.latent_dim, name="z_mean")(x)
0188 z_log_var = Dense(self.latent_dim, name="z_log_var")(x)
0189
0190 z_mean, z_log_var = _KLDivergenceLayer()([z_mean, z_log_var])
0191
0192 encoder_output = _Sampling()([z_mean, z_log_var, eps_input])
0193
0194 encoder = Model(inputs=[x_input, e_input, angle_input, geo_input, eps_input], outputs=encoder_output,
0195 name="encoder")
0196 return encoder
0197
0198 def _build_decoder(self) -> Model:
0199 """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds
0200 the decoder.
0201
0202 Returns:
0203 Decoder is returned as a keras.Model.
0204
0205 """
0206
0207 with self._strategy.scope():
0208
0209 latent_input, e_input, angle_input, geo_input = self._prepare_input_layers(for_encoder=False)
0210 x = concatenate([latent_input, e_input, angle_input, geo_input])
0211
0212 for intermediate_dim in reversed(self._intermediate_dims):
0213 x = Dense(units=intermediate_dim, activation=self._activation,
0214 kernel_initializer=self._kernel_initializer,
0215 bias_initializer=self._bias_initializer)(x)
0216 x = BatchNormalization()(x)
0217
0218 decoder_outputs = Dense(units=self._original_dim, activation=self._out_activation)(x)
0219
0220 decoder = Model(inputs=[latent_input, e_input, angle_input, geo_input], outputs=decoder_outputs,
0221 name="decoder")
0222 return decoder
0223
0224 def _manufacture_callbacks(self) -> List[Callback]:
0225 """
0226 Based on parameters set by the user, manufacture callbacks required for training.
0227
0228 Returns:
0229 A list of `Callback` objects.
0230
0231 """
0232 callbacks = []
0233
0234
0235 if self._early_stop:
0236 callbacks.append(
0237 EarlyStopping(monitor="val_loss",
0238 min_delta=self._min_delta,
0239 patience=self._patience,
0240 verbose=True,
0241 restore_best_weights=True))
0242
0243 if self._save_model_every_epoch:
0244 callbacks.append(ModelCheckpoint(filepath=f"{self._checkpoint_dir}/VAE_epoch_{{epoch:03}}/model_weights",
0245 monitor="val_loss",
0246 verbose=True,
0247 save_weights_only=True,
0248 mode="min",
0249 save_freq="epoch"))
0250
0251 callbacks.append(WandbCallback(
0252 monitor="val_loss", verbose=0, mode="auto", save_model=False))
0253 return callbacks
0254
0255 def _get_train_and_val_data(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0256 noise: np.array, train_indexes: np.array, validation_indexes: np.array) \
0257 -> Tuple[Dataset, Dataset]:
0258 """
0259 Splits data into train and validation set based on given lists of indexes.
0260
0261 """
0262
0263
0264 train_dataset = dataset[train_indexes, :]
0265 train_e_cond = e_cond[train_indexes]
0266 train_angle_cond = angle_cond[train_indexes]
0267 train_geo_cond = geo_cond[train_indexes, :]
0268 train_noise = noise[train_indexes, :]
0269
0270
0271 val_dataset = dataset[validation_indexes, :]
0272 val_e_cond = e_cond[validation_indexes]
0273 val_angle_cond = angle_cond[validation_indexes]
0274 val_geo_cond = geo_cond[validation_indexes, :]
0275 val_noise = noise[validation_indexes, :]
0276
0277
0278 train_x = (train_dataset, train_e_cond, train_angle_cond, train_geo_cond, train_noise)
0279 train_y = train_dataset
0280 val_x = (val_dataset, val_e_cond, val_angle_cond, val_geo_cond, val_noise)
0281 val_y = val_dataset
0282
0283
0284
0285
0286
0287 train_data = Dataset.from_tensor_slices((train_x, train_y))
0288 val_data = Dataset.from_tensor_slices((val_x, val_y))
0289
0290
0291 train_data = train_data.batch(self._batch_size)
0292 val_data = val_data.batch(self._batch_size)
0293
0294
0295 options = tf.data.Options()
0296 options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
0297 train_data = train_data.with_options(options)
0298 val_data = val_data.with_options(options)
0299
0300 return train_data, val_data
0301
0302 def _k_fold_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0303 noise: np.array, callbacks: List[Callback], verbose: bool = True) -> List[History]:
0304 """
0305 Performs K-fold cross validation training.
0306
0307 Number of fold is defined by (self._number_of_k_fold_splits). Always shuffle the dataset.
0308
0309 Args:
0310 dataset: A matrix representing showers. Shape =
0311 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0312 e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0313 angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0314 geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0315 noise: A matrix representing an additional noise needed to perform a reparametrization trick.
0316 callbacks: A list of callback forwarded to the fitting function.
0317 verbose: A boolean which says there the training should be performed in a verbose mode or not.
0318
0319 Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and
0320 metrics values at successive epochs, as well as validation loss values and validation metrics values (if
0321 applicable).
0322
0323 """
0324
0325 k_fold = KFold(n_splits=self._number_of_k_fold_splits, shuffle=True)
0326 histories = []
0327
0328 for i, (train_indexes, validation_indexes) in enumerate(k_fold.split(dataset)):
0329 print(f"K-fold: {i + 1}/{self._number_of_k_fold_splits}...")
0330 train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise,
0331 train_indexes, validation_indexes)
0332
0333 self._build_and_compile_new_model()
0334
0335 history = self.model.fit(x=train_data,
0336 shuffle=True,
0337 epochs=self._epochs,
0338 verbose=verbose,
0339 validation_data=val_data,
0340 callbacks=callbacks
0341 )
0342 histories.append(history)
0343
0344 if self._save_best_model:
0345 self.model.save_weights(f"{self._checkpoint_dir}/VAE_fold_{i + 1}/model_weights")
0346 print(f"Best model from fold {i + 1} was saved.")
0347
0348
0349 del self.model
0350 del train_data
0351 del val_data
0352 tf.keras.backend.clear_session()
0353 gc.collect()
0354
0355 return histories
0356
0357 def _single_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0358 noise: np.ndarray, callbacks: List[Callback], verbose: bool = True) -> List[History]:
0359 """
0360 Performs a single training.
0361
0362 A fraction of dataset (self._validation_split) is used as a validation data.
0363
0364 Args:
0365 dataset: A matrix representing showers. Shape =
0366 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0367 e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0368 angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0369 geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0370 noise: A matrix representing an additional noise needed to perform a reparametrization trick.
0371 callbacks: A list of callback forwarded to the fitting function.
0372 verbose: A boolean which says there the training should be performed in a verbose mode or not.
0373
0374 Returns: A one-element list of `History` objects.`History.history` attribute is a record of training loss
0375 values and metrics values at successive epochs, as well as validation loss values and validation metrics
0376 values (if applicable).
0377
0378 """
0379 dataset_size, _ = dataset.shape
0380 permutation = np.random.permutation(dataset_size)
0381 split = int(dataset_size * self._validation_split)
0382 train_indexes, validation_indexes = permutation[split:], permutation[:split]
0383
0384 train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise, train_indexes,
0385 validation_indexes)
0386
0387 history = self.model.fit(x=train_data,
0388 shuffle=True,
0389 epochs=self._epochs,
0390 verbose=verbose,
0391 validation_data=val_data,
0392 callbacks=callbacks
0393 )
0394 if self._save_best_model:
0395 self.model.save_weights(f"{self._checkpoint_dir}/VAE_best/model_weights")
0396 print("Best model was saved.")
0397
0398 return [history]
0399
0400 def train(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0401 verbose: bool = True) -> List[History]:
0402 """
0403 For a given input data trains and validates the model.
0404
0405 If the numer of K-fold splits > 1 then it runs K-fold cross validation, otherwise it runs a single training
0406 which uses (self._validation_split * 100) % of dataset as a validation data.
0407
0408 Args:
0409 dataset: A matrix representing showers. Shape =
0410 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0411 e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0412 angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0413 geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0414 verbose: A boolean which says there the training should be performed in a verbose mode or not.
0415
0416 Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and
0417 metrics values at successive epochs, as well as validation loss values and validation metrics values (if
0418 applicable).
0419
0420 """
0421
0422 callbacks = self._manufacture_callbacks()
0423
0424 noise = np.random.normal(0, 1, size=(dataset.shape[0], self.latent_dim))
0425
0426 if self._number_of_k_fold_splits > 1:
0427 return self._k_fold_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)
0428 else:
0429 return self._single_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)