File indexing completed on 2025-01-18 09:12:07
0001 import pandas as pd
0002
0003 import torch.nn as nn
0004 import torch.nn.functional as F
0005 import torch.utils
0006
0007 import ast
0008
0009
0010 def prepareDataSet(data: pd.DataFrame) -> pd.DataFrame:
0011 """Format the dataset that have been written from the Csv file"""
0012 """
0013 @param[in] data: input DataFrame containing 1 event
0014 @return: Formatted DataFrame
0015 """
0016 data = data
0017
0018 data = data[data["nMeasurements"] > 6]
0019 data = data.sort_values("good/duplicate/fake", ascending=False)
0020
0021 data = data.drop_duplicates(
0022 subset=[
0023 "particleId",
0024 "Hits_ID",
0025 "nOutliers",
0026 "nHoles",
0027 "nSharedHits",
0028 "chi2",
0029 ],
0030 keep="first",
0031 )
0032
0033
0034 data = data.set_index("particleId")
0035
0036 hitsIds = []
0037 for list in data["Hits_ID"].values:
0038 hitsIds.append(ast.literal_eval(list))
0039 data["Hits_ID"] = hitsIds
0040
0041 return data
0042
0043
0044 class DuplicateClassifier(nn.Module):
0045 """MLP model used to separate good tracks from duplicate tracks. Return one score per track the higher one correspond to the good track."""
0046
0047 def __init__(self, input_dim, n_layers):
0048 """Three layer MLP, 20% dropout, sigmoid activation for the last layer."""
0049 super(DuplicateClassifier, self).__init__()
0050 self.linear1 = nn.Linear(input_dim, n_layers[0])
0051 self.linear2 = nn.Linear(n_layers[0], n_layers[1])
0052 self.linear3 = nn.Linear(n_layers[1], n_layers[2])
0053 self.output = nn.Linear(n_layers[2], 1)
0054 self.sigmoid = nn.Sigmoid()
0055
0056 def forward(self, z):
0057 z = F.relu(self.linear1(z))
0058 z = F.relu(self.linear2(z))
0059 z = F.relu(self.linear3(z))
0060 return self.sigmoid(self.output(z))
0061
0062
0063 class Normalise(nn.Module):
0064 """Normalisation of the input before the MLP model."""
0065
0066 def __init__(self, mean, std):
0067 super(Normalise, self).__init__()
0068 self.mean = torch.tensor(mean, dtype=torch.float32)
0069 self.std = torch.tensor(std, dtype=torch.float32)
0070
0071 def forward(self, z):
0072 z = z - self.mean
0073 z = z / self.std
0074 return z