Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-08-28 08:12:42

0001 #!/usr/bin/env python3
0002 
0003 from pathlib import Path
0004 import os
0005 import sys
0006 
0007 import acts.examples
0008 import acts
0009 from acts.examples.reconstruction import addGnn, GnnBackend
0010 from acts import UnitConstants as u
0011 
0012 from digitization import runDigitization
0013 
0014 
0015 def runGNNTrackFinding(
0016     trackingGeometry,
0017     field,
0018     outputDir,
0019     digiConfigFile,
0020     geometrySelection,
0021     backend,
0022     modelDir,
0023     outputRoot=False,
0024     outputCsv=False,
0025     s=None,
0026 ):
0027     s = runDigitization(
0028         trackingGeometry,
0029         field,
0030         outputDir,
0031         digiConfigFile=digiConfigFile,
0032         particlesInput=None,
0033         outputRoot=outputRoot,
0034         outputCsv=outputCsv,
0035         s=s,
0036     )
0037 
0038     addGnn(
0039         s,
0040         trackingGeometry,
0041         geometrySelection,
0042         modelDir,
0043         backend=backend,
0044         outputDirRoot=outputDir if outputRoot else None,
0045     )
0046 
0047     s.run()
0048 
0049 
0050 if "__main__" == __name__:
0051 
0052     backend = GnnBackend.Torch
0053 
0054     if "onnx" in sys.argv:
0055         backend = GnnBackend.Onnx
0056     if "torch" in sys.argv:
0057         backend = GnnBackend.Torch
0058 
0059     detector = acts.examples.GenericDetector()
0060     trackingGeometry = detector.trackingGeometry()
0061 
0062     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0063 
0064     srcdir = Path(__file__).resolve().parent.parent.parent.parent
0065 
0066     geometrySelection = srcdir / "Examples/Configs/generic-seeding-config.json"
0067     assert geometrySelection.exists()
0068 
0069     digiConfigFile = srcdir / "Examples/Configs/generic-digi-smearing-config.json"
0070     assert digiConfigFile.exists()
0071 
0072     if backend == GnnBackend.Torch:
0073         modelDir = Path.cwd() / "torchscript_models"
0074         assert (modelDir / "embed.pt").exists()
0075         assert (modelDir / "filter.pt").exists()
0076         assert (modelDir / "gnn.pt").exists()
0077     else:
0078         modelDir = Path.cwd() / "onnx_models"
0079         assert (modelDir / "embedding.onnx").exists()
0080         assert (modelDir / "filtering.onnx").exists()
0081         assert (modelDir / "gnn.onnx").exists()
0082 
0083     s = acts.examples.Sequencer(events=2, numThreads=1)
0084     s.config.logLevel = acts.logging.INFO
0085 
0086     rnd = acts.examples.RandomNumbers()
0087     outputDir = Path(os.getcwd())
0088 
0089     runGNNTrackFinding(
0090         trackingGeometry,
0091         field,
0092         outputDir,
0093         digiConfigFile,
0094         geometrySelection,
0095         backend,
0096         modelDir,
0097         outputRoot=True,
0098         outputCsv=False,
0099         s=s,
0100     )