1009 lines
No EOL
46 KiB
Python
1009 lines
No EOL
46 KiB
Python
"""
|
|
WebSocket handler module.
|
|
|
|
This module manages WebSocket connections, message processing, heartbeat functionality,
|
|
and coordination between stream processing and detection pipelines.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from typing import Dict, Any, Optional, Callable, List, Set
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import WebSocket
|
|
from fastapi.websockets import WebSocketDisconnect
|
|
from websockets.exceptions import ConnectionClosedError
|
|
|
|
from ..core.config import config, subscription_to_camera, latest_frames
|
|
from ..core.constants import HEARTBEAT_INTERVAL
|
|
from ..core.exceptions import WebSocketError, StreamError
|
|
from ..streams.stream_manager import StreamManager
|
|
from ..streams.camera_monitor import CameraMonitor
|
|
from ..detection.detection_result import DetectionResult
|
|
from ..models.model_manager import ModelManager
|
|
from ..pipeline.pipeline_executor import PipelineExecutor
|
|
from ..storage.session_cache import SessionCacheManager
|
|
from ..storage.redis_client import RedisClientManager
|
|
from ..utils.system_monitor import get_system_metrics
|
|
|
|
# Setup logging
|
|
logger = logging.getLogger("detector_worker.websocket_handler")
|
|
ws_logger = logging.getLogger("websocket")
|
|
ws_rxtx_logger = logging.getLogger("websocket.rxtx") # Dedicated RX/TX logger
|
|
|
|
# Import enhanced loggers
|
|
from ..utils.logging_utils import get_websocket_logger
|
|
enhanced_ws_logger = get_websocket_logger()
|
|
|
|
# Type definitions for callbacks
|
|
MessageHandler = Callable[[Dict[str, Any]], asyncio.coroutine]
|
|
DetectionHandler = Callable[[str, Dict[str, Any], Any, WebSocket, Any, Dict[str, Any]], asyncio.coroutine]
|
|
|
|
|
|
class WebSocketHandler:
|
|
"""
|
|
Manages WebSocket connections and message processing for the detection worker.
|
|
|
|
This class handles:
|
|
- WebSocket lifecycle management
|
|
- Message routing and processing
|
|
- Heartbeat/state reporting
|
|
- Stream subscription management
|
|
- Detection result forwarding
|
|
- Session management
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stream_manager: StreamManager,
|
|
model_manager: ModelManager,
|
|
pipeline_executor: PipelineExecutor,
|
|
session_cache: SessionCacheManager,
|
|
redis_client: Optional[RedisClientManager] = None
|
|
):
|
|
"""
|
|
Initialize the WebSocket handler.
|
|
|
|
Args:
|
|
stream_manager: Manager for camera streams
|
|
model_manager: Manager for ML models
|
|
pipeline_executor: Pipeline execution engine
|
|
session_cache: Session state cache
|
|
redis_client: Optional Redis client for pub/sub
|
|
"""
|
|
self.stream_manager = stream_manager
|
|
self.model_manager = model_manager
|
|
self.pipeline_executor = pipeline_executor
|
|
self.session_cache = session_cache
|
|
self.redis_client = redis_client
|
|
|
|
# Connection state
|
|
self.websocket: Optional[WebSocket] = None
|
|
self.connected: bool = False
|
|
self.tasks: List[asyncio.Task] = []
|
|
|
|
# Message handlers
|
|
self.message_handlers: Dict[str, MessageHandler] = {
|
|
"subscribe": self._handle_subscribe,
|
|
"unsubscribe": self._handle_unsubscribe,
|
|
"setSubscriptionList": self._handle_set_subscription_list,
|
|
"requestState": self._handle_request_state,
|
|
"setSessionId": self._handle_set_session,
|
|
"patchSession": self._handle_patch_session,
|
|
"setProgressionStage": self._handle_set_progression_stage,
|
|
"patchSessionResult": self._handle_patch_session_result
|
|
}
|
|
|
|
# Session and display management
|
|
self.session_ids: Dict[str, str] = {} # display_identifier -> session_id
|
|
self.progression_stages: Dict[str, str] = {} # display_identifier -> progression_stage
|
|
self.display_identifiers: Set[str] = set()
|
|
|
|
# Camera monitor
|
|
self.camera_monitor = CameraMonitor()
|
|
|
|
async def handle_connection(self, websocket: WebSocket) -> None:
|
|
"""
|
|
Main entry point for handling a WebSocket connection.
|
|
|
|
Args:
|
|
websocket: The WebSocket connection to handle
|
|
"""
|
|
try:
|
|
await websocket.accept()
|
|
self.websocket = websocket
|
|
self.connected = True
|
|
|
|
# Log connection details with bulletproof logging
|
|
client_host = getattr(websocket.client, 'host', 'unknown')
|
|
client_port = getattr(websocket.client, 'port', 'unknown')
|
|
connection_msg = f"🔗 WebSocket connection accepted from {client_host}:{client_port}"
|
|
|
|
print(f"\n{connection_msg}") # Print to console (always visible)
|
|
logger.info(connection_msg)
|
|
ws_rxtx_logger.info(f"CONNECT -> Client: {client_host}:{client_port}")
|
|
|
|
print("🔄 WebSocket handler ready - waiting for messages from CMS backend...")
|
|
print("📡 All RX/TX communication will be logged below:")
|
|
print("=" * 80)
|
|
|
|
# Create concurrent tasks
|
|
stream_task = asyncio.create_task(self._process_streams())
|
|
heartbeat_task = asyncio.create_task(self._send_heartbeat())
|
|
message_task = asyncio.create_task(self._process_messages())
|
|
|
|
self.tasks = [stream_task, heartbeat_task, message_task]
|
|
|
|
# Wait for tasks to complete
|
|
await asyncio.gather(heartbeat_task, message_task)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in WebSocket handler: {e}")
|
|
|
|
finally:
|
|
self.connected = False
|
|
client_host = getattr(websocket.client, 'host', 'unknown') if websocket.client else 'unknown'
|
|
client_port = getattr(websocket.client, 'port', 'unknown') if websocket.client else 'unknown'
|
|
|
|
print(f"\n🔗 WEBSOCKET CONNECTION CLOSED: {client_host}:{client_port}")
|
|
print("=" * 80)
|
|
ws_rxtx_logger.info(f"DISCONNECT -> Client: {client_host}:{client_port}")
|
|
await self._cleanup()
|
|
|
|
async def _cleanup(self) -> None:
|
|
"""Clean up resources when connection closes."""
|
|
logger.info("Cleaning up WebSocket connection")
|
|
|
|
# Cancel all tasks
|
|
for task in self.tasks:
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Clean up streams
|
|
await self.stream_manager.cleanup_all_streams()
|
|
|
|
# Clean up models
|
|
self.model_manager.cleanup_all_models()
|
|
|
|
# Clear session state
|
|
self.session_cache.clear_all_sessions()
|
|
self.session_ids.clear()
|
|
self.display_identifiers.clear()
|
|
|
|
# Clear camera states
|
|
self.camera_monitor.clear_all_states()
|
|
|
|
logger.info("WebSocket cleanup completed")
|
|
|
|
async def _send_heartbeat(self) -> None:
|
|
"""Send periodic heartbeat/state reports to maintain connection."""
|
|
while self.connected:
|
|
try:
|
|
# Get system metrics
|
|
metrics = get_system_metrics()
|
|
|
|
# Build cameraConnections array as required by protocol
|
|
camera_connections = []
|
|
with self.stream_manager.streams_lock:
|
|
for camera_id, stream_info in self.stream_manager.streams.items():
|
|
# Check if camera is online
|
|
is_online = self.stream_manager.is_stream_active(camera_id)
|
|
|
|
connection_info = {
|
|
"subscriptionIdentifier": getattr(stream_info, "subscriptionIdentifier", camera_id),
|
|
"modelId": getattr(stream_info, "modelId", 0),
|
|
"modelName": getattr(stream_info, "modelName", "Unknown Model"),
|
|
"online": is_online
|
|
}
|
|
|
|
# Add crop coordinates if available
|
|
if hasattr(stream_info, "cropX1"):
|
|
connection_info["cropX1"] = stream_info.cropX1
|
|
if hasattr(stream_info, "cropY1"):
|
|
connection_info["cropY1"] = stream_info.cropY1
|
|
if hasattr(stream_info, "cropX2"):
|
|
connection_info["cropX2"] = stream_info.cropX2
|
|
if hasattr(stream_info, "cropY2"):
|
|
connection_info["cropY2"] = stream_info.cropY2
|
|
|
|
camera_connections.append(connection_info)
|
|
|
|
# Protocol-compliant stateReport format (worker.md lines 169-189)
|
|
state_data = {
|
|
"type": "stateReport",
|
|
"cpuUsage": metrics.get("cpu_percent", 0),
|
|
"memoryUsage": metrics.get("memory_percent", 0),
|
|
"gpuUsage": metrics.get("gpu_percent", 0),
|
|
"gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # Fixed field name
|
|
"cameraConnections": camera_connections
|
|
}
|
|
|
|
# BULLETPROOF TX LOGGING - Multiple methods to ensure visibility
|
|
tx_json = json.dumps(state_data, separators=(',', ':'))
|
|
print(f"\n🟢 WEBSOCKET TX -> {tx_json}") # Print to console (always visible)
|
|
logger.info(f"🟢 TX -> {tx_json}") # Standard logging
|
|
ws_rxtx_logger.info(f"TX -> {tx_json}") # WebSocket specific logging
|
|
|
|
# Enhanced TX logging
|
|
enhanced_ws_logger.log_tx(state_data)
|
|
await self.websocket.send_json(state_data)
|
|
|
|
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
|
|
|
except (WebSocketDisconnect, ConnectionClosedError):
|
|
logger.info("WebSocket disconnected during heartbeat")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error sending heartbeat: {e}")
|
|
break
|
|
|
|
async def _process_messages(self) -> None:
|
|
"""Process incoming WebSocket messages."""
|
|
while self.connected:
|
|
try:
|
|
text_data = await self.websocket.receive_text()
|
|
|
|
# BULLETPROOF RX LOGGING - Multiple methods to ensure visibility
|
|
print(f"\n🔵 WEBSOCKET RX <- {text_data}") # Print to console (always visible)
|
|
logger.info(f"🔵 RX <- {text_data}") # Standard logging
|
|
ws_rxtx_logger.info(f"RX <- {text_data}") # WebSocket specific logging
|
|
|
|
# Enhanced RX logging with correlation
|
|
correlation_id = enhanced_ws_logger.log_rx(text_data)
|
|
|
|
data = json.loads(text_data)
|
|
msg_type = data.get("type")
|
|
|
|
# Log message processing - FORCE INFO LEVEL
|
|
logger.info(f"📥 Processing message type: {msg_type} [corr:{correlation_id}]")
|
|
|
|
if msg_type in self.message_handlers:
|
|
handler = self.message_handlers[msg_type]
|
|
await handler(data)
|
|
logger.info(f"✅ Message {msg_type} processed successfully [corr:{correlation_id}]")
|
|
else:
|
|
logger.error(f"❌ Unknown message type: {msg_type} [corr:{correlation_id}]")
|
|
ws_rxtx_logger.error(f"UNKNOWN_MSG_TYPE -> {msg_type}")
|
|
|
|
except json.JSONDecodeError as e:
|
|
print(f"\n❌ WEBSOCKET ERROR - Invalid JSON received: {e}")
|
|
print(f"🔍 Raw message data: {text_data}")
|
|
logger.error(f"Received invalid JSON message: {e}")
|
|
logger.error(f"Raw message data: {text_data}")
|
|
enhanced_ws_logger.correlation_logger.error("Failed to parse JSON in received message")
|
|
except (WebSocketDisconnect, ConnectionClosedError) as e:
|
|
print(f"\n🔌 WEBSOCKET DISCONNECTED: {e}")
|
|
logger.warning(f"WebSocket disconnected: {e}")
|
|
break
|
|
except Exception as e:
|
|
print(f"\n💥 WEBSOCKET ERROR: {e}")
|
|
logger.error(f"Error handling message: {e}")
|
|
traceback.print_exc()
|
|
break
|
|
|
|
async def _process_streams(self) -> None:
|
|
"""Process active camera streams and run detection pipelines."""
|
|
while self.connected:
|
|
try:
|
|
active_streams = self.stream_manager.get_active_streams()
|
|
|
|
if active_streams:
|
|
# Process each active stream
|
|
tasks = []
|
|
for camera_id, stream_info in active_streams.items():
|
|
# Get latest frame
|
|
frame = self.stream_manager.get_latest_frame(camera_id)
|
|
if frame is None:
|
|
continue
|
|
|
|
# Get model for this camera
|
|
model_id = getattr(stream_info, "modelId", None)
|
|
if not model_id:
|
|
continue
|
|
|
|
model_tree = self.model_manager.get_model(camera_id, model_id)
|
|
if not model_tree:
|
|
continue
|
|
|
|
# Create detection task
|
|
persistent_data = self.session_cache.get_persistent_data(camera_id)
|
|
task = asyncio.create_task(
|
|
self._handle_detection(
|
|
camera_id, stream_info, frame,
|
|
self.websocket, model_tree, persistent_data
|
|
)
|
|
)
|
|
tasks.append(task)
|
|
|
|
# Wait for all detection tasks
|
|
if tasks:
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Update persistent data
|
|
for i, (camera_id, _) in enumerate(active_streams.items()):
|
|
if i < len(results) and isinstance(results[i], dict):
|
|
self.session_cache.update_persistent_data(camera_id, results[i])
|
|
|
|
# Polling interval
|
|
poll_interval = config.get("poll_interval_ms", 100) / 1000.0
|
|
await asyncio.sleep(poll_interval)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Stream processing cancelled")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in stream processing: {e}")
|
|
await asyncio.sleep(1)
|
|
|
|
async def _handle_detection(
|
|
self,
|
|
camera_id: str,
|
|
stream_info: Dict[str, Any],
|
|
frame: Any,
|
|
websocket: WebSocket,
|
|
model_tree: Any,
|
|
persistent_data: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Handle detection for a single camera frame.
|
|
|
|
Args:
|
|
camera_id: Camera identifier
|
|
stream_info: Stream configuration
|
|
frame: Video frame to process
|
|
websocket: WebSocket connection
|
|
model_tree: Model pipeline tree
|
|
persistent_data: Persistent data for this camera
|
|
|
|
Returns:
|
|
Updated persistent data
|
|
"""
|
|
try:
|
|
# Check camera connection state
|
|
if self.camera_monitor.should_notify_disconnection(camera_id):
|
|
await self._send_disconnection_notification(camera_id, stream_info)
|
|
return persistent_data
|
|
|
|
# Apply crop if specified
|
|
cropped_frame = self._apply_crop(frame, stream_info)
|
|
|
|
# Get session pipeline state
|
|
pipeline_state = self.session_cache.get_session_pipeline_state(camera_id)
|
|
|
|
# Run detection pipeline
|
|
detection_result = await self.pipeline_executor.execute_pipeline(
|
|
camera_id,
|
|
stream_info,
|
|
cropped_frame,
|
|
model_tree,
|
|
persistent_data,
|
|
pipeline_state
|
|
)
|
|
|
|
# Send detection result
|
|
if detection_result:
|
|
await self._send_detection_result(
|
|
camera_id, stream_info, detection_result
|
|
)
|
|
|
|
# Handle camera reconnection
|
|
if self.camera_monitor.should_notify_reconnection(camera_id):
|
|
self.camera_monitor.mark_reconnection_notified(camera_id)
|
|
logger.info(f"Camera {camera_id} reconnected successfully")
|
|
|
|
return persistent_data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in detection handling for camera {camera_id}: {e}")
|
|
traceback.print_exc()
|
|
return persistent_data
|
|
|
|
async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None:
|
|
"""
|
|
Handle setSubscriptionList command - declarative subscription management.
|
|
|
|
This is the primary subscription command per worker.md protocol.
|
|
Workers must reconcile the new subscription list with current state.
|
|
"""
|
|
subscriptions = data.get("subscriptions", [])
|
|
|
|
# DETAILED DEBUG LOGGING - Log the entire message payload
|
|
print(f"\n📋 RECEIVED setSubscriptionList with {len(subscriptions)} subscriptions")
|
|
logger.info(f"🔍 RECEIVED setSubscriptionList - Full payload: {json.dumps(data, indent=2)}")
|
|
logger.info(f"📋 Number of subscriptions: {len(subscriptions)}")
|
|
|
|
# Extract unique model URLs for download
|
|
unique_models = {} # model_id -> model_url
|
|
valid_subscriptions = []
|
|
|
|
for i, sub_config in enumerate(subscriptions):
|
|
sub_id = sub_config.get("subscriptionIdentifier")
|
|
model_id = sub_config.get("modelId")
|
|
model_url = sub_config.get("modelUrl")
|
|
|
|
print(f"📦 Subscription {i+1}: {sub_id} | Model {model_id}")
|
|
|
|
# Track unique models for download - check if model already exists locally
|
|
if model_id and model_url:
|
|
if model_id not in unique_models:
|
|
# Check if model directory already exists on disk
|
|
from ..core.config import MODELS_DIR
|
|
model_dir = os.path.join(MODELS_DIR, str(model_id))
|
|
|
|
print(f"🔍 Checking model directory: {model_dir}")
|
|
logger.info(f"Checking if model {model_id} exists at: {model_dir}")
|
|
|
|
if os.path.exists(model_dir) and os.path.isdir(model_dir):
|
|
# Check if directory has content (not empty)
|
|
dir_contents = os.listdir(model_dir)
|
|
actual_contents = [f for f in dir_contents if not f.startswith('.')]
|
|
|
|
print(f"📋 Directory contents: {dir_contents}")
|
|
print(f"📋 Filtered contents: {actual_contents}")
|
|
logger.info(f"Model {model_id} directory contents: {actual_contents}")
|
|
|
|
if actual_contents:
|
|
print(f"📁 Model {model_id} already exists locally, skipping download")
|
|
logger.info(f"Skipping download for model {model_id} - already exists")
|
|
else:
|
|
print(f"📁 Model {model_id} directory exists but empty, will download")
|
|
unique_models[model_id] = model_url
|
|
print(f"🎯 New model found: ID {model_id}")
|
|
logger.info(f"Model {model_id} directory empty, adding to download queue")
|
|
else:
|
|
print(f"📁 Model {model_id} directory does not exist, will download")
|
|
unique_models[model_id] = model_url
|
|
print(f"🎯 New model found: ID {model_id}")
|
|
logger.info(f"Model {model_id} directory not found, adding to download queue")
|
|
else:
|
|
print(f"🔄 Model {model_id} already tracked")
|
|
|
|
logger.info(f"📦 Subscription {i+1}: {json.dumps(sub_config, indent=2)}")
|
|
sub_id = sub_config.get("subscriptionIdentifier")
|
|
logger.info(f"🏷️ Subscription ID: '{sub_id}' (type: {type(sub_id)})")
|
|
|
|
print(f"📚 Unique models to download: {list(unique_models.keys())}")
|
|
|
|
# Download unique models first (before processing subscriptions)
|
|
if unique_models:
|
|
print(f"⬇️ Starting download of {len(unique_models)} unique models...")
|
|
await self._download_unique_models(unique_models)
|
|
|
|
try:
|
|
# Get current subscription identifiers
|
|
current_subscriptions = set(subscription_to_camera.keys())
|
|
|
|
# Get desired subscription identifiers
|
|
desired_subscriptions = set()
|
|
subscription_configs = {}
|
|
|
|
for sub_config in subscriptions:
|
|
sub_id = sub_config.get("subscriptionIdentifier")
|
|
|
|
# Enhanced validation with detailed logging
|
|
logger.info(f"🔍 Processing subscription config: subscriptionIdentifier='{sub_id}'")
|
|
|
|
# Handle null/None subscription IDs
|
|
if sub_id is None or sub_id == "null" or sub_id == "None" or not sub_id:
|
|
logger.error(f"❌ Invalid subscription ID received: '{sub_id}' (type: {type(sub_id)})")
|
|
logger.error(f"📋 Full subscription config: {json.dumps(sub_config, indent=2)}")
|
|
|
|
# Try to construct a valid subscription ID from available data
|
|
display_id = sub_config.get("displayId") or sub_config.get("displayIdentifier") or "unknown-display"
|
|
camera_id = sub_config.get("cameraId") or sub_config.get("camera") or "unknown-camera"
|
|
constructed_id = f"{display_id};{camera_id}"
|
|
|
|
logger.warning(f"🔧 Attempting to construct subscription ID: '{constructed_id}'")
|
|
logger.warning(f"📝 Available config keys: {list(sub_config.keys())}")
|
|
|
|
# Use constructed ID if it looks valid
|
|
if display_id != "unknown-display" or camera_id != "unknown-camera":
|
|
sub_id = constructed_id
|
|
logger.info(f"✅ Using constructed subscription ID: '{sub_id}'")
|
|
else:
|
|
logger.error(f"💥 Cannot construct valid subscription ID, skipping this subscription")
|
|
continue
|
|
|
|
if sub_id:
|
|
desired_subscriptions.add(sub_id)
|
|
subscription_configs[sub_id] = sub_config
|
|
|
|
# Extract display ID for session management
|
|
parts = sub_id.split(";")
|
|
if len(parts) >= 2:
|
|
display_id = parts[0]
|
|
self.display_identifiers.add(display_id)
|
|
|
|
# Calculate changes needed
|
|
to_add = desired_subscriptions - current_subscriptions
|
|
to_remove = current_subscriptions - desired_subscriptions
|
|
to_update = desired_subscriptions & current_subscriptions
|
|
|
|
logger.info(f"Subscription reconciliation: add={len(to_add)}, remove={len(to_remove)}, update={len(to_update)}")
|
|
|
|
# Remove obsolete subscriptions
|
|
for sub_id in to_remove:
|
|
camera_id = subscription_to_camera.get(sub_id)
|
|
if camera_id:
|
|
await self.stream_manager.stop_stream(camera_id)
|
|
self.model_manager.unload_models(camera_id)
|
|
subscription_to_camera.pop(sub_id, None)
|
|
# Clear cached data (SessionCacheManager handles this automatically)
|
|
# Note: clear_session method not available, cleanup happens automatically
|
|
logger.info(f"Removed subscription: {sub_id}")
|
|
|
|
# Add new subscriptions
|
|
for sub_id in to_add:
|
|
await self._start_subscription(sub_id, subscription_configs[sub_id])
|
|
logger.info(f"Added subscription: {sub_id}")
|
|
|
|
# Update existing subscriptions if needed
|
|
for sub_id in to_update:
|
|
# Check if configuration changed (model URL, crop coordinates, etc.)
|
|
current_config = subscription_to_camera.get(sub_id)
|
|
new_config = subscription_configs[sub_id]
|
|
|
|
# For now, restart subscription if model URL changed (handles S3 expiration)
|
|
current_model_url = getattr(current_config, 'model_url', None) if current_config else None
|
|
new_model_url = new_config.get("modelUrl")
|
|
|
|
if current_model_url != new_model_url:
|
|
# Restart with new configuration
|
|
camera_id = subscription_to_camera.get(sub_id)
|
|
if camera_id:
|
|
await self.stream_manager.stop_stream(camera_id)
|
|
self.model_manager.unload_models(camera_id)
|
|
|
|
await self._start_subscription(sub_id, new_config)
|
|
logger.info(f"Updated subscription: {sub_id}")
|
|
|
|
logger.info(f"Subscription list reconciliation completed. Active: {len(desired_subscriptions)}")
|
|
|
|
except Exception as e:
|
|
print(f"💥 Error handling setSubscriptionList: {e}")
|
|
logger.error(f"Error handling setSubscriptionList: {e}")
|
|
traceback.print_exc()
|
|
|
|
async def _download_unique_models(self, unique_models: Dict[int, str]) -> None:
|
|
"""
|
|
Download unique models to models/{model_id}/ folders.
|
|
|
|
Args:
|
|
unique_models: Dictionary of model_id -> model_url
|
|
"""
|
|
try:
|
|
# Use model manager to download models
|
|
download_tasks = []
|
|
|
|
for model_id, model_url in unique_models.items():
|
|
print(f"🚀 Queuing download: Model {model_id} from {model_url[:50]}...")
|
|
|
|
# Create download task using model manager
|
|
task = asyncio.create_task(
|
|
self._download_single_model(model_id, model_url)
|
|
)
|
|
download_tasks.append(task)
|
|
|
|
# Wait for all downloads to complete
|
|
if download_tasks:
|
|
print(f"⏳ Downloading {len(download_tasks)} models concurrently...")
|
|
results = await asyncio.gather(*download_tasks, return_exceptions=True)
|
|
|
|
# Check results
|
|
successful = 0
|
|
failed = 0
|
|
for i, result in enumerate(results):
|
|
model_id = list(unique_models.keys())[i]
|
|
if isinstance(result, Exception):
|
|
print(f"❌ Model {model_id} download failed: {result}")
|
|
failed += 1
|
|
else:
|
|
print(f"✅ Model {model_id} downloaded successfully")
|
|
successful += 1
|
|
|
|
print(f"📊 Download summary: {successful} successful, {failed} failed")
|
|
else:
|
|
print("📭 No models to download")
|
|
|
|
except Exception as e:
|
|
print(f"💥 Error in bulk model download: {e}")
|
|
logger.error(f"Error downloading unique models: {e}")
|
|
|
|
async def _download_single_model(self, model_id: int, model_url: str) -> None:
|
|
"""
|
|
Download a single model using the model manager.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
model_url: URL to download from
|
|
"""
|
|
try:
|
|
# Create a temporary camera ID for the download
|
|
temp_camera_id = f"download_temp_{model_id}_{int(time.time())}"
|
|
|
|
print(f"📥 Downloading model {model_id}...")
|
|
|
|
# Use model manager to load (download) the model
|
|
await self.model_manager.load_model(
|
|
camera_id=temp_camera_id,
|
|
model_id=str(model_id),
|
|
model_url=model_url,
|
|
force_reload=False # Use cached if already downloaded
|
|
)
|
|
|
|
# Clean up the temporary model reference
|
|
self.model_manager.unload_models(temp_camera_id)
|
|
|
|
print(f"✅ Model {model_id} successfully downloaded to models/{model_id}/")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to download model {model_id}: {e}")
|
|
raise # Re-raise for gather() to catch
|
|
|
|
async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None:
|
|
"""Start a single subscription with given configuration and enhanced validation."""
|
|
try:
|
|
# Validate subscription_id
|
|
if not subscription_id:
|
|
raise ValueError("Empty subscription_id provided")
|
|
|
|
# Extract camera ID from subscription identifier with enhanced validation
|
|
parts = subscription_id.split(";")
|
|
if len(parts) >= 2:
|
|
camera_id = parts[1]
|
|
else:
|
|
# Fallback to using subscription_id as camera_id if format is unexpected
|
|
camera_id = subscription_id
|
|
logger.warning(f"Subscription ID format unexpected: '{subscription_id}', using as camera_id")
|
|
|
|
# Validate camera_id
|
|
if not camera_id or camera_id == "null" or camera_id == "None":
|
|
raise ValueError(f"Invalid camera_id extracted from subscription_id '{subscription_id}': '{camera_id}'")
|
|
|
|
logger.info(f"Starting subscription {subscription_id} for camera {camera_id}")
|
|
logger.debug(f"Config keys for camera {camera_id}: {list(config.keys())}")
|
|
|
|
# Store subscription mapping
|
|
subscription_to_camera[subscription_id] = camera_id
|
|
|
|
# Start camera stream with enhanced config validation
|
|
if not config:
|
|
raise ValueError(f"Empty config provided for camera {camera_id}")
|
|
|
|
stream_started = await self.stream_manager.start_stream(camera_id, config)
|
|
if not stream_started:
|
|
raise RuntimeError(f"Failed to start stream for camera {camera_id}")
|
|
|
|
# Load model
|
|
model_id = config.get("modelId")
|
|
model_url = config.get("modelUrl")
|
|
|
|
if model_id and model_url:
|
|
logger.info(f"Loading model {model_id} for camera {camera_id} from {model_url}")
|
|
await self.model_manager.load_model(camera_id, model_id, model_url)
|
|
elif model_id or model_url:
|
|
logger.warning(f"Incomplete model config for camera {camera_id}: modelId={model_id}, modelUrl={model_url}")
|
|
else:
|
|
logger.info(f"No model specified for camera {camera_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting subscription {subscription_id}: {e}")
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
async def _handle_request_state(self, data: Dict[str, Any]) -> None:
|
|
"""Handle state request message."""
|
|
# Send immediate state report
|
|
await self._send_heartbeat()
|
|
|
|
async def _handle_set_session(self, data: Dict[str, Any]) -> None:
|
|
"""Handle setSessionId message."""
|
|
payload = data.get("payload", {})
|
|
display_id = payload.get("displayIdentifier")
|
|
session_id = payload.get("sessionId")
|
|
|
|
if display_id and session_id:
|
|
self.session_ids[display_id] = session_id
|
|
|
|
# Update session for all cameras of this display
|
|
with self.stream_manager.streams_lock:
|
|
for camera_id, stream in self.stream_manager.streams.items():
|
|
if stream["subscriptionIdentifier"].startswith(display_id + ";"):
|
|
self.session_cache.update_session_id(camera_id, session_id)
|
|
|
|
# Send acknowledgment
|
|
response = {
|
|
"type": "ack",
|
|
"requestId": data.get("requestId", str(uuid.uuid4())),
|
|
"code": "200",
|
|
"data": {
|
|
"message": f"Session ID set for display {display_id}",
|
|
"sessionId": session_id
|
|
}
|
|
}
|
|
# BULLETPROOF TX LOGGING for responses
|
|
response_json = json.dumps(response, separators=(',', ':'))
|
|
print(f"\n🟢 WEBSOCKET TX -> {response_json}") # Print to console (always visible)
|
|
enhanced_ws_logger.log_tx(response)
|
|
ws_rxtx_logger.info(f"TX -> {response_json}")
|
|
await self.websocket.send_json(response)
|
|
|
|
logger.info(f"Set session {session_id} for display {display_id}")
|
|
|
|
async def _handle_patch_session(self, data: Dict[str, Any]) -> None:
|
|
"""Handle patchSession message."""
|
|
payload = data.get("payload", {})
|
|
session_id = payload.get("sessionId")
|
|
patch_data = payload.get("data", {})
|
|
|
|
if session_id:
|
|
# Store patch data (could be used for session updates)
|
|
logger.info(f"Received patch for session {session_id}: {patch_data}")
|
|
|
|
# Send acknowledgment
|
|
response = {
|
|
"type": "ack",
|
|
"requestId": data.get("requestId", str(uuid.uuid4())),
|
|
"code": "200",
|
|
"data": {
|
|
"message": f"Session {session_id} patched successfully",
|
|
"sessionId": session_id,
|
|
"patchData": patch_data
|
|
}
|
|
}
|
|
# BULLETPROOF TX LOGGING for responses
|
|
response_json = json.dumps(response, separators=(',', ':'))
|
|
print(f"\n🟢 WEBSOCKET TX -> {response_json}") # Print to console (always visible)
|
|
enhanced_ws_logger.log_tx(response)
|
|
ws_rxtx_logger.info(f"TX -> {response_json}")
|
|
await self.websocket.send_json(response)
|
|
|
|
async def _handle_set_progression_stage(self, data: Dict[str, Any]) -> None:
|
|
"""Handle setProgressionStage message."""
|
|
payload = data.get("payload", {})
|
|
display_id = payload.get("displayIdentifier")
|
|
progression_stage = payload.get("progressionStage")
|
|
|
|
logger.info(f"🏁 PROGRESSION STAGE RECEIVED: displayId={display_id}, stage={progression_stage}")
|
|
|
|
if not display_id:
|
|
logger.warning("Missing displayIdentifier in setProgressionStage")
|
|
return
|
|
|
|
# Find all cameras for this display
|
|
affected_cameras = []
|
|
with self.stream_manager.streams_lock:
|
|
for camera_id, stream in self.stream_manager.streams.items():
|
|
if stream["subscriptionIdentifier"].startswith(display_id + ";"):
|
|
affected_cameras.append(camera_id)
|
|
|
|
logger.debug(f"🎯 Found {len(affected_cameras)} cameras for display {display_id}: {affected_cameras}")
|
|
|
|
# Update progression stage for each camera
|
|
for camera_id in affected_cameras:
|
|
pipeline_state = self.session_cache.get_or_init_session_pipeline_state(camera_id)
|
|
current_mode = pipeline_state.get("mode", "validation_detecting")
|
|
|
|
if progression_stage == "car_fueling":
|
|
# Stop YOLO inference during fueling
|
|
if current_mode == "lightweight":
|
|
pipeline_state["yolo_inference_enabled"] = False
|
|
pipeline_state["progression_stage"] = "car_fueling"
|
|
logger.info(f"⏸️ Camera {camera_id}: YOLO inference DISABLED for car_fueling stage")
|
|
else:
|
|
logger.debug(f"📊 Camera {camera_id}: car_fueling received but not in lightweight mode (mode: {current_mode})")
|
|
|
|
elif progression_stage == "car_waitpayment":
|
|
# Resume YOLO inference for absence counter
|
|
pipeline_state["yolo_inference_enabled"] = True
|
|
pipeline_state["progression_stage"] = "car_waitpayment"
|
|
logger.info(f"▶️ Camera {camera_id}: YOLO inference RE-ENABLED for car_waitpayment stage")
|
|
|
|
elif progression_stage == "welcome":
|
|
# Ignore welcome messages during car_waitpayment
|
|
current_progression = pipeline_state.get("progression_stage")
|
|
if current_progression == "car_waitpayment":
|
|
logger.info(f"🚫 Camera {camera_id}: IGNORING welcome stage (currently in car_waitpayment)")
|
|
else:
|
|
pipeline_state["progression_stage"] = "welcome"
|
|
logger.info(f"🎉 Camera {camera_id}: Progression stage set to welcome")
|
|
|
|
elif progression_stage in ["car_wait_staff"]:
|
|
pipeline_state["progression_stage"] = progression_stage
|
|
logger.info(f"📋 Camera {camera_id}: Progression stage set to {progression_stage}")
|
|
|
|
# Store progression stage for this display
|
|
if display_id and progression_stage is not None:
|
|
if progression_stage:
|
|
self.progression_stages[display_id] = progression_stage
|
|
else:
|
|
# Clear progression stage if null
|
|
self.progression_stages.pop(display_id, None)
|
|
|
|
async def _handle_patch_session_result(self, data: Dict[str, Any]) -> None:
|
|
"""Handle patchSessionResult message from backend."""
|
|
payload = data.get("payload", {})
|
|
session_id = payload.get("sessionId")
|
|
success = payload.get("success", False)
|
|
message = payload.get("message", "")
|
|
|
|
if success:
|
|
logger.info(f"Patch session {session_id} successful: {message}")
|
|
else:
|
|
logger.warning(f"Patch session {session_id} failed: {message}")
|
|
|
|
async def _send_detection_result(
|
|
self,
|
|
camera_id: str,
|
|
stream_info: Dict[str, Any],
|
|
detection_result: DetectionResult
|
|
) -> None:
|
|
"""Send detection result over WebSocket."""
|
|
# Get session ID for this display
|
|
subscription_id = getattr(stream_info, "subscriptionIdentifier", "")
|
|
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
|
|
session_id = self.session_ids.get(display_id)
|
|
|
|
detection_data = {
|
|
"type": "imageDetection",
|
|
"subscriptionIdentifier": subscription_id,
|
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
"sessionId": session_id, # Required by protocol
|
|
"data": {
|
|
"detection": detection_result.to_dict(),
|
|
"modelId": getattr(stream_info, "modelId", 0),
|
|
"modelName": getattr(stream_info, "modelName", "Unknown Model")
|
|
}
|
|
}
|
|
|
|
try:
|
|
# BULLETPROOF TX LOGGING for detection results
|
|
detection_json = json.dumps(detection_data, separators=(',', ':'))
|
|
print(f"\n🟢 WEBSOCKET TX -> {detection_json}") # Print to console (always visible)
|
|
enhanced_ws_logger.log_tx(detection_data)
|
|
ws_rxtx_logger.info(f"TX -> {detection_json}")
|
|
await self.websocket.send_json(detection_data)
|
|
except RuntimeError as e:
|
|
if "websocket.close" in str(e):
|
|
logger.warning(f"WebSocket closed - cannot send detection for camera {camera_id}")
|
|
else:
|
|
raise
|
|
|
|
async def _send_disconnection_notification(
|
|
self,
|
|
camera_id: str,
|
|
stream_info: Dict[str, Any]
|
|
) -> None:
|
|
"""Send camera disconnection notification."""
|
|
logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null")
|
|
|
|
# Clear cached data (SessionCacheManager handles this automatically via cleanup)
|
|
# Note: clear_session method not available, cleanup happens automatically
|
|
|
|
# Send null detection
|
|
detection_data = {
|
|
"type": "imageDetection",
|
|
"subscriptionIdentifier": getattr(stream_info, "subscriptionIdentifier", ""),
|
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
"data": {
|
|
"detection": None,
|
|
"modelId": getattr(stream_info, "modelId", 0),
|
|
"modelName": getattr(stream_info, "modelName", "Unknown Model")
|
|
}
|
|
}
|
|
|
|
try:
|
|
# BULLETPROOF TX LOGGING for detection results
|
|
detection_json = json.dumps(detection_data, separators=(',', ':'))
|
|
print(f"\n🟢 WEBSOCKET TX -> {detection_json}") # Print to console (always visible)
|
|
enhanced_ws_logger.log_tx(detection_data)
|
|
ws_rxtx_logger.info(f"TX -> {detection_json}")
|
|
await self.websocket.send_json(detection_data)
|
|
except RuntimeError as e:
|
|
if "websocket.close" in str(e):
|
|
logger.warning(f"WebSocket closed - cannot send disconnection signal for camera {camera_id}")
|
|
else:
|
|
raise
|
|
|
|
self.camera_monitor.mark_disconnection_notified(camera_id)
|
|
logger.info(f"📡 SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session")
|
|
|
|
async def _handle_subscribe(self, data: Dict[str, Any]) -> None:
|
|
"""Handle individual subscription message - often initial null data from CMS."""
|
|
try:
|
|
payload = data.get("payload", {})
|
|
subscription_id = payload.get("subscriptionIdentifier")
|
|
|
|
print(f"📥 SUBSCRIBE MESSAGE RECEIVED - subscriptionIdentifier: '{subscription_id}'")
|
|
|
|
# CMS often sends initial "null" subscribe messages during startup/verification
|
|
# These should be ignored as they contain no useful data
|
|
if not subscription_id or subscription_id == "null" or subscription_id == "None":
|
|
print(f"🔍 IGNORING initial subscribe message with null/empty subscriptionIdentifier")
|
|
print(f"📋 This is normal - CMS will send proper setSubscriptionList later")
|
|
return
|
|
|
|
# If we get a valid subscription ID, convert to setSubscriptionList format
|
|
subscription_list_data = {
|
|
"type": "setSubscriptionList",
|
|
"subscriptions": [payload]
|
|
}
|
|
|
|
print(f"✅ Processing valid subscribe message: {subscription_id}")
|
|
await self._handle_set_subscription_list(subscription_list_data)
|
|
|
|
except Exception as e:
|
|
print(f"💥 Error handling subscribe message: {e}")
|
|
logger.error(f"Error handling subscribe: {e}")
|
|
traceback.print_exc()
|
|
|
|
async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None:
|
|
"""Handle individual unsubscription message."""
|
|
try:
|
|
payload = data.get("payload", {})
|
|
subscription_id = payload.get("subscriptionIdentifier")
|
|
|
|
if not subscription_id:
|
|
logger.error("Unsubscribe message missing subscriptionIdentifier")
|
|
return
|
|
|
|
# Stop stream and clean up
|
|
camera_id = subscription_to_camera.get(subscription_id)
|
|
if camera_id:
|
|
await self.stream_manager.stop_stream(camera_id)
|
|
self.model_manager.unload_models(camera_id)
|
|
del subscription_to_camera[subscription_id]
|
|
logger.info(f"Unsubscribed from {subscription_id}")
|
|
else:
|
|
logger.warning(f"Unknown subscription ID: {subscription_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling unsubscribe: {e}")
|
|
traceback.print_exc()
|
|
|
|
def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any:
|
|
"""Apply crop to frame if crop coordinates are specified."""
|
|
crop_coords = [
|
|
getattr(stream_info, "cropX1", None),
|
|
getattr(stream_info, "cropY1", None),
|
|
getattr(stream_info, "cropX2", None),
|
|
getattr(stream_info, "cropY2", None)
|
|
]
|
|
|
|
if all(coord is not None for coord in crop_coords):
|
|
x1, y1, x2, y2 = crop_coords
|
|
return frame[y1:y2, x1:x2]
|
|
|
|
return frame
|
|
|
|
|
|
# Convenience function for backward compatibility
|
|
async def handle_websocket_connection(
|
|
websocket: WebSocket,
|
|
stream_manager: StreamManager,
|
|
model_manager: ModelManager,
|
|
pipeline_executor: PipelineExecutor,
|
|
session_cache: SessionCacheManager,
|
|
redis_client: Optional[RedisClientManager] = None
|
|
) -> None:
|
|
"""
|
|
Handle a WebSocket connection using the WebSocketHandler.
|
|
|
|
This is a convenience function that creates a handler instance
|
|
and processes the connection.
|
|
"""
|
|
handler = WebSocketHandler(
|
|
stream_manager,
|
|
model_manager,
|
|
pipeline_executor,
|
|
session_cache,
|
|
redis_client
|
|
)
|
|
await handler.handle_connection(websocket) |