Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-07 09:01:27

0001 import os
0002 import inspect
0003 from pathlib import Path
0004 import shutil
0005 import math
0006 import sys
0007 import tempfile
0008 
0009 import pytest
0010 
0011 from helpers import (
0012     dd4hepEnabled,
0013     hepmc3Enabled,
0014     geant4Enabled,
0015     AssertCollectionExistsAlg,
0016 )
0017 
0018 import acts
0019 from acts import UnitConstants as u
0020 from acts.examples import (
0021     ObjPropagationStepsWriter,
0022     CsvParticleWriter,
0023     CsvSimHitWriter,
0024     CsvTrackParameterWriter,
0025     CsvTrackWriter,
0026     CsvTrackingGeometryWriter,
0027     CsvMeasurementWriter,
0028     Sequencer,
0029     GenericDetector,
0030 )
0031 from acts.examples.json import (
0032     JsonMaterialWriter,
0033     JsonFormat,
0034 )
0035 
0036 from acts.examples.root import (
0037     RootMaterialTrackWriter,
0038     RootMaterialWriter,
0039     RootMeasurementWriter,
0040     RootPropagationStepsWriter,
0041     RootParticleWriter,
0042     RootSimHitWriter,
0043     RootTrackParameterWriter,
0044     RootTrackStatesWriter,
0045     RootTrackSummaryWriter,
0046     RootTrackFinderNTupleWriter,
0047     RootVertexNTupleWriter,
0048 )
0049 
0050 from acts.examples.odd import getOpenDataDetectorDirectory
0051 
0052 
0053 def assert_csv_output(csv_path, stem, num_files, size_threshold=100):
0054     __tracebackhide__ = True
0055     assert (
0056         len([f for f in csv_path.iterdir() if f.name.endswith(stem + ".csv")])
0057         == num_files
0058     )
0059     assert all(
0060         [
0061             f.stat().st_size > size_threshold
0062             for f in csv_path.iterdir()
0063             if f.name.endswith(stem + ".csv")
0064         ]
0065     )
0066 
0067 
0068 @pytest.mark.obj
0069 def test_obj_propagation_step_writer(tmp_path, trk_geo, conf_const, basic_prop_seq):
0070     with pytest.raises(TypeError):
0071         ObjPropagationStepsWriter()
0072 
0073     obj = tmp_path / "obj"
0074     obj.mkdir()
0075 
0076     s, alg = basic_prop_seq(trk_geo)
0077     w = conf_const(
0078         ObjPropagationStepsWriter,
0079         acts.logging.INFO,
0080         collection=alg.config.outputSummaryCollection,
0081         outputDir=str(obj),
0082     )
0083 
0084     s.addWriter(w)
0085 
0086     s.run()
0087 
0088     assert len([f for f in obj.iterdir() if f.is_file()]) == s.config.events
0089     for f in obj.iterdir():
0090         assert f.stat().st_size > 1024
0091 
0092 
0093 @pytest.mark.csv
0094 def test_csv_particle_writer(tmp_path, conf_const, ptcl_gun):
0095     s = Sequencer(numThreads=1, events=10)
0096     _, h3conv = ptcl_gun(s)
0097 
0098     out = tmp_path / "csv"
0099 
0100     out.mkdir()
0101 
0102     s.addWriter(
0103         conf_const(
0104             CsvParticleWriter,
0105             acts.logging.INFO,
0106             inputParticles=h3conv.config.outputParticles,
0107             outputStem="particle",
0108             outputDir=str(out),
0109         )
0110     )
0111 
0112     s.run()
0113 
0114     assert_csv_output(out, "particle", s.config.events, size_threshold=200)
0115 
0116 
0117 @pytest.mark.root
0118 def test_root_prop_step_writer(
0119     tmp_path, trk_geo, conf_const, basic_prop_seq, assert_root_hash
0120 ):
0121     with pytest.raises(TypeError):
0122         RootPropagationStepsWriter()
0123 
0124     file = tmp_path / "prop_steps.root"
0125     assert not file.exists()
0126 
0127     s, alg = basic_prop_seq(trk_geo)
0128     w = conf_const(
0129         RootPropagationStepsWriter,
0130         acts.logging.INFO,
0131         collection=alg.config.outputSummaryCollection,
0132         filePath=str(file),
0133     )
0134 
0135     s.addWriter(w)
0136 
0137     s.run()
0138 
0139     assert file.exists()
0140     assert file.stat().st_size > 2**10 * 50
0141     assert_root_hash(file.name, file)
0142 
0143 
0144 @pytest.mark.root
0145 def test_root_particle_writer(tmp_path, conf_const, ptcl_gun, assert_root_hash):
0146     s = Sequencer(numThreads=1, events=10)
0147     _, h3conv = ptcl_gun(s)
0148 
0149     file = tmp_path / "particles.root"
0150 
0151     assert not file.exists()
0152 
0153     s.addWriter(
0154         conf_const(
0155             RootParticleWriter,
0156             acts.logging.INFO,
0157             inputParticles=h3conv.config.outputParticles,
0158             filePath=str(file),
0159         )
0160     )
0161 
0162     s.run()
0163 
0164     assert file.exists()
0165     assert file.stat().st_size > 1024 * 10
0166     assert_root_hash(file.name, file)
0167 
0168 
0169 @pytest.mark.root
0170 def test_root_meas_writer(tmp_path, fatras, trk_geo, assert_root_hash):
0171     s = Sequencer(numThreads=1, events=10)
0172     evGen, simAlg, digiAlg = fatras(s)
0173 
0174     out = tmp_path / "meas.root"
0175 
0176     assert not out.exists()
0177 
0178     config = RootMeasurementWriter.Config(
0179         inputMeasurements=digiAlg.config.outputMeasurements,
0180         inputClusters=digiAlg.config.outputClusters,
0181         inputSimHits=simAlg.config.outputSimHits,
0182         inputMeasurementSimHitsMap=digiAlg.config.outputMeasurementSimHitsMap,
0183         filePath=str(out),
0184         surfaceByIdentifier=trk_geo.geoIdSurfaceMap(),
0185     )
0186     s.addWriter(RootMeasurementWriter(level=acts.logging.INFO, config=config))
0187     s.run()
0188 
0189     assert out.exists()
0190     assert out.stat().st_size > 40000
0191     assert_root_hash(out.name, out)
0192 
0193 
0194 @pytest.mark.root
0195 def test_root_simhits_writer(tmp_path, fatras, conf_const, assert_root_hash):
0196     s = Sequencer(numThreads=1, events=10)
0197     evGen, simAlg, digiAlg = fatras(s)
0198 
0199     out = tmp_path / "meas.root"
0200 
0201     assert not out.exists()
0202 
0203     s.addWriter(
0204         conf_const(
0205             RootSimHitWriter,
0206             level=acts.logging.INFO,
0207             inputSimHits=simAlg.config.outputSimHits,
0208             filePath=str(out),
0209         )
0210     )
0211 
0212     s.run()
0213     assert out.exists()
0214     assert out.stat().st_size > 2e4
0215     assert_root_hash(out.name, out)
0216 
0217 
0218 @pytest.mark.root
0219 def test_root_tracksummary_writer(tmp_path, fatras, conf_const):
0220     detector = GenericDetector()
0221     trackingGeometry = detector.trackingGeometry()
0222     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0223     s = Sequencer(numThreads=1, events=10)
0224 
0225     from truth_tracking_kalman import runTruthTrackingKalman
0226 
0227     # This also runs the RootTrackSummaryWriter with truth information
0228     runTruthTrackingKalman(
0229         trackingGeometry,
0230         field,
0231         digiConfigFile=Path(
0232             str(
0233                 Path(__file__).parent.parent.parent.parent
0234                 / "Examples/Configs/generic-digi-smearing-config.json"
0235             )
0236         ),
0237         outputDir=tmp_path,
0238         s=s,
0239     )
0240 
0241     # Run the RootTrackSummaryWriter without the truth information
0242     s.addWriter(
0243         conf_const(
0244             RootTrackSummaryWriter,
0245             level=acts.logging.INFO,
0246             inputTracks="tracks",
0247             filePath=str(tmp_path / "track_summary_kf_no_truth.root"),
0248         )
0249     )
0250 
0251     s.run()
0252     assert (tmp_path / "tracksummary_kf.root").exists()
0253     assert (tmp_path / "track_summary_kf_no_truth.root").exists()
0254 
0255 
0256 @pytest.mark.csv
0257 def test_csv_meas_writer(tmp_path, fatras, trk_geo, conf_const):
0258     s = Sequencer(numThreads=1, events=10)
0259     evGen, simAlg, digiAlg = fatras(s)
0260 
0261     out = tmp_path / "csv"
0262     out.mkdir()
0263 
0264     s.addWriter(
0265         conf_const(
0266             CsvMeasurementWriter,
0267             level=acts.logging.INFO,
0268             inputMeasurements=digiAlg.config.outputMeasurements,
0269             inputClusters=digiAlg.config.outputClusters,
0270             inputMeasurementSimHitsMap=digiAlg.config.outputMeasurementSimHitsMap,
0271             outputDir=str(out),
0272         )
0273     )
0274     s.run()
0275 
0276     assert_csv_output(out, "measurements", s.config.events, size_threshold=10)
0277     assert_csv_output(out, "measurement-simhit-map", s.config.events, size_threshold=10)
0278     assert_csv_output(out, "cells", s.config.events, size_threshold=10)
0279 
0280 
0281 @pytest.mark.csv
0282 def test_csv_simhits_writer(tmp_path, fatras, conf_const):
0283     s = Sequencer(numThreads=1, events=10)
0284     evGen, simAlg, digiAlg = fatras(s)
0285 
0286     out = tmp_path / "csv"
0287     out.mkdir()
0288 
0289     s.addWriter(
0290         conf_const(
0291             CsvSimHitWriter,
0292             level=acts.logging.INFO,
0293             inputSimHits=simAlg.config.outputSimHits,
0294             outputDir=str(out),
0295             outputStem="hits",
0296         )
0297     )
0298 
0299     s.run()
0300     assert_csv_output(out, "hits", s.config.events, size_threshold=200)
0301 
0302 
0303 @pytest.mark.parametrize(
0304     "writer",
0305     [
0306         RootPropagationStepsWriter,
0307         RootParticleWriter,
0308         RootTrackFinderNTupleWriter,
0309         RootTrackParameterWriter,
0310         RootMaterialTrackWriter,
0311         RootMeasurementWriter,
0312         RootMaterialWriter,
0313         RootSimHitWriter,
0314         RootTrackStatesWriter,
0315         RootTrackSummaryWriter,
0316         RootVertexNTupleWriter,
0317     ],
0318 )
0319 @pytest.mark.root
0320 def test_root_writer_interface(writer, conf_const, tmp_path, trk_geo):
0321     assert hasattr(writer, "Config")
0322 
0323     config = writer.Config
0324 
0325     assert hasattr(config, "filePath")
0326     assert hasattr(config, "fileMode")
0327 
0328     f = tmp_path / "target.root"
0329     assert not f.exists()
0330 
0331     kw = {"level": acts.logging.INFO, "filePath": str(f)}
0332 
0333     for k, _ in inspect.getmembers(config):
0334         if k.startswith("input"):
0335             kw[k] = "collection"
0336         if k == "surfaceByIdentifier":
0337             kw[k] = trk_geo.geoIdSurfaceMap()
0338 
0339     assert conf_const(writer, **kw)
0340 
0341     assert f.exists()
0342 
0343 
0344 @pytest.mark.parametrize(
0345     "writer",
0346     [
0347         CsvParticleWriter,
0348         CsvMeasurementWriter,
0349         CsvSimHitWriter,
0350         CsvTrackWriter,
0351         CsvTrackingGeometryWriter,
0352     ],
0353 )
0354 @pytest.mark.csv
0355 def test_csv_writer_interface(writer, conf_const, tmp_path, trk_geo):
0356     assert hasattr(writer, "Config")
0357 
0358     config = writer.Config
0359 
0360     assert hasattr(config, "outputDir")
0361 
0362     kw = {"level": acts.logging.INFO, "outputDir": str(tmp_path)}
0363 
0364     for k, _ in inspect.getmembers(config):
0365         if k.startswith("input"):
0366             kw[k] = "collection"
0367         if k == "trackingGeometry":
0368             kw[k] = trk_geo
0369         if k == "outputStem":
0370             kw[k] = "stem"
0371 
0372     assert conf_const(writer, **kw)
0373 
0374 
0375 @pytest.mark.root
0376 @pytest.mark.odd
0377 @pytest.mark.skipif(not dd4hepEnabled, reason="DD4hep not set up")
0378 def test_root_material_writer(tmp_path, assert_root_hash):
0379     from acts.examples.odd import getOpenDataDetector
0380 
0381     with getOpenDataDetector() as detector:
0382         trackingGeometry = detector.trackingGeometry()
0383 
0384         out = tmp_path / "material.root"
0385 
0386         assert not out.exists()
0387 
0388         rmw = RootMaterialWriter(level=acts.logging.WARNING, filePath=str(out))
0389         assert out.exists()
0390         assert out.stat().st_size > 0 and out.stat().st_size < 500
0391         rmw.write(trackingGeometry)
0392 
0393         assert out.stat().st_size > 1000
0394         assert_root_hash(out.name, out)
0395 
0396 
0397 @pytest.mark.json
0398 @pytest.mark.odd
0399 @pytest.mark.parametrize("fmt", [JsonFormat.Json, JsonFormat.Cbor])
0400 @pytest.mark.skipif(not dd4hepEnabled, reason="DD4hep not set up")
0401 def test_json_material_writer(tmp_path, fmt):
0402     from acts.examples.dd4hep import DD4hepDetector
0403 
0404     detector = DD4hepDetector(
0405         xmlFileNames=[str(getOpenDataDetectorDirectory() / "xml/OpenDataDetector.xml")]
0406     )
0407     trackingGeometry = detector.trackingGeometry()
0408 
0409     out = (tmp_path / "material").with_suffix("." + fmt.name.lower())
0410 
0411     assert not out.exists()
0412 
0413     jmw = JsonMaterialWriter(
0414         level=acts.logging.WARNING, fileName=str(out.with_suffix("")), writeFormat=fmt
0415     )
0416     assert not out.exists()
0417     jmw.write(trackingGeometry)
0418 
0419     assert out.stat().st_size > 1000
0420 
0421 
0422 @pytest.mark.csv
0423 def test_csv_multitrajectory_writer(tmp_path):
0424     detector = GenericDetector()
0425     trackingGeometry = detector.trackingGeometry()
0426     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0427 
0428     from truth_tracking_kalman import runTruthTrackingKalman
0429 
0430     s = Sequencer(numThreads=1, events=10)
0431     runTruthTrackingKalman(
0432         trackingGeometry,
0433         field,
0434         digiConfigFile=Path(
0435             str(
0436                 Path(__file__).parent.parent.parent.parent
0437                 / "Examples/Configs/generic-digi-smearing-config.json"
0438             )
0439         ),
0440         outputDir=tmp_path,
0441         s=s,
0442     )
0443 
0444     csv_dir = tmp_path / "csv"
0445     csv_dir.mkdir()
0446     s.addWriter(
0447         CsvTrackWriter(
0448             level=acts.logging.INFO,
0449             inputTracks="tracks",
0450             inputMeasurementParticlesMap="measurement_particles_map",
0451             outputDir=str(csv_dir),
0452         )
0453     )
0454     s.run()
0455     assert_csv_output(csv_dir, "CKFtracks", s.config.events, size_threshold=20)
0456 
0457 
0458 @pytest.mark.csv
0459 def test_csv_trackparameter_writer(tmp_path):
0460     detector = GenericDetector()
0461     trackingGeometry = detector.trackingGeometry()
0462     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0463 
0464     from truth_tracking_kalman import runTruthTrackingKalman
0465 
0466     s = Sequencer(numThreads=1, events=10)
0467     runTruthTrackingKalman(
0468         trackingGeometry,
0469         field,
0470         digiConfigFile=Path(
0471             str(
0472                 Path(__file__).parent.parent.parent.parent
0473                 / "Examples/Configs/generic-digi-smearing-config.json"
0474             )
0475         ),
0476         outputDir=tmp_path,
0477         s=s,
0478     )
0479 
0480     csv_dir = tmp_path / "csv"
0481     csv_dir.mkdir()
0482     s.addWriter(
0483         CsvTrackParameterWriter(
0484             level=acts.logging.INFO,
0485             inputTracks="tracks",
0486             outputStem="track_parameters",
0487             outputDir=str(csv_dir),
0488         )
0489     )
0490     s.run()
0491     assert_csv_output(csv_dir, "track_parameters", s.config.events, size_threshold=20)