Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:58:18

0001 #!/usr/bin/env python
0002 #
0003 # Licensed under the Apache License, Version 2.0 (the "License");
0004 # You may not use this file except in compliance with the License.
0005 # You may obtain a copy of the License at
0006 # http://www.apache.org/licenses/LICENSE-2.0OA
0007 #
0008 # Authors:
0009 # - Wen Guan, <wen.guan@cern.ch>, 2019 - 2025
0010 
0011 import logging
0012 import math
0013 import os
0014 import traceback
0015 import threading
0016 import uuid
0017 
0018 from idds.common import exceptions
0019 from idds.common.constants import Sections
0020 from idds.common.constants import (MessageType, MessageTypeStr,
0021                                    MessageStatus, MessageSource)
0022 from idds.common.plugin.plugin_base import PluginBase
0023 from idds.common.plugin.plugin_utils import load_plugins, load_plugin_sequence
0024 from idds.common.utils import get_process_thread_info
0025 from idds.common.utils import setup_logging, pid_exists, json_dumps, json_loads
0026 from idds.core import health as core_health, messages as core_messages, requests as core_requests
0027 from idds.agents.common.timerscheduler import TimerScheduler
0028 from idds.agents.common.eventbus.eventbus import EventBus
0029 from idds.agents.common.cache.redis import get_redis_cache
0030 
0031 
0032 setup_logging(__name__)
0033 
0034 
0035 class PrefixFilter(logging.Filter):
0036     """
0037     A logging filter that adds a prefix to every log record.
0038     """
0039     def __init__(self, prefix):
0040         super().__init__()
0041         self.prefix = prefix
0042 
0043     def filter(self, record):
0044         record.prefix = self.prefix
0045         return True
0046 
0047 
0048 class BaseAgentWorker(object):
0049     """
0050     Agent Worker classes
0051     """
0052     def __init__(self, **kwargs):
0053         super(BaseAgentWorker, self).__init__()
0054 
0055     def get_class_name(self):
0056         return self.__class__.__name__
0057 
0058     def get_logger(self, log_prefix=None):
0059         """
0060         Set up and return a process-aware logger.
0061         The logger name includes class name + process ID for uniqueness.
0062         """
0063         class_name = self.get_class_name()
0064         pid = os.getpid()
0065         logger_name = f"{class_name}-{pid}"
0066 
0067         logger = logging.getLogger(logger_name)
0068 
0069         if not log_prefix:
0070             log_prefix = class_name
0071 
0072         # Optional: configure if not already configured
0073         if not logger.handlers:
0074             handler = logging.StreamHandler()
0075             formatter = logging.Formatter(
0076                 fmt="%(asctime)s [%(process)d] [%(levelname)s] %(prefix)s %(name)s: %(message)s",
0077                 datefmt="%Y-%m-%d %H:%M:%S",
0078             )
0079             handler.setFormatter(formatter)
0080             handler.addFilter(PrefixFilter(log_prefix))
0081             logger.addHandler(handler)
0082             logger.setLevel(logging.INFO)
0083 
0084         return logger
0085 
0086 
0087 class BaseAgent(TimerScheduler, PluginBase):
0088     """
0089     The base IDDS agent class
0090     """
0091 
0092     min_request_id = None
0093     min_request_id_cache = {}
0094     checking_min_request_id_times = 0
0095     poll_new_min_request_id_times = 0
0096     poll_running_min_request_id_times = 0
0097 
0098     def __init__(self, num_threads=1, name="BaseAgent", logger=None, use_process_pool=False, **kwargs):
0099         super(BaseAgent, self).__init__(num_threads, name=name, use_process_pool=use_process_pool)
0100         self.name = self.__class__.__name__
0101         self.id = str(uuid.uuid4())[:8]
0102         self.logger = logger
0103         self.setup_logger(self.logger)
0104 
0105         self.thread_id = None
0106         self.thread_name = None
0107 
0108         self.config_section = Sections.Common
0109 
0110         for key in kwargs:
0111             setattr(self, key, kwargs[key])
0112 
0113         if not hasattr(self, 'heartbeat_delay'):
0114             self.heartbeat_delay = 60
0115 
0116         if not hasattr(self, 'health_message_delay'):
0117             self.health_message_delay = 600
0118 
0119         if not hasattr(self, 'poll_operation_time_period'):
0120             self.poll_operation_time_period = 120
0121         else:
0122             self.poll_operation_time_period = int(self.poll_operation_time_period)
0123 
0124         if not hasattr(self, 'event_interval_delay'):
0125             self.event_interval_delay = 0.0001
0126         else:
0127             self.event_interval_delay = int(self.event_interval_delay)
0128 
0129         if not hasattr(self, 'max_worker_exec_time'):
0130             self.max_worker_exec_time = 3600
0131         else:
0132             self.max_worker_exec_time = int(self.max_worker_exec_time)
0133         self.num_hang_workers, self.num_active_workers = 0, 0
0134 
0135         self.plugins = {}
0136         self.plugin_sequence = []
0137 
0138         self.agent_attributes = self.load_agent_attributes(kwargs)
0139 
0140         self.logger.info("agent_attributes: %s" % self.agent_attributes)
0141 
0142         self.event_bus = EventBus()
0143         self.event_func_map = {}
0144         self.event_futures = {}
0145 
0146         self.cache = get_redis_cache()
0147 
0148     def set_max_workers(self):
0149         self.number_workers = 0
0150         if not hasattr(self, 'max_number_workers') or not self.max_number_workers:
0151             self.max_number_workers = 3
0152         else:
0153             self.max_number_workers = int(self.max_number_workers)
0154 
0155     def get_event_bus(self):
0156         self.event_bus
0157 
0158     def get_name(self):
0159         return self.name
0160 
0161     def init_thread_info(self):
0162         hb_thread = threading.current_thread()
0163         self.thread_id = hb_thread.ident
0164         self.thread_name = hb_thread.name
0165 
0166     def get_thread_id(self):
0167         return self.thread_id
0168 
0169     def get_thread_name(self):
0170         return self.thread_name
0171 
0172     def load_agent_attributes(self, kwargs):
0173         rets = {}
0174         for key in kwargs:
0175             if '.' not in key:
0176                 continue
0177             key_items = key.split('.')
0178 
0179             ret_items = rets
0180             for item in key_items[:-1]:
0181                 if item not in ret_items:
0182                     ret_items[item] = {}
0183                 ret_items = ret_items[item]
0184             ret_items[key_items[-1]] = kwargs[key]
0185         return rets
0186 
0187     def load_plugin_sequence(self):
0188         self.plugin_sequence = load_plugin_sequence(self.config_section)
0189 
0190     def load_plugins(self):
0191         self.plugins = load_plugins(self.config_section, logger=self.logger)
0192         self.logger.info("plugins: %s" % str(self.plugins))
0193         """
0194         for plugin_name in self.plugin_sequence:
0195             if plugin_name not in self.plugins:
0196                 raise AgentPluginError("Plugin %s is defined in plugin_sequence but no plugin is defined with this name")
0197         for plugin_name in self.plugins:
0198             if plugin_name not in self.plugin_sequence:
0199                 raise AgentPluginError("Plugin %s is defined but it is not defined in plugin_sequence" % plugin_name)
0200         """
0201 
0202     def get_plugin(self, plugin_name):
0203         if plugin_name in self.plugins and self.plugins[plugin_name]:
0204             return self.plugins[plugin_name]
0205         raise exceptions.AgentPluginError("No corresponding plugin configured for %s" % plugin_name)
0206 
0207     def load_min_request_id(self):
0208         try:
0209             min_request_id = core_requests.get_min_request_id()
0210             self.logger.info(f"loaded min_request_id: {min_request_id}")
0211         except Exception as ex:
0212             self.logger.error(f"failed to load min_request_id: {ex}")
0213             min_request_id = 1
0214         self.logger.info(f"Set min_request_id to : {min_request_id}")
0215         BaseAgent.min_request_id = min_request_id
0216 
0217     def get_num_hang_active_workers(self):
0218         return self.num_hang_workers, self.num_active_workers
0219 
0220     def get_event_bulk_size(self):
0221         return math.ceil(self.get_num_free_workers() / 2)
0222 
0223     def init_event_function_map(self):
0224         self.event_func_map = {}
0225 
0226     def get_event_function_map(self):
0227         return self.event_func_map
0228 
0229     def execute_event_schedule(self):
0230         event_funcs = self.get_event_function_map()
0231         for event_type in event_funcs:
0232             exec_func = event_funcs[event_type]['exec_func']
0233             bulk_size = self.get_event_bulk_size()
0234             if bulk_size > 0:
0235                 events = self.event_bus.get(event_type, num_events=bulk_size, wait=2, callback=None)
0236                 for event in events:
0237                     self.submit(exec_func, event)
0238 
0239     def execute_schedules(self):
0240         # self.execute_timer_schedule()
0241         self.execute_timer_schedule_thread()
0242         self.execute_event_schedule()
0243 
0244     def execute(self):
0245         while not self.graceful_stop.is_set():
0246             try:
0247                 # self.execute_timer_schedule()
0248                 self.execute_timer_schedule_thread()
0249                 self.execute_event_schedule()
0250                 self.graceful_stop.wait(0.00001)
0251             except Exception as error:
0252                 self.logger.critical("Caught an exception: %s\n%s" % (str(error), traceback.format_exc()))
0253 
0254     def run(self):
0255         """
0256         Main run function.
0257         """
0258         try:
0259             self.logger.info("Starting main thread")
0260 
0261             self.init_thread_info()
0262             self.load_plugins()
0263 
0264             self.execute()
0265         except KeyboardInterrupt:
0266             self.stop()
0267 
0268     def __call__(self):
0269         self.run()
0270 
0271     def stop(self):
0272         super(BaseAgent, self).stop()
0273         try:
0274             self.event_bus.stop()
0275         except Exception:
0276             pass
0277 
0278     def terminate(self):
0279         self.stop()
0280 
0281     def get_hostname(self):
0282         hostname, pid, thread_id, thread_name = get_process_thread_info()
0283         return hostname
0284 
0285     def get_process_thread_info(self):
0286         hostname, pid, thread_id, thread_name = get_process_thread_info()
0287         return hostname, pid, thread_id, thread_name
0288 
0289     def is_self(self, health_item):
0290         hostname, pid, thread_id, thread_name = get_process_thread_info()
0291         ret = False
0292         if ('hostname' in health_item and 'pid' in health_item and 'agent' in health_item
0293             and 'thread_id' in health_item and health_item['hostname'] == hostname        # noqa W503
0294             and health_item['pid'] == pid and health_item['agent'] == self.get_name()):     # noqa W503
0295             ret = True
0296         if not ret:
0297             pass
0298             # self.logger.debug("is_self: hostname %s, pid %s, thread_id %s, agent %s, health %s" % (hostname, pid, thread_id, self.get_name(), health_item))
0299         return ret
0300 
0301     def get_health_payload(self):
0302         num_hang_workers, num_active_workers = self.get_num_hang_active_workers()
0303         return {'num_hang_workers': num_hang_workers, 'num_active_workers': num_active_workers}
0304 
0305     def is_ready(self):
0306         return True
0307 
0308     def health_heartbeat(self, heartbeat_delay=None):
0309         if heartbeat_delay:
0310             self.heartbeat_delay = heartbeat_delay
0311         hostname, pid, thread_id, thread_name = get_process_thread_info()
0312         payload = self.get_health_payload()
0313         if payload:
0314             payload = json_dumps(payload)
0315         if self.is_ready():
0316             self.logger.debug("health heartbeat: agent %s, pid %s, thread %s, delay %s, payload %s" % (self.get_name(), pid, thread_name, self.heartbeat_delay, payload))
0317             core_health.add_health_item(agent=self.get_name(), hostname=hostname, pid=pid,
0318                                         thread_id=thread_id, thread_name=thread_name, payload=payload)
0319             core_health.clean_health(older_than=self.heartbeat_delay * 3)
0320 
0321             health_items = core_health.retrieve_health_items()
0322             pids, pid_not_exists = [], []
0323             for health_item in health_items:
0324                 if health_item['hostname'] == hostname:
0325                     pid = health_item['pid']
0326                     if pid not in pids:
0327                         pids.append(pid)
0328             for pid in pids:
0329                 if not pid_exists(pid):
0330                     pid_not_exists.append(pid)
0331             if pid_not_exists:
0332                 core_health.clean_health(hostname=hostname, pids=pid_not_exists, older_than=None)
0333 
0334     def get_health_items(self):
0335         try:
0336             hostname, pid, thread_id, thread_name = get_process_thread_info()
0337             core_health.clean_health(older_than=self.heartbeat_delay * 3)
0338             health_items = core_health.retrieve_health_items()
0339             pids, pid_not_exists = [], []
0340             for health_item in health_items:
0341                 if health_item['hostname'] == hostname:
0342                     pid = health_item['pid']
0343                     if pid not in pids:
0344                         pids.append(pid)
0345             for pid in pids:
0346                 if not pid_exists(pid):
0347                     pid_not_exists.append(pid)
0348             if pid_not_exists:
0349                 core_health.clean_health(hostname=hostname, pids=pid_not_exists, older_than=None)
0350 
0351             health_items = core_health.retrieve_health_items()
0352             return health_items
0353         except Exception as ex:
0354             self.logger.warn("Failed to get health items: %s" % str(ex))
0355 
0356         return []
0357 
0358     def get_availability(self):
0359         try:
0360             availability = {}
0361             health_items = self.get_health_items()
0362             hostname, pid, thread_id, thread_name = get_process_thread_info()
0363             for item in health_items:
0364                 if item['hostname'] == hostname:
0365                     if item['agent'] not in availability:
0366                         availability[item['agent']] = {}
0367                     payload = item['payload']
0368                     num_hang_workers = 0
0369                     num_active_workers = 0
0370                     if payload:
0371                         payload = json_loads(payload)
0372                         num_hang_workers = payload.get('num_hang_workers', 0)
0373                         num_active_workers = payload.get('num_active_workers', 0)
0374 
0375                     availability[item['agent']]['num_hang_workers'] = num_hang_workers
0376                     availability[item['agent']]['num_active_workers'] = num_active_workers
0377 
0378             return availability
0379         except Exception as ex:
0380             self.logger.warn("Failed to get availability: %s" % str(ex))
0381         return {}
0382 
0383     def add_default_tasks(self):
0384         task = self.create_task(task_func=self.health_heartbeat, task_output_queue=None,
0385                                 task_args=tuple(), task_kwargs={}, delay_time=self.heartbeat_delay,
0386                                 priority=1)
0387         self.add_task(task)
0388 
0389     def generate_health_messages(self):
0390         core_health.clean_health(older_than=self.heartbeat_delay * 3)
0391         items = core_health.retrieve_health_items()
0392         msg_content = {'msg_type': MessageTypeStr.HealthHeartbeat.value,
0393                        'agents': items}
0394         num_msg_content = len(items)
0395 
0396         message = {'msg_type': MessageType.HealthHeartbeat,
0397                    'status': MessageStatus.New,
0398                    'source': MessageSource.Conductor,
0399                    'request_id': None,
0400                    'workload_id': None,
0401                    'transform_id': None,
0402                    'num_contents': num_msg_content,
0403                    'msg_content': msg_content}
0404         core_messages.add_message(msg_type=message['msg_type'],
0405                                   status=message['status'],
0406                                   source=message['source'],
0407                                   request_id=message['request_id'],
0408                                   workload_id=message['workload_id'],
0409                                   transform_id=message['transform_id'],
0410                                   num_contents=message['num_contents'],
0411                                   msg_content=message['msg_content'])
0412 
0413     def add_health_message_task(self):
0414         task = self.create_task(task_func=self.generate_health_messages, task_output_queue=None,
0415                                 task_args=tuple(), task_kwargs={}, delay_time=self.health_message_delay,
0416                                 priority=1)
0417         self.add_task(task)
0418 
0419     def get_request_message(self, request_id, bulk_size=1):
0420         return core_messages.retrieve_request_messages(request_id, bulk_size=bulk_size)
0421 
0422     def get_transform_message(self, request_id, transform_id, bulk_size=1):
0423         return core_messages.retrieve_transform_messages(request_id=request_id, transform_id=transform_id, bulk_size=bulk_size)
0424 
0425     def get_processing_message(self, request_id, processing_id, bulk_size=1):
0426         return core_messages.retrieve_processing_messages(request_id=request_id, processing_id=processing_id, bulk_size=bulk_size)
0427 
0428 
0429 if __name__ == '__main__':
0430     agent = BaseAgent()
0431     agent()