Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-03-29 08:37:26

0001 #!/usr/bin/env python3
0002 import argparse
0003 import ast
0004 import json
0005 import sys
0006 from pathlib import Path
0007 
0008 
0009 def _strip_magics(code: str) -> str:
0010     lines = []
0011     for line in code.splitlines():
0012         stripped = line.lstrip()
0013         if stripped.startswith(("%%", "%", "!")):
0014             continue
0015         lines.append(line)
0016     return "\n".join(lines)
0017 
0018 
0019 def load_code(path: Path) -> str:
0020     if path.suffix == ".ipynb":
0021         nb = json.loads(path.read_text())
0022         cells = []
0023         for cell in nb.get("cells", []):
0024             if cell.get("cell_type") != "code":
0025                 continue
0026             src = "".join(cell.get("source", []))
0027             cells.append(_strip_magics(src))
0028         return "\n".join(cells)
0029     return path.read_text()
0030 
0031 
0032 def stdlib_modules() -> set:
0033     names = set()
0034     try:
0035         names.update(sys.stdlib_module_names)
0036     except AttributeError:
0037         pass
0038     names.update(
0039         {
0040             "__future__",
0041             "typing",
0042             "pathlib",
0043             "sys",
0044             "os",
0045             "math",
0046             "json",
0047             "re",
0048             "time",
0049             "types",
0050             "fnmatch",
0051             "itertools",
0052             "collections",
0053             "subprocess",
0054             "logging",
0055             "argparse",
0056         }
0057     )
0058     return names
0059 
0060 
0061 def extract_modules(code: str) -> list:
0062     tree = ast.parse(code)
0063     modules = set()
0064     for node in ast.walk(tree):
0065         if isinstance(node, ast.Import):
0066             for alias in node.names:
0067                 root = alias.name.split(".")[0]
0068                 modules.add(root)
0069         elif isinstance(node, ast.ImportFrom):
0070             if node.level and node.level > 0:
0071                 continue
0072             if node.module:
0073                 root = node.module.split(".")[0]
0074                 modules.add(root)
0075     return sorted(modules)
0076 
0077 
0078 def build_script(packages: list) -> str:
0079     pkgs = "\n  ".join(packages)
0080     return f"""#!/usr/bin/env bash
0081 set -euo pipefail
0082 
0083 packages=(
0084   {pkgs}
0085 )
0086 
0087 missing=()
0088 while IFS= read -r pkg; do
0089   [ -z "${{pkg}}" ] && continue
0090   missing+=("${{pkg}}")
0091 done < <(python3 - <<'PY'
0092 import importlib.util
0093 
0094 packages = [
0095 {",\n".join([f"    {p!r}" for p in packages])}
0096 ]
0097 
0098 for pkg in packages:
0099     if importlib.util.find_spec(pkg) is None:
0100         print(pkg)
0101 PY
0102 )
0103 
0104 if [ ${{#missing[@]}} -eq 0 ]; then
0105   echo "All packages already installed."
0106   exit 0
0107 fi
0108 
0109 echo "Installing missing packages: ${{missing[*]}}"
0110 python3 -m pip install --user "${{missing[@]}}"
0111 """
0112 
0113 
0114 def main() -> int:
0115     parser = argparse.ArgumentParser(
0116         description="Generate a pip install-check script from a .py or .ipynb file."
0117     )
0118     parser.add_argument("input", type=Path, help="Path to .py or .ipynb file")
0119     parser.add_argument(
0120         "-o",
0121         "--output",
0122         type=Path,
0123         help='Output script path (default: "install_deps_<input_stem>.sh")',
0124     )
0125     args = parser.parse_args()
0126 
0127     code = load_code(args.input)
0128     modules = extract_modules(code)
0129     stdlib = stdlib_modules()
0130     packages = [m for m in modules if m not in stdlib]
0131     script = build_script(packages)
0132 
0133     output = args.output
0134     if output is None:
0135         output = Path(f"install_deps_{args.input.stem}.sh")
0136     output.write_text(script)
0137     print(f"done. please run: bash {output}")
0138     return 0
0139 
0140 
0141 if __name__ == "__main__":
0142     raise SystemExit(main())