File indexing completed on 2026-04-09 07:58:19
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 import base64
0012 import json
0013 import jwt
0014 import traceback
0015
0016
0017 from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
0018 from cryptography.hazmat.backends import default_backend
0019 from cryptography.hazmat.primitives import serialization
0020
0021 from idds.common import authentication
0022
0023
0024 def decode_value(val):
0025 if isinstance(val, str):
0026 val = val.encode()
0027 decoded = base64.urlsafe_b64decode(val + b'==')
0028 return int.from_bytes(decoded, 'big')
0029
0030
0031 class OIDCAuthentication(authentication.OIDCAuthentication):
0032 def __init__(self, timeout=None):
0033 super(OIDCAuthentication, self).__init__(timeout=timeout)
0034
0035 def get_public_key(self, token, jwks_uri, no_verify=False, with_cache=True):
0036 headers = jwt.get_unverified_header(token)
0037 if headers is None or 'kid' not in headers:
0038 raise jwt.exceptions.InvalidTokenError('cannot extract kid from headers')
0039 kid = headers['kid']
0040
0041 if with_cache:
0042 jwks = self.get_cache_value(jwks_uri)
0043 else:
0044 jwks = None
0045
0046 if not jwks:
0047 jwks_content = self.get_http_content(jwks_uri, no_verify=no_verify)
0048 jwks = json.loads(jwks_content)
0049 self.set_cache_value(jwks_uri, jwks)
0050
0051 jwk = None
0052 for j in jwks.get('keys', []):
0053 if j.get('kid') == kid:
0054 jwk = j
0055 if jwk is None:
0056 raise jwt.exceptions.InvalidTokenError('JWK not found for kid={0}: {1}'.format(kid, str(jwks)))
0057
0058 public_num = RSAPublicNumbers(n=decode_value(jwk['n']), e=decode_value(jwk['e']))
0059 public_key = public_num.public_key(default_backend())
0060 pem = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
0061 return pem
0062
0063 def verify_id_token_cache(self, vo, token, with_cache=True):
0064 try:
0065 auth_config, endpoint_config = self.get_auth_endpoint_config(vo)
0066
0067
0068 decoded_token = jwt.decode(token, verify=False, options={"verify_signature": False})
0069 audience = decoded_token['aud']
0070 if audience not in [auth_config['audience'], auth_config['client_id']]:
0071
0072 return False, "The audience %s of the token doesn't match vo configuration(client_id: %s)." % (audience, auth_config['client_id']), None
0073
0074 public_key = self.get_public_key(token, endpoint_config['jwks_uri'], no_verify=auth_config['no_verify'], with_cache=with_cache)
0075
0076 if 'iss' in decoded_token and decoded_token['iss'] and decoded_token['iss'] != endpoint_config['issuer'] and endpoint_config['issuer'].startswith(decoded_token['iss']):
0077
0078 issuer = decoded_token['iss']
0079 else:
0080 issuer = endpoint_config['issuer']
0081
0082 decoded = jwt.decode(token, public_key, verify=True, algorithms='RS256',
0083 audience=audience, issuer=issuer)
0084 decoded['vo'] = vo
0085 if 'name' in decoded:
0086 username = decoded['name']
0087 else:
0088 username = None
0089 return True, decoded, username
0090 except Exception as error:
0091 print(error)
0092 print(traceback.format_exc())
0093 return False, 'Failed to verify oidc token: ' + str(error), None
0094
0095 def verify_id_token(self, vo, token):
0096 status, data, username = self.verify_id_token_cache(vo, token, with_cache=True)
0097 if status:
0098 return status, data, username
0099 return self.verify_id_token_cache(vo, token, with_cache=False)
0100
0101
0102 class OIDCAuthenticationUtils(authentication.OIDCAuthenticationUtils):
0103 def __init__(self):
0104 super(OIDCAuthenticationUtils, self).__init__()
0105
0106
0107 class X509Authentication(authentication.X509Authentication):
0108 def __init__(self, timeout=None):
0109 super(X509Authentication, self).__init__(timeout=timeout)
0110
0111
0112 def get_user_name_from_dn1(dn):
0113 return authentication.get_user_name_from_dn1(dn)
0114
0115
0116 def get_user_name_from_dn2(dn):
0117 return authentication.get_user_name_from_dn2(dn)
0118
0119
0120 def get_user_name_from_dn(dn):
0121 dn = get_user_name_from_dn1(dn)
0122 dn = get_user_name_from_dn2(dn)
0123 return dn
0124
0125
0126 def authenticate_x509(vo, dn, client_cert):
0127 return authentication.authenticate_x509(vo, dn, client_cert)
0128
0129
0130 def authenticate_oidc(vo, token):
0131 oidc_auth = OIDCAuthentication()
0132 status, data, username = oidc_auth.verify_id_token(vo, token)
0133 if status:
0134 return status, data, username
0135 else:
0136 return status, data, username
0137
0138
0139 def authenticate_is_super_user(username, dn=None):
0140 return authentication.authenticate_is_super_user(username=username, dn=dn)