python-detector-worker/detector_worker/communication/websocket_handler.py

713 lines
No EOL
30 KiB
Python

"""
WebSocket handler module.
This module manages WebSocket connections, message processing, heartbeat functionality,
and coordination between stream processing and detection pipelines.
"""
import asyncio
import json
import logging
import time
import traceback
import uuid
from typing import Dict, Any, Optional, Callable, List, Set
from contextlib import asynccontextmanager
from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
from ..core.config import config, subscription_to_camera, latest_frames
from ..core.constants import HEARTBEAT_INTERVAL
from ..core.exceptions import WebSocketError, StreamError
from ..streams.stream_manager import StreamManager
from ..streams.camera_monitor import CameraMonitor
from ..detection.detection_result import DetectionResult
from ..models.model_manager import ModelManager
from ..pipeline.pipeline_executor import PipelineExecutor
from ..storage.session_cache import SessionCacheManager
from ..storage.redis_client import RedisClientManager
from ..utils.system_monitor import get_system_metrics
# Setup logging
logger = logging.getLogger("detector_worker.websocket_handler")
ws_logger = logging.getLogger("websocket")
ws_rxtx_logger = logging.getLogger("websocket.rxtx") # Dedicated RX/TX logger
# Type definitions for callbacks
MessageHandler = Callable[[Dict[str, Any]], asyncio.coroutine]
DetectionHandler = Callable[[str, Dict[str, Any], Any, WebSocket, Any, Dict[str, Any]], asyncio.coroutine]
class WebSocketHandler:
"""
Manages WebSocket connections and message processing for the detection worker.
This class handles:
- WebSocket lifecycle management
- Message routing and processing
- Heartbeat/state reporting
- Stream subscription management
- Detection result forwarding
- Session management
"""
def __init__(
self,
stream_manager: StreamManager,
model_manager: ModelManager,
pipeline_executor: PipelineExecutor,
session_cache: SessionCacheManager,
redis_client: Optional[RedisClientManager] = None
):
"""
Initialize the WebSocket handler.
Args:
stream_manager: Manager for camera streams
model_manager: Manager for ML models
pipeline_executor: Pipeline execution engine
session_cache: Session state cache
redis_client: Optional Redis client for pub/sub
"""
self.stream_manager = stream_manager
self.model_manager = model_manager
self.pipeline_executor = pipeline_executor
self.session_cache = session_cache
self.redis_client = redis_client
# Connection state
self.websocket: Optional[WebSocket] = None
self.connected: bool = False
self.tasks: List[asyncio.Task] = []
# Message handlers
self.message_handlers: Dict[str, MessageHandler] = {
"setSubscriptionList": self._handle_set_subscription_list,
"requestState": self._handle_request_state,
"setSessionId": self._handle_set_session,
"patchSession": self._handle_patch_session,
"setProgressionStage": self._handle_set_progression_stage,
"patchSessionResult": self._handle_patch_session_result
}
# Session and display management
self.session_ids: Dict[str, str] = {} # display_identifier -> session_id
self.progression_stages: Dict[str, str] = {} # display_identifier -> progression_stage
self.display_identifiers: Set[str] = set()
# Camera monitor
self.camera_monitor = CameraMonitor()
async def handle_connection(self, websocket: WebSocket) -> None:
"""
Main entry point for handling a WebSocket connection.
Args:
websocket: The WebSocket connection to handle
"""
try:
await websocket.accept()
self.websocket = websocket
self.connected = True
# Log connection details
client_host = getattr(websocket.client, 'host', 'unknown')
client_port = getattr(websocket.client, 'port', 'unknown')
logger.info(f"🔗 WebSocket connection accepted from {client_host}:{client_port}")
ws_rxtx_logger.info(f"CONNECT -> Client: {client_host}:{client_port}")
# Create concurrent tasks
stream_task = asyncio.create_task(self._process_streams())
heartbeat_task = asyncio.create_task(self._send_heartbeat())
message_task = asyncio.create_task(self._process_messages())
self.tasks = [stream_task, heartbeat_task, message_task]
# Wait for tasks to complete
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logger.error(f"Error in WebSocket handler: {e}")
finally:
self.connected = False
client_host = getattr(websocket.client, 'host', 'unknown') if websocket.client else 'unknown'
client_port = getattr(websocket.client, 'port', 'unknown') if websocket.client else 'unknown'
ws_rxtx_logger.info(f"DISCONNECT -> Client: {client_host}:{client_port}")
await self._cleanup()
async def _cleanup(self) -> None:
"""Clean up resources when connection closes."""
logger.info("Cleaning up WebSocket connection")
# Cancel all tasks
for task in self.tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Clean up streams
await self.stream_manager.cleanup_all_streams()
# Clean up models
self.model_manager.cleanup_all_models()
# Clear session state
self.session_cache.clear_all_sessions()
self.session_ids.clear()
self.display_identifiers.clear()
# Clear camera states
self.camera_monitor.clear_all_states()
logger.info("WebSocket cleanup completed")
async def _send_heartbeat(self) -> None:
"""Send periodic heartbeat/state reports to maintain connection."""
while self.connected:
try:
# Get system metrics
metrics = get_system_metrics()
# Build cameraConnections array as required by protocol
camera_connections = []
with self.stream_manager.streams_lock:
for camera_id, stream_info in self.stream_manager.streams.items():
# Check if camera is online
is_online = self.stream_manager.is_stream_active(camera_id)
connection_info = {
"subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id),
"modelId": stream_info.get("modelId", 0),
"modelName": stream_info.get("modelName", "Unknown Model"),
"online": is_online
}
# Add crop coordinates if available
if "cropX1" in stream_info:
connection_info["cropX1"] = stream_info["cropX1"]
if "cropY1" in stream_info:
connection_info["cropY1"] = stream_info["cropY1"]
if "cropX2" in stream_info:
connection_info["cropX2"] = stream_info["cropX2"]
if "cropY2" in stream_info:
connection_info["cropY2"] = stream_info["cropY2"]
camera_connections.append(connection_info)
# Protocol-compliant stateReport format (worker.md lines 169-189)
state_data = {
"type": "stateReport",
"cpuUsage": metrics.get("cpu_percent", 0),
"memoryUsage": metrics.get("memory_percent", 0),
"gpuUsage": metrics.get("gpu_percent", 0),
"gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # Fixed field name
"cameraConnections": camera_connections
}
# Compact JSON for RX/TX logging
compact_json = json.dumps(state_data, separators=(',', ':'))
ws_rxtx_logger.info(f"TX -> {compact_json}")
await self.websocket.send_json(state_data)
await asyncio.sleep(HEARTBEAT_INTERVAL)
except (WebSocketDisconnect, ConnectionClosedError):
logger.info("WebSocket disconnected during heartbeat")
break
except Exception as e:
logger.error(f"Error sending heartbeat: {e}")
break
async def _process_messages(self) -> None:
"""Process incoming WebSocket messages."""
while self.connected:
try:
text_data = await self.websocket.receive_text()
ws_rxtx_logger.info(f"RX <- {text_data}")
data = json.loads(text_data)
msg_type = data.get("type")
# Log message processing
logger.debug(f"📥 Processing message type: {msg_type}")
if msg_type in self.message_handlers:
handler = self.message_handlers[msg_type]
await handler(data)
logger.debug(f"✅ Message {msg_type} processed successfully")
else:
logger.error(f"❌ Unknown message type: {msg_type}")
ws_rxtx_logger.error(f"UNKNOWN_MSG_TYPE -> {msg_type}")
except json.JSONDecodeError:
logger.error("Received invalid JSON message")
except (WebSocketDisconnect, ConnectionClosedError) as e:
logger.warning(f"WebSocket disconnected: {e}")
break
except Exception as e:
logger.error(f"Error handling message: {e}")
traceback.print_exc()
break
async def _process_streams(self) -> None:
"""Process active camera streams and run detection pipelines."""
while self.connected:
try:
active_streams = self.stream_manager.get_active_streams()
if active_streams:
# Process each active stream
tasks = []
for camera_id, stream_info in active_streams.items():
# Get latest frame
frame = self.stream_manager.get_latest_frame(camera_id)
if frame is None:
continue
# Get model for this camera
model_id = stream_info.get("modelId")
if not model_id:
continue
model_tree = self.model_manager.get_model(camera_id, model_id)
if not model_tree:
continue
# Create detection task
persistent_data = self.session_cache.get_persistent_data(camera_id)
task = asyncio.create_task(
self._handle_detection(
camera_id, stream_info, frame,
self.websocket, model_tree, persistent_data
)
)
tasks.append(task)
# Wait for all detection tasks
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
# Update persistent data
for i, (camera_id, _) in enumerate(active_streams.items()):
if i < len(results) and isinstance(results[i], dict):
self.session_cache.update_persistent_data(camera_id, results[i])
# Polling interval
poll_interval = config.get("poll_interval_ms", 100) / 1000.0
await asyncio.sleep(poll_interval)
except asyncio.CancelledError:
logger.info("Stream processing cancelled")
break
except Exception as e:
logger.error(f"Error in stream processing: {e}")
await asyncio.sleep(1)
async def _handle_detection(
self,
camera_id: str,
stream_info: Dict[str, Any],
frame: Any,
websocket: WebSocket,
model_tree: Any,
persistent_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Handle detection for a single camera frame.
Args:
camera_id: Camera identifier
stream_info: Stream configuration
frame: Video frame to process
websocket: WebSocket connection
model_tree: Model pipeline tree
persistent_data: Persistent data for this camera
Returns:
Updated persistent data
"""
try:
# Check camera connection state
if self.camera_monitor.should_notify_disconnection(camera_id):
await self._send_disconnection_notification(camera_id, stream_info)
return persistent_data
# Apply crop if specified
cropped_frame = self._apply_crop(frame, stream_info)
# Get session pipeline state
pipeline_state = self.session_cache.get_session_pipeline_state(camera_id)
# Run detection pipeline
detection_result = await self.pipeline_executor.execute_pipeline(
camera_id,
stream_info,
cropped_frame,
model_tree,
persistent_data,
pipeline_state
)
# Send detection result
if detection_result:
await self._send_detection_result(
camera_id, stream_info, detection_result
)
# Handle camera reconnection
if self.camera_monitor.should_notify_reconnection(camera_id):
self.camera_monitor.mark_reconnection_notified(camera_id)
logger.info(f"Camera {camera_id} reconnected successfully")
return persistent_data
except Exception as e:
logger.error(f"Error in detection handling for camera {camera_id}: {e}")
traceback.print_exc()
return persistent_data
async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None:
"""
Handle setSubscriptionList command - declarative subscription management.
This is the primary subscription command per worker.md protocol.
Workers must reconcile the new subscription list with current state.
"""
subscriptions = data.get("subscriptions", [])
try:
# Get current subscription identifiers
current_subscriptions = set(subscription_to_camera.keys())
# Get desired subscription identifiers
desired_subscriptions = set()
subscription_configs = {}
for sub_config in subscriptions:
sub_id = sub_config.get("subscriptionIdentifier")
if sub_id:
desired_subscriptions.add(sub_id)
subscription_configs[sub_id] = sub_config
# Extract display ID for session management
parts = sub_id.split(";")
if len(parts) >= 2:
display_id = parts[0]
self.display_identifiers.add(display_id)
# Calculate changes needed
to_add = desired_subscriptions - current_subscriptions
to_remove = current_subscriptions - desired_subscriptions
to_update = desired_subscriptions & current_subscriptions
logger.info(f"Subscription reconciliation: add={len(to_add)}, remove={len(to_remove)}, update={len(to_update)}")
# Remove obsolete subscriptions
for sub_id in to_remove:
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
subscription_to_camera.pop(sub_id, None)
self.session_cache.clear_session(camera_id)
logger.info(f"Removed subscription: {sub_id}")
# Add new subscriptions
for sub_id in to_add:
await self._start_subscription(sub_id, subscription_configs[sub_id])
logger.info(f"Added subscription: {sub_id}")
# Update existing subscriptions if needed
for sub_id in to_update:
# Check if configuration changed (model URL, crop coordinates, etc.)
current_config = subscription_to_camera.get(sub_id)
new_config = subscription_configs[sub_id]
# For now, restart subscription if model URL changed (handles S3 expiration)
current_model_url = getattr(current_config, 'model_url', None) if current_config else None
new_model_url = new_config.get("modelUrl")
if current_model_url != new_model_url:
# Restart with new configuration
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
await self._start_subscription(sub_id, new_config)
logger.info(f"Updated subscription: {sub_id}")
logger.info(f"Subscription list reconciliation completed. Active: {len(desired_subscriptions)}")
except Exception as e:
logger.error(f"Error handling setSubscriptionList: {e}")
traceback.print_exc()
async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None:
"""Start a single subscription with given configuration."""
try:
# Extract camera ID from subscription identifier
parts = subscription_id.split(";")
camera_id = parts[1] if len(parts) >= 2 else subscription_id
# Store subscription mapping
subscription_to_camera[subscription_id] = camera_id
# Start camera stream
await self.stream_manager.start_stream(camera_id, config)
# Load model
model_id = config.get("modelId")
model_url = config.get("modelUrl")
if model_id and model_url:
await self.model_manager.load_model(camera_id, model_id, model_url)
except Exception as e:
logger.error(f"Error starting subscription {subscription_id}: {e}")
raise
traceback.print_exc()
async def _handle_request_state(self, data: Dict[str, Any]) -> None:
"""Handle state request message."""
# Send immediate state report
await self._send_heartbeat()
async def _handle_set_session(self, data: Dict[str, Any]) -> None:
"""Handle setSessionId message."""
payload = data.get("payload", {})
display_id = payload.get("displayIdentifier")
session_id = payload.get("sessionId")
if display_id and session_id:
self.session_ids[display_id] = session_id
# Update session for all cameras of this display
with self.stream_manager.streams_lock:
for camera_id, stream in self.stream_manager.streams.items():
if stream["subscriptionIdentifier"].startswith(display_id + ";"):
self.session_cache.update_session_id(camera_id, session_id)
# Send acknowledgment
response = {
"type": "ack",
"requestId": data.get("requestId", str(uuid.uuid4())),
"code": "200",
"data": {
"message": f"Session ID set for display {display_id}",
"sessionId": session_id
}
}
ws_rxtx_logger.info(f"TX -> {json.dumps(response, separators=(',', ':'))}")
await self.websocket.send_json(response)
logger.info(f"Set session {session_id} for display {display_id}")
async def _handle_patch_session(self, data: Dict[str, Any]) -> None:
"""Handle patchSession message."""
payload = data.get("payload", {})
session_id = payload.get("sessionId")
patch_data = payload.get("data", {})
if session_id:
# Store patch data (could be used for session updates)
logger.info(f"Received patch for session {session_id}: {patch_data}")
# Send acknowledgment
response = {
"type": "ack",
"requestId": data.get("requestId", str(uuid.uuid4())),
"code": "200",
"data": {
"message": f"Session {session_id} patched successfully",
"sessionId": session_id,
"patchData": patch_data
}
}
ws_rxtx_logger.info(f"TX -> {json.dumps(response, separators=(',', ':'))}")
await self.websocket.send_json(response)
async def _handle_set_progression_stage(self, data: Dict[str, Any]) -> None:
"""Handle setProgressionStage message."""
payload = data.get("payload", {})
display_id = payload.get("displayIdentifier")
progression_stage = payload.get("progressionStage")
logger.info(f"🏁 PROGRESSION STAGE RECEIVED: displayId={display_id}, stage={progression_stage}")
if not display_id:
logger.warning("Missing displayIdentifier in setProgressionStage")
return
# Find all cameras for this display
affected_cameras = []
with self.stream_manager.streams_lock:
for camera_id, stream in self.stream_manager.streams.items():
if stream["subscriptionIdentifier"].startswith(display_id + ";"):
affected_cameras.append(camera_id)
logger.debug(f"🎯 Found {len(affected_cameras)} cameras for display {display_id}: {affected_cameras}")
# Update progression stage for each camera
for camera_id in affected_cameras:
pipeline_state = self.session_cache.get_or_init_session_pipeline_state(camera_id)
current_mode = pipeline_state.get("mode", "validation_detecting")
if progression_stage == "car_fueling":
# Stop YOLO inference during fueling
if current_mode == "lightweight":
pipeline_state["yolo_inference_enabled"] = False
pipeline_state["progression_stage"] = "car_fueling"
logger.info(f"⏸️ Camera {camera_id}: YOLO inference DISABLED for car_fueling stage")
else:
logger.debug(f"📊 Camera {camera_id}: car_fueling received but not in lightweight mode (mode: {current_mode})")
elif progression_stage == "car_waitpayment":
# Resume YOLO inference for absence counter
pipeline_state["yolo_inference_enabled"] = True
pipeline_state["progression_stage"] = "car_waitpayment"
logger.info(f"▶️ Camera {camera_id}: YOLO inference RE-ENABLED for car_waitpayment stage")
elif progression_stage == "welcome":
# Ignore welcome messages during car_waitpayment
current_progression = pipeline_state.get("progression_stage")
if current_progression == "car_waitpayment":
logger.info(f"🚫 Camera {camera_id}: IGNORING welcome stage (currently in car_waitpayment)")
else:
pipeline_state["progression_stage"] = "welcome"
logger.info(f"🎉 Camera {camera_id}: Progression stage set to welcome")
elif progression_stage in ["car_wait_staff"]:
pipeline_state["progression_stage"] = progression_stage
logger.info(f"📋 Camera {camera_id}: Progression stage set to {progression_stage}")
# Store progression stage for this display
if display_id and progression_stage is not None:
if progression_stage:
self.progression_stages[display_id] = progression_stage
else:
# Clear progression stage if null
self.progression_stages.pop(display_id, None)
async def _handle_patch_session_result(self, data: Dict[str, Any]) -> None:
"""Handle patchSessionResult message from backend."""
payload = data.get("payload", {})
session_id = payload.get("sessionId")
success = payload.get("success", False)
message = payload.get("message", "")
if success:
logger.info(f"Patch session {session_id} successful: {message}")
else:
logger.warning(f"Patch session {session_id} failed: {message}")
async def _send_detection_result(
self,
camera_id: str,
stream_info: Dict[str, Any],
detection_result: DetectionResult
) -> None:
"""Send detection result over WebSocket."""
# Get session ID for this display
subscription_id = stream_info["subscriptionIdentifier"]
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
session_id = self.session_ids.get(display_id)
detection_data = {
"type": "imageDetection",
"subscriptionIdentifier": subscription_id,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"sessionId": session_id, # Required by protocol
"data": {
"detection": detection_result.to_dict(),
"modelId": stream_info["modelId"],
"modelName": stream_info["modelName"]
}
}
try:
ws_rxtx_logger.info(f"TX -> {json.dumps(detection_data, separators=(',', ':'))}")
await self.websocket.send_json(detection_data)
except RuntimeError as e:
if "websocket.close" in str(e):
logger.warning(f"WebSocket closed - cannot send detection for camera {camera_id}")
else:
raise
async def _send_disconnection_notification(
self,
camera_id: str,
stream_info: Dict[str, Any]
) -> None:
"""Send camera disconnection notification."""
logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null")
# Clear cached data
self.session_cache.clear_session(camera_id)
# Send null detection
detection_data = {
"type": "imageDetection",
"subscriptionIdentifier": stream_info["subscriptionIdentifier"],
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"data": {
"detection": None,
"modelId": stream_info["modelId"],
"modelName": stream_info["modelName"]
}
}
try:
ws_rxtx_logger.info(f"TX -> {json.dumps(detection_data, separators=(',', ':'))}")
await self.websocket.send_json(detection_data)
except RuntimeError as e:
if "websocket.close" in str(e):
logger.warning(f"WebSocket closed - cannot send disconnection signal for camera {camera_id}")
else:
raise
self.camera_monitor.mark_disconnection_notified(camera_id)
logger.info(f"📡 SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session")
def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any:
"""Apply crop to frame if crop coordinates are specified."""
crop_coords = [
stream_info.get("cropX1"),
stream_info.get("cropY1"),
stream_info.get("cropX2"),
stream_info.get("cropY2")
]
if all(coord is not None for coord in crop_coords):
x1, y1, x2, y2 = crop_coords
return frame[y1:y2, x1:x2]
return frame
# Convenience function for backward compatibility
async def handle_websocket_connection(
websocket: WebSocket,
stream_manager: StreamManager,
model_manager: ModelManager,
pipeline_executor: PipelineExecutor,
session_cache: SessionCacheManager,
redis_client: Optional[RedisClientManager] = None
) -> None:
"""
Handle a WebSocket connection using the WebSocketHandler.
This is a convenience function that creates a handler instance
and processes the connection.
"""
handler = WebSocketHandler(
stream_manager,
model_manager,
pipeline_executor,
session_cache,
redis_client
)
await handler.handle_connection(websocket)