File indexing completed on 2025-06-21 08:09:40
0001 from pathlib import Path
0002 import argparse
0003
0004 import acts
0005 import acts.examples
0006 from acts.examples.reconstruction import addTrackSelection, TrackSelectorConfig
0007
0008 u = acts.UnitConstants
0009
0010
0011 def runGNN4ITk(
0012 inputRootDump: Path,
0013 moduleMapPath: str,
0014 gnnModel: Path,
0015 events: int = 1,
0016 logLevel=acts.logging.INFO,
0017 ):
0018 assert inputRootDump.exists()
0019 assert Path(moduleMapPath + ".doublets.root").exists()
0020 assert Path(moduleMapPath + ".triplets.root").exists()
0021 assert gnnModel.exists()
0022
0023 moduleMapConfig = {
0024 "level": logLevel,
0025 "moduleMapPath": moduleMapPath,
0026 "rScale": 1000.0,
0027 "phiScale": 3.141592654,
0028 "zScale": 1000.0,
0029 "gpuDevice": 0,
0030 "gpuBlocks": 512,
0031 "moreParallel": True,
0032 }
0033
0034 gnnConfig = {
0035 "level": logLevel,
0036 "cut": 0.5,
0037 "modelPath": str(gnnModel),
0038 "useEdgeFeatures": True,
0039 }
0040
0041 builderCfg = {
0042 "level": logLevel,
0043 "useOneBlockImplementation": False,
0044 "doJunctionRemoval": True,
0045 }
0046
0047 graphConstructor = acts.examples.ModuleMapCuda(**moduleMapConfig)
0048 if gnnModel.suffix == ".pt":
0049 edgeClassifier = acts.examples.TorchEdgeClassifier(**gnnConfig)
0050 elif gnnModel.suffix == ".onnx":
0051 del gnnConfig["useEdgeFeatures"]
0052 edgeClassifier = acts.examples.OnnxEdgeClassifier(**gnnConfig)
0053 elif gnnModel.suffix == ".engine":
0054 edgeClassifier = acts.examples.TensorRTEdgeClassifier(**gnnConfig)
0055 trackBuilder = acts.examples.CudaTrackBuilding(**builderCfg)
0056
0057 s = acts.examples.Sequencer(
0058 events=events,
0059 numThreads=1,
0060 )
0061
0062 s.addReader(
0063 acts.examples.RootAthenaDumpReader(
0064 level=logLevel,
0065 treename="GNN4ITk",
0066 inputfiles=[str(inputRootDump)],
0067 outputSpacePoints="spacepoints",
0068 outputClusters="clusters",
0069 outputMeasurements="measurements",
0070 outputMeasurementParticlesMap="measurement_particles_map",
0071 outputParticleMeasurementsMap="particle_measurements_map",
0072 outputParticles="particles",
0073 skipOverlapSPsPhi=True,
0074 skipOverlapSPsEta=False,
0075 absBoundaryTolerance=0.01 * u.mm,
0076 )
0077 )
0078
0079 e = acts.examples.NodeFeature
0080 s.addAlgorithm(
0081 acts.examples.TrackFindingAlgorithmExaTrkX(
0082 level=logLevel,
0083 graphConstructor=graphConstructor,
0084 edgeClassifiers=[edgeClassifier],
0085 trackBuilder=trackBuilder,
0086 nodeFeatures=[
0087 e.R,
0088 e.Phi,
0089 e.Z,
0090 e.Eta,
0091 e.Cluster1R,
0092 e.Cluster1Phi,
0093 e.Cluster1Z,
0094 e.Cluster1Eta,
0095 e.Cluster2R,
0096 e.Cluster2Phi,
0097 e.Cluster2Z,
0098 e.Cluster2Eta,
0099 ],
0100 featureScales=[1000.0, 3.14159265359, 1000.0, 1.0] * 3,
0101 inputSpacePoints="spacepoints",
0102 inputClusters="clusters",
0103 outputProtoTracks="prototracks",
0104 )
0105 )
0106
0107 s.addAlgorithm(
0108 acts.examples.PrototracksToTracks(
0109 level=logLevel,
0110 inputProtoTracks="prototracks",
0111 inputMeasurements="measurements",
0112 outputTracks="gnn_only_tracks",
0113 )
0114 )
0115
0116 s.addAlgorithm(
0117 acts.examples.ParticleSelector(
0118 level=logLevel,
0119 ptMin=1 * u.GeV,
0120 rhoMax=26 * u.cm,
0121 measurementsMin=7,
0122 removeSecondaries=True,
0123 removeNeutral=True,
0124 excludeAbsPdgs=[
0125 11,
0126 ],
0127 inputParticles="particles",
0128 outputParticles="particles_selected",
0129 inputParticleMeasurementsMap="particle_measurements_map",
0130 )
0131 )
0132
0133 addTrackSelection(
0134 s,
0135 TrackSelectorConfig(nMeasurementsMin=7, requireReferenceSurface=False),
0136 inputTracks="gnn_only_tracks",
0137 outputTracks="gnn_only_tracks_selected",
0138 logLevel=logLevel,
0139 )
0140
0141
0142 s.addAlgorithm(
0143 acts.examples.TrackTruthMatcher(
0144 level=logLevel,
0145 inputTracks="gnn_only_tracks_selected",
0146 inputParticles="particles_selected",
0147 inputMeasurementParticlesMap="measurement_particles_map",
0148 outputTrackParticleMatching="tpm",
0149 outputParticleTrackMatching="ptm",
0150 doubleMatching=True,
0151 )
0152 )
0153
0154 s.addWriter(
0155 acts.examples.TrackFinderPerformanceWriter(
0156 level=logLevel,
0157 inputParticles="particles_selected",
0158 inputParticleMeasurementsMap="particle_measurements_map",
0159 inputTrackParticleMatching="tpm",
0160 inputParticleTrackMatching="ptm",
0161 inputTracks="gnn_only_tracks_selected",
0162 filePath="performance_gnn4itk.root",
0163 )
0164 )
0165
0166 s.run()
0167
0168
0169 if __name__ == "__main__":
0170 argparser = argparse.ArgumentParser(description="Run the GNN4ITk example")
0171
0172 argparser.add_argument(
0173 "--inputRootDump",
0174 type=Path,
0175 required=True,
0176 help="Path to the input ROOT dump file",
0177 )
0178 argparser.add_argument(
0179 "--moduleMapPath",
0180 type=str,
0181 required=True,
0182 help="Path to the module map file (without .doublets.root or .triplets.root suffixes)",
0183 )
0184 argparser.add_argument(
0185 "--gnnModel",
0186 type=Path,
0187 required=True,
0188 help="Path to the GNN model file (ONNX or Torch)",
0189 )
0190 argparser.add_argument(
0191 "--events",
0192 type=int,
0193 default=1,
0194 help="Number of events to process",
0195 )
0196
0197 args = argparser.parse_args()
0198
0199 runGNN4ITk(
0200 inputRootDump=args.inputRootDump,
0201 moduleMapPath=args.moduleMapPath,
0202 gnnModel=args.gnnModel,
0203 events=args.events,
0204 )