File indexing completed on 2026-04-19 08:00:05
0001 import re
0002 import subprocess
0003 import tempfile
0004
0005 import jinja2
0006
0007 from pandaharvester.harvestercore import core_utils
0008 from pandaharvester.harvestercore.plugin_base import PluginBase
0009
0010
0011 baseLogger = core_utils.setup_logger("slurm_submitter")
0012
0013
0014
0015 class SlurmSubmitterJinja(PluginBase):
0016
0017 def __init__(self, **kwarg):
0018 self.uploadLog = False
0019 self.logBaseURL = None
0020 PluginBase.__init__(self, **kwarg)
0021
0022
0023 def submit_workers(self, workspec_list):
0024 retList = []
0025 retStrList = []
0026 for workSpec in workspec_list:
0027
0028 tmpLog = self.make_logger(baseLogger, f"workerID={workSpec.workerID}", method_name="submit_workers")
0029
0030 workSpec.nCore = self.nCore
0031
0032 batchFile = self.make_batch_script_jinja(workSpec)
0033
0034 comStr = f"sbatch -D {workSpec.get_access_point()} {batchFile}"
0035
0036 tmpLog.debug(f"submit with {batchFile}")
0037 p = subprocess.Popen(comStr.split(), shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
0038
0039 stdOut, stdErr = p.communicate()
0040 retCode = p.returncode
0041 tmpLog.debug(f"retCode={retCode}")
0042 stdOut_str = stdOut if (isinstance(stdOut, str) or stdOut is None) else stdOut.decode()
0043 stdErr_str = stdErr if (isinstance(stdErr, str) or stdErr is None) else stdErr.decode()
0044 if retCode == 0:
0045
0046 workSpec.batchID = re.search("[^0-9]*([0-9]+)[^0-9]*", f"{stdOut_str}").group(1)
0047 tmpLog.debug(f"batchID={workSpec.batchID}")
0048
0049 if self.uploadLog:
0050 if self.logBaseURL is None:
0051 baseDir = workSpec.get_access_point()
0052 else:
0053 baseDir = self.logBaseURL
0054 stdOut, stdErr = self.get_log_file_names(batchFile, workSpec.batchID)
0055 if stdOut is not None:
0056 workSpec.set_log_file("stdout", f"{baseDir}/{stdOut}")
0057 if stdErr is not None:
0058 workSpec.set_log_file("stderr", f"{baseDir}/{stdErr}")
0059 tmpRetVal = (True, "")
0060 else:
0061
0062 errStr = f"{stdOut_str} {stdErr_str}"
0063 tmpLog.error(errStr)
0064 tmpRetVal = (False, errStr)
0065 retList.append(tmpRetVal)
0066 return retList
0067
0068
0069 def make_batch_script(self, workspec):
0070
0071 tmpFile = open(self.templateFile)
0072 self.template = tmpFile.read()
0073 tmpFile.close()
0074 del tmpFile
0075 tmpFile = tempfile.NamedTemporaryFile(delete=False, suffix="_submit.sh", dir=workspec.get_access_point())
0076 tmpFile.write(
0077 str(
0078 self.template.format(
0079 nCorePerNode=self.nCorePerNode, nNode=workspec.nCore // self.nCorePerNode, accessPoint=workspec.accessPoint, workerID=workspec.workerID
0080 )
0081 ).encode("latin_1")
0082 )
0083 tmpFile.close()
0084 return tmpFile.name
0085
0086
0087
0088 def make_batch_script_jinja(self, workspec):
0089
0090 tmpFile = open(self.templateFile)
0091 self.template = tmpFile.read()
0092 tmpFile.close()
0093 del tmpFile
0094 tmpFile = tempfile.NamedTemporaryFile(delete=False, suffix="_submit.sh", dir=workspec.get_access_point())
0095 tm = jinja2.Template(self.template)
0096 tmpFile.write(
0097 str(
0098 tm.render(
0099 nCorePerNode=self.nCorePerNode,
0100 nNode=workspec.nCore // self.nCorePerNode,
0101 accessPoint=workspec.accessPoint,
0102 workerID=workspec.workerID,
0103 workspec=workspec,
0104 )
0105 ).encode("latin_1")
0106 )
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118 tmpFile.close()
0119 return tmpFile.name
0120
0121
0122
0123 def get_log_file_names(self, batch_script, batch_id):
0124 stdOut = None
0125 stdErr = None
0126 with open(batch_script) as f:
0127 for line in f:
0128 if not line.startswith("#SBATCH"):
0129 continue
0130 items = line.split()
0131 if "-o" in items:
0132 stdOut = items[-1].replace("$SLURM_JOB_ID", batch_id)
0133 elif "-e" in items:
0134 stdErr = items[-1].replace("$SLURM_JOB_ID", batch_id)
0135 return stdOut, stdErr