Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-19 08:00:05

0001 import os
0002 import re
0003 import stat
0004 import subprocess
0005 import tempfile
0006 from math import ceil
0007 
0008 from pandaharvester.harvesterconfig import harvester_config
0009 from pandaharvester.harvestercore import core_utils
0010 from pandaharvester.harvestercore.plugin_base import PluginBase
0011 from pandaharvester.harvestercore.resource_type_mapper import ResourceTypeMapper
0012 from pandaharvester.harvestermisc.info_utils import PandaQueuesDict
0013 from pandaharvester.harvestersubmitter import submitter_common
0014 
0015 # logger
0016 baseLogger = core_utils.setup_logger("slurm_submitter_rubin")
0017 
0018 
0019 # submitter for SLURM batch system
0020 class SlurmSubmitter(PluginBase):
0021     # constructor
0022     def __init__(self, **kwarg):
0023         self.uploadLog = False
0024         self.logBaseURL = None
0025         PluginBase.__init__(self, **kwarg)
0026         if not hasattr(self, "localQueueName"):
0027             self.localQueueName = "grid"
0028         # ncore factor
0029         try:
0030             if hasattr(self, "nCoreFactor"):
0031                 if type(self.nCoreFactor) in [dict]:
0032                     # self.nCoreFactor is a dict for ucore
0033                     # self.nCoreFactor = self.nCoreFactor
0034                     pass
0035                 else:
0036                     self.nCoreFactor = int(self.nCoreFactor)
0037                     if (not self.nCoreFactor) or (self.nCoreFactor < 1):
0038                         self.nCoreFactor = 1
0039             else:
0040                 self.nCoreFactor = 1
0041         except AttributeError:
0042             self.nCoreFactor = 1
0043 
0044         try:
0045             self.checkPartition = bool(self.checkPartition)
0046         except AttributeError:
0047             self.checkPartition = False
0048         # num workers to check the partition
0049         try:
0050             self.nWorkersToCheckPartition = int(self.nWorkersToCheckPartition)
0051         except AttributeError:
0052             self.nWorkersToCheckPartition = 10
0053         if (not self.nWorkersToCheckPartition) or (self.nWorkersToCheckPartition < 1):
0054             self.nWorkersToCheckPartition = 1
0055         # partition configuration
0056         try:
0057             if not isinstance(self.partitions, (list, tuple)):
0058                 self.partitions = self.partitions.split(",")
0059             self.partitions = [p.strip() for p in self.partitions]
0060         except AttributeError:
0061             self.partitions = None
0062 
0063     def get_queued_jobs(self, partition, logger):
0064         status = False
0065         num_pending_jobs = 0
0066 
0067         username = os.getlogin()
0068         # command = f"squeue -u {username} --partition={partition} | grep -e PD -e CF"
0069         command = f"squeue -u {username} --partition={partition}"
0070         logger.debug(f"check jobs in partition {partition} command: {command.split()}")
0071         p = subprocess.Popen(command.split(), shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
0072         stdout, stderr = p.communicate()
0073         ret_code = p.returncode
0074 
0075         if ret_code == 0:
0076             stdout_str = stdout if (isinstance(stdout, str) or stdout is None) else stdout.decode()
0077             # stderr_str = stderr if (isinstance(stderr, str) or stderr is None) else stderr.decode()
0078 
0079             num_pending_jobs = 0
0080             for line in stdout_str.split("\n"):
0081                 if len(line) == 0 or line.startswith("JobID") or line.startswith("--"):
0082                     continue
0083 
0084                 batch_status = line.split()[4].strip()
0085                 if batch_status in ["CF", "PD"]:
0086                     num_pending_jobs += 1
0087             logger.debug(f"number of pending jobs in partition {partition} with user {username}: {num_pending_jobs}")
0088             status = True
0089         else:
0090             logger.error(f"returncode: {ret_code}, stdout: {stdout}, stderr: {stderr}")
0091 
0092         return status, num_pending_jobs
0093 
0094     # get partition
0095     def get_partition(self, logger):
0096         if not self.partitions:
0097             return None
0098 
0099         logger.debug(f"partitions: {self.partitions}")
0100         if not self.checkPartition:
0101             return self.partitions[0]
0102 
0103         num_pending_by_partition = {}
0104         for partition in self.partitions:
0105             status, num_pending_jobs = self.get_queued_jobs(partition, logger)
0106             if status:
0107                 num_pending_by_partition[partition] = num_pending_jobs
0108         logger.debug(f"num_pending_by_partition: {num_pending_by_partition}")
0109 
0110         sorted_num_pending = dict(sorted(num_pending_by_partition.items(), key=lambda item: item[1]))
0111         if sorted_num_pending:
0112             selected_partition = list(sorted_num_pending.keys())[0]
0113             return selected_partition
0114         return None
0115 
0116     def get_core_factor(self, workspec, is_unified_queue, logger):
0117         try:
0118             if type(self.nCoreFactor) in [dict]:
0119                 if workspec.jobType in self.nCoreFactor:
0120                     job_type = workspec.jobType
0121                 else:
0122                     job_type = "Any"
0123                 if is_unified_queue:
0124                     resource_type = workspec.resourceType
0125                 else:
0126                     resource_type = "Undefined"
0127                 n_core_factor = self.nCoreFactor.get(job_type, {}).get(resource_type, 1)
0128                 return int(n_core_factor)
0129             else:
0130                 return int(self.nCoreFactor)
0131         except Exception as ex:
0132             logger.warning(f"Failed to get core factor: {ex}")
0133         return 1
0134 
0135     # submit workers
0136     def submit_workers(self, workspec_list):
0137         tmpLog = self.make_logger(baseLogger, f"site={self.queueName}", method_name="submit_workers")
0138 
0139         # get info from harvester queue config
0140         # _queueConfigMapper = QueueConfigMapper()
0141         # harvester_queue_config = _queueConfigMapper.get_queue(self.queueName)
0142 
0143         # get the queue configuration from CRIC
0144         panda_queues_dict = PandaQueuesDict()
0145         this_panda_queue_dict = panda_queues_dict.get(self.queueName, {})
0146         # associated_params_dict = panda_queues_dict.get_harvester_params(self.queueName)
0147 
0148         retList = []
0149         num_workSpec = 0
0150         for workSpec in workspec_list:
0151             # make logger
0152             tmpLog = self.make_logger(
0153                 baseLogger, f"site={self.queueName} workerID={workSpec.workerID} resourceType={workSpec.resourceType}", method_name="submit_workers"
0154             )
0155             # set nCore
0156             if self.nCore > 0:
0157                 workSpec.nCore = self.nCore
0158             if self.checkPartition:
0159                 if num_workSpec % self.nWorkersToCheckPartition == 0:
0160                     partition = self.get_partition(tmpLog)
0161                     num_workSpec += 1
0162             else:
0163                 partition = self.get_partition(tmpLog)
0164 
0165             # make batch script
0166             batchFile = self.make_batch_script(workSpec, partition, this_panda_queue_dict, tmpLog)
0167             # command
0168             comStr = f"sbatch --exclusive=user -D {workSpec.get_access_point()} {batchFile}"
0169             # submit
0170             tmpLog.debug(f"submit with {batchFile}")
0171             p = subprocess.Popen(comStr.split(), shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
0172             # check return code
0173             stdOut, stdErr = p.communicate()
0174             retCode = p.returncode
0175             tmpLog.debug(f"retCode={retCode}")
0176             stdOut_str = stdOut if (isinstance(stdOut, str) or stdOut is None) else stdOut.decode()
0177             stdErr_str = stdErr if (isinstance(stdErr, str) or stdErr is None) else stdErr.decode()
0178             if retCode == 0:
0179                 # extract batchID
0180                 workSpec.batchID = re.search("[^0-9]*([0-9]+)[^0-9]*$", f"{stdOut_str}").group(1)
0181                 tmpLog.debug(f"batchID={workSpec.batchID}")
0182                 # set log files
0183                 if self.logBaseURL and self.logDir:
0184                     stdOut, stdErr = self.get_log_file_names(workSpec.accessPoint, workSpec.workerID)
0185                     rel_stdOut = os.path.relpath(stdOut, self.logDir)
0186                     rel_stdErr = os.path.relpath(stdErr, self.logDir)
0187                     log_stdOut = os.path.join(self.logBaseURL, rel_stdOut)
0188                     log_stdErr = os.path.join(self.logBaseURL, rel_stdErr)
0189                     workSpec.set_log_file("stdout", log_stdOut)
0190                     workSpec.set_log_file("stderr", log_stdErr)
0191                 tmpRetVal = (True, "")
0192             else:
0193                 # failed
0194                 errStr = f"{stdOut_str} {stdErr_str}"
0195                 tmpLog.error(errStr)
0196                 tmpRetVal = (False, errStr)
0197             retList.append(tmpRetVal)
0198         return retList
0199 
0200     def make_placeholder_map(self, workspec, partition, this_panda_queue_dict, logger):
0201         timeNow = core_utils.naive_utcnow()
0202 
0203         # get default information from queue info
0204         n_core_per_node_from_queue = this_panda_queue_dict.get("corecount", 1) if this_panda_queue_dict.get("corecount", 1) else 1
0205 
0206         is_unified_queue = this_panda_queue_dict.get("capability", "") == "ucore"
0207         special_par = this_panda_queue_dict.get("special_par", "")
0208 
0209         n_core_per_node = getattr(self, "nCorePerNode", n_core_per_node_from_queue)
0210         if not n_core_per_node:
0211             n_core_per_node = self.nCore
0212 
0213         n_core_factor = self.get_core_factor(workspec, is_unified_queue, logger)
0214 
0215         logger.debug(f"workspec.nCore: {workspec.nCore}, n_core_per_node: {n_core_per_node}")
0216         logger.debug(f"workspec.minRamCount: {workspec.minRamCount}")
0217 
0218         n_core_total = workspec.nCore if workspec.nCore else n_core_per_node
0219         request_ram = max(workspec.minRamCount, 1 * n_core_total) if workspec.minRamCount else 1 * n_core_total
0220         logger.debug(f"n_core_total: {n_core_total}, request_ram: {request_ram}")
0221 
0222         request_disk = workspec.maxDiskCount * 1024 if workspec.maxDiskCount else 1
0223         request_walltime = workspec.maxWalltime if workspec.maxWalltime else 0
0224 
0225         ce_queue_name = None
0226 
0227         # possible override by CRIC special_par
0228         if special_par:
0229             special_par_attr_list = [
0230                 "queue",
0231                 "maxWallTime",
0232                 "xcount",
0233             ]
0234             _match_special_par_dict = {attr: re.search(f"\\({attr}=([^)]+)\\)", special_par) for attr in special_par_attr_list}
0235             for attr, _match in _match_special_par_dict.items():
0236                 if not _match:
0237                     continue
0238                 elif attr == "queue":
0239                     ce_queue_name = str(_match.group(1))
0240                 elif attr == "maxWallTime":
0241                     request_walltime = int(_match.group(1))
0242                 elif attr == "xcount":
0243                     n_core_total = int(_match.group(1))
0244                 logger.debug(f"job attributes override by CRIC special_par: {attr}={str(_match.group(1))}")
0245 
0246         n_node = getattr(self, "nNode", None)
0247         if not n_node:
0248             n_node = ceil(n_core_total / n_core_per_node)
0249 
0250         n_core_total_factor = n_core_total * n_core_factor
0251         request_ram_factor = request_ram * n_core_factor
0252         request_ram_bytes = request_ram * 2**20
0253         request_ram_bytes_factor = request_ram * 2**20 * n_core_factor
0254         request_ram_per_core = ceil(request_ram / n_core_total)
0255         request_ram_bytes_per_core = ceil(request_ram_bytes / n_core_total)
0256         request_cputime = request_walltime * n_core_total
0257         request_walltime_minute = ceil(request_walltime / 60)
0258         request_walltime_hour = ceil(request_walltime / 3600)
0259         request_cputime_minute = ceil(request_cputime / 60)
0260 
0261         # instance of resource type mapper
0262         rt_mapper = ResourceTypeMapper()
0263         all_resource_types = rt_mapper.get_all_resource_types()
0264 
0265         # GTAG
0266         if self.logBaseURL and self.logDir:
0267             stdOut, stdErr = self.get_log_file_names(workspec.accessPoint, workspec.workerID)
0268             rel_stdOut = os.path.relpath(stdOut, self.logDir)
0269             log_stdOut = os.path.join(self.logBaseURL, rel_stdOut)
0270             gtag = log_stdOut
0271         else:
0272             gtag = "unknown"
0273 
0274         placeholder_map = {
0275             "nCorePerNode": n_core_per_node,
0276             "nCoreTotal": n_core_total_factor,
0277             "nCoreFactor": n_core_factor,
0278             "nNode": n_node,
0279             "requestRam": request_ram_factor,
0280             "requestRamBytes": request_ram_bytes_factor,
0281             "requestRamPerCore": request_ram_per_core,
0282             "requestRamBytesPerCore": request_ram_bytes_per_core,
0283             "requestDisk": request_disk,
0284             "requestWalltime": request_walltime,
0285             "requestWalltimeMinute": request_walltime_minute,
0286             "requestWalltimeHour": request_walltime_hour,
0287             "requestCputime": request_cputime,
0288             "requestCputimeMinute": request_cputime_minute,
0289             "accessPoint": workspec.accessPoint,
0290             "harvesterID": harvester_config.master.harvester_id,
0291             "workerID": workspec.workerID,
0292             "computingSite": workspec.computingSite,
0293             "pandaQueueName": self.queueName,
0294             "localQueueName": self.localQueueName,
0295             "ceQueueName": ce_queue_name,
0296             # 'x509UserProxy': x509_user_proxy,
0297             "logDir": self.logDir,
0298             "logSubDir": os.path.join(self.logDir, timeNow.strftime("%y-%m-%d_%H")),
0299             "jobType": workspec.jobType,
0300             "gtag": gtag,
0301             "partition": partition,
0302             "resourceType": submitter_common.get_resource_type(workspec.resourceType, is_unified_queue, all_resource_types),
0303             "pilotResourceTypeOption": submitter_common.get_resource_type(workspec.resourceType, is_unified_queue, all_resource_types, is_pilot_option=True),
0304         }
0305         for k in ["tokenDir", "tokenName", "tokenOrigin", "submitMode"]:
0306             try:
0307                 placeholder_map[k] = getattr(self, k)
0308             except Exception:
0309                 pass
0310         return placeholder_map
0311 
0312     # make batch script
0313     def make_batch_script(self, workspec, partition, this_panda_queue_dict, logger):
0314         # template for batch script
0315         with open(self.templateFile) as f:
0316             template = f.read()
0317         tmpFile = tempfile.NamedTemporaryFile(delete=False, suffix="_submit.sh", dir=workspec.get_access_point())
0318         placeholder = self.make_placeholder_map(workspec, partition, this_panda_queue_dict, logger)
0319         tmpFile.write(str(template.format_map(core_utils.SafeDict(placeholder))).encode("latin_1"))
0320         tmpFile.close()
0321 
0322         # set execution bit and group permissions on the temp file
0323         st = os.stat(tmpFile.name)
0324         os.chmod(tmpFile.name, st.st_mode | stat.S_IEXEC | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH)
0325 
0326         return tmpFile.name
0327 
0328     # get log file names
0329     def get_log_file_names_old(self, batch_script, batch_id):
0330         stdOut = None
0331         stdErr = None
0332         with open(batch_script) as f:
0333             for line in f:
0334                 if not line.startswith("#SBATCH"):
0335                     continue
0336                 items = line.split()
0337                 if "-o" in items:
0338                     stdOut = items[-1].replace("$SLURM_JOB_ID", batch_id)
0339                 elif "-e" in items:
0340                     stdErr = items[-1].replace("$SLURM_JOB_ID", batch_id)
0341         return stdOut, stdErr
0342 
0343     def get_log_file_names(self, access_point, worker_id):
0344         stdOut = os.path.join(access_point, f"{worker_id}.out")
0345         stdErr = os.path.join(access_point, f"{worker_id}.err")
0346         return stdOut, stdErr