File indexing completed on 2026-04-09 07:58:18
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 import logging
0012 import math
0013 import os
0014 import traceback
0015 import threading
0016 import uuid
0017
0018 from idds.common import exceptions
0019 from idds.common.constants import Sections
0020 from idds.common.constants import (MessageType, MessageTypeStr,
0021 MessageStatus, MessageSource)
0022 from idds.common.plugin.plugin_base import PluginBase
0023 from idds.common.plugin.plugin_utils import load_plugins, load_plugin_sequence
0024 from idds.common.utils import get_process_thread_info
0025 from idds.common.utils import setup_logging, pid_exists, json_dumps, json_loads
0026 from idds.core import health as core_health, messages as core_messages, requests as core_requests
0027 from idds.agents.common.timerscheduler import TimerScheduler
0028 from idds.agents.common.eventbus.eventbus import EventBus
0029 from idds.agents.common.cache.redis import get_redis_cache
0030
0031
0032 setup_logging(__name__)
0033
0034
0035 class PrefixFilter(logging.Filter):
0036 """
0037 A logging filter that adds a prefix to every log record.
0038 """
0039 def __init__(self, prefix):
0040 super().__init__()
0041 self.prefix = prefix
0042
0043 def filter(self, record):
0044 record.prefix = self.prefix
0045 return True
0046
0047
0048 class BaseAgentWorker(object):
0049 """
0050 Agent Worker classes
0051 """
0052 def __init__(self, **kwargs):
0053 super(BaseAgentWorker, self).__init__()
0054
0055 def get_class_name(self):
0056 return self.__class__.__name__
0057
0058 def get_logger(self, log_prefix=None):
0059 """
0060 Set up and return a process-aware logger.
0061 The logger name includes class name + process ID for uniqueness.
0062 """
0063 class_name = self.get_class_name()
0064 pid = os.getpid()
0065 logger_name = f"{class_name}-{pid}"
0066
0067 logger = logging.getLogger(logger_name)
0068
0069 if not log_prefix:
0070 log_prefix = class_name
0071
0072
0073 if not logger.handlers:
0074 handler = logging.StreamHandler()
0075 formatter = logging.Formatter(
0076 fmt="%(asctime)s [%(process)d] [%(levelname)s] %(prefix)s %(name)s: %(message)s",
0077 datefmt="%Y-%m-%d %H:%M:%S",
0078 )
0079 handler.setFormatter(formatter)
0080 handler.addFilter(PrefixFilter(log_prefix))
0081 logger.addHandler(handler)
0082 logger.setLevel(logging.INFO)
0083
0084 return logger
0085
0086
0087 class BaseAgent(TimerScheduler, PluginBase):
0088 """
0089 The base IDDS agent class
0090 """
0091
0092 min_request_id = None
0093 min_request_id_cache = {}
0094 checking_min_request_id_times = 0
0095 poll_new_min_request_id_times = 0
0096 poll_running_min_request_id_times = 0
0097
0098 def __init__(self, num_threads=1, name="BaseAgent", logger=None, use_process_pool=False, **kwargs):
0099 super(BaseAgent, self).__init__(num_threads, name=name, use_process_pool=use_process_pool)
0100 self.name = self.__class__.__name__
0101 self.id = str(uuid.uuid4())[:8]
0102 self.logger = logger
0103 self.setup_logger(self.logger)
0104
0105 self.thread_id = None
0106 self.thread_name = None
0107
0108 self.config_section = Sections.Common
0109
0110 for key in kwargs:
0111 setattr(self, key, kwargs[key])
0112
0113 if not hasattr(self, 'heartbeat_delay'):
0114 self.heartbeat_delay = 60
0115
0116 if not hasattr(self, 'health_message_delay'):
0117 self.health_message_delay = 600
0118
0119 if not hasattr(self, 'poll_operation_time_period'):
0120 self.poll_operation_time_period = 120
0121 else:
0122 self.poll_operation_time_period = int(self.poll_operation_time_period)
0123
0124 if not hasattr(self, 'event_interval_delay'):
0125 self.event_interval_delay = 0.0001
0126 else:
0127 self.event_interval_delay = int(self.event_interval_delay)
0128
0129 if not hasattr(self, 'max_worker_exec_time'):
0130 self.max_worker_exec_time = 3600
0131 else:
0132 self.max_worker_exec_time = int(self.max_worker_exec_time)
0133 self.num_hang_workers, self.num_active_workers = 0, 0
0134
0135 self.plugins = {}
0136 self.plugin_sequence = []
0137
0138 self.agent_attributes = self.load_agent_attributes(kwargs)
0139
0140 self.logger.info("agent_attributes: %s" % self.agent_attributes)
0141
0142 self.event_bus = EventBus()
0143 self.event_func_map = {}
0144 self.event_futures = {}
0145
0146 self.cache = get_redis_cache()
0147
0148 def set_max_workers(self):
0149 self.number_workers = 0
0150 if not hasattr(self, 'max_number_workers') or not self.max_number_workers:
0151 self.max_number_workers = 3
0152 else:
0153 self.max_number_workers = int(self.max_number_workers)
0154
0155 def get_event_bus(self):
0156 self.event_bus
0157
0158 def get_name(self):
0159 return self.name
0160
0161 def init_thread_info(self):
0162 hb_thread = threading.current_thread()
0163 self.thread_id = hb_thread.ident
0164 self.thread_name = hb_thread.name
0165
0166 def get_thread_id(self):
0167 return self.thread_id
0168
0169 def get_thread_name(self):
0170 return self.thread_name
0171
0172 def load_agent_attributes(self, kwargs):
0173 rets = {}
0174 for key in kwargs:
0175 if '.' not in key:
0176 continue
0177 key_items = key.split('.')
0178
0179 ret_items = rets
0180 for item in key_items[:-1]:
0181 if item not in ret_items:
0182 ret_items[item] = {}
0183 ret_items = ret_items[item]
0184 ret_items[key_items[-1]] = kwargs[key]
0185 return rets
0186
0187 def load_plugin_sequence(self):
0188 self.plugin_sequence = load_plugin_sequence(self.config_section)
0189
0190 def load_plugins(self):
0191 self.plugins = load_plugins(self.config_section, logger=self.logger)
0192 self.logger.info("plugins: %s" % str(self.plugins))
0193 """
0194 for plugin_name in self.plugin_sequence:
0195 if plugin_name not in self.plugins:
0196 raise AgentPluginError("Plugin %s is defined in plugin_sequence but no plugin is defined with this name")
0197 for plugin_name in self.plugins:
0198 if plugin_name not in self.plugin_sequence:
0199 raise AgentPluginError("Plugin %s is defined but it is not defined in plugin_sequence" % plugin_name)
0200 """
0201
0202 def get_plugin(self, plugin_name):
0203 if plugin_name in self.plugins and self.plugins[plugin_name]:
0204 return self.plugins[plugin_name]
0205 raise exceptions.AgentPluginError("No corresponding plugin configured for %s" % plugin_name)
0206
0207 def load_min_request_id(self):
0208 try:
0209 min_request_id = core_requests.get_min_request_id()
0210 self.logger.info(f"loaded min_request_id: {min_request_id}")
0211 except Exception as ex:
0212 self.logger.error(f"failed to load min_request_id: {ex}")
0213 min_request_id = 1
0214 self.logger.info(f"Set min_request_id to : {min_request_id}")
0215 BaseAgent.min_request_id = min_request_id
0216
0217 def get_num_hang_active_workers(self):
0218 return self.num_hang_workers, self.num_active_workers
0219
0220 def get_event_bulk_size(self):
0221 return math.ceil(self.get_num_free_workers() / 2)
0222
0223 def init_event_function_map(self):
0224 self.event_func_map = {}
0225
0226 def get_event_function_map(self):
0227 return self.event_func_map
0228
0229 def execute_event_schedule(self):
0230 event_funcs = self.get_event_function_map()
0231 for event_type in event_funcs:
0232 exec_func = event_funcs[event_type]['exec_func']
0233 bulk_size = self.get_event_bulk_size()
0234 if bulk_size > 0:
0235 events = self.event_bus.get(event_type, num_events=bulk_size, wait=2, callback=None)
0236 for event in events:
0237 self.submit(exec_func, event)
0238
0239 def execute_schedules(self):
0240
0241 self.execute_timer_schedule_thread()
0242 self.execute_event_schedule()
0243
0244 def execute(self):
0245 while not self.graceful_stop.is_set():
0246 try:
0247
0248 self.execute_timer_schedule_thread()
0249 self.execute_event_schedule()
0250 self.graceful_stop.wait(0.00001)
0251 except Exception as error:
0252 self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc()))
0253
0254 def run(self):
0255 """
0256 Main run function.
0257 """
0258 try:
0259 self.logger.info("Starting main thread")
0260
0261 self.init_thread_info()
0262 self.load_plugins()
0263
0264 self.execute()
0265 except KeyboardInterrupt:
0266 self.stop()
0267
0268 def __call__(self):
0269 self.run()
0270
0271 def stop(self):
0272 super(BaseAgent, self).stop()
0273 try:
0274 self.event_bus.stop()
0275 except Exception:
0276 pass
0277
0278 def terminate(self):
0279 self.stop()
0280
0281 def get_hostname(self):
0282 hostname, pid, thread_id, thread_name = get_process_thread_info()
0283 return hostname
0284
0285 def get_process_thread_info(self):
0286 hostname, pid, thread_id, thread_name = get_process_thread_info()
0287 return hostname, pid, thread_id, thread_name
0288
0289 def is_self(self, health_item):
0290 hostname, pid, thread_id, thread_name = get_process_thread_info()
0291 ret = False
0292 if ('hostname' in health_item and 'pid' in health_item and 'agent' in health_item
0293 and 'thread_id' in health_item and health_item['hostname'] == hostname
0294 and health_item['pid'] == pid and health_item['agent'] == self.get_name()):
0295 ret = True
0296 if not ret:
0297 pass
0298
0299 return ret
0300
0301 def get_health_payload(self):
0302 num_hang_workers, num_active_workers = self.get_num_hang_active_workers()
0303 return {'num_hang_workers': num_hang_workers, 'num_active_workers': num_active_workers}
0304
0305 def is_ready(self):
0306 return True
0307
0308 def health_heartbeat(self, heartbeat_delay=None):
0309 if heartbeat_delay:
0310 self.heartbeat_delay = heartbeat_delay
0311 hostname, pid, thread_id, thread_name = get_process_thread_info()
0312 payload = self.get_health_payload()
0313 if payload:
0314 payload = json_dumps(payload)
0315 if self.is_ready():
0316 self.logger.debug("health heartbeat: agent %s, pid %s, thread %s, delay %s, payload %s" % (self.get_name(), pid, thread_name, self.heartbeat_delay, payload))
0317 core_health.add_health_item(agent=self.get_name(), hostname=hostname, pid=pid,
0318 thread_id=thread_id, thread_name=thread_name, payload=payload)
0319 core_health.clean_health(older_than=self.heartbeat_delay * 3)
0320
0321 health_items = core_health.retrieve_health_items()
0322 pids, pid_not_exists = [], []
0323 for health_item in health_items:
0324 if health_item['hostname'] == hostname:
0325 pid = health_item['pid']
0326 if pid not in pids:
0327 pids.append(pid)
0328 for pid in pids:
0329 if not pid_exists(pid):
0330 pid_not_exists.append(pid)
0331 if pid_not_exists:
0332 core_health.clean_health(hostname=hostname, pids=pid_not_exists, older_than=None)
0333
0334 def get_health_items(self):
0335 try:
0336 hostname, pid, thread_id, thread_name = get_process_thread_info()
0337 core_health.clean_health(older_than=self.heartbeat_delay * 3)
0338 health_items = core_health.retrieve_health_items()
0339 pids, pid_not_exists = [], []
0340 for health_item in health_items:
0341 if health_item['hostname'] == hostname:
0342 pid = health_item['pid']
0343 if pid not in pids:
0344 pids.append(pid)
0345 for pid in pids:
0346 if not pid_exists(pid):
0347 pid_not_exists.append(pid)
0348 if pid_not_exists:
0349 core_health.clean_health(hostname=hostname, pids=pid_not_exists, older_than=None)
0350
0351 health_items = core_health.retrieve_health_items()
0352 return health_items
0353 except Exception as ex:
0354 self.logger.warn("Failed to get health items: %s" % str(ex))
0355
0356 return []
0357
0358 def get_availability(self):
0359 try:
0360 availability = {}
0361 health_items = self.get_health_items()
0362 hostname, pid, thread_id, thread_name = get_process_thread_info()
0363 for item in health_items:
0364 if item['hostname'] == hostname:
0365 if item['agent'] not in availability:
0366 availability[item['agent']] = {}
0367 payload = item['payload']
0368 num_hang_workers = 0
0369 num_active_workers = 0
0370 if payload:
0371 payload = json_loads(payload)
0372 num_hang_workers = payload.get('num_hang_workers', 0)
0373 num_active_workers = payload.get('num_active_workers', 0)
0374
0375 availability[item['agent']]['num_hang_workers'] = num_hang_workers
0376 availability[item['agent']]['num_active_workers'] = num_active_workers
0377
0378 return availability
0379 except Exception as ex:
0380 self.logger.warn("Failed to get availability: %s" % str(ex))
0381 return {}
0382
0383 def add_default_tasks(self):
0384 task = self.create_task(task_func=self.health_heartbeat, task_output_queue=None,
0385 task_args=tuple(), task_kwargs={}, delay_time=self.heartbeat_delay,
0386 priority=1)
0387 self.add_task(task)
0388
0389 def generate_health_messages(self):
0390 core_health.clean_health(older_than=self.heartbeat_delay * 3)
0391 items = core_health.retrieve_health_items()
0392 msg_content = {'msg_type': MessageTypeStr.HealthHeartbeat.value,
0393 'agents': items}
0394 num_msg_content = len(items)
0395
0396 message = {'msg_type': MessageType.HealthHeartbeat,
0397 'status': MessageStatus.New,
0398 'source': MessageSource.Conductor,
0399 'request_id': None,
0400 'workload_id': None,
0401 'transform_id': None,
0402 'num_contents': num_msg_content,
0403 'msg_content': msg_content}
0404 core_messages.add_message(msg_type=message['msg_type'],
0405 status=message['status'],
0406 source=message['source'],
0407 request_id=message['request_id'],
0408 workload_id=message['workload_id'],
0409 transform_id=message['transform_id'],
0410 num_contents=message['num_contents'],
0411 msg_content=message['msg_content'])
0412
0413 def add_health_message_task(self):
0414 task = self.create_task(task_func=self.generate_health_messages, task_output_queue=None,
0415 task_args=tuple(), task_kwargs={}, delay_time=self.health_message_delay,
0416 priority=1)
0417 self.add_task(task)
0418
0419 def get_request_message(self, request_id, bulk_size=1):
0420 return core_messages.retrieve_request_messages(request_id, bulk_size=bulk_size)
0421
0422 def get_transform_message(self, request_id, transform_id, bulk_size=1):
0423 return core_messages.retrieve_transform_messages(request_id=request_id, transform_id=transform_id, bulk_size=bulk_size)
0424
0425 def get_processing_message(self, request_id, processing_id, bulk_size=1):
0426 return core_messages.retrieve_processing_messages(request_id=request_id, processing_id=processing_id, bulk_size=bulk_size)
0427
0428
0429 if __name__ == '__main__':
0430 agent = BaseAgent()
0431 agent()