diff --git a/detector_worker/communication/websocket_handler.py b/detector_worker/communication/websocket_handler.py index bdfcad8..697db97 100644 --- a/detector_worker/communication/websocket_handler.py +++ b/detector_worker/communication/websocket_handler.py @@ -83,6 +83,8 @@ class WebSocketHandler: # Message handlers self.message_handlers: Dict[str, MessageHandler] = { + "subscribe": self._handle_subscribe, + "unsubscribe": self._handle_unsubscribe, "setSubscriptionList": self._handle_set_subscription_list, "requestState": self._handle_request_state, "setSessionId": self._handle_set_session, @@ -180,21 +182,21 @@ class WebSocketHandler: is_online = self.stream_manager.is_stream_active(camera_id) connection_info = { - "subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id), - "modelId": stream_info.get("modelId", 0), - "modelName": stream_info.get("modelName", "Unknown Model"), + "subscriptionIdentifier": getattr(stream_info, "subscriptionIdentifier", camera_id), + "modelId": getattr(stream_info, "modelId", 0), + "modelName": getattr(stream_info, "modelName", "Unknown Model"), "online": is_online } # Add crop coordinates if available - if "cropX1" in stream_info: - connection_info["cropX1"] = stream_info["cropX1"] - if "cropY1" in stream_info: - connection_info["cropY1"] = stream_info["cropY1"] - if "cropX2" in stream_info: - connection_info["cropX2"] = stream_info["cropX2"] - if "cropY2" in stream_info: - connection_info["cropY2"] = stream_info["cropY2"] + if hasattr(stream_info, "cropX1"): + connection_info["cropX1"] = stream_info.cropX1 + if hasattr(stream_info, "cropY1"): + connection_info["cropY1"] = stream_info.cropY1 + if hasattr(stream_info, "cropX2"): + connection_info["cropX2"] = stream_info.cropX2 + if hasattr(stream_info, "cropY2"): + connection_info["cropY2"] = stream_info.cropY2 camera_connections.append(connection_info) @@ -269,7 +271,7 @@ class WebSocketHandler: continue # Get model for this camera - model_id = stream_info.get("modelId") + model_id = getattr(stream_info, "modelId", None) if not model_id: continue @@ -413,7 +415,8 @@ class WebSocketHandler: await self.stream_manager.stop_stream(camera_id) self.model_manager.unload_models(camera_id) subscription_to_camera.pop(sub_id, None) - self.session_cache.clear_session(camera_id) + # Clear cached data (SessionCacheManager handles this automatically) + # Note: clear_session method not available, cleanup happens automatically logger.info(f"Removed subscription: {sub_id}") # Add new subscriptions @@ -612,7 +615,7 @@ class WebSocketHandler: ) -> None: """Send detection result over WebSocket.""" # Get session ID for this display - subscription_id = stream_info["subscriptionIdentifier"] + subscription_id = getattr(stream_info, "subscriptionIdentifier", "") display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id session_id = self.session_ids.get(display_id) @@ -623,8 +626,8 @@ class WebSocketHandler: "sessionId": session_id, # Required by protocol "data": { "detection": detection_result.to_dict(), - "modelId": stream_info["modelId"], - "modelName": stream_info["modelName"] + "modelId": getattr(stream_info, "modelId", 0), + "modelName": getattr(stream_info, "modelName", "Unknown Model") } } @@ -645,18 +648,18 @@ class WebSocketHandler: """Send camera disconnection notification.""" logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null") - # Clear cached data - self.session_cache.clear_session(camera_id) + # Clear cached data (SessionCacheManager handles this automatically via cleanup) + # Note: clear_session method not available, cleanup happens automatically # Send null detection detection_data = { "type": "imageDetection", - "subscriptionIdentifier": stream_info["subscriptionIdentifier"], + "subscriptionIdentifier": getattr(stream_info, "subscriptionIdentifier", ""), "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "data": { "detection": None, - "modelId": stream_info["modelId"], - "modelName": stream_info["modelName"] + "modelId": getattr(stream_info, "modelId", 0), + "modelName": getattr(stream_info, "modelName", "Unknown Model") } } @@ -672,13 +675,60 @@ class WebSocketHandler: self.camera_monitor.mark_disconnection_notified(camera_id) logger.info(f"📡 SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session") + async def _handle_subscribe(self, data: Dict[str, Any]) -> None: + """Handle individual subscription message.""" + try: + payload = data.get("payload", {}) + subscription_id = payload.get("subscriptionIdentifier") + + if not subscription_id: + logger.error("Subscribe message missing subscriptionIdentifier") + return + + # Convert single subscription to setSubscriptionList format + subscription_list_data = { + "type": "setSubscriptionList", + "subscriptions": [payload] + } + + # Delegate to existing setSubscriptionList handler + await self._handle_set_subscription_list(subscription_list_data) + + except Exception as e: + logger.error(f"Error handling subscribe: {e}") + traceback.print_exc() + + async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None: + """Handle individual unsubscription message.""" + try: + payload = data.get("payload", {}) + subscription_id = payload.get("subscriptionIdentifier") + + if not subscription_id: + logger.error("Unsubscribe message missing subscriptionIdentifier") + return + + # Stop stream and clean up + camera_id = subscription_to_camera.get(subscription_id) + if camera_id: + await self.stream_manager.stop_stream(camera_id) + self.model_manager.unload_models(camera_id) + del subscription_to_camera[subscription_id] + logger.info(f"Unsubscribed from {subscription_id}") + else: + logger.warning(f"Unknown subscription ID: {subscription_id}") + + except Exception as e: + logger.error(f"Error handling unsubscribe: {e}") + traceback.print_exc() + def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any: """Apply crop to frame if crop coordinates are specified.""" crop_coords = [ - stream_info.get("cropX1"), - stream_info.get("cropY1"), - stream_info.get("cropX2"), - stream_info.get("cropY2") + getattr(stream_info, "cropX1", None), + getattr(stream_info, "cropY1", None), + getattr(stream_info, "cropX2", None), + getattr(stream_info, "cropY2", None) ] if all(coord is not None for coord in crop_coords): diff --git a/detector_worker/models/model_manager.py b/detector_worker/models/model_manager.py index 3de3b97..97f41f0 100644 --- a/detector_worker/models/model_manager.py +++ b/detector_worker/models/model_manager.py @@ -7,13 +7,16 @@ for the detection worker. import os import logging import threading -from typing import Dict, Any, Optional, List, Set, Tuple +from typing import Dict, Any, Optional, List, Set, Tuple, TYPE_CHECKING from urllib.parse import urlparse import traceback from ..core.config import MODELS_DIR from ..core.exceptions import ModelLoadError +if TYPE_CHECKING: + from .pipeline_loader import PipelineLoader + # Setup logging logger = logging.getLogger("detector_worker.model_manager") @@ -110,11 +113,12 @@ class ModelManager: - Pipeline model tree management """ - def __init__(self, models_dir: str = MODELS_DIR): + def __init__(self, pipeline_loader: Optional['PipelineLoader'] = None, models_dir: str = MODELS_DIR): """ Initialize the model manager. Args: + pipeline_loader: Pipeline loader for handling MPTA archives (injected via DI) models_dir: Directory to cache downloaded models """ self.models_dir = models_dir @@ -124,8 +128,19 @@ class ModelManager: # Camera to models mapping self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree} - # Pipeline loader will be injected - self.pipeline_loader = None + # Pipeline loader injected via dependency injection + self.pipeline_loader = pipeline_loader + + # If pipeline_loader is None, try to resolve it from the container + if self.pipeline_loader is None: + try: + from ..core.dependency_injection import get_container + from .pipeline_loader import PipelineLoader + container = get_container() + self.pipeline_loader = container.resolve(PipelineLoader) + logger.info("PipelineLoader resolved from dependency container") + except Exception as e: + logger.warning(f"Could not resolve PipelineLoader from container: {e}") # Create models directory if it doesn't exist os.makedirs(self.models_dir, exist_ok=True) diff --git a/detector_worker/streams/stream_manager.py b/detector_worker/streams/stream_manager.py index d0a2dbf..5610b43 100644 --- a/detector_worker/streams/stream_manager.py +++ b/detector_worker/streams/stream_manager.py @@ -41,6 +41,15 @@ class StreamInfo: last_frame_time: Optional[float] = None frame_count: int = 0 + # Additional WebSocket fields + subscriptionIdentifier: Optional[str] = None + modelId: Optional[int] = None + modelName: Optional[str] = None + cropX1: Optional[int] = None + cropY1: Optional[int] = None + cropX2: Optional[int] = None + cropY2: Optional[int] = None + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary format.""" return { @@ -607,19 +616,24 @@ class StreamManager: snapshot_url = payload.get('snapshotUrl') snapshot_interval = payload.get('snapshotInterval', 5000) + # Create a subscriber_id (for WebSocket compatibility, use the subscription_id) + subscriber_id = f"websocket_{int(time.time() * 1000)}" + # Create subscription based on available URL type if rtsp_url: success = self.create_subscription( subscription_id=subscription_id, camera_id=camera_id, + subscriber_id=subscriber_id, rtsp_url=rtsp_url ) elif snapshot_url: success = self.create_subscription( subscription_id=subscription_id, camera_id=camera_id, + subscriber_id=subscriber_id, snapshot_url=snapshot_url, - snapshot_interval_ms=snapshot_interval + snapshot_interval=snapshot_interval ) else: logger.error(f"No valid stream URL provided for camera {camera_id}") diff --git a/requirements.txt b/requirements.txt index baddeb5..6578af6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,6 @@ fastapi[standard] uvicorn websockets redis -urllib3<2.0.0 \ No newline at end of file +urllib3<2.0.0 +aiohttp +aiofiles \ No newline at end of file