Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:58:18

0001 #!/usr/bin/env python
0002 #
0003 # Licensed under the Apache License, Version 2.0 (the "License");
0004 # You may not use this file except in compliance with the License.
0005 # You may obtain a copy of the License at
0006 # http://www.apache.org/licenses/LICENSE-2.0OA
0007 #
0008 # Authors:
0009 # - Wen Guan, <wen.guan@cern.ch>, 2019 - 2024
0010 
0011 import datetime
0012 import random
0013 import time
0014 import traceback
0015 try:
0016     # python 3
0017     from queue import Queue
0018 except ImportError:
0019     # Python 2
0020     from Queue import Queue
0021 
0022 from idds.common.constants import (Sections, MessageStatus, MessageDestination, MessageType,
0023                                    ProcessingStatus, ContentStatus, ContentRelationType)
0024 from idds.common.exceptions import AgentPluginError, IDDSException
0025 from idds.common.utils import setup_logging, get_logger
0026 from idds.core import (messages as core_messages,
0027                        catalog as core_catalog,
0028                        processings as core_processings,
0029                        health as core_health)
0030 from idds.agents.common.baseagent import BaseAgent
0031 
0032 
0033 setup_logging(__name__)
0034 
0035 
0036 class Conductor(BaseAgent):
0037     """
0038     Conductor works to notify workload management that the data is available.
0039     """
0040 
0041     def __init__(self, num_threads=1, retrieve_bulk_size=200, threshold_to_release_messages=None,
0042                  random_delay=None, delay=300, interval_delay=10, max_retry_delay=3600,
0043                  max_normal_retries=10, max_retries=30, replay_times=2, mode='multiple',
0044                  use_process_pool=False, retry_executor_threads=4, queue_throller=30, **kwargs):
0045         super(Conductor, self).__init__(num_threads=num_threads, name='Conductor', use_process_pool=use_process_pool, **kwargs)
0046         self.config_section = Sections.Conductor
0047         self.retrieve_bulk_size = int(retrieve_bulk_size)
0048         self.message_queue = Queue()
0049         self.output_message_queue = Queue()
0050         if threshold_to_release_messages is None:
0051             self.threshold_to_release_messages = None
0052         else:
0053             self.threshold_to_release_messages = int(threshold_to_release_messages)
0054         if random_delay is None:
0055             self.random_delay = None
0056         else:
0057             self.random_delay = int(random_delay)
0058             if self.random_delay < 5:
0059                 self.random_delay = 5
0060         if delay is None:
0061             delay = 60
0062         self.delay = int(delay)
0063         if not max_retry_delay:
0064             max_retry_delay = 3600
0065         self.max_retry_delay = int(max_retry_delay)
0066 
0067         self.max_normal_retries = int(max_normal_retries)
0068         self.max_retries = int(max_retries)
0069 
0070         if replay_times is None:
0071             replay_times = 2
0072         self.replay_times = int(replay_times)
0073         if not interval_delay:
0074             interval_delay = 10
0075         self.interval_delay = int(interval_delay)
0076         self.logger = get_logger(self.__class__.__name__)
0077 
0078         self.mode = mode
0079         self.selected = None
0080         self.selected_conductor = None
0081 
0082         self.queue_throller = int(queue_throller)
0083         self.retry_executor_threads = int(retry_executor_threads)
0084         self.retry_executor_name = self.executor_name + "_Retry"
0085         self.retry_executor = self.create_executors(self.retry_executor_name, max_workers=self.retry_executor_threads)
0086 
0087     def __del__(self):
0088         self.stop_notifier()
0089 
0090     def is_selected(self):
0091         selected = None
0092         if not self.selected_conductor:
0093             selected = True
0094         else:
0095             selected = self.is_self(self.selected_conductor)
0096         if self.selected is None or self.selected != selected:
0097             self.logger.info("is_selected changed from %s to %s" % (self.selected, selected))
0098         self.selected = selected
0099         return self.selected
0100 
0101     def monitor_conductor(self):
0102         if self.mode == "single":
0103             self.logger.info("Conductor single mode")
0104             self.selected_conductor = core_health.select_agent(name='Conductor', newer_than=self.heartbeat_delay * 2)
0105             self.logger.debug("Selected conductor: %s" % self.selected_conductor)
0106 
0107     def add_conductor_monitor_task(self):
0108         task = self.create_task(task_func=self.monitor_conductor, task_output_queue=None,
0109                                 task_args=tuple(), task_kwargs={}, delay_time=self.heartbeat_delay,
0110                                 priority=1)
0111         self.add_task(task)
0112 
0113     def get_new_messages(self):
0114         """
0115         Get messages
0116         """
0117         if BaseAgent.min_request_id is None:
0118             return []
0119 
0120         destination = [MessageDestination.Outside, MessageDestination.ContentExt, MessageDestination.AsyncResult]
0121         messages = core_messages.retrieve_messages(status=MessageStatus.New,
0122                                                    min_request_id=BaseAgent.min_request_id,
0123                                                    delay=60,
0124                                                    record_fetched=True,
0125                                                    bulk_size=self.retrieve_bulk_size,
0126                                                    destination=destination)
0127 
0128         # self.logger.debug("Main thread get %s new messages" % len(messages))
0129         if messages:
0130             self.logger.info("Main thread get %s new messages" % len(messages))
0131 
0132         return messages
0133 
0134     def get_retry_messages(self):
0135         """
0136         Get messages
0137         """
0138         # msg_type = [MessageType.StageInCollection, MessageType.StageInWork,
0139         #             MessageType.ActiveLearningCollection, MessageType.ActiveLearningWork,
0140         #             MessageType.HyperParameterOptCollection, MessageType.HyperParameterOptWork,
0141         #             MessageType.ProcessingCollection, MessageType.ProcessingWork,
0142         #             MessageType.UnknownCollection, MessageType.UnknownWork]
0143 
0144         retry_messages = []
0145         destination = [MessageDestination.Outside, MessageDestination.ContentExt, MessageDestination.AsyncResult]
0146         messages_d = core_messages.retrieve_messages(status=[MessageStatus.Delivered, MessageStatus.Fetched],
0147                                                      min_request_id=BaseAgent.min_request_id,
0148                                                      use_poll_period=True,
0149                                                      delay=120,
0150                                                      record_fetched=True,
0151                                                      record_fetched_status=MessageStatus.Delivered,
0152                                                      bulk_size=self.retrieve_bulk_size,
0153                                                      destination=destination)    # msg_type=msg_type)
0154         if messages_d:
0155             self.logger.info("Main thread get %s retries messages" % len(messages_d))
0156             retry_messages += messages_d
0157 
0158         return retry_messages
0159 
0160     def clean_messages(self, msgs, confirm=False):
0161         # core_messages.delete_messages(msgs)
0162         msg_status = MessageStatus.Delivered
0163         if confirm:
0164             msg_status = MessageStatus.ConfirmDelivered
0165         to_updates = []
0166         for msg in msgs:
0167             retries = msg['retries']
0168             if retries < self.max_normal_retries:
0169                 rand_num = random.randint(1, retries + 1)
0170                 delay = int(self.delay) * rand_num
0171                 delay = min(delay, self.max_retry_delay)
0172             else:
0173                 delay = self.max_retry_delay
0174             to_updates.append({'msg_id': msg['msg_id'],
0175                                'request_id': msg['request_id'],
0176                                'retries': msg['retries'] + 1,
0177                                'poll_period': datetime.timedelta(seconds=delay),
0178                                'status': msg_status})
0179         core_messages.update_messages(to_updates, min_request_id=BaseAgent.min_request_id)
0180 
0181     def start_notifier(self):
0182         if 'notifier' not in self.plugins:
0183             raise AgentPluginError('Plugin notifier is required')
0184         self.notifier = self.plugins['notifier']
0185 
0186         self.logger.info("Starting notifier: %s" % self.notifier)
0187         self.notifier.set_request_queue(self.message_queue)
0188         self.notifier.set_response_queue(self.output_message_queue)
0189         self.notifier.set_logger(self.logger)
0190         self.notifier.start()
0191 
0192     def stop_notifier(self):
0193         if hasattr(self, 'notifier') and self.notifier:
0194             self.logger.info("Stopping notifier: %s" % self.notifier)
0195             self.notifier.stop()
0196 
0197     def get_output_messages(self):
0198         msgs = []
0199         try:
0200             while not self.output_message_queue.empty():
0201                 msg = self.output_message_queue.get(False)
0202                 if msg:
0203                     msgs.append(msg)
0204         except Exception as error:
0205             self.logger.error("Failed to get output messages: %s, %s" % (error, traceback.format_exc()))
0206         return msgs
0207 
0208     def is_message_processed(self, message):
0209         retries = message['retries']
0210         try:
0211             if message['status'] in [MessageStatus.New]:
0212                 return False
0213             if retries >= self.max_retries:
0214                 self.logger.info("message %s has reached max retries %s" % (message['msg_id'], self.max_retries))
0215                 return True
0216             msg_type = message['msg_type']
0217             if msg_type in [MessageType.AsyncResult]:
0218                 return True
0219             if msg_type not in [MessageType.ProcessingFile]:
0220                 if retries < self.replay_times:
0221                     return False
0222                 else:
0223                     return True
0224             else:
0225                 msg_content = message['msg_content']
0226                 request_id = message['request_id']
0227                 transform_id = message['transform_id']
0228                 if 'files' not in msg_content or not msg_content['files']:
0229                     return True
0230                 if 'relation_type' not in msg_content or msg_content['relation_type'] != 'input':
0231                     return True
0232 
0233                 workload_id = msg_content['workload_id']
0234                 processings = core_processings.get_processings_by_transform_id(transform_id=transform_id)
0235                 find_processing = None
0236                 if processings:
0237                     for processing in processings:
0238                         if processing['workload_id'] == workload_id:
0239                             find_processing = processing
0240                 if find_processing and find_processing['status'] in [ProcessingStatus.Finished, ProcessingStatus.Failed,
0241                                                                      ProcessingStatus.Lost, ProcessingStatus.SubFinished,
0242                                                                      ProcessingStatus.Cancelled, ProcessingStatus.Expired,
0243                                                                      ProcessingStatus.Suspended, ProcessingStatus.Broken]:
0244                     return True
0245 
0246                 files = msg_content['files']
0247                 files_map_id = [f['map_id'] for f in files]
0248                 contents = core_catalog.get_contents_by_request_transform(request_id=request_id,
0249                                                                           transform_id=transform_id)
0250                 proc_conents = {}
0251                 for content in contents:
0252                     if content['content_relation_type'] == ContentRelationType.Output:
0253                         if content['map_id'] not in proc_conents:
0254                             proc_conents[content['map_id']] = []
0255                         if content['status'] not in proc_conents[content['map_id']]:
0256                             proc_conents[content['map_id']].append(content['status'])
0257                 all_map_id_processed = True
0258                 for map_id in files_map_id:
0259                     content_statuses = proc_conents.get(map_id, [])
0260                     if not content_statuses:
0261                         pass
0262                     if (len(content_statuses) == 1 and content_statuses == [ContentStatus.New]) or ContentStatus.Missing in content_statuses:
0263                         all_map_id_processed = False
0264                         return all_map_id_processed
0265                 return all_map_id_processed
0266         except Exception as ex:
0267             self.logger.error(ex)
0268             self.logger.error(traceback.format_exc())
0269 
0270             if retries < self.replay_times:
0271                 return False
0272         return False
0273 
0274     def process_messages(self, messages):
0275         try:
0276             to_discard_messages = []
0277             for message in messages:
0278                 message['destination'] = message['destination'].name
0279                 message['from_idds'] = True
0280 
0281                 # num_contents += message['num_contents']
0282                 if self.is_message_processed(message):
0283                     self.logger.debug("message (msg_id: %s) is already processed, not resend it again" % message['msg_id'])
0284                     to_discard_messages.append(message)
0285                 else:
0286                     self.message_queue.put(message)
0287             if to_discard_messages:
0288                 self.clean_messages(to_discard_messages, confirm=True)
0289         except Exception as ex:
0290             self.logger.error(f"Failed to process messages: {ex}")
0291             self.logger.error(traceback.format_exc())
0292 
0293     def run(self):
0294         """
0295         Main run function.
0296         """
0297         try:
0298             self.logger.info("Starting main thread")
0299             self.init_thread_info()
0300             self.load_plugins()
0301 
0302             self.add_default_tasks()
0303 
0304             task = self.create_task(task_func=self.load_min_request_id, task_output_queue=None, task_args=tuple(), task_kwargs={}, delay_time=600, priority=1)
0305             self.add_task(task)
0306 
0307             if self.mode == "single":
0308                 self.logger.debug("single mode")
0309                 self.add_conductor_monitor_task()
0310 
0311             self.start_notifier()
0312 
0313             # self.add_health_message_task()
0314 
0315             while not self.graceful_stop.is_set():
0316                 # execute timer task
0317                 self.execute_schedules()
0318 
0319                 try:
0320                     # num_contents = 0
0321                     if self.is_selected():
0322                         new_messages = []
0323                         retry_messages = []
0324                         new_messages = self.get_new_messages()
0325                         if new_messages:
0326                             self.process_messages(new_messages)
0327 
0328                         if self.retry_executor.has_free_workers():
0329                             retry_messages = self.get_retry_messages()
0330                             if retry_messages:
0331                                 self.retry_executor.submit(self.process_messages, retry_messages)
0332                     else:
0333                         new_messages = []
0334                         retry_messages = []
0335 
0336                     while self.message_queue.qsize() > self.queue_throller:
0337                         time.sleep(1)
0338                         output_messages = self.get_output_messages()
0339                         self.clean_messages(output_messages)
0340 
0341                     output_messages = self.get_output_messages()
0342                     self.clean_messages(output_messages)
0343                     time.sleep(1)
0344                 except IDDSException as error:
0345                     self.logger.error("Main thread IDDSException: %s" % str(error))
0346                     self.logger.error(traceback.format_exc())
0347                 except Exception as error:
0348                     self.logger.critical("Main thread exception: %s\n%s" % (str(error), traceback.format_exc()))
0349                 # time.sleep(random.randint(5, self.random_delay))
0350         except KeyboardInterrupt:
0351             self.stop()
0352 
0353     def stop(self):
0354         super(Conductor, self).stop()
0355         self.stop_notifier()
0356 
0357 
0358 if __name__ == '__main__':
0359     agent = Conductor()
0360     agent()