File indexing completed on 2026-04-10 08:39:02
0001 import base64
0002 import datetime
0003 from threading import Lock
0004
0005 import jwt
0006 import requests
0007 from cryptography.hazmat.backends import default_backend
0008 from cryptography.hazmat.primitives import serialization
0009 from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
0010 from jwt.exceptions import InvalidTokenError
0011 from pandacommon.pandautils.PandaUtils import naive_utcnow
0012
0013
0014 def decode_value(val):
0015 if isinstance(val, str):
0016 val = val.encode()
0017 decoded = base64.urlsafe_b64decode(val + b"==")
0018 return int.from_bytes(decoded, "big")
0019
0020
0021 def rsa_pem_from_jwk(jwk):
0022 public_num = RSAPublicNumbers(n=decode_value(jwk["n"]), e=decode_value(jwk["e"]))
0023 public_key = public_num.public_key(default_backend())
0024 pem = public_key.public_bytes(
0025 encoding=serialization.Encoding.PEM,
0026 format=serialization.PublicFormat.SubjectPublicKeyInfo,
0027 )
0028 return pem
0029
0030
0031 def get_jwk(kid, jwks):
0032 for jwk in jwks.get("keys", []):
0033 if jwk.get("kid") == kid:
0034 return jwk
0035 raise InvalidTokenError("JWK not found for kid={0}".format(kid, str(jwks)))
0036
0037
0038
0039 class TokenDecoder:
0040
0041 def __init__(self, refresh_interval=10):
0042 self.lock = Lock()
0043 self.data = {}
0044 self.refresh_interval = refresh_interval
0045
0046
0047 def get_data(self, url, log_stream):
0048 try:
0049 with self.lock:
0050 if url not in self.data or naive_utcnow() - self.data[url]["last_update"] > datetime.timedelta(minutes=self.refresh_interval):
0051 log_stream.debug(f"to refresh {url}")
0052 tmp_data = requests.get(url).json()
0053 log_stream.debug("refreshed")
0054 self.data[url] = {
0055 "data": tmp_data,
0056 "last_update": naive_utcnow(),
0057 }
0058 return self.data[url]["data"]
0059 except Exception as e:
0060 log_stream.error(f"failed to refresh with {str(e)}")
0061 raise
0062
0063
0064 def deserialize_token(self, token, auth_config, vo, log_stream, legacy_token_issuers):
0065 try:
0066
0067 unverified = jwt.decode(token, verify=False, options={"verify_signature": False})
0068 conf_key = None
0069 audience = None
0070 if "aud" in unverified:
0071 audience = unverified["aud"]
0072 if audience in auth_config:
0073 conf_key = audience
0074 if not conf_key:
0075
0076 conf_key = unverified["sub"]
0077 discovery_endpoint = auth_config[conf_key]["oidc_config_url"]
0078
0079 headers = jwt.get_unverified_header(token)
0080
0081 if headers is None or "kid" not in headers:
0082 raise jwt.exceptions.InvalidTokenError("cannot extract kid from headers")
0083 kid = headers["kid"]
0084
0085 oidc_config = self.get_data(discovery_endpoint, log_stream)
0086 jwks = self.get_data(oidc_config["jwks_uri"], log_stream)
0087
0088 jwk = get_jwk(kid, jwks)
0089 public_key = rsa_pem_from_jwk(jwk)
0090
0091 if unverified["iss"] and unverified["iss"] != oidc_config["issuer"] and oidc_config["issuer"].startswith(unverified["iss"]):
0092
0093 issuer = unverified["iss"]
0094 else:
0095 issuer = oidc_config["issuer"]
0096 if legacy_token_issuers:
0097 issuers = list(dict.fromkeys([issuer] + legacy_token_issuers))
0098 else:
0099 issuers = [issuer]
0100 decoded = None
0101 err_msg = None
0102 for tmp_issuer in issuers:
0103 try:
0104 decoded = jwt.decode(
0105 token,
0106 public_key,
0107 verify=True,
0108 algorithms="RS256",
0109 audience=audience,
0110 issuer=tmp_issuer,
0111 )
0112 break
0113 except jwt.exceptions.InvalidIssuerError as e:
0114 err_msg = str(e)
0115 if not decoded:
0116 raise jwt.exceptions.InvalidTokenError(f"failed to decode: {err_msg}")
0117 if vo is not None:
0118 decoded["vo"] = vo
0119 else:
0120 decoded["vo"] = auth_config[conf_key]["vo"]
0121 return decoded
0122 except Exception:
0123 raise
0124
0125
0126
0127 def get_access_token(token_endpoint: str, client_id: str, client_secret: str, scope: str = None, timeout: int = 180) -> tuple[bool, str]:
0128 """
0129 Get an access token with client_credentials flow
0130
0131 :param token_endpoint: URL for token request
0132 :param client_id: client ID
0133 :param client_secret: client secret
0134 :param scope: space separated string of scopes
0135 :param timeout: timeout in seconds
0136
0137 :return: (True, access_token) or (False, error_str)
0138 """
0139 try:
0140 token_request = {
0141 "grant_type": "client_credentials",
0142 "client_id": client_id,
0143 "client_secret": client_secret,
0144 }
0145 if scope:
0146 token_request["scope"] = scope
0147 token_response = requests.post(token_endpoint, data=token_request, timeout=timeout)
0148 token_response.raise_for_status()
0149 return True, token_response.json()["access_token"]
0150 except Exception as e:
0151 error_str = f"failed to get access token with {str(e)}"
0152 return False, error_str