Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-25 08:29:10

0001 import json
0002 import logging
0003 import threading
0004 import queue
0005 import time
0006 import uuid
0007 from typing import Dict, Optional
0008 from django.http import StreamingHttpResponse, HttpResponse, JsonResponse
0009 from django.utils import timezone
0010 from django.db.models import F
0011 from django.contrib.auth.models import AnonymousUser
0012 from rest_framework.decorators import api_view, authentication_classes, permission_classes
0013 from rest_framework.authentication import SessionAuthentication, TokenAuthentication
0014 from rest_framework.authtoken.models import Token
0015 from rest_framework.permissions import IsAuthenticated
0016 from django.conf import settings
0017 from channels.layers import get_channel_layer
0018 from asgiref.sync import async_to_sync
0019 from .models import Subscriber
0020 
0021 logger = logging.getLogger(__name__)
0022 
0023 
0024 class SSEMessageBroadcaster:
0025     """
0026     Singleton broadcaster that manages SSE connections and message distribution.
0027     Receives messages from ActiveMQ processor and forwards to connected SSE clients.
0028     """
0029     _instance = None
0030     _lock = threading.Lock()
0031     
0032     def __new__(cls):
0033         if cls._instance is None:
0034             with cls._lock:
0035                 if cls._instance is None:
0036                     cls._instance = super().__new__(cls)
0037         return cls._instance
0038     
0039     def __init__(self):
0040         if hasattr(self, 'initialized'):
0041             return
0042         self.client_queues: Dict[str, queue.Queue] = {}
0043         self.client_filters: Dict[str, Dict] = {}
0044         self.client_subscribers: Dict[str, int] = {}  # Maps client_id to subscriber_id
0045         self._lock = threading.Lock()
0046         self.initialized = True
0047         logger.info("SSE Message Broadcaster initialized")
0048         # Start background subscriber to channel layer group if available
0049         try:
0050             channel_layer = get_channel_layer()
0051             if channel_layer is not None:
0052                 group = getattr(settings, 'SSE_CHANNEL_GROUP', 'workflow_events')
0053                 threading.Thread(
0054                     target=_channel_layer_subscriber_loop,
0055                     args=(group, ),
0056                     name="SSEChannelLayerSubscriber",
0057                     daemon=True,
0058                 ).start()
0059                 logger.info(f"SSE Channel layer subscriber started for group '{group}'")
0060         except Exception as e:
0061             logger.debug(f"SSE channel layer subscriber not started: {e}")
0062     
0063     def add_client(self, client_id: str, request, filters: Optional[Dict] = None) -> queue.Queue:
0064         """Add a new SSE client and track as subscriber."""
0065         with self._lock:
0066             # Create queue (hardcoded for now, will be configurable later)
0067             client_queue = queue.Queue(maxsize=100)
0068             self.client_queues[client_id] = client_queue
0069             self.client_filters[client_id] = filters or {}
0070             
0071             # Create/update subscriber record
0072             subscriber_name = f"sse_{client_id[:8]}"  # Use first 8 chars of UUID
0073             
0074             subscriber, created = Subscriber.objects.update_or_create(
0075                 subscriber_name=subscriber_name,
0076                 defaults={
0077                     'delivery_type': 'sse',
0078                     'client_ip': self._get_client_ip(request),
0079                     'client_location': self._get_client_location(request),
0080                     'connected_at': timezone.now(),
0081                     'disconnected_at': None,
0082                     'last_activity': timezone.now(),
0083                     'is_active': True,
0084                     'message_filters': filters or {},
0085                     'description': f"SSE client from {self._get_client_ip(request)}"
0086                 }
0087             )
0088             
0089             # Store subscriber ID for cleanup on disconnect
0090             self.client_subscribers[client_id] = subscriber.subscriber_id
0091             
0092             logger.info(f"Added SSE client {client_id} as subscriber {subscriber_name}")
0093             return client_queue
0094     
0095     def remove_client(self, client_id: str):
0096         """Remove disconnected SSE client and update subscriber record."""
0097         with self._lock:
0098             self.client_queues.pop(client_id, None)
0099             self.client_filters.pop(client_id, None)
0100             
0101             # Update subscriber record
0102             if client_id in self.client_subscribers:
0103                 try:
0104                     subscriber = Subscriber.objects.get(
0105                         subscriber_id=self.client_subscribers[client_id]
0106                     )
0107                     subscriber.disconnected_at = timezone.now()
0108                     subscriber.is_active = False
0109                     subscriber.save()
0110                 except Subscriber.DoesNotExist:
0111                     pass
0112                 
0113                 self.client_subscribers.pop(client_id, None)
0114             
0115             logger.info(f"Removed SSE client {client_id}")
0116     
0117     def broadcast_message(self, message_data: Dict):
0118         """
0119         Broadcast a message to all connected SSE clients that match filters.
0120         Called by ActiveMQ processor when new messages arrive.
0121         """
0122         with self._lock:
0123             disconnected_clients = []
0124             
0125             for client_id, client_queue in self.client_queues.items():
0126                 try:
0127                     # Check if message passes client's filters
0128                     if self._message_matches_filters(message_data, self.client_filters.get(client_id, {})):
0129                         # Update subscriber stats
0130                         if client_id in self.client_subscribers:
0131                             self._update_subscriber_stats(self.client_subscribers[client_id], 'sent')
0132                         
0133                         # Non-blocking put
0134                         try:
0135                             client_queue.put_nowait(message_data)
0136                         except queue.Full:
0137                             # Remove oldest message and add new one
0138                             try:
0139                                 client_queue.get_nowait()
0140                                 client_queue.put_nowait(message_data)
0141                                 if client_id in self.client_subscribers:
0142                                     self._update_subscriber_stats(self.client_subscribers[client_id], 'dropped')
0143                             except queue.Empty:
0144                                 pass
0145                 except Exception as e:
0146                     logger.error(f"Error broadcasting to client {client_id}: {e}")
0147                     disconnected_clients.append(client_id)
0148             
0149             # Clean up disconnected clients
0150             for client_id in disconnected_clients:
0151                 self.remove_client(client_id)
0152     
0153     def _message_matches_filters(self, message: Dict, filters: Dict) -> bool:
0154         """Check if a message matches the client's subscription filters."""
0155         if not filters:
0156             return True
0157         
0158         # Filter by message type
0159         if 'msg_types' in filters:
0160             msg_type = message.get('msg_type')
0161             if msg_type not in filters['msg_types']:
0162                 return False
0163         
0164         # Filter by agent
0165         if 'agents' in filters:
0166             sender = message.get('processed_by', '')
0167             if sender not in filters['agents']:
0168                 return False
0169         
0170         # Filter by run_id
0171         if 'run_ids' in filters:
0172             run_id = message.get('run_id')
0173             if run_id not in filters['run_ids']:
0174                 return False
0175         
0176         return True
0177     
0178     def _get_client_ip(self, request):
0179         """Extract client IP from request."""
0180         x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
0181         if x_forwarded_for:
0182             ip = x_forwarded_for.split(',')[0]
0183         else:
0184             ip = request.META.get('REMOTE_ADDR')
0185         return ip
0186     
0187     def _get_client_location(self, request):
0188         """Determine client location from IP or headers."""
0189         # Could be enhanced with IP geolocation
0190         # For now, check for custom header or default
0191         location = request.META.get('HTTP_X_CLIENT_LOCATION', '')
0192         if not location:
0193             # Simple heuristic based on IP ranges (customize for your network)
0194             ip = self._get_client_ip(request)
0195             if ip.startswith('192.168.'):
0196                 location = 'Local'
0197             elif ip.startswith('10.'):
0198                 location = 'Internal'
0199             else:
0200                 location = 'Remote'
0201         return location
0202     
0203     def _update_subscriber_stats(self, subscriber_id: int, stat_type: str):
0204         """Update subscriber statistics in database."""
0205         try:
0206             if stat_type == 'sent':
0207                 Subscriber.objects.filter(subscriber_id=subscriber_id).update(
0208                     messages_sent=F('messages_sent') + 1,
0209                     last_activity=timezone.now()
0210                 )
0211             elif stat_type == 'dropped':
0212                 Subscriber.objects.filter(subscriber_id=subscriber_id).update(
0213                     messages_dropped=F('messages_dropped') + 1
0214                 )
0215         except Exception as e:
0216             logger.error(f"Failed to update subscriber stats: {e}")
0217 
0218 
0219 def sse_event_generator(client_id: str, client_queue: queue.Queue):
0220     """
0221     Generator function that yields SSE events from the client's message queue.
0222     """
0223     logger.info(f"Starting SSE event stream for client {client_id}")
0224     
0225     # Send initial connection message
0226     yield f"event: connected\ndata: {json.dumps({'client_id': client_id, 'status': 'connected'})}\n\n"
0227     
0228     # Heartbeat interval hardcoded for now (will be configurable later)
0229     last_heartbeat = time.time()
0230     heartbeat_interval = 30  # seconds
0231     
0232     try:
0233         while True:
0234             try:
0235                 # Try to get a message with short timeout (hardcoded for now)
0236                 message = client_queue.get(timeout=1.0)
0237                 
0238                 # Format as SSE event
0239                 event_type = message.get('msg_type', 'message')
0240                 event_data = json.dumps(message)
0241                 yield f"event: {event_type}\ndata: {event_data}\n\n"
0242                 
0243             except queue.Empty:
0244                 # No message available, check if we need to send heartbeat
0245                 current_time = time.time()
0246                 if current_time - last_heartbeat > heartbeat_interval:
0247                     yield f"event: heartbeat\ndata: {json.dumps({'timestamp': current_time})}\n\n"
0248                     last_heartbeat = current_time
0249                     
0250     except GeneratorExit:
0251         logger.info(f"SSE client {client_id} disconnected")
0252     except Exception as e:
0253         logger.error(f"Error in SSE event generator for client {client_id}: {e}")
0254 
0255 
0256 def sse_message_stream(request):
0257     """
0258     SSE endpoint for streaming ActiveMQ messages to remote clients.
0259     
0260     This is a plain Django view (not DRF) to avoid content negotiation issues with SSE.
0261     Authentication is handled manually to support text/event-stream responses.
0262     
0263     Query parameters:
0264     - msg_types: Comma-separated list of message types to filter (e.g., "stf_gen,data_ready")
0265     - agents: Comma-separated list of agent names to filter
0266     - run_ids: Comma-separated list of run IDs to filter
0267     
0268     Example:
0269     GET /api/messages/stream/?msg_types=stf_gen,data_ready&agents=daq-simulator
0270     """
0271     # Manual authentication handling (supports both session and token auth)
0272     user = request.user if hasattr(request, 'user') else AnonymousUser()
0273     
0274     # Check for token authentication if user is not authenticated
0275     if not user.is_authenticated:
0276         auth_header = request.META.get('HTTP_AUTHORIZATION', '')
0277         if auth_header.startswith('Token '):
0278             token_key = auth_header[6:]  # Remove 'Token ' prefix
0279             try:
0280                 token = Token.objects.get(key=token_key)
0281                 user = token.user
0282             except Token.DoesNotExist:
0283                 pass
0284     
0285     # Check if user is authenticated
0286     if not user.is_authenticated:
0287         return HttpResponse(
0288             json.dumps({'detail': 'Authentication credentials were not provided.'}),
0289             status=401,
0290             content_type='application/json'
0291         )
0292     
0293     # Generate unique client ID
0294     client_id = str(uuid.uuid4())
0295     
0296     # Parse filters from query parameters
0297     filters = {}
0298     
0299     msg_types = request.GET.get('msg_types')
0300     if msg_types:
0301         filters['msg_types'] = [t.strip() for t in msg_types.split(',')]
0302     
0303     agents = request.GET.get('agents')
0304     if agents:
0305         filters['agents'] = [a.strip() for a in agents.split(',')]
0306     
0307     run_ids = request.GET.get('run_ids')
0308     if run_ids:
0309         filters['run_ids'] = [r.strip() for r in run_ids.split(',')]
0310     
0311     # Get broadcaster instance and add client
0312     broadcaster = SSEMessageBroadcaster()
0313     client_queue = broadcaster.add_client(client_id, request, filters)
0314     
0315     def event_stream():
0316         try:
0317             yield from sse_event_generator(client_id, client_queue)
0318         finally:
0319             broadcaster.remove_client(client_id)
0320     
0321     # Create SSE response with appropriate headers
0322     response = StreamingHttpResponse(
0323         event_stream(),
0324         content_type='text/event-stream'
0325     )
0326     response['Cache-Control'] = 'no-cache'
0327     response['X-Accel-Buffering'] = 'no'  # Disable Nginx buffering
0328     response['Access-Control-Allow-Origin'] = '*'  # Configure as needed for production
0329     
0330     return response
0331 
0332 
0333 @api_view(['GET'])
0334 @authentication_classes([SessionAuthentication, TokenAuthentication])
0335 @permission_classes([IsAuthenticated])
0336 def sse_status(request):
0337     """
0338     Get current SSE broadcaster status including connected clients.
0339     """
0340     broadcaster = SSEMessageBroadcaster()
0341     
0342     status = {
0343         'connected_clients': len(broadcaster.client_queues),
0344         'client_ids': list(broadcaster.client_queues.keys()),
0345         'client_filters': broadcaster.client_filters
0346     }
0347     
0348     return JsonResponse(status)
0349 
0350 
0351 def _channel_layer_subscriber_loop(group_name: str):
0352     """Background loop: receive messages from Channels group and forward to SSE broadcaster."""
0353     try:
0354         channel_layer = get_channel_layer()
0355         if channel_layer is None:
0356             logger.debug("No channel layer available; subscriber loop exiting")
0357             return
0358         # Create a unique channel and join the group
0359         channel_name = async_to_sync(channel_layer.new_channel)()
0360         async_to_sync(channel_layer.group_add)(group_name, channel_name)
0361         logger.info(f"Subscribed to channel layer group '{group_name}' as '{channel_name}'")
0362         broadcaster = SSEMessageBroadcaster()
0363         while True:
0364             message = async_to_sync(channel_layer.receive)(channel_name)
0365             if not message:
0366                 continue
0367             if message.get('type') == 'broadcast':
0368                 payload = message.get('payload', {})
0369                 try:
0370                     broadcaster.broadcast_message(payload)
0371                 except Exception as e:
0372                     logger.error(f"Failed to broadcast SSE payload from channel layer: {e}")
0373     except Exception as e:
0374         logger.error(f"Channel layer subscriber loop error: {e}")