Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-19 08:00:05

0001 import os
0002 import re
0003 import stat
0004 import subprocess
0005 import tempfile
0006 from math import ceil
0007 
0008 from pandaharvester.harvesterconfig import harvester_config
0009 from pandaharvester.harvestercore import core_utils
0010 from pandaharvester.harvestercore.plugin_base import PluginBase
0011 
0012 # logger
0013 baseLogger = core_utils.setup_logger("slurm_submitter")
0014 
0015 
0016 # submitter for SLURM batch system
0017 class SlurmSubmitter(PluginBase):
0018     # constructor
0019     def __init__(self, **kwarg):
0020         self.uploadLog = False
0021         self.logBaseURL = None
0022         PluginBase.__init__(self, **kwarg)
0023         if not hasattr(self, "localQueueName"):
0024             self.localQueueName = "grid"
0025         # ncore factor
0026         try:
0027             if hasattr(self, "nCoreFactor"):
0028                 if type(self.nCoreFactor) in [dict]:
0029                     # self.nCoreFactor is a dict for ucore
0030                     # self.nCoreFactor = self.nCoreFactor
0031                     pass
0032                 else:
0033                     self.nCoreFactor = int(self.nCoreFactor)
0034                     if (not self.nCoreFactor) or (self.nCoreFactor < 1):
0035                         self.nCoreFactor = 1
0036             else:
0037                 self.nCoreFactor = 1
0038         except AttributeError:
0039             self.nCoreFactor = 1
0040 
0041     # submit workers
0042     def submit_workers(self, workspec_list):
0043         retList = []
0044         for workSpec in workspec_list:
0045             # make logger
0046             tmpLog = self.make_logger(baseLogger, f"workerID={workSpec.workerID}", method_name="submit_workers")
0047             # set nCore
0048             if self.nCore > 0:
0049                 workSpec.nCore = self.nCore
0050             # make batch script
0051             batchFile = self.make_batch_script(workSpec, tmpLog)
0052             # command
0053             comStr = f"sbatch -D {workSpec.get_access_point()} {batchFile}"
0054             # submit
0055             tmpLog.debug(f"submit with {batchFile}")
0056             p = subprocess.Popen(comStr.split(), shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
0057             # check return code
0058             stdOut, stdErr = p.communicate()
0059             retCode = p.returncode
0060             tmpLog.debug(f"retCode={retCode}")
0061             stdOut_str = stdOut if (isinstance(stdOut, str) or stdOut is None) else stdOut.decode()
0062             stdErr_str = stdErr if (isinstance(stdErr, str) or stdErr is None) else stdErr.decode()
0063             if retCode == 0:
0064                 # extract batchID
0065                 workSpec.batchID = re.search("[^0-9]*([0-9]+)[^0-9]*$", f"{stdOut_str}").group(1)
0066                 tmpLog.debug(f"batchID={workSpec.batchID}")
0067                 # set log files
0068                 if self.logBaseURL and self.logDir:
0069                     stdOut, stdErr = self.get_log_file_names(workSpec.accessPoint, workSpec.workerID)
0070                     rel_stdOut = os.path.relpath(stdOut, self.logDir)
0071                     rel_stdErr = os.path.relpath(stdErr, self.logDir)
0072                     log_stdOut = os.path.join(self.logBaseURL, rel_stdOut)
0073                     log_stdErr = os.path.join(self.logBaseURL, rel_stdErr)
0074                     workSpec.set_log_file("stdout", log_stdOut)
0075                     workSpec.set_log_file("stderr", log_stdErr)
0076                 tmpRetVal = (True, "")
0077             else:
0078                 # failed
0079                 errStr = f"{stdOut_str} {stdErr_str}"
0080                 tmpLog.error(errStr)
0081                 tmpRetVal = (False, errStr)
0082             retList.append(tmpRetVal)
0083         return retList
0084 
0085     def get_core_factor(self, workspec, logger):
0086         try:
0087             if type(self.nCoreFactor) in [dict]:
0088                 n_core_factor = self.nCoreFactor.get(workspec.jobType, {}).get(workspec.resourceType, 1)
0089                 return int(n_core_factor)
0090             return int(self.nCoreFactor)
0091         except Exception as ex:
0092             logger.warning(f"Failed to get core factor: {ex}")
0093         return 1
0094 
0095     def make_placeholder_map(self, workspec, logger):
0096         timeNow = core_utils.naive_utcnow()
0097 
0098         panda_queue_name = self.queueName
0099         this_panda_queue_dict = dict()
0100 
0101         # get default information from queue info
0102         n_core_per_node_from_queue = this_panda_queue_dict.get("corecount", 1) if this_panda_queue_dict.get("corecount", 1) else 1
0103         n_core_per_node = getattr(self, "nCorePerNode", n_core_per_node_from_queue)
0104         if not n_core_per_node:
0105             n_core_per_node = self.nCore
0106 
0107         n_core_factor = self.get_core_factor(workspec, logger)
0108 
0109         n_core_total = workspec.nCore if workspec.nCore else n_core_per_node
0110         n_core_total_factor = n_core_total * n_core_factor
0111         request_ram = max(workspec.minRamCount, 1 * n_core_total) if workspec.minRamCount else 1 * n_core_total
0112         request_disk = workspec.maxDiskCount * 1024 if workspec.maxDiskCount else 1
0113         request_walltime = workspec.maxWalltime if workspec.maxWalltime else 0
0114 
0115         n_node = getattr(self, "nNode", None)
0116         if not n_node:
0117             n_node = ceil(n_core_total / n_core_per_node)
0118 
0119         request_ram_factor = request_ram * n_core_factor
0120         request_ram_bytes = request_ram * 2**20
0121         request_ram_bytes_factor = request_ram * 2**20 * n_core_factor
0122         request_ram_per_core = ceil(request_ram / n_core_total)
0123         request_ram_bytes_per_core = ceil(request_ram_bytes / n_core_total)
0124         request_cputime = request_walltime * n_core_total
0125         request_walltime_minute = ceil(request_walltime / 60)
0126         request_cputime_minute = ceil(request_cputime / 60)
0127 
0128         # GTAG
0129         if self.logBaseURL and self.logDir:
0130             stdOut, stdErr = self.get_log_file_names(workspec.accessPoint, workspec.workerID)
0131             rel_stdOut = os.path.relpath(stdOut, self.logDir)
0132             log_stdOut = os.path.join(self.logBaseURL, rel_stdOut)
0133             gtag = log_stdOut
0134         else:
0135             gtag = "unknown"
0136 
0137         placeholder_map = {
0138             "nCorePerNode": n_core_per_node,
0139             "nCoreTotal": n_core_total_factor,
0140             "nCoreFactor": n_core_factor,
0141             "nNode": n_node,
0142             "requestRam": request_ram_factor,
0143             "requestRamBytes": request_ram_bytes_factor,
0144             "requestRamPerCore": request_ram_per_core,
0145             "requestRamBytesPerCore": request_ram_bytes_per_core,
0146             "requestDisk": request_disk,
0147             "requestWalltime": request_walltime,
0148             "requestWalltimeMinute": request_walltime_minute,
0149             "requestCputime": request_cputime,
0150             "requestCputimeMinute": request_cputime_minute,
0151             "accessPoint": workspec.accessPoint,
0152             "harvesterID": harvester_config.master.harvester_id,
0153             "workerID": workspec.workerID,
0154             "computingSite": workspec.computingSite,
0155             "pandaQueueName": panda_queue_name,
0156             "localQueueName": self.localQueueName,
0157             # 'x509UserProxy': x509_user_proxy,
0158             "logDir": self.logDir,
0159             "logSubDir": os.path.join(self.logDir, timeNow.strftime("%y-%m-%d_%H")),
0160             "jobType": workspec.jobType,
0161             "gtag": gtag,
0162         }
0163         for k in ["tokenDir", "tokenName", "tokenOrigin", "submitMode"]:
0164             try:
0165                 placeholder_map[k] = getattr(self, k)
0166             except Exception:
0167                 pass
0168         return placeholder_map
0169 
0170     # make batch script
0171     def make_batch_script(self, workspec, logger):
0172         # template for batch script
0173         with open(self.templateFile) as f:
0174             template = f.read()
0175         tmpFile = tempfile.NamedTemporaryFile(delete=False, suffix="_submit.sh", dir=workspec.get_access_point())
0176         placeholder = self.make_placeholder_map(workspec, logger)
0177         tmpFile.write(str(template.format_map(core_utils.SafeDict(placeholder))).encode("latin_1"))
0178         tmpFile.close()
0179 
0180         # set execution bit and group permissions on the temp file
0181         st = os.stat(tmpFile.name)
0182         os.chmod(tmpFile.name, st.st_mode | stat.S_IEXEC | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH)
0183 
0184         return tmpFile.name
0185 
0186     # get log file names
0187     def get_log_file_names_old(self, batch_script, batch_id):
0188         stdOut = None
0189         stdErr = None
0190         with open(batch_script) as f:
0191             for line in f:
0192                 if not line.startswith("#SBATCH"):
0193                     continue
0194                 items = line.split()
0195                 if "-o" in items:
0196                     stdOut = items[-1].replace("$SLURM_JOB_ID", batch_id)
0197                 elif "-e" in items:
0198                     stdErr = items[-1].replace("$SLURM_JOB_ID", batch_id)
0199         return stdOut, stdErr
0200 
0201     def get_log_file_names(self, access_point, worker_id):
0202         stdOut = os.path.join(access_point, f"{worker_id}.out")
0203         stdErr = os.path.join(access_point, f"{worker_id}.err")
0204         return stdOut, stdErr