File indexing completed on 2025-12-16 09:27:47
0001 import torch
0002 import argparse
0003 from LoadData import create_arrays
0004 from torch.utils.data import DataLoader, TensorDataset
0005 from RegressionModel import trainModel
0006
0007
0008 parser = argparse.ArgumentParser(description='Train a regression model for the Tagger.')
0009 parser.add_argument('--dataFiles', type=str, nargs='+', help='Path to the data files')
0010 parser.add_argument('--outModelFile', type=str, default="regression_model.onnx", help='Output file for the trained model')
0011 parser.add_argument('--batchSize', type=int, default=1024, help='Batch size for training')
0012 parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs for training')
0013 parser.add_argument('--entries', type=int, default=None, help='Number of entries to process from the data files')
0014 parser.add_argument('--treeName', type=str, default="events", help='Name of the tree in the ROOT files')
0015 parser.add_argument('--featureName', type=str, default="_TaggerTrackerFeatureTensor_floatData", help='Name of the feature tensor')
0016 parser.add_argument('--targetName', type=str, default="_TaggerTrackerTargetTensor_floatData", help='Name of the target tensor')
0017
0018 args = parser.parse_args()
0019
0020 input_data, target_data = create_arrays(args.dataFiles, entries=args.entries, treeName=args.treeName, featureName=args.featureName, targetName=args.targetName)
0021
0022 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
0023 print(f"Using device: {device}")
0024 if device.type == 'cuda':
0025 print("Device:", torch.cuda.get_device_name(0))
0026
0027 torch_input_data = torch.tensor(input_data, dtype=torch.float32)
0028 torch_target_data = torch.tensor(target_data, dtype=torch.float32)
0029
0030 print(f"Input data shape: {torch_input_data.shape}")
0031 print(f"Target data shape: {torch_target_data.shape}")
0032
0033
0034 validation_fraction = 0.25
0035 split_index = int(len(torch_input_data) * (1 - validation_fraction))
0036
0037 val_input_data = torch_input_data[split_index:]
0038 val_target_data = torch_target_data[split_index:]
0039 train_input_data = torch_input_data[:split_index]
0040 train_target_data = torch_target_data[:split_index]
0041
0042
0043 train_dataset = TensorDataset(train_input_data, train_target_data)
0044 val_dataset = TensorDataset(val_input_data, val_target_data)
0045
0046
0047 train_loader = DataLoader(train_dataset, batch_size=args.batchSize, shuffle=True )
0048 val_loader = DataLoader(val_dataset, batch_size=args.batchSize, shuffle=False)
0049
0050
0051 print(f"Training data: {len(train_input_data)} samples")
0052 model = trainModel(args.epochs, train_loader, val_loader, device)
0053
0054
0055 dummy_input = torch_input_data[0].unsqueeze(0).to(device)
0056 model.to(device)
0057
0058 torch.onnx.export(model, dummy_input, args.outModelFile,
0059 input_names=['input'], output_names=['output'],
0060 dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
0061 opset_version=11)
0062
0063 print(f"Model has been saved to {args.outModelFile}")