""" WebSocket message handling and protocol implementation. """ import asyncio import json import logging from typing import Optional from fastapi import WebSocket, WebSocketDisconnect from websockets.exceptions import ConnectionClosedError from .messages import ( parse_incoming_message, serialize_outgoing_message, MessageTypes, create_state_report ) from .models import ( SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage, RequestStateMessage, PatchSessionResultMessage ) from .state import worker_state, SystemMetrics from ..models import ModelManager from ..streaming.manager import shared_stream_manager from ..tracking.integration import TrackingPipelineIntegration logger = logging.getLogger(__name__) # Constants HEARTBEAT_INTERVAL = 2.0 # seconds WORKER_TIMEOUT_MS = 10000 # Global model manager instance model_manager = ModelManager() class WebSocketHandler: """ Handles WebSocket connection lifecycle and message processing. """ def __init__(self, websocket: WebSocket): self.websocket = websocket self.connected = False self._heartbeat_task: Optional[asyncio.Task] = None self._message_task: Optional[asyncio.Task] = None self._heartbeat_count = 0 self._last_processed_models: set = set() # Cache of last processed model IDs async def handle_connection(self) -> None: """ Main connection handler that manages the WebSocket lifecycle. Based on the original architecture from archive/app.py """ client_info = f"{self.websocket.client.host}:{self.websocket.client.port}" if self.websocket.client else "unknown" logger.info(f"Starting WebSocket handler for {client_info}") stream_task = None try: logger.info(f"Accepting WebSocket connection from {client_info}") await self.websocket.accept() self.connected = True logger.info(f"WebSocket connection accepted and established for {client_info}") # Send immediate heartbeat to show connection is alive await self._send_immediate_heartbeat() # Start background tasks (matching original architecture) stream_task = asyncio.create_task(self._process_streams()) heartbeat_task = asyncio.create_task(self._send_heartbeat()) message_task = asyncio.create_task(self._handle_messages()) logger.info(f"WebSocket background tasks started for {client_info} (stream + heartbeat + message handler)") # Wait for heartbeat and message tasks (stream runs independently) await asyncio.gather(heartbeat_task, message_task) except Exception as e: logger.error(f"Error in WebSocket connection for {client_info}: {e}", exc_info=True) finally: logger.info(f"Cleaning up connection for {client_info}") # Cancel stream task if stream_task and not stream_task.done(): stream_task.cancel() try: await stream_task except asyncio.CancelledError: logger.debug(f"Stream task cancelled for {client_info}") await self._cleanup() async def _send_immediate_heartbeat(self) -> None: """Send immediate heartbeat on connection to show we're alive.""" try: cpu_usage = SystemMetrics.get_cpu_usage() memory_usage = SystemMetrics.get_memory_usage() gpu_usage = SystemMetrics.get_gpu_usage() gpu_memory_usage = SystemMetrics.get_gpu_memory_usage() camera_connections = worker_state.get_camera_connections() state_report = create_state_report( cpu_usage=cpu_usage, memory_usage=memory_usage, gpu_usage=gpu_usage, gpu_memory_usage=gpu_memory_usage, camera_connections=camera_connections ) await self._send_message(state_report) logger.info(f"[TX → Backend] Initial stateReport: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, " f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras") except Exception as e: logger.error(f"Error sending immediate heartbeat: {e}") async def _send_heartbeat(self) -> None: """Send periodic state reports as heartbeat.""" while self.connected: try: # Collect system metrics cpu_usage = SystemMetrics.get_cpu_usage() memory_usage = SystemMetrics.get_memory_usage() gpu_usage = SystemMetrics.get_gpu_usage() gpu_memory_usage = SystemMetrics.get_gpu_memory_usage() camera_connections = worker_state.get_camera_connections() # Create and send state report state_report = create_state_report( cpu_usage=cpu_usage, memory_usage=memory_usage, gpu_usage=gpu_usage, gpu_memory_usage=gpu_memory_usage, camera_connections=camera_connections ) await self._send_message(state_report) # Only log full details every 10th heartbeat, otherwise just show a dot self._heartbeat_count += 1 if self._heartbeat_count % 10 == 0: logger.info(f"[TX → Backend] Heartbeat #{self._heartbeat_count}: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, " f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras") else: print(".", end="", flush=True) # Just show a dot to indicate heartbeat activity await asyncio.sleep(HEARTBEAT_INTERVAL) except Exception as e: logger.error(f"Error sending heartbeat: {e}") break async def _handle_messages(self) -> None: """Handle incoming WebSocket messages.""" while self.connected: try: raw_message = await self.websocket.receive_text() logger.info(f"[RX ← Backend] {raw_message}") # Parse incoming message message = parse_incoming_message(raw_message) if not message: logger.warning("Failed to parse incoming message") continue # Route message to appropriate handler await self._route_message(message) except (WebSocketDisconnect, ConnectionClosedError) as e: logger.warning(f"WebSocket disconnected: {e}") break except json.JSONDecodeError: logger.error("Received invalid JSON message") except Exception as e: logger.error(f"Error handling message: {e}") break async def _route_message(self, message) -> None: """Route parsed message to appropriate handler.""" message_type = message.type try: if message_type == MessageTypes.SET_SUBSCRIPTION_LIST: await self._handle_set_subscription_list(message) elif message_type == MessageTypes.SET_SESSION_ID: await self._handle_set_session_id(message) elif message_type == MessageTypes.SET_PROGRESSION_STAGE: await self._handle_set_progression_stage(message) elif message_type == MessageTypes.REQUEST_STATE: await self._handle_request_state(message) elif message_type == MessageTypes.PATCH_SESSION_RESULT: await self._handle_patch_session_result(message) else: logger.warning(f"Unknown message type: {message_type}") except Exception as e: logger.error(f"Error handling {message_type} message: {e}") async def _handle_set_subscription_list(self, message: SetSubscriptionListMessage) -> None: """Handle setSubscriptionList message for declarative subscription management.""" logger.info(f"[RX Processing] setSubscriptionList with {len(message.subscriptions)} subscriptions") # Update worker state with new subscriptions worker_state.set_subscriptions(message.subscriptions) # Phase 2: Download and manage models await self._ensure_models(message.subscriptions) # Phase 3 & 4: Integrate with streaming management and tracking await self._update_stream_subscriptions(message.subscriptions) logger.info("Subscription list updated successfully") async def _ensure_models(self, subscriptions) -> None: """Ensure all required models are downloaded and available.""" # Extract unique model requirements unique_models = {} for subscription in subscriptions: model_id = subscription.modelId if model_id not in unique_models: unique_models[model_id] = { 'model_url': subscription.modelUrl, 'model_name': subscription.modelName } # Check if model set has changed to avoid redundant processing current_model_ids = set(unique_models.keys()) if current_model_ids == self._last_processed_models: logger.debug(f"[Model Management] Model set unchanged {list(current_model_ids)}, skipping checks") return logger.info(f"[Model Management] Processing {len(unique_models)} unique models: {list(unique_models.keys())}") self._last_processed_models = current_model_ids # Check and download models concurrently download_tasks = [] for model_id, model_info in unique_models.items(): task = asyncio.create_task( self._ensure_single_model(model_id, model_info['model_url'], model_info['model_name']) ) download_tasks.append(task) # Wait for all downloads to complete if download_tasks: results = await asyncio.gather(*download_tasks, return_exceptions=True) # Log results success_count = 0 for i, result in enumerate(results): model_id = list(unique_models.keys())[i] if isinstance(result, Exception): logger.error(f"[Model Management] Failed to ensure model {model_id}: {result}") elif result: success_count += 1 logger.info(f"[Model Management] Model {model_id} ready for use") else: logger.error(f"[Model Management] Failed to ensure model {model_id}") logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models") async def _update_stream_subscriptions(self, subscriptions) -> None: """Update streaming subscriptions with tracking integration.""" try: # Convert subscriptions to the format expected by StreamManager subscription_payloads = [] for subscription in subscriptions: payload = { 'subscriptionIdentifier': subscription.subscriptionIdentifier, 'rtspUrl': subscription.rtspUrl, 'snapshotUrl': subscription.snapshotUrl, 'snapshotInterval': subscription.snapshotInterval, 'modelId': subscription.modelId, 'modelUrl': subscription.modelUrl, 'modelName': subscription.modelName } # Add crop coordinates if present if hasattr(subscription, 'cropX1'): payload.update({ 'cropX1': subscription.cropX1, 'cropY1': subscription.cropY1, 'cropX2': subscription.cropX2, 'cropY2': subscription.cropY2 }) subscription_payloads.append(payload) # Reconcile subscriptions with StreamManager logger.info("[Streaming] Reconciling stream subscriptions with tracking") reconcile_result = await self._reconcile_subscriptions_with_tracking(subscription_payloads) logger.info(f"[Streaming] Subscription reconciliation complete: " f"added={reconcile_result.get('added', 0)}, " f"removed={reconcile_result.get('removed', 0)}, " f"failed={reconcile_result.get('failed', 0)}") except Exception as e: logger.error(f"Error updating stream subscriptions: {e}", exc_info=True) async def _reconcile_subscriptions_with_tracking(self, target_subscriptions) -> dict: """Reconcile subscriptions with tracking integration.""" try: # First, we need to create tracking integrations for each unique model tracking_integrations = {} for subscription_payload in target_subscriptions: model_id = subscription_payload['modelId'] # Create tracking integration if not already created if model_id not in tracking_integrations: # Get pipeline configuration for this model pipeline_parser = model_manager.get_pipeline_config(model_id) if pipeline_parser: # Create tracking integration tracking_integration = TrackingPipelineIntegration( pipeline_parser, model_manager ) # Initialize tracking model success = await tracking_integration.initialize_tracking_model() if success: tracking_integrations[model_id] = tracking_integration logger.info(f"[Tracking] Created tracking integration for model {model_id}") else: logger.warning(f"[Tracking] Failed to initialize tracking for model {model_id}") else: logger.warning(f"[Tracking] No pipeline config found for model {model_id}") # Now reconcile with StreamManager, adding tracking integrations current_subscription_ids = set() for subscription_info in shared_stream_manager.get_all_subscriptions(): current_subscription_ids.add(subscription_info.subscription_id) target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions} # Find subscriptions to remove and add to_remove = current_subscription_ids - target_subscription_ids to_add = target_subscription_ids - current_subscription_ids # Remove old subscriptions removed_count = 0 for subscription_id in to_remove: if shared_stream_manager.remove_subscription(subscription_id): removed_count += 1 logger.info(f"[Streaming] Removed subscription {subscription_id}") # Add new subscriptions with tracking added_count = 0 failed_count = 0 for subscription_payload in target_subscriptions: subscription_id = subscription_payload['subscriptionIdentifier'] if subscription_id in to_add: success = await self._add_subscription_with_tracking( subscription_payload, tracking_integrations ) if success: added_count += 1 logger.info(f"[Streaming] Added subscription {subscription_id} with tracking") else: failed_count += 1 logger.error(f"[Streaming] Failed to add subscription {subscription_id}") return { 'removed': removed_count, 'added': added_count, 'failed': failed_count, 'total_active': len(shared_stream_manager.get_all_subscriptions()) } except Exception as e: logger.error(f"Error in subscription reconciliation with tracking: {e}", exc_info=True) return {'removed': 0, 'added': 0, 'failed': 0, 'total_active': 0} async def _add_subscription_with_tracking(self, payload, tracking_integrations) -> bool: """Add a subscription with tracking integration.""" try: from ..streaming.manager import StreamConfig subscription_id = payload['subscriptionIdentifier'] camera_id = subscription_id.split(';')[-1] model_id = payload['modelId'] # Get tracking integration for this model tracking_integration = tracking_integrations.get(model_id) # Extract crop coordinates if present crop_coords = None if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']): crop_coords = ( payload['cropX1'], payload['cropY1'], payload['cropX2'], payload['cropY2'] ) # Create stream configuration stream_config = StreamConfig( camera_id=camera_id, rtsp_url=payload.get('rtspUrl'), snapshot_url=payload.get('snapshotUrl'), snapshot_interval=payload.get('snapshotInterval', 5000), max_retries=3, save_test_frames=False # Disable frame saving, focus on tracking ) # Add subscription to StreamManager with tracking success = shared_stream_manager.add_subscription( subscription_id=subscription_id, stream_config=stream_config, crop_coords=crop_coords, model_id=model_id, model_url=payload.get('modelUrl'), tracking_integration=tracking_integration ) if success and tracking_integration: logger.info(f"[Tracking] Subscription {subscription_id} configured with tracking for model {model_id}") return success except Exception as e: logger.error(f"Error adding subscription with tracking: {e}", exc_info=True) return False async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool: """Ensure a single model is downloaded and available.""" try: # Check if model is already available if model_manager.is_model_downloaded(model_id): logger.info(f"[Model Management] Model {model_id} ({model_name}) already available") return True # Download and extract model in a thread pool to avoid blocking the event loop logger.info(f"[Model Management] Downloading model {model_id} ({model_name}) from {model_url}") # Use asyncio.to_thread for CPU-bound operations (Python 3.9+) # For compatibility, we'll use run_in_executor loop = asyncio.get_event_loop() model_path = await loop.run_in_executor( None, model_manager.ensure_model, model_id, model_url, model_name ) if model_path: logger.info(f"[Model Management] Successfully prepared model {model_id} at {model_path}") return True else: logger.error(f"[Model Management] Failed to prepare model {model_id}") return False except Exception as e: logger.error(f"[Model Management] Exception ensuring model {model_id}: {str(e)}", exc_info=True) return False async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None: """Handle setSessionId message.""" display_identifier = message.payload.displayIdentifier session_id = message.payload.sessionId logger.info(f"[RX Processing] setSessionId for display {display_identifier}: {session_id}") # Update worker state worker_state.set_session_id(display_identifier, session_id) # Update tracking integrations with session ID shared_stream_manager.set_session_id(display_identifier, session_id) async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None: """Handle setProgressionStage message.""" display_identifier = message.payload.displayIdentifier stage = message.payload.progressionStage logger.info(f"[RX Processing] setProgressionStage for display {display_identifier}: {stage}") # Update worker state worker_state.set_progression_stage(display_identifier, stage) # If stage indicates session is cleared/finished, clear from tracking if stage in ['finished', 'cleared', 'idle']: # Get session ID for this display and clear it session_id = worker_state.get_session_id(display_identifier) if session_id: shared_stream_manager.clear_session_id(session_id) logger.info(f"[Tracking] Cleared session {session_id} due to progression stage: {stage}") async def _handle_request_state(self, message: RequestStateMessage) -> None: """Handle requestState message by sending immediate state report.""" logger.debug("[RX Processing] requestState - sending immediate state report") # Collect metrics and send state report cpu_usage = SystemMetrics.get_cpu_usage() memory_usage = SystemMetrics.get_memory_usage() gpu_usage = SystemMetrics.get_gpu_usage() gpu_memory_usage = SystemMetrics.get_gpu_memory_usage() camera_connections = worker_state.get_camera_connections() state_report = create_state_report( cpu_usage=cpu_usage, memory_usage=memory_usage, gpu_usage=gpu_usage, gpu_memory_usage=gpu_memory_usage, camera_connections=camera_connections ) await self._send_message(state_report) async def _handle_patch_session_result(self, message: PatchSessionResultMessage) -> None: """Handle patchSessionResult message.""" payload = message.payload logger.info(f"[RX Processing] patchSessionResult for session {payload.sessionId}: " f"success={payload.success}, message='{payload.message}'") # TODO: Handle patch session result if needed # For now, just log the response async def _send_message(self, message) -> None: """Send message to backend via WebSocket.""" if not self.connected: logger.warning("Cannot send message: WebSocket not connected") return try: json_message = serialize_outgoing_message(message) await self.websocket.send_text(json_message) # Log non-heartbeat messages only (heartbeats are logged in their respective functions) if not (hasattr(message, 'type') and message.type == 'stateReport'): logger.info(f"[TX → Backend] {json_message}") except Exception as e: logger.error(f"Failed to send WebSocket message: {e}") raise async def _process_streams(self) -> None: """ Stream processing task that handles frame processing and detection. This is a placeholder for Phase 2 - currently just logs that it's running. """ logger.info("Stream processing task started") try: while self.connected: # Get current subscriptions subscriptions = worker_state.get_all_subscriptions() # TODO: Phase 2 - Add actual frame processing logic here # This will include: # - Frame reading from RTSP/HTTP streams # - Model inference using loaded pipelines # - Detection result sending via WebSocket # Sleep to prevent excessive CPU usage (similar to old poll_interval) await asyncio.sleep(0.1) # 100ms polling interval except asyncio.CancelledError: logger.info("Stream processing task cancelled") except Exception as e: logger.error(f"Error in stream processing: {e}", exc_info=True) async def _cleanup(self) -> None: """Clean up resources when connection closes.""" logger.info("Cleaning up WebSocket connection") self.connected = False # Cancel background tasks if self._heartbeat_task and not self._heartbeat_task.done(): self._heartbeat_task.cancel() if self._message_task and not self._message_task.done(): self._message_task.cancel() # Clear worker state worker_state.set_subscriptions([]) worker_state.session_ids.clear() worker_state.progression_stages.clear() logger.info("WebSocket connection cleanup completed") # Factory function for FastAPI integration async def websocket_endpoint(websocket: WebSocket) -> None: """ FastAPI WebSocket endpoint handler. Args: websocket: FastAPI WebSocket connection """ handler = WebSocketHandler(websocket) await handler.handle_connection()