File indexing completed on 2026-06-16 07:48:12
0001 import math
0002 import warnings
0003 from pathlib import Path
0004
0005 import pytest
0006
0007 import acts
0008 import acts.examples
0009 from acts import UnitConstants as u
0010 from acts.examples import Sequencer
0011
0012 from helpers import AssertCollectionExistsAlg, arrowEnabled, isCI, pythia8Enabled
0013
0014 pytestmark = pytest.mark.skipif(
0015 not arrowEnabled, reason="Arrow/Parquet bindings not built"
0016 )
0017
0018
0019 def test_coexist_with_pyarrow():
0020 """Verify acts.examples.arrow and pyarrow can be loaded into the same
0021 Python process. Regression guard for the linker-isolation design: if
0022 libActsPluginArrow leaked any arrow symbols, or if pyarrow's bundled
0023 libarrow got shadowed by a differently-built libarrow, this import
0024 sequence would fail with a missing-vtable / missing-symbol ImportError.
0025
0026 Skipped if pyarrow isn't installed."""
0027 pytest.importorskip("pyarrow")
0028
0029 import pyarrow as pa
0030 from acts.examples.arrow import ParquetWriter
0031
0032
0033
0034 table = pa.table({"x": [1, 2, 3]})
0035 assert table.num_rows == 3
0036
0037
0038 PARTICLE_FIELDS = {
0039 "particle_id",
0040 "pdg_id",
0041 "mass",
0042 "energy",
0043 "charge",
0044 "vx",
0045 "vy",
0046 "vz",
0047 "time",
0048 "px",
0049 "py",
0050 "pz",
0051 "perigee_d0",
0052 "perigee_z0",
0053 "vertex_primary",
0054 "parent_id",
0055 "primary",
0056 }
0057
0058
0059 def with_pyarrow(fn):
0060 """
0061 Some assertions use pyarrow to inspect the written Parquet files.
0062 Locally, if pyarrow is not available (or incompatible with the linked
0063 libarrow), we skip these checks with a warning. In CI mode, missing
0064 pyarrow becomes a failure.
0065 """
0066 try:
0067 import pyarrow
0068 import pyarrow.parquet as pq
0069
0070 fn(pyarrow, pq)
0071 except ImportError as e:
0072 if isCI:
0073 raise
0074 warnings.warn(f"pyarrow not available ({e}), skipping parquet checks")
0075
0076
0077 def _assert_particles_parquet(directory: Path, expected_events: int) -> None:
0078 """Verify a particles dataset directory has the expected nested schema
0079 and row count summed across shard files."""
0080 assert directory.exists(), f"{directory} does not exist"
0081 assert directory.is_dir(), f"{directory} is not a directory"
0082
0083 fragments = sorted(directory.glob("*.parquet"))
0084 assert fragments, f"{directory} contains no parquet shards"
0085 for f in fragments:
0086 assert f.stat().st_size > 0, f"{f} is empty"
0087
0088 @with_pyarrow
0089 def _check(pa, pq):
0090 total_rows = 0
0091 union_event_ids: list[int] = []
0092 first_schema = None
0093 for f in fragments:
0094 pf = pq.ParquetFile(str(f))
0095 total_rows += pf.metadata.num_rows
0096 if first_schema is None:
0097 first_schema = pf.schema_arrow
0098 else:
0099 assert pf.schema_arrow.equals(
0100 first_schema
0101 ), f"{f.name}: schema differs from {fragments[0].name}"
0102 t = pf.read(columns=["event_id"])
0103 union_event_ids.extend(t.column("event_id").to_pylist())
0104
0105 assert total_rows == expected_events, (
0106 f"{directory.name}: expected {expected_events} events, " f"got {total_rows}"
0107 )
0108
0109
0110
0111 assert sorted(union_event_ids) == list(range(expected_events)), (
0112 f"{directory.name}: event ids not contiguous 0..{expected_events-1}: "
0113 f"{sorted(union_event_ids)}"
0114 )
0115
0116 names = {first_schema.field(i).name for i in range(len(first_schema))}
0117 assert "event_id" in names, f"{directory.name}: event_id column missing"
0118 missing = PARTICLE_FIELDS - names
0119 assert not missing, f"{directory.name}: missing fields {missing}"
0120
0121
0122 for field_name in PARTICLE_FIELDS:
0123 ftype = first_schema.field(field_name).type
0124 assert pa.types.is_list(
0125 ftype
0126 ), f"{directory.name}: field '{field_name}' should be list, got {ftype}"
0127
0128
0129 table = pq.ParquetDataset(str(directory)).read()
0130 counts = [len(row) for row in table.column("particle_id").to_pylist()]
0131 assert any(
0132 c > 0 for c in counts
0133 ), f"{directory.name}: all events are empty ({counts})"
0134
0135
0136 def _assert_parquet_reader_config(
0137 inputDir: Path,
0138 collections: dict[str, str],
0139 expectedSchemas: dict[str, "acts.arrow.ArrowSchema"],
0140 expected_events: int,
0141 ) -> None:
0142 from acts.examples.arrow import ParquetReader
0143
0144 reader = ParquetReader(
0145 level=acts.logging.INFO,
0146 inputDir=str(inputDir),
0147 collections=collections,
0148 expectedSchemas=expectedSchemas,
0149 )
0150
0151 assert reader.availableEvents() == (0, expected_events)
0152
0153
0154 def _add_arrow_writer(
0155 s: Sequencer,
0156 outputDir: Path,
0157 inputs_to_tables: dict[str, str],
0158 eventsPerShard: int = 2,
0159 ) -> None:
0160 """Wire one ArrowParticleOutputConverter per (input, table) pair, and one
0161 ParquetWriter picking up all the resulting tables.
0162
0163 @param inputs_to_tables: maps whiteboard key of SimParticleContainer to
0164 the desired whiteboard key / filename basename of the Arrow table.
0165 These must differ — the same key can't hold both an ACTS container
0166 and an arrow::Table.
0167 """
0168 from acts.arrow import particleSchema
0169 from acts.examples.arrow import ArrowParticleOutputConverter, ParquetWriter
0170
0171 for input_key, table_key in inputs_to_tables.items():
0172 assert (
0173 input_key != table_key
0174 ), "Arrow output key must differ from the SimParticleContainer key"
0175 s.addAlgorithm(
0176 ArrowParticleOutputConverter(
0177 level=acts.logging.INFO,
0178 inputParticles=input_key,
0179 outputTable=table_key,
0180 )
0181 )
0182
0183 s.addWriter(
0184 ParquetWriter(
0185 level=acts.logging.INFO,
0186 outputDir=str(outputDir),
0187 collections={
0188 table_key: table_key for table_key in inputs_to_tables.values()
0189 },
0190 expectedSchemas={
0191 table_key: particleSchema() for table_key in inputs_to_tables.values()
0192 },
0193 eventsPerShard=eventsPerShard,
0194 )
0195 )
0196
0197
0198 def test_particle_gun_generated(tmp_path, ptcl_gun):
0199 """Particle gun → generated particles → Parquet."""
0200 from acts.arrow import particleSchema
0201
0202 nevents = 5
0203 s = Sequencer(numThreads=1, events=nevents)
0204 ptcl_gun(s)
0205 _add_arrow_writer(s, tmp_path, {"particles_generated": "particles_generated_arrow"})
0206 s.run()
0207
0208 _assert_particles_parquet(tmp_path / "particles_generated_arrow", nevents)
0209 _assert_parquet_reader_config(
0210 tmp_path,
0211 {"particles_generated_arrow": "particles_generated_arrow"},
0212 {"particles_generated_arrow": particleSchema()},
0213 nevents,
0214 )
0215
0216
0217 def test_particle_gun_roundtrip(tmp_path, ptcl_gun):
0218 """Write sharded Parquet, then drive a second Sequencer off ParquetReader
0219 and check the reader exposes — and processes — the same number of events
0220 that were written."""
0221 from acts.arrow import particleSchema
0222 from acts.examples.arrow import ParquetReader
0223
0224
0225
0226
0227 nevents = 5
0228 events_per_shard = 2
0229 expected_shards = (nevents + events_per_shard - 1) // events_per_shard
0230
0231 s_write = Sequencer(numThreads=1, events=nevents)
0232 ptcl_gun(s_write)
0233 _add_arrow_writer(
0234 s_write,
0235 tmp_path,
0236 {"particles_generated": "particles_generated_arrow"},
0237 eventsPerShard=events_per_shard,
0238 )
0239 s_write.run()
0240
0241 out_dir = tmp_path / "particles_generated_arrow"
0242 _assert_particles_parquet(out_dir, nevents)
0243
0244 shards = sorted(out_dir.glob("*.parquet"))
0245 assert len(shards) == expected_shards, (
0246 f"expected {expected_shards} shards for {nevents} events at "
0247 f"{events_per_shard} events/shard, got {len(shards)}: "
0248 f"{[s.name for s in shards]}"
0249 )
0250
0251 reader = ParquetReader(
0252 level=acts.logging.INFO,
0253 inputDir=str(tmp_path),
0254 collections={"particles_generated_arrow": "particles_generated_arrow"},
0255 expectedSchemas={"particles_generated_arrow": particleSchema()},
0256 )
0257 assert reader.availableEvents() == (0, nevents)
0258
0259
0260
0261 s_read = Sequencer(numThreads=1)
0262 s_read.addReader(reader)
0263 counter = AssertCollectionExistsAlg(
0264 collections="particles_generated_arrow",
0265 name="roundtrip_check",
0266 level=acts.logging.INFO,
0267 )
0268 s_read.addAlgorithm(counter)
0269 s_read.run()
0270
0271 assert counter.events_seen == nevents, (
0272 f"reader-driven sequencer processed {counter.events_seen} events, "
0273 f"expected {nevents}"
0274 )
0275
0276
0277 def test_particle_gun_fatras(tmp_path, fatras):
0278 """Particle gun + Fatras → both generated and simulated particles → Parquet."""
0279 nevents = 5
0280 s = Sequencer(numThreads=1, events=nevents)
0281 fatras(s)
0282 _add_arrow_writer(
0283 s,
0284 tmp_path,
0285 {
0286 "particles_generated": "particles_generated_arrow",
0287 "particles_simulated": "particles_simulated_arrow",
0288 },
0289 )
0290 s.run()
0291
0292 _assert_particles_parquet(tmp_path / "particles_generated_arrow", nevents)
0293 _assert_particles_parquet(tmp_path / "particles_simulated_arrow", nevents)
0294
0295
0296 SIMHIT_FIELDS = {
0297 "x",
0298 "y",
0299 "z",
0300 "true_x",
0301 "true_y",
0302 "true_z",
0303 "time",
0304 "particle_id",
0305 "detector",
0306 "volume_id",
0307 "layer_id",
0308 "surface_id",
0309 }
0310
0311
0312 def _add_simhit_arrow_writer(
0313 s: Sequencer,
0314 outputDir: Path,
0315 table_key: str = "simhits_arrow",
0316 *,
0317 withClusters: bool,
0318 eventsPerShard: int = 2,
0319 ) -> None:
0320 """Wire an ArrowSimHitOutputConverter + ParquetWriter for the simhits.
0321
0322 @param withClusters: if True, feed the cluster container and the
0323 sim-hit→measurement map so the digitized x,y,z columns are filled from
0324 the precomputed Cluster::globalPosition; otherwise leave them unwired so
0325 those columns come out NaN.
0326 """
0327 from acts.arrow import simHitSchema
0328 from acts.examples.arrow import ArrowSimHitOutputConverter, ParquetWriter
0329
0330 kwargs = {}
0331 if withClusters:
0332 kwargs["inputClusters"] = "clusters"
0333 kwargs["inputSimHitMeasurementsMap"] = "simhit_measurements_map"
0334
0335 s.addAlgorithm(
0336 ArrowSimHitOutputConverter(
0337 level=acts.logging.INFO,
0338 inputSimHits="simhits",
0339 inputParticles="particles_simulated",
0340 outputTable=table_key,
0341 **kwargs,
0342 )
0343 )
0344
0345 s.addWriter(
0346 ParquetWriter(
0347 level=acts.logging.INFO,
0348 outputDir=str(outputDir),
0349 collections={table_key: table_key},
0350 expectedSchemas={table_key: simHitSchema()},
0351 eventsPerShard=eventsPerShard,
0352 )
0353 )
0354
0355
0356 def _assert_simhits_parquet(
0357 directory: Path, expected_events: int, *, expect_digitized: bool
0358 ) -> None:
0359 """Verify a simhits dataset directory: schema, row count, and the
0360 digitized-position semantics.
0361
0362 The truth columns (true_x/true_y/true_z/time) are always finite. The
0363 digitized columns (x/y/z) are finite only for hits matched to a cluster
0364 when clusters were wired in; otherwise every value is NaN.
0365 """
0366 assert directory.exists(), f"{directory} does not exist"
0367 assert directory.is_dir(), f"{directory} is not a directory"
0368
0369 fragments = sorted(directory.glob("*.parquet"))
0370 assert fragments, f"{directory} contains no parquet shards"
0371 for f in fragments:
0372 assert f.stat().st_size > 0, f"{f} is empty"
0373
0374 @with_pyarrow
0375 def _check(pa, pq):
0376 first_schema = pq.ParquetFile(str(fragments[0])).schema_arrow
0377 names = {first_schema.field(i).name for i in range(len(first_schema))}
0378 assert "event_id" in names, f"{directory.name}: event_id column missing"
0379 missing = SIMHIT_FIELDS - names
0380 assert not missing, f"{directory.name}: missing fields {missing}"
0381 for field_name in SIMHIT_FIELDS:
0382 ftype = first_schema.field(field_name).type
0383 assert pa.types.is_list(
0384 ftype
0385 ), f"{directory.name}: field '{field_name}' should be list, got {ftype}"
0386
0387 table = pq.ParquetDataset(str(directory)).read()
0388 assert table.num_rows == expected_events, (
0389 f"{directory.name}: expected {expected_events} events, "
0390 f"got {table.num_rows}"
0391 )
0392
0393 x = table.column("x").to_pylist()
0394 y = table.column("y").to_pylist()
0395 z = table.column("z").to_pylist()
0396 tx = table.column("true_x").to_pylist()
0397 ty = table.column("true_y").to_pylist()
0398 tz = table.column("true_z").to_pylist()
0399
0400 total_hits = sum(len(row) for row in tx)
0401 assert total_hits > 0, f"{directory.name}: no sim hits across any event"
0402
0403
0404 for col in (tx, ty, tz):
0405 for row in col:
0406 for v in row:
0407 assert math.isfinite(v), f"{directory.name}: non-finite truth {v}"
0408
0409 n_digitized = 0
0410 for ex, ey, ez, etx, ety, etz in zip(x, y, z, tx, ty, tz):
0411 for xv, yv, zv, txv, tyv, tzv in zip(ex, ey, ez, etx, ety, etz):
0412 if not expect_digitized:
0413 assert (
0414 math.isnan(xv) and math.isnan(yv) and math.isnan(zv)
0415 ), f"{directory.name}: expected NaN digitized pos, got ({xv},{yv},{zv})"
0416 continue
0417
0418
0419
0420 if math.isnan(xv):
0421 continue
0422 n_digitized += 1
0423 assert math.isfinite(yv) and math.isfinite(
0424 zv
0425 ), f"{directory.name}: partial digitized pos ({xv},{yv},{zv})"
0426
0427
0428
0429 r = math.hypot(xv, yv)
0430 assert r < 1500.0 and abs(zv) < 4000.0, (
0431 f"{directory.name}: digitized pos ({xv},{yv},{zv}) outside "
0432 f"detector envelope — units or stale-position regression?"
0433 )
0434
0435
0436
0437
0438 dist = math.sqrt((xv - txv) ** 2 + (yv - tyv) ** 2 + (zv - tzv) ** 2)
0439 assert dist < 250.0, (
0440 f"{directory.name}: digitized pos {dist:.1f} mm from truth "
0441 f"({txv},{tyv},{tzv}) — wrong surface or unit error?"
0442 )
0443
0444 if expect_digitized:
0445 assert n_digitized > 0, (
0446 f"{directory.name}: clusters were wired but no hit got a "
0447 f"digitized position"
0448 )
0449
0450
0451 def test_fatras_simhits_digitized(tmp_path, fatras):
0452 """Fatras + digitization → ArrowSimHitOutputConverter reads cluster
0453 positions → Parquet. The matched-hit x,y,z must be the precomputed cluster
0454 global positions (finite, near truth)."""
0455 nevents = 3
0456 s = Sequencer(numThreads=1, events=nevents)
0457 fatras(s)
0458 _add_simhit_arrow_writer(s, tmp_path, withClusters=True)
0459 s.run()
0460
0461 _assert_simhits_parquet(tmp_path / "simhits_arrow", nevents, expect_digitized=True)
0462
0463
0464 def test_fatras_simhits_no_clusters_are_nan(tmp_path, fatras):
0465 """Without the cluster container and sim-hit→measurement map wired, the
0466 digitized x,y,z columns fall back to NaN while truth positions still
0467 populate."""
0468 nevents = 3
0469 s = Sequencer(numThreads=1, events=nevents)
0470 fatras(s)
0471 _add_simhit_arrow_writer(s, tmp_path, withClusters=False)
0472 s.run()
0473
0474 _assert_simhits_parquet(tmp_path / "simhits_arrow", nevents, expect_digitized=False)
0475
0476
0477 @pytest.mark.skipif(not pythia8Enabled, reason="Pythia8 not built")
0478 def test_pythia8_fatras(tmp_path, rng, trk_geo):
0479 """Pythia8 ttbar + Fatras → generated AND simulated particles → Parquet."""
0480 from acts.examples.simulation import (
0481 addPythia8,
0482 addGenParticleSelection,
0483 ParticleSelectorConfig,
0484 )
0485
0486 nevents = 3
0487 s = Sequencer(numThreads=1, events=nevents)
0488
0489 vtxGen = acts.examples.GaussianVertexGenerator(
0490 stddev=acts.Vector4(50 * u.um, 50 * u.um, 150 * u.mm, 20 * u.ns),
0491 mean=acts.Vector4(0, 0, 0, 0),
0492 )
0493
0494 addPythia8(
0495 s,
0496 rnd=rng,
0497 hardProcess=["Top:qqbar2ttbar=on"],
0498 npileup=2,
0499 vtxGen=vtxGen,
0500 outputDirCsv=None,
0501 outputDirRoot=None,
0502 logLevel=acts.logging.WARNING,
0503 )
0504
0505
0506
0507
0508 addGenParticleSelection(
0509 s,
0510 ParticleSelectorConfig(
0511 rho=(0.0, 24 * u.mm),
0512 absZ=(0.0, 1.0 * u.m),
0513 eta=(-3.0, 3.0),
0514 pt=(150 * u.MeV, None),
0515 ),
0516 )
0517
0518 field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T))
0519 s.addAlgorithm(
0520 acts.examples.FatrasSimulation(
0521 level=acts.logging.WARNING,
0522 inputParticles="particles_generated_selected",
0523 outputParticles="particles_simulated",
0524 outputSimHits="simhits",
0525 randomNumbers=rng,
0526 trackingGeometry=trk_geo,
0527 magneticField=field,
0528 generateHitsOnSensitive=True,
0529 emScattering=False,
0530 emEnergyLossIonisation=False,
0531 emEnergyLossRadiation=False,
0532 emPhotonConversion=False,
0533 )
0534 )
0535
0536 _add_arrow_writer(
0537 s,
0538 tmp_path,
0539 {
0540 "particles_generated": "particles_generated_arrow",
0541 "particles_simulated": "particles_simulated_arrow",
0542 },
0543 )
0544 s.run()
0545
0546 _assert_particles_parquet(tmp_path / "particles_generated_arrow", nevents)
0547 _assert_particles_parquet(tmp_path / "particles_simulated_arrow", nevents)
0548
0549
0550 def test_reader_schema_evolution_added_optional_column(tmp_path):
0551 """Read shards written without an optional column and verify the reader
0552 materializes it as null.
0553
0554 Concretely: hand-write track shards using the production track schema
0555 *minus* the optional `t` field, then drive a sequencer with
0556 `ParquetReader` configured with the *full* track schema (which has `t`
0557 as a nullable column). The dataset scanner should project missing
0558 columns to null per fragment, so the table on the whiteboard must
0559 carry `t` and it must be all-null.
0560
0561 This is the canonical added-optional-column schema-evolution case.
0562 """
0563 pa = pytest.importorskip("pyarrow")
0564 pq = pytest.importorskip("pyarrow.parquet")
0565
0566 from acts.arrow import ArrowTable, trackSchema
0567 from acts.examples import ReadDataHandle
0568 from acts.examples.arrow import ParquetReader
0569
0570
0571 full_track_schema_pa = pa.schema(trackSchema())
0572 assert "t" in full_track_schema_pa.names, (
0573 f"trackSchema() unexpectedly lacks the 't' field; this test relies "
0574 f"on it being present. Schema:\n{full_track_schema_pa}"
0575 )
0576
0577
0578
0579
0580 old_track_schema = full_track_schema_pa.remove(
0581 full_track_schema_pa.get_field_index("t")
0582 )
0583
0584 nevents = 4
0585 events_per_shard = 2
0586 collection_dir = tmp_path / "tracks_arrow"
0587 collection_dir.mkdir()
0588
0589
0590
0591
0592 event_id_field = pa.field("event_id", pa.uint32(), nullable=False)
0593 on_disk_schema = pa.schema([event_id_field, *list(old_track_schema)])
0594
0595 def field_type(name: str) -> "pa.DataType":
0596 return old_track_schema.field(name).type
0597
0598 def make_event_table(event_id: int) -> "pa.Table":
0599 return pa.table(
0600 {
0601 "event_id": pa.array([event_id], type=pa.uint32()),
0602 "d0": pa.array([[0.1]], type=field_type("d0")),
0603 "z0": pa.array([[0.2]], type=field_type("z0")),
0604 "phi": pa.array([[0.3]], type=field_type("phi")),
0605 "theta": pa.array([[0.4]], type=field_type("theta")),
0606 "qop": pa.array([[0.5]], type=field_type("qop")),
0607 "majority_particle_id": pa.array(
0608 [[1]], type=field_type("majority_particle_id")
0609 ),
0610 "hit_ids": pa.array([[[1, 2, 3]]], type=field_type("hit_ids")),
0611 "track_id": pa.array([[7]], type=field_type("track_id")),
0612 },
0613 schema=on_disk_schema,
0614 )
0615
0616
0617 for shard_start in range(0, nevents, events_per_shard):
0618 shard_end = shard_start + events_per_shard
0619 shard_path = (
0620 collection_dir / f"tracks_{shard_start:06d}-{shard_end:06d}.parquet"
0621 )
0622 with pq.ParquetWriter(str(shard_path), on_disk_schema) as writer:
0623 for event_id in range(shard_start, shard_end):
0624 writer.write_table(make_event_table(event_id))
0625
0626
0627 for shard_path in sorted(collection_dir.glob("*.parquet")):
0628 on_disk = pq.ParquetFile(str(shard_path)).schema_arrow
0629 assert "t" not in on_disk.names, (
0630 f"{shard_path.name}: precondition broken, on-disk shard "
0631 f"unexpectedly contains 't'. Schema: {on_disk}"
0632 )
0633
0634 reader = ParquetReader(
0635 level=acts.logging.INFO,
0636 inputDir=str(tmp_path),
0637 collections={"tracks_arrow": "tracks_arrow"},
0638 expectedSchemas={"tracks_arrow": trackSchema()},
0639 )
0640 assert reader.availableEvents() == (0, nevents)
0641
0642
0643
0644
0645 class TrackTableCheck(acts.examples.IAlgorithm):
0646 events_seen = 0
0647
0648 def __init__(self, name="TrackTableCheck"):
0649 super().__init__(name=name, level=acts.logging.INFO)
0650 self._handle = ReadDataHandle(self, ArrowTable, "tracks_arrow")
0651 self._handle.initialize("tracks_arrow")
0652
0653 def execute(self, ctx):
0654 handle = self._handle(ctx.eventStore)
0655 t = handle.as_table()
0656 assert "t" in t.column_names, (
0657 f"event {ctx.eventNumber}: 't' column missing from "
0658 f"projected table; schema: {t.schema}"
0659 )
0660 t_col = t.column("t")
0661 assert t_col.null_count == t_col.length(), (
0662 f"event {ctx.eventNumber}: 't' expected all-null, got "
0663 f"{t_col.length() - t_col.null_count} non-null of "
0664 f"{t_col.length()} values"
0665 )
0666 for required in ("d0", "z0", "phi", "theta", "qop"):
0667 assert required in t.column_names, (
0668 f"event {ctx.eventNumber}: required column " f"'{required}' missing"
0669 )
0670 type(self).events_seen += 1
0671 return acts.examples.ProcessCode.SUCCESS
0672
0673 s = Sequencer(numThreads=1)
0674 s.addReader(reader)
0675 s.addAlgorithm(TrackTableCheck())
0676 s.run()
0677
0678 assert (
0679 TrackTableCheck.events_seen == nevents
0680 ), f"checker saw {TrackTableCheck.events_seen} events, expected {nevents}"
0681
0682
0683 def test_python_alg_writes_arrow_table(tmp_path):
0684 """Smoke test for the write direction.
0685
0686 A pure-Python algorithm constructs a per-event pyarrow table, wraps it
0687 via `ArrowTable.from_arrow`, and writes it onto the WhiteBoard through
0688 a typed `WriteDataHandle`. A second pure-Python algorithm reads it back
0689 via `ReadDataHandle`, slurps it into pyarrow via `as_table()`, and
0690 asserts the values survived the round-trip.
0691
0692 Exercises: C-Data import (pyarrow → ArrowTable), `WhiteBoardRegistry`
0693 fromPython (ArrowTable → WhiteBoard storage), C-Data export (ArrowTable
0694 → pyarrow). End-to-end zero-copy across two libarrow instances.
0695 """
0696 pa = pytest.importorskip("pyarrow")
0697
0698 from acts.arrow import ArrowTable, trackSchema
0699 from acts.examples import ReadDataHandle, WriteDataHandle
0700
0701
0702 track_schema_pa = pa.schema(trackSchema())
0703
0704 def field_type(name):
0705 return track_schema_pa.field(name).type
0706
0707 class TrackProducer(acts.examples.IAlgorithm):
0708 """Builds one row per event in the production track schema and
0709 writes it onto the whiteboard as an ArrowTable."""
0710
0711 def __init__(self, key, name="TrackProducer"):
0712 super().__init__(name=name, level=acts.logging.INFO)
0713 self._out = WriteDataHandle(self, ArrowTable, key)
0714 self._out.initialize(key)
0715
0716 def execute(self, ctx):
0717 evt = float(ctx.eventNumber)
0718 pa_table = pa.table(
0719 {
0720 "d0": pa.array([[0.1 + evt]], type=field_type("d0")),
0721 "z0": pa.array([[0.2 + evt]], type=field_type("z0")),
0722 "phi": pa.array([[0.3]], type=field_type("phi")),
0723 "theta": pa.array([[0.4]], type=field_type("theta")),
0724 "qop": pa.array([[0.5]], type=field_type("qop")),
0725 "majority_particle_id": pa.array(
0726 [[1]], type=field_type("majority_particle_id")
0727 ),
0728 "hit_ids": pa.array([[[1, 2, 3]]], type=field_type("hit_ids")),
0729 "track_id": pa.array([[7]], type=field_type("track_id")),
0730 "t": pa.array([None], type=field_type("t")),
0731 },
0732 schema=track_schema_pa,
0733 )
0734 self._out(ctx, ArrowTable.from_arrow(pa_table))
0735 return acts.examples.ProcessCode.SUCCESS
0736
0737 class TrackConsumer(acts.examples.IAlgorithm):
0738 """Reads the table back, exports through C-Data into pyarrow, and
0739 asserts the per-event values match what TrackProducer wrote."""
0740
0741 events_seen = 0
0742
0743 def __init__(self, key, name="TrackConsumer"):
0744 super().__init__(name=name, level=acts.logging.INFO)
0745 self._in = ReadDataHandle(self, ArrowTable, key)
0746 self._in.initialize(key)
0747
0748 def execute(self, ctx):
0749 evt = float(ctx.eventNumber)
0750 t = self._in(ctx.eventStore).as_table()
0751 assert t.num_rows == 1
0752 d0 = t.column("d0").to_pylist()[0]
0753 z0 = t.column("z0").to_pylist()[0]
0754 assert d0 == [
0755 pytest.approx(0.1 + evt)
0756 ], f"event {ctx.eventNumber}: d0 round-trip mismatch: {d0}"
0757 assert z0 == [
0758 pytest.approx(0.2 + evt)
0759 ], f"event {ctx.eventNumber}: z0 round-trip mismatch: {z0}"
0760 type(self).events_seen += 1
0761 return acts.examples.ProcessCode.SUCCESS
0762
0763 nevents = 3
0764 s = Sequencer(numThreads=1, events=nevents)
0765 s.addAlgorithm(TrackProducer(key="produced_tracks_arrow"))
0766 s.addAlgorithm(TrackConsumer(key="produced_tracks_arrow"))
0767 s.run()
0768
0769 assert (
0770 TrackConsumer.events_seen == nevents
0771 ), f"consumer saw {TrackConsumer.events_seen} events, expected {nevents}"
0772
0773
0774 def test_writer_rejects_missing_schema(tmp_path):
0775 """ParquetWriter requires an expected schema for every collection.
0776 Constructing one without one must fail at config time, not at run time.
0777 """
0778 from acts.examples.arrow import ParquetWriter
0779
0780 with pytest.raises(ValueError, match="no expected schema"):
0781 ParquetWriter(
0782 level=acts.logging.INFO,
0783 outputDir=str(tmp_path),
0784 collections={"some_collection": "some_collection"},
0785 expectedSchemas={},
0786 )
0787
0788
0789 def test_writer_aborts_on_per_event_schema_mismatch(tmp_path):
0790 """A pure-Python algorithm produces a table whose schema doesn't match
0791 the writer's declared expectedSchemas. The writer must abort the
0792 sequencer with a clear message rather than silently writing garbage.
0793 """
0794 pa = pytest.importorskip("pyarrow")
0795
0796 from acts.arrow import ArrowTable, particleSchema
0797 from acts.examples import WriteDataHandle
0798 from acts.examples.arrow import ParquetWriter
0799
0800
0801
0802
0803 class WrongShapeProducer(acts.examples.IAlgorithm):
0804 def __init__(self, key, name="WrongShapeProducer"):
0805 super().__init__(name=name, level=acts.logging.INFO)
0806 self._out = WriteDataHandle(self, ArrowTable, key)
0807 self._out.initialize(key)
0808
0809 def execute(self, ctx):
0810 wrong = pa.table({"unexpected": pa.array([1], type=pa.int32())})
0811 self._out(ctx, ArrowTable.from_arrow(wrong))
0812 return acts.examples.ProcessCode.SUCCESS
0813
0814 s = Sequencer(numThreads=1, events=1)
0815 s.addAlgorithm(WrongShapeProducer(key="bogus_arrow"))
0816 s.addWriter(
0817 ParquetWriter(
0818 level=acts.logging.INFO,
0819 outputDir=str(tmp_path),
0820 collections={"bogus_arrow": "bogus_arrow"},
0821 expectedSchemas={"bogus_arrow": particleSchema()},
0822 )
0823 )
0824
0825
0826 with pytest.raises(RuntimeError):
0827 s.run()