File indexing completed on 2025-11-03 09:01:32
0001 import torch
0002 import torch.nn as nn
0003 import torch.optim as optim
0004 import numpy as np
0005 
0006 class ProjectToX0Plane(nn.Module):
0007     def forward(self, x):
0008         
0009         x0, y0, z0, px, py, pz = x.unbind(dim=1)
0010 
0011         
0012         momentum = torch.sqrt(px**2 + py**2 + pz**2)
0013         px_norm = px / momentum
0014         py_norm = py / momentum
0015         pz_norm = pz / momentum
0016 
0017         
0018         
0019         
0020         t = -x0 / px_norm
0021 
0022         y_proj = y0 + py_norm * t
0023         z_proj = z0 + pz_norm * t
0024 
0025         
0026         return torch.stack([y_proj, z_proj, px_norm, py_norm], dim=1)
0027     
0028     def project_numpy(self, arr):
0029         """
0030         Projects a numpy array of shape (N, 6) using the forward method,
0031         returns a numpy array of shape (N, 4).
0032         """
0033         device = next(self.parameters()).device if any(p.device.type != 'cpu' for p in self.parameters()) else 'cpu'
0034         x = torch.from_numpy(arr).float().to(device)
0035         with torch.no_grad():
0036             projected = self.forward(x)
0037         return projected.cpu().numpy()
0038 
0039 class RegressionModel(nn.Module):
0040     def __init__(self):
0041         super(RegressionModel, self).__init__()
0042         self.project_to_x0 = ProjectToX0Plane()
0043         self.fc1  = nn.Linear(4, 512)
0044         self.fc2  = nn.Linear(512, 64)
0045         self.fc3  = nn.Linear(64, 3)  
0046 
0047         
0048         self.input_mean = nn.Parameter(torch.zeros(4), requires_grad=False)
0049         self.input_std = nn.Parameter(torch.ones(4), requires_grad=False)
0050         self.output_mean = nn.Parameter(torch.zeros(3), requires_grad=False)
0051         self.output_std = nn.Parameter(torch.ones(3), requires_grad=False)
0052         
0053 
0054     def forward(self, x):
0055         
0056         x = self.project_to_x0(x)
0057         
0058         x = (x - self.input_mean) / self.input_std
0059 
0060         
0061         x = self._core_forward(x)
0062 
0063         
0064         x = x * self.output_std + self.output_mean
0065         return x
0066     
0067     def _core_forward(self, x):
0068         
0069         x = torch.relu(self.fc1(x))
0070         x = torch.relu(self.fc2(x))
0071         x = self.fc3(x)
0072         return x
0073     
0074     def adapt(self, input_data, output_data):
0075         
0076         self.input_mean.data = input_data.mean(dim=0)
0077         self.input_std.data = input_data.std(dim=0)
0078         self.output_mean.data = output_data.mean(dim=0)
0079         self.output_std.data = output_data.std(dim=0)
0080 
0081 def preprocess_data(model, data_loader, adapt=True):
0082     inputs  = data_loader.dataset.tensors[0]
0083     targets = data_loader.dataset.tensors[1]
0084 
0085 
0086     projected_inputs = model.project_to_x0(inputs)
0087     
0088     
0089     if adapt:
0090         model.adapt(projected_inputs, targets)
0091 
0092     
0093     normalized_inputs  = (projected_inputs - model.input_mean ) / model.input_std
0094     normalized_targets = (targets          - model.output_mean) / model.output_std
0095 
0096     
0097     data_loader.dataset.tensors = (normalized_inputs, normalized_targets)
0098 
0099 def makeModel():
0100     
0101     model = RegressionModel()
0102     
0103     optimizer = optim.Adam(model.parameters(), lr=0.0004)
0104     
0105     criterion = nn.HuberLoss(delta=0.2)  
0106 
0107     return model, optimizer, criterion
0108 
0109 def trainModel(epochs, train_loader, val_loader, device):
0110 
0111     model, optimizer, criterion = makeModel()
0112 
0113     model.to(device)
0114 
0115     
0116     preprocess_data(model, train_loader, adapt=True)
0117 
0118     
0119     preprocess_data(model, val_loader, adapt=False)
0120 
0121     
0122     train_loader.dataset.tensors = (train_loader.dataset.tensors[0].to(device), train_loader.dataset.tensors[1].to(device))
0123     val_loader.dataset.tensors = (val_loader.dataset.tensors[0].to(device), val_loader.dataset.tensors[1].to(device))
0124 
0125     
0126     for name, param in model.named_parameters():
0127         print(f"{name} is on {param.device}")
0128 
0129     for epoch in range(epochs):
0130         model.train()
0131         running_loss = 0.0
0132         for inputs, targets in train_loader:
0133             optimizer.zero_grad()
0134             outputs = model._core_forward(inputs)
0135             loss = criterion(outputs, targets)
0136             loss.backward()
0137             optimizer.step()
0138             running_loss += loss.item() * inputs.size(0)
0139         
0140         epoch_loss = running_loss / len(train_loader.dataset)
0141 
0142         
0143         
0144         model.eval()
0145         val_loss = 0.0
0146         with torch.no_grad():
0147             for val_inputs, val_targets in val_loader:
0148                 val_outputs = model._core_forward(val_inputs)
0149                 val_loss += criterion(val_outputs, val_targets).item() * val_inputs.size(0)
0150 
0151         val_loss /= len(val_loader.dataset)
0152 
0153         print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss}, Val Loss: {val_loss}")
0154 
0155     return model