""" WebSocket message handling and protocol implementation. """ import asyncio import json import logging import os import cv2 from datetime import datetime, timezone, timedelta from pathlib import Path from typing import Optional from fastapi import WebSocket, WebSocketDisconnect from websockets.exceptions import ConnectionClosedError from .messages import ( parse_incoming_message, serialize_outgoing_message, MessageTypes, create_state_report ) from .models import ( SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage, RequestStateMessage, PatchSessionResultMessage ) from .state import worker_state, SystemMetrics from ..models import ModelManager from ..streaming.manager import shared_stream_manager from ..tracking.integration import TrackingPipelineIntegration logger = logging.getLogger(__name__) # Constants HEARTBEAT_INTERVAL = 2.0 # seconds WORKER_TIMEOUT_MS = 10000 # Global model manager instance model_manager = ModelManager() class WebSocketHandler: """ Handles WebSocket connection lifecycle and message processing. """ def __init__(self, websocket: WebSocket): self.websocket = websocket self.connected = False self._heartbeat_task: Optional[asyncio.Task] = None self._message_task: Optional[asyncio.Task] = None self._heartbeat_count = 0 self._last_processed_models: set = set() # Cache of last processed model IDs async def handle_connection(self) -> None: """ Main connection handler that manages the WebSocket lifecycle. Based on the original architecture from archive/app.py """ client_info = f"{self.websocket.client.host}:{self.websocket.client.port}" if self.websocket.client else "unknown" logger.info(f"Starting WebSocket handler for {client_info}") stream_task = None try: logger.info(f"Accepting WebSocket connection from {client_info}") await self.websocket.accept() self.connected = True logger.info(f"WebSocket connection accepted and established for {client_info}") # Send immediate heartbeat to show connection is alive await self._send_immediate_heartbeat() # Start background tasks (matching original architecture) stream_task = asyncio.create_task(self._process_streams()) heartbeat_task = asyncio.create_task(self._send_heartbeat()) message_task = asyncio.create_task(self._handle_messages()) logger.info(f"WebSocket background tasks started for {client_info} (stream + heartbeat + message handler)") # Wait for heartbeat and message tasks (stream runs independently) await asyncio.gather(heartbeat_task, message_task) except Exception as e: logger.error(f"Error in WebSocket connection for {client_info}: {e}", exc_info=True) finally: logger.info(f"Cleaning up connection for {client_info}") # Cancel stream task if stream_task and not stream_task.done(): stream_task.cancel() try: await stream_task except asyncio.CancelledError: logger.debug(f"Stream task cancelled for {client_info}") await self._cleanup() async def _send_immediate_heartbeat(self) -> None: """Send immediate heartbeat on connection to show we're alive.""" try: cpu_usage = SystemMetrics.get_cpu_usage() memory_usage = SystemMetrics.get_memory_usage() gpu_usage = SystemMetrics.get_gpu_usage() gpu_memory_usage = SystemMetrics.get_gpu_memory_usage() camera_connections = worker_state.get_camera_connections() state_report = create_state_report( cpu_usage=cpu_usage, memory_usage=memory_usage, gpu_usage=gpu_usage, gpu_memory_usage=gpu_memory_usage, camera_connections=camera_connections ) await self._send_message(state_report) logger.info(f"[TX → Backend] Initial stateReport: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, " f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras") except Exception as e: logger.error(f"Error sending immediate heartbeat: {e}") async def _send_heartbeat(self) -> None: """Send periodic state reports as heartbeat.""" while self.connected: try: # Collect system metrics cpu_usage = SystemMetrics.get_cpu_usage() memory_usage = SystemMetrics.get_memory_usage() gpu_usage = SystemMetrics.get_gpu_usage() gpu_memory_usage = SystemMetrics.get_gpu_memory_usage() camera_connections = worker_state.get_camera_connections() # Create and send state report state_report = create_state_report( cpu_usage=cpu_usage, memory_usage=memory_usage, gpu_usage=gpu_usage, gpu_memory_usage=gpu_memory_usage, camera_connections=camera_connections ) await self._send_message(state_report) # Only log full details every 10th heartbeat, otherwise just show a dot self._heartbeat_count += 1 if self._heartbeat_count % 10 == 0: logger.info(f"[TX → Backend] Heartbeat #{self._heartbeat_count}: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, " f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras") else: print(".", end="", flush=True) # Just show a dot to indicate heartbeat activity await asyncio.sleep(HEARTBEAT_INTERVAL) except Exception as e: logger.error(f"Error sending heartbeat: {e}") break async def _handle_messages(self) -> None: """Handle incoming WebSocket messages.""" while self.connected: try: raw_message = await self.websocket.receive_text() logger.info(f"[RX ← Backend] {raw_message}") # Parse incoming message message = parse_incoming_message(raw_message) if not message: logger.warning("Failed to parse incoming message") continue # Route message to appropriate handler await self._route_message(message) except (WebSocketDisconnect, ConnectionClosedError) as e: logger.warning(f"WebSocket disconnected: {e}") break except json.JSONDecodeError: logger.error("Received invalid JSON message") except Exception as e: logger.error(f"Error handling message: {e}") break async def _route_message(self, message) -> None: """Route parsed message to appropriate handler.""" message_type = message.type try: if message_type == MessageTypes.SET_SUBSCRIPTION_LIST: await self._handle_set_subscription_list(message) elif message_type == MessageTypes.SET_SESSION_ID: await self._handle_set_session_id(message) elif message_type == MessageTypes.SET_PROGRESSION_STAGE: await self._handle_set_progression_stage(message) elif message_type == MessageTypes.REQUEST_STATE: await self._handle_request_state(message) elif message_type == MessageTypes.PATCH_SESSION_RESULT: await self._handle_patch_session_result(message) else: logger.warning(f"Unknown message type: {message_type}") except Exception as e: logger.error(f"Error handling {message_type} message: {e}") async def _handle_set_subscription_list(self, message: SetSubscriptionListMessage) -> None: """Handle setSubscriptionList message for declarative subscription management.""" logger.info(f"[RX Processing] setSubscriptionList with {len(message.subscriptions)} subscriptions") # Update worker state with new subscriptions worker_state.set_subscriptions(message.subscriptions) # Phase 2: Download and manage models await self._ensure_models(message.subscriptions) # Phase 3 & 4: Integrate with streaming management and tracking await self._update_stream_subscriptions(message.subscriptions) logger.info("Subscription list updated successfully") async def _ensure_models(self, subscriptions) -> None: """Ensure all required models are downloaded and available.""" # Extract unique model requirements unique_models = {} for subscription in subscriptions: model_id = subscription.modelId if model_id not in unique_models: unique_models[model_id] = { 'model_url': subscription.modelUrl, 'model_name': subscription.modelName } # Check if model set has changed to avoid redundant processing current_model_ids = set(unique_models.keys()) if current_model_ids == self._last_processed_models: logger.debug(f"[Model Management] Model set unchanged {list(current_model_ids)}, skipping checks") return logger.info(f"[Model Management] Processing {len(unique_models)} unique models: {list(unique_models.keys())}") self._last_processed_models = current_model_ids # Check and download models concurrently download_tasks = [] for model_id, model_info in unique_models.items(): task = asyncio.create_task( self._ensure_single_model(model_id, model_info['model_url'], model_info['model_name']) ) download_tasks.append(task) # Wait for all downloads to complete if download_tasks: results = await asyncio.gather(*download_tasks, return_exceptions=True) # Log results success_count = 0 for i, result in enumerate(results): model_id = list(unique_models.keys())[i] if isinstance(result, Exception): logger.error(f"[Model Management] Failed to ensure model {model_id}: {result}") elif result: success_count += 1 logger.info(f"[Model Management] Model {model_id} ready for use") else: logger.error(f"[Model Management] Failed to ensure model {model_id}") logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models") async def _update_stream_subscriptions(self, subscriptions) -> None: """Update streaming subscriptions with tracking integration.""" try: # Convert subscriptions to the format expected by StreamManager subscription_payloads = [] for subscription in subscriptions: payload = { 'subscriptionIdentifier': subscription.subscriptionIdentifier, 'rtspUrl': subscription.rtspUrl, 'snapshotUrl': subscription.snapshotUrl, 'snapshotInterval': subscription.snapshotInterval, 'modelId': subscription.modelId, 'modelUrl': subscription.modelUrl, 'modelName': subscription.modelName } # Add crop coordinates if present if hasattr(subscription, 'cropX1'): payload.update({ 'cropX1': subscription.cropX1, 'cropY1': subscription.cropY1, 'cropX2': subscription.cropX2, 'cropY2': subscription.cropY2 }) subscription_payloads.append(payload) # Reconcile subscriptions with StreamManager logger.info("[Streaming] Reconciling stream subscriptions with tracking") reconcile_result = await self._reconcile_subscriptions_with_tracking(subscription_payloads) logger.info(f"[Streaming] Subscription reconciliation complete: " f"added={reconcile_result.get('added', 0)}, " f"removed={reconcile_result.get('removed', 0)}, " f"failed={reconcile_result.get('failed', 0)}") except Exception as e: logger.error(f"Error updating stream subscriptions: {e}", exc_info=True) async def _reconcile_subscriptions_with_tracking(self, target_subscriptions) -> dict: """Reconcile subscriptions with tracking integration.""" try: # First, we need to create tracking integrations for each unique model tracking_integrations = {} for subscription_payload in target_subscriptions: model_id = subscription_payload['modelId'] # Create tracking integration if not already created if model_id not in tracking_integrations: # Get pipeline configuration for this model pipeline_parser = model_manager.get_pipeline_config(model_id) if pipeline_parser: # Create tracking integration with message sender tracking_integration = TrackingPipelineIntegration( pipeline_parser, model_manager, model_id, self._send_message ) # Initialize tracking model success = await tracking_integration.initialize_tracking_model() if success: tracking_integrations[model_id] = tracking_integration logger.info(f"[Tracking] Created tracking integration for model {model_id}") else: logger.warning(f"[Tracking] Failed to initialize tracking for model {model_id}") else: logger.warning(f"[Tracking] No pipeline config found for model {model_id}") # Now reconcile with StreamManager, adding tracking integrations current_subscription_ids = set() for subscription_info in shared_stream_manager.get_all_subscriptions(): current_subscription_ids.add(subscription_info.subscription_id) target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions} # Find subscriptions to remove and add to_remove = current_subscription_ids - target_subscription_ids to_add = target_subscription_ids - current_subscription_ids # Remove old subscriptions removed_count = 0 for subscription_id in to_remove: if shared_stream_manager.remove_subscription(subscription_id): removed_count += 1 logger.info(f"[Streaming] Removed subscription {subscription_id}") # Add new subscriptions with tracking added_count = 0 failed_count = 0 for subscription_payload in target_subscriptions: subscription_id = subscription_payload['subscriptionIdentifier'] if subscription_id in to_add: success = await self._add_subscription_with_tracking( subscription_payload, tracking_integrations ) if success: added_count += 1 logger.info(f"[Streaming] Added subscription {subscription_id} with tracking") else: failed_count += 1 logger.error(f"[Streaming] Failed to add subscription {subscription_id}") return { 'removed': removed_count, 'added': added_count, 'failed': failed_count, 'total_active': len(shared_stream_manager.get_all_subscriptions()) } except Exception as e: logger.error(f"Error in subscription reconciliation with tracking: {e}", exc_info=True) return {'removed': 0, 'added': 0, 'failed': 0, 'total_active': 0} async def _add_subscription_with_tracking(self, payload, tracking_integrations) -> bool: """Add a subscription with tracking integration.""" try: from ..streaming.manager import StreamConfig subscription_id = payload['subscriptionIdentifier'] camera_id = subscription_id.split(';')[-1] model_id = payload['modelId'] logger.info(f"[SUBSCRIPTION_MAPPING] subscription_id='{subscription_id}' → camera_id='{camera_id}'") # Get tracking integration for this model tracking_integration = tracking_integrations.get(model_id) # Extract crop coordinates if present crop_coords = None if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']): crop_coords = ( payload['cropX1'], payload['cropY1'], payload['cropX2'], payload['cropY2'] ) # Create stream configuration stream_config = StreamConfig( camera_id=camera_id, rtsp_url=payload.get('rtspUrl'), snapshot_url=payload.get('snapshotUrl'), snapshot_interval=payload.get('snapshotInterval', 5000), max_retries=3, ) # Add subscription to StreamManager with tracking success = shared_stream_manager.add_subscription( subscription_id=subscription_id, stream_config=stream_config, crop_coords=crop_coords, model_id=model_id, model_url=payload.get('modelUrl'), tracking_integration=tracking_integration ) if success and tracking_integration: logger.info(f"[Tracking] Subscription {subscription_id} configured with tracking for model {model_id}") return success except Exception as e: logger.error(f"Error adding subscription with tracking: {e}", exc_info=True) return False async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool: """Ensure a single model is downloaded and available.""" try: # Check if model is already available if model_manager.is_model_downloaded(model_id): logger.info(f"[Model Management] Model {model_id} ({model_name}) already available") return True # Download and extract model in a thread pool to avoid blocking the event loop logger.info(f"[Model Management] Downloading model {model_id} ({model_name}) from {model_url}") # Use asyncio.to_thread for CPU-bound operations (Python 3.9+) # For compatibility, we'll use run_in_executor loop = asyncio.get_event_loop() model_path = await loop.run_in_executor( None, model_manager.ensure_model, model_id, model_url, model_name ) if model_path: logger.info(f"[Model Management] Successfully prepared model {model_id} at {model_path}") return True else: logger.error(f"[Model Management] Failed to prepare model {model_id}") return False except Exception as e: logger.error(f"[Model Management] Exception ensuring model {model_id}: {str(e)}", exc_info=True) return False async def _save_snapshot(self, display_identifier: str, session_id: int) -> None: """ Save snapshot image to images folder after receiving sessionId. Args: display_identifier: Display identifier to match with subscriptionIdentifier session_id: Session ID to include in filename """ try: # Find subscription that matches the displayIdentifier matching_subscription = None for subscription in worker_state.get_all_subscriptions(): # Extract display ID from subscriptionIdentifier (format: displayId;cameraId) from .messages import extract_display_identifier sub_display_id = extract_display_identifier(subscription.subscriptionIdentifier) if sub_display_id == display_identifier: matching_subscription = subscription break if not matching_subscription: logger.error(f"[Snapshot Save] No subscription found for display {display_identifier}") return if not matching_subscription.snapshotUrl: logger.error(f"[Snapshot Save] No snapshotUrl found for display {display_identifier}") return # Ensure images directory exists (relative path for Docker bind mount) images_dir = Path("images") images_dir.mkdir(exist_ok=True) # Generate filename with timestamp and session ID timestamp = datetime.now(tz=timezone(timedelta(hours=7))).strftime("%Y%m%d_%H%M%S") filename = f"{session_id}_{display_identifier}_{timestamp}.jpg" filepath = images_dir / filename # Use existing HTTPSnapshotReader to fetch snapshot logger.info(f"[Snapshot Save] Fetching snapshot from {matching_subscription.snapshotUrl}") # Run snapshot fetch in thread pool to avoid blocking async loop loop = asyncio.get_event_loop() frame = await loop.run_in_executor(None, self._fetch_snapshot_sync, matching_subscription.snapshotUrl) if frame is not None: # Save the image using OpenCV success = cv2.imwrite(str(filepath), frame) if success: logger.info(f"[Snapshot Save] Successfully saved snapshot to {filepath}") else: logger.error(f"[Snapshot Save] Failed to save image file {filepath}") else: logger.error(f"[Snapshot Save] Failed to fetch snapshot from {matching_subscription.snapshotUrl}") except Exception as e: logger.error(f"[Snapshot Save] Error saving snapshot for display {display_identifier}: {e}", exc_info=True) def _fetch_snapshot_sync(self, snapshot_url: str): """ Synchronous snapshot fetching using existing HTTPSnapshotReader infrastructure. Args: snapshot_url: URL to fetch snapshot from Returns: np.ndarray or None: Fetched frame or None on error """ try: from ..streaming.readers import HTTPSnapshotReader # Create temporary snapshot reader for single fetch snapshot_reader = HTTPSnapshotReader( camera_id="temp_snapshot", snapshot_url=snapshot_url, interval_ms=5000 # Not used for single fetch ) # Use existing fetch_single_snapshot method return snapshot_reader.fetch_single_snapshot() except Exception as e: logger.error(f"Error in sync snapshot fetch: {e}") return None async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None: """Handle setSessionId message.""" display_identifier = message.payload.displayIdentifier session_id = message.payload.sessionId logger.info(f"[RX Processing] setSessionId for display {display_identifier}: {session_id}") # Update worker state worker_state.set_session_id(display_identifier, session_id) # Update tracking integrations with session ID shared_stream_manager.set_session_id(display_identifier, session_id) # Save snapshot image after getting sessionId if session_id: await self._save_snapshot(display_identifier, session_id) async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None: """Handle setProgressionStage message.""" display_identifier = message.payload.displayIdentifier stage = message.payload.progressionStage logger.info(f"[RX Processing] setProgressionStage for display {display_identifier}: {stage}") # Update worker state worker_state.set_progression_stage(display_identifier, stage) # Update tracking integration for car abandonment detection session_id = worker_state.get_session_id(display_identifier) if session_id: shared_stream_manager.set_progression_stage(session_id, stage) # If stage indicates session is cleared/finished, clear from tracking if stage in ['finished', 'cleared', 'idle']: # Get session ID for this display and clear it if session_id: shared_stream_manager.clear_session_id(session_id) logger.info(f"[Tracking] Cleared session {session_id} due to progression stage: {stage}") async def _handle_request_state(self, message: RequestStateMessage) -> None: """Handle requestState message by sending immediate state report.""" logger.debug("[RX Processing] requestState - sending immediate state report") # Collect metrics and send state report cpu_usage = SystemMetrics.get_cpu_usage() memory_usage = SystemMetrics.get_memory_usage() gpu_usage = SystemMetrics.get_gpu_usage() gpu_memory_usage = SystemMetrics.get_gpu_memory_usage() camera_connections = worker_state.get_camera_connections() state_report = create_state_report( cpu_usage=cpu_usage, memory_usage=memory_usage, gpu_usage=gpu_usage, gpu_memory_usage=gpu_memory_usage, camera_connections=camera_connections ) await self._send_message(state_report) async def _handle_patch_session_result(self, message: PatchSessionResultMessage) -> None: """Handle patchSessionResult message.""" payload = message.payload logger.info(f"[RX Processing] patchSessionResult for session {payload.sessionId}: " f"success={payload.success}, message='{payload.message}'") # TODO: Handle patch session result if needed # For now, just log the response async def _send_message(self, message) -> None: """Send message to backend via WebSocket.""" if not self.connected: logger.warning("Cannot send message: WebSocket not connected") return try: json_message = serialize_outgoing_message(message) await self.websocket.send_text(json_message) # Log non-heartbeat messages only (heartbeats are logged in their respective functions) if not (hasattr(message, 'type') and message.type == 'stateReport'): logger.info(f"[TX → Backend] {json_message}") except Exception as e: logger.error(f"Failed to send WebSocket message: {e}") raise async def _process_streams(self) -> None: """ Stream processing task that handles frame processing and detection. This is a placeholder for Phase 2 - currently just logs that it's running. """ logger.info("Stream processing task started") try: while self.connected: # Get current subscriptions subscriptions = worker_state.get_all_subscriptions() # TODO: Phase 2 - Add actual frame processing logic here # This will include: # - Frame reading from RTSP/HTTP streams # - Model inference using loaded pipelines # - Detection result sending via WebSocket # Sleep to prevent excessive CPU usage (similar to old poll_interval) await asyncio.sleep(0.1) # 100ms polling interval except asyncio.CancelledError: logger.info("Stream processing task cancelled") except Exception as e: logger.error(f"Error in stream processing: {e}", exc_info=True) async def _cleanup(self) -> None: """Clean up resources when connection closes.""" logger.info("Cleaning up WebSocket connection") self.connected = False # Cancel background tasks if self._heartbeat_task and not self._heartbeat_task.done(): self._heartbeat_task.cancel() if self._message_task and not self._message_task.done(): self._message_task.cancel() # Clear worker state worker_state.set_subscriptions([]) worker_state.session_ids.clear() worker_state.progression_stages.clear() logger.info("WebSocket connection cleanup completed") # Factory function for FastAPI integration async def websocket_endpoint(websocket: WebSocket) -> None: """ FastAPI WebSocket endpoint handler. Args: websocket: FastAPI WebSocket connection """ handler = WebSocketHandler(websocket) await handler.handle_connection()