Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-16 08:13:58

0001 import os
0002 import inspect
0003 from pathlib import Path
0004 import shutil
0005 import math
0006 import sys
0007 import tempfile
0008 
0009 import pytest
0010 
0011 from helpers import (
0012     dd4hepEnabled,
0013     hepmc3Enabled,
0014     geant4Enabled,
0015     AssertCollectionExistsAlg,
0016 )
0017 
0018 import acts
0019 from acts import UnitConstants as u
0020 from acts.examples import (
0021     ObjPropagationStepsWriter,
0022     TrackFinderNTupleWriter,
0023     RootPropagationStepsWriter,
0024     RootParticleWriter,
0025     RootTrackParameterWriter,
0026     RootMaterialTrackWriter,
0027     RootMaterialWriter,
0028     RootSimHitWriter,
0029     RootTrackStatesWriter,
0030     RootTrackSummaryWriter,
0031     VertexNTupleWriter,
0032     RootMeasurementWriter,
0033     CsvParticleWriter,
0034     CsvSimHitWriter,
0035     CsvTrackParameterWriter,
0036     CsvTrackWriter,
0037     CsvTrackingGeometryWriter,
0038     CsvMeasurementWriter,
0039     JsonMaterialWriter,
0040     JsonFormat,
0041     Sequencer,
0042     GenericDetector,
0043 )
0044 from acts.examples.odd import getOpenDataDetectorDirectory
0045 
0046 
0047 def assert_csv_output(csv_path, stem, num_files, size_threshold=100):
0048     __tracebackhide__ = True
0049     assert (
0050         len([f for f in csv_path.iterdir() if f.name.endswith(stem + ".csv")])
0051         == num_files
0052     )
0053     assert all(
0054         [
0055             f.stat().st_size > size_threshold
0056             for f in csv_path.iterdir()
0057             if f.name.endswith(stem + ".csv")
0058         ]
0059     )
0060 
0061 
0062 @pytest.mark.obj
0063 def test_obj_propagation_step_writer(tmp_path, trk_geo, conf_const, basic_prop_seq):
0064     with pytest.raises(TypeError):
0065         ObjPropagationStepsWriter()
0066 
0067     obj = tmp_path / "obj"
0068     obj.mkdir()
0069 
0070     s, alg = basic_prop_seq(trk_geo)
0071     w = conf_const(
0072         ObjPropagationStepsWriter,
0073         acts.logging.INFO,
0074         collection=alg.config.outputSummaryCollection,
0075         outputDir=str(obj),
0076     )
0077 
0078     s.addWriter(w)
0079 
0080     s.run()
0081 
0082     assert len([f for f in obj.iterdir() if f.is_file()]) == s.config.events
0083     for f in obj.iterdir():
0084         assert f.stat().st_size > 1024
0085 
0086 
0087 @pytest.mark.csv
0088 def test_csv_particle_writer(tmp_path, conf_const, ptcl_gun):
0089     s = Sequencer(numThreads=1, events=10)
0090     _, h3conv = ptcl_gun(s)
0091 
0092     out = tmp_path / "csv"
0093 
0094     out.mkdir()
0095 
0096     s.addWriter(
0097         conf_const(
0098             CsvParticleWriter,
0099             acts.logging.INFO,
0100             inputParticles=h3conv.config.outputParticles,
0101             outputStem="particle",
0102             outputDir=str(out),
0103         )
0104     )
0105 
0106     s.run()
0107 
0108     assert_csv_output(out, "particle", s.config.events, size_threshold=200)
0109 
0110 
0111 @pytest.mark.root
0112 def test_root_prop_step_writer(
0113     tmp_path, trk_geo, conf_const, basic_prop_seq, assert_root_hash
0114 ):
0115     with pytest.raises(TypeError):
0116         RootPropagationStepsWriter()
0117 
0118     file = tmp_path / "prop_steps.root"
0119     assert not file.exists()
0120 
0121     s, alg = basic_prop_seq(trk_geo)
0122     w = conf_const(
0123         RootPropagationStepsWriter,
0124         acts.logging.INFO,
0125         collection=alg.config.outputSummaryCollection,
0126         filePath=str(file),
0127     )
0128 
0129     s.addWriter(w)
0130 
0131     s.run()
0132 
0133     assert file.exists()
0134     assert file.stat().st_size > 2**10 * 50
0135     assert_root_hash(file.name, file)
0136 
0137 
0138 @pytest.mark.root
0139 def test_root_particle_writer(tmp_path, conf_const, ptcl_gun, assert_root_hash):
0140     s = Sequencer(numThreads=1, events=10)
0141     _, h3conv = ptcl_gun(s)
0142 
0143     file = tmp_path / "particles.root"
0144 
0145     assert not file.exists()
0146 
0147     s.addWriter(
0148         conf_const(
0149             RootParticleWriter,
0150             acts.logging.INFO,
0151             inputParticles=h3conv.config.outputParticles,
0152             filePath=str(file),
0153         )
0154     )
0155 
0156     s.run()
0157 
0158     assert file.exists()
0159     assert file.stat().st_size > 1024 * 10
0160     assert_root_hash(file.name, file)
0161 
0162 
0163 @pytest.mark.root
0164 def test_root_meas_writer(tmp_path, fatras, trk_geo, assert_root_hash):
0165     s = Sequencer(numThreads=1, events=10)
0166     evGen, simAlg, digiAlg = fatras(s)
0167 
0168     out = tmp_path / "meas.root"
0169 
0170     assert not out.exists()
0171 
0172     config = RootMeasurementWriter.Config(
0173         inputMeasurements=digiAlg.config.outputMeasurements,
0174         inputClusters=digiAlg.config.outputClusters,
0175         inputSimHits=simAlg.config.outputSimHits,
0176         inputMeasurementSimHitsMap=digiAlg.config.outputMeasurementSimHitsMap,
0177         filePath=str(out),
0178         surfaceByIdentifier=trk_geo.geoIdSurfaceMap(),
0179     )
0180     s.addWriter(RootMeasurementWriter(level=acts.logging.INFO, config=config))
0181     s.run()
0182 
0183     assert out.exists()
0184     assert out.stat().st_size > 40000
0185     assert_root_hash(out.name, out)
0186 
0187 
0188 @pytest.mark.root
0189 def test_root_simhits_writer(tmp_path, fatras, conf_const, assert_root_hash):
0190     s = Sequencer(numThreads=1, events=10)
0191     evGen, simAlg, digiAlg = fatras(s)
0192 
0193     out = tmp_path / "meas.root"
0194 
0195     assert not out.exists()
0196 
0197     s.addWriter(
0198         conf_const(
0199             RootSimHitWriter,
0200             level=acts.logging.INFO,
0201             inputSimHits=simAlg.config.outputSimHits,
0202             filePath=str(out),
0203         )
0204     )
0205 
0206     s.run()
0207     assert out.exists()
0208     assert out.stat().st_size > 2e4
0209     assert_root_hash(out.name, out)
0210 
0211 
0212 @pytest.mark.root
0213 def test_root_tracksummary_writer(tmp_path, fatras, conf_const):
0214     detector = GenericDetector()
0215     trackingGeometry = detector.trackingGeometry()
0216     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0217     s = Sequencer(numThreads=1, events=10)
0218 
0219     from truth_tracking_kalman import runTruthTrackingKalman
0220 
0221     # This also runs the RootTrackSummaryWriter with truth information
0222     runTruthTrackingKalman(
0223         trackingGeometry,
0224         field,
0225         digiConfigFile=Path(
0226             str(
0227                 Path(__file__).parent.parent.parent.parent
0228                 / "Examples/Configs/generic-digi-smearing-config.json"
0229             )
0230         ),
0231         outputDir=tmp_path,
0232         s=s,
0233     )
0234 
0235     # Run the RootTrackSummaryWriter without the truth information
0236     s.addWriter(
0237         conf_const(
0238             RootTrackSummaryWriter,
0239             level=acts.logging.INFO,
0240             inputTracks="tracks",
0241             filePath=str(tmp_path / "track_summary_kf_no_truth.root"),
0242         )
0243     )
0244 
0245     s.run()
0246     assert (tmp_path / "tracksummary_kf.root").exists()
0247     assert (tmp_path / "track_summary_kf_no_truth.root").exists()
0248 
0249 
0250 @pytest.mark.csv
0251 def test_csv_meas_writer(tmp_path, fatras, trk_geo, conf_const):
0252     s = Sequencer(numThreads=1, events=10)
0253     evGen, simAlg, digiAlg = fatras(s)
0254 
0255     out = tmp_path / "csv"
0256     out.mkdir()
0257 
0258     s.addWriter(
0259         conf_const(
0260             CsvMeasurementWriter,
0261             level=acts.logging.INFO,
0262             inputMeasurements=digiAlg.config.outputMeasurements,
0263             inputClusters=digiAlg.config.outputClusters,
0264             inputMeasurementSimHitsMap=digiAlg.config.outputMeasurementSimHitsMap,
0265             outputDir=str(out),
0266         )
0267     )
0268     s.run()
0269 
0270     assert_csv_output(out, "measurements", s.config.events, size_threshold=10)
0271     assert_csv_output(out, "measurement-simhit-map", s.config.events, size_threshold=10)
0272     assert_csv_output(out, "cells", s.config.events, size_threshold=10)
0273 
0274 
0275 @pytest.mark.csv
0276 def test_csv_simhits_writer(tmp_path, fatras, conf_const):
0277     s = Sequencer(numThreads=1, events=10)
0278     evGen, simAlg, digiAlg = fatras(s)
0279 
0280     out = tmp_path / "csv"
0281     out.mkdir()
0282 
0283     s.addWriter(
0284         conf_const(
0285             CsvSimHitWriter,
0286             level=acts.logging.INFO,
0287             inputSimHits=simAlg.config.outputSimHits,
0288             outputDir=str(out),
0289             outputStem="hits",
0290         )
0291     )
0292 
0293     s.run()
0294     assert_csv_output(out, "hits", s.config.events, size_threshold=200)
0295 
0296 
0297 @pytest.mark.parametrize(
0298     "writer",
0299     [
0300         RootPropagationStepsWriter,
0301         RootParticleWriter,
0302         TrackFinderNTupleWriter,
0303         RootTrackParameterWriter,
0304         RootMaterialTrackWriter,
0305         RootMeasurementWriter,
0306         RootMaterialWriter,
0307         RootSimHitWriter,
0308         RootTrackStatesWriter,
0309         RootTrackSummaryWriter,
0310         VertexNTupleWriter,
0311     ],
0312 )
0313 @pytest.mark.root
0314 def test_root_writer_interface(writer, conf_const, tmp_path, trk_geo):
0315     assert hasattr(writer, "Config")
0316 
0317     config = writer.Config
0318 
0319     assert hasattr(config, "filePath")
0320     assert hasattr(config, "fileMode")
0321 
0322     f = tmp_path / "target.root"
0323     assert not f.exists()
0324 
0325     kw = {"level": acts.logging.INFO, "filePath": str(f)}
0326 
0327     for k, _ in inspect.getmembers(config):
0328         if k.startswith("input"):
0329             kw[k] = "collection"
0330         if k == "surfaceByIdentifier":
0331             kw[k] = trk_geo.geoIdSurfaceMap()
0332 
0333     assert conf_const(writer, **kw)
0334 
0335     assert f.exists()
0336 
0337 
0338 @pytest.mark.parametrize(
0339     "writer",
0340     [
0341         CsvParticleWriter,
0342         CsvMeasurementWriter,
0343         CsvSimHitWriter,
0344         CsvTrackWriter,
0345         CsvTrackingGeometryWriter,
0346     ],
0347 )
0348 @pytest.mark.csv
0349 def test_csv_writer_interface(writer, conf_const, tmp_path, trk_geo):
0350     assert hasattr(writer, "Config")
0351 
0352     config = writer.Config
0353 
0354     assert hasattr(config, "outputDir")
0355 
0356     kw = {"level": acts.logging.INFO, "outputDir": str(tmp_path)}
0357 
0358     for k, _ in inspect.getmembers(config):
0359         if k.startswith("input"):
0360             kw[k] = "collection"
0361         if k == "trackingGeometry":
0362             kw[k] = trk_geo
0363         if k == "outputStem":
0364             kw[k] = "stem"
0365 
0366     assert conf_const(writer, **kw)
0367 
0368 
0369 @pytest.mark.root
0370 @pytest.mark.odd
0371 @pytest.mark.skipif(not dd4hepEnabled, reason="DD4hep not set up")
0372 def test_root_material_writer(tmp_path, assert_root_hash):
0373     from acts.examples.odd import getOpenDataDetector
0374 
0375     with getOpenDataDetector() as detector:
0376         trackingGeometry = detector.trackingGeometry()
0377 
0378         out = tmp_path / "material.root"
0379 
0380         assert not out.exists()
0381 
0382         rmw = RootMaterialWriter(level=acts.logging.WARNING, filePath=str(out))
0383         assert out.exists()
0384         assert out.stat().st_size > 0 and out.stat().st_size < 500
0385         rmw.write(trackingGeometry)
0386 
0387         assert out.stat().st_size > 1000
0388         assert_root_hash(out.name, out)
0389 
0390 
0391 @pytest.mark.json
0392 @pytest.mark.odd
0393 @pytest.mark.parametrize("fmt", [JsonFormat.Json, JsonFormat.Cbor])
0394 @pytest.mark.skipif(not dd4hepEnabled, reason="DD4hep not set up")
0395 def test_json_material_writer(tmp_path, fmt):
0396     from acts.examples.dd4hep import DD4hepDetector
0397 
0398     detector = DD4hepDetector(
0399         xmlFileNames=[str(getOpenDataDetectorDirectory() / "xml/OpenDataDetector.xml")]
0400     )
0401     trackingGeometry = detector.trackingGeometry()
0402 
0403     out = (tmp_path / "material").with_suffix("." + fmt.name.lower())
0404 
0405     assert not out.exists()
0406 
0407     jmw = JsonMaterialWriter(
0408         level=acts.logging.WARNING, fileName=str(out.with_suffix("")), writeFormat=fmt
0409     )
0410     assert not out.exists()
0411     jmw.write(trackingGeometry)
0412 
0413     assert out.stat().st_size > 1000
0414 
0415 
0416 @pytest.mark.csv
0417 def test_csv_multitrajectory_writer(tmp_path):
0418     detector = GenericDetector()
0419     trackingGeometry = detector.trackingGeometry()
0420     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0421 
0422     from truth_tracking_kalman import runTruthTrackingKalman
0423 
0424     s = Sequencer(numThreads=1, events=10)
0425     runTruthTrackingKalman(
0426         trackingGeometry,
0427         field,
0428         digiConfigFile=Path(
0429             str(
0430                 Path(__file__).parent.parent.parent.parent
0431                 / "Examples/Configs/generic-digi-smearing-config.json"
0432             )
0433         ),
0434         outputDir=tmp_path,
0435         s=s,
0436     )
0437 
0438     csv_dir = tmp_path / "csv"
0439     csv_dir.mkdir()
0440     s.addWriter(
0441         CsvTrackWriter(
0442             level=acts.logging.INFO,
0443             inputTracks="tracks",
0444             inputMeasurementParticlesMap="measurement_particles_map",
0445             outputDir=str(csv_dir),
0446         )
0447     )
0448     s.run()
0449     assert_csv_output(csv_dir, "CKFtracks", s.config.events, size_threshold=20)
0450 
0451 
0452 @pytest.mark.csv
0453 def test_csv_trackparameter_writer(tmp_path):
0454     detector = GenericDetector()
0455     trackingGeometry = detector.trackingGeometry()
0456     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0457 
0458     from truth_tracking_kalman import runTruthTrackingKalman
0459 
0460     s = Sequencer(numThreads=1, events=10)
0461     runTruthTrackingKalman(
0462         trackingGeometry,
0463         field,
0464         digiConfigFile=Path(
0465             str(
0466                 Path(__file__).parent.parent.parent.parent
0467                 / "Examples/Configs/generic-digi-smearing-config.json"
0468             )
0469         ),
0470         outputDir=tmp_path,
0471         s=s,
0472     )
0473 
0474     csv_dir = tmp_path / "csv"
0475     csv_dir.mkdir()
0476     s.addWriter(
0477         CsvTrackParameterWriter(
0478             level=acts.logging.INFO,
0479             inputTracks="tracks",
0480             outputStem="track_parameters",
0481             outputDir=str(csv_dir),
0482         )
0483     )
0484     s.run()
0485     assert_csv_output(csv_dir, "track_parameters", s.config.events, size_threshold=20)