File indexing completed on 2026-04-09 07:58:16
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012 import base64
0013 import concurrent.futures
0014 import contextlib
0015 import errno
0016 import datetime
0017 import functools
0018 import importlib
0019 import hashlib
0020 import logging
0021 import json
0022 import os
0023 import re
0024 import requests
0025 import signal
0026 import socket
0027 import subprocess
0028 import sys
0029 import tarfile
0030 import threading
0031 import time
0032
0033 from enum import Enum
0034 from functools import wraps
0035 from logging.handlers import RotatingFileHandler
0036 from itertools import groupby
0037 from operator import itemgetter
0038 from packaging import version as packaging_version
0039 from typing import Any, Callable
0040
0041 from idds.common.config import (
0042 config_has_section,
0043 config_has_option,
0044 config_get,
0045 config_get_bool,
0046 config_get_int,
0047 )
0048 from idds.common.constants import (
0049 IDDSEnum,
0050 RequestType,
0051 RequestStatus,
0052 TransformType,
0053 TransformStatus,
0054 CollectionType,
0055 CollectionRelationType,
0056 CollectionStatus,
0057 ContentType,
0058 ContentStatus,
0059 GranularityType,
0060 ProcessingStatus,
0061 )
0062 from idds.common.dict_class import DictClass
0063 from idds.common.exceptions import IDDSException
0064
0065
0066
0067 DATE_FORMAT = "%a, %d %b %Y %H:%M:%S UTC"
0068
0069
0070 def get_log_dir():
0071 if config_has_section("common") and config_has_option("common", "logdir"):
0072 return config_get("common", "logdir")
0073 return "/var/log/idds"
0074
0075
0076 def setup_logging(name, stream=None, log_file=None, loglevel=None):
0077 """
0078 Setup logging
0079 """
0080 if loglevel is None:
0081 if config_has_section("common") and config_has_option("common", "loglevel"):
0082 loglevel = getattr(logging, config_get("common", "loglevel").upper())
0083 else:
0084 loglevel = logging.INFO
0085
0086 if os.environ.get("IDDS_LOG_LEVEL", None):
0087 idds_log_level = os.environ.get("IDDS_LOG_LEVEL", None)
0088 idds_log_level = idds_log_level.upper()
0089 if idds_log_level in ["DEBUG", "CRITICAL", "ERROR", "WARNING", "INFO"]:
0090 loglevel = getattr(logging, idds_log_level)
0091
0092 if type(loglevel) in [str]:
0093 loglevel = loglevel.upper()
0094 loglevel = getattr(logging, loglevel)
0095
0096 if log_file is not None:
0097 if not log_file.startswith("/"):
0098 logdir = None
0099 if config_has_section("common") and config_has_option("common", "logdir"):
0100 logdir = config_get("common", "logdir")
0101 if not logdir:
0102 logdir = "/var/log/idds"
0103 log_file = os.path.join(logdir, log_file)
0104
0105 if log_file:
0106 logging.basicConfig(
0107 filename=log_file,
0108 level=loglevel,
0109 format="%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s",
0110 )
0111 elif stream is None:
0112 if os.environ.get("IDDS_LOG_FILE", None):
0113 idds_log_file = os.environ.get("IDDS_LOG_FILE", None)
0114 logging.basicConfig(
0115 filename=idds_log_file,
0116 level=loglevel,
0117 format="%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s",
0118 )
0119 elif (
0120 config_has_section("common")
0121 and config_has_option("common", "logdir")
0122 and config_has_option("common", "logfile")
0123 ) or log_file:
0124 if log_file:
0125 log_filename = log_file
0126 else:
0127 log_filename = config_get("common", "logfile")
0128 if not log_filename.startswith("/"):
0129 log_filename = os.path.join(
0130 config_get("common", "logdir"), log_filename
0131 )
0132 logging.basicConfig(
0133 filename=log_filename,
0134 level=loglevel,
0135 format="%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s",
0136 )
0137 else:
0138 logging.basicConfig(
0139 stream=sys.stdout,
0140 level=loglevel,
0141 format="%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s",
0142 )
0143 else:
0144 logging.basicConfig(
0145 stream=stream,
0146 level=loglevel,
0147 format="%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s",
0148 )
0149 logging.Formatter.converter = time.gmtime
0150
0151
0152 def get_logger(name, filename=None, loglevel=None):
0153 """
0154 Setup logging
0155 """
0156 if loglevel is None:
0157 if config_has_section("common") and config_has_option("common", "loglevel"):
0158 loglevel = getattr(logging, config_get("common", "loglevel").upper())
0159 else:
0160 loglevel = logging.INFO
0161
0162 if filename is None:
0163 filename = name + ".log"
0164 if not filename.startswith("/"):
0165 logdir = None
0166 if config_has_section("common") and config_has_option("common", "logdir"):
0167 logdir = config_get("common", "logdir")
0168 if not logdir:
0169 logdir = "/var/log/idds"
0170 filename = os.path.join(logdir, filename)
0171
0172 formatter = logging.Formatter(
0173 "%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s"
0174 )
0175
0176 handler = RotatingFileHandler(
0177 filename, maxBytes=2 * 1024 * 1024 * 1024, backupCount=3
0178 )
0179 handler.setFormatter(formatter)
0180 logger = logging.getLogger(name)
0181 logger.setLevel(loglevel)
0182 logger.addHandler(handler)
0183 logger.propagate = False
0184 return logger
0185
0186
0187 def get_rest_url_prefix():
0188 if config_has_section("rest") and config_has_option("rest", "url_prefix"):
0189 url_prefix = config_get("rest", "url_prefix")
0190 else:
0191 url_prefix = None
0192 if url_prefix:
0193 while url_prefix.startswith("/"):
0194 url_prefix = url_prefix[1:]
0195 while url_prefix.endswith("/"):
0196 url_prefix = url_prefix[:-1]
0197 url_prefix = "/" + url_prefix
0198 return url_prefix
0199
0200
0201 def get_rest_debug():
0202 if config_has_section("rest") and config_has_option("rest", "debug"):
0203 return config_get_bool("rest", "debug")
0204 return False
0205
0206
0207 def get_rest_cacher_dir():
0208 cacher_dir = None
0209 if config_has_section("rest") and config_has_option("rest", "cacher_dir"):
0210 cacher_dir = config_get("rest", "cacher_dir")
0211 if cacher_dir and os.path.exists(cacher_dir):
0212 return cacher_dir
0213 raise Exception("cacher_dir is not defined or it doesn't exist")
0214
0215
0216 def get_asyncresult_config():
0217 broker_type = None
0218 brokers = None
0219 broker_destination = None
0220 broker_timeout = 360
0221 broker_username = None
0222 broker_password = None
0223 broker_x509 = None
0224
0225 if config_has_section("asyncresult"):
0226 if config_has_option("asyncresult", "broker_type"):
0227 broker_type = config_get("asyncresult", "broker_type")
0228 if config_has_option("asyncresult", "brokers"):
0229 brokers = config_get("asyncresult", "brokers")
0230 if config_has_option("asyncresult", "broker_destination"):
0231 broker_destination = config_get("asyncresult", "broker_destination")
0232 if config_has_option("asyncresult", "broker_timeout"):
0233 broker_timeout = config_get_int("asyncresult", "broker_timeout")
0234 if config_has_option("asyncresult", "broker_username"):
0235 broker_username = config_get("asyncresult", "broker_username")
0236 if config_has_option("asyncresult", "broker_password"):
0237 broker_password = config_get("asyncresult", "broker_password")
0238 if config_has_option("asyncresult", "broker_x509"):
0239 broker_x509 = config_get("asyncresult", "broker_x509")
0240
0241 ret = {
0242 "broker_type": broker_type,
0243 "brokers": brokers,
0244 "broker_destination": broker_destination,
0245 "broker_timeout": broker_timeout,
0246 "broker_username": broker_username,
0247 "broker_password": broker_password,
0248 "broker_x509": broker_x509,
0249 }
0250 return ret
0251
0252
0253 def get_prompt_broker_config():
0254 transformer_broker = None
0255 transformer_broadcast_broker = None
0256 result_broker = None
0257
0258 if config_has_section("prompt"):
0259 if config_has_option("prompt", "transformer_broker"):
0260 transformer_broker = json_loads(config_get("prompt", "transformer_broker"))
0261 if config_has_option("prompt", "transformer_broadcast_broker"):
0262 transformer_broadcast_broker = json_loads(config_get("prompt", "transformer_broadcast_broker"))
0263 if config_has_option("prompt", "result_broker"):
0264 result_broker = json_loads(config_get("prompt", "result_broker"))
0265
0266 ret = {
0267 "transformer_broker": transformer_broker,
0268 "transformer_broadcast_broker": transformer_broadcast_broker,
0269 "result_broker": result_broker,
0270 }
0271 return ret
0272
0273
0274 def str_to_date(string):
0275 """
0276 Converts a string to the corresponding datetime value.
0277
0278 :param string: the string to convert to datetime value.
0279 """
0280 return datetime.datetime.strptime(string, DATE_FORMAT) if string else None
0281
0282
0283 def date_to_str(date):
0284 """
0285 Converts a datetime value to a string.
0286
0287 :param date: the datetime value to convert.
0288 """
0289 return datetime.datetime.strftime(date, DATE_FORMAT) if date else None
0290
0291
0292 def has_config():
0293 """
0294 check whether there is a config file
0295 """
0296 if os.environ.get("IDDS_CONFIG", None):
0297 configfile = os.environ.get("IDDS_CONFIG", None)
0298 if configfile and os.path.exists(configfile):
0299 return True
0300 else:
0301 configfiles = [
0302 "%s/etc/idds/idds.cfg" % os.environ.get("IDDS_HOME", ""),
0303 "/etc/idds/idds.cfg",
0304 "%s/etc/idds/idds.cfg" % os.environ.get("VIRTUAL_ENV", ""),
0305 ]
0306
0307 for configfile in configfiles:
0308 if configfile and os.path.exists(configfile):
0309 return True
0310 return False
0311
0312
0313 def check_rest_host():
0314 """
0315 Function to check whether rest host is defined in config.
0316 To be used to decide whether to skip some test functions.
0317
0318 :returns True: if rest host is available. Otherwise False.
0319 """
0320 if config_has_option("rest", "host"):
0321 host = config_get("rest", "host")
0322 if host:
0323 return True
0324 return False
0325
0326
0327 def get_rest_host():
0328 """
0329 Function to get rest host
0330 """
0331 if "IDDS_HOST" in os.environ:
0332 return os.environ.get("IDDS_HOST")
0333 host = config_get("rest", "host")
0334 url_prefix = get_rest_url_prefix()
0335 while host.endswith("/"):
0336 host = host[:-1]
0337 if url_prefix:
0338 host = "".join([host, url_prefix])
0339 return host
0340
0341
0342 def check_user_proxy():
0343 """
0344 Check whether there is a user proxy.
0345 """
0346 if "X509_USER_PROXY" in os.environ:
0347 client_proxy = os.environ["X509_USER_PROXY"]
0348 else:
0349 client_proxy = "/tmp/x509up_u%d" % os.geteuid()
0350
0351 if not os.path.exists(client_proxy):
0352 return False
0353 else:
0354 return True
0355
0356
0357 def check_database():
0358 """
0359 Function to check whether database is defined in config.
0360 To be used to decide whether to skip some test functions.
0361
0362 :returns True: if database.default is available. Otherwise False.
0363 """
0364 if config_has_option("database", "default"):
0365 database = config_get("database", "default")
0366 if database:
0367 return True
0368 return False
0369
0370
0371 def kill_process_group(pgrp, nap=10):
0372 """
0373 Kill the process group.
0374 DO NOT MOVE TO PROCESSES.PY - will lead to circular import since execute() needs it as well.
0375 :param pgrp: process group id (int).
0376 :param nap: napping time between kill signals in seconds (int)
0377 :return: boolean (True if SIGTERM followed by SIGKILL signalling was successful)
0378 """
0379
0380 status = False
0381 _sleep = True
0382
0383
0384 print(f"killing group process {pgrp}")
0385 try:
0386 os.killpg(pgrp, signal.SIGTERM)
0387 except Exception as error:
0388 print(
0389 f"exception thrown when killing child group process under SIGTERM: {error}"
0390 )
0391 _sleep = False
0392 else:
0393 print(f"SIGTERM sent to process group {pgrp}")
0394
0395 if _sleep:
0396 print(f"sleeping {nap} s to allow processes to exit")
0397 time.sleep(nap)
0398
0399 try:
0400 os.killpg(pgrp, signal.SIGKILL)
0401 except Exception as error:
0402 print(
0403 f"exception thrown when killing child group process with SIGKILL: {error}"
0404 )
0405 else:
0406 print(f"SIGKILL sent to process group {pgrp}")
0407 status = True
0408
0409 return status
0410
0411
0412 def kill_all(process: Any) -> str:
0413 """
0414 Kill all processes after a time-out exception in process.communication().
0415
0416 :param process: process object
0417 :return: stderr (str).
0418 """
0419
0420 stderr = ""
0421 try:
0422 print("killing lingering subprocess and process group")
0423 time.sleep(1)
0424
0425 kill_process_group(os.getpgid(process.pid))
0426 except ProcessLookupError as exc:
0427 stderr += f"\n(kill process group) ProcessLookupError={exc}"
0428 except Exception as exc:
0429 stderr += f"\n(kill_all 1) exception caught: {exc}"
0430 try:
0431 print("killing lingering process")
0432 time.sleep(1)
0433 os.kill(process.pid, signal.SIGTERM)
0434 print("sleeping a bit before sending SIGKILL")
0435 time.sleep(10)
0436 os.kill(process.pid, signal.SIGKILL)
0437 except ProcessLookupError as exc:
0438 stderr += f"\n(kill process) ProcessLookupError={exc}"
0439 except Exception as exc:
0440 stderr += f"\n(kill_all 2) exception caught: {exc}"
0441 print(f"sent soft kill signals - final stderr: {stderr}")
0442 return stderr
0443
0444
0445 def run_process(cmd, stdout=None, stderr=None, wait=False, timeout=7 * 24 * 3600):
0446 """
0447 Runs a command in an out-of-procees shell.
0448 """
0449 print(f"To run command: {cmd}")
0450 if stdout and stderr:
0451 process = subprocess.Popen(
0452 cmd,
0453 shell=True,
0454 stdout=stdout,
0455 stderr=stderr,
0456 preexec_fn=os.setsid,
0457 encoding="utf-8",
0458 )
0459 else:
0460 process = subprocess.Popen(
0461 cmd, shell=True, preexec_fn=os.setsid, encoding="utf-8"
0462 )
0463 if not wait:
0464 return process
0465
0466 try:
0467 print(f"subprocess.communicate() will use timeout={timeout} s")
0468 process.communicate(timeout=timeout)
0469 except subprocess.TimeoutExpired as ex:
0470 stderr = f"subprocess communicate sent TimeoutExpired: {ex}"
0471 print(stderr)
0472 stderr = kill_all(process)
0473 print(f"Killing process: {stderr}")
0474 exit_code = -1
0475 except Exception as ex:
0476 stderr = f"subprocess has an exception: {ex}"
0477 print(stderr)
0478 stderr = kill_all(process)
0479 print(f"Killing process: {stderr}")
0480 exit_code = -1
0481 else:
0482 exit_code = process.poll()
0483
0484 try:
0485 process.wait(timeout=60)
0486 except subprocess.TimeoutExpired:
0487 print("process did not complete within the timeout of 60s - terminating")
0488 process.terminate()
0489 return exit_code
0490
0491
0492 def run_command(cmd):
0493 """
0494 Runs a command in an out-of-procees shell.
0495 """
0496 process = subprocess.Popen(
0497 cmd,
0498 shell=True,
0499 stdout=subprocess.PIPE,
0500 stderr=subprocess.PIPE,
0501 preexec_fn=os.setsid,
0502 )
0503 stdout, stderr = process.communicate()
0504 if stdout is not None and type(stdout) in [bytes]:
0505 stdout = stdout.decode()
0506 if stderr is not None and type(stderr) in [bytes]:
0507 stderr = stderr.decode()
0508 status = process.returncode
0509 return status, stdout, stderr
0510
0511
0512 def get_space_from_string(space_str):
0513 """
0514 Convert space with P, T, G, M to int
0515 """
0516 M = 1024
0517 G = 1024 * M
0518 T = 1024 * G
0519 P = 1024 * T
0520
0521 if "M" in space_str:
0522 return int(float(space_str.split("M")[0]) * M)
0523 elif "G" in space_str:
0524 return int(float(space_str.split("G")[0]) * G)
0525 elif "T" in space_str:
0526 return int(float(space_str.split("T")[0]) * T)
0527 elif "P" in space_str:
0528 return int(float(space_str.split("P")[0]) * P)
0529 else:
0530 return int(space_str)
0531
0532
0533 def urlretrieve(url, dest, timeout=300):
0534 """
0535 Download a file.
0536
0537 :param url: The url of the source file.
0538 :param dest: destination file path.
0539 """
0540 with open(dest, "wb") as f:
0541 r = requests.get(url, allow_redirects=True, timeout=timeout)
0542 if r.status_code == 200:
0543 f.write(r.content)
0544 return 0
0545 else:
0546 return -1
0547
0548
0549 def convert_nojsontype_to_value(params):
0550 """
0551 Convert enum to its value
0552
0553 :param params: dict of parameters.
0554
0555 :returns: dict of parameters.
0556 """
0557 if isinstance(params, list):
0558 new_params = []
0559 for v in params:
0560 if v is not None:
0561 if isinstance(v, Enum):
0562 v = v.value
0563 if isinstance(v, datetime.datetime):
0564 v = date_to_str(v)
0565 if isinstance(v, (list, dict)):
0566 v = convert_nojsontype_to_value(v)
0567 new_params.append(v)
0568 params = new_params
0569 elif isinstance(params, dict):
0570 for key in params:
0571 if params[key] is not None:
0572 if isinstance(params[key], Enum):
0573 params[key] = params[key].value
0574 if isinstance(params[key], datetime.datetime):
0575 params[key] = date_to_str(params[key])
0576 if isinstance(params[key], (list, dict)):
0577 params[key] = convert_nojsontype_to_value(params[key])
0578 return params
0579
0580
0581 def convert_value_to_nojsontype(params):
0582 """
0583 Convert value to enum
0584
0585 :param params: dict of parameters.
0586
0587 :returns: dict of parameters.
0588 """
0589 req_keys = {"request_type": RequestType, "status": RequestStatus}
0590 transform_keys = {"transform_type": TransformType, "status": TransformStatus}
0591 coll_keys = {
0592 "coll_type": CollectionType,
0593 "relation_type": CollectionRelationType,
0594 "coll_status": CollectionStatus,
0595 }
0596 content_keys = {"content_type": ContentType, "status": ContentStatus}
0597 process_keys = {"granularity_type": GranularityType, "status": ProcessingStatus}
0598
0599 if "request_type" in params:
0600 keys = req_keys
0601 elif "transform_type" in params:
0602 keys = transform_keys
0603 elif "coll_type" in params:
0604 keys = coll_keys
0605 elif "content_type" in params:
0606 keys = content_keys
0607 elif "granularity_type" in params:
0608 keys = process_keys
0609
0610 if isinstance(params, list):
0611 new_params = []
0612 for v in params:
0613 if v is not None and isinstance(v, (list, dict)):
0614 v = convert_value_to_nojsontype(v)
0615 new_params.append(v)
0616 params = new_params
0617 elif isinstance(params, dict):
0618 keys = []
0619 if "request_type" in params:
0620 keys = req_keys
0621 elif "transform_type" in params:
0622 keys = transform_keys
0623 elif "coll_type" in params:
0624 keys = coll_keys
0625 elif "content_type" in params:
0626 keys = content_keys
0627 elif "granularity_type" in params:
0628 keys = process_keys
0629
0630 for key in keys.keys():
0631 if (
0632 key in params
0633 and params[key] is not None
0634 and isinstance(params[key], int)
0635 ):
0636 params[key] = keys[key](params[key])
0637
0638 for key in params:
0639 if params[key] is not None:
0640 if isinstance(params[key], (list, dict)):
0641 params[key] = convert_value_to_nojsontype(params[key])
0642
0643 return params
0644
0645
0646 def convert_request_type_to_transform_type(request_type):
0647 if isinstance(request_type, RequestType):
0648 request_type = request_type.value
0649 return TransformType(request_type)
0650
0651
0652 class DictClassEncoder(json.JSONEncoder):
0653 def default(self, obj):
0654
0655 if isinstance(obj, IDDSEnum) or isinstance(obj, DictClass):
0656 return obj.to_dict()
0657 elif isinstance(obj, datetime.datetime):
0658 return date_to_str(obj)
0659 elif isinstance(obj, datetime.timedelta):
0660 return str(obj)
0661
0662
0663
0664
0665
0666
0667 return json.JSONEncoder.default(self, obj)
0668
0669
0670 def as_has_dict(dct):
0671 if DictClass.is_class(dct):
0672 return DictClass.from_dict(dct)
0673 return dct
0674
0675
0676 def json_dumps(obj, indent=None, sort_keys=False):
0677 return json.dumps(obj, indent=indent, sort_keys=sort_keys, cls=DictClassEncoder)
0678
0679
0680 def json_loads(obj):
0681 return json.loads(obj, object_hook=as_has_dict)
0682
0683
0684 def get_parameters_from_string(text):
0685 """
0686 Find all strings starting with '%'. For example, for this string below, it should return ['NUM_POINTS', 'IN', 'OUT']
0687 'run --rm -it -v "$(pwd)":/payload gitlab-registry.cern.ch/zhangruihpc/endpointcontainer:latest /bin/bash -c "echo "--num_points %NUM_POINTS"; /bin/cat /payload/%IN>/payload/%OUT"'
0688 """
0689 ret = re.findall(r"[%]\w+", text)
0690 ret = [r.replace("%", "") for r in ret]
0691
0692 ret = list(set(ret))
0693 return ret
0694
0695
0696 def replace_parameters_with_values(text, values):
0697 """
0698 Replace all strings starting with '%'. For example, for this string below, it should replace ['%NUM_POINTS', '%IN', '%OUT']
0699 'run --rm -it -v "$(pwd)":/payload gitlab-registry.cern.ch/zhangruihpc/endpointcontainer:latest /bin/bash -c "echo "--num_points %NUM_POINTS"; /bin/cat /payload/%IN>/payload/%OUT"'
0700
0701 :param text
0702 :param values: parameter values, for example {'NUM_POINTS': 5, 'IN': 'input.json', 'OUT': 'output.json'}
0703 """
0704 for key in values:
0705 key1 = "%" + key
0706 text = re.sub(key1, str(values[key]), text)
0707 return text
0708
0709
0710 def tar_zip_files(output_dir, output_filename, files):
0711 output_filename = os.path.join(output_dir, output_filename)
0712 with tarfile.open(output_filename, "w:gz") as tar:
0713 for file in files:
0714 tar.add(file, arcname=os.path.basename(file))
0715
0716
0717 def exception_handler(function):
0718 @wraps(function)
0719 def new_funct(*args, **kwargs):
0720 try:
0721 return function(*args, **kwargs)
0722 except IDDSException as ex:
0723 logging.error(ex)
0724
0725 return False, str(ex)
0726 except Exception as ex:
0727 logging.error(ex)
0728
0729 return False, str(ex)
0730
0731 return new_funct
0732
0733
0734 def is_sub(a, b):
0735 if not a:
0736 return True
0737
0738 for i in a:
0739 if i not in b:
0740 return False
0741 return True
0742
0743
0744 def get_proxy_path():
0745 try:
0746 if "X509_USER_PROXY" in os.environ:
0747 proxy = os.environ["X509_USER_PROXY"]
0748 if os.path.exists(proxy) and os.access(proxy, os.R_OK):
0749 return proxy
0750 proxy = "/tmp/x509up_u%s" % os.getuid()
0751 if os.path.exists(proxy) and os.access(proxy, os.R_OK):
0752 return proxy
0753 except Exception as ex:
0754 raise IDDSException("Cannot find User proxy: %s" % str(ex))
0755 return None
0756
0757
0758 def get_proxy():
0759 try:
0760 proxy = get_proxy_path()
0761 if not proxy:
0762 return proxy
0763 with open(proxy, "r") as fp:
0764 data = fp.read()
0765 return data
0766 except Exception as ex:
0767 raise IDDSException("Cannot find User proxy: %s" % str(ex))
0768 return None
0769
0770
0771 def is_new_version(version1, version2):
0772 return packaging_version.parse(version1) > packaging_version.parse(version2)
0773
0774
0775 def extract_scope_atlas(did, scopes):
0776
0777 if did.find(":") > -1:
0778 if len(did.split(":")) > 2:
0779 raise IDDSException("Too many colons. Cannot extract scope and name")
0780 scope, name = did.split(":")[0], did.split(":")[1]
0781 if name.endswith("/"):
0782 name = name[:-1]
0783 return scope, name
0784 else:
0785 scope = did.split(".")[0]
0786 if did.startswith("user") or did.startswith("group"):
0787 scope = ".".join(did.split(".")[0:2])
0788 if did.endswith("/"):
0789 did = did[:-1]
0790 return scope, did
0791
0792
0793 def truncate_string(string, length=800):
0794 string = (string[:length] + "...") if string and len(string) > length else string
0795 return string
0796
0797
0798 def merge_dict(dict1, dict2):
0799 keys = list(dict1.keys())
0800 for key in list(dict2.keys()):
0801 if key not in keys:
0802 keys.append(key)
0803 for key in keys:
0804 if key in dict2:
0805 if key not in dict1 or dict1[key] is None:
0806 dict1[key] = dict2[key]
0807 else:
0808 if dict2[key] is None:
0809 continue
0810 elif not isinstance(dict1[key], type(dict2[key])):
0811 raise Exception(
0812 "type of %s is different from %s, cannot merge"
0813 % (type(dict1[key]), type(dict2[key]))
0814 )
0815 elif dict1[key] == dict2[key]:
0816 continue
0817 elif type(dict1[key]) in (list, tuple, str):
0818 dict1[key] = dict1[key] + dict2[key]
0819 elif type(dict1[key]) in (int, float, complex):
0820 dict1[key] = dict1[key] + dict2[key]
0821 elif type(dict1[key]) in (bool, bool):
0822 dict1[key] = True
0823 elif type(dict1[key]) in (dict, dict):
0824 dict1[key] = merge_dict(dict1[key], dict2[key])
0825 return dict1
0826
0827
0828 def pid_exists(pid):
0829 """
0830 Check whether pid exists in the current process table.
0831 UNIX only.
0832 """
0833 if pid < 0:
0834 return False
0835 if pid == 0:
0836
0837
0838
0839
0840 raise ValueError("invalid PID 0")
0841 try:
0842 os.kill(pid, 0)
0843 except OSError as err:
0844 if err.errno == errno.ESRCH:
0845
0846 return False
0847 elif err.errno == errno.EPERM:
0848
0849 return True
0850 else:
0851
0852
0853 raise
0854 else:
0855 return True
0856
0857
0858 def get_list_chunks(full_list, bulk_size=2000):
0859 chunks = [full_list[i:i + bulk_size] for i in range(0, len(full_list), bulk_size)]
0860 return chunks
0861
0862
0863 def report_availability(availability):
0864 try:
0865 log_dir = get_log_dir()
0866 if log_dir:
0867 filename = os.path.join(log_dir, "idds_availability")
0868 with open(filename, "w") as f:
0869 json.dump(availability, f)
0870 else:
0871 print("availability: %s" % str(availability))
0872 except Exception as ex:
0873 error = "Failed to report availablity: %s" % str(ex)
0874 print(error)
0875 logging.debug(error)
0876
0877
0878 def split_chunks_not_continous(data):
0879 rets = []
0880 for k, g in groupby(enumerate(data), lambda i_x: i_x[0] - i_x[1]):
0881 rets.append(list(map(itemgetter(1), g)))
0882 return rets
0883
0884
0885 def group_list(input_list, key):
0886 update_groups = {}
0887 for item in input_list:
0888 item_key = item[key]
0889 del item[key]
0890 item_tuple = str(tuple(sorted(item.items())))
0891 if item_tuple not in update_groups:
0892 update_groups[item_tuple] = {"keys": [], "items": item}
0893 update_groups[item_tuple]["keys"].append(item_key)
0894 return update_groups
0895
0896
0897 def import_func(name: str) -> Callable[..., Any]:
0898 """Returns a function from a dotted path name. Example: `path.to.module:func`.
0899
0900 When the attribute we look for is a staticmethod, module name in its
0901 dotted path is not the last-before-end word
0902
0903 E.g.: package_a.package_b.module_a:ClassA.my_static_method
0904
0905 Thus we remove the bits from the end of the name until we can import it
0906
0907 Args:
0908 name (str): The name (reference) to the path.
0909
0910 Raises:
0911 ValueError: If no module is found or invalid attribute name.
0912
0913 Returns:
0914 Any: An attribute (normally a Callable)
0915 """
0916 name_bits = name.split(":")
0917 module_name_bits, attribute_bits = name_bits[:-1], [name_bits[-1]]
0918 module_name_bits = module_name_bits.split(".")
0919 attribute_bits = attribute_bits.split(".")
0920 module = None
0921 while len(module_name_bits):
0922 try:
0923 module_name = ".".join(module_name_bits)
0924 module = importlib.import_module(module_name)
0925 break
0926 except ImportError:
0927 attribute_bits.insert(0, module_name_bits.pop())
0928
0929 if module is None:
0930
0931 try:
0932 return __builtins__[name]
0933 except KeyError:
0934 raise ValueError("Invalid attribute name: %s" % name)
0935
0936 attribute_name = ".".join(attribute_bits)
0937 if hasattr(module, attribute_name):
0938 return getattr(module, attribute_name)
0939
0940 attribute_name = attribute_bits.pop()
0941 attribute_owner_name = ".".join(attribute_bits)
0942 try:
0943 attribute_owner = getattr(module, attribute_owner_name)
0944 except:
0945 raise ValueError("Invalid attribute name: %s" % attribute_name)
0946
0947 if not hasattr(attribute_owner, attribute_name):
0948 raise ValueError("Invalid attribute name: %s" % name)
0949 return getattr(attribute_owner, attribute_name)
0950
0951
0952 def import_attribute(name: str) -> Callable[..., Any]:
0953 """Returns an attribute from a dotted path name. Example: `path.to.func`.
0954
0955 When the attribute we look for is a staticmethod, module name in its
0956 dotted path is not the last-before-end word
0957
0958 E.g.: package_a.package_b.module_a.ClassA.my_static_method
0959
0960 Thus we remove the bits from the end of the name until we can import it
0961
0962 Args:
0963 name (str): The name (reference) to the path.
0964
0965 Raises:
0966 ValueError: If no module is found or invalid attribute name.
0967
0968 Returns:
0969 Any: An attribute (normally a Callable)
0970 """
0971 name_bits = name.split(".")
0972 module_name_bits, attribute_bits = name_bits[:-1], [name_bits[-1]]
0973 module = None
0974 while len(module_name_bits):
0975 try:
0976 module_name = ".".join(module_name_bits)
0977 module = importlib.import_module(module_name)
0978 break
0979 except ImportError:
0980 attribute_bits.insert(0, module_name_bits.pop())
0981
0982 if module is None:
0983
0984 try:
0985 return __builtins__[name]
0986 except KeyError:
0987 raise ValueError("Invalid attribute name: %s" % name)
0988
0989 attribute_name = ".".join(attribute_bits)
0990 if hasattr(module, attribute_name):
0991 return getattr(module, attribute_name)
0992
0993 attribute_name = attribute_bits.pop()
0994 attribute_owner_name = ".".join(attribute_bits)
0995 try:
0996 attribute_owner = getattr(module, attribute_owner_name)
0997 except:
0998 raise ValueError("Invalid attribute name: %s" % attribute_name)
0999
1000 if not hasattr(attribute_owner, attribute_name):
1001 raise ValueError("Invalid attribute name: %s" % name)
1002 return getattr(attribute_owner, attribute_name)
1003
1004
1005 def decode_base64(sb, remove_quotes=False):
1006 try:
1007 if isinstance(sb, str):
1008 sb_bytes = bytes(sb, "ascii")
1009 elif isinstance(sb, bytes):
1010 sb_bytes = sb
1011 else:
1012 return sb
1013 decode_str = base64.b64decode(sb_bytes).decode("utf-8")
1014
1015 if remove_quotes:
1016 return decode_str[1:-1]
1017 return decode_str
1018 except Exception as ex:
1019 logging.error("decode_base64 %s: %s" % (sb, ex))
1020 return sb
1021
1022
1023 def encode_base64(sb):
1024 try:
1025 if isinstance(sb, str):
1026 sb_bytes = bytes(sb, "ascii")
1027 elif isinstance(sb, bytes):
1028 sb_bytes = sb
1029 return base64.b64encode(sb_bytes).decode("utf-8")
1030 except Exception as ex:
1031 logging.error("encode_base64 %s: %s" % (sb, ex))
1032 return sb
1033
1034
1035 def is_execluded_file(file, exclude_files=[]):
1036 if exclude_files:
1037 for pattern in exclude_files:
1038
1039 if any(c in pattern for c in '*?[]^$.()|+{}'):
1040 reg = re.compile(pattern)
1041 if re.match(reg, file):
1042 return True
1043 else:
1044
1045 if file == pattern:
1046 return True
1047 return False
1048
1049
1050 def create_archive_file(work_dir, archive_filename, files, exclude_files=[]):
1051 if not archive_filename.startswith("/"):
1052 archive_filename = os.path.join(work_dir, archive_filename)
1053
1054 def safe_relpath(path, base):
1055 """Compute relative path, resolving symlinks to avoid '../' chains."""
1056 rel = os.path.relpath(os.path.realpath(path), os.path.realpath(base))
1057 if rel.startswith('..'):
1058
1059 rel = os.path.basename(path)
1060 return rel
1061
1062 with tarfile.open(archive_filename, "w:gz", dereference=True) as tar:
1063 for local_file in files:
1064 if os.path.isfile(local_file):
1065 if is_execluded_file(local_file, exclude_files):
1066 continue
1067
1068 tar.add(local_file, arcname=os.path.basename(local_file))
1069 elif os.path.isdir(local_file):
1070 for filename in os.listdir(local_file):
1071 if is_execluded_file(filename, exclude_files):
1072 continue
1073 if os.path.isfile(filename):
1074 file_path = os.path.join(local_file, filename)
1075 tar.add(
1076 file_path, arcname=safe_relpath(file_path, local_file)
1077 )
1078 elif os.path.isdir(filename):
1079 for root, dirs, fs in os.walk(filename):
1080 for f in fs:
1081 if not is_execluded_file(f, exclude_files):
1082 file_path = os.path.join(root, f)
1083 tar.add(
1084 file_path,
1085 arcname=safe_relpath(file_path, local_file),
1086 )
1087 return archive_filename
1088
1089
1090 class SecureString(object):
1091 def __init__(self, value):
1092 self._value = value
1093
1094 def __str__(self):
1095 return "****"
1096
1097
1098 def is_panda_client_verbose():
1099 verbose = os.environ.get("PANDA_CLIENT_VERBOSE", None)
1100 if verbose:
1101 verbose = verbose.lower()
1102 if verbose == "true":
1103 return True
1104 return False
1105
1106
1107 def get_unique_id_for_dict(dict_):
1108 ret = hashlib.sha1(json.dumps(dict_, sort_keys=True).encode()).hexdigest()
1109
1110 return ret
1111
1112
1113 def idds_mask(dict_):
1114 ret = {}
1115 for k in dict_:
1116 if (
1117 "pass" in k
1118 or "password" in k
1119 or "passwd" in k
1120 or "token" in k
1121 or "security" in k
1122 or "secure" in k
1123 ):
1124 ret[k] = "***"
1125 else:
1126 ret[k] = dict_[k]
1127 return ret
1128
1129
1130 @contextlib.contextmanager
1131 def modified_environ(*remove, **update):
1132 """
1133 Temporarily updates the ``os.environ`` dictionary in-place.
1134 The ``os.environ`` dictionary is updated in-place so that the modification
1135 is sure to work in all situations.
1136 :param remove: Environment variables to remove.
1137 :param update: Dictionary of environment variables and values to add/update.
1138 """
1139 env = os.environ
1140 update = update or {}
1141 remove = remove or []
1142
1143
1144 stomped = (set(update.keys()) | set(remove)) & set(env.keys())
1145
1146 update_after = {k: env[k] for k in stomped}
1147
1148 remove_after = frozenset(k for k in update if k not in env)
1149
1150 try:
1151 env.update(update)
1152 [env.pop(k, None) for k in remove]
1153 yield
1154 finally:
1155 env.update(update_after)
1156 [env.pop(k) for k in remove_after]
1157
1158
1159 def run_with_timeout(func, args=(), kwargs={}, timeout=None, retries=1):
1160 """
1161 Run a function with a timeout.
1162
1163 Parameters:
1164 func (callable): The function to run.
1165 args (tuple): The arguments to pass to the function.
1166 kwargs (dict): The keyword arguments to pass to the function.
1167 timeout (float or int): The time limit in seconds.
1168
1169 Returns:
1170 The function's result if it finishes within the timeout.
1171 Raises TimeoutError if the function takes longer than the specified timeout.
1172 """
1173 for i in range(retries):
1174 with concurrent.futures.ThreadPoolExecutor() as executor:
1175 future = executor.submit(func, *args, **kwargs)
1176 try:
1177 if i > 0:
1178 logging.info(f"retry {i} to execute function.")
1179 return future.result(timeout=timeout)
1180 except concurrent.futures.TimeoutError:
1181
1182 logging.error(
1183 f"Function '{func.__name__}' timed out after {timeout} seconds in retry {i}."
1184 )
1185 return TimeoutError(
1186 f"Function '{func.__name__}' timed out after {timeout} seconds."
1187 )
1188
1189
1190 def timeout_wrapper(timeout, retries=1):
1191 """
1192 Decorator to timeout a function after a given number of seconds.
1193
1194 Parameters:
1195 seconds (int or float): The time limit in seconds.
1196
1197 Raises:
1198 TimeoutError: If the function execution exceeds the time limit.
1199 """
1200
1201 def decorator(func):
1202 @functools.wraps(func)
1203 def wrapper(*args, **kwargs):
1204 for i in range(retries):
1205 with concurrent.futures.ThreadPoolExecutor() as executor:
1206 future = executor.submit(func, *args, **kwargs)
1207 try:
1208 if i > 0:
1209 logging.info(f"retry {i} to execute function.")
1210
1211 return future.result(timeout=timeout)
1212 except concurrent.futures.TimeoutError:
1213
1214 logging.error(
1215 f"Function '{func.__name__}' timed out after {timeout} seconds in retry {i}."
1216 )
1217 return TimeoutError(
1218 f"Function '{func.__name__}' timed out after {timeout} seconds."
1219 )
1220
1221 return wrapper
1222
1223 return decorator
1224
1225
1226 def get_process_thread_info():
1227 """
1228 Returns: hostname, process id, thread id and thread name
1229 """
1230 hostname = socket.getfqdn()
1231 hostname = hostname.split(".")[0]
1232 pid = os.getpid()
1233 hb_thread = threading.current_thread()
1234 thread_id = hb_thread.ident
1235 thread_name = hb_thread.name
1236 return hostname, pid, thread_id, thread_name
1237
1238
1239 def run_command_with_timeout(
1240 command, timeout=600, stdout=sys.stdout, stderr=sys.stderr
1241 ):
1242 """
1243 Run a command and monitor its output. Terminate if no output within timeout.
1244 """
1245 last_output_time = time.time()
1246
1247 def monitor_output(stream, output, timeout):
1248 nonlocal last_output_time
1249 for line in iter(stream.readline, b""):
1250 output.buffer.write(line)
1251 output.flush()
1252 last_output_time = time.time()
1253
1254
1255 process = subprocess.Popen(
1256 command,
1257 preexec_fn=os.setsid,
1258 stdout=subprocess.PIPE,
1259 stderr=subprocess.PIPE,
1260 )
1261
1262
1263 stdout_thread = threading.Thread(
1264 target=monitor_output, args=(process.stdout, stdout, timeout)
1265 )
1266 stderr_thread = threading.Thread(
1267 target=monitor_output, args=(process.stderr, stderr, timeout)
1268 )
1269 stdout_thread.start()
1270 stderr_thread.start()
1271
1272
1273 while process.poll() is None:
1274 time_elapsed = time.time() - last_output_time
1275 if time_elapsed > timeout:
1276 print(f"No output for {time_elapsed} seconds. Terminating process.")
1277 kill_all(process)
1278 break
1279 time.sleep(10)
1280
1281
1282 stdout_thread.join()
1283 stderr_thread.join()
1284 process.wait()
1285 return process