Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-11 07:54:09

0001 import json
0002 import os
0003 import tempfile
0004 import pytest
0005 from click.testing import CliRunner
0006 from pyrobird.cli.merge import merge, is_valid_dex_file, merge_dex_files, merge_events, merge_event_groups, create_merged_header
0007 
0008 # Sample Firebird DEX JSON data for testing
0009 SAMPLE_DEX_1 = {
0010     "type": "firebird-dex-json",
0011     "version": "0.02",
0012     "origin": {
0013         "file": "sample1.root",
0014         "entries_count": 2
0015     },
0016     "events": [
0017         {
0018             "id": "event_0",
0019             "groups": [
0020                 {
0021                     "name": "BarrelVertexHits",
0022                     "type": "BoxTrackerHit",
0023                     "origin": {"type": "edm4eic::TrackerHitData"},
0024                     "hits": [
0025                         {
0026                             "pos": [1, 2, 3],
0027                             "dim": [0.1, 0.1, 0.1],
0028                             "t": [0, 0],
0029                             "ed": [0.001, 0]
0030                         }
0031                     ]
0032                 }
0033             ]
0034         },
0035         {
0036             "id": "event_1",
0037             "groups": [
0038                 {
0039                     "name": "BarrelTracks",
0040                     "type": "TrackerLinePointTrajectory",
0041                     "origin": {"type":"edm4eic::TrackSegmentData"},
0042                     "lines": []
0043                 }
0044             ]
0045         }
0046     ]
0047 }
0048 
0049 SAMPLE_DEX_2 = {
0050     "type": "firebird-dex-json",
0051     "version": "0.01",
0052     "origin": {
0053         "file": "sample2.root",
0054         "entries_count": 2
0055     },
0056     "events": [
0057         {
0058             "id": "event_0",
0059             "groups": [
0060                 {
0061                     "name": "EndcapVertexHits",
0062                     "type": "BoxTrackerHit",
0063                     "origin": {"type": ["edm4eic::TrackerHitData"]},
0064                     "hits": [
0065                         {
0066                             "pos": [10, 20, 30],
0067                             "dim": [0.2, 0.2, 0.2],
0068                             "t": [1, 0],
0069                             "ed": [0.002, 0]
0070                         }
0071                     ]
0072                 }
0073             ]
0074         },
0075         {
0076             "id": "event_2",
0077             "groups": [
0078                 {
0079                     "name": "EndcapTracks",
0080                     "type": "TrackerLinePointTrajectory",
0081                     "origin": {"type": "edm4eic::TrackSegmentData"},
0082                     "lines": []
0083                 }
0084             ]
0085         }
0086     ]
0087 }
0088 
0089 # Sample with conflicting group names
0090 SAMPLE_DEX_CONFLICT = {
0091     "version": "0.03",
0092     "events": [
0093         {
0094             "id": "event_0",
0095             "groups": [
0096                 {
0097                     "name": "BarrelVertexHits",  # Same name as in SAMPLE_DEX_1
0098                     "type": "BoxTrackerHit",
0099                     "origin": {"type": "edm4eic::TrackerHitData"},
0100                     "hits": [
0101                         {
0102                             "pos": [100, 200, 300],
0103                             "dim": [1, 1, 1],
0104                             "t": [10, 1],
0105                             "ed": [0.01, 0.001]
0106                         }
0107                     ]
0108                 }
0109             ]
0110         }
0111     ]
0112 }
0113 
0114 # Invalid DEX format (missing required fields)
0115 INVALID_DEX = {
0116     "version": "0.01",
0117     "events": [
0118         {
0119             "id": "event_0",
0120             # Missing "groups" field
0121         }
0122     ]
0123 }
0124 
0125 
0126 @pytest.fixture
0127 def temp_dex_files():
0128     """Create temporary DEX files for testing."""
0129     with tempfile.TemporaryDirectory() as tmpdirname:
0130         file1_path = os.path.join(tmpdirname, "sample1.firebird.json")
0131         file2_path = os.path.join(tmpdirname, "sample2.firebird.json")
0132         conflict_path = os.path.join(tmpdirname, "conflict.firebird.json")
0133         invalid_path = os.path.join(tmpdirname, "invalid.firebird.json")
0134         output_path = os.path.join(tmpdirname, "output.firebird.json")
0135 
0136         with open(file1_path, 'w') as f:
0137             json.dump(SAMPLE_DEX_1, f)
0138 
0139         with open(file2_path, 'w') as f:
0140             json.dump(SAMPLE_DEX_2, f)
0141 
0142         with open(conflict_path, 'w') as f:
0143             json.dump(SAMPLE_DEX_CONFLICT, f)
0144 
0145         with open(invalid_path, 'w') as f:
0146             json.dump(INVALID_DEX, f)
0147 
0148         yield {
0149             "file1": file1_path,
0150             "file2": file2_path,
0151             "conflict": conflict_path,
0152             "invalid": invalid_path,
0153             "output": output_path
0154         }
0155 
0156 
0157 def test_basic_merge(temp_dex_files):
0158     """Test basic merging of two compatible DEX files."""
0159     runner = CliRunner()
0160     result = runner.invoke(merge, [temp_dex_files["file1"], temp_dex_files["file2"], "-o", temp_dex_files["output"]])
0161 
0162     assert result.exit_code == 0
0163 
0164     # Check that the output file exists
0165     assert os.path.exists(temp_dex_files["output"])
0166 
0167     # Load and verify the merged content
0168     with open(temp_dex_files["output"], 'r') as f:
0169         merged_data = json.load(f)
0170 
0171     # Verify the header structure
0172     assert "type" in merged_data
0173     assert merged_data["type"] == "firebird-dex-json"
0174     assert "version" in merged_data
0175     assert merged_data["version"] in ["0.04"]  # Should use the latest version
0176 
0177     # Check origin metadata
0178     assert "origin" in merged_data
0179     assert "merged_from" in merged_data["origin"]
0180     assert len(merged_data["origin"]["merged_from"]) == 2
0181     assert "entries_count" in merged_data["origin"]
0182     assert merged_data["origin"]["entries_count"] == 4  # Total from both files
0183 
0184     # Verify merged events
0185     assert len(merged_data["events"]) == 3  # All events from both files
0186 
0187     # Check that the events have been properly merged
0188     event_ids = [event["id"] for event in merged_data["events"]]
0189     assert "event_0" in event_ids
0190     assert "event_1" in event_ids
0191     assert "event_2" in event_ids
0192 
0193     # Check the groups in the first event (which exists in both files)
0194     for event in merged_data["events"]:
0195         if event["id"] == "event_0":
0196             # This event should have groups from both files
0197             group_names = [group["name"] for group in event["groups"]]
0198             assert "BarrelVertexHits" in group_names
0199             assert "EndcapVertexHits" in group_names
0200             assert len(event["groups"]) == 2
0201 
0202 
0203 def test_reset_id_flag(temp_dex_files):
0204     """Test merging with reset-id flag."""
0205     runner = CliRunner()
0206     result = runner.invoke(merge, ["--reset-id", temp_dex_files["file1"], temp_dex_files["file2"], "-o", temp_dex_files["output"]])
0207 
0208     assert result.exit_code == 0
0209 
0210     # Load and verify the merged content
0211     with open(temp_dex_files["output"], 'r') as f:
0212         merged_data = json.load(f)
0213 
0214     # With reset-id, events should have been merged by position
0215     assert len(merged_data["events"]) == 2  # File1.event0 + File2.event0, File1.event1 + File2.event1
0216 
0217     # First event should have groups from both first events
0218     first_event = merged_data["events"][0]
0219     group_names = [group["name"] for group in first_event["groups"]]
0220     assert "BarrelVertexHits" in group_names
0221     assert "EndcapVertexHits" in group_names
0222 
0223     # Second event should have groups from both second events
0224     second_event = merged_data["events"][1]
0225     group_names = [group["name"] for group in second_event["groups"]]
0226     assert "BarrelTracks" in group_names
0227     assert "EndcapTracks" in group_names
0228 
0229 
0230 def test_conflict_detection(temp_dex_files):
0231     """Test detection of conflicting group names."""
0232     runner = CliRunner()
0233     # In this case, we want to let the exception be raised but still test it
0234     try:
0235         runner.invoke(merge, [temp_dex_files["file1"], temp_dex_files["conflict"]], catch_exceptions=False)
0236         # If we get here, the command didn't raise an exception, which is a failure
0237         assert False, "Expected a ValueError but no exception was raised"
0238     except ValueError as e:
0239         # The error message should be about duplicate groups
0240         error_msg = str(e).lower()
0241         assert any(phrase in error_msg for phrase in [
0242             "duplicate group",
0243             "duplicate name",
0244             "already exists"
0245         ]), f"Expected error about duplicate groups, got: {error_msg}"
0246 
0247 
0248 def test_ignore_flag(temp_dex_files):
0249     """Test the ignore flag for conflicting group names."""
0250     runner = CliRunner()
0251     result = runner.invoke(
0252         merge,
0253         ["--ignore", temp_dex_files["file1"], temp_dex_files["conflict"], "-o", temp_dex_files["output"]]
0254     )
0255 
0256     assert result.exit_code == 0
0257 
0258     # Load and verify the merged content
0259     with open(temp_dex_files["output"], 'r') as f:
0260         merged_data = json.load(f)
0261 
0262     # Find the event with ID "event_0"
0263     for event in merged_data["events"]:
0264         if event["id"] == "event_0":
0265             # Check that we have the BarrelVertexHits group from file1 (not from conflict)
0266             for group in event["groups"]:
0267                 if group["name"] == "BarrelVertexHits":
0268                     # Verify it's the one from file1, not from conflict
0269                     assert group["hits"][0]["pos"] == [1, 2, 3]  # Values from SAMPLE_DEX_1
0270                     assert group["hits"][0]["dim"] == [0.1, 0.1, 0.1]  # Values from SAMPLE_DEX_1
0271 
0272 
0273 def test_overwrite_flag(temp_dex_files):
0274     """Test the overwrite flag for conflicting group names."""
0275     runner = CliRunner()
0276     result = runner.invoke(
0277         merge,
0278         ["--overwrite", temp_dex_files["file1"], temp_dex_files["conflict"], "-o", temp_dex_files["output"]]
0279     )
0280 
0281     assert result.exit_code == 0
0282 
0283     # Load and verify the merged content
0284     with open(temp_dex_files["output"], 'r') as f:
0285         merged_data = json.load(f)
0286 
0287     # Find the event with ID "event_0"
0288     for event in merged_data["events"]:
0289         if event["id"] == "event_0":
0290             # Check that we have the BarrelVertexHits group from conflict (not from file1)
0291             for group in event["groups"]:
0292                 if group["name"] == "BarrelVertexHits":
0293                     # Verify it's the one from conflict, not from file1
0294                     assert group["hits"][0]["pos"] == [100, 200, 300]  # Values from SAMPLE_DEX_CONFLICT
0295                     assert group["hits"][0]["dim"] == [1, 1, 1]  # Values from SAMPLE_DEX_CONFLICT
0296 
0297 
0298 def test_invalid_file(temp_dex_files):
0299     """Test handling of invalid DEX files."""
0300     runner = CliRunner()
0301     result = runner.invoke(merge, [temp_dex_files["file1"], temp_dex_files["invalid"]], catch_exceptions=False)
0302 
0303     # Should fail due to invalid file format
0304     assert result.exit_code != 0
0305     assert "not a valid Firebird DEX file" in result.output or "valid firebird dex" in result.output.lower()
0306 
0307 
0308 def test_conflict_between_flags(temp_dex_files):
0309     """Test that ignore and overwrite flags cannot be used together."""
0310     runner = CliRunner()
0311     result = runner.invoke(
0312         merge,
0313         ["--ignore", "--overwrite", temp_dex_files["file1"], temp_dex_files["conflict"]],
0314         catch_exceptions=False
0315     )
0316 
0317     # Should fail due to conflicting flags
0318     assert result.exit_code != 0
0319     assert "--ignore and --overwrite flags cannot be used together" in result.output or "ignore and overwrite" in result.output.lower()
0320 
0321 
0322 def test_missing_files():
0323     """Test handling of missing input files."""
0324     runner = CliRunner()
0325     result = runner.invoke(merge, ["nonexistent1.json", "nonexistent2.json"], catch_exceptions=False)
0326 
0327     # Should fail due to missing files
0328     assert result.exit_code != 0
0329     assert any(text in result.output.lower() for text in ["no such file", "not found", "error opening", "could not open"])
0330 
0331 
0332 def test_stdout_output(temp_dex_files):
0333     """Test output to stdout when no output file is specified."""
0334     runner = CliRunner()
0335     result = runner.invoke(merge, [temp_dex_files["file1"], temp_dex_files["file2"]])
0336 
0337     assert result.exit_code == 0
0338     # Verify that the merged JSON is in the stdout output
0339     assert "BarrelVertexHits" in result.output
0340     assert "EndcapVertexHits" in result.output
0341     assert "firebird-dex-json" in result.output
0342 
0343 
0344 def test_merge_more_than_two_files(temp_dex_files):
0345     """Test merging more than two files."""
0346     # Create a third file with a unique event
0347     sample_dex_3 = {
0348         "type": "firebird-dex-json",
0349         "version": "0.03",
0350         "origin": {
0351             "file": "sample3.root",
0352             "entries_count": 1
0353         },
0354         "events": [
0355             {
0356                 "id": "event_3",
0357                 "groups": [
0358                     {
0359                         "name": "CalorHits",
0360                         "type": "BoxTrackerHit",
0361                         "origin": {"type": "edm4eic::TrackerHitData"},
0362                         "hits": []
0363                     }
0364                 ]
0365             }
0366         ]
0367     }
0368 
0369     file3_path = os.path.join(os.path.dirname(temp_dex_files["file1"]), "sample3.firebird.json")
0370     with open(file3_path, 'w') as f:
0371         json.dump(sample_dex_3, f)
0372 
0373     runner = CliRunner()
0374     result = runner.invoke(
0375         merge,
0376         [temp_dex_files["file1"], temp_dex_files["file2"], file3_path, "-o", temp_dex_files["output"]]
0377     )
0378 
0379     assert result.exit_code == 0
0380 
0381     # Load and verify the merged content
0382     with open(temp_dex_files["output"], 'r') as f:
0383         merged_data = json.load(f)
0384 
0385     # Verify header
0386     assert merged_data["version"] == "0.04"  # Should use the latest version
0387     assert len(merged_data["origin"]["merged_from"]) == 3
0388     assert merged_data["origin"]["entries_count"] == 5  # Total from all files
0389 
0390     # Verify merged events - should have all events from all three files
0391     assert len(merged_data["events"]) == 4  # 2 from file1 + 1 from file2 + 1 from file3
0392 
0393     # Check that all expected event IDs are present
0394     event_ids = [event["id"] for event in merged_data["events"]]
0395     assert "event_0" in event_ids
0396     assert "event_1" in event_ids
0397     assert "event_2" in event_ids
0398     assert "event_3" in event_ids
0399 
0400 
0401 def test_is_valid_dex_file():
0402     """Test the is_valid_dex_file function."""
0403     # Valid DEX file
0404     assert is_valid_dex_file(SAMPLE_DEX_1)
0405 
0406     # Invalid cases
0407     assert not is_valid_dex_file({})  # Empty dict
0408     assert not is_valid_dex_file({"events": "not a list"})  # events not a list
0409     assert not is_valid_dex_file({"version": "0.01"})  # Missing events
0410 
0411     # Invalid event
0412     invalid_event = {
0413         "version": "0.01",
0414         "events": [{"not_id": "event_0"}]  # Missing id
0415     }
0416     assert not is_valid_dex_file(invalid_event)
0417 
0418     # Invalid group
0419     invalid_group = {
0420         "version": "0.01",
0421         "events": [{
0422             "id": "event_0",
0423             "groups": [{"not_name": "BarrelVertexHits"}]  # Missing name
0424         }]
0425     }
0426     assert not is_valid_dex_file(invalid_group)
0427 
0428 
0429 def test_create_merged_header():
0430     """Test the create_merged_header function."""
0431     dex_files = [
0432         ("file1.json", SAMPLE_DEX_1),
0433         ("file2.json", SAMPLE_DEX_2),
0434         ("file3.json", SAMPLE_DEX_CONFLICT)
0435     ]
0436 
0437     header = create_merged_header(dex_files)
0438 
0439     # Check basic structure
0440     assert header["type"] == "firebird-dex-json"
0441     assert header["version"] == "0.04"  # Should use the latest version
0442     assert "origin" in header
0443     assert "merged_from" in header["origin"]
0444     assert "entries_count" in header["origin"]
0445 
0446     # Check merged_from list
0447     merged_from = header["origin"]["merged_from"]
0448     assert len(merged_from) == 3
0449     assert merged_from[0]["file"] == "file1.json"
0450     assert merged_from[1]["file"] == "file2.json"
0451     assert merged_from[2]["file"] == "file3.json"
0452 
0453     # Check entries_count
0454     assert header["origin"]["entries_count"] == 4  # 2 from file1 + 2 from file2
0455 
0456 
0457 def test_merge_event_groups():
0458     """Test the merge_event_groups function."""
0459     # Set up test events
0460     event1 = {
0461         "id": "test_event",
0462         "groups": [
0463             {"name": "group1", "type": "type1", "origin": {"type": "TypeA"}, "data": [1, 2, 3]},
0464             {"name": "group2", "type": "type2", "origin": {"type": "TypeB"}, "data": [4, 5, 6]}
0465         ]
0466     }
0467 
0468     event2 = {
0469         "id": "test_event",
0470         "groups": [
0471             {"name": "group3", "type": "type3", "origin": {"type": "TypeC"}, "data": [7, 8, 9]},
0472             {"name": "group1", "type": "type1", "origin": {"type": "TypeA"}, "data": [10, 11, 12]}  # Conflict with event1
0473         ]
0474     }
0475 
0476     events_with_id = [
0477         ("file1.json", event1),
0478         ("file2.json", event2)
0479     ]
0480 
0481     # Test normal merge (should raise ValueError due to conflict)
0482     with pytest.raises(ValueError):
0483         merged_event = merge_event_groups("test_event", events_with_id)
0484 
0485     # Test with ignore flag
0486     merged_event = merge_event_groups("test_event", events_with_id, ignore=True)
0487     group_names = [group["name"] for group in merged_event["groups"]]
0488     assert "group1" in group_names
0489     assert "group2" in group_names
0490     assert "group3" in group_names
0491 
0492     # Verify that group1 from file1 was kept
0493     for group in merged_event["groups"]:
0494         if group["name"] == "group1":
0495             assert group["data"] == [1, 2, 3]  # From event1, not event2
0496 
0497     # Test with overwrite flag
0498     merged_event = merge_event_groups("test_event", events_with_id, overwrite=True)
0499     group_names = [group["name"] for group in merged_event["groups"]]
0500     assert "group1" in group_names
0501     assert "group2" in group_names
0502     assert "group3" in group_names
0503 
0504     # Verify that group1 from file2 overwrote the one from file1
0505     for group in merged_event["groups"]:
0506         if group["name"] == "group1":
0507             assert group["data"] == [10, 11, 12]  # From event2, not event1