Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-15 09:23:41

0001 #!/usr/bin/env python3
0002 # /// script
0003 # requires-python = ">=3.11"
0004 # dependencies = [
0005 #   "typer",
0006 #   "httpx",
0007 #   "rich",
0008 # ]
0009 # ///
0010 
0011 import asyncio
0012 import concurrent.futures
0013 import hashlib
0014 import shutil
0015 import subprocess
0016 import sys
0017 import tarfile
0018 import tempfile
0019 from pathlib import Path
0020 from typing import Annotated
0021 
0022 import httpx
0023 import typer
0024 from rich.console import Console
0025 from rich.progress import (
0026     BarColumn,
0027     DownloadColumn,
0028     Progress,
0029     TaskID,
0030     TextColumn,
0031     TimeRemainingColumn,
0032     TransferSpeedColumn,
0033 )
0034 
0035 console = Console()
0036 app = typer.Typer()
0037 
0038 
0039 def find_geant4_config() -> Path:
0040     """Find geant4-config in PATH."""
0041     result = shutil.which("geant4-config")
0042     if not result:
0043         console.print("[red]Error: geant4-config not found in PATH[/red]")
0044         raise typer.Exit(1)
0045     return Path(result)
0046 
0047 
0048 def parse_datasets(config_path: Path) -> list[dict]:
0049     """Parse dataset information from geant4-config script."""
0050     with open(config_path) as f:
0051         content = f.read()
0052 
0053     # Find the dataset_list line
0054     for line in content.splitlines():
0055         if line.strip().startswith("dataset_list="):
0056             break
0057     else:
0058         console.print("[red]Error: Could not find dataset_list in geant4-config[/red]")
0059         raise typer.Exit(1)
0060 
0061     # Extract the awk script dataset string
0062     # Format: NAME|ENVVAR|PATH|FILENAME|MD5;...
0063     # The string is within quotes before the ", array," part
0064     start = line.find('"') + 1
0065     # Find the end quote before ", array,"
0066     end = line.find('", array,')
0067     if end == -1:
0068         end = line.rfind('"')
0069     dataset_string = line[start:end]
0070 
0071     datasets = []
0072     for entry in dataset_string.split(";"):
0073         if not entry.strip():
0074             continue
0075         parts = entry.split("|")
0076         if len(parts) >= 5:
0077             datasets.append(
0078                 {
0079                     "name": parts[0],
0080                     "envvar": parts[1],
0081                     "path": parts[2],
0082                     "filename": parts[3],
0083                     "md5": parts[4],
0084                 }
0085             )
0086 
0087     return datasets
0088 
0089 
0090 def get_dataset_url(config_path: Path) -> str:
0091     """Get the base URL for datasets from geant4-config."""
0092     with open(config_path) as f:
0093         content = f.read()
0094 
0095     # Find the dataset_url line
0096     for line in content.splitlines():
0097         if line.strip().startswith("dataset_url="):
0098             # Extract URL from line like: dataset_url="https://cern.ch/geant4-data/datasets"
0099             start = line.find('"') + 1
0100             end = line.rfind('"')
0101             return line[start:end]
0102 
0103     # Fallback to default URL
0104     return "https://cern.ch/geant4-data/datasets"
0105 
0106 
0107 def verify_md5(filepath: Path, expected_md5: str) -> bool:
0108     """Verify MD5 checksum of a file."""
0109     md5 = hashlib.md5()
0110     with open(filepath, "rb") as f:
0111         for chunk in iter(lambda: f.read(8192), b""):
0112             md5.update(chunk)
0113     return md5.hexdigest() == expected_md5
0114 
0115 
0116 def extract_and_install(
0117     tarball_path: Path, temp_dir: Path, dest_dir: Path
0118 ) -> tuple[bool, str]:
0119     """Extract tarball and install to final location (runs in process pool)."""
0120     try:
0121         # Extract
0122         with tarfile.open(tarball_path, "r:gz") as tar:
0123             tar.extractall(temp_dir, filter="data")
0124 
0125         # Install to final location
0126         dataset_dir_name = dest_dir.name
0127         src_dir = temp_dir / dataset_dir_name
0128 
0129         dest_dir.parent.mkdir(parents=True, exist_ok=True)
0130 
0131         if dest_dir.exists():
0132             shutil.rmtree(dest_dir)
0133 
0134         shutil.move(str(src_dir), str(dest_dir))
0135 
0136         # Clean up tarball
0137         tarball_path.unlink()
0138 
0139         return True, f"Successfully installed {dataset_dir_name}"
0140     except Exception as e:
0141         return False, f"Failed to extract/install: {e}"
0142 
0143 
0144 async def download_dataset(
0145     client: httpx.AsyncClient,
0146     dataset: dict,
0147     base_url: str,
0148     temp_dir: Path,
0149     progress: Progress,
0150     executor: concurrent.futures.ProcessPoolExecutor,
0151 ) -> tuple[bool, str]:
0152     """Download and verify a single dataset."""
0153     filename = dataset["filename"]
0154     url = f"{base_url}/{filename}"
0155     dest_path = temp_dir / filename
0156 
0157     task_id = progress.add_task(f"[cyan]{filename}", total=None)
0158 
0159     try:
0160         # Download
0161         async with client.stream("GET", url, follow_redirects=True) as response:
0162             response.raise_for_status()
0163             total = int(response.headers.get("content-length", 0))
0164             progress.update(task_id, total=total)
0165 
0166             with open(dest_path, "wb") as f:
0167                 downloaded = 0
0168                 async for chunk in response.aiter_bytes(chunk_size=8192):
0169                     f.write(chunk)
0170                     downloaded += len(chunk)
0171                     progress.update(task_id, completed=downloaded)
0172 
0173         # Verify MD5 in process pool (non-blocking)
0174         progress.update(task_id, description=f"[yellow]{filename} (verifying)")
0175         loop = asyncio.get_event_loop()
0176         md5_valid = await loop.run_in_executor(
0177             executor,
0178             verify_md5,
0179             dest_path,
0180             dataset["md5"],
0181         )
0182 
0183         if not md5_valid:
0184             progress.update(task_id, description=f"[red]{filename} (MD5 mismatch)")
0185             return False, f"MD5 mismatch for {filename}"
0186 
0187         # Extract and install in process pool (non-blocking)
0188         progress.update(task_id, description=f"[yellow]{filename} (extracting)")
0189         dest_dir = Path(dataset["path"])
0190         success, msg = await loop.run_in_executor(
0191             executor,
0192             extract_and_install,
0193             dest_path,
0194             temp_dir,
0195             dest_dir,
0196         )
0197 
0198         if success:
0199             progress.update(task_id, description=f"[green]{filename} (installed)")
0200             return True, f"Successfully installed {dataset['name']}"
0201         else:
0202             progress.update(task_id, description=f"[red]{filename} (failed)")
0203             return False, msg
0204 
0205     except Exception as e:
0206         progress.update(task_id, description=f"[red]{filename} (failed)")
0207         return False, f"Failed to download {filename}: {e}"
0208 
0209 
0210 async def download_all_datasets(
0211     datasets: list[dict],
0212     base_url: str,
0213     temp_dir: Path,
0214     max_concurrent: int,
0215     dry_run: bool = False,
0216     force: bool = False,
0217 ) -> int:
0218     """Download all datasets with limited concurrency.
0219 
0220     Returns:
0221         Number of failures (0 if all successful)
0222     """
0223     # Filter out already installed datasets
0224     if force:
0225         datasets_to_install = datasets
0226     else:
0227         datasets_to_install = [ds for ds in datasets if not Path(ds["path"]).exists()]
0228 
0229     if not datasets_to_install:
0230         console.print("[green]All datasets already installed[/green]")
0231         return 0
0232 
0233     if dry_run:
0234         console.print(
0235             f"[yellow]DRY RUN: Would download {len(datasets_to_install)} datasets:[/yellow]"
0236         )
0237         for ds in datasets_to_install:
0238             console.print(f"  [cyan]•[/cyan] {ds['name']} ({ds['filename']})")
0239             console.print(f"    URL: {base_url}/{ds['filename']}")
0240             console.print(f"    Destination: {ds['path']}")
0241             console.print(f"    MD5: {ds['md5']}")
0242         return 0
0243 
0244     console.print(f"[cyan]Downloading {len(datasets_to_install)} datasets...[/cyan]")
0245 
0246     progress = Progress(
0247         TextColumn("[bold blue]{task.description}"),
0248         BarColumn(),
0249         DownloadColumn(),
0250         TransferSpeedColumn(),
0251         TimeRemainingColumn(),
0252         console=console,
0253     )
0254 
0255     # Use process pool for extraction
0256     with concurrent.futures.ProcessPoolExecutor() as executor:
0257         async with httpx.AsyncClient(timeout=1800.0) as client:
0258             with progress:
0259                 # Use semaphore to limit concurrent downloads
0260                 semaphore = asyncio.Semaphore(max_concurrent)
0261 
0262                 async def bounded_download(dataset):
0263                     async with semaphore:
0264                         return await download_dataset(
0265                             client, dataset, base_url, temp_dir, progress, executor
0266                         )
0267 
0268                 results = await asyncio.gather(
0269                     *[bounded_download(ds) for ds in datasets_to_install],
0270                     return_exceptions=True,
0271                 )
0272 
0273     # Print summary
0274     console.print()
0275     successes = sum(1 for r in results if not isinstance(r, Exception) and r[0])
0276     failures = len(results) - successes
0277 
0278     if failures == 0:
0279         console.print(f"[green]✓ Successfully installed {successes} datasets[/green]")
0280     else:
0281         console.print(
0282             f"[yellow]⚠ Installed {successes} datasets, {failures} failed[/yellow]"
0283         )
0284         for result in results:
0285             if isinstance(result, Exception) or not result[0]:
0286                 msg = str(result) if isinstance(result, Exception) else result[1]
0287                 console.print(f"[red]  • {msg}[/red]")
0288 
0289     return failures
0290 
0291 
0292 @app.command()
0293 def main(
0294     max_concurrent: Annotated[
0295         int, typer.Option("--jobs", "-j", help="Maximum concurrent downloads")
0296     ] = 4,
0297     dry_run: Annotated[
0298         bool,
0299         typer.Option(
0300             "--dry-run",
0301             help="Show what would be downloaded without actually downloading",
0302         ),
0303     ] = False,
0304     force: Annotated[
0305         bool,
0306         typer.Option(
0307             "--force",
0308             help="Redownload and reinstall datasets even if they already exist",
0309         ),
0310     ] = False,
0311     config: Annotated[
0312         Path | None, typer.Option("--config", help="Path to geant4-config script")
0313     ] = None,
0314 ) -> None:
0315     """Download Geant4 datasets in parallel."""
0316     # Find geant4-config
0317     if config:
0318         config_path = config
0319         if not config_path.exists():
0320             console.print(f"[red]Error: {config_path} does not exist[/red]")
0321             raise typer.Exit(1)
0322     else:
0323         config_path = find_geant4_config()
0324     console.print(f"[cyan]Found geant4-config at: {config_path}[/cyan]")
0325 
0326     # Parse datasets
0327     datasets = parse_datasets(config_path)
0328     console.print(f"[cyan]Found {len(datasets)} datasets[/cyan]")
0329 
0330     # Get base URL
0331     base_url = get_dataset_url(config_path)
0332     console.print(f"[cyan]Base URL: {base_url}[/cyan]")
0333     console.print()
0334 
0335     # Create temp directory
0336     with tempfile.TemporaryDirectory(prefix="geant4-downloads-") as temp_dir:
0337         # Download datasets
0338         failures = asyncio.run(
0339             download_all_datasets(
0340                 datasets, base_url, Path(temp_dir), max_concurrent, dry_run, force
0341             )
0342         )
0343 
0344     # Exit with error code if any downloads failed
0345     if failures > 0:
0346         raise typer.Exit(1)
0347 
0348 
0349 if __name__ == "__main__":
0350     app()