Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 08:38:39

0001 import ast
0002 import inspect
0003 import re
0004 import sys
0005 import threading
0006 import time
0007 import typing
0008 from functools import wraps
0009 from types import ModuleType, UnionType
0010 from typing import Union, get_args, get_origin
0011 
0012 from pandacommon.pandalogger.LogWrapper import LogWrapper
0013 
0014 import pandaserver.jobdispatcher.Protocol as Protocol
0015 from pandaserver.config import panda_config
0016 from pandaserver.dataservice.ddm import rucioAPI
0017 from pandaserver.srvcore import CoreUtils
0018 
0019 TIME_OUT = "TimeOut"
0020 
0021 MESSAGE_SSL = "SSL secure connection is required"
0022 MESSAGE_PROD_ROLE = "production or pilot role required"
0023 MESSAGE_TASK_ID = "jediTaskID must be an integer"
0024 MESSAGE_DATABASE = "database error in the PanDA server"
0025 MESSAGE_JSON = "failed to load JSON"
0026 
0027 
0028 def get_endpoint(protocol):
0029     if protocol not in ["http", "https"]:
0030         return False, "Protocol must be either 'http' or 'https'"
0031 
0032     try:
0033         if protocol == "https":
0034             endpoint = f"{panda_config.pserverhost}:{panda_config.pserverport}"
0035         else:
0036             endpoint = f"{panda_config.pserverhosthttp}:{panda_config.pserverporthttp}"
0037     except Exception as e:
0038         return False, str(e)
0039 
0040     return True, endpoint
0041 
0042 
0043 def extract_allowed_methods(module: ModuleType) -> list:
0044     """
0045     Generate the allowed methods dynamically with all function names present in the API module, excluding
0046     functions imported from other modules or the init_task_buffer function
0047 
0048     :param module: The module to extract the allowed methods from
0049     :return: A list of allowed method names
0050     """
0051     return [
0052         name
0053         for name, obj in inspect.getmembers(module, inspect.isfunction)
0054         if obj.__module__ == module.__name__ and name != "init_task_buffer" and name.startswith("_") is False
0055     ]
0056 
0057 
0058 def generate_response(success, message="", data=None):
0059     response = {"success": success, "message": message, "data": data}
0060     return response
0061 
0062 
0063 # get FQANs
0064 def get_fqan(req):
0065     fqans = []
0066     for tmp_key in req.subprocess_env:
0067         tmp_value = req.subprocess_env[tmp_key]
0068         # Scan VOMS attributes
0069         # compact style
0070         if tmp_key.startswith("GRST_CRED_") and tmp_value.startswith("VOMS"):
0071             fqan = tmp_value.split()[-1]
0072             fqans.append(fqan)
0073 
0074         # old style
0075         elif tmp_key.startswith("GRST_CONN_"):
0076             tmp_items = tmp_value.split(":")
0077             if len(tmp_items) == 2 and tmp_items[0] == "fqan":
0078                 fqans.append(tmp_items[-1])
0079 
0080     return fqans
0081 
0082 
0083 def get_email_address(user, tmp_logger):
0084     tmp_logger.debug(f"Getting mail address for {user}")
0085     n_tries = 3
0086     email = None
0087     try:
0088         for attempt in range(n_tries):
0089             status, user_info = rucioAPI.finger(user)
0090             if status:
0091                 email = user_info["email"]
0092                 tmp_logger.debug(f"User {user} got email {email}")
0093                 break
0094             else:
0095                 tmp_logger.debug(f"Attempt {attempt + 1} of {n_tries} failed. Retrying...")
0096             time.sleep(1)
0097     except Exception:
0098         error_type, error_value = sys.exc_info()[:2]
0099         tmp_logger.error(f"Failed to convert email address {user} : {error_type} {error_value}")
0100 
0101     return email
0102 
0103 
0104 def get_request_method(req):
0105     # Extract the http method like GET, POST, ... from the request environment
0106     environ = req.subprocess_env
0107     request_method = environ.get("REQUEST_METHOD", None)  # GET, POST, PUT, DELETE
0108     return request_method
0109 
0110 
0111 # get DN
0112 def get_dn(req):
0113     real_dn = ""
0114     if "SSL_CLIENT_S_DN" in req.subprocess_env:
0115         # remove redundant CN
0116         real_dn = CoreUtils.get_bare_dn(req.subprocess_env["SSL_CLIENT_S_DN"], keep_proxy=True)
0117     return real_dn
0118 
0119 
0120 # check role
0121 def has_production_role(req):
0122     # check DN
0123     user = get_dn(req)
0124     for sdn in panda_config.production_dns:
0125         if sdn in user:
0126             return True
0127     # get FQANs
0128     fqans = get_fqan(req)
0129     # loop over all FQANs
0130     for fqan in fqans:
0131         # check production role
0132         for rolePat in [
0133             "/atlas/usatlas/Role=production",
0134             "/atlas/Role=production",
0135             "^/[^/]+/Role=production",
0136         ]:
0137             if fqan.startswith(rolePat):
0138                 return True
0139             if re.search(rolePat, fqan):
0140                 return True
0141     return False
0142 
0143 
0144 def extract_production_working_groups(fqans):
0145     # Extract working groups with production role from FQANs
0146     wg_prod_roles = []
0147     for fqan in fqans:
0148         # Match FQANs with 'Role=production' and extract the working group
0149         match = re.search(r"/atlas/([^/]+)/Role=production", fqan)
0150         if match:
0151             working_group = match.group(1)
0152             # Exclude 'usatlas' and ensure uniqueness
0153             if working_group and working_group not in ["usatlas"] + wg_prod_roles:
0154                 wg_prod_roles.extend([working_group, f"gr_{working_group}"])  # Add group and prefixed variant
0155 
0156     return wg_prod_roles
0157 
0158 
0159 def extract_primary_production_working_group(fqans):
0160     working_group = None
0161     for fqan in fqans:
0162         match = re.search("/[^/]+/([^/]+)/Role=production", fqan)
0163         if match:
0164             # ignore usatlas since it is used as atlas prod role
0165             tmp_working_group = match.group(1)
0166             if tmp_working_group not in ["", "usatlas"]:
0167                 working_group = tmp_working_group.split("-")[-1].lower()
0168 
0169     return working_group
0170 
0171 
0172 # security check
0173 def is_secure(req, logger=None):
0174     # check security
0175     if not Protocol.isSecure(req):
0176         return False
0177 
0178     # disable limited proxy
0179     if "/CN=limited proxy" in req.subprocess_env["SSL_CLIENT_S_DN"]:
0180         if logger:
0181             logger.warning(f"access via limited proxy : {req.subprocess_env['SSL_CLIENT_S_DN']}")
0182         return False
0183 
0184     return True
0185 
0186 
0187 def normalize_type(t):
0188     mapping = {
0189         typing.List: list,
0190         typing.Dict: dict,
0191         typing.Set: set,
0192         typing.Tuple: tuple,
0193     }
0194     return mapping.get(t, t)
0195 
0196 
0197 def request_validation(logger, secure=True, production=False, request_method=None):
0198     def decorator(func):
0199         @wraps(func)
0200         def wrapper(req, *args, **kwargs):
0201             # Generate a logger with the underlying function name
0202             tmp_logger = LogWrapper(logger, func.__name__)
0203             tmp_logger_context = LogWrapper(logger, f"{func.__name__} args:{args} kwargs:{kwargs}")
0204 
0205             # expected and received request methods
0206             expected_request_method = request_method
0207             received_request_method = get_request_method(req)
0208 
0209             # check SSL if required
0210             if secure and not is_secure(req, tmp_logger):
0211                 tmp_logger.error(f"{MESSAGE_SSL}")
0212                 return generate_response(False, message=MESSAGE_SSL)
0213 
0214             # check production role if required
0215             if production and not has_production_role(req):
0216                 tmp_logger.error(f"{MESSAGE_PROD_ROLE}")
0217                 return generate_response(False, message=MESSAGE_PROD_ROLE)
0218 
0219             # check method if required
0220             if expected_request_method and expected_request_method != received_request_method:
0221                 message = f"expecting {expected_request_method}, received {req.subprocess_env.get('REQUEST_METHOD', None)}"
0222                 tmp_logger.error(f"{message}")
0223                 return generate_response(False, message=message)
0224 
0225             # Get function signature and type hints
0226             sig = inspect.signature(func)
0227             args_tmp = (req,) + args
0228             try:
0229                 bound_args = sig.bind(*args_tmp, **kwargs)
0230             except TypeError as e:
0231                 message = f"Argument error: {str(e)}"
0232                 tmp_logger_context.error(message)
0233                 return generate_response(False, message=message)
0234             bound_args.apply_defaults()
0235 
0236             for param_name, param_value in bound_args.arguments.items():
0237                 # tmp_logger.debug(f"Got parameter '{param_name}' with value '{param_value}' and type '{type(param_value)}'")
0238 
0239                 # Skip the first argument (req)
0240                 if param_name == "req":
0241                     continue
0242 
0243                 # Skip if no type hint
0244                 expected_type = sig.parameters[param_name].annotation
0245                 if expected_type is inspect.Parameter.empty:
0246                     continue
0247 
0248                 # Skip if value is the default value
0249                 default_value = sig.parameters[param_name].default
0250                 if default_value == param_value:
0251                     continue
0252 
0253                 # Handle generics like List[int]
0254                 origin = get_origin(expected_type)
0255                 args = get_args(expected_type)
0256 
0257                 # GET methods are URL encoded. Parameters will lose the type and come as string. We need to cast them to the expected type
0258                 if received_request_method == "GET":
0259                     try:
0260                         tmp_logger.debug(f"Casting '{param_name}' to type {expected_type.__name__}.")
0261                         tmp_logger.debug(type(param_value))
0262                         if param_value == "None" and default_value is None:
0263                             param_value = None
0264                         # Don't cast if the type is already a string
0265                         elif expected_type is str:
0266                             pass
0267                         # Booleans need to be handled separately, since bool("False") == True
0268                         elif expected_type is bool:
0269                             param_value = param_value.lower() in ("true", "1")
0270                         # Convert to float first, then to int. This is a courtesy for cases passing decimal numbers.
0271                         elif expected_type is int:
0272                             param_value = int(float(param_value))
0273                         elif origin is list and args:
0274                             element_type = args[0]  # Get the type inside List[<type>]
0275 
0276                             # If only one element, convert it to a list
0277                             if isinstance(param_value, str):
0278                                 param_value = [param_value]
0279 
0280                             # Convert the elements of the list to the expected type
0281                             if element_type is int:
0282                                 param_value = [int(float(i)) for i in param_value]  # Convert list items to int
0283                             elif element_type is float:
0284                                 param_value = [float(i) for i in param_value]  # Convert list items to float
0285                             elif element_type is bool:
0286                                 param_value = [i.lower() in ("true", "1") for i in param_value]  # Convert list items to bool
0287                         else:
0288                             # Normalize type, e.g. typing.Dict -> dict
0289                             expected_type = normalize_type(expected_type)
0290                             if not isinstance(param_value, expected_type):
0291                                 param_value = ast.literal_eval(param_value)
0292                             if not isinstance(param_value, expected_type):
0293                                 raise TypeError(f"Expected {expected_type}, received {type(param_value)}")
0294                         bound_args.arguments[param_name] = param_value  # Ensure the cast value is used
0295                     except (ValueError, TypeError):
0296                         message = f"Type error: '{param_name}' with value '{param_value}' could not be casted to type {expected_type.__name__} from {type(param_value).__name__}."
0297                         tmp_logger_context.error(message)
0298                         return generate_response(False, message=message)
0299 
0300                 # Check type
0301                 if origin and (origin is not Union and origin is not UnionType):  # Handle generics (e.g., List[int])
0302                     if not isinstance(param_value, origin):
0303                         message = f"Type error: '{param_name}' must be of type {origin.__name__}, got {type(param_value).__name__}."
0304                         tmp_logger_context.error(message)
0305                         return generate_response(False, message=message)
0306 
0307                     if args:  # Check inner types for lists, dicts, etc.
0308                         if origin is list and not all(isinstance(i, args[0]) for i in param_value):
0309                             message = f"Type error: All elements in '{param_name}' must be {args[0].__name__}."
0310                             tmp_logger_context.error(message)
0311                             return generate_response(False, message=message)
0312                 elif not isinstance(param_value, expected_type) and not (param_value is None and param_value == default_value):
0313                     message = f"Type error: '{param_name}' must be of type {expected_type.__name__}, got {type(param_value).__name__}."
0314                     tmp_logger_context.error(message)
0315                     return generate_response(False, message=message)
0316 
0317             return func(*bound_args.args, **bound_args.kwargs)
0318 
0319         return wrapper
0320 
0321     return decorator
0322 
0323 
0324 # a wrapper to install timeout into a method
0325 class TimedMethod:
0326     def __init__(self, method, timeout):
0327         self.method = method
0328         self.timeout = timeout
0329         self.result = TIME_OUT
0330 
0331     # method emulation
0332     def __call__(self, *var, **kwargs):
0333         self.result = self.method(*var, **kwargs)
0334 
0335     # run
0336     def run(self, *var, **kwargs):
0337         thr = threading.Thread(target=self, args=var, kwargs=kwargs)
0338         thr.start()
0339         thr.join()