File indexing completed on 2025-01-18 09:12:08
0001
0002
0003 from pathlib import Path
0004 import os
0005 import sys
0006
0007 import acts.examples
0008 import acts
0009 from acts.examples.reconstruction import addExaTrkX, ExaTrkXBackend
0010 from acts import UnitConstants as u
0011
0012 from digitization import runDigitization
0013
0014
0015 def runGNNTrackFinding(
0016 trackingGeometry,
0017 field,
0018 outputDir,
0019 digiConfigFile,
0020 geometrySelection,
0021 backend,
0022 modelDir,
0023 outputRoot=False,
0024 outputCsv=False,
0025 s=None,
0026 ):
0027 s = runDigitization(
0028 trackingGeometry,
0029 field,
0030 outputDir,
0031 digiConfigFile=digiConfigFile,
0032 particlesInput=None,
0033 outputRoot=outputRoot,
0034 outputCsv=outputCsv,
0035 s=s,
0036 )
0037
0038 addExaTrkX(
0039 s,
0040 trackingGeometry,
0041 geometrySelection,
0042 modelDir,
0043 backend=backend,
0044 outputDirRoot=outputDir if outputRoot else None,
0045 )
0046
0047 s.run()
0048
0049
0050 if "__main__" == __name__:
0051
0052 backend = ExaTrkXBackend.Torch
0053
0054 if "onnx" in sys.argv:
0055 backend = ExaTrkXBackend.Onnx
0056 if "torch" in sys.argv:
0057 backend = ExaTrkXBackend.Torch
0058
0059 detector = acts.examples.GenericDetector()
0060 trackingGeometry = detector.trackingGeometry()
0061
0062 field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0063
0064 srcdir = Path(__file__).resolve().parent.parent.parent.parent
0065
0066 geometrySelection = (
0067 srcdir
0068 / "Examples/Algorithms/TrackFinding/share/geoSelection-genericDetector.json"
0069 )
0070 assert geometrySelection.exists()
0071
0072 digiConfigFile = (
0073 srcdir
0074 / "Examples/Algorithms/Digitization/share/default-smearing-config-generic.json"
0075 )
0076 assert digiConfigFile.exists()
0077
0078 if backend == ExaTrkXBackend.Torch:
0079 modelDir = Path.cwd() / "torchscript_models"
0080 assert (modelDir / "embed.pt").exists()
0081 assert (modelDir / "filter.pt").exists()
0082 assert (modelDir / "gnn.pt").exists()
0083 else:
0084 modelDir = Path.cwd() / "onnx_models"
0085 assert (modelDir / "embedding.onnx").exists()
0086 assert (modelDir / "filtering.onnx").exists()
0087 assert (modelDir / "gnn.onnx").exists()
0088
0089 s = acts.examples.Sequencer(events=2, numThreads=1)
0090 s.config.logLevel = acts.logging.INFO
0091
0092 rnd = acts.examples.RandomNumbers()
0093 outputDir = Path(os.getcwd())
0094
0095 runGNNTrackFinding(
0096 trackingGeometry,
0097 field,
0098 outputDir,
0099 digiConfigFile,
0100 geometrySelection,
0101 backend,
0102 modelDir,
0103 outputRoot=True,
0104 outputCsv=False,
0105 s=s,
0106 )