File indexing completed on 2026-04-19 08:00:04
0001 import os
0002 import re
0003 import stat
0004 import subprocess
0005 import time
0006 import yaml
0007 from typing import Optional
0008 from math import ceil
0009
0010 from globus_compute_sdk import Executor, Client
0011 from globus_compute_sdk.sdk.shell_function import ShellFunction, ShellResult
0012
0013 from pandaharvester.harvesterconfig import harvester_config
0014 from pandaharvester.harvestercore import core_utils
0015 from pandaharvester.harvestercore.plugin_base import PluginBase
0016 from pandaharvester.harvestercore.plugin_factory import PluginFactory
0017 from pandaharvester.harvestercore.queue_config_mapper import QueueConfigMapper
0018
0019
0020 baseLogger = core_utils.setup_logger("globus_compute_slurm_submitter")
0021
0022
0023 class GlobusComputeSlurmSubmitter(PluginBase):
0024 def __init__(self, **kwargs):
0025 self.uploadLog = False
0026 self.logBaseURL = None
0027 PluginBase.__init__(self, **kwargs)
0028 if not hasattr(self, "localQueueName"):
0029 self.localQueueName = "grid"
0030 try:
0031 if hasattr(self, "nCoreFactor"):
0032 if type(self.nCoreFactor) in [dict]:
0033
0034
0035 pass
0036 else:
0037 self.nCoreFactor = int(self.nCoreFactor)
0038 if (not self.nCoreFactor) or (self.nCoreFactor < 1):
0039 self.nCoreFactor = 1
0040 else:
0041 self.nCoreFactor = 1
0042 except AttributeError:
0043 self.nCoreFactor = 1
0044
0045 self.mep_id = kwargs.get("mep_id")
0046 self.template_file = kwargs.get("templateFile")
0047 self.gc_client = Client()
0048 self.config_file = kwargs.get("config_file")
0049 self.slurm_log_dir = kwargs.get("slurm_log_dir")
0050
0051 with open(self.config_file, "r") as cf:
0052 self.config = yaml.safe_load(cf)
0053
0054 with open(self.template_file, 'r') as file:
0055 self.template_init = file.read()
0056
0057 def get_core_factor(self, workspec, logger):
0058 try:
0059 if type(self.nCoreFactor) in [dict]:
0060 n_core_factor = self.nCoreFactor.get(workspec.jobType, {}).get(workspec.resourceType, 1)
0061 return int(n_core_factor)
0062 return int(self.nCoreFactor)
0063 except Exception as ex:
0064 logger.warning(f"Failed to get core factor: {ex}")
0065 return 1
0066
0067 def render_template(self, workspec, logger):
0068 timeNow = core_utils.naive_utcnow()
0069
0070 this_panda_queue_dict = dict()
0071 n_core_per_node_from_queue = this_panda_queue_dict.get("corecount", 1) if this_panda_queue_dict.get("corecount", 1) else 1
0072 try:
0073 n_core_per_node = self.nCorePerNode if self.nCorePerNode else n_core_per_node_from_queue
0074 except AttributeError:
0075 n_core_per_node = n_core_per_node_from_queue
0076 if not n_core_per_node:
0077 n_core_per_node = self.nCore
0078
0079 n_core_factor = self.get_core_factor(workspec, logger)
0080 n_core_total = workspec.nCore if workspec.nCore else n_core_per_node
0081 n_core_total_factor = n_core_total * n_core_factor
0082 request_ram = max(workspec.minRamCount, 1 * n_core_total) if workspec.minRamCount else 1 * n_core_total
0083 request_disk = workspec.maxDiskCount * 1024 if workspec.maxDiskCount else 1
0084 request_walltime = workspec.maxWalltime if workspec.maxWalltime else 0
0085
0086 n_node = ceil(n_core_total / n_core_per_node)
0087 request_ram_factor = request_ram * n_core_factor
0088 request_ram_bytes = request_ram * 2**20
0089 request_ram_bytes_factor = request_ram * 2**20 * n_core_factor
0090 request_ram_per_core = ceil(request_ram * n_node / n_core_total)
0091 request_ram_bytes_per_core = ceil(request_ram_bytes * n_node / n_core_total)
0092 request_cputime = request_walltime * n_core_total
0093 request_walltime_minute = ceil(request_walltime / 60)
0094 request_cputime_minute = ceil(request_cputime / 60)
0095
0096
0097 if self.remote_workdir:
0098 remote_accessPoint = os.path.join(self.remote_workdir, self.queueName, str(workspec.workerID))
0099 logger.debug(f"In render_template, enable remote_accessPoint, set it to be {remote_accessPoint}")
0100 else:
0101 remote_accessPoint = workspec.accessPoint
0102 logger.debug(f"In render_template, disable remote_accessPoint, set it to be {remote_accessPoint}")
0103
0104 variables = {}
0105
0106 variables['nCorePerNode'] = n_core_per_node
0107 variables['nCoreTotal'] = n_core_total_factor
0108 variables['nCoreFactor'] = n_core_factor
0109 variables['nNode'] = n_node
0110 variables['requestRam'] = request_ram_factor
0111 variables['requestRamBytes'] = request_ram_bytes_factor
0112 variables['requestRamPerCore'] = request_ram_per_core
0113 variables['requestRamBytesPerCore'] = request_ram_bytes_per_core
0114 variables['requestDisk'] = request_disk
0115 variables['requestWalltime'] = request_walltime
0116 variables['requestWalltimeMinute'] = request_walltime_minute
0117 variables['requestCputime'] = request_cputime
0118 variables['requestCputimeMinute'] = request_cputime_minute
0119 variables['accessPoint'] = workspec.accessPoint
0120 variables['remote_accessPoint'] = remote_accessPoint
0121 variables['harvesterID'] = harvester_config.master.harvester_id
0122 variables['workerID'] = workspec.workerID
0123 variables['computingSite'] = workspec.computingSite
0124 variables['pandaQueueName'] = self.queueName
0125 variables['localQueueName'] = self.localQueueName
0126 variables['jobType'] = workspec.jobType
0127 variables['tokenOrigin'] = self.tokenOrigin
0128 variables['tokenDir'] = self.tokenDir
0129 variables['tokenName'] = self.tokenName
0130
0131 variables['harvester_dir'] = self.config.get('harvester_dir', '/global/common/software/m2616/harvester-perlmutter')
0132
0133 variables['harvester_tasks_per_node'] = self.config.get('tasks_per_node', 4)
0134
0135 ATHENA_COMPUTE_POLICY = self.config.get('ATHENA_COMPUTE_POLICY', 'normal')
0136 if ATHENA_COMPUTE_POLICY == 'normal':
0137 variables['ATHENA_nCorePerNode'] = self.config.get('ATHENA_nCorePerNode', 256)
0138 variables['ATHENA_PROC_NUMBER_JOB'] = variables['ATHENA_nCorePerNode'] // variables['harvester_tasks_per_node']
0139 variables['ATHENA_PROC_NUMBER'] = variables['ATHENA_PROC_NUMBER_JOB']
0140 variables['ATHENA_CORE_NUMBER'] = variables['ATHENA_PROC_NUMBER_JOB']
0141 else:
0142 exit(1)
0143
0144 task_content = self.template_init.format(**variables)
0145 user_endpoint_config = {
0146 "account": self.config["account"],
0147 "partition": self.config["partition"],
0148 "parallelism": n_node,
0149 "nodes_per_block": n_node,
0150 "walltime": self.config["walltime"],
0151 "scheduler_options": self.config["scheduler_options"],
0152 "worker_init_extra": self.config["worker_init_extra"],
0153 "max_blocks": self.config["max_blocks"],
0154 }
0155 return task_content, user_endpoint_config
0156
0157 def submit_workers(self, workspec_list):
0158 retList = []
0159 for workSpec in workspec_list:
0160 tmpLog = self.make_logger(baseLogger, f"workerID={workSpec.workerID}", method_name="submit_workers")
0161 tmpLog.debug(f"Step 1: Start submission with workSpec with detail of {workSpec.__dict__}")
0162 try:
0163 task_content, user_endpoint_config = self.render_template(workSpec, tmpLog)
0164 tmpLog.debug(f"Step 2: Rendering finished. \nGet task_content: \n{task_content}\nGet user_endpoint_config: \n{user_endpoint_config}")
0165
0166 batch = self.gc_client.create_batch(user_endpoint_config=user_endpoint_config)
0167
0168 self.bf = ShellFunction(cmd=task_content, snippet_lines=2000, log_dir=self.slurm_log_dir)
0169 self.func_id = self.gc_client.register_function(self.bf)
0170 tmpLog.debug(f"Finish registration of function with func_id = {self.func_id}")
0171
0172 batch.add(function_id=self.func_id)
0173 tmpLog.debug(f"Step 3: Finish creating batch, now going to submit batch with GC.")
0174 batch_res = self.gc_client.batch_run(batch=batch, endpoint_id=self.mep_id)
0175 tmpLog.debug(f"Step 4: Made a batch submission to GC. Got batch_res = \n{batch_res}")
0176 for func_id, each_task_list in batch_res['tasks'].items():
0177 workSpec.batchID = None
0178 globus_compute_attr_dict = {}
0179 globus_compute_attr_dict["sandbox_dir"] = os.path.join(self.slurm_log_dir, each_task_list[0])
0180 globus_compute_attr_dict["gc_task_id"] = each_task_list[0]
0181 workSpec.set_work_attributes({"globus_compute_attr": globus_compute_attr_dict})
0182 tmpLog.debug(f"Now setting: \nbatchID = {workSpec.batchID}, \nGC sandbox dir = {globus_compute_attr_dict['sandbox_dir']}")
0183 tmpRetVal = (True, "")
0184 except Exception as e:
0185 tmpLog.error(f"Error during submit workers: {e}")
0186 tmpRetVal = (False, str(e))
0187 retList.append(tmpRetVal)
0188 return retList