""" WebSocket handler module. This module manages WebSocket connections, message processing, heartbeat functionality, and coordination between stream processing and detection pipelines. """ import asyncio import json import logging import time import traceback import uuid from typing import Dict, Any, Optional, Callable, List, Set from contextlib import asynccontextmanager from fastapi import WebSocket from fastapi.websockets import WebSocketDisconnect from websockets.exceptions import ConnectionClosedError from ..core.config import config, subscription_to_camera, latest_frames from ..core.constants import HEARTBEAT_INTERVAL from ..core.exceptions import WebSocketError, StreamError from ..streams.stream_manager import StreamManager from ..streams.camera_monitor import CameraMonitor from ..detection.detection_result import DetectionResult from ..models.model_manager import ModelManager from ..pipeline.pipeline_executor import PipelineExecutor from ..storage.session_cache import SessionCacheManager from ..storage.redis_client import RedisClientManager from ..utils.system_monitor import get_system_metrics # Setup logging logger = logging.getLogger("detector_worker.websocket_handler") ws_logger = logging.getLogger("websocket") ws_rxtx_logger = logging.getLogger("websocket.rxtx") # Dedicated RX/TX logger # Type definitions for callbacks MessageHandler = Callable[[Dict[str, Any]], asyncio.coroutine] DetectionHandler = Callable[[str, Dict[str, Any], Any, WebSocket, Any, Dict[str, Any]], asyncio.coroutine] class WebSocketHandler: """ Manages WebSocket connections and message processing for the detection worker. This class handles: - WebSocket lifecycle management - Message routing and processing - Heartbeat/state reporting - Stream subscription management - Detection result forwarding - Session management """ def __init__( self, stream_manager: StreamManager, model_manager: ModelManager, pipeline_executor: PipelineExecutor, session_cache: SessionCacheManager, redis_client: Optional[RedisClientManager] = None ): """ Initialize the WebSocket handler. Args: stream_manager: Manager for camera streams model_manager: Manager for ML models pipeline_executor: Pipeline execution engine session_cache: Session state cache redis_client: Optional Redis client for pub/sub """ self.stream_manager = stream_manager self.model_manager = model_manager self.pipeline_executor = pipeline_executor self.session_cache = session_cache self.redis_client = redis_client # Connection state self.websocket: Optional[WebSocket] = None self.connected: bool = False self.tasks: List[asyncio.Task] = [] # Message handlers self.message_handlers: Dict[str, MessageHandler] = { "subscribe": self._handle_subscribe, "unsubscribe": self._handle_unsubscribe, "requestState": self._handle_request_state, "setSessionId": self._handle_set_session, "patchSession": self._handle_patch_session, "setProgressionStage": self._handle_set_progression_stage } # Session and display management self.session_ids: Dict[str, str] = {} # display_identifier -> session_id self.display_identifiers: Set[str] = set() # Camera monitor self.camera_monitor = CameraMonitor() async def handle_connection(self, websocket: WebSocket) -> None: """ Main entry point for handling a WebSocket connection. Args: websocket: The WebSocket connection to handle """ try: await websocket.accept() self.websocket = websocket self.connected = True # Log connection details client_host = getattr(websocket.client, 'host', 'unknown') client_port = getattr(websocket.client, 'port', 'unknown') logger.info(f"πŸ”— WebSocket connection accepted from {client_host}:{client_port}") ws_rxtx_logger.info(f"CONNECT -> Client: {client_host}:{client_port}") # Create concurrent tasks stream_task = asyncio.create_task(self._process_streams()) heartbeat_task = asyncio.create_task(self._send_heartbeat()) message_task = asyncio.create_task(self._process_messages()) self.tasks = [stream_task, heartbeat_task, message_task] # Wait for tasks to complete await asyncio.gather(heartbeat_task, message_task) except Exception as e: logger.error(f"Error in WebSocket handler: {e}") finally: self.connected = False client_host = getattr(websocket.client, 'host', 'unknown') if websocket.client else 'unknown' client_port = getattr(websocket.client, 'port', 'unknown') if websocket.client else 'unknown' ws_rxtx_logger.info(f"DISCONNECT -> Client: {client_host}:{client_port}") await self._cleanup() async def _cleanup(self) -> None: """Clean up resources when connection closes.""" logger.info("Cleaning up WebSocket connection") # Cancel all tasks for task in self.tasks: task.cancel() try: await task except asyncio.CancelledError: pass # Clean up streams await self.stream_manager.cleanup_all_streams() # Clean up models self.model_manager.cleanup_all_models() # Clear session state self.session_cache.clear_all_sessions() self.session_ids.clear() self.display_identifiers.clear() # Clear camera states self.camera_monitor.clear_all_states() logger.info("WebSocket cleanup completed") async def _send_heartbeat(self) -> None: """Send periodic heartbeat/state reports to maintain connection.""" while self.connected: try: # Get system metrics metrics = get_system_metrics() # Get active streams info active_streams = self.stream_manager.get_active_streams() active_models = self.model_manager.get_loaded_models() state_data = { "type": "stateReport", "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "data": { "activeStreams": len(active_streams), "loadedModels": len(active_models), "cpuUsage": metrics.get("cpu_percent", 0), "memoryUsage": metrics.get("memory_percent", 0), "gpuUsage": metrics.get("gpu_percent", 0), "gpuMemory": metrics.get("gpu_memory_percent", 0), "uptime": time.time() - metrics.get("start_time", time.time()) } } # Compact JSON for RX/TX logging compact_json = json.dumps(state_data, separators=(',', ':')) ws_rxtx_logger.info(f"TX -> {compact_json}") await self.websocket.send_json(state_data) await asyncio.sleep(HEARTBEAT_INTERVAL) except (WebSocketDisconnect, ConnectionClosedError): logger.info("WebSocket disconnected during heartbeat") break except Exception as e: logger.error(f"Error sending heartbeat: {e}") break async def _process_messages(self) -> None: """Process incoming WebSocket messages.""" while self.connected: try: text_data = await self.websocket.receive_text() ws_rxtx_logger.info(f"RX <- {text_data}") data = json.loads(text_data) msg_type = data.get("type") # Log message processing logger.debug(f"πŸ“₯ Processing message type: {msg_type}") if msg_type in self.message_handlers: handler = self.message_handlers[msg_type] await handler(data) logger.debug(f"βœ… Message {msg_type} processed successfully") else: logger.error(f"❌ Unknown message type: {msg_type}") ws_rxtx_logger.error(f"UNKNOWN_MSG_TYPE -> {msg_type}") except json.JSONDecodeError: logger.error("Received invalid JSON message") except (WebSocketDisconnect, ConnectionClosedError) as e: logger.warning(f"WebSocket disconnected: {e}") break except Exception as e: logger.error(f"Error handling message: {e}") traceback.print_exc() break async def _process_streams(self) -> None: """Process active camera streams and run detection pipelines.""" while self.connected: try: active_streams = self.stream_manager.get_active_streams() if active_streams: # Process each active stream tasks = [] for camera_id, stream_info in active_streams.items(): # Get latest frame frame = self.stream_manager.get_latest_frame(camera_id) if frame is None: continue # Get model for this camera model_id = stream_info.get("modelId") if not model_id: continue model_tree = self.model_manager.get_model(camera_id, model_id) if not model_tree: continue # Create detection task persistent_data = self.session_cache.get_persistent_data(camera_id) task = asyncio.create_task( self._handle_detection( camera_id, stream_info, frame, self.websocket, model_tree, persistent_data ) ) tasks.append(task) # Wait for all detection tasks if tasks: results = await asyncio.gather(*tasks, return_exceptions=True) # Update persistent data for i, (camera_id, _) in enumerate(active_streams.items()): if i < len(results) and isinstance(results[i], dict): self.session_cache.update_persistent_data(camera_id, results[i]) # Polling interval poll_interval = config.get("poll_interval_ms", 100) / 1000.0 await asyncio.sleep(poll_interval) except asyncio.CancelledError: logger.info("Stream processing cancelled") break except Exception as e: logger.error(f"Error in stream processing: {e}") await asyncio.sleep(1) async def _handle_detection( self, camera_id: str, stream_info: Dict[str, Any], frame: Any, websocket: WebSocket, model_tree: Any, persistent_data: Dict[str, Any] ) -> Dict[str, Any]: """ Handle detection for a single camera frame. Args: camera_id: Camera identifier stream_info: Stream configuration frame: Video frame to process websocket: WebSocket connection model_tree: Model pipeline tree persistent_data: Persistent data for this camera Returns: Updated persistent data """ try: # Check camera connection state if self.camera_monitor.should_notify_disconnection(camera_id): await self._send_disconnection_notification(camera_id, stream_info) return persistent_data # Apply crop if specified cropped_frame = self._apply_crop(frame, stream_info) # Get session pipeline state pipeline_state = self.session_cache.get_session_pipeline_state(camera_id) # Run detection pipeline detection_result = await self.pipeline_executor.execute_pipeline( camera_id, stream_info, cropped_frame, model_tree, persistent_data, pipeline_state ) # Send detection result if detection_result: await self._send_detection_result( camera_id, stream_info, detection_result ) # Handle camera reconnection if self.camera_monitor.should_notify_reconnection(camera_id): self.camera_monitor.mark_reconnection_notified(camera_id) logger.info(f"Camera {camera_id} reconnected successfully") return persistent_data except Exception as e: logger.error(f"Error in detection handling for camera {camera_id}: {e}") traceback.print_exc() return persistent_data async def _handle_subscribe(self, data: Dict[str, Any]) -> None: """Handle stream subscription request.""" payload = data.get("payload", {}) subscription_id = payload.get("subscriptionIdentifier") if not subscription_id: logger.error("Missing subscriptionIdentifier in subscribe payload") return try: # Extract display and camera IDs parts = subscription_id.split(";") if len(parts) >= 2: display_id = parts[0] camera_id = parts[1] self.display_identifiers.add(display_id) else: camera_id = subscription_id # Store subscription mapping subscription_to_camera[subscription_id] = camera_id # Start camera stream await self.stream_manager.start_stream(camera_id, payload) # Load model model_id = payload.get("modelId") model_url = payload.get("modelUrl") if model_id and model_url: await self.model_manager.load_model(camera_id, model_id, model_url) logger.info(f"Subscribed to stream: {subscription_id}") except Exception as e: logger.error(f"Error handling subscription: {e}") traceback.print_exc() async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None: """Handle stream unsubscription request.""" payload = data.get("payload", {}) subscription_id = payload.get("subscriptionIdentifier") if not subscription_id: logger.error("Missing subscriptionIdentifier in unsubscribe payload") return try: # Get camera ID from subscription camera_id = subscription_to_camera.get(subscription_id) if not camera_id: logger.warning(f"No camera found for subscription: {subscription_id}") return # Stop stream await self.stream_manager.stop_stream(camera_id) # Unload model self.model_manager.unload_models(camera_id) # Clean up mappings subscription_to_camera.pop(subscription_id, None) # Clean up session state self.session_cache.clear_session(camera_id) logger.info(f"Unsubscribed from stream: {subscription_id}") except Exception as e: logger.error(f"Error handling unsubscription: {e}") traceback.print_exc() async def _handle_request_state(self, data: Dict[str, Any]) -> None: """Handle state request message.""" # Send immediate state report await self._send_heartbeat() async def _handle_set_session(self, data: Dict[str, Any]) -> None: """Handle setSessionId message.""" payload = data.get("payload", {}) display_id = payload.get("displayIdentifier") session_id = payload.get("sessionId") if display_id and session_id: self.session_ids[display_id] = session_id # Update session for all cameras of this display with self.stream_manager.streams_lock: for camera_id, stream in self.stream_manager.streams.items(): if stream["subscriptionIdentifier"].startswith(display_id + ";"): self.session_cache.update_session_id(camera_id, session_id) # Send acknowledgment response = { "type": "ack", "requestId": data.get("requestId", str(uuid.uuid4())), "code": "200", "data": { "message": f"Session ID set for display {display_id}", "sessionId": session_id } } ws_rxtx_logger.info(f"TX -> {json.dumps(response, separators=(',', ':'))}") await self.websocket.send_json(response) logger.info(f"Set session {session_id} for display {display_id}") async def _handle_patch_session(self, data: Dict[str, Any]) -> None: """Handle patchSession message.""" payload = data.get("payload", {}) session_id = payload.get("sessionId") patch_data = payload.get("data", {}) if session_id: # Store patch data (could be used for session updates) logger.info(f"Received patch for session {session_id}: {patch_data}") # Send acknowledgment response = { "type": "ack", "requestId": data.get("requestId", str(uuid.uuid4())), "code": "200", "data": { "message": f"Session {session_id} patched successfully", "sessionId": session_id, "patchData": patch_data } } ws_rxtx_logger.info(f"TX -> {json.dumps(response, separators=(',', ':'))}") await self.websocket.send_json(response) async def _handle_set_progression_stage(self, data: Dict[str, Any]) -> None: """Handle setProgressionStage message.""" payload = data.get("payload", {}) display_id = payload.get("displayIdentifier") progression_stage = payload.get("progressionStage") logger.info(f"🏁 PROGRESSION STAGE RECEIVED: displayId={display_id}, stage={progression_stage}") if not display_id: logger.warning("Missing displayIdentifier in setProgressionStage") return # Find all cameras for this display affected_cameras = [] with self.stream_manager.streams_lock: for camera_id, stream in self.stream_manager.streams.items(): if stream["subscriptionIdentifier"].startswith(display_id + ";"): affected_cameras.append(camera_id) logger.debug(f"🎯 Found {len(affected_cameras)} cameras for display {display_id}: {affected_cameras}") # Update progression stage for each camera for camera_id in affected_cameras: pipeline_state = self.session_cache.get_or_init_session_pipeline_state(camera_id) current_mode = pipeline_state.get("mode", "validation_detecting") if progression_stage == "car_fueling": # Stop YOLO inference during fueling if current_mode == "lightweight": pipeline_state["yolo_inference_enabled"] = False pipeline_state["progression_stage"] = "car_fueling" logger.info(f"⏸️ Camera {camera_id}: YOLO inference DISABLED for car_fueling stage") else: logger.debug(f"πŸ“Š Camera {camera_id}: car_fueling received but not in lightweight mode (mode: {current_mode})") elif progression_stage == "car_waitpayment": # Resume YOLO inference for absence counter pipeline_state["yolo_inference_enabled"] = True pipeline_state["progression_stage"] = "car_waitpayment" logger.info(f"▢️ Camera {camera_id}: YOLO inference RE-ENABLED for car_waitpayment stage") elif progression_stage == "welcome": # Ignore welcome messages during car_waitpayment current_progression = pipeline_state.get("progression_stage") if current_progression == "car_waitpayment": logger.info(f"🚫 Camera {camera_id}: IGNORING welcome stage (currently in car_waitpayment)") else: pipeline_state["progression_stage"] = "welcome" logger.info(f"πŸŽ‰ Camera {camera_id}: Progression stage set to welcome") elif progression_stage in ["car_wait_staff"]: pipeline_state["progression_stage"] = progression_stage logger.info(f"πŸ“‹ Camera {camera_id}: Progression stage set to {progression_stage}") async def _send_detection_result( self, camera_id: str, stream_info: Dict[str, Any], detection_result: DetectionResult ) -> None: """Send detection result over WebSocket.""" detection_data = { "type": "imageDetection", "subscriptionIdentifier": stream_info["subscriptionIdentifier"], "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "data": { "detection": detection_result.to_dict(), "modelId": stream_info["modelId"], "modelName": stream_info["modelName"] } } try: ws_rxtx_logger.info(f"TX -> {json.dumps(detection_data, separators=(',', ':'))}") await self.websocket.send_json(detection_data) except RuntimeError as e: if "websocket.close" in str(e): logger.warning(f"WebSocket closed - cannot send detection for camera {camera_id}") else: raise async def _send_disconnection_notification( self, camera_id: str, stream_info: Dict[str, Any] ) -> None: """Send camera disconnection notification.""" logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null") # Clear cached data self.session_cache.clear_session(camera_id) # Send null detection detection_data = { "type": "imageDetection", "subscriptionIdentifier": stream_info["subscriptionIdentifier"], "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "data": { "detection": None, "modelId": stream_info["modelId"], "modelName": stream_info["modelName"] } } try: ws_rxtx_logger.info(f"TX -> {json.dumps(detection_data, separators=(',', ':'))}") await self.websocket.send_json(detection_data) except RuntimeError as e: if "websocket.close" in str(e): logger.warning(f"WebSocket closed - cannot send disconnection signal for camera {camera_id}") else: raise self.camera_monitor.mark_disconnection_notified(camera_id) logger.info(f"πŸ“‘ SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session") def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any: """Apply crop to frame if crop coordinates are specified.""" crop_coords = [ stream_info.get("cropX1"), stream_info.get("cropY1"), stream_info.get("cropX2"), stream_info.get("cropY2") ] if all(coord is not None for coord in crop_coords): x1, y1, x2, y2 = crop_coords return frame[y1:y2, x1:x2] return frame # Convenience function for backward compatibility async def handle_websocket_connection( websocket: WebSocket, stream_manager: StreamManager, model_manager: ModelManager, pipeline_executor: PipelineExecutor, session_cache: SessionCacheManager, redis_client: Optional[RedisClientManager] = None ) -> None: """ Handle a WebSocket connection using the WebSocketHandler. This is a convenience function that creates a handler instance and processes the connection. """ handler = WebSocketHandler( stream_manager, model_manager, pipeline_executor, session_cache, redis_client ) await handler.handle_connection(websocket)