File indexing completed on 2026-04-09 08:38:39
0001 import ast
0002 import inspect
0003 import re
0004 import sys
0005 import threading
0006 import time
0007 import typing
0008 from functools import wraps
0009 from types import ModuleType, UnionType
0010 from typing import Union, get_args, get_origin
0011
0012 from pandacommon.pandalogger.LogWrapper import LogWrapper
0013
0014 import pandaserver.jobdispatcher.Protocol as Protocol
0015 from pandaserver.config import panda_config
0016 from pandaserver.dataservice.ddm import rucioAPI
0017 from pandaserver.srvcore import CoreUtils
0018
0019 TIME_OUT = "TimeOut"
0020
0021 MESSAGE_SSL = "SSL secure connection is required"
0022 MESSAGE_PROD_ROLE = "production or pilot role required"
0023 MESSAGE_TASK_ID = "jediTaskID must be an integer"
0024 MESSAGE_DATABASE = "database error in the PanDA server"
0025 MESSAGE_JSON = "failed to load JSON"
0026
0027
0028 def get_endpoint(protocol):
0029 if protocol not in ["http", "https"]:
0030 return False, "Protocol must be either 'http' or 'https'"
0031
0032 try:
0033 if protocol == "https":
0034 endpoint = f"{panda_config.pserverhost}:{panda_config.pserverport}"
0035 else:
0036 endpoint = f"{panda_config.pserverhosthttp}:{panda_config.pserverporthttp}"
0037 except Exception as e:
0038 return False, str(e)
0039
0040 return True, endpoint
0041
0042
0043 def extract_allowed_methods(module: ModuleType) -> list:
0044 """
0045 Generate the allowed methods dynamically with all function names present in the API module, excluding
0046 functions imported from other modules or the init_task_buffer function
0047
0048 :param module: The module to extract the allowed methods from
0049 :return: A list of allowed method names
0050 """
0051 return [
0052 name
0053 for name, obj in inspect.getmembers(module, inspect.isfunction)
0054 if obj.__module__ == module.__name__ and name != "init_task_buffer" and name.startswith("_") is False
0055 ]
0056
0057
0058 def generate_response(success, message="", data=None):
0059 response = {"success": success, "message": message, "data": data}
0060 return response
0061
0062
0063
0064 def get_fqan(req):
0065 fqans = []
0066 for tmp_key in req.subprocess_env:
0067 tmp_value = req.subprocess_env[tmp_key]
0068
0069
0070 if tmp_key.startswith("GRST_CRED_") and tmp_value.startswith("VOMS"):
0071 fqan = tmp_value.split()[-1]
0072 fqans.append(fqan)
0073
0074
0075 elif tmp_key.startswith("GRST_CONN_"):
0076 tmp_items = tmp_value.split(":")
0077 if len(tmp_items) == 2 and tmp_items[0] == "fqan":
0078 fqans.append(tmp_items[-1])
0079
0080 return fqans
0081
0082
0083 def get_email_address(user, tmp_logger):
0084 tmp_logger.debug(f"Getting mail address for {user}")
0085 n_tries = 3
0086 email = None
0087 try:
0088 for attempt in range(n_tries):
0089 status, user_info = rucioAPI.finger(user)
0090 if status:
0091 email = user_info["email"]
0092 tmp_logger.debug(f"User {user} got email {email}")
0093 break
0094 else:
0095 tmp_logger.debug(f"Attempt {attempt + 1} of {n_tries} failed. Retrying...")
0096 time.sleep(1)
0097 except Exception:
0098 error_type, error_value = sys.exc_info()[:2]
0099 tmp_logger.error(f"Failed to convert email address {user} : {error_type} {error_value}")
0100
0101 return email
0102
0103
0104 def get_request_method(req):
0105
0106 environ = req.subprocess_env
0107 request_method = environ.get("REQUEST_METHOD", None)
0108 return request_method
0109
0110
0111
0112 def get_dn(req):
0113 real_dn = ""
0114 if "SSL_CLIENT_S_DN" in req.subprocess_env:
0115
0116 real_dn = CoreUtils.get_bare_dn(req.subprocess_env["SSL_CLIENT_S_DN"], keep_proxy=True)
0117 return real_dn
0118
0119
0120
0121 def has_production_role(req):
0122
0123 user = get_dn(req)
0124 for sdn in panda_config.production_dns:
0125 if sdn in user:
0126 return True
0127
0128 fqans = get_fqan(req)
0129
0130 for fqan in fqans:
0131
0132 for rolePat in [
0133 "/atlas/usatlas/Role=production",
0134 "/atlas/Role=production",
0135 "^/[^/]+/Role=production",
0136 ]:
0137 if fqan.startswith(rolePat):
0138 return True
0139 if re.search(rolePat, fqan):
0140 return True
0141 return False
0142
0143
0144 def extract_production_working_groups(fqans):
0145
0146 wg_prod_roles = []
0147 for fqan in fqans:
0148
0149 match = re.search(r"/atlas/([^/]+)/Role=production", fqan)
0150 if match:
0151 working_group = match.group(1)
0152
0153 if working_group and working_group not in ["usatlas"] + wg_prod_roles:
0154 wg_prod_roles.extend([working_group, f"gr_{working_group}"])
0155
0156 return wg_prod_roles
0157
0158
0159 def extract_primary_production_working_group(fqans):
0160 working_group = None
0161 for fqan in fqans:
0162 match = re.search("/[^/]+/([^/]+)/Role=production", fqan)
0163 if match:
0164
0165 tmp_working_group = match.group(1)
0166 if tmp_working_group not in ["", "usatlas"]:
0167 working_group = tmp_working_group.split("-")[-1].lower()
0168
0169 return working_group
0170
0171
0172
0173 def is_secure(req, logger=None):
0174
0175 if not Protocol.isSecure(req):
0176 return False
0177
0178
0179 if "/CN=limited proxy" in req.subprocess_env["SSL_CLIENT_S_DN"]:
0180 if logger:
0181 logger.warning(f"access via limited proxy : {req.subprocess_env['SSL_CLIENT_S_DN']}")
0182 return False
0183
0184 return True
0185
0186
0187 def normalize_type(t):
0188 mapping = {
0189 typing.List: list,
0190 typing.Dict: dict,
0191 typing.Set: set,
0192 typing.Tuple: tuple,
0193 }
0194 return mapping.get(t, t)
0195
0196
0197 def request_validation(logger, secure=True, production=False, request_method=None):
0198 def decorator(func):
0199 @wraps(func)
0200 def wrapper(req, *args, **kwargs):
0201
0202 tmp_logger = LogWrapper(logger, func.__name__)
0203 tmp_logger_context = LogWrapper(logger, f"{func.__name__} args:{args} kwargs:{kwargs}")
0204
0205
0206 expected_request_method = request_method
0207 received_request_method = get_request_method(req)
0208
0209
0210 if secure and not is_secure(req, tmp_logger):
0211 tmp_logger.error(f"{MESSAGE_SSL}")
0212 return generate_response(False, message=MESSAGE_SSL)
0213
0214
0215 if production and not has_production_role(req):
0216 tmp_logger.error(f"{MESSAGE_PROD_ROLE}")
0217 return generate_response(False, message=MESSAGE_PROD_ROLE)
0218
0219
0220 if expected_request_method and expected_request_method != received_request_method:
0221 message = f"expecting {expected_request_method}, received {req.subprocess_env.get('REQUEST_METHOD', None)}"
0222 tmp_logger.error(f"{message}")
0223 return generate_response(False, message=message)
0224
0225
0226 sig = inspect.signature(func)
0227 args_tmp = (req,) + args
0228 try:
0229 bound_args = sig.bind(*args_tmp, **kwargs)
0230 except TypeError as e:
0231 message = f"Argument error: {str(e)}"
0232 tmp_logger_context.error(message)
0233 return generate_response(False, message=message)
0234 bound_args.apply_defaults()
0235
0236 for param_name, param_value in bound_args.arguments.items():
0237
0238
0239
0240 if param_name == "req":
0241 continue
0242
0243
0244 expected_type = sig.parameters[param_name].annotation
0245 if expected_type is inspect.Parameter.empty:
0246 continue
0247
0248
0249 default_value = sig.parameters[param_name].default
0250 if default_value == param_value:
0251 continue
0252
0253
0254 origin = get_origin(expected_type)
0255 args = get_args(expected_type)
0256
0257
0258 if received_request_method == "GET":
0259 try:
0260 tmp_logger.debug(f"Casting '{param_name}' to type {expected_type.__name__}.")
0261 tmp_logger.debug(type(param_value))
0262 if param_value == "None" and default_value is None:
0263 param_value = None
0264
0265 elif expected_type is str:
0266 pass
0267
0268 elif expected_type is bool:
0269 param_value = param_value.lower() in ("true", "1")
0270
0271 elif expected_type is int:
0272 param_value = int(float(param_value))
0273 elif origin is list and args:
0274 element_type = args[0]
0275
0276
0277 if isinstance(param_value, str):
0278 param_value = [param_value]
0279
0280
0281 if element_type is int:
0282 param_value = [int(float(i)) for i in param_value]
0283 elif element_type is float:
0284 param_value = [float(i) for i in param_value]
0285 elif element_type is bool:
0286 param_value = [i.lower() in ("true", "1") for i in param_value]
0287 else:
0288
0289 expected_type = normalize_type(expected_type)
0290 if not isinstance(param_value, expected_type):
0291 param_value = ast.literal_eval(param_value)
0292 if not isinstance(param_value, expected_type):
0293 raise TypeError(f"Expected {expected_type}, received {type(param_value)}")
0294 bound_args.arguments[param_name] = param_value
0295 except (ValueError, TypeError):
0296 message = f"Type error: '{param_name}' with value '{param_value}' could not be casted to type {expected_type.__name__} from {type(param_value).__name__}."
0297 tmp_logger_context.error(message)
0298 return generate_response(False, message=message)
0299
0300
0301 if origin and (origin is not Union and origin is not UnionType):
0302 if not isinstance(param_value, origin):
0303 message = f"Type error: '{param_name}' must be of type {origin.__name__}, got {type(param_value).__name__}."
0304 tmp_logger_context.error(message)
0305 return generate_response(False, message=message)
0306
0307 if args:
0308 if origin is list and not all(isinstance(i, args[0]) for i in param_value):
0309 message = f"Type error: All elements in '{param_name}' must be {args[0].__name__}."
0310 tmp_logger_context.error(message)
0311 return generate_response(False, message=message)
0312 elif not isinstance(param_value, expected_type) and not (param_value is None and param_value == default_value):
0313 message = f"Type error: '{param_name}' must be of type {expected_type.__name__}, got {type(param_value).__name__}."
0314 tmp_logger_context.error(message)
0315 return generate_response(False, message=message)
0316
0317 return func(*bound_args.args, **bound_args.kwargs)
0318
0319 return wrapper
0320
0321 return decorator
0322
0323
0324
0325 class TimedMethod:
0326 def __init__(self, method, timeout):
0327 self.method = method
0328 self.timeout = timeout
0329 self.result = TIME_OUT
0330
0331
0332 def __call__(self, *var, **kwargs):
0333 self.result = self.method(*var, **kwargs)
0334
0335
0336 def run(self, *var, **kwargs):
0337 thr = threading.Thread(target=self, args=var, kwargs=kwargs)
0338 thr.start()
0339 thr.join()