All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 7s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 3m31s
Build Worker Base and Application Images / deploy-stack (push) Successful in 15s
677 lines
No EOL
29 KiB
Python
677 lines
No EOL
29 KiB
Python
"""
|
|
WebSocket message handling and protocol implementation.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import cv2
|
|
from datetime import datetime, timezone, timedelta
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from websockets.exceptions import ConnectionClosedError
|
|
|
|
from .messages import (
|
|
parse_incoming_message, serialize_outgoing_message,
|
|
MessageTypes, create_state_report
|
|
)
|
|
from .models import (
|
|
SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
|
|
RequestStateMessage, PatchSessionResultMessage
|
|
)
|
|
from .state import worker_state, SystemMetrics
|
|
from ..models import ModelManager
|
|
from ..streaming.manager import shared_stream_manager
|
|
from ..tracking.integration import TrackingPipelineIntegration
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Constants
|
|
HEARTBEAT_INTERVAL = 2.0 # seconds
|
|
WORKER_TIMEOUT_MS = 10000
|
|
|
|
# Global model manager instance
|
|
model_manager = ModelManager()
|
|
|
|
|
|
class WebSocketHandler:
|
|
"""
|
|
Handles WebSocket connection lifecycle and message processing.
|
|
"""
|
|
|
|
def __init__(self, websocket: WebSocket):
|
|
self.websocket = websocket
|
|
self.connected = False
|
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
|
self._message_task: Optional[asyncio.Task] = None
|
|
self._heartbeat_count = 0
|
|
self._last_processed_models: set = set() # Cache of last processed model IDs
|
|
|
|
async def handle_connection(self) -> None:
|
|
"""
|
|
Main connection handler that manages the WebSocket lifecycle.
|
|
Based on the original architecture from archive/app.py
|
|
"""
|
|
client_info = f"{self.websocket.client.host}:{self.websocket.client.port}" if self.websocket.client else "unknown"
|
|
logger.info(f"Starting WebSocket handler for {client_info}")
|
|
|
|
stream_task = None
|
|
try:
|
|
logger.info(f"Accepting WebSocket connection from {client_info}")
|
|
await self.websocket.accept()
|
|
self.connected = True
|
|
logger.info(f"WebSocket connection accepted and established for {client_info}")
|
|
|
|
# Send immediate heartbeat to show connection is alive
|
|
await self._send_immediate_heartbeat()
|
|
|
|
# Start background tasks (matching original architecture)
|
|
stream_task = asyncio.create_task(self._process_streams())
|
|
heartbeat_task = asyncio.create_task(self._send_heartbeat())
|
|
message_task = asyncio.create_task(self._handle_messages())
|
|
|
|
logger.info(f"WebSocket background tasks started for {client_info} (stream + heartbeat + message handler)")
|
|
|
|
# Wait for heartbeat and message tasks (stream runs independently)
|
|
await asyncio.gather(heartbeat_task, message_task)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in WebSocket connection for {client_info}: {e}", exc_info=True)
|
|
finally:
|
|
logger.info(f"Cleaning up connection for {client_info}")
|
|
# Cancel stream task
|
|
if stream_task and not stream_task.done():
|
|
stream_task.cancel()
|
|
try:
|
|
await stream_task
|
|
except asyncio.CancelledError:
|
|
logger.debug(f"Stream task cancelled for {client_info}")
|
|
await self._cleanup()
|
|
|
|
async def _send_immediate_heartbeat(self) -> None:
|
|
"""Send immediate heartbeat on connection to show we're alive."""
|
|
try:
|
|
cpu_usage = SystemMetrics.get_cpu_usage()
|
|
memory_usage = SystemMetrics.get_memory_usage()
|
|
gpu_usage = SystemMetrics.get_gpu_usage()
|
|
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
|
|
camera_connections = worker_state.get_camera_connections()
|
|
|
|
state_report = create_state_report(
|
|
cpu_usage=cpu_usage,
|
|
memory_usage=memory_usage,
|
|
gpu_usage=gpu_usage,
|
|
gpu_memory_usage=gpu_memory_usage,
|
|
camera_connections=camera_connections
|
|
)
|
|
|
|
await self._send_message(state_report)
|
|
logger.info(f"[TX → Backend] Initial stateReport: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
|
|
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sending immediate heartbeat: {e}")
|
|
|
|
async def _send_heartbeat(self) -> None:
|
|
"""Send periodic state reports as heartbeat."""
|
|
while self.connected:
|
|
try:
|
|
# Collect system metrics
|
|
cpu_usage = SystemMetrics.get_cpu_usage()
|
|
memory_usage = SystemMetrics.get_memory_usage()
|
|
gpu_usage = SystemMetrics.get_gpu_usage()
|
|
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
|
|
camera_connections = worker_state.get_camera_connections()
|
|
|
|
# Create and send state report
|
|
state_report = create_state_report(
|
|
cpu_usage=cpu_usage,
|
|
memory_usage=memory_usage,
|
|
gpu_usage=gpu_usage,
|
|
gpu_memory_usage=gpu_memory_usage,
|
|
camera_connections=camera_connections
|
|
)
|
|
|
|
await self._send_message(state_report)
|
|
|
|
# Only log full details every 10th heartbeat, otherwise just show a dot
|
|
self._heartbeat_count += 1
|
|
if self._heartbeat_count % 10 == 0:
|
|
logger.info(f"[TX → Backend] Heartbeat #{self._heartbeat_count}: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
|
|
f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
|
|
else:
|
|
print(".", end="", flush=True) # Just show a dot to indicate heartbeat activity
|
|
|
|
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sending heartbeat: {e}")
|
|
break
|
|
|
|
async def _handle_messages(self) -> None:
|
|
"""Handle incoming WebSocket messages."""
|
|
while self.connected:
|
|
try:
|
|
raw_message = await self.websocket.receive_text()
|
|
logger.info(f"[RX ← Backend] {raw_message}")
|
|
|
|
# Parse incoming message
|
|
message = parse_incoming_message(raw_message)
|
|
if not message:
|
|
logger.warning("Failed to parse incoming message")
|
|
continue
|
|
|
|
# Route message to appropriate handler
|
|
await self._route_message(message)
|
|
|
|
except (WebSocketDisconnect, ConnectionClosedError) as e:
|
|
logger.warning(f"WebSocket disconnected: {e}")
|
|
break
|
|
except json.JSONDecodeError:
|
|
logger.error("Received invalid JSON message")
|
|
except Exception as e:
|
|
logger.error(f"Error handling message: {e}")
|
|
break
|
|
|
|
async def _route_message(self, message) -> None:
|
|
"""Route parsed message to appropriate handler."""
|
|
message_type = message.type
|
|
|
|
try:
|
|
if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
|
|
await self._handle_set_subscription_list(message)
|
|
elif message_type == MessageTypes.SET_SESSION_ID:
|
|
await self._handle_set_session_id(message)
|
|
elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
|
|
await self._handle_set_progression_stage(message)
|
|
elif message_type == MessageTypes.REQUEST_STATE:
|
|
await self._handle_request_state(message)
|
|
elif message_type == MessageTypes.PATCH_SESSION_RESULT:
|
|
await self._handle_patch_session_result(message)
|
|
else:
|
|
logger.warning(f"Unknown message type: {message_type}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling {message_type} message: {e}")
|
|
|
|
async def _handle_set_subscription_list(self, message: SetSubscriptionListMessage) -> None:
|
|
"""Handle setSubscriptionList message for declarative subscription management."""
|
|
logger.info(f"[RX Processing] setSubscriptionList with {len(message.subscriptions)} subscriptions")
|
|
|
|
# Update worker state with new subscriptions
|
|
worker_state.set_subscriptions(message.subscriptions)
|
|
|
|
# Phase 2: Download and manage models
|
|
await self._ensure_models(message.subscriptions)
|
|
|
|
# Phase 3 & 4: Integrate with streaming management and tracking
|
|
await self._update_stream_subscriptions(message.subscriptions)
|
|
|
|
logger.info("Subscription list updated successfully")
|
|
|
|
async def _ensure_models(self, subscriptions) -> None:
|
|
"""Ensure all required models are downloaded and available."""
|
|
# Extract unique model requirements
|
|
unique_models = {}
|
|
for subscription in subscriptions:
|
|
model_id = subscription.modelId
|
|
if model_id not in unique_models:
|
|
unique_models[model_id] = {
|
|
'model_url': subscription.modelUrl,
|
|
'model_name': subscription.modelName
|
|
}
|
|
|
|
# Check if model set has changed to avoid redundant processing
|
|
current_model_ids = set(unique_models.keys())
|
|
if current_model_ids == self._last_processed_models:
|
|
logger.debug(f"[Model Management] Model set unchanged {list(current_model_ids)}, skipping checks")
|
|
return
|
|
|
|
logger.info(f"[Model Management] Processing {len(unique_models)} unique models: {list(unique_models.keys())}")
|
|
self._last_processed_models = current_model_ids
|
|
|
|
# Check and download models concurrently
|
|
download_tasks = []
|
|
for model_id, model_info in unique_models.items():
|
|
task = asyncio.create_task(
|
|
self._ensure_single_model(model_id, model_info['model_url'], model_info['model_name'])
|
|
)
|
|
download_tasks.append(task)
|
|
|
|
# Wait for all downloads to complete
|
|
if download_tasks:
|
|
results = await asyncio.gather(*download_tasks, return_exceptions=True)
|
|
|
|
# Log results
|
|
success_count = 0
|
|
for i, result in enumerate(results):
|
|
model_id = list(unique_models.keys())[i]
|
|
if isinstance(result, Exception):
|
|
logger.error(f"[Model Management] Failed to ensure model {model_id}: {result}")
|
|
elif result:
|
|
success_count += 1
|
|
logger.info(f"[Model Management] Model {model_id} ready for use")
|
|
else:
|
|
logger.error(f"[Model Management] Failed to ensure model {model_id}")
|
|
|
|
logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models")
|
|
|
|
async def _update_stream_subscriptions(self, subscriptions) -> None:
|
|
"""Update streaming subscriptions with tracking integration."""
|
|
try:
|
|
# Convert subscriptions to the format expected by StreamManager
|
|
subscription_payloads = []
|
|
for subscription in subscriptions:
|
|
payload = {
|
|
'subscriptionIdentifier': subscription.subscriptionIdentifier,
|
|
'rtspUrl': subscription.rtspUrl,
|
|
'snapshotUrl': subscription.snapshotUrl,
|
|
'snapshotInterval': subscription.snapshotInterval,
|
|
'modelId': subscription.modelId,
|
|
'modelUrl': subscription.modelUrl,
|
|
'modelName': subscription.modelName
|
|
}
|
|
# Add crop coordinates if present
|
|
if hasattr(subscription, 'cropX1'):
|
|
payload.update({
|
|
'cropX1': subscription.cropX1,
|
|
'cropY1': subscription.cropY1,
|
|
'cropX2': subscription.cropX2,
|
|
'cropY2': subscription.cropY2
|
|
})
|
|
subscription_payloads.append(payload)
|
|
|
|
# Reconcile subscriptions with StreamManager
|
|
logger.info("[Streaming] Reconciling stream subscriptions with tracking")
|
|
reconcile_result = await self._reconcile_subscriptions_with_tracking(subscription_payloads)
|
|
|
|
logger.info(f"[Streaming] Subscription reconciliation complete: "
|
|
f"added={reconcile_result.get('added', 0)}, "
|
|
f"removed={reconcile_result.get('removed', 0)}, "
|
|
f"failed={reconcile_result.get('failed', 0)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating stream subscriptions: {e}", exc_info=True)
|
|
|
|
async def _reconcile_subscriptions_with_tracking(self, target_subscriptions) -> dict:
|
|
"""Reconcile subscriptions with tracking integration."""
|
|
try:
|
|
# Create separate tracking integrations for each subscription (camera isolation)
|
|
tracking_integrations = {}
|
|
|
|
for subscription_payload in target_subscriptions:
|
|
subscription_id = subscription_payload['subscriptionIdentifier']
|
|
model_id = subscription_payload['modelId']
|
|
|
|
# Create separate tracking integration per subscription for camera isolation
|
|
# Get pipeline configuration for this model
|
|
pipeline_parser = model_manager.get_pipeline_config(model_id)
|
|
if pipeline_parser:
|
|
# Create tracking integration with message sender (separate instance per camera)
|
|
tracking_integration = TrackingPipelineIntegration(
|
|
pipeline_parser, model_manager, model_id, self._send_message
|
|
)
|
|
|
|
# Initialize tracking model
|
|
success = await tracking_integration.initialize_tracking_model()
|
|
if success:
|
|
tracking_integrations[subscription_id] = tracking_integration
|
|
logger.info(f"[Tracking] Created isolated tracking integration for subscription {subscription_id} (model {model_id})")
|
|
else:
|
|
logger.warning(f"[Tracking] Failed to initialize tracking for subscription {subscription_id} (model {model_id})")
|
|
else:
|
|
logger.warning(f"[Tracking] No pipeline config found for model {model_id} in subscription {subscription_id}")
|
|
|
|
# Now reconcile with StreamManager, adding tracking integrations
|
|
current_subscription_ids = set()
|
|
for subscription_info in shared_stream_manager.get_all_subscriptions():
|
|
current_subscription_ids.add(subscription_info.subscription_id)
|
|
|
|
target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions}
|
|
|
|
# Find subscriptions to remove and add
|
|
to_remove = current_subscription_ids - target_subscription_ids
|
|
to_add = target_subscription_ids - current_subscription_ids
|
|
|
|
# Remove old subscriptions
|
|
removed_count = 0
|
|
for subscription_id in to_remove:
|
|
if shared_stream_manager.remove_subscription(subscription_id):
|
|
removed_count += 1
|
|
logger.info(f"[Streaming] Removed subscription {subscription_id}")
|
|
|
|
# Add new subscriptions with tracking
|
|
added_count = 0
|
|
failed_count = 0
|
|
for subscription_payload in target_subscriptions:
|
|
subscription_id = subscription_payload['subscriptionIdentifier']
|
|
if subscription_id in to_add:
|
|
success = await self._add_subscription_with_tracking(
|
|
subscription_payload, tracking_integrations
|
|
)
|
|
if success:
|
|
added_count += 1
|
|
logger.info(f"[Streaming] Added subscription {subscription_id} with tracking")
|
|
else:
|
|
failed_count += 1
|
|
logger.error(f"[Streaming] Failed to add subscription {subscription_id}")
|
|
|
|
return {
|
|
'removed': removed_count,
|
|
'added': added_count,
|
|
'failed': failed_count,
|
|
'total_active': len(shared_stream_manager.get_all_subscriptions())
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in subscription reconciliation with tracking: {e}", exc_info=True)
|
|
return {'removed': 0, 'added': 0, 'failed': 0, 'total_active': 0}
|
|
|
|
async def _add_subscription_with_tracking(self, payload, tracking_integrations) -> bool:
|
|
"""Add a subscription with tracking integration."""
|
|
try:
|
|
from ..streaming.manager import StreamConfig
|
|
|
|
subscription_id = payload['subscriptionIdentifier']
|
|
camera_id = subscription_id.split(';')[-1]
|
|
model_id = payload['modelId']
|
|
|
|
logger.info(f"[SUBSCRIPTION_MAPPING] subscription_id='{subscription_id}' → camera_id='{camera_id}'")
|
|
|
|
# Get tracking integration for this subscription (camera-isolated)
|
|
tracking_integration = tracking_integrations.get(subscription_id)
|
|
|
|
# Extract crop coordinates if present
|
|
crop_coords = None
|
|
if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']):
|
|
crop_coords = (
|
|
payload['cropX1'],
|
|
payload['cropY1'],
|
|
payload['cropX2'],
|
|
payload['cropY2']
|
|
)
|
|
|
|
# Create stream configuration
|
|
stream_config = StreamConfig(
|
|
camera_id=camera_id,
|
|
rtsp_url=payload.get('rtspUrl'),
|
|
snapshot_url=payload.get('snapshotUrl'),
|
|
snapshot_interval=payload.get('snapshotInterval', 5000),
|
|
max_retries=3,
|
|
)
|
|
|
|
# Add subscription to StreamManager with tracking
|
|
success = shared_stream_manager.add_subscription(
|
|
subscription_id=subscription_id,
|
|
stream_config=stream_config,
|
|
crop_coords=crop_coords,
|
|
model_id=model_id,
|
|
model_url=payload.get('modelUrl'),
|
|
tracking_integration=tracking_integration
|
|
)
|
|
|
|
if success and tracking_integration:
|
|
logger.info(f"[Tracking] Subscription {subscription_id} configured with isolated tracking for model {model_id}")
|
|
|
|
return success
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding subscription with tracking: {e}", exc_info=True)
|
|
return False
|
|
|
|
async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool:
|
|
"""Ensure a single model is downloaded and available."""
|
|
try:
|
|
# Check if model is already available
|
|
if model_manager.is_model_downloaded(model_id):
|
|
logger.info(f"[Model Management] Model {model_id} ({model_name}) already available")
|
|
return True
|
|
|
|
# Download and extract model in a thread pool to avoid blocking the event loop
|
|
logger.info(f"[Model Management] Downloading model {model_id} ({model_name}) from {model_url}")
|
|
|
|
# Use asyncio.to_thread for CPU-bound operations (Python 3.9+)
|
|
# For compatibility, we'll use run_in_executor
|
|
loop = asyncio.get_event_loop()
|
|
model_path = await loop.run_in_executor(
|
|
None,
|
|
model_manager.ensure_model,
|
|
model_id,
|
|
model_url,
|
|
model_name
|
|
)
|
|
|
|
if model_path:
|
|
logger.info(f"[Model Management] Successfully prepared model {model_id} at {model_path}")
|
|
return True
|
|
else:
|
|
logger.error(f"[Model Management] Failed to prepare model {model_id}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"[Model Management] Exception ensuring model {model_id}: {str(e)}", exc_info=True)
|
|
return False
|
|
|
|
async def _save_snapshot(self, display_identifier: str, session_id: int) -> None:
|
|
"""
|
|
Save snapshot image to images folder after receiving sessionId.
|
|
|
|
Args:
|
|
display_identifier: Display identifier to match with subscriptionIdentifier
|
|
session_id: Session ID to include in filename
|
|
"""
|
|
try:
|
|
# Find subscription that matches the displayIdentifier
|
|
matching_subscription = None
|
|
for subscription in worker_state.get_all_subscriptions():
|
|
# Extract display ID from subscriptionIdentifier (format: displayId;cameraId)
|
|
from .messages import extract_display_identifier
|
|
sub_display_id = extract_display_identifier(subscription.subscriptionIdentifier)
|
|
if sub_display_id == display_identifier:
|
|
matching_subscription = subscription
|
|
break
|
|
|
|
if not matching_subscription:
|
|
logger.error(f"[Snapshot Save] No subscription found for display {display_identifier}")
|
|
return
|
|
|
|
if not matching_subscription.snapshotUrl:
|
|
logger.error(f"[Snapshot Save] No snapshotUrl found for display {display_identifier}")
|
|
return
|
|
|
|
# Ensure images directory exists (relative path for Docker bind mount)
|
|
images_dir = Path("images")
|
|
images_dir.mkdir(exist_ok=True)
|
|
|
|
# Generate filename with timestamp and session ID
|
|
timestamp = datetime.now(tz=timezone(timedelta(hours=7))).strftime("%Y%m%d_%H%M%S")
|
|
filename = f"{session_id}_{display_identifier}_{timestamp}.jpg"
|
|
filepath = images_dir / filename
|
|
|
|
# Use existing HTTPSnapshotReader to fetch snapshot
|
|
logger.info(f"[Snapshot Save] Fetching snapshot from {matching_subscription.snapshotUrl}")
|
|
|
|
# Run snapshot fetch in thread pool to avoid blocking async loop
|
|
loop = asyncio.get_event_loop()
|
|
frame = await loop.run_in_executor(None, self._fetch_snapshot_sync, matching_subscription.snapshotUrl)
|
|
|
|
if frame is not None:
|
|
# Save the image using OpenCV
|
|
success = cv2.imwrite(str(filepath), frame)
|
|
if success:
|
|
logger.info(f"[Snapshot Save] Successfully saved snapshot to {filepath}")
|
|
else:
|
|
logger.error(f"[Snapshot Save] Failed to save image file {filepath}")
|
|
else:
|
|
logger.error(f"[Snapshot Save] Failed to fetch snapshot from {matching_subscription.snapshotUrl}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"[Snapshot Save] Error saving snapshot for display {display_identifier}: {e}", exc_info=True)
|
|
|
|
def _fetch_snapshot_sync(self, snapshot_url: str):
|
|
"""
|
|
Synchronous snapshot fetching using existing HTTPSnapshotReader infrastructure.
|
|
|
|
Args:
|
|
snapshot_url: URL to fetch snapshot from
|
|
|
|
Returns:
|
|
np.ndarray or None: Fetched frame or None on error
|
|
"""
|
|
try:
|
|
from ..streaming.readers import HTTPSnapshotReader
|
|
|
|
# Create temporary snapshot reader for single fetch
|
|
snapshot_reader = HTTPSnapshotReader(
|
|
camera_id="temp_snapshot",
|
|
snapshot_url=snapshot_url,
|
|
interval_ms=5000 # Not used for single fetch
|
|
)
|
|
|
|
# Use existing fetch_single_snapshot method
|
|
return snapshot_reader.fetch_single_snapshot()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in sync snapshot fetch: {e}")
|
|
return None
|
|
|
|
async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None:
|
|
"""Handle setSessionId message."""
|
|
display_identifier = message.payload.displayIdentifier
|
|
session_id = str(message.payload.sessionId) if message.payload.sessionId is not None else None
|
|
|
|
logger.info(f"[RX Processing] setSessionId for display {display_identifier}: {session_id}")
|
|
|
|
# Update worker state
|
|
worker_state.set_session_id(display_identifier, session_id)
|
|
|
|
# Update tracking integrations with session ID
|
|
shared_stream_manager.set_session_id(display_identifier, session_id)
|
|
|
|
async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None:
|
|
"""Handle setProgressionStage message."""
|
|
display_identifier = message.payload.displayIdentifier
|
|
stage = message.payload.progressionStage
|
|
|
|
logger.info(f"[RX Processing] setProgressionStage for display {display_identifier}: {stage}")
|
|
|
|
# Update worker state
|
|
worker_state.set_progression_stage(display_identifier, stage)
|
|
|
|
# Update tracking integration for car abandonment detection
|
|
session_id = worker_state.get_session_id(display_identifier)
|
|
if session_id:
|
|
shared_stream_manager.set_progression_stage(session_id, stage)
|
|
|
|
# Save snapshot image when progression stage is car_fueling
|
|
if stage == 'car_fueling' and session_id:
|
|
await self._save_snapshot(display_identifier, session_id)
|
|
|
|
# If stage indicates session is cleared/finished, clear from tracking
|
|
if stage in ['finished', 'cleared', 'idle']:
|
|
# Get session ID for this display and clear it
|
|
if session_id:
|
|
shared_stream_manager.clear_session_id(session_id)
|
|
logger.info(f"[Tracking] Cleared session {session_id} due to progression stage: {stage}")
|
|
|
|
async def _handle_request_state(self, message: RequestStateMessage) -> None:
|
|
"""Handle requestState message by sending immediate state report."""
|
|
logger.debug("[RX Processing] requestState - sending immediate state report")
|
|
|
|
# Collect metrics and send state report
|
|
cpu_usage = SystemMetrics.get_cpu_usage()
|
|
memory_usage = SystemMetrics.get_memory_usage()
|
|
gpu_usage = SystemMetrics.get_gpu_usage()
|
|
gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
|
|
camera_connections = worker_state.get_camera_connections()
|
|
|
|
state_report = create_state_report(
|
|
cpu_usage=cpu_usage,
|
|
memory_usage=memory_usage,
|
|
gpu_usage=gpu_usage,
|
|
gpu_memory_usage=gpu_memory_usage,
|
|
camera_connections=camera_connections
|
|
)
|
|
|
|
await self._send_message(state_report)
|
|
|
|
async def _handle_patch_session_result(self, message: PatchSessionResultMessage) -> None:
|
|
"""Handle patchSessionResult message."""
|
|
payload = message.payload
|
|
logger.info(f"[RX Processing] patchSessionResult for session {payload.sessionId}: "
|
|
f"success={payload.success}, message='{payload.message}'")
|
|
|
|
# TODO: Handle patch session result if needed
|
|
# For now, just log the response
|
|
|
|
async def _send_message(self, message) -> None:
|
|
"""Send message to backend via WebSocket."""
|
|
if not self.connected:
|
|
logger.warning("Cannot send message: WebSocket not connected")
|
|
return
|
|
|
|
try:
|
|
json_message = serialize_outgoing_message(message)
|
|
await self.websocket.send_text(json_message)
|
|
# Log non-heartbeat messages only (heartbeats are logged in their respective functions)
|
|
if not (hasattr(message, 'type') and message.type == 'stateReport'):
|
|
logger.info(f"[TX → Backend] {json_message}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to send WebSocket message: {e}")
|
|
raise
|
|
|
|
async def _process_streams(self) -> None:
|
|
"""
|
|
Stream processing task that handles frame processing and detection.
|
|
This is a placeholder for Phase 2 - currently just logs that it's running.
|
|
"""
|
|
logger.info("Stream processing task started")
|
|
try:
|
|
while self.connected:
|
|
# Get current subscriptions
|
|
subscriptions = worker_state.get_all_subscriptions()
|
|
|
|
# TODO: Phase 2 - Add actual frame processing logic here
|
|
# This will include:
|
|
# - Frame reading from RTSP/HTTP streams
|
|
# - Model inference using loaded pipelines
|
|
# - Detection result sending via WebSocket
|
|
|
|
# Sleep to prevent excessive CPU usage (similar to old poll_interval)
|
|
await asyncio.sleep(0.1) # 100ms polling interval
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Stream processing task cancelled")
|
|
except Exception as e:
|
|
logger.error(f"Error in stream processing: {e}", exc_info=True)
|
|
|
|
async def _cleanup(self) -> None:
|
|
"""Clean up resources when connection closes."""
|
|
logger.info("Cleaning up WebSocket connection")
|
|
self.connected = False
|
|
|
|
# Cancel background tasks
|
|
if self._heartbeat_task and not self._heartbeat_task.done():
|
|
self._heartbeat_task.cancel()
|
|
if self._message_task and not self._message_task.done():
|
|
self._message_task.cancel()
|
|
|
|
# Clear worker state
|
|
worker_state.set_subscriptions([])
|
|
worker_state.session_ids.clear()
|
|
worker_state.progression_stages.clear()
|
|
|
|
logger.info("WebSocket connection cleanup completed")
|
|
|
|
|
|
# Factory function for FastAPI integration
|
|
async def websocket_endpoint(websocket: WebSocket) -> None:
|
|
"""
|
|
FastAPI WebSocket endpoint handler.
|
|
|
|
Args:
|
|
websocket: FastAPI WebSocket connection
|
|
"""
|
|
handler = WebSocketHandler(websocket)
|
|
await handler.handle_connection() |