Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-06-25 07:48:53

0001 #!/usr/bin/env python3
0002 
0003 from pathlib import Path
0004 from typing import Optional
0005 
0006 import acts
0007 import acts.examples
0008 from acts.examples.odd import getOpenDataDetector, getOpenDataDetectorDirectory
0009 from acts.examples.simulation import (
0010     addDigiParticleSelection,
0011     ParticleSelectorConfig,
0012 )
0013 from acts.examples.reconstruction import (
0014     SeedingAlgorithm,
0015     addSeeding,
0016     addKalmanTracks,
0017 )
0018 
0019 u = acts.UnitConstants
0020 
0021 _srcdir = Path(__file__).resolve().parent.parent.parent.parent
0022 
0023 
0024 def runColliderMLTruthTracking(
0025     trackingGeometry: acts.TrackingGeometry,
0026     field: acts.MagneticFieldProvider,
0027     outputDir: Path,
0028     inputDir: Path,
0029     geoIdMapPath: Optional[Path] = None,
0030     geoIdMapSourcePrefix: str = "gen1",
0031     geoIdMapTargetPrefix: str = "gen3",
0032     decorators=[],
0033     events: int = 10,
0034     numThreads: int = 1,
0035     sample: str = "ttbar_pu200",
0036     s: Optional[acts.examples.Sequencer] = None,
0037 ):
0038     """Set up a ColliderML truth-tracking sequencer and return it with the performance writer.
0039 
0040     Returns
0041     -------
0042     (Sequencer, PythonTrackFinderPerformanceWriter)
0043         Call s.run() on the sequencer, then access perf_writer.histograms().
0044     """
0045     from acts.examples.arrow import (
0046         ColliderMLRelease1InputConverter,
0047         ParquetReader,
0048     )
0049     from acts.examples.root import (
0050         RootTrackStatesWriter,
0051         RootTrackSummaryWriter,
0052         RootTrackFitterPerformanceWriter,
0053     )
0054     from acts.examples import PythonTrackFinderPerformanceWriter
0055 
0056     outputDir = Path(outputDir)
0057     outputDir.mkdir(parents=True, exist_ok=True)
0058 
0059     s = s or acts.examples.Sequencer(
0060         events=events,
0061         numThreads=numThreads,
0062         logLevel=acts.logging.INFO,
0063         outputDir=str(outputDir),
0064         failOnUnmaskedFpe=False,
0065     )
0066 
0067     for d in decorators:
0068         s.addContextDecorator(d)
0069 
0070     rnd = acts.examples.RandomNumbers(seed=42)
0071 
0072     particles_dir = (
0073         inputDir / f"{sample}_particles" / "data" / f"{sample}_particles"
0074     ).resolve()
0075     hits_dir = (
0076         inputDir / f"{sample}_tracker_hits" / "data" / f"{sample}_tracker_hits"
0077     ).resolve()
0078 
0079     s.addReader(
0080         ParquetReader(
0081             level=acts.logging.INFO,
0082             inputDir=str(particles_dir.parent),
0083             collections={
0084                 "cml_particles": str(particles_dir),
0085                 "cml_hits": str(hits_dir),
0086             },
0087             expectedSchemas={
0088                 "cml_particles": ColliderMLRelease1InputConverter.particleSchema(),
0089                 "cml_hits": ColliderMLRelease1InputConverter.hitSchema(),
0090             },
0091         )
0092     )
0093 
0094     converter_kwargs = dict(
0095         level=acts.logging.INFO,
0096         inputParticlesTable="cml_particles",
0097         inputHitsTable="cml_hits",
0098         outputParticles="particles",
0099         outputSimHits="simhits",
0100         outputMeasurements="measurements",
0101         outputMeasurementSubset="measurement_subset",
0102         outputMeasSimHitsMap="measurement_simhits_map",
0103         outputMeasParticlesMap="measurement_particles_map",
0104         outputParticleMeasurementsMap="particle_measurements_map",
0105         trackingGeometry=trackingGeometry,
0106     )
0107     if geoIdMapPath is not None:
0108         converter_kwargs["geoIdMapPath"] = geoIdMapPath
0109         converter_kwargs["geoIdMapSourcePrefix"] = geoIdMapSourcePrefix
0110         converter_kwargs["geoIdMapTargetPrefix"] = geoIdMapTargetPrefix
0111 
0112     s.addAlgorithm(ColliderMLRelease1InputConverter(**converter_kwargs))
0113 
0114     s.addWhiteboardAlias("particles_simulated_selected", "particles")
0115     addDigiParticleSelection(
0116         s,
0117         ParticleSelectorConfig(
0118             pt=(1.0 * u.GeV, None),
0119             measurements=(5, None),
0120             removeNeutral=True,
0121         ),
0122     )
0123 
0124     addSeeding(
0125         s,
0126         trackingGeometry=trackingGeometry,
0127         field=field,
0128         rnd=rnd,
0129         seedingAlgorithm=SeedingAlgorithm.TruthEstimated,
0130         selectedParticles="particles_selected",
0131         geoSelectionConfigFile=_srcdir / "Examples/Configs/odd-seeding-config.json",
0132         initialSigmas=[
0133             1 * u.mm,
0134             1 * u.mm,
0135             1 * u.degree,
0136             1 * u.degree,
0137             0 / u.GeV,
0138             1 * u.ns,
0139         ],
0140         initialSigmaQoverPt=0.1 / u.GeV,
0141         initialSigmaPtRel=0.1,
0142         initialVarInflation=[1e0, 1e0, 1e0, 1e0, 1e0, 1e0],
0143         logLevel=acts.logging.INFO,
0144     )
0145 
0146     addKalmanTracks(
0147         s,
0148         trackingGeometry=trackingGeometry,
0149         field=field,
0150         logLevel=acts.logging.INFO,
0151     )
0152 
0153     s.addAlgorithm(
0154         acts.examples.TrackSelectorAlgorithm(
0155             level=acts.logging.INFO,
0156             inputTracks="tracks",
0157             outputTracks="selected_tracks",
0158             selectorConfig=acts.TrackSelector.Config(
0159                 minMeasurements=7,
0160             ),
0161         )
0162     )
0163     s.addWhiteboardAlias("tracks", "selected_tracks")
0164 
0165     s.addWriter(
0166         RootTrackStatesWriter(
0167             level=acts.logging.INFO,
0168             inputTracks="tracks",
0169             inputParticles="particles_selected",
0170             inputTrackParticleMatching="track_particle_matching",
0171             inputSimHits="simhits",
0172             inputMeasurementSimHitsMap="measurement_simhits_map",
0173             filePath=str(outputDir / "trackstates_kf.root"),
0174         )
0175     )
0176     s.addWriter(
0177         RootTrackSummaryWriter(
0178             level=acts.logging.INFO,
0179             inputTracks="tracks",
0180             inputParticles="particles_selected",
0181             inputTrackParticleMatching="track_particle_matching",
0182             filePath=str(outputDir / "tracksummary_kf.root"),
0183         )
0184     )
0185     s.addWriter(
0186         RootTrackFitterPerformanceWriter(
0187             level=acts.logging.INFO,
0188             inputTracks="tracks",
0189             inputParticles="particles_selected",
0190             inputTrackParticleMatching="track_particle_matching",
0191             filePath=str(outputDir / "performance_kf.root"),
0192         )
0193     )
0194 
0195     # Both writers use truth_seeded_particles as the denominator so efficiency
0196     # is evaluated only on particles that the seeding algorithm actually found.
0197     # This cleanly separates seeding coverage (determined upstream by the
0198     # geoSelection config) from KF quality.
0199 
0200     # Proto-track level: how many seeded particles survive as seed-tracks.
0201     perf_proto_cfg = PythonTrackFinderPerformanceWriter.Config()
0202     perf_proto_cfg.inputTracks = "seed-tracks"
0203     perf_proto_cfg.inputParticles = "truth_seeded_particles"
0204     perf_proto_cfg.inputTrackParticleMatching = "seed_particle_matching"
0205     perf_proto_cfg.inputParticleTrackMatching = "particle_seed_matching"
0206     perf_proto_cfg.inputParticleMeasurementsMap = "particle_measurements_map"
0207     perf_proto_writer = PythonTrackFinderPerformanceWriter(
0208         perf_proto_cfg, acts.logging.INFO
0209     )
0210     s.addWriter(perf_proto_writer)
0211 
0212     # KF track level: how many seeded particles survive KF fitting + selection.
0213     perf_cfg = PythonTrackFinderPerformanceWriter.Config()
0214     perf_cfg.inputTracks = "tracks"
0215     perf_cfg.inputParticles = "truth_seeded_particles"
0216     perf_cfg.inputTrackParticleMatching = "track_particle_matching"
0217     perf_cfg.inputParticleTrackMatching = "particle_track_matching"
0218     perf_cfg.inputParticleMeasurementsMap = "particle_measurements_map"
0219     perf_writer = PythonTrackFinderPerformanceWriter(perf_cfg, acts.logging.INFO)
0220     s.addWriter(perf_writer)
0221 
0222     return s, perf_proto_writer, perf_writer
0223 
0224 
0225 def _serialize_hists(hists):
0226     """Convert PythonTrackFinderPerformanceWriter histograms to a picklable dict.
0227 
0228     The writer exposes three C++ histogram wrapper types, none of which are
0229     picklable. We extract numpy arrays keyed by type tag:
0230       "efficiency"  — Efficiency1/2  (has .accepted / .total BoostHistogram)
0231       "profile"     — ProfileHistogram1 (.histogram is BoostProfileHistogram)
0232       "histogram"   — Histogram1/2/3   (.histogram is plain BoostHistogram)
0233     """
0234     import numpy as np
0235 
0236     out = {}
0237     for key, h in hists.items():
0238         if hasattr(h, "accepted"):
0239             # Efficiency1 / Efficiency2
0240             out[key] = {
0241                 "type": "efficiency",
0242                 "edges": np.asarray(h.total.axis(0).edges),
0243                 "accepted": np.asarray(h.accepted.values()),
0244                 "total": np.asarray(h.total.values()),
0245                 "label": h.total.axis(0).label,
0246             }
0247         elif hasattr(h, "histogram"):
0248             bh = h.histogram
0249             if hasattr(bh, "counts"):
0250                 # ProfileHistogram1 — BoostProfileHistogram
0251                 out[key] = {
0252                     "type": "profile",
0253                     "edges": np.asarray(bh.axis(0).edges),
0254                     "counts": np.asarray(bh.counts()),
0255                     "means": np.asarray(bh.means()),
0256                     "sum_of_deltas_squared": np.asarray(bh.sum_of_deltas_squared()),
0257                     "label": bh.axis(0).label,
0258                 }
0259             else:
0260                 # Histogram1/2/3 — plain BoostHistogram (from FakePlotTool)
0261                 out[key] = {
0262                     "type": "histogram",
0263                     "edges": np.asarray(bh.axis(0).edges),
0264                     "values": np.asarray(bh.values()),
0265                     "label": bh.axis(0).label,
0266                 }
0267     return out
0268 
0269 
0270 if __name__ == "__main__":
0271     import argparse
0272     import pickle
0273 
0274     parser = argparse.ArgumentParser(
0275         description="ColliderML truth-tracking Kalman filter demo on ttbar PU200."
0276     )
0277     parser.add_argument(
0278         "--input",
0279         "-i",
0280         type=Path,
0281         required=True,
0282         help="ColliderML sample root (contains ttbar_pu200_{particles,tracker_hits}/)",
0283     )
0284     parser.add_argument(
0285         "--output",
0286         "-o",
0287         type=Path,
0288         default=Path.cwd() / "colliderml_output",
0289         help="Output directory (default: colliderml_output)",
0290     )
0291     parser.add_argument(
0292         "--events",
0293         "-n",
0294         type=int,
0295         default=10,
0296         help="Number of events (default: 10)",
0297     )
0298     parser.add_argument(
0299         "-j",
0300         "--jobs",
0301         type=int,
0302         default=1,
0303         help="Number of parallel threads (default: 1)",
0304     )
0305     args = parser.parse_args()
0306 
0307     detector = getOpenDataDetector()
0308     trackingGeometry = detector.trackingGeometry()
0309     decorators = detector.contextDecorators()
0310     field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0311 
0312     s, perf_proto_writer, perf_writer = runColliderMLTruthTracking(
0313         trackingGeometry=trackingGeometry,
0314         field=field,
0315         outputDir=args.output,
0316         inputDir=args.input,
0317         decorators=decorators,
0318         events=args.events,
0319         numThreads=args.jobs,
0320     )
0321     s.run()
0322 
0323     hist_path = args.output / "histograms.pkl"
0324     with open(hist_path, "wb") as f:
0325         pickle.dump(_serialize_hists(perf_writer.histograms()), f)
0326     print(f"Saved histograms → {hist_path}")
0327 
0328     proto_hist_path = args.output / "histograms_proto.pkl"
0329     with open(proto_hist_path, "wb") as f:
0330         pickle.dump(_serialize_hists(perf_proto_writer.histograms()), f)
0331     print(f"Saved proto-track histograms → {proto_hist_path}")