Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:04:51

0001 import json
0002 import os
0003 import tempfile
0004 import pytest
0005 from click.testing import CliRunner
0006 from pyrobird.cli.merge import merge, is_valid_dex_file, merge_dex_files, merge_entries, merge_entry_components, create_merged_header
0007 
0008 # Sample Firebird DEX JSON data for testing
0009 SAMPLE_DEX_1 = {
0010     "type": "firebird-dex-json",
0011     "version": "0.02",
0012     "origin": {
0013         "file": "sample1.root",
0014         "entries_count": 2
0015     },
0016     "entries": [
0017         {
0018             "id": "event_0",
0019             "components": [
0020                 {
0021                     "name": "BarrelVertexHits",
0022                     "type": "BoxTrackerHit",
0023                     "originType": "edm4eic::TrackerHitData",
0024                     "hits": [
0025                         {
0026                             "pos": [1, 2, 3],
0027                             "dim": [0.1, 0.1, 0.1],
0028                             "t": [0, 0],
0029                             "ed": [0.001, 0]
0030                         }
0031                     ]
0032                 }
0033             ]
0034         },
0035         {
0036             "id": "event_1",
0037             "components": [
0038                 {
0039                     "name": "BarrelTracks",
0040                     "type": "TrackerLinePointTrajectory",
0041                     "originType": "edm4eic::TrackSegmentData",
0042                     "lines": []
0043                 }
0044             ]
0045         }
0046     ]
0047 }
0048 
0049 SAMPLE_DEX_2 = {
0050     "type": "firebird-dex-json",
0051     "version": "0.01",
0052     "origin": {
0053         "file": "sample2.root",
0054         "entries_count": 2
0055     },
0056     "entries": [
0057         {
0058             "id": "event_0",
0059             "components": [
0060                 {
0061                     "name": "EndcapVertexHits",
0062                     "type": "BoxTrackerHit",
0063                     "originType": "edm4eic::TrackerHitData",
0064                     "hits": [
0065                         {
0066                             "pos": [10, 20, 30],
0067                             "dim": [0.2, 0.2, 0.2],
0068                             "t": [1, 0],
0069                             "ed": [0.002, 0]
0070                         }
0071                     ]
0072                 }
0073             ]
0074         },
0075         {
0076             "id": "event_2",
0077             "components": [
0078                 {
0079                     "name": "EndcapTracks",
0080                     "type": "TrackerLinePointTrajectory",
0081                     "originType": "edm4eic::TrackSegmentData",
0082                     "lines": []
0083                 }
0084             ]
0085         }
0086     ]
0087 }
0088 
0089 # Sample with conflicting component names
0090 SAMPLE_DEX_CONFLICT = {
0091     "version": "0.03",
0092     "entries": [
0093         {
0094             "id": "event_0",
0095             "components": [
0096                 {
0097                     "name": "BarrelVertexHits",  # Same name as in SAMPLE_DEX_1
0098                     "type": "BoxTrackerHit",
0099                     "originType": "edm4eic::TrackerHitData",
0100                     "hits": [
0101                         {
0102                             "pos": [100, 200, 300],
0103                             "dim": [1, 1, 1],
0104                             "t": [10, 1],
0105                             "ed": [0.01, 0.001]
0106                         }
0107                     ]
0108                 }
0109             ]
0110         }
0111     ]
0112 }
0113 
0114 # Invalid DEX format (missing required fields)
0115 INVALID_DEX = {
0116     "version": "0.01",
0117     "entries": [
0118         {
0119             "id": "event_0",
0120             # Missing "components" field
0121         }
0122     ]
0123 }
0124 
0125 
0126 @pytest.fixture
0127 def temp_dex_files():
0128     """Create temporary DEX files for testing."""
0129     with tempfile.TemporaryDirectory() as tmpdirname:
0130         file1_path = os.path.join(tmpdirname, "sample1.firebird.json")
0131         file2_path = os.path.join(tmpdirname, "sample2.firebird.json")
0132         conflict_path = os.path.join(tmpdirname, "conflict.firebird.json")
0133         invalid_path = os.path.join(tmpdirname, "invalid.firebird.json")
0134         output_path = os.path.join(tmpdirname, "output.firebird.json")
0135 
0136         with open(file1_path, 'w') as f:
0137             json.dump(SAMPLE_DEX_1, f)
0138 
0139         with open(file2_path, 'w') as f:
0140             json.dump(SAMPLE_DEX_2, f)
0141 
0142         with open(conflict_path, 'w') as f:
0143             json.dump(SAMPLE_DEX_CONFLICT, f)
0144 
0145         with open(invalid_path, 'w') as f:
0146             json.dump(INVALID_DEX, f)
0147 
0148         yield {
0149             "file1": file1_path,
0150             "file2": file2_path,
0151             "conflict": conflict_path,
0152             "invalid": invalid_path,
0153             "output": output_path
0154         }
0155 
0156 
0157 def test_basic_merge(temp_dex_files):
0158     """Test basic merging of two compatible DEX files."""
0159     runner = CliRunner()
0160     result = runner.invoke(merge, [temp_dex_files["file1"], temp_dex_files["file2"], "-o", temp_dex_files["output"]])
0161 
0162     assert result.exit_code == 0
0163 
0164     # Check that the output file exists
0165     assert os.path.exists(temp_dex_files["output"])
0166 
0167     # Load and verify the merged content
0168     with open(temp_dex_files["output"], 'r') as f:
0169         merged_data = json.load(f)
0170 
0171     # Verify the header structure
0172     assert "type" in merged_data
0173     assert merged_data["type"] == "firebird-dex-json"
0174     assert "version" in merged_data
0175     assert merged_data["version"] in ["0.02", "0.03"]  # Should use the highest version
0176 
0177     # Check origin metadata
0178     assert "origin" in merged_data
0179     assert "merged_from" in merged_data["origin"]
0180     assert len(merged_data["origin"]["merged_from"]) == 2
0181     assert "entries_count" in merged_data["origin"]
0182     assert merged_data["origin"]["entries_count"] == 4  # Total from both files
0183 
0184     # Verify merged entries
0185     assert len(merged_data["entries"]) == 3  # All entries from both files
0186 
0187     # Check that the entries have been properly merged
0188     entry_ids = [entry["id"] for entry in merged_data["entries"]]
0189     assert "event_0" in entry_ids
0190     assert "event_1" in entry_ids
0191     assert "event_2" in entry_ids
0192 
0193     # Check the components in the first entry (which exists in both files)
0194     for entry in merged_data["entries"]:
0195         if entry["id"] == "event_0":
0196             # This entry should have components from both files
0197             component_names = [comp["name"] for comp in entry["components"]]
0198             assert "BarrelVertexHits" in component_names
0199             assert "EndcapVertexHits" in component_names
0200             assert len(entry["components"]) == 2
0201 
0202 
0203 def test_reset_id_flag(temp_dex_files):
0204     """Test merging with reset-id flag."""
0205     runner = CliRunner()
0206     result = runner.invoke(merge, ["--reset-id", temp_dex_files["file1"], temp_dex_files["file2"], "-o", temp_dex_files["output"]])
0207 
0208     assert result.exit_code == 0
0209 
0210     # Load and verify the merged content
0211     with open(temp_dex_files["output"], 'r') as f:
0212         merged_data = json.load(f)
0213 
0214     # With reset-id, entries should have been merged by position
0215     assert len(merged_data["entries"]) == 2  # File1.entry0 + File2.entry0, File1.entry1 + File2.entry1
0216 
0217     # First entry should have components from both first entries
0218     first_entry = merged_data["entries"][0]
0219     component_names = [comp["name"] for comp in first_entry["components"]]
0220     assert "BarrelVertexHits" in component_names
0221     assert "EndcapVertexHits" in component_names
0222 
0223     # Second entry should have components from both second entries
0224     second_entry = merged_data["entries"][1]
0225     component_names = [comp["name"] for comp in second_entry["components"]]
0226     assert "BarrelTracks" in component_names
0227     assert "EndcapTracks" in component_names
0228 
0229 
0230 def test_conflict_detection(temp_dex_files):
0231     """Test detection of conflicting component names."""
0232     runner = CliRunner()
0233     # Use pytest.raises to catch the expected ValueError with the specific error message
0234     with pytest.raises(ValueError, match="Duplicate component name.*Use --ignore or --overwrite flags to handle duplicates"):
0235         runner.invoke(merge, [temp_dex_files["file1"], temp_dex_files["conflict"]], catch_exceptions=False)
0236 
0237 
0238 
0239 def test_ignore_flag(temp_dex_files):
0240     """Test the ignore flag for conflicting component names."""
0241     runner = CliRunner()
0242     result = runner.invoke(
0243         merge,
0244         ["--ignore", temp_dex_files["file1"], temp_dex_files["conflict"], "-o", temp_dex_files["output"]]
0245     )
0246 
0247     assert result.exit_code == 0
0248 
0249     # Load and verify the merged content
0250     with open(temp_dex_files["output"], 'r') as f:
0251         merged_data = json.load(f)
0252 
0253     # Find the entry with ID "event_0"
0254     for entry in merged_data["entries"]:
0255         if entry["id"] == "event_0":
0256             # Check that we have the BarrelVertexHits component from file1 (not from conflict)
0257             for comp in entry["components"]:
0258                 if comp["name"] == "BarrelVertexHits":
0259                     # Verify it's the one from file1, not from conflict
0260                     assert comp["hits"][0]["pos"] == [1, 2, 3]  # Values from SAMPLE_DEX_1
0261                     assert comp["hits"][0]["dim"] == [0.1, 0.1, 0.1]  # Values from SAMPLE_DEX_1
0262 
0263 
0264 def test_overwrite_flag(temp_dex_files):
0265     """Test the overwrite flag for conflicting component names."""
0266     runner = CliRunner()
0267     result = runner.invoke(
0268         merge,
0269         ["--overwrite", temp_dex_files["file1"], temp_dex_files["conflict"], "-o", temp_dex_files["output"]]
0270     )
0271 
0272     assert result.exit_code == 0
0273 
0274     # Load and verify the merged content
0275     with open(temp_dex_files["output"], 'r') as f:
0276         merged_data = json.load(f)
0277 
0278     # Find the entry with ID "event_0"
0279     for entry in merged_data["entries"]:
0280         if entry["id"] == "event_0":
0281             # Check that we have the BarrelVertexHits component from conflict (not from file1)
0282             for comp in entry["components"]:
0283                 if comp["name"] == "BarrelVertexHits":
0284                     # Verify it's the one from conflict, not from file1
0285                     assert comp["hits"][0]["pos"] == [100, 200, 300]  # Values from SAMPLE_DEX_CONFLICT
0286                     assert comp["hits"][0]["dim"] == [1, 1, 1]  # Values from SAMPLE_DEX_CONFLICT
0287 
0288 
0289 def test_invalid_file(temp_dex_files):
0290     """Test handling of invalid DEX files."""
0291     runner = CliRunner()
0292     result = runner.invoke(merge, [temp_dex_files["file1"], temp_dex_files["invalid"]], catch_exceptions=False)
0293 
0294     # Should fail due to invalid file format
0295     assert result.exit_code != 0
0296     assert "not a valid Firebird DEX file" in result.output or "valid firebird dex" in result.output.lower()
0297 
0298 
0299 def test_conflict_between_flags(temp_dex_files):
0300     """Test that ignore and overwrite flags cannot be used together."""
0301     runner = CliRunner()
0302     result = runner.invoke(
0303         merge,
0304         ["--ignore", "--overwrite", temp_dex_files["file1"], temp_dex_files["conflict"]],
0305         catch_exceptions=False
0306     )
0307 
0308     # Should fail due to conflicting flags
0309     assert result.exit_code != 0
0310     assert "--ignore and --overwrite flags cannot be used together" in result.output or "ignore and overwrite" in result.output.lower()
0311 
0312 
0313 def test_missing_files():
0314     """Test handling of missing input files."""
0315     runner = CliRunner()
0316     result = runner.invoke(merge, ["nonexistent1.json", "nonexistent2.json"], catch_exceptions=False)
0317 
0318     # Should fail due to missing files
0319     assert result.exit_code != 0
0320     assert any(text in result.output.lower() for text in ["no such file", "not found", "error opening", "could not open"])
0321 
0322 
0323 def test_stdout_output(temp_dex_files):
0324     """Test output to stdout when no output file is specified."""
0325     runner = CliRunner()
0326     result = runner.invoke(merge, [temp_dex_files["file1"], temp_dex_files["file2"]])
0327 
0328     assert result.exit_code == 0
0329     # Verify that the merged JSON is in the stdout output
0330     assert "BarrelVertexHits" in result.output
0331     assert "EndcapVertexHits" in result.output
0332     assert "firebird-dex-json" in result.output
0333 
0334 
0335 def test_merge_more_than_two_files(temp_dex_files):
0336     """Test merging more than two files."""
0337     # Create a third file with a unique entry
0338     sample_dex_3 = {
0339         "type": "firebird-dex-json",
0340         "version": "0.03",
0341         "origin": {
0342             "file": "sample3.root",
0343             "entries_count": 1
0344         },
0345         "entries": [
0346             {
0347                 "id": "event_3",
0348                 "components": [
0349                     {
0350                         "name": "CalorHits",
0351                         "type": "BoxTrackerHit",
0352                         "originType": "edm4eic::TrackerHitData",
0353                         "hits": []
0354                     }
0355                 ]
0356             }
0357         ]
0358     }
0359 
0360     file3_path = os.path.join(os.path.dirname(temp_dex_files["file1"]), "sample3.firebird.json")
0361     with open(file3_path, 'w') as f:
0362         json.dump(sample_dex_3, f)
0363 
0364     runner = CliRunner()
0365     result = runner.invoke(
0366         merge,
0367         [temp_dex_files["file1"], temp_dex_files["file2"], file3_path, "-o", temp_dex_files["output"]]
0368     )
0369 
0370     assert result.exit_code == 0
0371 
0372     # Load and verify the merged content
0373     with open(temp_dex_files["output"], 'r') as f:
0374         merged_data = json.load(f)
0375 
0376     # Verify header
0377     assert merged_data["version"] == "0.03"  # Should use the highest version
0378     assert len(merged_data["origin"]["merged_from"]) == 3
0379     assert merged_data["origin"]["entries_count"] == 5  # Total from all files
0380 
0381     # Verify merged entries - should have all entries from all three files
0382     assert len(merged_data["entries"]) == 4  # 2 from file1 + 1 from file2 + 1 from file3
0383 
0384     # Check that all expected entry IDs are present
0385     entry_ids = [entry["id"] for entry in merged_data["entries"]]
0386     assert "event_0" in entry_ids
0387     assert "event_1" in entry_ids
0388     assert "event_2" in entry_ids
0389     assert "event_3" in entry_ids
0390 
0391 
0392 def test_is_valid_dex_file():
0393     """Test the is_valid_dex_file function."""
0394     # Valid DEX file
0395     assert is_valid_dex_file(SAMPLE_DEX_1)
0396 
0397     # Invalid cases
0398     assert not is_valid_dex_file({})  # Empty dict
0399     assert not is_valid_dex_file({"entries": "not a list"})  # entries not a list
0400     assert not is_valid_dex_file({"version": "0.01"})  # Missing entries
0401 
0402     # Invalid entry
0403     invalid_entry = {
0404         "version": "0.01",
0405         "entries": [{"not_id": "event_0"}]  # Missing id
0406     }
0407     assert not is_valid_dex_file(invalid_entry)
0408 
0409     # Invalid component
0410     invalid_component = {
0411         "version": "0.01",
0412         "entries": [{
0413             "id": "event_0",
0414             "components": [{"not_name": "BarrelVertexHits"}]  # Missing name
0415         }]
0416     }
0417     assert not is_valid_dex_file(invalid_component)
0418 
0419 
0420 def test_create_merged_header():
0421     """Test the create_merged_header function."""
0422     dex_files = [
0423         ("file1.json", SAMPLE_DEX_1),
0424         ("file2.json", SAMPLE_DEX_2),
0425         ("file3.json", SAMPLE_DEX_CONFLICT)
0426     ]
0427 
0428     header = create_merged_header(dex_files)
0429 
0430     # Check basic structure
0431     assert header["type"] == "firebird-dex-json"
0432     assert header["version"] == "0.03"  # Should use the highest version
0433     assert "origin" in header
0434     assert "merged_from" in header["origin"]
0435     assert "entries_count" in header["origin"]
0436 
0437     # Check merged_from list
0438     merged_from = header["origin"]["merged_from"]
0439     assert len(merged_from) == 3
0440     assert merged_from[0]["file"] == "file1.json"
0441     assert merged_from[1]["file"] == "file2.json"
0442     assert merged_from[2]["file"] == "file3.json"
0443 
0444     # Check entries_count
0445     assert header["origin"]["entries_count"] == 4  # 2 from file1 + 2 from file2
0446 
0447 
0448 def test_merge_entry_components():
0449     """Test the merge_entry_components function."""
0450     # Set up test entries
0451     entry1 = {
0452         "id": "test_entry",
0453         "components": [
0454             {"name": "comp1", "type": "type1", "data": [1, 2, 3]},
0455             {"name": "comp2", "type": "type2", "data": [4, 5, 6]}
0456         ]
0457     }
0458 
0459     entry2 = {
0460         "id": "test_entry",
0461         "components": [
0462             {"name": "comp3", "type": "type3", "data": [7, 8, 9]},
0463             {"name": "comp1", "type": "type1", "data": [10, 11, 12]}  # Conflict with entry1
0464         ]
0465     }
0466 
0467     entries_with_id = [
0468         ("file1.json", entry1),
0469         ("file2.json", entry2)
0470     ]
0471 
0472     # Test normal merge (should raise ValueError due to conflict)
0473     with pytest.raises(ValueError):
0474         merged_entry = merge_entry_components("test_entry", entries_with_id)
0475 
0476     # Test with ignore flag
0477     merged_entry = merge_entry_components("test_entry", entries_with_id, ignore=True)
0478     component_names = [comp["name"] for comp in merged_entry["components"]]
0479     assert "comp1" in component_names
0480     assert "comp2" in component_names
0481     assert "comp3" in component_names
0482 
0483     # Verify that comp1 from file1 was kept
0484     for comp in merged_entry["components"]:
0485         if comp["name"] == "comp1":
0486             assert comp["data"] == [1, 2, 3]  # From entry1, not entry2
0487 
0488     # Test with overwrite flag
0489     merged_entry = merge_entry_components("test_entry", entries_with_id, overwrite=True)
0490     component_names = [comp["name"] for comp in merged_entry["components"]]
0491     assert "comp1" in component_names
0492     assert "comp2" in component_names
0493     assert "comp3" in component_names
0494 
0495     # Verify that comp1 from file2 overwrote the one from file1
0496     for comp in merged_entry["components"]:
0497         if comp["name"] == "comp1":
0498             assert comp["data"] == [10, 11, 12]  # From entry2, not entry1