Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-10-13 08:28:38

0001 import gc
0002 from dataclasses import dataclass, field
0003 from typing import List, Tuple
0004 
0005 import numpy as np
0006 import tensorflow as tf
0007 import wandb
0008 from sklearn.model_selection import KFold
0009 from tensorflow.keras import backend as K
0010 from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, History, Callback
0011 from tensorflow.keras.layers import BatchNormalization, Input, Dense, Layer, concatenate
0012 from tensorflow.keras.losses import BinaryCrossentropy, Reduction
0013 from tensorflow.keras.utils import Sequence
0014 from tensorflow.keras.models import Model
0015 from tensorflow.python.data import Dataset
0016 from tensorflow.python.distribute.distribute_lib import Strategy
0017 from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
0018 from wandb.keras import WandbCallback
0019 
0020 from core.constants import ORIGINAL_DIM, LATENT_DIM, BATCH_SIZE_PER_REPLICA, EPOCHS, LEARNING_RATE, ACTIVATION, \
0021     OUT_ACTIVATION, OPTIMIZER_TYPE, KERNEL_INITIALIZER, GLOBAL_CHECKPOINT_DIR, EARLY_STOP, BIAS_INITIALIZER, \
0022     INTERMEDIATE_DIMS, SAVE_MODEL_EVERY_EPOCH, SAVE_BEST_MODEL, PATIENCE, MIN_DELTA, BEST_MODEL_FILENAME, \
0023     NUMBER_OF_K_FOLD_SPLITS, VALIDATION_SPLIT, WANDB_ENTITY
0024 from utils.optimizer import OptimizerFactory, OptimizerType
0025 
0026 
0027 class _Sampling(Layer):
0028     """ Custom layer to do the reparameterization trick: sample random latent vectors z from the latent Gaussian
0029     distribution.
0030 
0031     The sampled vector z is given by sampled_z = mean + std * epsilon
0032     """
0033 
0034     def __call__(self, inputs, **kwargs):
0035         z_mean, z_log_var, epsilon = inputs
0036         z_sigma = K.exp(0.5 * z_log_var)
0037         return z_mean + z_sigma * epsilon
0038 
0039 
0040 # KL divergence computation
0041 class _KLDivergenceLayer(Layer):
0042 
0043     def call(self, inputs, **kwargs):
0044         mu, log_var = inputs
0045         kl_loss = -0.5 * (1 + log_var - K.square(mu) - K.exp(log_var))
0046         kl_loss = K.mean(K.sum(kl_loss, axis=-1))
0047         self.add_loss(kl_loss)
0048         return inputs
0049 
0050 
0051 class DataGenerator(Sequence):
0052     def __init__(self, x_set, y_set, batch_size):
0053         self.x, self.y = x_set, y_set
0054         self.batch_size = batch_size
0055 
0056     def __len__(self):
0057         return int(np.ceil(len(self.x[0]) / float(self.batch_size)))  # x[0] for actual showers
0058 
0059     def __getitem__(self, idx):
0060         batch_x = []
0061         for i in range(len(self.x)):
0062             batch_x.append(self.x[i][idx * self.batch_size:(idx + 1) * self.batch_size])
0063         batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
0064         return tuple(batch_x), batch_y
0065 
0066 
0067 class VAE(Model):
0068     def get_config(self):
0069         config = super().get_config()
0070         config["encoder"] = self.encoder
0071         config["decoder"] = self.decoder
0072         return config
0073 
0074     def call(self, inputs, training=None, mask=None):
0075         _, e_input, angle_input, geo_input, _ = inputs
0076         z = self.encoder(inputs)
0077         return self.decoder([z, e_input, angle_input, geo_input])
0078 
0079     def __init__(self, encoder, decoder, **kwargs):
0080         super(VAE, self).__init__(**kwargs)
0081         self.encoder = encoder
0082         self.decoder = decoder
0083         self._set_inputs(inputs=self.encoder.inputs, outputs=self(self.encoder.inputs))
0084 
0085 
0086 @dataclass
0087 class VAEHandler:
0088     """
0089     Class to handle building and training VAE models.
0090     """
0091     _wandb_project_name: str = None
0092     _wandb_run_name: str = None
0093     _wandb_tags: List[str] = field(default_factory=list)
0094     _original_dim: int = ORIGINAL_DIM
0095     latent_dim: int = LATENT_DIM
0096     _batch_size_per_replica: int = BATCH_SIZE_PER_REPLICA
0097     _intermediate_dims: List[int] = field(default_factory=lambda: INTERMEDIATE_DIMS)
0098     _learning_rate: float = LEARNING_RATE
0099     _epochs: int = EPOCHS
0100     _activation: str = ACTIVATION
0101     _out_activation: str = OUT_ACTIVATION
0102     _number_of_k_fold_splits: float = NUMBER_OF_K_FOLD_SPLITS
0103     _optimizer_type: OptimizerType = OPTIMIZER_TYPE
0104     _kernel_initializer: str = KERNEL_INITIALIZER
0105     _bias_initializer: str = BIAS_INITIALIZER
0106     _checkpoint_dir: str = GLOBAL_CHECKPOINT_DIR
0107     _early_stop: bool = EARLY_STOP
0108     _save_model_every_epoch: bool = SAVE_MODEL_EVERY_EPOCH
0109     _save_best_model: bool = SAVE_BEST_MODEL
0110     _patience: int = PATIENCE
0111     _min_delta: float = MIN_DELTA
0112     _best_model_filename: str = BEST_MODEL_FILENAME
0113     _validation_split: float = VALIDATION_SPLIT
0114     _strategy: Strategy = MirroredStrategy()
0115 
0116     def __post_init__(self) -> None:
0117         # Calculate true batch size.
0118         self._batch_size = self._batch_size_per_replica * self._strategy.num_replicas_in_sync
0119         self._build_and_compile_new_model()
0120         # Setup Wandb.
0121         if self._wandb_project_name is not None:
0122             self._setup_wandb()
0123 
0124     def _setup_wandb(self) -> None:
0125         config = {
0126             "learning_rate": self._learning_rate,
0127             "batch_size": self._batch_size,
0128             "epochs": self._epochs,
0129             "optimizer_type": self._optimizer_type,
0130             "intermediate_dims": self._intermediate_dims,
0131             "latent_dim": self.latent_dim
0132         }
0133         # Reinit flag is needed for hyperparameter tuning. Whenever new training is started, new Wandb run should be
0134         # created.
0135         wandb.init(name=self._wandb_run_name, project=self._wandb_project_name, entity=WANDB_ENTITY, reinit=True, config=config,
0136                    tags=self._wandb_tags)
0137 
0138     def _build_and_compile_new_model(self) -> None:
0139         """ Builds and compiles a new model.
0140 
0141         VAEHandler keep a list of VAE instance. The reason is that while k-fold cross validation is performed,
0142         each fold requires a new, clear instance of model. New model is always added at the end of the list of
0143         existing ones.
0144 
0145         Returns: None
0146 
0147         """
0148         # Build encoder and decoder.
0149         encoder = self._build_encoder()
0150         decoder = self._build_decoder()
0151 
0152         # Compile model within a distributed strategy.
0153         with self._strategy.scope():
0154             # Build VAE.
0155             self.model = VAE(encoder, decoder)
0156             # Manufacture an optimizer and compile model with.
0157             optimizer = OptimizerFactory.create_optimizer(self._optimizer_type, self._learning_rate)
0158             reconstruction_loss = BinaryCrossentropy(reduction=Reduction.SUM)
0159             self.model.compile(optimizer=optimizer, loss=[reconstruction_loss], loss_weights=[ORIGINAL_DIM])
0160 
0161     def _prepare_input_layers(self, for_encoder: bool) -> List[Input]:
0162         """
0163         Create four Input layers. Each of them is responsible to take respectively: batch of showers/batch of latent
0164         vectors, batch of energies, batch of angles, batch of geometries.
0165 
0166         Args:
0167             for_encoder: Boolean which decides whether an input is full dimensional shower or a latent vector.
0168 
0169         Returns:
0170             List of Input layers (five for encoder and four for decoder).
0171 
0172         """
0173         e_input = Input(shape=(1,))
0174         angle_input = Input(shape=(1,))
0175         geo_input = Input(shape=(2,))
0176         if for_encoder:
0177             x_input = Input(shape=self._original_dim)
0178             eps_input = Input(shape=self.latent_dim)
0179             return [x_input, e_input, angle_input, geo_input, eps_input]
0180         else:
0181             x_input = Input(shape=self.latent_dim)
0182             return [x_input, e_input, angle_input, geo_input]
0183 
0184     def _build_encoder(self) -> Model:
0185         """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds
0186         the encoder.
0187 
0188         Returns:
0189              Encoder is returned as a keras.Model.
0190 
0191         """
0192 
0193         with self._strategy.scope():
0194             # Prepare input layer.
0195             x_input, e_input, angle_input, geo_input, eps_input = self._prepare_input_layers(for_encoder=True)
0196             x = concatenate([x_input, e_input, angle_input, geo_input])
0197             # Construct hidden layers (Dense and Batch Normalization).
0198             for intermediate_dim in self._intermediate_dims:
0199                 x = Dense(units=intermediate_dim, activation=self._activation,
0200                           kernel_initializer=self._kernel_initializer,
0201                           bias_initializer=self._bias_initializer)(x)
0202                 x = BatchNormalization()(x)
0203             # Add Dense layer to get description of multidimensional Gaussian distribution in terms of mean
0204             # and log(variance).
0205             z_mean = Dense(self.latent_dim, name="z_mean")(x)
0206             z_log_var = Dense(self.latent_dim, name="z_log_var")(x)
0207             # Add KLDivergenceLayer responsible for calculation of KL loss.
0208             z_mean, z_log_var = _KLDivergenceLayer()([z_mean, z_log_var])
0209             # Sample a probe from the distribution.
0210             encoder_output = _Sampling()([z_mean, z_log_var, eps_input])
0211             # Create model.
0212             encoder = Model(inputs=[x_input, e_input, angle_input, geo_input, eps_input], outputs=encoder_output,
0213                             name="encoder")
0214         return encoder
0215 
0216     def _build_decoder(self) -> Model:
0217         """ Based on a list of intermediate dimensions, activation function and initializers for kernel and bias builds
0218         the decoder.
0219 
0220         Returns:
0221              Decoder is returned as a keras.Model.
0222 
0223         """
0224 
0225         with self._strategy.scope():
0226             # Prepare input layer.
0227             latent_input, e_input, angle_input, geo_input = self._prepare_input_layers(for_encoder=False)
0228             x = concatenate([latent_input, e_input, angle_input, geo_input])
0229             # Construct hidden layers (Dense and Batch Normalization).
0230             for intermediate_dim in reversed(self._intermediate_dims):
0231                 x = Dense(units=intermediate_dim, activation=self._activation,
0232                           kernel_initializer=self._kernel_initializer,
0233                           bias_initializer=self._bias_initializer)(x)
0234                 x = BatchNormalization()(x)
0235             # Add Dense layer to get output which shape is compatible in an input's shape.
0236             decoder_outputs = Dense(units=self._original_dim, activation=self._out_activation)(x)
0237             # Create model.
0238             decoder = Model(inputs=[latent_input, e_input, angle_input, geo_input], outputs=decoder_outputs,
0239                             name="decoder")
0240         return decoder
0241 
0242     def _manufacture_callbacks(self) -> List[Callback]:
0243         """
0244         Based on parameters set by the user, manufacture callbacks required for training.
0245 
0246         Returns:
0247             A list of `Callback` objects.
0248 
0249         """
0250         callbacks = []
0251         # If the early stopping flag is on then stop the training when a monitored metric (validation) has stopped
0252         # improving after (patience) number of epochs.
0253         if self._early_stop:
0254             callbacks.append(
0255                 EarlyStopping(monitor="val_loss",
0256                               min_delta=self._min_delta,
0257                               patience=self._patience,
0258                               verbose=True,
0259                               restore_best_weights=True))
0260         # Save model after every epoch.
0261         if self._save_model_every_epoch:
0262             callbacks.append(ModelCheckpoint(filepath=f"{self._checkpoint_dir}/VAE_epoch_{{epoch:03}}/model_weights",
0263                                              monitor="val_loss",
0264                                              verbose=True,
0265                                              save_weights_only=True,
0266                                              mode="min",
0267                                              save_freq="epoch"))
0268         # Pass metadata to wandb.
0269         callbacks.append(WandbCallback(
0270             monitor="val_loss", verbose=0, mode="auto", save_model=False))
0271         return callbacks
0272 
0273     def _get_train_and_val_data(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0274                                 noise: np.array, train_indexes: np.array, validation_indexes: np.array) \
0275             -> Tuple[Dataset, Dataset]:
0276         """
0277         Splits data into train and validation set based on given lists of indexes.
0278         Load batches to the GPU instead of entire dataset.
0279 
0280         """
0281 
0282         # Prepare training data.
0283         train_dataset = dataset[train_indexes, :]
0284         train_e_cond = e_cond[train_indexes]
0285         train_angle_cond = angle_cond[train_indexes]
0286         train_geo_cond = geo_cond[train_indexes, :]
0287         train_noise = noise[train_indexes, :]
0288 
0289         # Prepare validation data.
0290         val_dataset = dataset[validation_indexes, :]
0291         val_e_cond = e_cond[validation_indexes]
0292         val_angle_cond = angle_cond[validation_indexes]
0293         val_geo_cond = geo_cond[validation_indexes, :]
0294         val_noise = noise[validation_indexes, :]
0295 
0296         # Gather them into tuples.
0297         train_x = (train_dataset, train_e_cond, train_angle_cond, train_geo_cond, train_noise)
0298         train_y = train_dataset
0299         val_x = (val_dataset, val_e_cond, val_angle_cond, val_geo_cond, val_noise)
0300         val_y = val_dataset
0301 
0302         train_gen = DataGenerator(train_x, train_y, self._batch_size)
0303         val_gen = DataGenerator(val_x, val_y, self._batch_size)
0304         return train_gen, val_gen
0305 
0306     def _k_fold_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0307                          noise: np.array, callbacks: List[Callback], verbose: bool = True) -> List[History]:
0308         """
0309         Performs K-fold cross validation training.
0310 
0311         Number of fold is defined by (self._number_of_k_fold_splits). Always shuffle the dataset.
0312 
0313         Args:
0314             dataset: A matrix representing showers. Shape =
0315                 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0316             e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0317             angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0318             geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0319             noise: A matrix representing an additional noise needed to perform a reparametrization trick.
0320             callbacks: A list of callback forwarded to the fitting function.
0321             verbose: A boolean which says there the training should be performed in a verbose mode or not.
0322 
0323         Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and
0324         metrics values at successive epochs, as well as validation loss values and validation metrics values (if
0325         applicable).
0326 
0327         """
0328         # TODO(@mdragula): KFold cross validation can be parallelized. Each fold is independent from each the others.
0329         k_fold = KFold(n_splits=self._number_of_k_fold_splits, shuffle=True)
0330         histories = []
0331 
0332         for i, (train_indexes, validation_indexes) in enumerate(k_fold.split(dataset)):
0333             print(f"K-fold: {i + 1}/{self._number_of_k_fold_splits}...")
0334             train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise,
0335                                                                 train_indexes, validation_indexes)
0336 
0337             self._build_and_compile_new_model()
0338 
0339             history = self.model.fit(x=train_data,
0340                                      shuffle=True,
0341                                      epochs=self._epochs,
0342                                      verbose=verbose,
0343                                      validation_data=val_data,
0344                                      callbacks=callbacks
0345                                      )
0346             histories.append(history)
0347 
0348             if self._save_best_model:
0349                 self.model.save_weights(f"{self._checkpoint_dir}/VAE_fold_{i + 1}/model_weights")
0350                 print(f"Best model from fold {i + 1} was saved.")
0351 
0352             # Remove all unnecessary data from previous fold.
0353             del self.model
0354             del train_data
0355             del val_data
0356             tf.keras.backend.clear_session()
0357             gc.collect()
0358 
0359         return histories
0360 
0361     def _single_training(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0362                          noise: np.ndarray, callbacks: List[Callback], verbose: bool = True) -> List[History]:
0363         """
0364         Performs a single training.
0365 
0366         A fraction of dataset (self._validation_split) is used as a validation data.
0367 
0368         Args:
0369             dataset: A matrix representing showers. Shape =
0370                 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0371             e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0372             angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0373             geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0374             noise: A matrix representing an additional noise needed to perform a reparametrization trick.
0375             callbacks: A list of callback forwarded to the fitting function.
0376             verbose: A boolean which says there the training should be performed in a verbose mode or not.
0377 
0378         Returns: A one-element list of `History` objects.`History.history` attribute is a record of training loss
0379         values and metrics values at successive epochs, as well as validation loss values and validation metrics
0380         values (if applicable).
0381 
0382         """
0383         dataset_size, _ = dataset.shape
0384         permutation = np.random.permutation(dataset_size)
0385         split = int(dataset_size * self._validation_split)
0386         train_indexes, validation_indexes = permutation[split:], permutation[:split]
0387 
0388         train_data, val_data = self._get_train_and_val_data(dataset, e_cond, angle_cond, geo_cond, noise, train_indexes,
0389                                                             validation_indexes)
0390 
0391         history = self.model.fit(x=train_data,
0392                                  shuffle=True,
0393                                  epochs=self._epochs,
0394                                  verbose=verbose,
0395                                  validation_data=val_data,
0396                                  callbacks=callbacks
0397                                  )
0398         if self._save_best_model:
0399             self.model.save_weights(f"{self._checkpoint_dir}/VAE_best/model_weights")
0400             print("Best model was saved.")
0401 
0402         return [history]
0403 
0404     def train(self, dataset: np.array, e_cond: np.array, angle_cond: np.array, geo_cond: np.array,
0405               verbose: bool = True) -> List[History]:
0406         """
0407         For a given input data trains and validates the model.
0408 
0409         If the numer of K-fold splits > 1 then it runs K-fold cross validation, otherwise it runs a single training
0410         which uses (self._validation_split * 100) % of dataset as a validation data.
0411 
0412         Args:
0413             dataset: A matrix representing showers. Shape =
0414                 (number of samples, ORIGINAL_DIM = N_CELLS_Z * N_CELLS_R * N_CELLS_PHI).
0415             e_cond: A matrix representing an energy for each sample. Shape = (number of samples, ).
0416             angle_cond: A matrix representing an angle for each sample. Shape = (number of samples, ).
0417             geo_cond: A matrix representing a geometry of the detector for each sample. Shape = (number of samples, 2).
0418             verbose: A boolean which says there the training should be performed in a verbose mode or not.
0419 
0420         Returns: A list of `History` objects.`History.history` attribute is a record of training loss values and
0421         metrics values at successive epochs, as well as validation loss values and validation metrics values (if
0422         applicable).
0423 
0424         """
0425 
0426         callbacks = self._manufacture_callbacks()
0427 
0428         noise = np.random.normal(0, 1, size=(dataset.shape[0], self.latent_dim))
0429 
0430         if self._number_of_k_fold_splits > 1:
0431             return self._k_fold_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)
0432         else:
0433             return self._single_training(dataset, e_cond, angle_cond, geo_cond, noise, callbacks, verbose)