Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-06-21 08:09:40

0001 from pathlib import Path
0002 import argparse
0003 
0004 import acts
0005 import acts.examples
0006 from acts.examples.reconstruction import addTrackSelection, TrackSelectorConfig
0007 
0008 u = acts.UnitConstants
0009 
0010 
0011 def runGNN4ITk(
0012     inputRootDump: Path,
0013     moduleMapPath: str,
0014     gnnModel: Path,
0015     events: int = 1,
0016     logLevel=acts.logging.INFO,
0017 ):
0018     assert inputRootDump.exists()
0019     assert Path(moduleMapPath + ".doublets.root").exists()
0020     assert Path(moduleMapPath + ".triplets.root").exists()
0021     assert gnnModel.exists()
0022 
0023     moduleMapConfig = {
0024         "level": logLevel,
0025         "moduleMapPath": moduleMapPath,
0026         "rScale": 1000.0,
0027         "phiScale": 3.141592654,
0028         "zScale": 1000.0,
0029         "gpuDevice": 0,
0030         "gpuBlocks": 512,
0031         "moreParallel": True,
0032     }
0033 
0034     gnnConfig = {
0035         "level": logLevel,
0036         "cut": 0.5,
0037         "modelPath": str(gnnModel),
0038         "useEdgeFeatures": True,
0039     }
0040 
0041     builderCfg = {
0042         "level": logLevel,
0043         "useOneBlockImplementation": False,
0044         "doJunctionRemoval": True,
0045     }
0046 
0047     graphConstructor = acts.examples.ModuleMapCuda(**moduleMapConfig)
0048     if gnnModel.suffix == ".pt":
0049         edgeClassifier = acts.examples.TorchEdgeClassifier(**gnnConfig)
0050     elif gnnModel.suffix == ".onnx":
0051         del gnnConfig["useEdgeFeatures"]
0052         edgeClassifier = acts.examples.OnnxEdgeClassifier(**gnnConfig)
0053     elif gnnModel.suffix == ".engine":
0054         edgeClassifier = acts.examples.TensorRTEdgeClassifier(**gnnConfig)
0055     trackBuilder = acts.examples.CudaTrackBuilding(**builderCfg)
0056 
0057     s = acts.examples.Sequencer(
0058         events=events,
0059         numThreads=1,
0060     )
0061 
0062     s.addReader(
0063         acts.examples.RootAthenaDumpReader(
0064             level=logLevel,
0065             treename="GNN4ITk",
0066             inputfiles=[str(inputRootDump)],
0067             outputSpacePoints="spacepoints",
0068             outputClusters="clusters",
0069             outputMeasurements="measurements",
0070             outputMeasurementParticlesMap="measurement_particles_map",
0071             outputParticleMeasurementsMap="particle_measurements_map",
0072             outputParticles="particles",
0073             skipOverlapSPsPhi=True,
0074             skipOverlapSPsEta=False,
0075             absBoundaryTolerance=0.01 * u.mm,
0076         )
0077     )
0078 
0079     e = acts.examples.NodeFeature
0080     s.addAlgorithm(
0081         acts.examples.TrackFindingAlgorithmExaTrkX(
0082             level=logLevel,
0083             graphConstructor=graphConstructor,
0084             edgeClassifiers=[edgeClassifier],
0085             trackBuilder=trackBuilder,
0086             nodeFeatures=[
0087                 e.R,
0088                 e.Phi,
0089                 e.Z,
0090                 e.Eta,
0091                 e.Cluster1R,
0092                 e.Cluster1Phi,
0093                 e.Cluster1Z,
0094                 e.Cluster1Eta,
0095                 e.Cluster2R,
0096                 e.Cluster2Phi,
0097                 e.Cluster2Z,
0098                 e.Cluster2Eta,
0099             ],
0100             featureScales=[1000.0, 3.14159265359, 1000.0, 1.0] * 3,
0101             inputSpacePoints="spacepoints",
0102             inputClusters="clusters",
0103             outputProtoTracks="prototracks",
0104         )
0105     )
0106 
0107     s.addAlgorithm(
0108         acts.examples.PrototracksToTracks(
0109             level=logLevel,
0110             inputProtoTracks="prototracks",
0111             inputMeasurements="measurements",
0112             outputTracks="gnn_only_tracks",
0113         )
0114     )
0115 
0116     s.addAlgorithm(
0117         acts.examples.ParticleSelector(
0118             level=logLevel,
0119             ptMin=1 * u.GeV,
0120             rhoMax=26 * u.cm,
0121             measurementsMin=7,
0122             removeSecondaries=True,
0123             removeNeutral=True,
0124             excludeAbsPdgs=[
0125                 11,
0126             ],
0127             inputParticles="particles",
0128             outputParticles="particles_selected",
0129             inputParticleMeasurementsMap="particle_measurements_map",
0130         )
0131     )
0132 
0133     addTrackSelection(
0134         s,
0135         TrackSelectorConfig(nMeasurementsMin=7, requireReferenceSurface=False),
0136         inputTracks="gnn_only_tracks",
0137         outputTracks="gnn_only_tracks_selected",
0138         logLevel=logLevel,
0139     )
0140 
0141     # NOTE: This is not standard ATLAS matching
0142     s.addAlgorithm(
0143         acts.examples.TrackTruthMatcher(
0144             level=logLevel,
0145             inputTracks="gnn_only_tracks_selected",
0146             inputParticles="particles_selected",
0147             inputMeasurementParticlesMap="measurement_particles_map",
0148             outputTrackParticleMatching="tpm",
0149             outputParticleTrackMatching="ptm",
0150             doubleMatching=True,
0151         )
0152     )
0153 
0154     s.addWriter(
0155         acts.examples.TrackFinderPerformanceWriter(
0156             level=logLevel,
0157             inputParticles="particles_selected",
0158             inputParticleMeasurementsMap="particle_measurements_map",
0159             inputTrackParticleMatching="tpm",
0160             inputParticleTrackMatching="ptm",
0161             inputTracks="gnn_only_tracks_selected",
0162             filePath="performance_gnn4itk.root",
0163         )
0164     )
0165 
0166     s.run()
0167 
0168 
0169 if __name__ == "__main__":
0170     argparser = argparse.ArgumentParser(description="Run the GNN4ITk example")
0171 
0172     argparser.add_argument(
0173         "--inputRootDump",
0174         type=Path,
0175         required=True,
0176         help="Path to the input ROOT dump file",
0177     )
0178     argparser.add_argument(
0179         "--moduleMapPath",
0180         type=str,
0181         required=True,
0182         help="Path to the module map file (without .doublets.root or .triplets.root suffixes)",
0183     )
0184     argparser.add_argument(
0185         "--gnnModel",
0186         type=Path,
0187         required=True,
0188         help="Path to the GNN model file (ONNX or Torch)",
0189     )
0190     argparser.add_argument(
0191         "--events",
0192         type=int,
0193         default=1,
0194         help="Number of events to process",
0195     )
0196 
0197     args = argparser.parse_args()
0198 
0199     runGNN4ITk(
0200         inputRootDump=args.inputRootDump,
0201         moduleMapPath=args.moduleMapPath,
0202         gnnModel=args.gnnModel,
0203         events=args.events,
0204     )