Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:58:16

0001 #!/usr/bin/env python
0002 #
0003 # Licensed under the Apache License, Version 2.0 (the "License");
0004 # You may not use this file except in compliance with the License.
0005 # You may obtain a copy of the License at
0006 # http://www.apache.org/licenses/LICENSE-2.0OA
0007 #
0008 # Authors:
0009 # - Wen Guan, <wen.guan@cern.ch>, 2019 - 2025
0010 # - Lino Oscar Gerlach, <lino.oscar.gerlach@cern.ch>, 2024
0011 
0012 import base64
0013 import concurrent.futures
0014 import contextlib
0015 import errno
0016 import datetime
0017 import functools
0018 import importlib
0019 import hashlib
0020 import logging
0021 import json
0022 import os
0023 import re
0024 import requests
0025 import signal
0026 import socket
0027 import subprocess
0028 import sys
0029 import tarfile
0030 import threading
0031 import time
0032 
0033 from enum import Enum
0034 from functools import wraps
0035 from logging.handlers import RotatingFileHandler
0036 from itertools import groupby
0037 from operator import itemgetter
0038 from packaging import version as packaging_version
0039 from typing import Any, Callable
0040 
0041 from idds.common.config import (
0042     config_has_section,
0043     config_has_option,
0044     config_get,
0045     config_get_bool,
0046     config_get_int,
0047 )
0048 from idds.common.constants import (
0049     IDDSEnum,
0050     RequestType,
0051     RequestStatus,
0052     TransformType,
0053     TransformStatus,
0054     CollectionType,
0055     CollectionRelationType,
0056     CollectionStatus,
0057     ContentType,
0058     ContentStatus,
0059     GranularityType,
0060     ProcessingStatus,
0061 )
0062 from idds.common.dict_class import DictClass
0063 from idds.common.exceptions import IDDSException
0064 
0065 
0066 # RFC 1123
0067 DATE_FORMAT = "%a, %d %b %Y %H:%M:%S UTC"
0068 
0069 
0070 def get_log_dir():
0071     if config_has_section("common") and config_has_option("common", "logdir"):
0072         return config_get("common", "logdir")
0073     return "/var/log/idds"
0074 
0075 
0076 def setup_logging(name, stream=None, log_file=None, loglevel=None):
0077     """
0078     Setup logging
0079     """
0080     if loglevel is None:
0081         if config_has_section("common") and config_has_option("common", "loglevel"):
0082             loglevel = getattr(logging, config_get("common", "loglevel").upper())
0083         else:
0084             loglevel = logging.INFO
0085 
0086         if os.environ.get("IDDS_LOG_LEVEL", None):
0087             idds_log_level = os.environ.get("IDDS_LOG_LEVEL", None)
0088             idds_log_level = idds_log_level.upper()
0089             if idds_log_level in ["DEBUG", "CRITICAL", "ERROR", "WARNING", "INFO"]:
0090                 loglevel = getattr(logging, idds_log_level)
0091 
0092     if type(loglevel) in [str]:
0093         loglevel = loglevel.upper()
0094         loglevel = getattr(logging, loglevel)
0095 
0096     if log_file is not None:
0097         if not log_file.startswith("/"):
0098             logdir = None
0099             if config_has_section("common") and config_has_option("common", "logdir"):
0100                 logdir = config_get("common", "logdir")
0101             if not logdir:
0102                 logdir = "/var/log/idds"
0103             log_file = os.path.join(logdir, log_file)
0104 
0105     if log_file:
0106         logging.basicConfig(
0107             filename=log_file,
0108             level=loglevel,
0109             format="%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s",
0110         )
0111     elif stream is None:
0112         if os.environ.get("IDDS_LOG_FILE", None):
0113             idds_log_file = os.environ.get("IDDS_LOG_FILE", None)
0114             logging.basicConfig(
0115                 filename=idds_log_file,
0116                 level=loglevel,
0117                 format="%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s",
0118             )
0119         elif (
0120             config_has_section("common")
0121             and config_has_option("common", "logdir")
0122             and config_has_option("common", "logfile")
0123         ) or log_file:
0124             if log_file:
0125                 log_filename = log_file
0126             else:
0127                 log_filename = config_get("common", "logfile")
0128                 if not log_filename.startswith("/"):
0129                     log_filename = os.path.join(
0130                         config_get("common", "logdir"), log_filename
0131                     )
0132             logging.basicConfig(
0133                 filename=log_filename,
0134                 level=loglevel,
0135                 format="%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s",
0136             )
0137         else:
0138             logging.basicConfig(
0139                 stream=sys.stdout,
0140                 level=loglevel,
0141                 format="%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s",
0142             )
0143     else:
0144         logging.basicConfig(
0145             stream=stream,
0146             level=loglevel,
0147             format="%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s",
0148         )
0149     logging.Formatter.converter = time.gmtime
0150 
0151 
0152 def get_logger(name, filename=None, loglevel=None):
0153     """
0154     Setup logging
0155     """
0156     if loglevel is None:
0157         if config_has_section("common") and config_has_option("common", "loglevel"):
0158             loglevel = getattr(logging, config_get("common", "loglevel").upper())
0159         else:
0160             loglevel = logging.INFO
0161 
0162     if filename is None:
0163         filename = name + ".log"
0164     if not filename.startswith("/"):
0165         logdir = None
0166         if config_has_section("common") and config_has_option("common", "logdir"):
0167             logdir = config_get("common", "logdir")
0168         if not logdir:
0169             logdir = "/var/log/idds"
0170         filename = os.path.join(logdir, filename)
0171 
0172     formatter = logging.Formatter(
0173         "%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s"
0174     )
0175 
0176     handler = RotatingFileHandler(
0177         filename, maxBytes=2 * 1024 * 1024 * 1024, backupCount=3
0178     )
0179     handler.setFormatter(formatter)
0180     logger = logging.getLogger(name)
0181     logger.setLevel(loglevel)
0182     logger.addHandler(handler)
0183     logger.propagate = False
0184     return logger
0185 
0186 
0187 def get_rest_url_prefix():
0188     if config_has_section("rest") and config_has_option("rest", "url_prefix"):
0189         url_prefix = config_get("rest", "url_prefix")
0190     else:
0191         url_prefix = None
0192     if url_prefix:
0193         while url_prefix.startswith("/"):
0194             url_prefix = url_prefix[1:]
0195         while url_prefix.endswith("/"):
0196             url_prefix = url_prefix[:-1]
0197         url_prefix = "/" + url_prefix
0198     return url_prefix
0199 
0200 
0201 def get_rest_debug():
0202     if config_has_section("rest") and config_has_option("rest", "debug"):
0203         return config_get_bool("rest", "debug")
0204     return False
0205 
0206 
0207 def get_rest_cacher_dir():
0208     cacher_dir = None
0209     if config_has_section("rest") and config_has_option("rest", "cacher_dir"):
0210         cacher_dir = config_get("rest", "cacher_dir")
0211     if cacher_dir and os.path.exists(cacher_dir):
0212         return cacher_dir
0213     raise Exception("cacher_dir is not defined or it doesn't exist")
0214 
0215 
0216 def get_asyncresult_config():
0217     broker_type = None
0218     brokers = None
0219     broker_destination = None
0220     broker_timeout = 360
0221     broker_username = None
0222     broker_password = None
0223     broker_x509 = None
0224 
0225     if config_has_section("asyncresult"):
0226         if config_has_option("asyncresult", "broker_type"):
0227             broker_type = config_get("asyncresult", "broker_type")
0228         if config_has_option("asyncresult", "brokers"):
0229             brokers = config_get("asyncresult", "brokers")
0230         if config_has_option("asyncresult", "broker_destination"):
0231             broker_destination = config_get("asyncresult", "broker_destination")
0232         if config_has_option("asyncresult", "broker_timeout"):
0233             broker_timeout = config_get_int("asyncresult", "broker_timeout")
0234         if config_has_option("asyncresult", "broker_username"):
0235             broker_username = config_get("asyncresult", "broker_username")
0236         if config_has_option("asyncresult", "broker_password"):
0237             broker_password = config_get("asyncresult", "broker_password")
0238         if config_has_option("asyncresult", "broker_x509"):
0239             broker_x509 = config_get("asyncresult", "broker_x509")
0240 
0241     ret = {
0242         "broker_type": broker_type,
0243         "brokers": brokers,
0244         "broker_destination": broker_destination,
0245         "broker_timeout": broker_timeout,
0246         "broker_username": broker_username,
0247         "broker_password": broker_password,
0248         "broker_x509": broker_x509,
0249     }
0250     return ret
0251 
0252 
0253 def get_prompt_broker_config():
0254     transformer_broker = None
0255     transformer_broadcast_broker = None
0256     result_broker = None
0257 
0258     if config_has_section("prompt"):
0259         if config_has_option("prompt", "transformer_broker"):
0260             transformer_broker = json_loads(config_get("prompt", "transformer_broker"))
0261         if config_has_option("prompt", "transformer_broadcast_broker"):
0262             transformer_broadcast_broker = json_loads(config_get("prompt", "transformer_broadcast_broker"))
0263         if config_has_option("prompt", "result_broker"):
0264             result_broker = json_loads(config_get("prompt", "result_broker"))
0265 
0266     ret = {
0267         "transformer_broker": transformer_broker,
0268         "transformer_broadcast_broker": transformer_broadcast_broker,
0269         "result_broker": result_broker,
0270     }
0271     return ret
0272 
0273 
0274 def str_to_date(string):
0275     """
0276     Converts a string to the corresponding datetime value.
0277 
0278     :param string: the string to convert to datetime value.
0279     """
0280     return datetime.datetime.strptime(string, DATE_FORMAT) if string else None
0281 
0282 
0283 def date_to_str(date):
0284     """
0285     Converts a datetime value to a string.
0286 
0287     :param date: the datetime value to convert.
0288     """
0289     return datetime.datetime.strftime(date, DATE_FORMAT) if date else None
0290 
0291 
0292 def has_config():
0293     """
0294     check whether there is a config file
0295     """
0296     if os.environ.get("IDDS_CONFIG", None):
0297         configfile = os.environ.get("IDDS_CONFIG", None)
0298         if configfile and os.path.exists(configfile):
0299             return True
0300     else:
0301         configfiles = [
0302             "%s/etc/idds/idds.cfg" % os.environ.get("IDDS_HOME", ""),
0303             "/etc/idds/idds.cfg",
0304             "%s/etc/idds/idds.cfg" % os.environ.get("VIRTUAL_ENV", ""),
0305         ]
0306 
0307         for configfile in configfiles:
0308             if configfile and os.path.exists(configfile):
0309                 return True
0310     return False
0311 
0312 
0313 def check_rest_host():
0314     """
0315     Function to check whether rest host is defined in config.
0316     To be used to decide whether to skip some test functions.
0317 
0318     :returns True: if rest host is available. Otherwise False.
0319     """
0320     if config_has_option("rest", "host"):
0321         host = config_get("rest", "host")
0322         if host:
0323             return True
0324     return False
0325 
0326 
0327 def get_rest_host():
0328     """
0329     Function to get rest host
0330     """
0331     if "IDDS_HOST" in os.environ:
0332         return os.environ.get("IDDS_HOST")
0333     host = config_get("rest", "host")
0334     url_prefix = get_rest_url_prefix()
0335     while host.endswith("/"):
0336         host = host[:-1]
0337     if url_prefix:
0338         host = "".join([host, url_prefix])
0339     return host
0340 
0341 
0342 def check_user_proxy():
0343     """
0344     Check whether there is a user proxy.
0345     """
0346     if "X509_USER_PROXY" in os.environ:
0347         client_proxy = os.environ["X509_USER_PROXY"]
0348     else:
0349         client_proxy = "/tmp/x509up_u%d" % os.geteuid()
0350 
0351     if not os.path.exists(client_proxy):
0352         return False
0353     else:
0354         return True
0355 
0356 
0357 def check_database():
0358     """
0359     Function to check whether database is defined in config.
0360     To be used to decide whether to skip some test functions.
0361 
0362     :returns True: if database.default is available. Otherwise False.
0363     """
0364     if config_has_option("database", "default"):
0365         database = config_get("database", "default")
0366         if database:
0367             return True
0368     return False
0369 
0370 
0371 def kill_process_group(pgrp, nap=10):
0372     """
0373     Kill the process group.
0374     DO NOT MOVE TO PROCESSES.PY - will lead to circular import since execute() needs it as well.
0375     :param pgrp: process group id (int).
0376     :param nap: napping time between kill signals in seconds (int)
0377     :return: boolean (True if SIGTERM followed by SIGKILL signalling was successful)
0378     """
0379 
0380     status = False
0381     _sleep = True
0382 
0383     # kill the process gracefully
0384     print(f"killing group process {pgrp}")
0385     try:
0386         os.killpg(pgrp, signal.SIGTERM)
0387     except Exception as error:
0388         print(
0389             f"exception thrown when killing child group process under SIGTERM: {error}"
0390         )
0391         _sleep = False
0392     else:
0393         print(f"SIGTERM sent to process group {pgrp}")
0394 
0395     if _sleep:
0396         print(f"sleeping {nap} s to allow processes to exit")
0397         time.sleep(nap)
0398 
0399     try:
0400         os.killpg(pgrp, signal.SIGKILL)
0401     except Exception as error:
0402         print(
0403             f"exception thrown when killing child group process with SIGKILL: {error}"
0404         )
0405     else:
0406         print(f"SIGKILL sent to process group {pgrp}")
0407         status = True
0408 
0409     return status
0410 
0411 
0412 def kill_all(process: Any) -> str:
0413     """
0414     Kill all processes after a time-out exception in process.communication().
0415 
0416     :param process: process object
0417     :return: stderr (str).
0418     """
0419 
0420     stderr = ""
0421     try:
0422         print("killing lingering subprocess and process group")
0423         time.sleep(1)
0424         # process.kill()
0425         kill_process_group(os.getpgid(process.pid))
0426     except ProcessLookupError as exc:
0427         stderr += f"\n(kill process group) ProcessLookupError={exc}"
0428     except Exception as exc:
0429         stderr += f"\n(kill_all 1) exception caught: {exc}"
0430     try:
0431         print("killing lingering process")
0432         time.sleep(1)
0433         os.kill(process.pid, signal.SIGTERM)
0434         print("sleeping a bit before sending SIGKILL")
0435         time.sleep(10)
0436         os.kill(process.pid, signal.SIGKILL)
0437     except ProcessLookupError as exc:
0438         stderr += f"\n(kill process) ProcessLookupError={exc}"
0439     except Exception as exc:
0440         stderr += f"\n(kill_all 2) exception caught: {exc}"
0441     print(f"sent soft kill signals - final stderr: {stderr}")
0442     return stderr
0443 
0444 
0445 def run_process(cmd, stdout=None, stderr=None, wait=False, timeout=7 * 24 * 3600):
0446     """
0447     Runs a command in an out-of-procees shell.
0448     """
0449     print(f"To run command: {cmd}")
0450     if stdout and stderr:
0451         process = subprocess.Popen(
0452             cmd,
0453             shell=True,
0454             stdout=stdout,
0455             stderr=stderr,
0456             preexec_fn=os.setsid,
0457             encoding="utf-8",
0458         )
0459     else:
0460         process = subprocess.Popen(
0461             cmd, shell=True, preexec_fn=os.setsid, encoding="utf-8"
0462         )
0463     if not wait:
0464         return process
0465 
0466     try:
0467         print(f"subprocess.communicate() will use timeout={timeout} s")
0468         process.communicate(timeout=timeout)
0469     except subprocess.TimeoutExpired as ex:
0470         stderr = f"subprocess communicate sent TimeoutExpired: {ex}"
0471         print(stderr)
0472         stderr = kill_all(process)
0473         print(f"Killing process: {stderr}")
0474         exit_code = -1
0475     except Exception as ex:
0476         stderr = f"subprocess has an exception: {ex}"
0477         print(stderr)
0478         stderr = kill_all(process)
0479         print(f"Killing process: {stderr}")
0480         exit_code = -1
0481     else:
0482         exit_code = process.poll()
0483 
0484     try:
0485         process.wait(timeout=60)
0486     except subprocess.TimeoutExpired:
0487         print("process did not complete within the timeout of 60s - terminating")
0488         process.terminate()
0489     return exit_code
0490 
0491 
0492 def run_command(cmd):
0493     """
0494     Runs a command in an out-of-procees shell.
0495     """
0496     process = subprocess.Popen(
0497         cmd,
0498         shell=True,
0499         stdout=subprocess.PIPE,
0500         stderr=subprocess.PIPE,
0501         preexec_fn=os.setsid,
0502     )
0503     stdout, stderr = process.communicate()
0504     if stdout is not None and type(stdout) in [bytes]:
0505         stdout = stdout.decode()
0506     if stderr is not None and type(stderr) in [bytes]:
0507         stderr = stderr.decode()
0508     status = process.returncode
0509     return status, stdout, stderr
0510 
0511 
0512 def get_space_from_string(space_str):
0513     """
0514     Convert space with P, T, G, M to int
0515     """
0516     M = 1024
0517     G = 1024 * M
0518     T = 1024 * G
0519     P = 1024 * T
0520 
0521     if "M" in space_str:
0522         return int(float(space_str.split("M")[0]) * M)
0523     elif "G" in space_str:
0524         return int(float(space_str.split("G")[0]) * G)
0525     elif "T" in space_str:
0526         return int(float(space_str.split("T")[0]) * T)
0527     elif "P" in space_str:
0528         return int(float(space_str.split("P")[0]) * P)
0529     else:
0530         return int(space_str)
0531 
0532 
0533 def urlretrieve(url, dest, timeout=300):
0534     """
0535     Download a file.
0536 
0537     :param url: The url of the source file.
0538     :param dest: destination file path.
0539     """
0540     with open(dest, "wb") as f:
0541         r = requests.get(url, allow_redirects=True, timeout=timeout)
0542         if r.status_code == 200:
0543             f.write(r.content)
0544             return 0
0545         else:
0546             return -1
0547 
0548 
0549 def convert_nojsontype_to_value(params):
0550     """
0551     Convert enum to its value
0552 
0553     :param params: dict of parameters.
0554 
0555     :returns: dict of parameters.
0556     """
0557     if isinstance(params, list):
0558         new_params = []
0559         for v in params:
0560             if v is not None:
0561                 if isinstance(v, Enum):
0562                     v = v.value
0563                 if isinstance(v, datetime.datetime):
0564                     v = date_to_str(v)
0565                 if isinstance(v, (list, dict)):
0566                     v = convert_nojsontype_to_value(v)
0567             new_params.append(v)
0568         params = new_params
0569     elif isinstance(params, dict):
0570         for key in params:
0571             if params[key] is not None:
0572                 if isinstance(params[key], Enum):
0573                     params[key] = params[key].value
0574                 if isinstance(params[key], datetime.datetime):
0575                     params[key] = date_to_str(params[key])
0576                 if isinstance(params[key], (list, dict)):
0577                     params[key] = convert_nojsontype_to_value(params[key])
0578     return params
0579 
0580 
0581 def convert_value_to_nojsontype(params):
0582     """
0583     Convert value to enum
0584 
0585     :param params: dict of parameters.
0586 
0587     :returns: dict of parameters.
0588     """
0589     req_keys = {"request_type": RequestType, "status": RequestStatus}
0590     transform_keys = {"transform_type": TransformType, "status": TransformStatus}
0591     coll_keys = {
0592         "coll_type": CollectionType,
0593         "relation_type": CollectionRelationType,
0594         "coll_status": CollectionStatus,
0595     }
0596     content_keys = {"content_type": ContentType, "status": ContentStatus}
0597     process_keys = {"granularity_type": GranularityType, "status": ProcessingStatus}
0598 
0599     if "request_type" in params:
0600         keys = req_keys
0601     elif "transform_type" in params:
0602         keys = transform_keys
0603     elif "coll_type" in params:
0604         keys = coll_keys
0605     elif "content_type" in params:
0606         keys = content_keys
0607     elif "granularity_type" in params:
0608         keys = process_keys
0609 
0610     if isinstance(params, list):
0611         new_params = []
0612         for v in params:
0613             if v is not None and isinstance(v, (list, dict)):
0614                 v = convert_value_to_nojsontype(v)
0615             new_params.append(v)
0616         params = new_params
0617     elif isinstance(params, dict):
0618         keys = []
0619         if "request_type" in params:
0620             keys = req_keys
0621         elif "transform_type" in params:
0622             keys = transform_keys
0623         elif "coll_type" in params:
0624             keys = coll_keys
0625         elif "content_type" in params:
0626             keys = content_keys
0627         elif "granularity_type" in params:
0628             keys = process_keys
0629 
0630         for key in keys.keys():
0631             if (
0632                 key in params
0633                 and params[key] is not None
0634                 and isinstance(params[key], int)
0635             ):
0636                 params[key] = keys[key](params[key])
0637 
0638         for key in params:
0639             if params[key] is not None:
0640                 if isinstance(params[key], (list, dict)):
0641                     params[key] = convert_value_to_nojsontype(params[key])
0642 
0643     return params
0644 
0645 
0646 def convert_request_type_to_transform_type(request_type):
0647     if isinstance(request_type, RequestType):
0648         request_type = request_type.value
0649     return TransformType(request_type)
0650 
0651 
0652 class DictClassEncoder(json.JSONEncoder):
0653     def default(self, obj):
0654         # print(obj)
0655         if isinstance(obj, IDDSEnum) or isinstance(obj, DictClass):
0656             return obj.to_dict()
0657         elif isinstance(obj, datetime.datetime):
0658             return date_to_str(obj)
0659         elif isinstance(obj, datetime.timedelta):
0660             return str(obj)
0661         # elif isinstance(obj, (datetime.time, datetime.date)):
0662         #     return obj.isoformat()
0663         # elif isinstance(obj, datetime.timedelta):
0664         #     return obj.days * 24 * 60 * 60 + obj.seconds
0665 
0666         # Let the base class default method raise the TypeError
0667         return json.JSONEncoder.default(self, obj)
0668 
0669 
0670 def as_has_dict(dct):
0671     if DictClass.is_class(dct):
0672         return DictClass.from_dict(dct)
0673     return dct
0674 
0675 
0676 def json_dumps(obj, indent=None, sort_keys=False):
0677     return json.dumps(obj, indent=indent, sort_keys=sort_keys, cls=DictClassEncoder)
0678 
0679 
0680 def json_loads(obj):
0681     return json.loads(obj, object_hook=as_has_dict)
0682 
0683 
0684 def get_parameters_from_string(text):
0685     """
0686     Find all strings starting with '%'. For example, for this string below, it should return ['NUM_POINTS', 'IN', 'OUT']
0687     'run --rm -it -v "$(pwd)":/payload gitlab-registry.cern.ch/zhangruihpc/endpointcontainer:latest /bin/bash -c "echo "--num_points %NUM_POINTS"; /bin/cat /payload/%IN>/payload/%OUT"'
0688     """
0689     ret = re.findall(r"[%]\w+", text)
0690     ret = [r.replace("%", "") for r in ret]
0691     # remove dumplications
0692     ret = list(set(ret))
0693     return ret
0694 
0695 
0696 def replace_parameters_with_values(text, values):
0697     """
0698     Replace all strings starting with '%'. For example, for this string below, it should replace ['%NUM_POINTS', '%IN', '%OUT']
0699     'run --rm -it -v "$(pwd)":/payload gitlab-registry.cern.ch/zhangruihpc/endpointcontainer:latest /bin/bash -c "echo "--num_points %NUM_POINTS"; /bin/cat /payload/%IN>/payload/%OUT"'
0700 
0701     :param text
0702     :param values: parameter values, for example {'NUM_POINTS': 5, 'IN': 'input.json', 'OUT': 'output.json'}
0703     """
0704     for key in values:
0705         key1 = "%" + key
0706         text = re.sub(key1, str(values[key]), text)
0707     return text
0708 
0709 
0710 def tar_zip_files(output_dir, output_filename, files):
0711     output_filename = os.path.join(output_dir, output_filename)
0712     with tarfile.open(output_filename, "w:gz") as tar:
0713         for file in files:
0714             tar.add(file, arcname=os.path.basename(file))
0715 
0716 
0717 def exception_handler(function):
0718     @wraps(function)
0719     def new_funct(*args, **kwargs):
0720         try:
0721             return function(*args, **kwargs)
0722         except IDDSException as ex:
0723             logging.error(ex)
0724             # print(traceback.format_exc())
0725             return False, str(ex)
0726         except Exception as ex:
0727             logging.error(ex)
0728             # print(traceback.format_exc())
0729             return False, str(ex)
0730 
0731     return new_funct
0732 
0733 
0734 def is_sub(a, b):
0735     if not a:
0736         return True
0737 
0738     for i in a:
0739         if i not in b:
0740             return False
0741     return True
0742 
0743 
0744 def get_proxy_path():
0745     try:
0746         if "X509_USER_PROXY" in os.environ:
0747             proxy = os.environ["X509_USER_PROXY"]
0748             if os.path.exists(proxy) and os.access(proxy, os.R_OK):
0749                 return proxy
0750         proxy = "/tmp/x509up_u%s" % os.getuid()
0751         if os.path.exists(proxy) and os.access(proxy, os.R_OK):
0752             return proxy
0753     except Exception as ex:
0754         raise IDDSException("Cannot find User proxy: %s" % str(ex))
0755     return None
0756 
0757 
0758 def get_proxy():
0759     try:
0760         proxy = get_proxy_path()
0761         if not proxy:
0762             return proxy
0763         with open(proxy, "r") as fp:
0764             data = fp.read()
0765         return data
0766     except Exception as ex:
0767         raise IDDSException("Cannot find User proxy: %s" % str(ex))
0768     return None
0769 
0770 
0771 def is_new_version(version1, version2):
0772     return packaging_version.parse(version1) > packaging_version.parse(version2)
0773 
0774 
0775 def extract_scope_atlas(did, scopes):
0776     # Try to extract the scope from the DSN
0777     if did.find(":") > -1:
0778         if len(did.split(":")) > 2:
0779             raise IDDSException("Too many colons. Cannot extract scope and name")
0780         scope, name = did.split(":")[0], did.split(":")[1]
0781         if name.endswith("/"):
0782             name = name[:-1]
0783         return scope, name
0784     else:
0785         scope = did.split(".")[0]
0786         if did.startswith("user") or did.startswith("group"):
0787             scope = ".".join(did.split(".")[0:2])
0788         if did.endswith("/"):
0789             did = did[:-1]
0790         return scope, did
0791 
0792 
0793 def truncate_string(string, length=800):
0794     string = (string[:length] + "...") if string and len(string) > length else string
0795     return string
0796 
0797 
0798 def merge_dict(dict1, dict2):
0799     keys = list(dict1.keys())
0800     for key in list(dict2.keys()):
0801         if key not in keys:
0802             keys.append(key)
0803     for key in keys:
0804         if key in dict2:
0805             if key not in dict1 or dict1[key] is None:
0806                 dict1[key] = dict2[key]
0807             else:
0808                 if dict2[key] is None:
0809                     continue
0810                 elif not isinstance(dict1[key], type(dict2[key])):
0811                     raise Exception(
0812                         "type of %s is different from %s, cannot merge"
0813                         % (type(dict1[key]), type(dict2[key]))
0814                     )
0815                 elif dict1[key] == dict2[key]:
0816                     continue
0817                 elif type(dict1[key]) in (list, tuple, str):
0818                     dict1[key] = dict1[key] + dict2[key]
0819                 elif type(dict1[key]) in (int, float, complex):
0820                     dict1[key] = dict1[key] + dict2[key]
0821                 elif type(dict1[key]) in (bool, bool):
0822                     dict1[key] = True
0823                 elif type(dict1[key]) in (dict, dict):
0824                     dict1[key] = merge_dict(dict1[key], dict2[key])
0825     return dict1
0826 
0827 
0828 def pid_exists(pid):
0829     """
0830     Check whether pid exists in the current process table.
0831     UNIX only.
0832     """
0833     if pid < 0:
0834         return False
0835     if pid == 0:
0836         # According to "man 2 kill" PID 0 refers to every process
0837         # in the process group of the calling process.
0838         # On certain systems 0 is a valid PID but we have no way
0839         # to know that in a portable fashion.
0840         raise ValueError("invalid PID 0")
0841     try:
0842         os.kill(pid, 0)
0843     except OSError as err:
0844         if err.errno == errno.ESRCH:
0845             # ESRCH == No such process
0846             return False
0847         elif err.errno == errno.EPERM:
0848             # EPERM clearly means there's a process to deny access to
0849             return True
0850         else:
0851             # According to "man 2 kill" possible error values are
0852             # (EINVAL, EPERM, ESRCH)
0853             raise
0854     else:
0855         return True
0856 
0857 
0858 def get_list_chunks(full_list, bulk_size=2000):
0859     chunks = [full_list[i:i + bulk_size] for i in range(0, len(full_list), bulk_size)]
0860     return chunks
0861 
0862 
0863 def report_availability(availability):
0864     try:
0865         log_dir = get_log_dir()
0866         if log_dir:
0867             filename = os.path.join(log_dir, "idds_availability")
0868             with open(filename, "w") as f:
0869                 json.dump(availability, f)
0870         else:
0871             print("availability: %s" % str(availability))
0872     except Exception as ex:
0873         error = "Failed to report availablity: %s" % str(ex)
0874         print(error)
0875         logging.debug(error)
0876 
0877 
0878 def split_chunks_not_continous(data):
0879     rets = []
0880     for k, g in groupby(enumerate(data), lambda i_x: i_x[0] - i_x[1]):
0881         rets.append(list(map(itemgetter(1), g)))
0882     return rets
0883 
0884 
0885 def group_list(input_list, key):
0886     update_groups = {}
0887     for item in input_list:
0888         item_key = item[key]
0889         del item[key]
0890         item_tuple = str(tuple(sorted(item.items())))
0891         if item_tuple not in update_groups:
0892             update_groups[item_tuple] = {"keys": [], "items": item}
0893         update_groups[item_tuple]["keys"].append(item_key)
0894     return update_groups
0895 
0896 
0897 def import_func(name: str) -> Callable[..., Any]:
0898     """Returns a function from a dotted path name. Example: `path.to.module:func`.
0899 
0900     When the attribute we look for is a staticmethod, module name in its
0901     dotted path is not the last-before-end word
0902 
0903     E.g.: package_a.package_b.module_a:ClassA.my_static_method
0904 
0905     Thus we remove the bits from the end of the name until we can import it
0906 
0907     Args:
0908         name (str): The name (reference) to the path.
0909 
0910     Raises:
0911         ValueError: If no module is found or invalid attribute name.
0912 
0913     Returns:
0914         Any: An attribute (normally a Callable)
0915     """
0916     name_bits = name.split(":")
0917     module_name_bits, attribute_bits = name_bits[:-1], [name_bits[-1]]
0918     module_name_bits = module_name_bits.split(".")
0919     attribute_bits = attribute_bits.split(".")
0920     module = None
0921     while len(module_name_bits):
0922         try:
0923             module_name = ".".join(module_name_bits)
0924             module = importlib.import_module(module_name)
0925             break
0926         except ImportError:
0927             attribute_bits.insert(0, module_name_bits.pop())
0928 
0929     if module is None:
0930         # maybe it's a builtin
0931         try:
0932             return __builtins__[name]
0933         except KeyError:
0934             raise ValueError("Invalid attribute name: %s" % name)
0935 
0936     attribute_name = ".".join(attribute_bits)
0937     if hasattr(module, attribute_name):
0938         return getattr(module, attribute_name)
0939     # staticmethods
0940     attribute_name = attribute_bits.pop()
0941     attribute_owner_name = ".".join(attribute_bits)
0942     try:
0943         attribute_owner = getattr(module, attribute_owner_name)
0944     except:  # noqa
0945         raise ValueError("Invalid attribute name: %s" % attribute_name)
0946 
0947     if not hasattr(attribute_owner, attribute_name):
0948         raise ValueError("Invalid attribute name: %s" % name)
0949     return getattr(attribute_owner, attribute_name)
0950 
0951 
0952 def import_attribute(name: str) -> Callable[..., Any]:
0953     """Returns an attribute from a dotted path name. Example: `path.to.func`.
0954 
0955     When the attribute we look for is a staticmethod, module name in its
0956     dotted path is not the last-before-end word
0957 
0958     E.g.: package_a.package_b.module_a.ClassA.my_static_method
0959 
0960     Thus we remove the bits from the end of the name until we can import it
0961 
0962     Args:
0963         name (str): The name (reference) to the path.
0964 
0965     Raises:
0966         ValueError: If no module is found or invalid attribute name.
0967 
0968     Returns:
0969         Any: An attribute (normally a Callable)
0970     """
0971     name_bits = name.split(".")
0972     module_name_bits, attribute_bits = name_bits[:-1], [name_bits[-1]]
0973     module = None
0974     while len(module_name_bits):
0975         try:
0976             module_name = ".".join(module_name_bits)
0977             module = importlib.import_module(module_name)
0978             break
0979         except ImportError:
0980             attribute_bits.insert(0, module_name_bits.pop())
0981 
0982     if module is None:
0983         # maybe it's a builtin
0984         try:
0985             return __builtins__[name]
0986         except KeyError:
0987             raise ValueError("Invalid attribute name: %s" % name)
0988 
0989     attribute_name = ".".join(attribute_bits)
0990     if hasattr(module, attribute_name):
0991         return getattr(module, attribute_name)
0992     # staticmethods
0993     attribute_name = attribute_bits.pop()
0994     attribute_owner_name = ".".join(attribute_bits)
0995     try:
0996         attribute_owner = getattr(module, attribute_owner_name)
0997     except:  # noqa
0998         raise ValueError("Invalid attribute name: %s" % attribute_name)
0999 
1000     if not hasattr(attribute_owner, attribute_name):
1001         raise ValueError("Invalid attribute name: %s" % name)
1002     return getattr(attribute_owner, attribute_name)
1003 
1004 
1005 def decode_base64(sb, remove_quotes=False):
1006     try:
1007         if isinstance(sb, str):
1008             sb_bytes = bytes(sb, "ascii")
1009         elif isinstance(sb, bytes):
1010             sb_bytes = sb
1011         else:
1012             return sb
1013         decode_str = base64.b64decode(sb_bytes).decode("utf-8")
1014         # remove the single quotes afeter decoding
1015         if remove_quotes:
1016             return decode_str[1:-1]
1017         return decode_str
1018     except Exception as ex:
1019         logging.error("decode_base64 %s: %s" % (sb, ex))
1020         return sb
1021 
1022 
1023 def encode_base64(sb):
1024     try:
1025         if isinstance(sb, str):
1026             sb_bytes = bytes(sb, "ascii")
1027         elif isinstance(sb, bytes):
1028             sb_bytes = sb
1029         return base64.b64encode(sb_bytes).decode("utf-8")
1030     except Exception as ex:
1031         logging.error("encode_base64 %s: %s" % (sb, ex))
1032         return sb
1033 
1034 
1035 def is_execluded_file(file, exclude_files=[]):
1036     if exclude_files:
1037         for pattern in exclude_files:
1038             # If pattern contains '*' or regex special chars, use regex
1039             if any(c in pattern for c in '*?[]^$.()|+{}'):
1040                 reg = re.compile(pattern)
1041                 if re.match(reg, file):
1042                     return True
1043             else:
1044                 # Exact match: only exclude file/dir named exactly 'pattern'
1045                 if file == pattern:
1046                     return True
1047     return False
1048 
1049 
1050 def create_archive_file(work_dir, archive_filename, files, exclude_files=[]):
1051     if not archive_filename.startswith("/"):
1052         archive_filename = os.path.join(work_dir, archive_filename)
1053 
1054     def safe_relpath(path, base):
1055         """Compute relative path, resolving symlinks to avoid '../' chains."""
1056         rel = os.path.relpath(os.path.realpath(path), os.path.realpath(base))
1057         if rel.startswith('..'):
1058             # Fallback to basename if paths don't share a common prefix
1059             rel = os.path.basename(path)
1060         return rel
1061 
1062     with tarfile.open(archive_filename, "w:gz", dereference=True) as tar:
1063         for local_file in files:
1064             if os.path.isfile(local_file):
1065                 if is_execluded_file(local_file, exclude_files):
1066                     continue
1067                 # base_name = os.path.basename(local_file)
1068                 tar.add(local_file, arcname=os.path.basename(local_file))
1069             elif os.path.isdir(local_file):
1070                 for filename in os.listdir(local_file):
1071                     if is_execluded_file(filename, exclude_files):
1072                         continue
1073                     if os.path.isfile(filename):
1074                         file_path = os.path.join(local_file, filename)
1075                         tar.add(
1076                             file_path, arcname=safe_relpath(file_path, local_file)
1077                         )
1078                     elif os.path.isdir(filename):
1079                         for root, dirs, fs in os.walk(filename):
1080                             for f in fs:
1081                                 if not is_execluded_file(f, exclude_files):
1082                                     file_path = os.path.join(root, f)
1083                                     tar.add(
1084                                         file_path,
1085                                         arcname=safe_relpath(file_path, local_file),
1086                                     )
1087     return archive_filename
1088 
1089 
1090 class SecureString(object):
1091     def __init__(self, value):
1092         self._value = value
1093 
1094     def __str__(self):
1095         return "****"
1096 
1097 
1098 def is_panda_client_verbose():
1099     verbose = os.environ.get("PANDA_CLIENT_VERBOSE", None)
1100     if verbose:
1101         verbose = verbose.lower()
1102         if verbose == "true":
1103             return True
1104     return False
1105 
1106 
1107 def get_unique_id_for_dict(dict_):
1108     ret = hashlib.sha1(json.dumps(dict_, sort_keys=True).encode()).hexdigest()
1109     # logging.debug("get_unique_id_for_dict, type: %s: %s, ret: %s" % (type(dict_), dict_, ret))
1110     return ret
1111 
1112 
1113 def idds_mask(dict_):
1114     ret = {}
1115     for k in dict_:
1116         if (
1117             "pass" in k
1118             or "password" in k
1119             or "passwd" in k
1120             or "token" in k
1121             or "security" in k
1122             or "secure" in k
1123         ):
1124             ret[k] = "***"
1125         else:
1126             ret[k] = dict_[k]
1127     return ret
1128 
1129 
1130 @contextlib.contextmanager
1131 def modified_environ(*remove, **update):
1132     """
1133     Temporarily updates the ``os.environ`` dictionary in-place.
1134     The ``os.environ`` dictionary is updated in-place so that the modification
1135     is sure to work in all situations.
1136     :param remove: Environment variables to remove.
1137     :param update: Dictionary of environment variables and values to add/update.
1138     """
1139     env = os.environ
1140     update = update or {}
1141     remove = remove or []
1142 
1143     # List of environment variables being updated or removed.
1144     stomped = (set(update.keys()) | set(remove)) & set(env.keys())
1145     # Environment variables and values to restore on exit.
1146     update_after = {k: env[k] for k in stomped}
1147     # Environment variables and values to remove on exit.
1148     remove_after = frozenset(k for k in update if k not in env)
1149 
1150     try:
1151         env.update(update)
1152         [env.pop(k, None) for k in remove]
1153         yield
1154     finally:
1155         env.update(update_after)
1156         [env.pop(k) for k in remove_after]
1157 
1158 
1159 def run_with_timeout(func, args=(), kwargs={}, timeout=None, retries=1):
1160     """
1161     Run a function with a timeout.
1162 
1163     Parameters:
1164         func (callable): The function to run.
1165         args (tuple): The arguments to pass to the function.
1166         kwargs (dict): The keyword arguments to pass to the function.
1167         timeout (float or int): The time limit in seconds.
1168 
1169     Returns:
1170         The function's result if it finishes within the timeout.
1171         Raises TimeoutError if the function takes longer than the specified timeout.
1172     """
1173     for i in range(retries):
1174         with concurrent.futures.ThreadPoolExecutor() as executor:
1175             future = executor.submit(func, *args, **kwargs)
1176             try:
1177                 if i > 0:
1178                     logging.info(f"retry {i} to execute function.")
1179                 return future.result(timeout=timeout)
1180             except concurrent.futures.TimeoutError:
1181                 # raise TimeoutError(f"Function '{func.__name__}' timed out after {timeout} seconds.")
1182                 logging.error(
1183                     f"Function '{func.__name__}' timed out after {timeout} seconds in retry {i}."
1184                 )
1185     return TimeoutError(
1186         f"Function '{func.__name__}' timed out after {timeout} seconds."
1187     )
1188 
1189 
1190 def timeout_wrapper(timeout, retries=1):
1191     """
1192     Decorator to timeout a function after a given number of seconds.
1193 
1194     Parameters:
1195         seconds (int or float): The time limit in seconds.
1196 
1197     Raises:
1198         TimeoutError: If the function execution exceeds the time limit.
1199     """
1200 
1201     def decorator(func):
1202         @functools.wraps(func)
1203         def wrapper(*args, **kwargs):
1204             for i in range(retries):
1205                 with concurrent.futures.ThreadPoolExecutor() as executor:
1206                     future = executor.submit(func, *args, **kwargs)
1207                     try:
1208                         if i > 0:
1209                             logging.info(f"retry {i} to execute function.")
1210 
1211                         return future.result(timeout=timeout)
1212                     except concurrent.futures.TimeoutError:
1213                         # raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds.")
1214                         logging.error(
1215                             f"Function '{func.__name__}' timed out after {timeout} seconds in retry {i}."
1216                         )
1217             return TimeoutError(
1218                 f"Function '{func.__name__}' timed out after {timeout} seconds."
1219             )
1220 
1221         return wrapper
1222 
1223     return decorator
1224 
1225 
1226 def get_process_thread_info():
1227     """
1228     Returns: hostname, process id, thread id and thread name
1229     """
1230     hostname = socket.getfqdn()
1231     hostname = hostname.split(".")[0]
1232     pid = os.getpid()
1233     hb_thread = threading.current_thread()
1234     thread_id = hb_thread.ident
1235     thread_name = hb_thread.name
1236     return hostname, pid, thread_id, thread_name
1237 
1238 
1239 def run_command_with_timeout(
1240     command, timeout=600, stdout=sys.stdout, stderr=sys.stderr
1241 ):
1242     """
1243     Run a command and monitor its output. Terminate if no output within timeout.
1244     """
1245     last_output_time = time.time()
1246 
1247     def monitor_output(stream, output, timeout):
1248         nonlocal last_output_time
1249         for line in iter(stream.readline, b""):
1250             output.buffer.write(line)
1251             output.flush()
1252             last_output_time = time.time()  # Reset timer on new output
1253 
1254     # Start the process
1255     process = subprocess.Popen(
1256         command,
1257         preexec_fn=os.setsid,  # setpgrp
1258         stdout=subprocess.PIPE,
1259         stderr=subprocess.PIPE,
1260     )
1261 
1262     # Start the monitoring thread
1263     stdout_thread = threading.Thread(
1264         target=monitor_output, args=(process.stdout, stdout, timeout)
1265     )
1266     stderr_thread = threading.Thread(
1267         target=monitor_output, args=(process.stderr, stderr, timeout)
1268     )
1269     stdout_thread.start()
1270     stderr_thread.start()
1271 
1272     # monitor the output and enforce timeout
1273     while process.poll() is None:
1274         time_elapsed = time.time() - last_output_time
1275         if time_elapsed > timeout:
1276             print(f"No output for {time_elapsed} seconds. Terminating process.")
1277             kill_all(process)
1278             break
1279         time.sleep(10)  # Check every second
1280 
1281     # Wait for the process to complete and join the monitoring thread
1282     stdout_thread.join()
1283     stderr_thread.join()
1284     process.wait()
1285     return process