Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:58:19

0001 #!/usr/bin/env python
0002 #
0003 # Licensed under the Apache License, Version 2.0 (the "License");
0004 # You may not use this file except in compliance with the License.
0005 # You may obtain a copy of the License at
0006 # http://www.apache.org/licenses/LICENSE-2.0OA
0007 #
0008 # Authors:
0009 # - Wen Guan, <wen.guan@@cern.ch>, 2024
0010 
0011 import base64
0012 import json
0013 import jwt
0014 import traceback
0015 
0016 # from cryptography import x509
0017 from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
0018 from cryptography.hazmat.backends import default_backend
0019 from cryptography.hazmat.primitives import serialization
0020 
0021 from idds.common import authentication
0022 
0023 
0024 def decode_value(val):
0025     if isinstance(val, str):
0026         val = val.encode()
0027     decoded = base64.urlsafe_b64decode(val + b'==')
0028     return int.from_bytes(decoded, 'big')
0029 
0030 
0031 class OIDCAuthentication(authentication.OIDCAuthentication):
0032     def __init__(self, timeout=None):
0033         super(OIDCAuthentication, self).__init__(timeout=timeout)
0034 
0035     def get_public_key(self, token, jwks_uri, no_verify=False, with_cache=True):
0036         headers = jwt.get_unverified_header(token)
0037         if headers is None or 'kid' not in headers:
0038             raise jwt.exceptions.InvalidTokenError('cannot extract kid from headers')
0039         kid = headers['kid']
0040 
0041         if with_cache:
0042             jwks = self.get_cache_value(jwks_uri)
0043         else:
0044             jwks = None
0045 
0046         if not jwks:
0047             jwks_content = self.get_http_content(jwks_uri, no_verify=no_verify)
0048             jwks = json.loads(jwks_content)
0049             self.set_cache_value(jwks_uri, jwks)
0050 
0051         jwk = None
0052         for j in jwks.get('keys', []):
0053             if j.get('kid') == kid:
0054                 jwk = j
0055         if jwk is None:
0056             raise jwt.exceptions.InvalidTokenError('JWK not found for kid={0}: {1}'.format(kid, str(jwks)))
0057 
0058         public_num = RSAPublicNumbers(n=decode_value(jwk['n']), e=decode_value(jwk['e']))
0059         public_key = public_num.public_key(default_backend())
0060         pem = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
0061         return pem
0062 
0063     def verify_id_token_cache(self, vo, token, with_cache=True):
0064         try:
0065             auth_config, endpoint_config = self.get_auth_endpoint_config(vo)
0066 
0067             # check audience
0068             decoded_token = jwt.decode(token, verify=False, options={"verify_signature": False})
0069             audience = decoded_token['aud']
0070             if audience not in [auth_config['audience'], auth_config['client_id']]:
0071                 # discovery_endpoint = auth_config['oidc_config_url']
0072                 return False, "The audience %s of the token doesn't match vo configuration(client_id: %s)." % (audience, auth_config['client_id']), None
0073 
0074             public_key = self.get_public_key(token, endpoint_config['jwks_uri'], no_verify=auth_config['no_verify'], with_cache=with_cache)
0075             # decode token only with RS256
0076             if 'iss' in decoded_token and decoded_token['iss'] and decoded_token['iss'] != endpoint_config['issuer'] and endpoint_config['issuer'].startswith(decoded_token['iss']):
0077                 # iss is missing the last '/' in access tokens
0078                 issuer = decoded_token['iss']
0079             else:
0080                 issuer = endpoint_config['issuer']
0081 
0082             decoded = jwt.decode(token, public_key, verify=True, algorithms='RS256',
0083                                  audience=audience, issuer=issuer)
0084             decoded['vo'] = vo
0085             if 'name' in decoded:
0086                 username = decoded['name']
0087             else:
0088                 username = None
0089             return True, decoded, username
0090         except Exception as error:
0091             print(error)
0092             print(traceback.format_exc())
0093             return False, 'Failed to verify oidc token: ' + str(error), None
0094 
0095     def verify_id_token(self, vo, token):
0096         status, data, username = self.verify_id_token_cache(vo, token, with_cache=True)
0097         if status:
0098             return status, data, username
0099         return self.verify_id_token_cache(vo, token, with_cache=False)
0100 
0101 
0102 class OIDCAuthenticationUtils(authentication.OIDCAuthenticationUtils):
0103     def __init__(self):
0104         super(OIDCAuthenticationUtils, self).__init__()
0105 
0106 
0107 class X509Authentication(authentication.X509Authentication):
0108     def __init__(self, timeout=None):
0109         super(X509Authentication, self).__init__(timeout=timeout)
0110 
0111 
0112 def get_user_name_from_dn1(dn):
0113     return authentication.get_user_name_from_dn1(dn)
0114 
0115 
0116 def get_user_name_from_dn2(dn):
0117     return authentication.get_user_name_from_dn2(dn)
0118 
0119 
0120 def get_user_name_from_dn(dn):
0121     dn = get_user_name_from_dn1(dn)
0122     dn = get_user_name_from_dn2(dn)
0123     return dn
0124 
0125 
0126 def authenticate_x509(vo, dn, client_cert):
0127     return authentication.authenticate_x509(vo, dn, client_cert)
0128 
0129 
0130 def authenticate_oidc(vo, token):
0131     oidc_auth = OIDCAuthentication()
0132     status, data, username = oidc_auth.verify_id_token(vo, token)
0133     if status:
0134         return status, data, username
0135     else:
0136         return status, data, username
0137 
0138 
0139 def authenticate_is_super_user(username, dn=None):
0140     return authentication.authenticate_is_super_user(username=username, dn=dn)