File indexing completed on 2026-04-19 08:00:05
0001 import os
0002 import re
0003 import stat
0004 import subprocess
0005 import tempfile
0006 from math import ceil
0007
0008 from pandaharvester.harvesterconfig import harvester_config
0009 from pandaharvester.harvestercore import core_utils
0010 from pandaharvester.harvestercore.plugin_base import PluginBase
0011
0012
0013 baseLogger = core_utils.setup_logger("slurm_submitter")
0014
0015
0016
0017 class SlurmSubmitter(PluginBase):
0018
0019 def __init__(self, **kwarg):
0020 self.uploadLog = False
0021 self.logBaseURL = None
0022 PluginBase.__init__(self, **kwarg)
0023 if not hasattr(self, "localQueueName"):
0024 self.localQueueName = "grid"
0025
0026 try:
0027 if hasattr(self, "nCoreFactor"):
0028 if type(self.nCoreFactor) in [dict]:
0029
0030
0031 pass
0032 else:
0033 self.nCoreFactor = int(self.nCoreFactor)
0034 if (not self.nCoreFactor) or (self.nCoreFactor < 1):
0035 self.nCoreFactor = 1
0036 else:
0037 self.nCoreFactor = 1
0038 except AttributeError:
0039 self.nCoreFactor = 1
0040
0041
0042 def submit_workers(self, workspec_list):
0043 retList = []
0044 for workSpec in workspec_list:
0045
0046 tmpLog = self.make_logger(baseLogger, f"workerID={workSpec.workerID}", method_name="submit_workers")
0047
0048 if self.nCore > 0:
0049 workSpec.nCore = self.nCore
0050
0051 batchFile = self.make_batch_script(workSpec, tmpLog)
0052
0053 comStr = f"sbatch -D {workSpec.get_access_point()} {batchFile}"
0054
0055 tmpLog.debug(f"submit with {batchFile}")
0056 p = subprocess.Popen(comStr.split(), shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
0057
0058 stdOut, stdErr = p.communicate()
0059 retCode = p.returncode
0060 tmpLog.debug(f"retCode={retCode}")
0061 stdOut_str = stdOut if (isinstance(stdOut, str) or stdOut is None) else stdOut.decode()
0062 stdErr_str = stdErr if (isinstance(stdErr, str) or stdErr is None) else stdErr.decode()
0063 if retCode == 0:
0064
0065 workSpec.batchID = re.search("[^0-9]*([0-9]+)[^0-9]*$", f"{stdOut_str}").group(1)
0066 tmpLog.debug(f"batchID={workSpec.batchID}")
0067
0068 if self.logBaseURL and self.logDir:
0069 stdOut, stdErr = self.get_log_file_names(workSpec.accessPoint, workSpec.workerID)
0070 rel_stdOut = os.path.relpath(stdOut, self.logDir)
0071 rel_stdErr = os.path.relpath(stdErr, self.logDir)
0072 log_stdOut = os.path.join(self.logBaseURL, rel_stdOut)
0073 log_stdErr = os.path.join(self.logBaseURL, rel_stdErr)
0074 workSpec.set_log_file("stdout", log_stdOut)
0075 workSpec.set_log_file("stderr", log_stdErr)
0076 tmpRetVal = (True, "")
0077 else:
0078
0079 errStr = f"{stdOut_str} {stdErr_str}"
0080 tmpLog.error(errStr)
0081 tmpRetVal = (False, errStr)
0082 retList.append(tmpRetVal)
0083 return retList
0084
0085 def get_core_factor(self, workspec, logger):
0086 try:
0087 if type(self.nCoreFactor) in [dict]:
0088 n_core_factor = self.nCoreFactor.get(workspec.jobType, {}).get(workspec.resourceType, 1)
0089 return int(n_core_factor)
0090 return int(self.nCoreFactor)
0091 except Exception as ex:
0092 logger.warning(f"Failed to get core factor: {ex}")
0093 return 1
0094
0095 def make_placeholder_map(self, workspec, logger):
0096 timeNow = core_utils.naive_utcnow()
0097
0098 panda_queue_name = self.queueName
0099 this_panda_queue_dict = dict()
0100
0101
0102 n_core_per_node_from_queue = this_panda_queue_dict.get("corecount", 1) if this_panda_queue_dict.get("corecount", 1) else 1
0103 n_core_per_node = getattr(self, "nCorePerNode", n_core_per_node_from_queue)
0104 if not n_core_per_node:
0105 n_core_per_node = self.nCore
0106
0107 n_core_factor = self.get_core_factor(workspec, logger)
0108
0109 n_core_total = workspec.nCore if workspec.nCore else n_core_per_node
0110 n_core_total_factor = n_core_total * n_core_factor
0111 request_ram = max(workspec.minRamCount, 1 * n_core_total) if workspec.minRamCount else 1 * n_core_total
0112 request_disk = workspec.maxDiskCount * 1024 if workspec.maxDiskCount else 1
0113 request_walltime = workspec.maxWalltime if workspec.maxWalltime else 0
0114
0115 n_node = getattr(self, "nNode", None)
0116 if not n_node:
0117 n_node = ceil(n_core_total / n_core_per_node)
0118
0119 request_ram_factor = request_ram * n_core_factor
0120 request_ram_bytes = request_ram * 2**20
0121 request_ram_bytes_factor = request_ram * 2**20 * n_core_factor
0122 request_ram_per_core = ceil(request_ram / n_core_total)
0123 request_ram_bytes_per_core = ceil(request_ram_bytes / n_core_total)
0124 request_cputime = request_walltime * n_core_total
0125 request_walltime_minute = ceil(request_walltime / 60)
0126 request_cputime_minute = ceil(request_cputime / 60)
0127
0128
0129 if self.logBaseURL and self.logDir:
0130 stdOut, stdErr = self.get_log_file_names(workspec.accessPoint, workspec.workerID)
0131 rel_stdOut = os.path.relpath(stdOut, self.logDir)
0132 log_stdOut = os.path.join(self.logBaseURL, rel_stdOut)
0133 gtag = log_stdOut
0134 else:
0135 gtag = "unknown"
0136
0137 placeholder_map = {
0138 "nCorePerNode": n_core_per_node,
0139 "nCoreTotal": n_core_total_factor,
0140 "nCoreFactor": n_core_factor,
0141 "nNode": n_node,
0142 "requestRam": request_ram_factor,
0143 "requestRamBytes": request_ram_bytes_factor,
0144 "requestRamPerCore": request_ram_per_core,
0145 "requestRamBytesPerCore": request_ram_bytes_per_core,
0146 "requestDisk": request_disk,
0147 "requestWalltime": request_walltime,
0148 "requestWalltimeMinute": request_walltime_minute,
0149 "requestCputime": request_cputime,
0150 "requestCputimeMinute": request_cputime_minute,
0151 "accessPoint": workspec.accessPoint,
0152 "harvesterID": harvester_config.master.harvester_id,
0153 "workerID": workspec.workerID,
0154 "computingSite": workspec.computingSite,
0155 "pandaQueueName": panda_queue_name,
0156 "localQueueName": self.localQueueName,
0157
0158 "logDir": self.logDir,
0159 "logSubDir": os.path.join(self.logDir, timeNow.strftime("%y-%m-%d_%H")),
0160 "jobType": workspec.jobType,
0161 "gtag": gtag,
0162 }
0163 for k in ["tokenDir", "tokenName", "tokenOrigin", "submitMode"]:
0164 try:
0165 placeholder_map[k] = getattr(self, k)
0166 except Exception:
0167 pass
0168 return placeholder_map
0169
0170
0171 def make_batch_script(self, workspec, logger):
0172
0173 with open(self.templateFile) as f:
0174 template = f.read()
0175 tmpFile = tempfile.NamedTemporaryFile(delete=False, suffix="_submit.sh", dir=workspec.get_access_point())
0176 placeholder = self.make_placeholder_map(workspec, logger)
0177 tmpFile.write(str(template.format_map(core_utils.SafeDict(placeholder))).encode("latin_1"))
0178 tmpFile.close()
0179
0180
0181 st = os.stat(tmpFile.name)
0182 os.chmod(tmpFile.name, st.st_mode | stat.S_IEXEC | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH)
0183
0184 return tmpFile.name
0185
0186
0187 def get_log_file_names_old(self, batch_script, batch_id):
0188 stdOut = None
0189 stdErr = None
0190 with open(batch_script) as f:
0191 for line in f:
0192 if not line.startswith("#SBATCH"):
0193 continue
0194 items = line.split()
0195 if "-o" in items:
0196 stdOut = items[-1].replace("$SLURM_JOB_ID", batch_id)
0197 elif "-e" in items:
0198 stdErr = items[-1].replace("$SLURM_JOB_ID", batch_id)
0199 return stdOut, stdErr
0200
0201 def get_log_file_names(self, access_point, worker_id):
0202 stdOut = os.path.join(access_point, f"{worker_id}.out")
0203 stdErr = os.path.join(access_point, f"{worker_id}.err")
0204 return stdOut, stdErr