Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:58:20

0001 #!/usr/bin/env python
0002 #
0003 # Licensed under the Apache License, Version 2.0 (the "License");
0004 # You may not use this file except in compliance with the License.
0005 # You may obtain a copy of the License at
0006 # http://www.apache.org/licenses/LICENSE-2.0OA
0007 #
0008 # Authors:
0009 # - Wen Guan, <wen.guan@cern.ch>, 2019 - 2024
0010 
0011 
0012 """
0013 operations related to Messages.
0014 """
0015 
0016 import datetime
0017 import re
0018 import copy
0019 
0020 from sqlalchemy import or_, asc
0021 from sqlalchemy.exc import DatabaseError, IntegrityError
0022 
0023 from idds.common import exceptions
0024 from idds.common.constants import MessageDestination
0025 from idds.common.utils import group_list
0026 from idds.orm.base import models
0027 from idds.orm.base.session import transactional_session
0028 
0029 
0030 @transactional_session
0031 def add_message(msg_type, status, source, request_id, workload_id, transform_id,
0032                 num_contents, msg_content, internal_id=None, bulk_size=None, processing_id=None,
0033                 destination=MessageDestination.Outside, session=None):
0034     """
0035     Add a message to be submitted asynchronously to a message broker.
0036 
0037     :param msg_type: The type of the msg as a number, e.g., finished_stagein.
0038     :param status: The status about the message
0039     :param source: The source where the message is from.
0040     :param request_id: The request id.
0041     :param workload_id: The workload id.
0042     :param transform_id: The transform id.
0043     :param num_contents: Number of items in msg_content.
0044     :param msg_content: The message msg_content as JSON.
0045     :param session: The database session.
0046     """
0047 
0048     try:
0049         num_contents_list = []
0050         msg_content_list = []
0051         if bulk_size and num_contents > bulk_size:
0052             if 'files' in msg_content:
0053                 files = msg_content['files']
0054                 chunks = [files[i:i + bulk_size] for i in range(0, len(files), bulk_size)]
0055                 for chunk in chunks:
0056                     new_msg_content = copy.deepcopy(msg_content)
0057                     new_msg_content['files'] = chunk
0058                     new_num_contents = len(chunk)
0059                     num_contents_list.append(new_num_contents)
0060                     msg_content_list.append(new_msg_content)
0061             else:
0062                 num_contents_list.append(num_contents)
0063                 msg_content_list.append(msg_content)
0064         else:
0065             num_contents_list.append(num_contents)
0066             msg_content_list.append(msg_content)
0067 
0068         msgs = []
0069         for msg_content, num_contents in zip(msg_content_list, num_contents_list):
0070             new_message = {'msg_type': msg_type, 'status': status, 'request_id': request_id,
0071                            'workload_id': workload_id, 'transform_id': transform_id,
0072                            'internal_id': internal_id, 'source': source, 'num_contents': num_contents,
0073                            'destination': destination, 'processing_id': processing_id,
0074                            'locking': 0, 'msg_content': msg_content}
0075             msgs.append(new_message)
0076 
0077         session.bulk_insert_mappings(models.Message, msgs)
0078     except TypeError as e:
0079         raise exceptions.DatabaseException('Invalid JSON for msg_content: %s' % str(e))
0080     except DatabaseError as e:
0081         if re.match('.*ORA-12899.*', e.args[0]) \
0082            or re.match('.*1406.*', e.args[0]):
0083             raise exceptions.DatabaseException('Could not persist message, msg_content too large: %s' % str(e))
0084         else:
0085             raise exceptions.DatabaseException('Could not persist message: %s' % str(e))
0086 
0087 
0088 @transactional_session
0089 def add_messages(messages, bulk_size=1000, session=None):
0090     try:
0091         # session.bulk_insert_mappings(models.Message, messages)
0092         for msg in messages:
0093             add_message(**msg, bulk_size=bulk_size, session=session)
0094     except TypeError as e:
0095         raise exceptions.DatabaseException('Invalid JSON for msg_content: %s' % str(e))
0096     except DatabaseError as e:
0097         if re.match('.*ORA-12899.*', e.args[0]) \
0098            or re.match('.*1406.*', e.args[0]):
0099             raise exceptions.DatabaseException('Could not persist message, msg_content too large: %s' % str(e))
0100         else:
0101             raise exceptions.DatabaseException('Could not persist message: %s' % str(e))
0102 
0103 
0104 @transactional_session
0105 def update_messages(messages, bulk_size=1000, use_bulk_update_mappings=False, request_id=None, transform_id=None, min_request_id=None, session=None):
0106     try:
0107         if use_bulk_update_mappings:
0108             session.bulk_update_mappings(models.Message, messages)
0109         else:
0110             groups = group_list(messages, key='msg_id')
0111             for group_key in groups:
0112                 group = groups[group_key]
0113                 keys = group['keys']
0114                 items = group['items']
0115                 query = session.query(models.Message)
0116                 if request_id:
0117                     query = query.filter(models.Message.request_id == request_id)
0118                 else:
0119                     if min_request_id:
0120                         query = query.filter(or_(models.Message.request_id >= min_request_id,
0121                                                  models.Message.request_id.is_(None)))
0122                 if transform_id:
0123                     query = query.filter(models.Message.transform_id == transform_id)
0124                 query = query.filter(models.Message.msg_id.in_(keys))\
0125                              .update(items, synchronize_session=False)
0126     except TypeError as e:
0127         raise exceptions.DatabaseException('Invalid JSON for msg_content: %s' % str(e))
0128     except DatabaseError as e:
0129         if re.match('.*ORA-12899.*', e.args[0]) \
0130            or re.match('.*1406.*', e.args[0]):
0131             raise exceptions.DatabaseException('Could not persist message, msg_content too large: %s' % str(e))
0132         else:
0133             raise exceptions.DatabaseException('Could not persist message: %s' % str(e))
0134 
0135 
0136 @transactional_session
0137 def retrieve_messages(bulk_size=1000, msg_type=None, status=None, source=None,
0138                       destination=None, request_id=None, workload_id=None,
0139                       transform_id=None, processing_id=None, fetching_id=None,
0140                       min_request_id=None, use_poll_period=False, retries=None,
0141                       delay=None, internal_id=None, session=None):
0142     """
0143     Retrieve up to $bulk messages.
0144 
0145     :param bulk: Number of messages as an integer.
0146     :param msg_type: Return only specified msg_type.
0147     :param status: The status about the message
0148     :param source: The source where the message is from.
0149     :param session: The database session.
0150 
0151     :returns messages: List of dictionaries
0152     """
0153     messages = []
0154     try:
0155         if destination is not None:
0156             if not isinstance(destination, (list, tuple)):
0157                 destination = [destination]
0158             if len(destination) == 1:
0159                 destination = [destination[0], destination[0]]
0160         if msg_type is not None:
0161             if not isinstance(msg_type, (list, tuple)):
0162                 msg_type = [msg_type]
0163             if len(msg_type) == 1:
0164                 msg_type = [msg_type[0], msg_type[0]]
0165         if status is not None:
0166             if not isinstance(status, (list, tuple)):
0167                 status = [status]
0168             if len(status) == 1:
0169                 status = [status[0], status[0]]
0170 
0171         query = session.query(models.Message)
0172 
0173         if msg_type is not None:
0174             query = query.filter(models.Message.msg_type.in_(msg_type))
0175         if status is not None:
0176             query = query.filter(models.Message.status.in_(status))
0177         if source is not None:
0178             query = query.filter_by(source=source)
0179         if destination is not None:
0180             query = query.filter(models.Message.destination.in_(destination))
0181         if request_id is not None:
0182             query = query.filter_by(request_id=request_id)
0183         else:
0184             if min_request_id:
0185                 query = query.filter(or_(models.Message.request_id >= min_request_id,
0186                                          models.Message.request_id.is_(None)))
0187         if workload_id is not None:
0188             query = query.filter_by(workload_id=workload_id)
0189         if transform_id is not None:
0190             query = query.filter_by(transform_id=transform_id)
0191         if processing_id is not None:
0192             query = query.filter_by(processing_id=processing_id)
0193         if internal_id is not None:
0194             query = query.filter_by(internal_id=internal_id)
0195         if retries:
0196             query = query.filter_by(retries=retries)
0197         if delay:
0198             query = query.filter(models.Message.updated_at < datetime.datetime.utcnow() - datetime.timedelta(seconds=delay))
0199         elif use_poll_period:
0200             query = query.filter(models.Message.updated_at + models.Message.poll_period <= datetime.datetime.utcnow())
0201 
0202         query = query.order_by(asc(models.Message.updated_at))
0203 
0204         if bulk_size:
0205             query = query.order_by(models.Message.created_at).limit(bulk_size)
0206         # query = query.with_for_update(nowait=True)
0207 
0208         tmp = query.all()
0209         if tmp:
0210             for t in tmp:
0211                 message = t.to_dict()
0212                 messages.append(message)
0213         return messages
0214     except IntegrityError as e:
0215         raise exceptions.DatabaseException(e.args)
0216 
0217 
0218 @transactional_session
0219 def delete_messages(messages, session=None):
0220     """
0221     Delete all messages with the given IDs.
0222 
0223     :param messages: The messages to delete as a list of dictionaries.
0224     """
0225     message_condition = []
0226     for message in messages:
0227         message_condition.append(models.Message.msg_id == message['msg_id'])
0228 
0229     try:
0230         if message_condition:
0231             session.query(models.Message).\
0232                 filter(or_(*message_condition)).\
0233                 delete(synchronize_session=False)
0234     except IntegrityError as e:
0235         raise exceptions.DatabaseException(e.args)
0236 
0237 
0238 @transactional_session
0239 def clean_old_messages(request_id, session=None):
0240     """
0241     Delete messages whose request id is older than request_id.
0242 
0243     :param request_id: request id..
0244     """
0245     session.query(models.Message)\
0246            .filter(models.Message.request_id <= request_id)\
0247            .delete(synchronize_session=False)
0248 
0249 # @transactional_session
0250 # def update_messages(messages, session=None):
0251 #     """
0252 #     Update all messages status with the given IDs.
0253 #
0254 #     :param messages: The messages to be updated as a list of dictionaries.
0255 #     """
0256 #     try:
0257 #         for msg in messages:
0258 #             session.query(models.Message).filter_by(msg_id=msg['msg_id']).update({'status': msg['status']}, synchronize_session=False)
0259 #     except IntegrityError as e:
0260 #         raise exceptions.DatabaseException(e.args)