Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-11-03 09:01:32

0001 import torch
0002 import torch.nn as nn
0003 import torch.optim as optim
0004 import numpy as np
0005 
0006 class ProjectToX0Plane(nn.Module):
0007     def forward(self, x):
0008         # x shape: (batch, 6) -> [x, y, z, px, py, pz]
0009         x0, y0, z0, px, py, pz = x.unbind(dim=1)
0010 
0011         # Normalize momentum components
0012         momentum = torch.sqrt(px**2 + py**2 + pz**2)
0013         px_norm = px / momentum
0014         py_norm = py / momentum
0015         pz_norm = pz / momentum
0016 
0017         # Avoid division by zero for px
0018         # eps = 1e-8
0019         # px_safe = torch.where(px_norm.abs() < eps, eps * torch.sign(px_norm) + eps, px_norm)
0020         t = -x0 / px_norm
0021 
0022         y_proj = y0 + py_norm * t
0023         z_proj = z0 + pz_norm * t
0024 
0025         # Output: [y_proj, z_proj, px_norm, py_norm]
0026         return torch.stack([y_proj, z_proj, px_norm, py_norm], dim=1)
0027     
0028     def project_numpy(self, arr):
0029         """
0030         Projects a numpy array of shape (N, 6) using the forward method,
0031         returns a numpy array of shape (N, 4).
0032         """
0033         device = next(self.parameters()).device if any(p.device.type != 'cpu' for p in self.parameters()) else 'cpu'
0034         x = torch.from_numpy(arr).float().to(device)
0035         with torch.no_grad():
0036             projected = self.forward(x)
0037         return projected.cpu().numpy()
0038 
0039 class RegressionModel(nn.Module):
0040     def __init__(self):
0041         super(RegressionModel, self).__init__()
0042         self.project_to_x0 = ProjectToX0Plane()
0043         self.fc1  = nn.Linear(4, 512)
0044         self.fc2  = nn.Linear(512, 64)
0045         self.fc3  = nn.Linear(64, 3)  # Output layer for
0046 
0047         # Normalization parameters
0048         self.input_mean = nn.Parameter(torch.zeros(4), requires_grad=False)
0049         self.input_std = nn.Parameter(torch.ones(4), requires_grad=False)
0050         self.output_mean = nn.Parameter(torch.zeros(3), requires_grad=False)
0051         self.output_std = nn.Parameter(torch.ones(3), requires_grad=False)
0052         
0053 
0054     def forward(self, x):
0055         # Conditionally apply projection
0056         x = self.project_to_x0(x)
0057         # Normalize inputs
0058         x = (x - self.input_mean) / self.input_std
0059 
0060         # Pass through the fully connected layers
0061         x = self._core_forward(x)
0062 
0063         # Denormalize outputs
0064         x = x * self.output_std + self.output_mean
0065         return x
0066     
0067     def _core_forward(self, x):
0068         # Core fully connected layers
0069         x = torch.relu(self.fc1(x))
0070         x = torch.relu(self.fc2(x))
0071         x = self.fc3(x)
0072         return x
0073     
0074     def adapt(self, input_data, output_data):
0075         # Normalization
0076         self.input_mean.data = input_data.mean(dim=0)
0077         self.input_std.data = input_data.std(dim=0)
0078         self.output_mean.data = output_data.mean(dim=0)
0079         self.output_std.data = output_data.std(dim=0)
0080 
0081 def preprocess_data(model, data_loader, adapt=True):
0082     inputs  = data_loader.dataset.tensors[0]
0083     targets = data_loader.dataset.tensors[1]
0084 
0085 
0086     projected_inputs = model.project_to_x0(inputs)
0087     
0088     # Compute normalization parameters
0089     if adapt:
0090         model.adapt(projected_inputs, targets)
0091 
0092     # Normalize inputs and targets
0093     normalized_inputs  = (projected_inputs - model.input_mean ) / model.input_std
0094     normalized_targets = (targets          - model.output_mean) / model.output_std
0095 
0096     # Replace the dataset with preprocessed data
0097     data_loader.dataset.tensors = (normalized_inputs, normalized_targets)
0098 
0099 def makeModel():
0100     # Create the model
0101     model = RegressionModel()
0102     # Define the optimizer
0103     optimizer = optim.Adam(model.parameters(), lr=0.0004)
0104     # Define the loss function
0105     criterion = nn.HuberLoss(delta=0.2)  # Huber loss for regression tasks
0106 
0107     return model, optimizer, criterion
0108 
0109 def trainModel(epochs, train_loader, val_loader, device):
0110 
0111     model, optimizer, criterion = makeModel()
0112 
0113     model.to(device)
0114 
0115     # Preprocess training and validation data
0116     preprocess_data(model, train_loader, adapt=True)
0117 
0118     # Preprocess validation data without adapting
0119     preprocess_data(model, val_loader, adapt=False)
0120 
0121     # Move data to the GPU
0122     train_loader.dataset.tensors = (train_loader.dataset.tensors[0].to(device), train_loader.dataset.tensors[1].to(device))
0123     val_loader.dataset.tensors = (val_loader.dataset.tensors[0].to(device), val_loader.dataset.tensors[1].to(device))
0124 
0125     # Verify that the model parameters are on the GPU
0126     for name, param in model.named_parameters():
0127         print(f"{name} is on {param.device}")
0128 
0129     for epoch in range(epochs):
0130         model.train()
0131         running_loss = 0.0
0132         for inputs, targets in train_loader:
0133             optimizer.zero_grad()
0134             outputs = model._core_forward(inputs)
0135             loss = criterion(outputs, targets)
0136             loss.backward()
0137             optimizer.step()
0138             running_loss += loss.item() * inputs.size(0)
0139         
0140         epoch_loss = running_loss / len(train_loader.dataset)
0141 
0142         
0143         # Validation step
0144         model.eval()
0145         val_loss = 0.0
0146         with torch.no_grad():
0147             for val_inputs, val_targets in val_loader:
0148                 val_outputs = model._core_forward(val_inputs)
0149                 val_loss += criterion(val_outputs, val_targets).item() * val_inputs.size(0)
0150 
0151         val_loss /= len(val_loader.dataset)
0152 
0153         print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss}, Val Loss: {val_loss}")
0154 
0155     return model