Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-19 08:00:05

0001 import re
0002 import subprocess
0003 import tempfile
0004 
0005 import jinja2
0006 
0007 from pandaharvester.harvestercore import core_utils
0008 from pandaharvester.harvestercore.plugin_base import PluginBase
0009 
0010 # logger
0011 baseLogger = core_utils.setup_logger("slurm_submitter")
0012 
0013 
0014 # submitter for SLURM batch system
0015 class SlurmSubmitterJinja(PluginBase):
0016     # constructor
0017     def __init__(self, **kwarg):
0018         self.uploadLog = False
0019         self.logBaseURL = None
0020         PluginBase.__init__(self, **kwarg)
0021 
0022     # submit workers
0023     def submit_workers(self, workspec_list):
0024         retList = []
0025         retStrList = []
0026         for workSpec in workspec_list:
0027             # make logger
0028             tmpLog = self.make_logger(baseLogger, f"workerID={workSpec.workerID}", method_name="submit_workers")
0029             # set nCore
0030             workSpec.nCore = self.nCore
0031             # make batch script
0032             batchFile = self.make_batch_script_jinja(workSpec)
0033             # command
0034             comStr = f"sbatch -D {workSpec.get_access_point()} {batchFile}"
0035             # submit
0036             tmpLog.debug(f"submit with {batchFile}")
0037             p = subprocess.Popen(comStr.split(), shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
0038             # check return code
0039             stdOut, stdErr = p.communicate()
0040             retCode = p.returncode
0041             tmpLog.debug(f"retCode={retCode}")
0042             stdOut_str = stdOut if (isinstance(stdOut, str) or stdOut is None) else stdOut.decode()
0043             stdErr_str = stdErr if (isinstance(stdErr, str) or stdErr is None) else stdErr.decode()
0044             if retCode == 0:
0045                 # extract batchID
0046                 workSpec.batchID = re.search("[^0-9]*([0-9]+)[^0-9]*", f"{stdOut_str}").group(1)
0047                 tmpLog.debug(f"batchID={workSpec.batchID}")
0048                 # set log files
0049                 if self.uploadLog:
0050                     if self.logBaseURL is None:
0051                         baseDir = workSpec.get_access_point()
0052                     else:
0053                         baseDir = self.logBaseURL
0054                     stdOut, stdErr = self.get_log_file_names(batchFile, workSpec.batchID)
0055                     if stdOut is not None:
0056                         workSpec.set_log_file("stdout", f"{baseDir}/{stdOut}")
0057                     if stdErr is not None:
0058                         workSpec.set_log_file("stderr", f"{baseDir}/{stdErr}")
0059                 tmpRetVal = (True, "")
0060             else:
0061                 # failed
0062                 errStr = f"{stdOut_str} {stdErr_str}"
0063                 tmpLog.error(errStr)
0064                 tmpRetVal = (False, errStr)
0065             retList.append(tmpRetVal)
0066         return retList
0067 
0068     # make batch script
0069     def make_batch_script(self, workspec):
0070         # template for batch script
0071         tmpFile = open(self.templateFile)
0072         self.template = tmpFile.read()
0073         tmpFile.close()
0074         del tmpFile
0075         tmpFile = tempfile.NamedTemporaryFile(delete=False, suffix="_submit.sh", dir=workspec.get_access_point())
0076         tmpFile.write(
0077             str(
0078                 self.template.format(
0079                     nCorePerNode=self.nCorePerNode, nNode=workspec.nCore // self.nCorePerNode, accessPoint=workspec.accessPoint, workerID=workspec.workerID
0080                 )
0081             ).encode("latin_1")
0082         )
0083         tmpFile.close()
0084         return tmpFile.name
0085 
0086     # make batch script
0087 
0088     def make_batch_script_jinja(self, workspec):
0089         # template for batch script
0090         tmpFile = open(self.templateFile)
0091         self.template = tmpFile.read()
0092         tmpFile.close()
0093         del tmpFile
0094         tmpFile = tempfile.NamedTemporaryFile(delete=False, suffix="_submit.sh", dir=workspec.get_access_point())
0095         tm = jinja2.Template(self.template)
0096         tmpFile.write(
0097             str(
0098                 tm.render(
0099                     nCorePerNode=self.nCorePerNode,
0100                     nNode=workspec.nCore // self.nCorePerNode,
0101                     accessPoint=workspec.accessPoint,
0102                     workerID=workspec.workerID,
0103                     workspec=workspec,
0104                 )
0105             ).encode("latin_1")
0106         )
0107 
0108         # tmpFile.write(str(self.template.format(nCorePerNode=self.nCorePerNode,
0109         #                                   nNode=workspec.nCore // self.nCorePerNode,
0110         #                                   accessPoint=workspec.accessPoint,
0111         #                                   workerID=workspec.workerID)).encode("latin_1")
0112         #              )
0113         # tmpFile.write(str(self.template.format(nCorePerNode=self.nCorePerNode,
0114         #                                   nNode=workspec.nCore // self.nCorePerNode,
0115         #                                   worker=workSpec,
0116         #                                   submitter=self)).encode("latin_1")
0117         #              )
0118         tmpFile.close()
0119         return tmpFile.name
0120 
0121     # get log file names
0122 
0123     def get_log_file_names(self, batch_script, batch_id):
0124         stdOut = None
0125         stdErr = None
0126         with open(batch_script) as f:
0127             for line in f:
0128                 if not line.startswith("#SBATCH"):
0129                     continue
0130                 items = line.split()
0131                 if "-o" in items:
0132                     stdOut = items[-1].replace("$SLURM_JOB_ID", batch_id)
0133                 elif "-e" in items:
0134                     stdErr = items[-1].replace("$SLURM_JOB_ID", batch_id)
0135         return stdOut, stdErr