python-detector-worker/core/communication/websocket.py
2025-09-23 15:44:09 +07:00

322 lines
No EOL
13 KiB
Python

"""
WebSocket message handling and protocol implementation.
"""
import asyncio
import json
import logging
from typing import Optional
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
from .messages import (
parse_incoming_message, serialize_outgoing_message,
MessageTypes, create_state_report
)
from .models import (
SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
RequestStateMessage, PatchSessionResultMessage
)
from .state import worker_state, SystemMetrics
logger = logging.getLogger(__name__)
# Constants
HEARTBEAT_INTERVAL = 2.0 # seconds
WORKER_TIMEOUT_MS = 10000
class WebSocketHandler:
"""
Handles WebSocket connection lifecycle and message processing.
"""
def __init__(self, websocket: WebSocket):
self.websocket = websocket
self.connected = False
self._heartbeat_task: Optional[asyncio.Task] = None
self._message_task: Optional[asyncio.Task] = None
async def handle_connection(self) -> None:
"""
Main connection handler that manages the WebSocket lifecycle.
Based on the original architecture from archive/app.py
"""
client_info = f"{self.websocket.client.host}:{self.websocket.client.port}" if self.websocket.client else "unknown"
logger.info(f"Starting WebSocket handler for {client_info}")
stream_task = None
try:
logger.info(f"Accepting WebSocket connection from {client_info}")
await self.websocket.accept()
self.connected = True
logger.info(f"WebSocket connection accepted and established for {client_info}")
# Send immediate heartbeat to show connection is alive
await self._send_immediate_heartbeat()
# Start background tasks (matching original architecture)
stream_task = asyncio.create_task(self._process_streams())
heartbeat_task = asyncio.create_task(self._send_heartbeat())
message_task = asyncio.create_task(self._handle_messages())
logger.info(f"WebSocket background tasks started for {client_info} (stream + heartbeat + message handler)")
# Wait for heartbeat and message tasks (stream runs independently)
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logger.error(f"Error in WebSocket connection for {client_info}: {e}", exc_info=True)
finally:
logger.info(f"Cleaning up connection for {client_info}")
# Cancel stream task
if stream_task and not stream_task.done():
stream_task.cancel()
try:
await stream_task
except asyncio.CancelledError:
logger.debug(f"Stream task cancelled for {client_info}")
await self._cleanup()
async def _send_immediate_heartbeat(self) -> None:
"""Send immediate heartbeat on connection to show we're alive."""
try:
cpu_usage = SystemMetrics.get_cpu_usage()
memory_usage = SystemMetrics.get_memory_usage()
gpu_usage = SystemMetrics.get_gpu_usage()
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
camera_connections = worker_state.get_camera_connections()
state_report = create_state_report(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
gpu_usage=gpu_usage,
gpu_memory_usage=gpu_memory_usage,
camera_connections=camera_connections
)
await self._send_message(state_report)
logger.info(f"[TX → Backend] stateReport: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
except Exception as e:
logger.error(f"Error sending immediate heartbeat: {e}")
async def _send_heartbeat(self) -> None:
"""Send periodic state reports as heartbeat."""
while self.connected:
try:
# Collect system metrics
cpu_usage = SystemMetrics.get_cpu_usage()
memory_usage = SystemMetrics.get_memory_usage()
gpu_usage = SystemMetrics.get_gpu_usage()
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
camera_connections = worker_state.get_camera_connections()
# Create and send state report
state_report = create_state_report(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
gpu_usage=gpu_usage,
gpu_memory_usage=gpu_memory_usage,
camera_connections=camera_connections
)
await self._send_message(state_report)
logger.info(f"[TX → Backend] Heartbeat: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
await asyncio.sleep(HEARTBEAT_INTERVAL)
except Exception as e:
logger.error(f"Error sending heartbeat: {e}")
break
async def _handle_messages(self) -> None:
"""Handle incoming WebSocket messages."""
while self.connected:
try:
raw_message = await self.websocket.receive_text()
logger.info(f"[RX ← Backend] {raw_message}")
# Parse incoming message
message = parse_incoming_message(raw_message)
if not message:
logger.warning("Failed to parse incoming message")
continue
# Route message to appropriate handler
await self._route_message(message)
except (WebSocketDisconnect, ConnectionClosedError) as e:
logger.warning(f"WebSocket disconnected: {e}")
break
except json.JSONDecodeError:
logger.error("Received invalid JSON message")
except Exception as e:
logger.error(f"Error handling message: {e}")
break
async def _route_message(self, message) -> None:
"""Route parsed message to appropriate handler."""
message_type = message.type
try:
if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
await self._handle_set_subscription_list(message)
elif message_type == MessageTypes.SET_SESSION_ID:
await self._handle_set_session_id(message)
elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
await self._handle_set_progression_stage(message)
elif message_type == MessageTypes.REQUEST_STATE:
await self._handle_request_state(message)
elif message_type == MessageTypes.PATCH_SESSION_RESULT:
await self._handle_patch_session_result(message)
else:
logger.warning(f"Unknown message type: {message_type}")
except Exception as e:
logger.error(f"Error handling {message_type} message: {e}")
async def _handle_set_subscription_list(self, message: SetSubscriptionListMessage) -> None:
"""Handle setSubscriptionList message for declarative subscription management."""
logger.info(f"[RX Processing] setSubscriptionList with {len(message.subscriptions)} subscriptions")
# Update worker state with new subscriptions
worker_state.set_subscriptions(message.subscriptions)
# TODO: Phase 2 - Integrate with model management and streaming
# For now, just log the subscription changes
for subscription in message.subscriptions:
logger.info(f" Subscription: {subscription.subscriptionIdentifier} -> "
f"Model {subscription.modelId} ({subscription.modelName})")
if subscription.rtspUrl:
logger.debug(f" RTSP: {subscription.rtspUrl}")
if subscription.snapshotUrl:
logger.debug(f" Snapshot: {subscription.snapshotUrl} ({subscription.snapshotInterval}ms)")
if subscription.modelUrl:
logger.debug(f" Model: {subscription.modelUrl}")
logger.info("Subscription list updated successfully")
async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None:
"""Handle setSessionId message."""
display_identifier = message.payload.displayIdentifier
session_id = message.payload.sessionId
logger.info(f"[RX Processing] setSessionId for display {display_identifier}: {session_id}")
# Update worker state
worker_state.set_session_id(display_identifier, session_id)
async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None:
"""Handle setProgressionStage message."""
display_identifier = message.payload.displayIdentifier
stage = message.payload.progressionStage
logger.info(f"[RX Processing] setProgressionStage for display {display_identifier}: {stage}")
# Update worker state
worker_state.set_progression_stage(display_identifier, stage)
async def _handle_request_state(self, message: RequestStateMessage) -> None:
"""Handle requestState message by sending immediate state report."""
logger.debug("[RX Processing] requestState - sending immediate state report")
# Collect metrics and send state report
cpu_usage = SystemMetrics.get_cpu_usage()
memory_usage = SystemMetrics.get_memory_usage()
gpu_usage = SystemMetrics.get_gpu_usage()
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
camera_connections = worker_state.get_camera_connections()
state_report = create_state_report(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
gpu_usage=gpu_usage,
gpu_memory_usage=gpu_memory_usage,
camera_connections=camera_connections
)
await self._send_message(state_report)
async def _handle_patch_session_result(self, message: PatchSessionResultMessage) -> None:
"""Handle patchSessionResult message."""
payload = message.payload
logger.info(f"[RX Processing] patchSessionResult for session {payload.sessionId}: "
f"success={payload.success}, message='{payload.message}'")
# TODO: Handle patch session result if needed
# For now, just log the response
async def _send_message(self, message) -> None:
"""Send message to backend via WebSocket."""
if not self.connected:
logger.warning("Cannot send message: WebSocket not connected")
return
try:
json_message = serialize_outgoing_message(message)
await self.websocket.send_text(json_message)
# Log heartbeats at INFO level with simplified format
if hasattr(message, 'type') and message.type == 'stateReport':
logger.info(f"[TX → Backend] {message.type}")
else:
logger.info(f"[TX → Backend] {json_message}")
except Exception as e:
logger.error(f"Failed to send WebSocket message: {e}")
raise
async def _process_streams(self) -> None:
"""
Stream processing task that handles frame processing and detection.
This is a placeholder for Phase 2 - currently just logs that it's running.
"""
logger.info("Stream processing task started")
try:
while self.connected:
# Get current subscriptions
subscriptions = worker_state.get_all_subscriptions()
# TODO: Phase 2 - Add actual frame processing logic here
# This will include:
# - Frame reading from RTSP/HTTP streams
# - Model inference using loaded pipelines
# - Detection result sending via WebSocket
# Sleep to prevent excessive CPU usage (similar to old poll_interval)
await asyncio.sleep(0.1) # 100ms polling interval
except asyncio.CancelledError:
logger.info("Stream processing task cancelled")
except Exception as e:
logger.error(f"Error in stream processing: {e}", exc_info=True)
async def _cleanup(self) -> None:
"""Clean up resources when connection closes."""
logger.info("Cleaning up WebSocket connection")
self.connected = False
# Cancel background tasks
if self._heartbeat_task and not self._heartbeat_task.done():
self._heartbeat_task.cancel()
if self._message_task and not self._message_task.done():
self._message_task.cancel()
# Clear worker state
worker_state.set_subscriptions([])
worker_state.session_ids.clear()
worker_state.progression_stages.clear()
logger.info("WebSocket connection cleanup completed")
# Factory function for FastAPI integration
async def websocket_endpoint(websocket: WebSocket) -> None:
"""
FastAPI WebSocket endpoint handler.
Args:
websocket: FastAPI WebSocket connection
"""
handler = WebSocketHandler(websocket)
await handler.handle_connection()