Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-10 08:38:58

0001 import datetime
0002 import os
0003 import re
0004 import socket
0005 import sys
0006 import traceback
0007 
0008 from pandacommon.pandalogger.PandaLogger import PandaLogger
0009 from pandacommon.pandautils.PandaUtils import naive_utcnow
0010 
0011 from pandajedi.jedicore.MsgWrapper import MsgWrapper
0012 from pandajedi.jedicore.ThreadUtils import ListWithLock, ThreadPool, WorkerThread
0013 
0014 from .WatchDogBase import WatchDogBase
0015 
0016 logger = PandaLogger().getLogger(__name__.split(".")[-1])
0017 
0018 
0019 # data locality updater for ATLAS
0020 class AtlasDataLocalityUpdaterWatchDog(WatchDogBase):
0021     # constructor
0022     def __init__(self, taskBufferIF, ddmIF):
0023         WatchDogBase.__init__(self, taskBufferIF, ddmIF)
0024         self.pid = f"{socket.getfqdn().split('.')[0]}-{os.getpid()}-dog"
0025         self.vo = "atlas"
0026         self.ddmIF = ddmIF.getInterface(self.vo)
0027 
0028     # get list-with-lock of datasets to update
0029     def get_datasets_list(self):
0030         datasets_list = self.taskBufferIF.get_tasks_inputdatasets_JEDI(self.vo)
0031         datasets_list = ListWithLock(datasets_list)
0032         # return
0033         return datasets_list
0034 
0035     # update data locality records to DB table
0036     def doUpdateDataLocality(self):
0037         tmpLog = MsgWrapper(logger, " #ATM #KV doUpdateDataLocality")
0038         tmpLog.debug("start")
0039         try:
0040             # lock
0041             got_lock = self.taskBufferIF.lockProcess_JEDI(
0042                 vo=self.vo,
0043                 prodSourceLabel="default",
0044                 cloud=None,
0045                 workqueue_id=None,
0046                 resource_name=None,
0047                 component="AtlasDataLocaUpdDog.doUpdateDataLoca",
0048                 pid=self.pid,
0049                 timeLimit=240,
0050             )
0051             if not got_lock:
0052                 tmpLog.debug("locked by another process. Skipped")
0053                 return
0054             tmpLog.debug("got lock")
0055             # get list of datasets
0056             datasets_list = self.get_datasets_list()
0057             tmpLog.debug(f"got {len(datasets_list)} datasets to update")
0058             # make thread pool
0059             thread_pool = ThreadPool()
0060             # make workers
0061             n_workers = 8
0062             for _ in range(n_workers):
0063                 thr = DataLocalityUpdaterThread(
0064                     taskDsList=datasets_list, threadPool=thread_pool, taskbufferIF=self.taskBufferIF, ddmIF=self.ddmIF, pid=self.pid, loggerObj=tmpLog
0065                 )
0066                 thr.start()
0067             tmpLog.debug(f"started {n_workers} updater workers")
0068             # join
0069             thread_pool.join()
0070             # done
0071             tmpLog.debug("done")
0072         except Exception:
0073             errtype, errvalue = sys.exc_info()[:2]
0074             tmpLog.error(f"failed with {errtype} {errvalue} {traceback.format_exc()}")
0075 
0076     # clean up old data locality records in DB table
0077     def doCleanDataLocality(self):
0078         tmpLog = MsgWrapper(logger, " #ATM #KV doCleanDataLocality")
0079         tmpLog.debug("start")
0080         try:
0081             # lock
0082             got_lock = self.taskBufferIF.lockProcess_JEDI(
0083                 vo=self.vo,
0084                 prodSourceLabel="default",
0085                 cloud=None,
0086                 workqueue_id=None,
0087                 resource_name=None,
0088                 component="AtlasDataLocaUpdDog.doCleanDataLoca",
0089                 pid=self.pid,
0090                 timeLimit=1440,
0091             )
0092             if not got_lock:
0093                 tmpLog.debug("locked by another process. Skipped")
0094                 return
0095             tmpLog.debug("got lock")
0096             # lifetime of records
0097             record_lifetime_hours = 24
0098             # run
0099             now_timestamp = naive_utcnow()
0100             before_timestamp = now_timestamp - datetime.timedelta(hours=record_lifetime_hours)
0101             n_rows = self.taskBufferIF.deleteOutdatedDatasetLocality_JEDI(before_timestamp)
0102             tmpLog.info(f"cleaned up {n_rows} records")
0103             # done
0104             tmpLog.debug("done")
0105         except Exception:
0106             errtype, errvalue = sys.exc_info()[:2]
0107             tmpLog.error(f"failed with {errtype} {errvalue} {traceback.format_exc()}")
0108 
0109     # main
0110     def doAction(self):
0111         try:
0112             # get logger
0113             origTmpLog = MsgWrapper(logger)
0114             origTmpLog.debug("start")
0115             # clean up data locality
0116             self.doCleanDataLocality()
0117             # update data locality
0118             self.doUpdateDataLocality()
0119         except Exception:
0120             errtype, errvalue = sys.exc_info()[:2]
0121             origTmpLog.error(f"failed with {errtype} {errvalue}")
0122         # return
0123         origTmpLog.debug("done")
0124         return self.SC_SUCCEEDED
0125 
0126 
0127 # thread for data locality update
0128 class DataLocalityUpdaterThread(WorkerThread):
0129     # constructor
0130     def __init__(self, taskDsList, threadPool, taskbufferIF, ddmIF, pid, loggerObj):
0131         # initialize worker with no semaphore
0132         WorkerThread.__init__(self, None, threadPool, loggerObj)
0133         # attributes
0134         self.taskDsList = taskDsList
0135         self.taskBufferIF = taskbufferIF
0136         self.ddmIF = ddmIF
0137         self.msgType = "datalocalityupdate"
0138         self.pid = pid
0139         self.logger = loggerObj
0140 
0141     # main
0142     def runImpl(self):
0143         # initialize
0144         n_updated_ds = 0
0145         n_skipped_ds = 0
0146         n_updated_replicas = 0
0147         n_skipped_replicas = 0
0148         while True:
0149             try:
0150                 # get part of datasets
0151                 nDatasets = 5
0152                 taskDsList = self.taskDsList.get(nDatasets)
0153                 if len(taskDsList) == 0:
0154                     # no more datasets, quit
0155                     self.logger.debug(
0156                         f"{self.name} terminating since no more items; updated {n_updated_ds} datasets and {n_updated_replicas} replicas; skipped {n_skipped_ds} datasets and {n_skipped_replicas} replicas"
0157                     )
0158                     return
0159                 # loop over these datasets
0160                 for item in taskDsList:
0161                     if item is None:
0162                         n_skipped_ds += 1
0163                         continue
0164                     jedi_task_id, dataset_id, dataset_name = item
0165                     _, task_spec = self.taskBufferIF.getTaskWithID_JEDI(jedi_task_id)
0166                     dataset_replicas_map = self.ddmIF.listDatasetReplicas(dataset_name)
0167                     is_distributed_ds = self.ddmIF.isDistributedDataset(dataset_name)
0168                     # get rules when using data carousel
0169                     rule_rse_list = []
0170                     rule_rse_types = []
0171                     if task_spec.inputPreStaging():
0172                         # collect rse expressions from rules
0173                         _, tmp_rules = self.ddmIF.get_rules_state(dataset_name)
0174                         rule_rse_list = [r["rse_expression"] for r in tmp_rules.values()]
0175                         rule_rse_types = []
0176                         # extract rse types from rse expressions
0177                         for tmp_rse in rule_rse_list:
0178                             m = re.search(r"type=([^)]+)", tmp_rse)
0179                             if m:
0180                                 rule_rse_types.append(m.group(1))
0181                     # loop over all replicas
0182                     for tmp_rse, tmp_stat_list in dataset_replicas_map.items():
0183                         # pre-checks
0184                         if is_distributed_ds:
0185                             # no checks for distributed datasets
0186                             pass
0187                         elif task_spec.inputPreStaging():
0188                             # use only replicas with rules when using data carousel
0189                             if tmp_rse not in rule_rse_list:
0190                                 # check rse type
0191                                 to_skip = True
0192                                 for rse_type in rule_rse_types:
0193                                     if rse_type in tmp_rse:
0194                                         to_skip = False
0195                                         break
0196                                 if to_skip:
0197                                     n_skipped_replicas += 1
0198                                     self.logger.debug(
0199                                         f"skipped {tmp_rse} for dataset {dataset_name} due to missing rule in {rule_rse_list} or rse type in {rule_rse_types}"
0200                                     )
0201                                     continue
0202                         else:
0203                             # use only complete replicas unless input is distributed or uses data carousel
0204                             tmp_statistics = tmp_stat_list[-1]
0205                             # skip unknown and incomplete
0206                             if tmp_statistics["found"] is None or tmp_statistics["found"] != tmp_statistics["total"]:
0207                                 n_skipped_replicas += 1
0208                                 continue
0209                         # update dataset locality table
0210                         self.taskBufferIF.updateDatasetLocality_JEDI(jedi_taskid=jedi_task_id, datasetid=dataset_id, rse=tmp_rse)
0211                         n_updated_replicas += 1
0212                     n_updated_ds += 1
0213             except Exception as e:
0214                 self.logger.error(f"{self.__class__.__name__} failed in runImpl() with {str(e)}: {traceback.format_exc()}")
0215                 return