Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-10 08:38:58

0001 import os
0002 import socket
0003 import sys
0004 import traceback
0005 
0006 from pandacommon.pandalogger.PandaLogger import PandaLogger
0007 
0008 from pandajedi.jedibrokerage import AtlasBrokerUtils
0009 from pandajedi.jediconfig import jedi_config
0010 from pandajedi.jedicore.MsgWrapper import MsgWrapper
0011 from pandaserver.dataservice import DataServiceUtils
0012 
0013 from .WatchDogBase import WatchDogBase
0014 
0015 logger = PandaLogger().getLogger(__name__.split(".")[-1])
0016 
0017 
0018 # task withholder watchdog for ATLAS
0019 class AtlasTaskWithholderWatchDog(WatchDogBase):
0020     # constructor
0021     def __init__(self, taskBufferIF, ddmIF):
0022         WatchDogBase.__init__(self, taskBufferIF, ddmIF)
0023         self.pid = f"{socket.getfqdn().split('.')[0]}-{os.getpid()}-dog"
0024         # self.cronActions = {'forPrestage': 'atlas_prs'}
0025         self.vo = "atlas"
0026         self.prodSourceLabelList = ["managed"]
0027         # call refresh
0028         self.refresh()
0029 
0030     # get process lock
0031     def _get_lock(self):
0032         return self.taskBufferIF.lockProcess_JEDI(
0033             vo=self.vo,
0034             prodSourceLabel="managed",
0035             cloud=None,
0036             workqueue_id=None,
0037             resource_name=None,
0038             component="AtlasTaskWithholderWatchDog",
0039             pid=self.pid,
0040             timeLimit=5,
0041         )
0042 
0043     # refresh information stored in the instance
0044     def refresh(self):
0045         # work queue mapper
0046         self.workQueueMapper = self.taskBufferIF.getWorkQueueMap()
0047         # site mapper
0048         self.siteMapper = self.taskBufferIF.get_site_mapper()
0049         # all sites
0050         allSiteList = []
0051         for siteName, tmpSiteSpec in self.siteMapper.siteSpecList.items():
0052             # if tmpSiteSpec.type == 'analysis' or tmpSiteSpec.is_grandly_unified():
0053             allSiteList.append(siteName)
0054         self.allSiteList = allSiteList
0055 
0056     # get map of site to list of RSEs and blacklisted RSEs
0057     def get_site_rse_map_and_blacklisted_rse_set(self, prod_source_label):
0058         site_rse_map = {}
0059         blacklisted_rse_set = set()
0060         for tmpPseudoSiteName in self.allSiteList:
0061             tmpSiteSpec = self.siteMapper.getSite(tmpPseudoSiteName)
0062             tmpSiteName = tmpSiteSpec.get_unified_name()
0063             scope_input, scope_output = DataServiceUtils.select_scope(tmpSiteSpec, prod_source_label, prod_source_label)
0064             try:
0065                 tmp_ddm_spec = tmpSiteSpec.ddm_endpoints_input[scope_input]
0066                 endpoint_name = tmpSiteSpec.ddm_input[scope_input]
0067                 endpoint_token_map = tmp_ddm_spec.getTokenMap("input")
0068                 tmp_endpoint = tmp_ddm_spec.getEndPoint(endpoint_name)
0069             except KeyError:
0070                 continue
0071             else:
0072                 # fill site rse map
0073                 site_rse_map[tmpSiteName] = list(endpoint_token_map.values())
0074                 # blacklisted rse
0075                 if tmp_endpoint is not None and tmp_endpoint["blacklisted"] == "Y":
0076                     blacklisted_rse_set.add(endpoint_name)
0077         # return
0078         return site_rse_map, blacklisted_rse_set
0079 
0080     # get busy sites
0081     def get_busy_sites(self, gshare, cutoff):
0082         busy_sites_list = []
0083         # get global share
0084         tmpSt, jobStatPrioMap = self.taskBufferIF.getJobStatisticsByGlobalShare(self.vo)
0085         if not tmpSt:
0086             # got nothing...
0087             return busy_sites_list
0088         for tmpPseudoSiteName in self.allSiteList:
0089             tmpSiteSpec = self.siteMapper.getSite(tmpPseudoSiteName)
0090             tmpSiteName = tmpSiteSpec.get_unified_name()
0091             # get nQueue and nRunning
0092             nRunning = AtlasBrokerUtils.getNumJobs(jobStatPrioMap, tmpSiteName, "running", workQueue_tag=gshare)
0093             nQueue = 0
0094             for jobStatus in ["defined", "assigned", "activated", "starting"]:
0095                 nQueue += AtlasBrokerUtils.getNumJobs(jobStatPrioMap, tmpSiteName, jobStatus, workQueue_tag=gshare)
0096             # busy sites
0097             if nQueue > max(cutoff, nRunning * 2):
0098                 busy_sites_list.append(tmpSiteName)
0099         # return
0100         return busy_sites_list
0101 
0102     # # handle waiting jobs
0103     # def do_make_tasks_pending(self, task_list):
0104     #     tmpLog = MsgWrapper(logger, 'do_make_tasks_pending')
0105     #     tmpLog.debug('start')
0106     #     # check every x min
0107     #     checkInterval = 20
0108     #     # make task pending
0109     #     for taskID, pending_reason in task_list:
0110     #         tmpLog = MsgWrapper(logger, '< #ATM #KV do_make_tasks_pending jediTaskID={0}>'.format(taskID))
0111     #         retVal = self.taskBufferIF.makeTaskPending_JEDI(taskID, reason=pending_reason)
0112     #         tmpLog.debug('done with {0}'.format(retVal))
0113 
0114     # set tasks to be pending due to condition of data locality
0115     def do_for_data_locality(self):
0116         tmp_log = MsgWrapper(logger)
0117         # refresh
0118         self.refresh()
0119         # list of resource type
0120         # resource_type_list = [ rt.resource_name for rt in self.taskBufferIF.load_resource_types() ]
0121         # loop
0122         for prod_source_label in self.prodSourceLabelList:
0123             # site-rse map and blacklisted rses
0124             site_rse_map, blacklisted_rse_set = self.get_site_rse_map_and_blacklisted_rse_set(prod_source_label)
0125             tmp_log.debug(f"Found {len(blacklisted_rse_set)} blacklisted RSEs : {','.join(list(blacklisted_rse_set))}")
0126             # parameter from GDP config
0127             upplimit_ioIntensity = self.taskBufferIF.getConfigValue("task_withholder", f"LIMIT_IOINTENSITY_{prod_source_label}", "jedi", self.vo)
0128             lowlimit_currentPriority = self.taskBufferIF.getConfigValue("task_withholder", f"LIMIT_PRIORITY_{prod_source_label}", "jedi", self.vo)
0129             if upplimit_ioIntensity is None:
0130                 upplimit_ioIntensity = 999999
0131             if lowlimit_currentPriority is None:
0132                 lowlimit_currentPriority = -999999
0133             upplimit_ioIntensity = max(upplimit_ioIntensity, 100)
0134             # get work queue for gshare
0135             work_queue_list = self.workQueueMapper.getAlignedQueueList(self.vo, prod_source_label)
0136             # loop over work queue
0137             for work_queue in work_queue_list:
0138                 gshare = work_queue.queue_name
0139                 # get cutoff
0140                 cutoff = self.taskBufferIF.getConfigValue("jobbroker", f"NQUEUELIMITSITE_{gshare}", "jedi", self.vo)
0141                 if not cutoff:
0142                     cutoff = 20
0143                 # busy sites
0144                 busy_sites_list = self.get_busy_sites(gshare, cutoff)
0145                 # rses of busy sites
0146                 busy_rses = set()
0147                 for site in busy_sites_list:
0148                     try:
0149                         busy_rses.update(set(site_rse_map[site]))
0150                     except KeyError:
0151                         continue
0152                 # make sql parameters of rses
0153                 to_exclude_rses = list(busy_rses | blacklisted_rse_set)
0154                 if not to_exclude_rses:
0155                     continue
0156                 rse_params_list = []
0157                 rse_params_map = {}
0158                 for j, rse in enumerate(to_exclude_rses):
0159                     rse_param = f":rse_{j + 1}"
0160                     rse_params_list.append(rse_param)
0161                     rse_params_map[rse_param] = rse
0162                 rse_params_str = ",".join(rse_params_list)
0163                 # sql
0164                 sql_query = (
0165                     "SELECT t.jediTaskID "
0166                     "FROM {jedi_schema}.JEDI_Tasks t "
0167                     "WHERE t.status IN ('ready','running','scouting') AND t.lockedBy IS NULL "
0168                     "AND t.gshare=:gshare "
0169                     "AND t.ioIntensity>=:ioIntensity AND t.currentPriority<:currentPriority "
0170                     "AND EXISTS ( "
0171                     "SELECT * FROM {jedi_schema}.JEDI_Datasets d "
0172                     "WHERE d.jediTaskID=t.jediTaskID "
0173                     "AND d.type='input' "
0174                     ") "
0175                     "AND NOT EXISTS ( "
0176                     "SELECT * FROM {jedi_schema}.JEDI_Dataset_Locality dl "
0177                     "WHERE dl.jediTaskID=t.jediTaskID "
0178                     "AND dl.rse NOT IN ({rse_params_str}) "
0179                     ") "
0180                 ).format(jedi_schema=jedi_config.db.schemaJEDI, rse_params_str=rse_params_str)
0181                 # params map
0182                 params_map = {
0183                     ":gshare": gshare,
0184                     ":ioIntensity": upplimit_ioIntensity,
0185                     ":currentPriority": lowlimit_currentPriority,
0186                 }
0187                 params_map.update(rse_params_map)
0188                 # pending reason
0189                 reason = (
0190                     "no local input data, ioIntensity>={ioIntensity}, currentPriority<{currentPriority},"
0191                     "nQueue>max({cutOff},nRunning*2) at all sites where the task can run".format(
0192                         ioIntensity=upplimit_ioIntensity, currentPriority=lowlimit_currentPriority, cutOff=cutoff
0193                     )
0194                 )
0195                 # set pending
0196                 dry_run = False
0197                 if dry_run:
0198                     dry_sql_query = (
0199                         "SELECT t.jediTaskID "
0200                         "FROM {jedi_schema}.JEDI_Tasks t "
0201                         "WHERE t.status IN ('ready','running','scouting') AND t.lockedBy IS NULL "
0202                         "AND t.gshare=:gshare "
0203                         "AND t.ioIntensity>=:ioIntensity AND t.currentPriority<:currentPriority "
0204                         "AND EXISTS ( "
0205                         "SELECT * FROM {jedi_schema}.JEDI_Datasets d "
0206                         "WHERE d.jediTaskID=t.jediTaskID "
0207                         "AND d.type='input' "
0208                         ") "
0209                         "AND NOT EXISTS ( "
0210                         "SELECT * FROM {jedi_schema}.JEDI_Dataset_Locality dl "
0211                         "WHERE dl.jediTaskID=t.jediTaskID "
0212                         "AND dl.rse NOT IN ({rse_params_str}) "
0213                         ") "
0214                     ).format(jedi_schema=jedi_config.db.schemaJEDI, rse_params_str=rse_params_str)
0215                     res = self.taskBufferIF.querySQL(dry_sql_query, params_map)
0216                     n_tasks = 0 if res is None else len(res)
0217                     if n_tasks > 0:
0218                         result = [x[0] for x in res]
0219                         tmp_log.debug(f'[dry run] gshare: {gshare:<16} {n_tasks:>5} tasks would be pending : {result} ; reason="{reason}" ')
0220                 else:
0221                     n_tasks = self.taskBufferIF.queryTasksToBePending_JEDI(sql_query, params_map, reason)
0222                     if n_tasks is not None and n_tasks > 0:
0223                         tmp_log.info(f'gshare: {gshare:<16} {str(n_tasks):>5} tasks got pending ; reason="{reason}" ')
0224 
0225     # main
0226     def doAction(self):
0227         try:
0228             # get logger
0229             origTmpLog = MsgWrapper(logger)
0230             origTmpLog.debug("start")
0231             # lock
0232             got_lock = self._get_lock()
0233             if not got_lock:
0234                 origTmpLog.debug("locked by another process. Skipped")
0235                 return self.SC_SUCCEEDED
0236             origTmpLog.debug("got lock")
0237             # make tasks pending under certain conditions
0238             self.do_for_data_locality()
0239         except Exception:
0240             errtype, errvalue = sys.exc_info()[:2]
0241             err_str = traceback.format_exc()
0242             origTmpLog.error(f"failed with {errtype} {errvalue} ; {err_str}")
0243         # return
0244         origTmpLog.debug("done")
0245         return self.SC_SUCCEEDED