File indexing completed on 2025-11-03 09:01:32
0001 import torch
0002 import argparse
0003 from LoadData import create_arrays
0004 from torch.utils.data import DataLoader, TensorDataset
0005 from RegressionModel  import trainModel
0006 
0007 
0008 parser = argparse.ArgumentParser(description='Train a regression model for the Tagger.')
0009 parser.add_argument('--dataFiles', type=str, nargs='+', help='Path to the data files')
0010 parser.add_argument('--outModelFile', type=str, default="regression_model.onnx", help='Output file for the trained model')
0011 parser.add_argument('--batchSize', type=int, default=1024, help='Batch size for training')
0012 parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs for training')
0013 parser.add_argument('--entries', type=int, default=None, help='Number of entries to process from the data files')
0014 parser.add_argument('--treeName', type=str, default="events", help='Name of the tree in the ROOT files')
0015 parser.add_argument('--featureName', type=str, default="_TaggerTrackerFeatureTensor_floatData", help='Name of the feature tensor')
0016 parser.add_argument('--targetName', type=str, default="_TaggerTrackerTargetTensor_floatData", help='Name of the target tensor')
0017 
0018 args   = parser.parse_args()
0019 
0020 input_data, target_data = create_arrays(args.dataFiles, entries=args.entries, treeName=args.treeName, featureName=args.featureName, targetName=args.targetName)
0021 
0022 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
0023 print(f"Using device: {device}")
0024 if device.type == 'cuda':
0025     print("Device:", torch.cuda.get_device_name(0))
0026     
0027 torch_input_data  = torch.tensor(input_data, dtype=torch.float32)
0028 torch_target_data = torch.tensor(target_data, dtype=torch.float32)
0029 
0030 print(f"Input data shape: {torch_input_data.shape}")
0031 print(f"Target data shape: {torch_target_data.shape}")
0032 
0033 
0034 validation_fraction = 0.25
0035 split_index         = int(len(torch_input_data) * (1 - validation_fraction))
0036 
0037 val_input_data    = torch_input_data[split_index:]
0038 val_target_data   = torch_target_data[split_index:]
0039 train_input_data  = torch_input_data[:split_index]
0040 train_target_data = torch_target_data[:split_index] 
0041 
0042 
0043 train_dataset = TensorDataset(train_input_data, train_target_data)
0044 val_dataset   = TensorDataset(val_input_data, val_target_data)
0045 
0046 
0047 train_loader = DataLoader(train_dataset, batch_size=args.batchSize, shuffle=True )
0048 val_loader   = DataLoader(val_dataset,   batch_size=args.batchSize, shuffle=False)
0049 
0050 
0051 print(f"Training data: {len(train_input_data)} samples")
0052 model  = trainModel(args.epochs, train_loader, val_loader, device)
0053 
0054 
0055 dummy_input = torch_input_data[0].unsqueeze(0).to(device)  
0056 model.to(device)  
0057 
0058 torch.onnx.export(model, dummy_input, args.outModelFile, 
0059                   input_names=['input'], output_names=['output'],
0060                   dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
0061                   opset_version=11)
0062 
0063 print(f"Model has been saved to {args.outModelFile}")