File indexing completed on 2026-04-19 08:00:05
0001 import os
0002 import re
0003 import stat
0004 import subprocess
0005 import tempfile
0006 from math import ceil
0007
0008 from pandaharvester.harvesterconfig import harvester_config
0009 from pandaharvester.harvestercore import core_utils
0010 from pandaharvester.harvestercore.plugin_base import PluginBase
0011 from pandaharvester.harvestercore.resource_type_mapper import ResourceTypeMapper
0012 from pandaharvester.harvestermisc.info_utils import PandaQueuesDict
0013 from pandaharvester.harvestersubmitter import submitter_common
0014
0015
0016 baseLogger = core_utils.setup_logger("slurm_submitter_rubin")
0017
0018
0019
0020 class SlurmSubmitter(PluginBase):
0021
0022 def __init__(self, **kwarg):
0023 self.uploadLog = False
0024 self.logBaseURL = None
0025 PluginBase.__init__(self, **kwarg)
0026 if not hasattr(self, "localQueueName"):
0027 self.localQueueName = "grid"
0028
0029 try:
0030 if hasattr(self, "nCoreFactor"):
0031 if type(self.nCoreFactor) in [dict]:
0032
0033
0034 pass
0035 else:
0036 self.nCoreFactor = int(self.nCoreFactor)
0037 if (not self.nCoreFactor) or (self.nCoreFactor < 1):
0038 self.nCoreFactor = 1
0039 else:
0040 self.nCoreFactor = 1
0041 except AttributeError:
0042 self.nCoreFactor = 1
0043
0044 try:
0045 self.checkPartition = bool(self.checkPartition)
0046 except AttributeError:
0047 self.checkPartition = False
0048
0049 try:
0050 self.nWorkersToCheckPartition = int(self.nWorkersToCheckPartition)
0051 except AttributeError:
0052 self.nWorkersToCheckPartition = 10
0053 if (not self.nWorkersToCheckPartition) or (self.nWorkersToCheckPartition < 1):
0054 self.nWorkersToCheckPartition = 1
0055
0056 try:
0057 if not isinstance(self.partitions, (list, tuple)):
0058 self.partitions = self.partitions.split(",")
0059 self.partitions = [p.strip() for p in self.partitions]
0060 except AttributeError:
0061 self.partitions = None
0062
0063 def get_queued_jobs(self, partition, logger):
0064 status = False
0065 num_pending_jobs = 0
0066
0067 username = os.getlogin()
0068
0069 command = f"squeue -u {username} --partition={partition}"
0070 logger.debug(f"check jobs in partition {partition} command: {command.split()}")
0071 p = subprocess.Popen(command.split(), shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
0072 stdout, stderr = p.communicate()
0073 ret_code = p.returncode
0074
0075 if ret_code == 0:
0076 stdout_str = stdout if (isinstance(stdout, str) or stdout is None) else stdout.decode()
0077
0078
0079 num_pending_jobs = 0
0080 for line in stdout_str.split("\n"):
0081 if len(line) == 0 or line.startswith("JobID") or line.startswith("--"):
0082 continue
0083
0084 batch_status = line.split()[4].strip()
0085 if batch_status in ["CF", "PD"]:
0086 num_pending_jobs += 1
0087 logger.debug(f"number of pending jobs in partition {partition} with user {username}: {num_pending_jobs}")
0088 status = True
0089 else:
0090 logger.error(f"returncode: {ret_code}, stdout: {stdout}, stderr: {stderr}")
0091
0092 return status, num_pending_jobs
0093
0094
0095 def get_partition(self, logger):
0096 if not self.partitions:
0097 return None
0098
0099 logger.debug(f"partitions: {self.partitions}")
0100 if not self.checkPartition:
0101 return self.partitions[0]
0102
0103 num_pending_by_partition = {}
0104 for partition in self.partitions:
0105 status, num_pending_jobs = self.get_queued_jobs(partition, logger)
0106 if status:
0107 num_pending_by_partition[partition] = num_pending_jobs
0108 logger.debug(f"num_pending_by_partition: {num_pending_by_partition}")
0109
0110 sorted_num_pending = dict(sorted(num_pending_by_partition.items(), key=lambda item: item[1]))
0111 if sorted_num_pending:
0112 selected_partition = list(sorted_num_pending.keys())[0]
0113 return selected_partition
0114 return None
0115
0116 def get_core_factor(self, workspec, is_unified_queue, logger):
0117 try:
0118 if type(self.nCoreFactor) in [dict]:
0119 if workspec.jobType in self.nCoreFactor:
0120 job_type = workspec.jobType
0121 else:
0122 job_type = "Any"
0123 if is_unified_queue:
0124 resource_type = workspec.resourceType
0125 else:
0126 resource_type = "Undefined"
0127 n_core_factor = self.nCoreFactor.get(job_type, {}).get(resource_type, 1)
0128 return int(n_core_factor)
0129 else:
0130 return int(self.nCoreFactor)
0131 except Exception as ex:
0132 logger.warning(f"Failed to get core factor: {ex}")
0133 return 1
0134
0135
0136 def submit_workers(self, workspec_list):
0137 tmpLog = self.make_logger(baseLogger, f"site={self.queueName}", method_name="submit_workers")
0138
0139
0140
0141
0142
0143
0144 panda_queues_dict = PandaQueuesDict()
0145 this_panda_queue_dict = panda_queues_dict.get(self.queueName, {})
0146
0147
0148 retList = []
0149 num_workSpec = 0
0150 for workSpec in workspec_list:
0151
0152 tmpLog = self.make_logger(
0153 baseLogger, f"site={self.queueName} workerID={workSpec.workerID} resourceType={workSpec.resourceType}", method_name="submit_workers"
0154 )
0155
0156 if self.nCore > 0:
0157 workSpec.nCore = self.nCore
0158 if self.checkPartition:
0159 if num_workSpec % self.nWorkersToCheckPartition == 0:
0160 partition = self.get_partition(tmpLog)
0161 num_workSpec += 1
0162 else:
0163 partition = self.get_partition(tmpLog)
0164
0165
0166 batchFile = self.make_batch_script(workSpec, partition, this_panda_queue_dict, tmpLog)
0167
0168 comStr = f"sbatch --exclusive=user -D {workSpec.get_access_point()} {batchFile}"
0169
0170 tmpLog.debug(f"submit with {batchFile}")
0171 p = subprocess.Popen(comStr.split(), shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
0172
0173 stdOut, stdErr = p.communicate()
0174 retCode = p.returncode
0175 tmpLog.debug(f"retCode={retCode}")
0176 stdOut_str = stdOut if (isinstance(stdOut, str) or stdOut is None) else stdOut.decode()
0177 stdErr_str = stdErr if (isinstance(stdErr, str) or stdErr is None) else stdErr.decode()
0178 if retCode == 0:
0179
0180 workSpec.batchID = re.search("[^0-9]*([0-9]+)[^0-9]*$", f"{stdOut_str}").group(1)
0181 tmpLog.debug(f"batchID={workSpec.batchID}")
0182
0183 if self.logBaseURL and self.logDir:
0184 stdOut, stdErr = self.get_log_file_names(workSpec.accessPoint, workSpec.workerID)
0185 rel_stdOut = os.path.relpath(stdOut, self.logDir)
0186 rel_stdErr = os.path.relpath(stdErr, self.logDir)
0187 log_stdOut = os.path.join(self.logBaseURL, rel_stdOut)
0188 log_stdErr = os.path.join(self.logBaseURL, rel_stdErr)
0189 workSpec.set_log_file("stdout", log_stdOut)
0190 workSpec.set_log_file("stderr", log_stdErr)
0191 tmpRetVal = (True, "")
0192 else:
0193
0194 errStr = f"{stdOut_str} {stdErr_str}"
0195 tmpLog.error(errStr)
0196 tmpRetVal = (False, errStr)
0197 retList.append(tmpRetVal)
0198 return retList
0199
0200 def make_placeholder_map(self, workspec, partition, this_panda_queue_dict, logger):
0201 timeNow = core_utils.naive_utcnow()
0202
0203
0204 n_core_per_node_from_queue = this_panda_queue_dict.get("corecount", 1) if this_panda_queue_dict.get("corecount", 1) else 1
0205
0206 is_unified_queue = this_panda_queue_dict.get("capability", "") == "ucore"
0207 special_par = this_panda_queue_dict.get("special_par", "")
0208
0209 n_core_per_node = getattr(self, "nCorePerNode", n_core_per_node_from_queue)
0210 if not n_core_per_node:
0211 n_core_per_node = self.nCore
0212
0213 n_core_factor = self.get_core_factor(workspec, is_unified_queue, logger)
0214
0215 logger.debug(f"workspec.nCore: {workspec.nCore}, n_core_per_node: {n_core_per_node}")
0216 logger.debug(f"workspec.minRamCount: {workspec.minRamCount}")
0217
0218 n_core_total = workspec.nCore if workspec.nCore else n_core_per_node
0219 request_ram = max(workspec.minRamCount, 1 * n_core_total) if workspec.minRamCount else 1 * n_core_total
0220 logger.debug(f"n_core_total: {n_core_total}, request_ram: {request_ram}")
0221
0222 request_disk = workspec.maxDiskCount * 1024 if workspec.maxDiskCount else 1
0223 request_walltime = workspec.maxWalltime if workspec.maxWalltime else 0
0224
0225 ce_queue_name = None
0226
0227
0228 if special_par:
0229 special_par_attr_list = [
0230 "queue",
0231 "maxWallTime",
0232 "xcount",
0233 ]
0234 _match_special_par_dict = {attr: re.search(f"\\({attr}=([^)]+)\\)", special_par) for attr in special_par_attr_list}
0235 for attr, _match in _match_special_par_dict.items():
0236 if not _match:
0237 continue
0238 elif attr == "queue":
0239 ce_queue_name = str(_match.group(1))
0240 elif attr == "maxWallTime":
0241 request_walltime = int(_match.group(1))
0242 elif attr == "xcount":
0243 n_core_total = int(_match.group(1))
0244 logger.debug(f"job attributes override by CRIC special_par: {attr}={str(_match.group(1))}")
0245
0246 n_node = getattr(self, "nNode", None)
0247 if not n_node:
0248 n_node = ceil(n_core_total / n_core_per_node)
0249
0250 n_core_total_factor = n_core_total * n_core_factor
0251 request_ram_factor = request_ram * n_core_factor
0252 request_ram_bytes = request_ram * 2**20
0253 request_ram_bytes_factor = request_ram * 2**20 * n_core_factor
0254 request_ram_per_core = ceil(request_ram / n_core_total)
0255 request_ram_bytes_per_core = ceil(request_ram_bytes / n_core_total)
0256 request_cputime = request_walltime * n_core_total
0257 request_walltime_minute = ceil(request_walltime / 60)
0258 request_walltime_hour = ceil(request_walltime / 3600)
0259 request_cputime_minute = ceil(request_cputime / 60)
0260
0261
0262 rt_mapper = ResourceTypeMapper()
0263 all_resource_types = rt_mapper.get_all_resource_types()
0264
0265
0266 if self.logBaseURL and self.logDir:
0267 stdOut, stdErr = self.get_log_file_names(workspec.accessPoint, workspec.workerID)
0268 rel_stdOut = os.path.relpath(stdOut, self.logDir)
0269 log_stdOut = os.path.join(self.logBaseURL, rel_stdOut)
0270 gtag = log_stdOut
0271 else:
0272 gtag = "unknown"
0273
0274 placeholder_map = {
0275 "nCorePerNode": n_core_per_node,
0276 "nCoreTotal": n_core_total_factor,
0277 "nCoreFactor": n_core_factor,
0278 "nNode": n_node,
0279 "requestRam": request_ram_factor,
0280 "requestRamBytes": request_ram_bytes_factor,
0281 "requestRamPerCore": request_ram_per_core,
0282 "requestRamBytesPerCore": request_ram_bytes_per_core,
0283 "requestDisk": request_disk,
0284 "requestWalltime": request_walltime,
0285 "requestWalltimeMinute": request_walltime_minute,
0286 "requestWalltimeHour": request_walltime_hour,
0287 "requestCputime": request_cputime,
0288 "requestCputimeMinute": request_cputime_minute,
0289 "accessPoint": workspec.accessPoint,
0290 "harvesterID": harvester_config.master.harvester_id,
0291 "workerID": workspec.workerID,
0292 "computingSite": workspec.computingSite,
0293 "pandaQueueName": self.queueName,
0294 "localQueueName": self.localQueueName,
0295 "ceQueueName": ce_queue_name,
0296
0297 "logDir": self.logDir,
0298 "logSubDir": os.path.join(self.logDir, timeNow.strftime("%y-%m-%d_%H")),
0299 "jobType": workspec.jobType,
0300 "gtag": gtag,
0301 "partition": partition,
0302 "resourceType": submitter_common.get_resource_type(workspec.resourceType, is_unified_queue, all_resource_types),
0303 "pilotResourceTypeOption": submitter_common.get_resource_type(workspec.resourceType, is_unified_queue, all_resource_types, is_pilot_option=True),
0304 }
0305 for k in ["tokenDir", "tokenName", "tokenOrigin", "submitMode"]:
0306 try:
0307 placeholder_map[k] = getattr(self, k)
0308 except Exception:
0309 pass
0310 return placeholder_map
0311
0312
0313 def make_batch_script(self, workspec, partition, this_panda_queue_dict, logger):
0314
0315 with open(self.templateFile) as f:
0316 template = f.read()
0317 tmpFile = tempfile.NamedTemporaryFile(delete=False, suffix="_submit.sh", dir=workspec.get_access_point())
0318 placeholder = self.make_placeholder_map(workspec, partition, this_panda_queue_dict, logger)
0319 tmpFile.write(str(template.format_map(core_utils.SafeDict(placeholder))).encode("latin_1"))
0320 tmpFile.close()
0321
0322
0323 st = os.stat(tmpFile.name)
0324 os.chmod(tmpFile.name, st.st_mode | stat.S_IEXEC | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH)
0325
0326 return tmpFile.name
0327
0328
0329 def get_log_file_names_old(self, batch_script, batch_id):
0330 stdOut = None
0331 stdErr = None
0332 with open(batch_script) as f:
0333 for line in f:
0334 if not line.startswith("#SBATCH"):
0335 continue
0336 items = line.split()
0337 if "-o" in items:
0338 stdOut = items[-1].replace("$SLURM_JOB_ID", batch_id)
0339 elif "-e" in items:
0340 stdErr = items[-1].replace("$SLURM_JOB_ID", batch_id)
0341 return stdOut, stdErr
0342
0343 def get_log_file_names(self, access_point, worker_id):
0344 stdOut = os.path.join(access_point, f"{worker_id}.out")
0345 stdErr = os.path.join(access_point, f"{worker_id}.err")
0346 return stdOut, stdErr