Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-03-28 07:46:23

0001 import pytest
0002 
0003 import acts
0004 from acts.examples import Sequencer
0005 
0006 from helpers import failure_threshold
0007 
0008 u = acts.UnitConstants
0009 
0010 
0011 def assert_entries(root_file, tree_name, exp=None, non_zero=False):
0012     __tracebackhide__ = True
0013     import ROOT
0014 
0015     ROOT.PyConfig.IgnoreCommandLineOptions = True
0016     ROOT.gROOT.SetBatch(True)
0017 
0018     rf = ROOT.TFile.Open(str(root_file))
0019     keys = [k.GetName() for k in rf.GetListOfKeys()]
0020     assert tree_name in keys
0021     print("Entries:", rf.Get(tree_name).GetEntries())
0022     if non_zero:
0023         assert rf.Get(tree_name).GetEntries() > 0, f"{root_file}:{tree_name}"
0024     if exp is not None:
0025         assert rf.Get(tree_name).GetEntries() == exp, f"{root_file}:{tree_name}"
0026 
0027 
0028 def assert_has_entries(root_file, tree_name):
0029     __tracebackhide__ = True
0030     assert_entries(root_file, tree_name, non_zero=True)
0031 
0032 
0033 @pytest.mark.parametrize("revFiltMomThresh", [0 * u.GeV, 1 * u.TeV])
0034 def test_truth_tracking_kalman(
0035     tmp_path, assert_root_hash, revFiltMomThresh, detector_config
0036 ):
0037     root_files = [
0038         ("trackstates_kf.root", "trackstates", 19),
0039         ("tracksummary_kf.root", "tracksummary", 10),
0040         ("performance_kf.root", None, -1),
0041     ]
0042 
0043     for fn, _, _ in root_files:
0044         fp = tmp_path / fn
0045         assert not fp.exists()
0046 
0047     with detector_config.detector:
0048         from truth_tracking_kalman import runTruthTrackingKalman
0049 
0050         field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0051 
0052         seq = Sequencer(events=10, numThreads=1)
0053 
0054         runTruthTrackingKalman(
0055             trackingGeometry=detector_config.trackingGeometry,
0056             field=field,
0057             digiConfigFile=detector_config.digiConfigFile,
0058             outputDir=tmp_path,
0059             reverseFilteringMomThreshold=revFiltMomThresh,
0060             s=seq,
0061         )
0062 
0063         seq.run()
0064 
0065     for fn, tn, ee in root_files:
0066         fp = tmp_path / fn
0067         assert fp.exists()
0068         assert fp.stat().st_size > 1024
0069         if tn is not None:
0070             assert_has_entries(fp, tn)
0071             assert_root_hash(fn, fp)
0072 
0073     import ROOT
0074 
0075     ROOT.PyConfig.IgnoreCommandLineOptions = True
0076     ROOT.gROOT.SetBatch(True)
0077     rf = ROOT.TFile.Open(str(tmp_path / "tracksummary_kf.root"))
0078     keys = [k.GetName() for k in rf.GetListOfKeys()]
0079     assert "tracksummary" in keys
0080     for entry in rf.Get("tracksummary"):
0081         assert entry.hasFittedParams
0082 
0083 
0084 def test_python_track_access(generic_detector_config, tmp_path):
0085     with generic_detector_config.detector:
0086         from truth_tracking_kalman import runTruthTrackingKalman
0087 
0088         field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0089 
0090         seq = Sequencer(events=100, numThreads=-1)
0091 
0092         runTruthTrackingKalman(
0093             trackingGeometry=generic_detector_config.trackingGeometry,
0094             field=field,
0095             digiConfigFile=generic_detector_config.digiConfigFile,
0096             outputDir=tmp_path,
0097             numParticles=100,
0098             s=seq,
0099         )
0100 
0101         import hist
0102         import math
0103         import numpy as np
0104 
0105         class TrackAccess(acts.examples.IAlgorithm):
0106             def __init__(self):
0107                 super().__init__("TrackAccess", acts.logging.INFO)
0108 
0109                 self.tracks = acts.examples.ReadDataHandle(
0110                     self, acts.examples.ConstTrackContainer, "InputTracks"
0111                 )
0112                 self.tracks.initialize("selected-tracks")
0113 
0114                 self.protoTracks = acts.examples.WriteDataHandle(
0115                     self, acts.examples.ProtoTrackContainer, "ProtoTracks"
0116                 )
0117                 self.protoTracks.initialize("proto-tracks")
0118 
0119                 self.hists = {}
0120                 self.hists["d0"] = hist.Hist(
0121                     hist.axis.Regular(10, -0.1, 0.1, name="d0"),
0122                     label="d0",
0123                 )
0124                 self.hists["z0"] = hist.Hist(
0125                     hist.axis.Regular(10, -1, 1, name="z0"),
0126                     label="z0",
0127                 )
0128                 self.hists["phi"] = hist.Hist(
0129                     hist.axis.Regular(10, math.pi, -math.pi, name="phi"),
0130                     label="phi",
0131                 )
0132                 self.hists["theta"] = hist.Hist(
0133                     hist.axis.Regular(10, 0, math.pi, name="theta"),
0134                     label="theta",
0135                 )
0136                 self.hists["eta"] = hist.Hist(
0137                     hist.axis.Regular(10, -5, 5, name="eta"),
0138                     label="eta",
0139                 )
0140                 self.hists["qop"] = hist.Hist(
0141                     hist.axis.Regular(10, -1, 1, name="qop"),
0142                     label="qop",
0143                 )
0144                 self.hists["nTracks"] = hist.Hist(
0145                     hist.axis.Regular(10, 9800, 10200, name="nTracks"),
0146                     label="nTracks",
0147                 )
0148 
0149             def execute(self, context):
0150                 self.logger.info("Track access")
0151 
0152                 tracks = self.tracks(context.eventStore)
0153                 assert isinstance(tracks, acts.examples.ConstTrackContainer)
0154 
0155                 self.logger.info("Tracks: {}", len(tracks))
0156 
0157                 self.hists["d0"].fill(tracks.parameters[:, 0])
0158                 self.hists["z0"].fill(tracks.parameters[:, 1])
0159                 self.hists["phi"].fill(tracks.parameters[:, 2])
0160                 self.hists["theta"].fill(tracks.parameters[:, 3])
0161                 self.hists["eta"].fill(-np.log(np.tan(tracks.parameters[:, 3] / 2)))
0162                 self.hists["qop"].fill(tracks.parameters[:, 4])
0163                 self.hists["nTracks"].fill(len(tracks))
0164                 # for track in tracks:
0165                 #     self.logger.info("Track: {}", track)
0166                 #     self.logger.info("Track index: {}", track.index)
0167                 #     self.logger.info("Track tip index: {}", track.tipIndex)
0168                 #     self.logger.info("Track stem index: {}", track.stemIndex)
0169                 #     self.logger.info(
0170                 #         "Track reference surface: {}", track.referenceSurface
0171                 #     )
0172                 #     self.logger.info(
0173                 #         "Track has reference surface: {}", track.hasReferenceSurface
0174                 #     )
0175                 #     self.logger.info("Track parameters: {}", track.parameters)
0176                 #     self.logger.info("Track covariance: {}", track.covariance)
0177                 #     self.logger.info(
0178                 #         "Track particle hypothesis: {}", track.particleHypothesis
0179                 #     )
0180 
0181                 wb = context.eventStore
0182                 assert not wb.exists("proto-tracks")
0183                 myProtoTracks = acts.examples.ProtoTrackContainer()
0184                 self.protoTracks(context, myProtoTracks)
0185                 assert wb.exists("proto-tracks")
0186 
0187                 return acts.examples.ProcessCode.SUCCESS
0188 
0189             def finalize(self):
0190                 for h in self.hists.values():
0191                     print(h.label)
0192                     print(h)
0193                 return acts.examples.ProcessCode.SUCCESS
0194 
0195         seq.addAlgorithm(TrackAccess())
0196 
0197         with acts.logging.ScopedFailureThreshold(acts.logging.ERROR):
0198             seq.run()
0199 
0200 
0201 def test_python_track_state_access(generic_detector_config, tmp_path):
0202     with generic_detector_config.detector:
0203         from truth_tracking_kalman import runTruthTrackingKalman
0204 
0205         field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0206 
0207         seq = Sequencer(events=10, numThreads=1)
0208 
0209         runTruthTrackingKalman(
0210             trackingGeometry=generic_detector_config.trackingGeometry,
0211             field=field,
0212             digiConfigFile=generic_detector_config.digiConfigFile,
0213             outputDir=tmp_path,
0214             numParticles=10,
0215             linkForward=True,
0216             s=seq,
0217         )
0218 
0219         class TrackStateAccess(acts.examples.IAlgorithm):
0220             def __init__(self):
0221                 super().__init__("TrackStateAccess", acts.logging.INFO)
0222 
0223                 self.tracks = acts.examples.ReadDataHandle(
0224                     self, acts.examples.ConstTrackContainer, "InputTracks"
0225                 )
0226                 self.tracks.initialize("selected-tracks")
0227 
0228             def execute(self, context):
0229                 tracks = self.tracks(context.eventStore)
0230 
0231                 import numpy as np
0232 
0233                 for track in tracks:
0234                     params = track.parameters
0235                     assert isinstance(params, acts.BoundVector)
0236                     cov = track.covariance
0237                     assert isinstance(cov, acts.BoundMatrix)
0238                     # parameters from per-proxy accessor must match the
0239                     # bulk numpy array at the same index
0240                     assert np.allclose(
0241                         [params[i] for i in range(6)],
0242                         tracks.parameters[track.index],
0243                     )
0244 
0245                     n_meas_from_summary = track.nMeasurements
0246                     n_meas_counted = 0
0247                     n_holes_counted = 0
0248 
0249                     for state in track.trackStatesReversed:
0250                         assert isinstance(state, acts.examples.ConstTrackStateProxy)
0251 
0252                         flags = state.typeFlags
0253                         # every state must be at least one of the known types
0254                         assert (
0255                             flags.isMeasurement
0256                             or flags.isOutlier
0257                             or flags.isHole
0258                             or flags.hasMaterial
0259                         )
0260 
0261                         if flags.isMeasurement:
0262                             n_meas_counted += 1
0263 
0264                         if flags.isHole:
0265                             n_holes_counted += 1
0266 
0267                         if state.hasPredicted:
0268                             pred = state.predicted
0269                             assert isinstance(pred, acts.BoundVector)
0270 
0271                         if state.hasFiltered:
0272                             filt = state.filtered
0273                             assert isinstance(filt, acts.BoundVector)
0274 
0275                         if state.hasSmoothed:
0276                             smth = state.smoothed
0277                             assert isinstance(smth, acts.BoundVector)
0278 
0279                     assert n_meas_counted == n_meas_from_summary
0280                     assert n_holes_counted == track.nHoles
0281 
0282                     assert track.isForwardLinked
0283 
0284                     rev_predicted = [
0285                         state.predicted
0286                         for state in track.trackStatesReversed
0287                         if state.hasPredicted
0288                     ]
0289                     fwd_predicted = [
0290                         state.predicted
0291                         for state in track.trackStates
0292                         if state.hasPredicted
0293                     ]
0294                     assert len(fwd_predicted) == len(rev_predicted)
0295                     for fwd, rev in zip(fwd_predicted, reversed(rev_predicted)):
0296                         assert all(fwd[i] == pytest.approx(rev[i]) for i in range(6))
0297 
0298                 return acts.examples.ProcessCode.SUCCESS
0299 
0300         seq.addAlgorithm(TrackStateAccess())
0301 
0302         with acts.logging.ScopedFailureThreshold(acts.logging.ERROR):
0303             seq.run()
0304 
0305 
0306 @pytest.mark.skip(reason="Needs updating after converter became unnecessary")
0307 def test_python_space_point_access(generic_detector_config, tmp_path):
0308     from acts.examples.simulation import (
0309         addParticleGun,
0310         ParticleConfig,
0311         EtaConfig,
0312         PhiConfig,
0313         MomentumConfig,
0314         addFatras,
0315         addDigitization,
0316         ParticleSelectorConfig,
0317         addDigiParticleSelection,
0318     )
0319 
0320     from acts.examples.reconstruction import (
0321         addSeeding,
0322         SeedingAlgorithm,
0323     )
0324 
0325     with generic_detector_config.detector:
0326         s = acts.examples.Sequencer(
0327             events=100, numThreads=1, logLevel=acts.logging.INFO
0328         )
0329         trackingGeometry = generic_detector_config.trackingGeometry
0330         field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0331         digiConfigFile = generic_detector_config.digiConfigFile
0332 
0333         rnd = acts.examples.RandomNumbers(seed=42)
0334 
0335         logger = acts.getDefaultLogger("Truth tracking example", acts.logging.INFO)
0336 
0337         addParticleGun(
0338             s,
0339             ParticleConfig(num=10, pdg=acts.PdgParticle.eMuon, randomizeCharge=True),
0340             EtaConfig(-3.0, 3.0, uniform=True),
0341             MomentumConfig(1.0 * u.GeV, 100.0 * u.GeV, transverse=True),
0342             PhiConfig(0.0, 360.0 * u.degree),
0343             vtxGen=acts.examples.GaussianVertexGenerator(
0344                 mean=acts.Vector4(0, 0, 0, 0),
0345                 stddev=acts.Vector4(0, 0, 0, 0),
0346             ),
0347             multiplicity=1,
0348             rnd=rnd,
0349         )
0350 
0351         addFatras(
0352             s,
0353             trackingGeometry,
0354             field,
0355             rnd=rnd,
0356             enableInteractions=True,
0357         )
0358 
0359         addDigitization(
0360             s,
0361             trackingGeometry,
0362             field,
0363             digiConfigFile=digiConfigFile,
0364             rnd=rnd,
0365         )
0366 
0367         addDigiParticleSelection(
0368             s,
0369             ParticleSelectorConfig(
0370                 pt=(0.9 * u.GeV, None),
0371                 measurements=(7, None),
0372                 removeNeutral=True,
0373                 removeSecondaries=True,
0374             ),
0375         )
0376 
0377         addSeeding(
0378             s,
0379             trackingGeometry,
0380             field,
0381             rnd=rnd,
0382             inputParticles="particles_generated",
0383             seedingAlgorithm=SeedingAlgorithm.GridTriplet,
0384             geoSelectionConfigFile=generic_detector_config.geometrySelection,
0385             particleHypothesis=acts.ParticleHypothesis.muon,
0386             initialSigmas=[
0387                 1 * u.mm,
0388                 1 * u.mm,
0389                 1 * u.degree,
0390                 1 * u.degree,
0391                 0 / u.GeV,
0392                 1 * u.ns,
0393             ],
0394             initialSigmaQoverPt=0.1 / u.GeV,
0395             initialSigmaPtRel=0.1,
0396             initialVarInflation=[1e0, 1e0, 1e0, 1e0, 1e0, 1e0],
0397         )
0398 
0399         spConverter = acts.examples.SpacePointConverter(
0400             inputSpacePoints="spacepoints",
0401             outputSpacePoints="spacepoints2",
0402             logger=logger,
0403         )
0404         s.addAlgorithm(spConverter)
0405 
0406         class SpacePointAccess(acts.examples.IAlgorithm):
0407 
0408             def __init__(self):
0409                 super().__init__("SpacePointAccess", acts.logging.INFO)
0410 
0411                 self.spacePoints = acts.examples.ReadDataHandle(
0412                     self, acts.SpacePointContainer2, "InputSpacePoints"
0413                 )
0414                 self.spacePoints.initialize("spacepoints2")
0415 
0416             def execute(self, context: acts.examples.AlgorithmContext):
0417                 self.logger.info("Space point access")
0418                 spacePoints: acts.SpacePointContainer2 = self.spacePoints(
0419                     context.eventStore
0420                 )
0421 
0422                 for sp in spacePoints:
0423                     self.logger.info("Space point: {}", sp.x)
0424                     self.logger.info("Space point: {}", sp.y)
0425                     self.logger.info("Space point: {}", sp.z)
0426                     self.logger.info("Space point: {}", sp.r)
0427                     self.logger.info("Space point: {}", sp.xy)
0428                     self.logger.info("Space point: {}", sp.xy[:, 0])
0429                     self.logger.info("Space point: {}", sp.xy[:, 1])
0430                     self.logger.info("Space point: {}", sp.xy[:, 2])
0431                     self.logger.info("Space point: {}", sp.xy[:, 3])
0432 
0433                 return acts.examples.ProcessCode.SUCCESS
0434 
0435         s.addAlgorithm(SpacePointAccess())
0436 
0437         s.run()
0438 
0439 
0440 def test_truth_tracking_gsf(tmp_path, assert_root_hash, detector_config):
0441     from truth_tracking_gsf import runTruthTrackingGsf
0442 
0443     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0444 
0445     seq = Sequencer(
0446         events=10,
0447         numThreads=1,
0448     )
0449 
0450     root_files = [
0451         ("trackstates_gsf.root", "trackstates"),
0452         ("tracksummary_gsf.root", "tracksummary"),
0453     ]
0454 
0455     for fn, _ in root_files:
0456         fp = tmp_path / fn
0457         assert not fp.exists()
0458 
0459     with detector_config.detector:
0460         runTruthTrackingGsf(
0461             trackingGeometry=detector_config.trackingGeometry,
0462             decorators=detector_config.decorators,
0463             field=field,
0464             digiConfigFile=detector_config.digiConfigFile,
0465             outputDir=tmp_path,
0466             s=seq,
0467         )
0468 
0469         # See https://github.com/acts-project/acts/issues/1300
0470         with failure_threshold(acts.logging.FATAL):
0471             seq.run()
0472 
0473     for fn, tn in root_files:
0474         fp = tmp_path / fn
0475         assert fp.exists()
0476         assert fp.stat().st_size > 1024
0477         if tn is not None:
0478             assert_root_hash(fn, fp)
0479 
0480 
0481 def test_refitting(tmp_path, detector_config, assert_root_hash):
0482     from truth_tracking_gsf_refitting import runRefittingGsf
0483 
0484     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0485 
0486     seq = Sequencer(
0487         events=10,
0488         numThreads=1,
0489     )
0490 
0491     with detector_config.detector:
0492         # Only check if it runs without errors right known
0493         # Changes in fitter behaviour should be caught by other tests
0494         runRefittingGsf(
0495             trackingGeometry=detector_config.trackingGeometry,
0496             field=field,
0497             digiConfigFile=detector_config.digiConfigFile,
0498             outputDir=tmp_path,
0499             s=seq,
0500         ).run()
0501 
0502     root_files = [
0503         ("trackstates_gsf_refit.root", "trackstates"),
0504         ("tracksummary_gsf_refit.root", "tracksummary"),
0505     ]
0506 
0507     for fn, tn in root_files:
0508         fp = tmp_path / fn
0509         assert fp.exists()
0510         assert fp.stat().st_size > 1024
0511         if tn is not None:
0512             assert_root_hash(fn, fp)