File indexing completed on 2025-10-13 08:28:38
0001 import gc
0002 from dataclasses import dataclass, field
0003 from typing import List, Tuple
0004
0005 import numpy as np
0006 import tensorflow as tf
0007 import wandb
0008 from sklearn.model_selection import KFold
0009 from tensorflow.keras import backend as K
0010 from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, History, Callback
0011 from tensorflow.keras.layers import BatchNormalization, Input, Dense, Layer, concatenate
0012 from tensorflow.keras.losses import BinaryCrossentropy, Reduction
0013 from tensorflow.keras.utils import Sequence
0014 from tensorflow.keras.models import Model
0015 from tensorflow.python.data import Dataset
0016 from tensorflow.python.distribute.distribute_lib import Strategy
0017 from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
0018 from wandb.keras import WandbCallback
0019
0020 from core.constants import ORIGINAL_DIM, LATENT_DIM, BATCH_SIZE_PER_REPLICA, EPOCHS, LEARNING_RATE, ACTIVATION, \
0021 OUT_ACTIVATION, OPTIMIZER_TYPE, KERNEL_INITIALIZER, GLOBAL_CHECKPOINT_DIR, EARLY_STOP, BIAS_INITIALIZER, \
0022 INTERMEDIATE_DIMS, SAVE_MODEL_EVERY_EPOCH, SAVE_BEST_MODEL, PATIENCE, MIN_DELTA, BEST_MODEL_FILENAME, \
0023 NUMBER_OF_K_FOLD_SPLITS, VALIDATION_SPLIT, WANDB_ENTITY
0024 from utils.optimizer import OptimizerFactory, OptimizerType
0025
0026
0027 class _Sampling(Layer):
0028 """ Custom layer to do the reparameterization trick: sample random latent vectors z from the latent Gaussian
0029 distribution.
0030
0031 The sampled vector z is given by sampled_z = mean + std * epsilon
0032 """
0033
0034 def __call__(self, inputs, **kwargs):
0035 z_mean, z_log_var, epsilon = inputs
0036 z_sigma = K.exp(0.5 * z_log_var)
0037 return z_mean + z_sigma * epsilon
0038
0039
0040
0041 class _KLDivergenceLayer(Layer):
0042
0043 def call(self, inputs, **kwargs):
0044 mu, log_var = inputs
0045 kl_loss = -0.5 * (1 + log_var - K.square(mu) - K.exp(log_var))
0046 kl_loss = K.mean(K.sum(kl_loss, axis=-1))
0047 self.add_loss(kl_loss)
0048 return inputs
0049
0050
0051 class DataGenerator(Sequence):
0052 def __init__(self, x_set, y_set, batch_size):
0053 self.x, self.y = x_set, y_set
0054 self.batch_size = batch_size
0055
0056 def __len__(self):
0057 return int(np.ceil(len(self.x[0]) / float(self.batch_size)))
0058
0059 def __getitem__(self, idx):
0060 batch_x = []
0061 for i in range(len(self.x)):
0062 batch_x.append(self.x[i][idx * self.batch_size:(idx + 1) * self.batch_size])
0063 batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
0064 return tuple(batch_x), batch_y
0065
0066
0067 class VAE(Model):
0068 def get_config(self):
0069 config = super().get_config()
0070 config["encoder"] = self.encoder
0071 config["decoder"] = self.decoder
0072 return config
0073
0074 def call(self, inputs, training=None, mask=None):
0075 _, e_input, angle_input, geo_input, _ = inputs
0076 z = self.encoder(inputs)
0077 return self.decoder([z, e_input, angle_input, geo_input])
0078
0079 def __init__(self, encoder, decoder, **kwargs):
0080 super(VAE, self).__init__(**kwargs)
0081 self.encoder = encoder
0082 self.decoder = decoder
0083 self._set_inputs(inputs=self.encoder.inputs, outputs=self(self.encoder.inputs))
0084
0085
0086 @dataclass
0087 class VAEHandler:
0088 """
0089 Class to handle building and training VAE models.
0090 """
0091 _wandb_project_name: str = None
0092 _wandb_run_name: str = None
0093 _wandb_tags: List[str] = field(default_factory=list)
0094 _original_dim: int = ORIGINAL_DIM
0095 latent_dim: int = LATENT_DIM
0096 _batch_size_per_replica: int = BATCH_SIZE_PER_REPLICA
0097 _intermediate_dims: List[int] = field(default_factory=lambda: INTERMEDIATE_DIMS)
0098 _learning_rate: float = LEARNING_RATE
0099 _epochs: int = EPOCHS
0100 _activation: str = ACTIVATION
0101 _out_activation: str = OUT_ACTIVATION
0102 _number_of_k_fold_splits: float = NUMBER_OF_K_FOLD_SPLITS
0103 _optimizer_type: OptimizerType = OPTIMIZER_TYPE
0104 _kernel_initializer: str = KERNEL_INITIALIZER
0105 _bias_initializer: str = BIAS_INITIALIZER
0106 _checkpoint_dir: str = GLOBAL_CHECKPOINT_DIR
0107 _early_stop: bool = EARLY_STOP
0108 _save_model_every_epoch: bool = SAVE_MODEL_EVERY_EPOCH
0109 _save_best_model: bool = SAVE_BEST_MODEL
0110 _patience: int = PATIENCE
0111 _min_delta: float = MIN_DELTA
0112 _best_model_filename: str = BEST_MODEL_FILENAME
0113 _validation_split: float = VALIDATION_SPLIT
0114 _strategy: Strategy = MirroredStrategy()
0115
0116 def __post_init__(self) -> None:
0117
0118 self._batch_size = self._batch_size_per_replica * self._strategy.num_replicas_in_sync
0119 self._build_and_compile_new_model()
0120
0121 if self._wandb_project_name is not None:
0122 self._setup_wandb()
0123
0124 def _setup_wandb(self) -> None:
0125 config = {
0126 "learning_rate": self._learning_rate,
0127 "batch_size": self._batch_size,
0128 "epochs": self._epochs,
0129 "optimizer_type": self._optimizer_type,
0130 "intermediate_dims": self._intermediate_dims,
0131 "latent_dim": self.latent_dim
0132 }
0133
0134
0135 wandb.init(name=self._wandb_run_name, project=self._wandb_project_name, entity=WANDB_ENTITY, reinit=True, config=config,
0136 tags=self._wandb_tags)
0137
0138 def _build_and_compile_new_model(self) -> None:
0139 """ Builds and compiles a new model.
0140
0141 VAEHandler keep a list of VAE instance. The reason is that while k-fold cross validation is performed,
0142 each fold requires a new, clear instance of model. New model is always added at the end of the list of
0143 existing ones.
0144
0145 Returns: None
0146
0147 """
0148
0149 encoder = self._build_encoder()
0150 decoder = self._build_decoder()
0151
0152
0153 with self._strategy.scope():
0154
0155 self.model = VAE(encoder, decoder)
0156
0157 optimizer = OptimizerFactory.create_optimizer(self._optimizer_type, self._learning_rate)
0158 reconstruction_loss = BinaryCrossentropy(reduction=Reduction.SUM)
0159 self.model.compile(optimizer=optimizer, loss=[reconstruction_loss], loss_weights=[ORIGINAL_DIM])
0160
0161 def _prepare_input_layers(self, for_encoder: bool) -> List[Input]:
0162 """
0163 Create four Input layers. Each of them is responsible to take respectively: batch of showers/batch of latent
0164 vectors, batch of energies, batch of angles, batch of geometries.
0165
0166 Args:
0167 for_encoder: Boolean which decides whether an input is full dimensional shower or a latent vector.
0168
0169 Returns:
0170 List of Input layers (five for encoder and four for decoder).
0171
0172 """
0173 e_input = Input(shape=(1,))
0174 angle_input = Input(shape=(1,))
0175 geo_input = Input(shape=(2,))
0176 if for_encoder:
0177 x_input = Input(shape=self._original_dim)
0178 eps_input = Input(shape=self.latent_dim)
0179 return [x_input, e_input, angle_input, geo_input, eps_input]
0180 else:
0181 x_input = Input(shape=self.latent_dim)
0182 return [x_input, e_input, angle_input, geo_input]
0183
0184 def _build_encoder(self) -> Model:
0185 """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds
0186 the encoder.
0187
0188 Returns:
0189 Encoder is returned as a keras.Model.
0190
0191 """
0192
0193 with self._strategy.scope():
0194
0195 x_input, e_input, angle_input, geo_input, eps_input = self._prepare_input_layers(for_encoder=True)
0196 x = concatenate([x_input, e_input, angle_input, geo_input])
0197
0198 for intermediate_dim in self._intermediate_dims:
0199 x = Dense(units=intermediate_dim, activation=self._activation,
0200 kernel_initializer=self._kernel_initializer,
0201 bias_initializer=self._bias_initializer)(x)
0202 x = BatchNormalization()(x)
0203
0204
0205 z_mean = Dense(self.latent_dim, name="z_mean")(x)
0206 z_log_var = Dense(self.latent_dim, name="z_log_var")(x)
0207
0208 z_mean, z_log_var = _KLDivergenceLayer()([z_mean, z_log_var])
0209
0210 encoder_output = _Sampling()([z_mean, z_log_var, eps_input])
0211
0212 encoder = Model(inputs=[x_input, e_input, angle_input, geo_input, eps_input], outputs=encoder_output,
0213 name="encoder")
0214 return encoder
0215
0216 def _build_decoder(self) -> Model:
0217 """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds
0218 the decoder.
0219
0220 Returns:
0221 Decoder is returned as a keras.Model.
0222
0223 """
0224
0225 with self._strategy.scope():
0226
0227 latent_input, e_input, angle_input, geo_input = self._prepare_input_layers(for_encoder=False)
0228 x = concatenate([latent_input, e_input, angle_input, geo_input])
0229
0230 for intermediate_dim in reversed(self._intermediate_dims):
0231 x = Dense(units=intermediate_dim, activation=self._activation,
0232 kernel_initializer=self._kernel_initializer,
0233 bias_initializer=self._bias_initializer)(x)
0234 x = BatchNormalization()(x)
0235
0236 decoder_outputs = Dense(units=self._original_dim, activation=self._out_activation)(x)
0237
0238 decoder = Model(inputs=[latent_input, e_input, angle_input, geo_input], outputs=decoder_outputs,
0239 name="decoder")
0240 return decoder
0241
0242 def _manufacture_callbacks(self) -> List[Callback]:
0243 """
0244 Based on parameters set by the user, manufacture callbacks required for training.
0245
0246 Returns:
0247 A list of `Callback` objects.
0248
0249 """
0250 callbacks = []
0251
0252
0253 if self._early_stop:
0254 callbacks.append(
0255 EarlyStopping(monitor="val_loss",
0256 min_delta=self._min_delta,
0257 patience=self._patience,
0258 verbose=True,
0259 restore_best_weights=True))
0260
0261 if self._save_model_every_epoch:
0262 callbacks.append(ModelCheckpoint(filepath=f"{self._checkpoint_dir}/VAE_epoch_{{epoch:03}}/model_weights",
0263 monitor="val_loss",
0264 verbose=True,
0265 save_weights_only=True,
0266 mode="min",
0267 save_freq="epoch"))
0268
0269 callbacks.append(WandbCallback(
0270 monitor="val_loss", verbose=0, mode="auto", save_model=False))
0271 return callbacks
0272
0273 def _get_train_and_val_data(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0274 noise: np.array, train_indexes: np.array, validation_indexes: np.array) \
0275 -> Tuple[Dataset, Dataset]:
0276 """
0277 Splits data into train and validation set based on given lists of indexes.
0278 Load batches to the GPU instead of entire dataset.
0279
0280 """
0281
0282
0283 train_dataset = dataset[train_indexes, :]
0284 train_e_cond = e_cond[train_indexes]
0285 train_angle_cond = angle_cond[train_indexes]
0286 train_geo_cond = geo_cond[train_indexes, :]
0287 train_noise = noise[train_indexes, :]
0288
0289
0290 val_dataset = dataset[validation_indexes, :]
0291 val_e_cond = e_cond[validation_indexes]
0292 val_angle_cond = angle_cond[validation_indexes]
0293 val_geo_cond = geo_cond[validation_indexes, :]
0294 val_noise = noise[validation_indexes, :]
0295
0296
0297 train_x = (train_dataset, train_e_cond, train_angle_cond, train_geo_cond, train_noise)
0298 train_y = train_dataset
0299 val_x = (val_dataset, val_e_cond, val_angle_cond, val_geo_cond, val_noise)
0300 val_y = val_dataset
0301
0302 train_gen = DataGenerator(train_x, train_y, self._batch_size)
0303 val_gen = DataGenerator(val_x, val_y, self._batch_size)
0304 return train_gen, val_gen
0305
0306 def _k_fold_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0307 noise: np.array, callbacks: List[Callback], verbose: bool = True) -> List[History]:
0308 """
0309 Performs K-fold cross validation training.
0310
0311 Number of fold is defined by (self._number_of_k_fold_splits). Always shuffle the dataset.
0312
0313 Args:
0314 dataset: A matrix representing showers. Shape =
0315 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0316 e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0317 angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0318 geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0319 noise: A matrix representing an additional noise needed to perform a reparametrization trick.
0320 callbacks: A list of callback forwarded to the fitting function.
0321 verbose: A boolean which says there the training should be performed in a verbose mode or not.
0322
0323 Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and
0324 metrics values at successive epochs, as well as validation loss values and validation metrics values (if
0325 applicable).
0326
0327 """
0328
0329 k_fold = KFold(n_splits=self._number_of_k_fold_splits, shuffle=True)
0330 histories = []
0331
0332 for i, (train_indexes, validation_indexes) in enumerate(k_fold.split(dataset)):
0333 print(f"K-fold: {i + 1}/{self._number_of_k_fold_splits}...")
0334 train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise,
0335 train_indexes, validation_indexes)
0336
0337 self._build_and_compile_new_model()
0338
0339 history = self.model.fit(x=train_data,
0340 shuffle=True,
0341 epochs=self._epochs,
0342 verbose=verbose,
0343 validation_data=val_data,
0344 callbacks=callbacks
0345 )
0346 histories.append(history)
0347
0348 if self._save_best_model:
0349 self.model.save_weights(f"{self._checkpoint_dir}/VAE_fold_{i + 1}/model_weights")
0350 print(f"Best model from fold {i + 1} was saved.")
0351
0352
0353 del self.model
0354 del train_data
0355 del val_data
0356 tf.keras.backend.clear_session()
0357 gc.collect()
0358
0359 return histories
0360
0361 def _single_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0362 noise: np.ndarray, callbacks: List[Callback], verbose: bool = True) -> List[History]:
0363 """
0364 Performs a single training.
0365
0366 A fraction of dataset (self._validation_split) is used as a validation data.
0367
0368 Args:
0369 dataset: A matrix representing showers. Shape =
0370 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0371 e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0372 angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0373 geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0374 noise: A matrix representing an additional noise needed to perform a reparametrization trick.
0375 callbacks: A list of callback forwarded to the fitting function.
0376 verbose: A boolean which says there the training should be performed in a verbose mode or not.
0377
0378 Returns: A one-element list of `History` objects.`History.history` attribute is a record of training loss
0379 values and metrics values at successive epochs, as well as validation loss values and validation metrics
0380 values (if applicable).
0381
0382 """
0383 dataset_size, _ = dataset.shape
0384 permutation = np.random.permutation(dataset_size)
0385 split = int(dataset_size * self._validation_split)
0386 train_indexes, validation_indexes = permutation[split:], permutation[:split]
0387
0388 train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise, train_indexes,
0389 validation_indexes)
0390
0391 history = self.model.fit(x=train_data,
0392 shuffle=True,
0393 epochs=self._epochs,
0394 verbose=verbose,
0395 validation_data=val_data,
0396 callbacks=callbacks
0397 )
0398 if self._save_best_model:
0399 self.model.save_weights(f"{self._checkpoint_dir}/VAE_best/model_weights")
0400 print("Best model was saved.")
0401
0402 return [history]
0403
0404 def train(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0405 verbose: bool = True) -> List[History]:
0406 """
0407 For a given input data trains and validates the model.
0408
0409 If the numer of K-fold splits > 1 then it runs K-fold cross validation, otherwise it runs a single training
0410 which uses (self._validation_split * 100) % of dataset as a validation data.
0411
0412 Args:
0413 dataset: A matrix representing showers. Shape =
0414 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0415 e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0416 angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0417 geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0418 verbose: A boolean which says there the training should be performed in a verbose mode or not.
0419
0420 Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and
0421 metrics values at successive epochs, as well as validation loss values and validation metrics values (if
0422 applicable).
0423
0424 """
0425
0426 callbacks = self._manufacture_callbacks()
0427
0428 noise = np.random.normal(0, 1, size=(dataset.shape[0], self.latent_dim))
0429
0430 if self._number_of_k_fold_splits > 1:
0431 return self._k_fold_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)
0432 else:
0433 return self._single_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)