python-detector-worker/core/communication/websocket.py
2025-09-26 13:05:58 +07:00

677 lines
No EOL
29 KiB
Python

"""
WebSocket message handling and protocol implementation.
"""
import asyncio
import json
import logging
import os
import cv2
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Optional
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
from .messages import (
parse_incoming_message, serialize_outgoing_message,
MessageTypes, create_state_report
)
from .models import (
SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
RequestStateMessage, PatchSessionResultMessage
)
from .state import worker_state, SystemMetrics
from ..models import ModelManager
from ..streaming.manager import shared_stream_manager
from ..tracking.integration import TrackingPipelineIntegration
logger = logging.getLogger(__name__)
# Constants
HEARTBEAT_INTERVAL = 2.0 # seconds
WORKER_TIMEOUT_MS = 10000
# Global model manager instance
model_manager = ModelManager()
class WebSocketHandler:
"""
Handles WebSocket connection lifecycle and message processing.
"""
def __init__(self, websocket: WebSocket):
self.websocket = websocket
self.connected = False
self._heartbeat_task: Optional[asyncio.Task] = None
self._message_task: Optional[asyncio.Task] = None
self._heartbeat_count = 0
self._last_processed_models: set = set() # Cache of last processed model IDs
async def handle_connection(self) -> None:
"""
Main connection handler that manages the WebSocket lifecycle.
Based on the original architecture from archive/app.py
"""
client_info = f"{self.websocket.client.host}:{self.websocket.client.port}" if self.websocket.client else "unknown"
logger.info(f"Starting WebSocket handler for {client_info}")
stream_task = None
try:
logger.info(f"Accepting WebSocket connection from {client_info}")
await self.websocket.accept()
self.connected = True
logger.info(f"WebSocket connection accepted and established for {client_info}")
# Send immediate heartbeat to show connection is alive
await self._send_immediate_heartbeat()
# Start background tasks (matching original architecture)
stream_task = asyncio.create_task(self._process_streams())
heartbeat_task = asyncio.create_task(self._send_heartbeat())
message_task = asyncio.create_task(self._handle_messages())
logger.info(f"WebSocket background tasks started for {client_info} (stream + heartbeat + message handler)")
# Wait for heartbeat and message tasks (stream runs independently)
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logger.error(f"Error in WebSocket connection for {client_info}: {e}", exc_info=True)
finally:
logger.info(f"Cleaning up connection for {client_info}")
# Cancel stream task
if stream_task and not stream_task.done():
stream_task.cancel()
try:
await stream_task
except asyncio.CancelledError:
logger.debug(f"Stream task cancelled for {client_info}")
await self._cleanup()
async def _send_immediate_heartbeat(self) -> None:
"""Send immediate heartbeat on connection to show we're alive."""
try:
cpu_usage = SystemMetrics.get_cpu_usage()
memory_usage = SystemMetrics.get_memory_usage()
gpu_usage = SystemMetrics.get_gpu_usage()
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
camera_connections = worker_state.get_camera_connections()
state_report = create_state_report(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
gpu_usage=gpu_usage,
gpu_memory_usage=gpu_memory_usage,
camera_connections=camera_connections
)
await self._send_message(state_report)
logger.info(f"[TX → Backend] Initial stateReport: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
except Exception as e:
logger.error(f"Error sending immediate heartbeat: {e}")
async def _send_heartbeat(self) -> None:
"""Send periodic state reports as heartbeat."""
while self.connected:
try:
# Collect system metrics
cpu_usage = SystemMetrics.get_cpu_usage()
memory_usage = SystemMetrics.get_memory_usage()
gpu_usage = SystemMetrics.get_gpu_usage()
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
camera_connections = worker_state.get_camera_connections()
# Create and send state report
state_report = create_state_report(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
gpu_usage=gpu_usage,
gpu_memory_usage=gpu_memory_usage,
camera_connections=camera_connections
)
await self._send_message(state_report)
# Only log full details every 10th heartbeat, otherwise just show a dot
self._heartbeat_count += 1
if self._heartbeat_count % 10 == 0:
logger.info(f"[TX → Backend] Heartbeat #{self._heartbeat_count}: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
else:
print(".", end="", flush=True) # Just show a dot to indicate heartbeat activity
await asyncio.sleep(HEARTBEAT_INTERVAL)
except Exception as e:
logger.error(f"Error sending heartbeat: {e}")
break
async def _handle_messages(self) -> None:
"""Handle incoming WebSocket messages."""
while self.connected:
try:
raw_message = await self.websocket.receive_text()
logger.info(f"[RX ← Backend] {raw_message}")
# Parse incoming message
message = parse_incoming_message(raw_message)
if not message:
logger.warning("Failed to parse incoming message")
continue
# Route message to appropriate handler
await self._route_message(message)
except (WebSocketDisconnect, ConnectionClosedError) as e:
logger.warning(f"WebSocket disconnected: {e}")
break
except json.JSONDecodeError:
logger.error("Received invalid JSON message")
except Exception as e:
logger.error(f"Error handling message: {e}")
break
async def _route_message(self, message) -> None:
"""Route parsed message to appropriate handler."""
message_type = message.type
try:
if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
await self._handle_set_subscription_list(message)
elif message_type == MessageTypes.SET_SESSION_ID:
await self._handle_set_session_id(message)
elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
await self._handle_set_progression_stage(message)
elif message_type == MessageTypes.REQUEST_STATE:
await self._handle_request_state(message)
elif message_type == MessageTypes.PATCH_SESSION_RESULT:
await self._handle_patch_session_result(message)
else:
logger.warning(f"Unknown message type: {message_type}")
except Exception as e:
logger.error(f"Error handling {message_type} message: {e}")
async def _handle_set_subscription_list(self, message: SetSubscriptionListMessage) -> None:
"""Handle setSubscriptionList message for declarative subscription management."""
logger.info(f"[RX Processing] setSubscriptionList with {len(message.subscriptions)} subscriptions")
# Update worker state with new subscriptions
worker_state.set_subscriptions(message.subscriptions)
# Phase 2: Download and manage models
await self._ensure_models(message.subscriptions)
# Phase 3 & 4: Integrate with streaming management and tracking
await self._update_stream_subscriptions(message.subscriptions)
logger.info("Subscription list updated successfully")
async def _ensure_models(self, subscriptions) -> None:
"""Ensure all required models are downloaded and available."""
# Extract unique model requirements
unique_models = {}
for subscription in subscriptions:
model_id = subscription.modelId
if model_id not in unique_models:
unique_models[model_id] = {
'model_url': subscription.modelUrl,
'model_name': subscription.modelName
}
# Check if model set has changed to avoid redundant processing
current_model_ids = set(unique_models.keys())
if current_model_ids == self._last_processed_models:
logger.debug(f"[Model Management] Model set unchanged {list(current_model_ids)}, skipping checks")
return
logger.info(f"[Model Management] Processing {len(unique_models)} unique models: {list(unique_models.keys())}")
self._last_processed_models = current_model_ids
# Check and download models concurrently
download_tasks = []
for model_id, model_info in unique_models.items():
task = asyncio.create_task(
self._ensure_single_model(model_id, model_info['model_url'], model_info['model_name'])
)
download_tasks.append(task)
# Wait for all downloads to complete
if download_tasks:
results = await asyncio.gather(*download_tasks, return_exceptions=True)
# Log results
success_count = 0
for i, result in enumerate(results):
model_id = list(unique_models.keys())[i]
if isinstance(result, Exception):
logger.error(f"[Model Management] Failed to ensure model {model_id}: {result}")
elif result:
success_count += 1
logger.info(f"[Model Management] Model {model_id} ready for use")
else:
logger.error(f"[Model Management] Failed to ensure model {model_id}")
logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models")
async def _update_stream_subscriptions(self, subscriptions) -> None:
"""Update streaming subscriptions with tracking integration."""
try:
# Convert subscriptions to the format expected by StreamManager
subscription_payloads = []
for subscription in subscriptions:
payload = {
'subscriptionIdentifier': subscription.subscriptionIdentifier,
'rtspUrl': subscription.rtspUrl,
'snapshotUrl': subscription.snapshotUrl,
'snapshotInterval': subscription.snapshotInterval,
'modelId': subscription.modelId,
'modelUrl': subscription.modelUrl,
'modelName': subscription.modelName
}
# Add crop coordinates if present
if hasattr(subscription, 'cropX1'):
payload.update({
'cropX1': subscription.cropX1,
'cropY1': subscription.cropY1,
'cropX2': subscription.cropX2,
'cropY2': subscription.cropY2
})
subscription_payloads.append(payload)
# Reconcile subscriptions with StreamManager
logger.info("[Streaming] Reconciling stream subscriptions with tracking")
reconcile_result = await self._reconcile_subscriptions_with_tracking(subscription_payloads)
logger.info(f"[Streaming] Subscription reconciliation complete: "
f"added={reconcile_result.get('added', 0)}, "
f"removed={reconcile_result.get('removed', 0)}, "
f"failed={reconcile_result.get('failed', 0)}")
except Exception as e:
logger.error(f"Error updating stream subscriptions: {e}", exc_info=True)
async def _reconcile_subscriptions_with_tracking(self, target_subscriptions) -> dict:
"""Reconcile subscriptions with tracking integration."""
try:
# First, we need to create tracking integrations for each unique model
tracking_integrations = {}
for subscription_payload in target_subscriptions:
model_id = subscription_payload['modelId']
# Create tracking integration if not already created
if model_id not in tracking_integrations:
# Get pipeline configuration for this model
pipeline_parser = model_manager.get_pipeline_config(model_id)
if pipeline_parser:
# Create tracking integration with message sender
tracking_integration = TrackingPipelineIntegration(
pipeline_parser, model_manager, model_id, self._send_message
)
# Initialize tracking model
success = await tracking_integration.initialize_tracking_model()
if success:
tracking_integrations[model_id] = tracking_integration
logger.info(f"[Tracking] Created tracking integration for model {model_id}")
else:
logger.warning(f"[Tracking] Failed to initialize tracking for model {model_id}")
else:
logger.warning(f"[Tracking] No pipeline config found for model {model_id}")
# Now reconcile with StreamManager, adding tracking integrations
current_subscription_ids = set()
for subscription_info in shared_stream_manager.get_all_subscriptions():
current_subscription_ids.add(subscription_info.subscription_id)
target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions}
# Find subscriptions to remove and add
to_remove = current_subscription_ids - target_subscription_ids
to_add = target_subscription_ids - current_subscription_ids
# Remove old subscriptions
removed_count = 0
for subscription_id in to_remove:
if shared_stream_manager.remove_subscription(subscription_id):
removed_count += 1
logger.info(f"[Streaming] Removed subscription {subscription_id}")
# Add new subscriptions with tracking
added_count = 0
failed_count = 0
for subscription_payload in target_subscriptions:
subscription_id = subscription_payload['subscriptionIdentifier']
if subscription_id in to_add:
success = await self._add_subscription_with_tracking(
subscription_payload, tracking_integrations
)
if success:
added_count += 1
logger.info(f"[Streaming] Added subscription {subscription_id} with tracking")
else:
failed_count += 1
logger.error(f"[Streaming] Failed to add subscription {subscription_id}")
return {
'removed': removed_count,
'added': added_count,
'failed': failed_count,
'total_active': len(shared_stream_manager.get_all_subscriptions())
}
except Exception as e:
logger.error(f"Error in subscription reconciliation with tracking: {e}", exc_info=True)
return {'removed': 0, 'added': 0, 'failed': 0, 'total_active': 0}
async def _add_subscription_with_tracking(self, payload, tracking_integrations) -> bool:
"""Add a subscription with tracking integration."""
try:
from ..streaming.manager import StreamConfig
subscription_id = payload['subscriptionIdentifier']
camera_id = subscription_id.split(';')[-1]
model_id = payload['modelId']
logger.info(f"[SUBSCRIPTION_MAPPING] subscription_id='{subscription_id}' → camera_id='{camera_id}'")
# Get tracking integration for this model
tracking_integration = tracking_integrations.get(model_id)
# Extract crop coordinates if present
crop_coords = None
if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']):
crop_coords = (
payload['cropX1'],
payload['cropY1'],
payload['cropX2'],
payload['cropY2']
)
# Create stream configuration
stream_config = StreamConfig(
camera_id=camera_id,
rtsp_url=payload.get('rtspUrl'),
snapshot_url=payload.get('snapshotUrl'),
snapshot_interval=payload.get('snapshotInterval', 5000),
max_retries=3,
)
# Add subscription to StreamManager with tracking
success = shared_stream_manager.add_subscription(
subscription_id=subscription_id,
stream_config=stream_config,
crop_coords=crop_coords,
model_id=model_id,
model_url=payload.get('modelUrl'),
tracking_integration=tracking_integration
)
if success and tracking_integration:
logger.info(f"[Tracking] Subscription {subscription_id} configured with tracking for model {model_id}")
return success
except Exception as e:
logger.error(f"Error adding subscription with tracking: {e}", exc_info=True)
return False
async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool:
"""Ensure a single model is downloaded and available."""
try:
# Check if model is already available
if model_manager.is_model_downloaded(model_id):
logger.info(f"[Model Management] Model {model_id} ({model_name}) already available")
return True
# Download and extract model in a thread pool to avoid blocking the event loop
logger.info(f"[Model Management] Downloading model {model_id} ({model_name}) from {model_url}")
# Use asyncio.to_thread for CPU-bound operations (Python 3.9+)
# For compatibility, we'll use run_in_executor
loop = asyncio.get_event_loop()
model_path = await loop.run_in_executor(
None,
model_manager.ensure_model,
model_id,
model_url,
model_name
)
if model_path:
logger.info(f"[Model Management] Successfully prepared model {model_id} at {model_path}")
return True
else:
logger.error(f"[Model Management] Failed to prepare model {model_id}")
return False
except Exception as e:
logger.error(f"[Model Management] Exception ensuring model {model_id}: {str(e)}", exc_info=True)
return False
async def _save_snapshot(self, display_identifier: str, session_id: int) -> None:
"""
Save snapshot image to images folder after receiving sessionId.
Args:
display_identifier: Display identifier to match with subscriptionIdentifier
session_id: Session ID to include in filename
"""
try:
# Find subscription that matches the displayIdentifier
matching_subscription = None
for subscription in worker_state.get_all_subscriptions():
# Extract display ID from subscriptionIdentifier (format: displayId;cameraId)
from .messages import extract_display_identifier
sub_display_id = extract_display_identifier(subscription.subscriptionIdentifier)
if sub_display_id == display_identifier:
matching_subscription = subscription
break
if not matching_subscription:
logger.error(f"[Snapshot Save] No subscription found for display {display_identifier}")
return
if not matching_subscription.snapshotUrl:
logger.error(f"[Snapshot Save] No snapshotUrl found for display {display_identifier}")
return
# Ensure images directory exists (relative path for Docker bind mount)
images_dir = Path("images")
images_dir.mkdir(exist_ok=True)
# Generate filename with timestamp and session ID
timestamp = datetime.now(tz=timezone(timedelta(hours=7))).strftime("%Y%m%d_%H%M%S")
filename = f"{session_id}_{display_identifier}_{timestamp}.jpg"
filepath = images_dir / filename
# Use existing HTTPSnapshotReader to fetch snapshot
logger.info(f"[Snapshot Save] Fetching snapshot from {matching_subscription.snapshotUrl}")
# Run snapshot fetch in thread pool to avoid blocking async loop
loop = asyncio.get_event_loop()
frame = await loop.run_in_executor(None, self._fetch_snapshot_sync, matching_subscription.snapshotUrl)
if frame is not None:
# Save the image using OpenCV
success = cv2.imwrite(str(filepath), frame)
if success:
logger.info(f"[Snapshot Save] Successfully saved snapshot to {filepath}")
else:
logger.error(f"[Snapshot Save] Failed to save image file {filepath}")
else:
logger.error(f"[Snapshot Save] Failed to fetch snapshot from {matching_subscription.snapshotUrl}")
except Exception as e:
logger.error(f"[Snapshot Save] Error saving snapshot for display {display_identifier}: {e}", exc_info=True)
def _fetch_snapshot_sync(self, snapshot_url: str):
"""
Synchronous snapshot fetching using existing HTTPSnapshotReader infrastructure.
Args:
snapshot_url: URL to fetch snapshot from
Returns:
np.ndarray or None: Fetched frame or None on error
"""
try:
from ..streaming.readers import HTTPSnapshotReader
# Create temporary snapshot reader for single fetch
snapshot_reader = HTTPSnapshotReader(
camera_id="temp_snapshot",
snapshot_url=snapshot_url,
interval_ms=5000 # Not used for single fetch
)
# Use existing fetch_single_snapshot method
return snapshot_reader.fetch_single_snapshot()
except Exception as e:
logger.error(f"Error in sync snapshot fetch: {e}")
return None
async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None:
"""Handle setSessionId message."""
display_identifier = message.payload.displayIdentifier
session_id = message.payload.sessionId
logger.info(f"[RX Processing] setSessionId for display {display_identifier}: {session_id}")
# Update worker state
worker_state.set_session_id(display_identifier, session_id)
# Update tracking integrations with session ID
shared_stream_manager.set_session_id(display_identifier, session_id)
# Save snapshot image after getting sessionId
if session_id:
await self._save_snapshot(display_identifier, session_id)
async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None:
"""Handle setProgressionStage message."""
display_identifier = message.payload.displayIdentifier
stage = message.payload.progressionStage
logger.info(f"[RX Processing] setProgressionStage for display {display_identifier}: {stage}")
# Update worker state
worker_state.set_progression_stage(display_identifier, stage)
# Update tracking integration for car abandonment detection
session_id = worker_state.get_session_id(display_identifier)
if session_id:
shared_stream_manager.set_progression_stage(session_id, stage)
# If stage indicates session is cleared/finished, clear from tracking
if stage in ['finished', 'cleared', 'idle']:
# Get session ID for this display and clear it
if session_id:
shared_stream_manager.clear_session_id(session_id)
logger.info(f"[Tracking] Cleared session {session_id} due to progression stage: {stage}")
async def _handle_request_state(self, message: RequestStateMessage) -> None:
"""Handle requestState message by sending immediate state report."""
logger.debug("[RX Processing] requestState - sending immediate state report")
# Collect metrics and send state report
cpu_usage = SystemMetrics.get_cpu_usage()
memory_usage = SystemMetrics.get_memory_usage()
gpu_usage = SystemMetrics.get_gpu_usage()
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
camera_connections = worker_state.get_camera_connections()
state_report = create_state_report(
cpu_usage=cpu_usage,
memory_usage=memory_usage,
gpu_usage=gpu_usage,
gpu_memory_usage=gpu_memory_usage,
camera_connections=camera_connections
)
await self._send_message(state_report)
async def _handle_patch_session_result(self, message: PatchSessionResultMessage) -> None:
"""Handle patchSessionResult message."""
payload = message.payload
logger.info(f"[RX Processing] patchSessionResult for session {payload.sessionId}: "
f"success={payload.success}, message='{payload.message}'")
# TODO: Handle patch session result if needed
# For now, just log the response
async def _send_message(self, message) -> None:
"""Send message to backend via WebSocket."""
if not self.connected:
logger.warning("Cannot send message: WebSocket not connected")
return
try:
json_message = serialize_outgoing_message(message)
await self.websocket.send_text(json_message)
# Log non-heartbeat messages only (heartbeats are logged in their respective functions)
if not (hasattr(message, 'type') and message.type == 'stateReport'):
logger.info(f"[TX → Backend] {json_message}")
except Exception as e:
logger.error(f"Failed to send WebSocket message: {e}")
raise
async def _process_streams(self) -> None:
"""
Stream processing task that handles frame processing and detection.
This is a placeholder for Phase 2 - currently just logs that it's running.
"""
logger.info("Stream processing task started")
try:
while self.connected:
# Get current subscriptions
subscriptions = worker_state.get_all_subscriptions()
# TODO: Phase 2 - Add actual frame processing logic here
# This will include:
# - Frame reading from RTSP/HTTP streams
# - Model inference using loaded pipelines
# - Detection result sending via WebSocket
# Sleep to prevent excessive CPU usage (similar to old poll_interval)
await asyncio.sleep(0.1) # 100ms polling interval
except asyncio.CancelledError:
logger.info("Stream processing task cancelled")
except Exception as e:
logger.error(f"Error in stream processing: {e}", exc_info=True)
async def _cleanup(self) -> None:
"""Clean up resources when connection closes."""
logger.info("Cleaning up WebSocket connection")
self.connected = False
# Cancel background tasks
if self._heartbeat_task and not self._heartbeat_task.done():
self._heartbeat_task.cancel()
if self._message_task and not self._message_task.done():
self._message_task.cancel()
# Clear worker state
worker_state.set_subscriptions([])
worker_state.session_ids.clear()
worker_state.progression_stages.clear()
logger.info("WebSocket connection cleanup completed")
# Factory function for FastAPI integration
async def websocket_endpoint(websocket: WebSocket) -> None:
"""
FastAPI WebSocket endpoint handler.
Args:
websocket: FastAPI WebSocket connection
"""
handler = WebSocketHandler(websocket)
await handler.handle_connection()