Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /panda-server/pandaserver/pandamcp/client_utils/panda_mcp_proxy.py was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #!/usr/bin/env python3
0002 """
0003 panda_mcp_proxy.py
0004 
0005 Cross-platform MCP proxy with automatic token refresh.
0006 Drop-in replacement for panda_mcp_wrapper.sh that works on Windows, macOS, and Linux.
0007 
0008 The core problem it solves: panda_mcp_wrapper.sh passes the id_token as a static
0009 header to mcp-remote at startup. After 15 minutes the token expires and every MCP
0010 call fails until Claude Desktop is restarted. This proxy refreshes the token
0011 transparently on every outbound request, so sessions can run indefinitely.
0012 
0013 Architecture:
0014     LLM client as Claude Desktop or LM Studio <--stdio (JSON-RPC)--> this proxy <--HTTPS/SSE--> remote MCP server
0015 
0016 Requirements:
0017     pip install httpx
0018 
0019 Configuration in LLM client (claude_desktop_config.json or mcp.json):
0020     {
0021       "mcpServers": {
0022         "panda-mcp": {
0023           "command": "python3",
0024           "args": ["/path/to/panda_mcp_proxy.py"]
0025         }
0026       }
0027     }
0028 
0029 Environment variables (all optional, same defaults as panda_mcp_wrapper.sh):
0030     PANDA_SERVER     PanDA server URL          (default: https://pandaserver.cern.ch:25443)
0031     VO               Virtual organisation      (default: atlas)
0032     TOKEN_FILE       Path to token cache file  (default: ~/.panda_id_token)
0033     MCP_URL          Remote MCP server URL     (default: https://aipanda120.cern.ch:8443/mcp/)
0034     SSL_CERT_FILE    Path to CA bundle file    (optional, for custom CAs)
0035     REQUESTS_CA_BUNDLE  Alternative to SSL_CERT_FILE
0036 """
0037 
0038 import asyncio
0039 import base64
0040 import json
0041 import logging
0042 import os
0043 import pathlib
0044 import ssl
0045 import sys
0046 import time
0047 import urllib.error
0048 import urllib.parse
0049 import urllib.request
0050 
0051 try:
0052     import httpx
0053 except ImportError:
0054     print("ERROR: httpx is required. Install with: pip install httpx", file=sys.stderr)
0055     sys.exit(1)
0056 
0057 
0058 # ── Configuration ─────────────────────────────────────────────────────────────
0059 
0060 PANDA_SERVER = os.environ.get("PANDA_SERVER", "https://pandaserver.cern.ch:25443")
0061 VO = os.environ.get("VO", "atlas")
0062 TOKEN_FILE = pathlib.Path(os.environ.get("TOKEN_FILE", pathlib.Path.home() / ".panda_id_token"))
0063 MCP_URL = os.environ.get("MCP_URL", "https://aipanda120.cern.ch:8443/mcp/")
0064 
0065 # Refresh the token this many seconds before it actually expires
0066 TOKEN_REFRESH_MARGIN = 300
0067 
0068 # How long to wait before retrying a failed SSE connection (seconds)
0069 SSE_RECONNECT_DELAY = 5
0070 
0071 
0072 # ── Logging ───────────────────────────────────────────────────────────────────
0073 
0074 logging.basicConfig(
0075     level=logging.WARNING,
0076     format="%(asctime)s [panda_mcp_proxy] %(levelname)s %(message)s",
0077     stream=sys.stderr,
0078 )
0079 log = logging.getLogger(__name__)
0080 
0081 
0082 # ── SSL context ───────────────────────────────────────────────────────────────
0083 
0084 
0085 def _build_ssl_context() -> ssl.SSLContext:
0086     ctx = ssl.create_default_context()
0087     ca_file = os.environ.get("SSL_CERT_FILE") or os.environ.get("REQUESTS_CA_BUNDLE")
0088     if ca_file:
0089         ctx.load_verify_locations(cafile=ca_file)
0090     return ctx
0091 
0092 
0093 # ── Token management ──────────────────────────────────────────────────────────
0094 
0095 
0096 class TokenManager:
0097     """Thread-safe token cache with silent refresh via refresh_token."""
0098 
0099     def __init__(self):
0100         self._lock = asyncio.Lock()
0101         self._id_token: str = ""
0102         self._exp: float = 0.0
0103 
0104     # ── private helpers ───────────────────────────────────────────────────────
0105 
0106     @staticmethod
0107     def _decode_exp(id_token: str) -> float:
0108         try:
0109             payload = id_token.split(".")[1]
0110             payload += "=" * (-len(payload) % 4)
0111             claims = json.loads(base64.urlsafe_b64decode(payload))
0112             return float(claims.get("exp", 0))
0113         except Exception:
0114             return 0.0
0115 
0116     @staticmethod
0117     def _load_file() -> dict:
0118         try:
0119             return json.loads(TOKEN_FILE.read_text())
0120         except Exception:
0121             return {}
0122 
0123     @staticmethod
0124     def _save_file(data: dict) -> None:
0125         try:
0126             TOKEN_FILE.write_text(json.dumps(data))
0127         except Exception as exc:
0128             log.warning("Could not update token file: %s", exc)
0129 
0130     # urllib is used here (no httpx) so this can be called without an async client
0131     @staticmethod
0132     def _http_get_json(url: str, ssl_ctx: ssl.SSLContext) -> dict:
0133         req = urllib.request.Request(url)
0134         with urllib.request.urlopen(req, context=ssl_ctx, timeout=15) as r:
0135             return json.load(r)
0136 
0137     @staticmethod
0138     def _http_post_form(url: str, data: dict, ssl_ctx: ssl.SSLContext) -> dict:
0139         encoded = urllib.parse.urlencode(data).encode()
0140         req = urllib.request.Request(url, data=encoded, method="POST")
0141         req.add_header("Content-Type", "application/x-www-form-urlencoded")
0142         with urllib.request.urlopen(req, context=ssl_ctx, timeout=15) as r:
0143             return json.load(r)
0144 
0145     async def _do_refresh(self, refresh_token: str) -> str:
0146         """Use refresh_token to obtain a fresh id_token. Returns '' on failure."""
0147         ssl_ctx = _build_ssl_context()
0148         loop = asyncio.get_event_loop()
0149 
0150         try:
0151             auth_cfg = await loop.run_in_executor(None, self._http_get_json, f"{PANDA_SERVER}/auth/{VO}_auth_config.json", ssl_ctx)
0152             oidc_cfg = await loop.run_in_executor(None, self._http_get_json, auth_cfg["oidc_config_url"], ssl_ctx)
0153             token_resp = await loop.run_in_executor(
0154                 None,
0155                 self._http_post_form,
0156                 oidc_cfg["token_endpoint"],
0157                 {
0158                     "grant_type": "refresh_token",
0159                     "client_id": auth_cfg["client_id"],
0160                     "client_secret": auth_cfg.get("client_secret", ""),
0161                     "refresh_token": refresh_token,
0162                 },
0163                 ssl_ctx,
0164             )
0165         except Exception as exc:
0166             log.error("Token refresh request failed: %s", exc)
0167             return ""
0168 
0169         id_token = token_resp.get("id_token", "")
0170         if id_token:
0171             self._save_file(token_resp)
0172             log.warning("id_token refreshed successfully.")
0173         else:
0174             log.error("Refresh response missing id_token: %s", token_resp.get("error", "unknown"))
0175         return id_token
0176 
0177     # ── public API ────────────────────────────────────────────────────────────
0178 
0179     async def get(self) -> str:
0180         """Return a valid id_token, refreshing silently if needed."""
0181         async with self._lock:
0182             now = time.time()
0183 
0184             # In-memory token still good
0185             if self._id_token and self._exp - now > TOKEN_REFRESH_MARGIN:
0186                 return self._id_token
0187 
0188             # Try token file
0189             data = self._load_file()
0190             id_token = data.get("id_token", "")
0191             if id_token:
0192                 exp = self._decode_exp(id_token)
0193                 if exp - now > TOKEN_REFRESH_MARGIN:
0194                     self._id_token = id_token
0195                     self._exp = exp
0196                     return id_token
0197 
0198             # Attempt silent refresh
0199             refresh_token = data.get("refresh_token", "")
0200             if refresh_token:
0201                 log.warning("id_token expired or close to expiry — refreshing silently…")
0202                 id_token = await self._do_refresh(refresh_token)
0203                 if id_token:
0204                     self._id_token = id_token
0205                     self._exp = self._decode_exp(id_token)
0206                     return id_token
0207 
0208             raise RuntimeError("No valid token found. Run get_panda_token.sh (or get_panda_token.py) first.")
0209 
0210 
0211 # ── SSE helpers ───────────────────────────────────────────────────────────────
0212 
0213 
0214 def _parse_sse(lines: list[str]) -> tuple[str, str]:
0215     """Parse a complete SSE block (list of non-empty lines) into (event, data)."""
0216     event = "message"
0217     data = ""
0218     for line in lines:
0219         if line.startswith("event:"):
0220             event = line[6:].strip()
0221         elif line.startswith("data:"):
0222             data = line[5:].strip()
0223     return event, data
0224 
0225 
0226 # ── Stdio helpers (cross-platform) ────────────────────────────────────────────
0227 
0228 
0229 async def _read_stdin_lines(queue: asyncio.Queue) -> None:
0230     """Read newline-delimited JSON from stdin and push to queue. Runs in a thread."""
0231     loop = asyncio.get_event_loop()
0232 
0233     def _blocking_read():
0234         # sys.stdin.readline returns '' on EOF; works on Windows and Unix
0235         return sys.stdin.readline()
0236 
0237     while True:
0238         line = await loop.run_in_executor(None, _blocking_read)
0239         if not line:
0240             break
0241         line = line.strip()
0242         if line:
0243             await queue.put(line)
0244 
0245 
0246 async def _write_stdout_lines(queue: asyncio.Queue) -> None:
0247     """Write newline-delimited JSON from queue to stdout."""
0248     loop = asyncio.get_event_loop()
0249 
0250     def _blocking_write(msg: str):
0251         sys.stdout.write(msg + "\n")
0252         sys.stdout.flush()
0253 
0254     while True:
0255         msg = await queue.get()
0256         await loop.run_in_executor(None, _blocking_write, msg)
0257 
0258 
0259 # ── MCP proxy (Streamable HTTP transport) ────────────────────────────────────
0260 #
0261 # MCP spec 2025-03-26+ uses "Streamable HTTP":
0262 #   - Every client message is a POST to MCP_URL
0263 #   - The POST response is either plain JSON or an SSE stream (both handled below)
0264 #   - An optional GET to MCP_URL opens a server-push SSE channel; 405 means unsupported
0265 #   - The server returns Mcp-Session-Id on the initialize response; subsequent
0266 #     requests include it so the server can correlate the session
0267 
0268 
0269 class MCPProxy:
0270     def __init__(self):
0271         self.tokens = TokenManager()
0272         self._session_id: str | None = None
0273         self._outbound: asyncio.Queue = asyncio.Queue()  # stdin  → remote
0274         self._inbound: asyncio.Queue = asyncio.Queue()  # remote → stdout
0275 
0276     async def _headers(self, extra: dict | None = None) -> dict:
0277         token = await self.tokens.get()
0278         h = {
0279             "Authorization": f"Bearer {token}",
0280             "Origin": VO,
0281             "X-Auth-Token": f"Bearer {token}",
0282         }
0283         if self._session_id:
0284             h["Mcp-Session-Id"] = self._session_id
0285         if extra:
0286             h.update(extra)
0287         return h
0288 
0289     async def _handle_response(self, resp: httpx.Response) -> None:
0290         """Read a POST response — either plain JSON or an SSE stream."""
0291         # Capture session ID returned on initialize
0292         if sid := resp.headers.get("Mcp-Session-Id"):
0293             if self._session_id != sid:
0294                 self._session_id = sid
0295                 log.warning("MCP session ID: %s", sid)
0296 
0297         content_type = resp.headers.get("content-type", "")
0298         if "text/event-stream" in content_type:
0299             block: list[str] = []
0300             async for raw_line in resp.aiter_lines():
0301                 if raw_line:
0302                     block.append(raw_line)
0303                 elif block:
0304                     _, data = _parse_sse(block)
0305                     block.clear()
0306                     if data:
0307                         await self._inbound.put(data)
0308         else:
0309             body = await resp.aread()
0310             if body.strip():
0311                 await self._inbound.put(body.decode())
0312 
0313     async def _sender_loop(self, client: httpx.AsyncClient) -> None:
0314         """Forward every stdin message to the remote server via POST."""
0315         while True:
0316             msg = await self._outbound.get()
0317             try:
0318                 headers = await self._headers(
0319                     {
0320                         "Content-Type": "application/json",
0321                         "Accept": "application/json, text/event-stream",
0322                     }
0323                 )
0324                 async with client.stream("POST", MCP_URL, content=msg, headers=headers) as resp:
0325                     resp.raise_for_status()
0326                     await self._handle_response(resp)
0327             except Exception as exc:
0328                 log.error("POST failed: %s", exc)
0329 
0330     async def _server_push_loop(self, client: httpx.AsyncClient) -> None:
0331         """Optional GET SSE channel for server-initiated messages. Exits if unsupported (405)."""
0332         while True:
0333             try:
0334                 headers = await self._headers({"Accept": "text/event-stream"})
0335                 log.warning("Opening server-push SSE channel to %s", MCP_URL)
0336                 async with client.stream("GET", MCP_URL, headers=headers) as resp:
0337                     if resp.status_code == 405:
0338                         log.warning("Server does not support GET SSE push — skipping.")
0339                         return
0340                     resp.raise_for_status()
0341                     block: list[str] = []
0342                     async for raw_line in resp.aiter_lines():
0343                         if raw_line:
0344                             block.append(raw_line)
0345                         elif block:
0346                             _, data = _parse_sse(block)
0347                             block.clear()
0348                             if data:
0349                                 await self._inbound.put(data)
0350             except httpx.HTTPStatusError as exc:
0351                 if exc.response.status_code in (400, 405):
0352                     log.warning("Server does not support GET SSE push (%s) — skipping.", exc.response.status_code)
0353                     return
0354                 log.error("Server-push SSE failed (%s) — retrying in %ds", exc.response.status_code, SSE_RECONNECT_DELAY)
0355             except Exception as exc:
0356                 log.error("Server-push SSE error: %s — retrying in %ds", exc, SSE_RECONNECT_DELAY)
0357 
0358             await asyncio.sleep(SSE_RECONNECT_DELAY)
0359 
0360     async def run(self) -> None:
0361         # Fail fast if no token is available before accepting any MCP traffic
0362         await self.tokens.get()
0363 
0364         ssl_ctx = _build_ssl_context()
0365         async with httpx.AsyncClient(verify=ssl_ctx, timeout=None, follow_redirects=True) as client:
0366             await asyncio.gather(
0367                 _read_stdin_lines(self._outbound),
0368                 _write_stdout_lines(self._inbound),
0369                 self._sender_loop(client),
0370                 self._server_push_loop(client),
0371             )
0372 
0373 
0374 # ── Entry point ───────────────────────────────────────────────────────────────
0375 
0376 if __name__ == "__main__":
0377     # On Windows, stdin/stdout must be in binary-compatible text mode
0378     if sys.platform == "win32":
0379         import msvcrt
0380 
0381         msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
0382         msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
0383 
0384     try:
0385         asyncio.run(MCPProxy().run())
0386     except KeyboardInterrupt:
0387         pass
0388     except RuntimeError as exc:
0389         print(f"ERROR: {exc}", file=sys.stderr)
0390         sys.exit(1)