refactor: done phase 1
This commit is contained in:
parent
f7c464be21
commit
cbbed3d933
13 changed files with 1084 additions and 891 deletions
326
core/communication/websocket.py
Normal file
326
core/communication/websocket.py
Normal file
|
@ -0,0 +1,326 @@
|
|||
"""
|
||||
WebSocket message handling and protocol implementation.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
from .messages import (
|
||||
parse_incoming_message, serialize_outgoing_message,
|
||||
MessageTypes, create_state_report
|
||||
)
|
||||
from .models import (
|
||||
SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
|
||||
RequestStateMessage, PatchSessionResultMessage
|
||||
)
|
||||
from .state import worker_state, SystemMetrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
HEARTBEAT_INTERVAL = 2.0 # seconds
|
||||
WORKER_TIMEOUT_MS = 10000
|
||||
|
||||
|
||||
class WebSocketHandler:
|
||||
"""
|
||||
Handles WebSocket connection lifecycle and message processing.
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
self.websocket = websocket
|
||||
self.connected = False
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._message_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def handle_connection(self) -> None:
|
||||
"""
|
||||
Main connection handler that manages the WebSocket lifecycle.
|
||||
Based on the original architecture from archive/app.py
|
||||
"""
|
||||
client_info = f"{self.websocket.client.host}:{self.websocket.client.port}" if self.websocket.client else "unknown"
|
||||
logger.info(f"Starting WebSocket handler for {client_info}")
|
||||
|
||||
stream_task = None
|
||||
try:
|
||||
logger.info(f"Accepting WebSocket connection from {client_info}")
|
||||
await self.websocket.accept()
|
||||
self.connected = True
|
||||
logger.info(f"WebSocket connection accepted and established for {client_info}")
|
||||
|
||||
# Send immediate heartbeat to show connection is alive
|
||||
await self._send_immediate_heartbeat()
|
||||
|
||||
# Start background tasks (matching original architecture)
|
||||
stream_task = asyncio.create_task(self._process_streams())
|
||||
heartbeat_task = asyncio.create_task(self._send_heartbeat())
|
||||
message_task = asyncio.create_task(self._handle_messages())
|
||||
|
||||
logger.info(f"WebSocket background tasks started for {client_info} (stream + heartbeat + message handler)")
|
||||
|
||||
# Wait for heartbeat and message tasks (stream runs independently)
|
||||
await asyncio.gather(heartbeat_task, message_task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in WebSocket connection for {client_info}: {e}", exc_info=True)
|
||||
finally:
|
||||
logger.info(f"Cleaning up connection for {client_info}")
|
||||
# Cancel stream task
|
||||
if stream_task and not stream_task.done():
|
||||
stream_task.cancel()
|
||||
try:
|
||||
await stream_task
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"Stream task cancelled for {client_info}")
|
||||
await self._cleanup()
|
||||
|
||||
async def _send_immediate_heartbeat(self) -> None:
|
||||
"""Send immediate heartbeat on connection to show we're alive."""
|
||||
try:
|
||||
cpu_usage = SystemMetrics.get_cpu_usage()
|
||||
memory_usage = SystemMetrics.get_memory_usage()
|
||||
gpu_usage = SystemMetrics.get_gpu_usage()
|
||||
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
|
||||
camera_connections = worker_state.get_camera_connections()
|
||||
|
||||
state_report = create_state_report(
|
||||
cpu_usage=cpu_usage,
|
||||
memory_usage=memory_usage,
|
||||
gpu_usage=gpu_usage,
|
||||
gpu_memory_usage=gpu_memory_usage,
|
||||
camera_connections=camera_connections
|
||||
)
|
||||
|
||||
await self._send_message(state_report)
|
||||
logger.info(f"Sent immediate stateReport: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
|
||||
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending immediate heartbeat: {e}")
|
||||
|
||||
async def _send_heartbeat(self) -> None:
|
||||
"""Send periodic state reports as heartbeat."""
|
||||
while self.connected:
|
||||
try:
|
||||
# Collect system metrics
|
||||
cpu_usage = SystemMetrics.get_cpu_usage()
|
||||
memory_usage = SystemMetrics.get_memory_usage()
|
||||
gpu_usage = SystemMetrics.get_gpu_usage()
|
||||
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
|
||||
camera_connections = worker_state.get_camera_connections()
|
||||
|
||||
# Create and send state report
|
||||
state_report = create_state_report(
|
||||
cpu_usage=cpu_usage,
|
||||
memory_usage=memory_usage,
|
||||
gpu_usage=gpu_usage,
|
||||
gpu_memory_usage=gpu_memory_usage,
|
||||
camera_connections=camera_connections
|
||||
)
|
||||
|
||||
await self._send_message(state_report)
|
||||
logger.debug(f"Sent heartbeat: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
|
||||
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
|
||||
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending heartbeat: {e}")
|
||||
break
|
||||
|
||||
async def _handle_messages(self) -> None:
|
||||
"""Handle incoming WebSocket messages."""
|
||||
while self.connected:
|
||||
try:
|
||||
raw_message = await self.websocket.receive_text()
|
||||
logger.info(f"Received message: {raw_message}")
|
||||
|
||||
# Parse incoming message
|
||||
message = parse_incoming_message(raw_message)
|
||||
if not message:
|
||||
logger.warning("Failed to parse incoming message")
|
||||
continue
|
||||
|
||||
# Route message to appropriate handler
|
||||
await self._route_message(message)
|
||||
|
||||
except (WebSocketDisconnect, ConnectionClosedError) as e:
|
||||
logger.warning(f"WebSocket disconnected: {e}")
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Received invalid JSON message")
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
break
|
||||
|
||||
async def _route_message(self, message) -> None:
|
||||
"""Route parsed message to appropriate handler."""
|
||||
message_type = message.type
|
||||
|
||||
try:
|
||||
if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
|
||||
await self._handle_set_subscription_list(message)
|
||||
elif message_type == MessageTypes.SET_SESSION_ID:
|
||||
await self._handle_set_session_id(message)
|
||||
elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
|
||||
await self._handle_set_progression_stage(message)
|
||||
elif message_type == MessageTypes.REQUEST_STATE:
|
||||
await self._handle_request_state(message)
|
||||
elif message_type == MessageTypes.PATCH_SESSION_RESULT:
|
||||
await self._handle_patch_session_result(message)
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {message_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling {message_type} message: {e}")
|
||||
|
||||
async def _handle_set_subscription_list(self, message: SetSubscriptionListMessage) -> None:
|
||||
"""Handle setSubscriptionList message for declarative subscription management."""
|
||||
logger.info(f"Processing setSubscriptionList with {len(message.subscriptions)} subscriptions")
|
||||
|
||||
# Update worker state with new subscriptions
|
||||
worker_state.set_subscriptions(message.subscriptions)
|
||||
|
||||
# TODO: Phase 2 - Integrate with model management and streaming
|
||||
# For now, just log the subscription changes
|
||||
for subscription in message.subscriptions:
|
||||
logger.info(f" Subscription: {subscription.subscriptionIdentifier} -> "
|
||||
f"Model {subscription.modelId} ({subscription.modelName})")
|
||||
if subscription.rtspUrl:
|
||||
logger.debug(f" RTSP: {subscription.rtspUrl}")
|
||||
if subscription.snapshotUrl:
|
||||
logger.debug(f" Snapshot: {subscription.snapshotUrl} ({subscription.snapshotInterval}ms)")
|
||||
if subscription.modelUrl:
|
||||
logger.debug(f" Model: {subscription.modelUrl}")
|
||||
|
||||
logger.info("Subscription list updated successfully")
|
||||
|
||||
async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None:
|
||||
"""Handle setSessionId message."""
|
||||
display_identifier = message.payload.displayIdentifier
|
||||
session_id = message.payload.sessionId
|
||||
|
||||
logger.info(f"Setting session ID for display {display_identifier}: {session_id}")
|
||||
|
||||
# Update worker state
|
||||
worker_state.set_session_id(display_identifier, session_id)
|
||||
|
||||
async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None:
|
||||
"""Handle setProgressionStage message."""
|
||||
display_identifier = message.payload.displayIdentifier
|
||||
stage = message.payload.progressionStage
|
||||
|
||||
logger.info(f"Setting progression stage for display {display_identifier}: {stage}")
|
||||
|
||||
# Update worker state
|
||||
worker_state.set_progression_stage(display_identifier, stage)
|
||||
|
||||
async def _handle_request_state(self, message: RequestStateMessage) -> None:
|
||||
"""Handle requestState message by sending immediate state report."""
|
||||
logger.debug("Received requestState, sending immediate state report")
|
||||
|
||||
# Collect metrics and send state report
|
||||
cpu_usage = SystemMetrics.get_cpu_usage()
|
||||
memory_usage = SystemMetrics.get_memory_usage()
|
||||
gpu_usage = SystemMetrics.get_gpu_usage()
|
||||
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
|
||||
camera_connections = worker_state.get_camera_connections()
|
||||
|
||||
state_report = create_state_report(
|
||||
cpu_usage=cpu_usage,
|
||||
memory_usage=memory_usage,
|
||||
gpu_usage=gpu_usage,
|
||||
gpu_memory_usage=gpu_memory_usage,
|
||||
camera_connections=camera_connections
|
||||
)
|
||||
|
||||
await self._send_message(state_report)
|
||||
|
||||
async def _handle_patch_session_result(self, message: PatchSessionResultMessage) -> None:
|
||||
"""Handle patchSessionResult message."""
|
||||
payload = message.payload
|
||||
logger.info(f"Received patch session result for session {payload.sessionId}: "
|
||||
f"success={payload.success}, message='{payload.message}'")
|
||||
|
||||
# TODO: Handle patch session result if needed
|
||||
# For now, just log the response
|
||||
|
||||
async def _send_message(self, message) -> None:
|
||||
"""Send message to backend via WebSocket."""
|
||||
if not self.connected:
|
||||
logger.warning("Cannot send message: WebSocket not connected")
|
||||
return
|
||||
|
||||
try:
|
||||
json_message = serialize_outgoing_message(message)
|
||||
await self.websocket.send_text(json_message)
|
||||
# Don't log full message for heartbeats to avoid spam, just type
|
||||
if hasattr(message, 'type') and message.type == 'stateReport':
|
||||
logger.debug(f"Sent message: {message.type}")
|
||||
else:
|
||||
logger.debug(f"Sent message: {json_message}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send WebSocket message: {e}")
|
||||
raise
|
||||
|
||||
async def _process_streams(self) -> None:
|
||||
"""
|
||||
Stream processing task that handles frame processing and detection.
|
||||
This is a placeholder for Phase 2 - currently just logs that it's running.
|
||||
"""
|
||||
logger.info("Stream processing task started")
|
||||
try:
|
||||
while self.connected:
|
||||
# Get current subscriptions
|
||||
subscriptions = worker_state.get_all_subscriptions()
|
||||
|
||||
if subscriptions:
|
||||
logger.debug(f"Stream processor running with {len(subscriptions)} active subscriptions")
|
||||
# TODO: Phase 2 - Add actual frame processing logic here
|
||||
# This will include:
|
||||
# - Frame reading from RTSP/HTTP streams
|
||||
# - Model inference using loaded pipelines
|
||||
# - Detection result sending via WebSocket
|
||||
else:
|
||||
logger.debug("Stream processor running with no active subscriptions")
|
||||
|
||||
# Sleep to prevent excessive CPU usage (similar to old poll_interval)
|
||||
await asyncio.sleep(0.1) # 100ms polling interval
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Stream processing task cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream processing: {e}", exc_info=True)
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""Clean up resources when connection closes."""
|
||||
logger.info("Cleaning up WebSocket connection")
|
||||
self.connected = False
|
||||
|
||||
# Cancel background tasks
|
||||
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||
self._heartbeat_task.cancel()
|
||||
if self._message_task and not self._message_task.done():
|
||||
self._message_task.cancel()
|
||||
|
||||
# Clear worker state
|
||||
worker_state.set_subscriptions([])
|
||||
worker_state.session_ids.clear()
|
||||
worker_state.progression_stages.clear()
|
||||
|
||||
logger.info("WebSocket connection cleanup completed")
|
||||
|
||||
|
||||
# Factory function for FastAPI integration
|
||||
async def websocket_endpoint(websocket: WebSocket) -> None:
|
||||
"""
|
||||
FastAPI WebSocket endpoint handler.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection
|
||||
"""
|
||||
handler = WebSocketHandler(websocket)
|
||||
await handler.handle_connection()
|
Loading…
Add table
Add a link
Reference in a new issue