Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-23 09:22:36

0001 import gc
0002 from dataclasses import dataclass, field
0003 from typing import List, Tuple
0004 
0005 import numpy as np
0006 import tensorflow as tf
0007 import wandb
0008 from sklearn.model_selection import KFold
0009 from tensorflow.keras import backend as K
0010 from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, History, Callback
0011 from tensorflow.keras.layers import BatchNormalization, Input, Dense, Layer, concatenate
0012 from tensorflow.keras.losses import BinaryCrossentropy, Reduction
0013 from tensorflow.keras.models import Model
0014 from tensorflow.python.data import Dataset
0015 from tensorflow.python.distribute.distribute_lib import Strategy
0016 from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
0017 from wandb.keras import WandbCallback
0018 
0019 from core.constants import ORIGINAL_DIM, LATENT_DIM, BATCH_SIZE_PER_REPLICA, EPOCHS, LEARNING_RATE, ACTIVATION, \
0020     OUT_ACTIVATION, OPTIMIZER_TYPE, KERNEL_INITIALIZER, GLOBAL_CHECKPOINT_DIR, EARLY_STOP, BIAS_INITIALIZER, \
0021     INTERMEDIATE_DIMS, SAVE_MODEL_EVERY_EPOCH, SAVE_BEST_MODEL, PATIENCE, MIN_DELTA, BEST_MODEL_FILENAME, \
0022     NUMBER_OF_K_FOLD_SPLITS, VALIDATION_SPLIT, WANDB_ENTITY
0023 from utils.optimizer import OptimizerFactory, OptimizerType
0024 
0025 
0026 class _Sampling(Layer):
0027     """ Custom layer to do the reparameterization trick: sample random latent vectors z from the latent Gaussian
0028     distribution.
0029 
0030     The sampled vector z is given by sampled_z = mean + std * epsilon
0031     """
0032 
0033     def __call__(self, inputs, **kwargs):
0034         z_mean, z_log_var, epsilon = inputs
0035         z_sigma = K.exp(0.5 * z_log_var)
0036         return z_mean + z_sigma * epsilon
0037 
0038 
0039 # KL divergence computation
0040 class _KLDivergenceLayer(Layer):
0041 
0042     def call(self, inputs, **kwargs):
0043         mu, log_var = inputs
0044         kl_loss = -0.5 * (1 + log_var - K.square(mu) - K.exp(log_var))
0045         kl_loss = K.mean(K.sum(kl_loss, axis=-1))
0046         self.add_loss(kl_loss)
0047         return inputs
0048 
0049 
0050 class VAE(Model):
0051     def get_config(self):
0052         config = super().get_config()
0053         config["encoder"] = self.encoder
0054         config["decoder"] = self.decoder
0055         return config
0056 
0057     def call(self, inputs, training=None, mask=None):
0058         _, e_input, angle_input, geo_input, _ = inputs
0059         z = self.encoder(inputs)
0060         return self.decoder([z, e_input, angle_input, geo_input])
0061 
0062     def __init__(self, encoder, decoder, **kwargs):
0063         super(VAE, self).__init__(**kwargs)
0064         self.encoder = encoder
0065         self.decoder = decoder
0066         self._set_inputs(inputs=self.encoder.inputs, outputs=self(self.encoder.inputs))
0067 
0068 
0069 @dataclass
0070 class VAEHandler:
0071     """
0072     Class to handle building and training VAE models.
0073     """
0074     _wandb_project_name: str = None
0075     _wandb_tags: List[str] = field(default_factory=list)
0076     _original_dim: int = ORIGINAL_DIM
0077     latent_dim: int = LATENT_DIM
0078     _batch_size_per_replica: int = BATCH_SIZE_PER_REPLICA
0079     _intermediate_dims: List[int] = field(default_factory=lambda: INTERMEDIATE_DIMS)
0080     _learning_rate: float = LEARNING_RATE
0081     _epochs: int = EPOCHS
0082     _activation: str = ACTIVATION
0083     _out_activation: str = OUT_ACTIVATION
0084     _number_of_k_fold_splits: float = NUMBER_OF_K_FOLD_SPLITS
0085     _optimizer_type: OptimizerType = OPTIMIZER_TYPE
0086     _kernel_initializer: str = KERNEL_INITIALIZER
0087     _bias_initializer: str = BIAS_INITIALIZER
0088     _checkpoint_dir: str = GLOBAL_CHECKPOINT_DIR
0089     _early_stop: bool = EARLY_STOP
0090     _save_model_every_epoch: bool = SAVE_MODEL_EVERY_EPOCH
0091     _save_best_model: bool = SAVE_BEST_MODEL
0092     _patience: int = PATIENCE
0093     _min_delta: float = MIN_DELTA
0094     _best_model_filename: str = BEST_MODEL_FILENAME
0095     _validation_split: float = VALIDATION_SPLIT
0096     _strategy: Strategy = MirroredStrategy()
0097 
0098     def __post_init__(self) -> None:
0099         # Calculate true batch size.
0100         self._batch_size = self._batch_size_per_replica * self._strategy.num_replicas_in_sync
0101         self._build_and_compile_new_model()
0102         # Setup Wandb.
0103         if self._wandb_project_name is not None:
0104             self._setup_wandb()
0105 
0106     def _setup_wandb(self) -> None:
0107         config = {
0108             "learning_rate": self._learning_rate,
0109             "batch_size": self._batch_size,
0110             "epochs": self._epochs,
0111             "optimizer_type": self._optimizer_type,
0112             "intermediate_dims": self._intermediate_dims,
0113             "latent_dim": self.latent_dim
0114         }
0115         # Reinit flag is needed for hyperparameter tuning. Whenever new training is started, new Wandb run should be
0116         # created.
0117         wandb.init(project=self._wandb_project_name, entity=WANDB_ENTITY, reinit=True, config=config,
0118                    tags=self._wandb_tags)
0119 
0120     def _build_and_compile_new_model(self) -> None:
0121         """ Builds and compiles a new model.
0122 
0123         VAEHandler keep a list of VAE instance. The reason is that while k-fold cross validation is performed,
0124         each fold requires a new, clear instance of model. New model is always added at the end of the list of
0125         existing ones.
0126 
0127         Returns: None
0128 
0129         """
0130         # Build encoder and decoder.
0131         encoder = self._build_encoder()
0132         decoder = self._build_decoder()
0133 
0134         # Compile model within a distributed strategy.
0135         with self._strategy.scope():
0136             # Build VAE.
0137             self.model = VAE(encoder, decoder)
0138             # Manufacture an optimizer and compile model with.
0139             optimizer = OptimizerFactory.create_optimizer(self._optimizer_type, self._learning_rate)
0140             reconstruction_loss = BinaryCrossentropy(reduction=Reduction.SUM)
0141             self.model.compile(optimizer=optimizer, loss=[reconstruction_loss], loss_weights=[ORIGINAL_DIM])
0142 
0143     def _prepare_input_layers(self, for_encoder: bool) -> List[Input]:
0144         """
0145         Create four Input layers. Each of them is responsible to take respectively: batch of showers/batch of latent
0146         vectors, batch of energies, batch of angles, batch of geometries.
0147 
0148         Args:
0149             for_encoder: Boolean which decides whether an input is full dimensional shower or a latent vector.
0150 
0151         Returns:
0152             List of Input layers (five for encoder and four for decoder).
0153 
0154         """
0155         e_input = Input(shape=(1,))
0156         angle_input = Input(shape=(1,))
0157         geo_input = Input(shape=(2,))
0158         if for_encoder:
0159             x_input = Input(shape=self._original_dim)
0160             eps_input = Input(shape=self.latent_dim)
0161             return [x_input, e_input, angle_input, geo_input, eps_input]
0162         else:
0163             x_input = Input(shape=self.latent_dim)
0164             return [x_input, e_input, angle_input, geo_input]
0165 
0166     def _build_encoder(self) -> Model:
0167         """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds
0168         the encoder.
0169 
0170         Returns:
0171              Encoder is returned as a keras.Model.
0172 
0173         """
0174 
0175         with self._strategy.scope():
0176             # Prepare input layer.
0177             x_input, e_input, angle_input, geo_input, eps_input = self._prepare_input_layers(for_encoder=True)
0178             x = concatenate([x_input, e_input, angle_input, geo_input])
0179             # Construct hidden layers (Dense and Batch Normalization).
0180             for intermediate_dim in self._intermediate_dims:
0181                 x = Dense(units=intermediate_dim, activation=self._activation,
0182                           kernel_initializer=self._kernel_initializer,
0183                           bias_initializer=self._bias_initializer)(x)
0184                 x = BatchNormalization()(x)
0185             # Add Dense layer to get description of multidimensional Gaussian distribution in terms of mean
0186             # and log(variance).
0187             z_mean = Dense(self.latent_dim, name="z_mean")(x)
0188             z_log_var = Dense(self.latent_dim, name="z_log_var")(x)
0189             # Add KLDivergenceLayer responsible for calculation of KL loss.
0190             z_mean, z_log_var = _KLDivergenceLayer()([z_mean, z_log_var])
0191             # Sample a probe from the distribution.
0192             encoder_output = _Sampling()([z_mean, z_log_var, eps_input])
0193             # Create model.
0194             encoder = Model(inputs=[x_input, e_input, angle_input, geo_input, eps_input], outputs=encoder_output,
0195                             name="encoder")
0196         return encoder
0197 
0198     def _build_decoder(self) -> Model:
0199         """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds
0200         the decoder.
0201 
0202         Returns:
0203              Decoder is returned as a keras.Model.
0204 
0205         """
0206 
0207         with self._strategy.scope():
0208             # Prepare input layer.
0209             latent_input, e_input, angle_input, geo_input = self._prepare_input_layers(for_encoder=False)
0210             x = concatenate([latent_input, e_input, angle_input, geo_input])
0211             # Construct hidden layers (Dense and Batch Normalization).
0212             for intermediate_dim in reversed(self._intermediate_dims):
0213                 x = Dense(units=intermediate_dim, activation=self._activation,
0214                           kernel_initializer=self._kernel_initializer,
0215                           bias_initializer=self._bias_initializer)(x)
0216                 x = BatchNormalization()(x)
0217             # Add Dense layer to get output which shape is compatible in an input's shape.
0218             decoder_outputs = Dense(units=self._original_dim, activation=self._out_activation)(x)
0219             # Create model.
0220             decoder = Model(inputs=[latent_input, e_input, angle_input, geo_input], outputs=decoder_outputs,
0221                             name="decoder")
0222         return decoder
0223 
0224     def _manufacture_callbacks(self) -> List[Callback]:
0225         """
0226         Based on parameters set by the user, manufacture callbacks required for training.
0227 
0228         Returns:
0229             A list of `Callback` objects.
0230 
0231         """
0232         callbacks = []
0233         # If the early stopping flag is on then stop the training when a monitored metric (validation) has stopped
0234         # improving after (patience) number of epochs.
0235         if self._early_stop:
0236             callbacks.append(
0237                 EarlyStopping(monitor="val_loss",
0238                               min_delta=self._min_delta,
0239                               patience=self._patience,
0240                               verbose=True,
0241                               restore_best_weights=True))
0242         # Save model after every epoch.
0243         if self._save_model_every_epoch:
0244             callbacks.append(ModelCheckpoint(filepath=f"{self._checkpoint_dir}/VAE_epoch_{{epoch:03}}/model_weights",
0245                                              monitor="val_loss",
0246                                              verbose=True,
0247                                              save_weights_only=True,
0248                                              mode="min",
0249                                              save_freq="epoch"))
0250         # Pass metadata to wandb.
0251         callbacks.append(WandbCallback(
0252             monitor="val_loss", verbose=0, mode="auto", save_model=False))
0253         return callbacks
0254 
0255     def _get_train_and_val_data(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0256                                 noise: np.array, train_indexes: np.array, validation_indexes: np.array) \
0257             -> Tuple[Dataset, Dataset]:
0258         """
0259         Splits data into train and validation set based on given lists of indexes.
0260 
0261         """
0262 
0263         # Prepare training data.
0264         train_dataset = dataset[train_indexes, :]
0265         train_e_cond = e_cond[train_indexes]
0266         train_angle_cond = angle_cond[train_indexes]
0267         train_geo_cond = geo_cond[train_indexes, :]
0268         train_noise = noise[train_indexes, :]
0269 
0270         # Prepare validation data.
0271         val_dataset = dataset[validation_indexes, :]
0272         val_e_cond = e_cond[validation_indexes]
0273         val_angle_cond = angle_cond[validation_indexes]
0274         val_geo_cond = geo_cond[validation_indexes, :]
0275         val_noise = noise[validation_indexes, :]
0276 
0277         # Gather them into tuples.
0278         train_x = (train_dataset, train_e_cond, train_angle_cond, train_geo_cond, train_noise)
0279         train_y = train_dataset
0280         val_x = (val_dataset, val_e_cond, val_angle_cond, val_geo_cond, val_noise)
0281         val_y = val_dataset
0282 
0283         # Wrap data in Dataset objects.
0284         # TODO(@mdragula): This approach requires loading the whole data set to RAM. It
0285         #  would be better to read the data partially when needed. Also one should bare in mind that using tf.Dataset
0286         #  slows down training process.
0287         train_data = Dataset.from_tensor_slices((train_x, train_y))
0288         val_data = Dataset.from_tensor_slices((val_x, val_y))
0289 
0290         # The batch size must now be set on the Dataset objects.
0291         train_data = train_data.batch(self._batch_size)
0292         val_data = val_data.batch(self._batch_size)
0293 
0294         # Disable AutoShard.
0295         options = tf.data.Options()
0296         options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
0297         train_data = train_data.with_options(options)
0298         val_data = val_data.with_options(options)
0299 
0300         return train_data, val_data
0301 
0302     def _k_fold_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0303                          noise: np.array, callbacks: List[Callback], verbose: bool = True) -> List[History]:
0304         """
0305         Performs K-fold cross validation training.
0306 
0307         Number of fold is defined by (self._number_of_k_fold_splits). Always shuffle the dataset.
0308 
0309         Args:
0310             dataset: A matrix representing showers. Shape =
0311                 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0312             e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0313             angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0314             geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0315             noise: A matrix representing an additional noise needed to perform a reparametrization trick.
0316             callbacks: A list of callback forwarded to the fitting function.
0317             verbose: A boolean which says there the training should be performed in a verbose mode or not.
0318 
0319         Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and
0320         metrics values at successive epochs, as well as validation loss values and validation metrics values (if
0321         applicable).
0322 
0323         """
0324         # TODO(@mdragula): KFold cross validation can be parallelized. Each fold is independent from each the others.
0325         k_fold = KFold(n_splits=self._number_of_k_fold_splits, shuffle=True)
0326         histories = []
0327 
0328         for i, (train_indexes, validation_indexes) in enumerate(k_fold.split(dataset)):
0329             print(f"K-fold: {i + 1}/{self._number_of_k_fold_splits}...")
0330             train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise,
0331                                                                 train_indexes, validation_indexes)
0332 
0333             self._build_and_compile_new_model()
0334 
0335             history = self.model.fit(x=train_data,
0336                                      shuffle=True,
0337                                      epochs=self._epochs,
0338                                      verbose=verbose,
0339                                      validation_data=val_data,
0340                                      callbacks=callbacks
0341                                      )
0342             histories.append(history)
0343 
0344             if self._save_best_model:
0345                 self.model.save_weights(f"{self._checkpoint_dir}/VAE_fold_{i + 1}/model_weights")
0346                 print(f"Best model from fold {i + 1} was saved.")
0347 
0348             # Remove all unnecessary data from previous fold.
0349             del self.model
0350             del train_data
0351             del val_data
0352             tf.keras.backend.clear_session()
0353             gc.collect()
0354 
0355         return histories
0356 
0357     def _single_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0358                          noise: np.ndarray, callbacks: List[Callback], verbose: bool = True) -> List[History]:
0359         """
0360         Performs a single training.
0361 
0362         A fraction of dataset (self._validation_split) is used as a validation data.
0363 
0364         Args:
0365             dataset: A matrix representing showers. Shape =
0366                 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0367             e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0368             angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0369             geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0370             noise: A matrix representing an additional noise needed to perform a reparametrization trick.
0371             callbacks: A list of callback forwarded to the fitting function.
0372             verbose: A boolean which says there the training should be performed in a verbose mode or not.
0373 
0374         Returns: A one-element list of `History` objects.`History.history` attribute is a record of training loss
0375         values and metrics values at successive epochs, as well as validation loss values and validation metrics
0376         values (if applicable).
0377 
0378         """
0379         dataset_size, _ = dataset.shape
0380         permutation = np.random.permutation(dataset_size)
0381         split = int(dataset_size * self._validation_split)
0382         train_indexes, validation_indexes = permutation[split:], permutation[:split]
0383 
0384         train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise, train_indexes,
0385                                                             validation_indexes)
0386 
0387         history = self.model.fit(x=train_data,
0388                                  shuffle=True,
0389                                  epochs=self._epochs,
0390                                  verbose=verbose,
0391                                  validation_data=val_data,
0392                                  callbacks=callbacks
0393                                  )
0394         if self._save_best_model:
0395             self.model.save_weights(f"{self._checkpoint_dir}/VAE_best/model_weights")
0396             print("Best model was saved.")
0397 
0398         return [history]
0399 
0400     def train(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0401               verbose: bool = True) -> List[History]:
0402         """
0403         For a given input data trains and validates the model.
0404 
0405         If the numer of K-fold splits > 1 then it runs K-fold cross validation, otherwise it runs a single training
0406         which uses (self._validation_split * 100) % of dataset as a validation data.
0407 
0408         Args:
0409             dataset: A matrix representing showers. Shape =
0410                 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0411             e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0412             angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0413             geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0414             verbose: A boolean which says there the training should be performed in a verbose mode or not.
0415 
0416         Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and
0417         metrics values at successive epochs, as well as validation loss values and validation metrics values (if
0418         applicable).
0419 
0420         """
0421 
0422         callbacks = self._manufacture_callbacks()
0423 
0424         noise = np.random.normal(0, 1, size=(dataset.shape[0], self.latent_dim))
0425 
0426         if self._number_of_k_fold_splits > 1:
0427             return self._k_fold_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)
0428         else:
0429             return self._single_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)