Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-10 08:39:02

0001 import base64
0002 import datetime
0003 from threading import Lock
0004 
0005 import jwt
0006 import requests
0007 from cryptography.hazmat.backends import default_backend
0008 from cryptography.hazmat.primitives import serialization
0009 from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
0010 from jwt.exceptions import InvalidTokenError
0011 from pandacommon.pandautils.PandaUtils import naive_utcnow
0012 
0013 
0014 def decode_value(val):
0015     if isinstance(val, str):
0016         val = val.encode()
0017     decoded = base64.urlsafe_b64decode(val + b"==")
0018     return int.from_bytes(decoded, "big")
0019 
0020 
0021 def rsa_pem_from_jwk(jwk):
0022     public_num = RSAPublicNumbers(n=decode_value(jwk["n"]), e=decode_value(jwk["e"]))
0023     public_key = public_num.public_key(default_backend())
0024     pem = public_key.public_bytes(
0025         encoding=serialization.Encoding.PEM,
0026         format=serialization.PublicFormat.SubjectPublicKeyInfo,
0027     )
0028     return pem
0029 
0030 
0031 def get_jwk(kid, jwks):
0032     for jwk in jwks.get("keys", []):
0033         if jwk.get("kid") == kid:
0034             return jwk
0035     raise InvalidTokenError("JWK not found for kid={0}".format(kid, str(jwks)))
0036 
0037 
0038 # token decoder
0039 class TokenDecoder:
0040     # constructor
0041     def __init__(self, refresh_interval=10):
0042         self.lock = Lock()
0043         self.data = {}
0044         self.refresh_interval = refresh_interval
0045 
0046     # get cached data
0047     def get_data(self, url, log_stream):
0048         try:
0049             with self.lock:
0050                 if url not in self.data or naive_utcnow() - self.data[url]["last_update"] > datetime.timedelta(minutes=self.refresh_interval):
0051                     log_stream.debug(f"to refresh {url}")
0052                     tmp_data = requests.get(url).json()
0053                     log_stream.debug("refreshed")
0054                     self.data[url] = {
0055                         "data": tmp_data,
0056                         "last_update": naive_utcnow(),
0057                     }
0058                 return self.data[url]["data"]
0059         except Exception as e:
0060             log_stream.error(f"failed to refresh with {str(e)}")
0061             raise
0062 
0063     # decode and verify JWT token
0064     def deserialize_token(self, token, auth_config, vo, log_stream, legacy_token_issuers):
0065         try:
0066             # check audience
0067             unverified = jwt.decode(token, verify=False, options={"verify_signature": False})
0068             conf_key = None
0069             audience = None
0070             if "aud" in unverified:
0071                 audience = unverified["aud"]
0072                 if audience in auth_config:
0073                     conf_key = audience
0074             if not conf_key:
0075                 # use sub as config key for access token
0076                 conf_key = unverified["sub"]
0077             discovery_endpoint = auth_config[conf_key]["oidc_config_url"]
0078             # decode headers
0079             headers = jwt.get_unverified_header(token)
0080             # get key id
0081             if headers is None or "kid" not in headers:
0082                 raise jwt.exceptions.InvalidTokenError("cannot extract kid from headers")
0083             kid = headers["kid"]
0084             # retrieve OIDC configuration and JWK set
0085             oidc_config = self.get_data(discovery_endpoint, log_stream)
0086             jwks = self.get_data(oidc_config["jwks_uri"], log_stream)
0087             # get JWK and public key
0088             jwk = get_jwk(kid, jwks)
0089             public_key = rsa_pem_from_jwk(jwk)
0090             # decode token only with RS256
0091             if unverified["iss"] and unverified["iss"] != oidc_config["issuer"] and oidc_config["issuer"].startswith(unverified["iss"]):
0092                 # iss is missing the last '/' in access tokens
0093                 issuer = unverified["iss"]
0094             else:
0095                 issuer = oidc_config["issuer"]
0096             if legacy_token_issuers:
0097                 issuers = list(dict.fromkeys([issuer] + legacy_token_issuers))
0098             else:
0099                 issuers = [issuer]
0100             decoded = None
0101             err_msg = None
0102             for tmp_issuer in issuers:
0103                 try:
0104                     decoded = jwt.decode(
0105                         token,
0106                         public_key,
0107                         verify=True,
0108                         algorithms="RS256",
0109                         audience=audience,
0110                         issuer=tmp_issuer,
0111                     )
0112                     break
0113                 except jwt.exceptions.InvalidIssuerError as e:
0114                     err_msg = str(e)
0115             if not decoded:
0116                 raise jwt.exceptions.InvalidTokenError(f"failed to decode: {err_msg}")
0117             if vo is not None:
0118                 decoded["vo"] = vo
0119             else:
0120                 decoded["vo"] = auth_config[conf_key]["vo"]
0121             return decoded
0122         except Exception:
0123             raise
0124 
0125 
0126 # get an access token with client_credentials flow
0127 def get_access_token(token_endpoint: str, client_id: str, client_secret: str, scope: str = None, timeout: int = 180) -> tuple[bool, str]:
0128     """
0129     Get an access token with client_credentials flow
0130 
0131     :param token_endpoint: URL for token request
0132     :param client_id: client ID
0133     :param client_secret: client secret
0134     :param scope: space separated string of scopes
0135     :param timeout: timeout in seconds
0136 
0137     :return: (True, access_token) or (False, error_str)
0138     """
0139     try:
0140         token_request = {
0141             "grant_type": "client_credentials",
0142             "client_id": client_id,
0143             "client_secret": client_secret,
0144         }
0145         if scope:
0146             token_request["scope"] = scope
0147         token_response = requests.post(token_endpoint, data=token_request, timeout=timeout)
0148         token_response.raise_for_status()
0149         return True, token_response.json()["access_token"]
0150     except Exception as e:
0151         error_str = f"failed to get access token with {str(e)}"
0152         return False, error_str