Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:12:08

0001 #!/usr/bin/env python3
0002 
0003 from pathlib import Path
0004 import os
0005 import sys
0006 
0007 import acts.examples
0008 import acts
0009 from acts.examples.reconstruction import addExaTrkX, ExaTrkXBackend
0010 from acts import UnitConstants as u
0011 
0012 from digitization import runDigitization
0013 
0014 
0015 def runGNNTrackFinding(
0016     trackingGeometry,
0017     field,
0018     outputDir,
0019     digiConfigFile,
0020     geometrySelection,
0021     backend,
0022     modelDir,
0023     outputRoot=False,
0024     outputCsv=False,
0025     s=None,
0026 ):
0027     s = runDigitization(
0028         trackingGeometry,
0029         field,
0030         outputDir,
0031         digiConfigFile=digiConfigFile,
0032         particlesInput=None,
0033         outputRoot=outputRoot,
0034         outputCsv=outputCsv,
0035         s=s,
0036     )
0037 
0038     addExaTrkX(
0039         s,
0040         trackingGeometry,
0041         geometrySelection,
0042         modelDir,
0043         backend=backend,
0044         outputDirRoot=outputDir if outputRoot else None,
0045     )
0046 
0047     s.run()
0048 
0049 
0050 if "__main__" == __name__:
0051 
0052     backend = ExaTrkXBackend.Torch
0053 
0054     if "onnx" in sys.argv:
0055         backend = ExaTrkXBackend.Onnx
0056     if "torch" in sys.argv:
0057         backend = ExaTrkXBackend.Torch
0058 
0059     detector = acts.examples.GenericDetector()
0060     trackingGeometry = detector.trackingGeometry()
0061 
0062     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0063 
0064     srcdir = Path(__file__).resolve().parent.parent.parent.parent
0065 
0066     geometrySelection = (
0067         srcdir
0068         / "Examples/Algorithms/TrackFinding/share/geoSelection-genericDetector.json"
0069     )
0070     assert geometrySelection.exists()
0071 
0072     digiConfigFile = (
0073         srcdir
0074         / "Examples/Algorithms/Digitization/share/default-smearing-config-generic.json"
0075     )
0076     assert digiConfigFile.exists()
0077 
0078     if backend == ExaTrkXBackend.Torch:
0079         modelDir = Path.cwd() / "torchscript_models"
0080         assert (modelDir / "embed.pt").exists()
0081         assert (modelDir / "filter.pt").exists()
0082         assert (modelDir / "gnn.pt").exists()
0083     else:
0084         modelDir = Path.cwd() / "onnx_models"
0085         assert (modelDir / "embedding.onnx").exists()
0086         assert (modelDir / "filtering.onnx").exists()
0087         assert (modelDir / "gnn.onnx").exists()
0088 
0089     s = acts.examples.Sequencer(events=2, numThreads=1)
0090     s.config.logLevel = acts.logging.INFO
0091 
0092     rnd = acts.examples.RandomNumbers()
0093     outputDir = Path(os.getcwd())
0094 
0095     runGNNTrackFinding(
0096         trackingGeometry,
0097         field,
0098         outputDir,
0099         digiConfigFile,
0100         geometrySelection,
0101         backend,
0102         modelDir,
0103         outputRoot=True,
0104         outputCsv=False,
0105         s=s,
0106     )