From f617025e01b18938f40c3e6aed9ad7eed30d225c Mon Sep 17 00:00:00 2001 From: ziesorx Date: Fri, 12 Sep 2025 20:25:18 +0700 Subject: [PATCH] Fix: websocket error --- app.py | 2 +- .../communication/websocket_handler.py | 2 +- detector_worker/core/dependency_injection.py | 62 +++++++- detector_worker/streams/stream_manager.py | 146 ++++++++++++++++++ 4 files changed, 207 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index ac814c6..d9a571b 100644 --- a/app.py +++ b/app.py @@ -102,7 +102,7 @@ app = FastAPI( ) -@app.websocket("/ws") +@app.websocket("/") async def websocket_endpoint(websocket: WebSocket): """Main WebSocket endpoint for real-time communication.""" try: diff --git a/detector_worker/communication/websocket_handler.py b/detector_worker/communication/websocket_handler.py index 2707bea..da05d6a 100644 --- a/detector_worker/communication/websocket_handler.py +++ b/detector_worker/communication/websocket_handler.py @@ -95,7 +95,7 @@ class WebSocketHandler: self.display_identifiers: Set[str] = set() # Camera monitor - self.camera_monitor = CameraConnectionMonitor() + self.camera_monitor = CameraMonitor() async def handle_connection(self, websocket: WebSocket) -> None: """ diff --git a/detector_worker/core/dependency_injection.py b/detector_worker/core/dependency_injection.py index 60a5537..7ccf5c0 100644 --- a/detector_worker/core/dependency_injection.py +++ b/detector_worker/core/dependency_injection.py @@ -188,7 +188,8 @@ class ServiceContainer: """Internal service resolution.""" # Check if service is registered if service_type not in self._services: - raise DependencyInjectionError(f"Service {service_type.__name__} is not registered") + service_name = getattr(service_type, '__name__', str(service_type)) + raise DependencyInjectionError(f"Service {service_name} is not registered") descriptor = self._services[service_type] @@ -263,15 +264,26 @@ class ServiceContainer: if param.annotation != inspect.Parameter.empty: # Try to resolve the parameter type try: - dependency = self.resolve(param.annotation) + # Handle Optional[T] types by extracting the inner type + param_type = param.annotation + if hasattr(param_type, '__origin__') and param_type.__origin__ is Union: + # This is Optional[T] which is Union[T, None] + union_args = getattr(param_type, '__args__', ()) + if len(union_args) == 2 and type(None) in union_args: + # Extract the non-None type + param_type = next(arg for arg in union_args if arg is not type(None)) + + dependency = self.resolve(param_type) dependencies.append(dependency) except DependencyInjectionError: # If dependency cannot be resolved and has default, use default if param.default != inspect.Parameter.empty: dependencies.append(param.default) else: + param_name = getattr(param.annotation, '__name__', str(param.annotation)) + service_name = getattr(service_type, '__name__', str(service_type)) raise DependencyInjectionError( - f"Cannot resolve dependency {param.annotation.__name__} for {service_type.__name__}" + f"Cannot resolve dependency {param_name} for {service_name}" ) else: # Parameter without type annotation, use default if available @@ -423,6 +435,9 @@ class DetectorWorkerContainer: # Register other core services self._register_detection_services() + self._register_stream_services() + self._register_model_services() + self._register_pipeline_services() self._register_communication_services() self._register_storage_services() @@ -443,6 +458,47 @@ class DetectorWorkerContainer: except ImportError as e: logger.warning(f"Could not register detection services: {e}") + def _register_stream_services(self) -> None: + """Register stream-related services.""" + try: + from ..streams.stream_manager import StreamManager + from ..streams.frame_reader import RTSPFrameReader, SnapshotFrameReader + from ..streams.camera_monitor import CameraMonitor + + self.container.register_singleton(StreamManager) + self.container.register_transient(RTSPFrameReader) + self.container.register_transient(SnapshotFrameReader) + self.container.register_singleton(CameraMonitor) + + except ImportError as e: + logger.warning(f"Could not register stream services: {e}") + + def _register_model_services(self) -> None: + """Register model-related services.""" + try: + from ..models.model_manager import ModelManager + from ..models.pipeline_loader import PipelineLoader + + self.container.register_singleton(ModelManager) + self.container.register_singleton(PipelineLoader) + + except ImportError as e: + logger.warning(f"Could not register model services: {e}") + + def _register_pipeline_services(self) -> None: + """Register pipeline-related services.""" + try: + from ..pipeline.pipeline_executor import PipelineExecutor + from ..pipeline.action_executor import ActionExecutor + from ..pipeline.field_mapper import FieldMapper + + self.container.register_singleton(PipelineExecutor) + self.container.register_transient(ActionExecutor) + self.container.register_singleton(FieldMapper) + + except ImportError as e: + logger.warning(f"Could not register pipeline services: {e}") + def _register_communication_services(self) -> None: """Register communication-related services.""" try: diff --git a/detector_worker/streams/stream_manager.py b/detector_worker/streams/stream_manager.py index c6d2155..d0a2dbf 100644 --- a/detector_worker/streams/stream_manager.py +++ b/detector_worker/streams/stream_manager.py @@ -100,6 +100,17 @@ class StreamManager: self._subscriptions: Dict[str, StreamSubscription] = {} self._lock = None + @property + def streams(self) -> Dict[str, StreamInfo]: + """Public access to streams dictionary.""" + return self._streams + + @property + def streams_lock(self): + """Public access to streams lock.""" + self._ensure_thread_safety() + return self._lock + def _ensure_thread_safety(self): """Initialize thread safety if not already done.""" if self._lock is None: @@ -525,6 +536,141 @@ class StreamManager: self._subscriptions.clear() logger.info("All streams shut down successfully") + + # ===== WEBSOCKET HANDLER COMPATIBILITY METHODS ===== + # These methods provide compatibility with the WebSocketHandler interface + + def get_active_streams(self) -> Dict[str, Any]: + """ + Get all active streams for WebSocket handler compatibility. + + Returns: + Dictionary of active streams with their information + """ + self._ensure_thread_safety() + with self._lock: + active_streams = {} + for camera_id, stream_info in self._streams.items(): + if stream_info.thread and stream_info.thread.is_alive(): + active_streams[camera_id] = { + 'camera_id': camera_id, + 'status': 'active', + 'stream_type': stream_info.stream_type, + 'subscribers': len([sub_id for sub_id, sub in self._subscriptions.items() + if sub.camera_id == camera_id]), + 'last_frame_time': getattr(stream_info, 'last_frame_time', None), + 'error_count': getattr(stream_info, 'error_count', 0) + } + return active_streams + + def get_latest_frame(self, camera_id: str) -> Optional[Any]: + """ + Get the latest frame for a camera for WebSocket handler compatibility. + + Args: + camera_id: Camera identifier + + Returns: + Latest frame data or None if not available + """ + self._ensure_thread_safety() + with self._lock: + stream_info = self._streams.get(camera_id) + if stream_info and hasattr(stream_info, 'latest_frame'): + return stream_info.latest_frame + return None + + async def cleanup_all_streams(self) -> None: + """ + Cleanup all streams asynchronously for WebSocket handler compatibility. + """ + # This is an async wrapper around shutdown_all for compatibility + self.shutdown_all() + + async def start_stream(self, camera_id: str, payload: Dict[str, Any]) -> bool: + """ + Start a stream for WebSocket handler compatibility. + + Args: + camera_id: Camera identifier + payload: Stream configuration payload + + Returns: + True if stream started successfully, False otherwise + """ + try: + # Create a subscription ID for this stream + subscription_id = f"ws_{camera_id}_{int(time.time() * 1000)}" + + # Extract stream parameters from payload + rtsp_url = payload.get('rtspUrl') + snapshot_url = payload.get('snapshotUrl') + snapshot_interval = payload.get('snapshotInterval', 5000) + + # Create subscription based on available URL type + if rtsp_url: + success = self.create_subscription( + subscription_id=subscription_id, + camera_id=camera_id, + rtsp_url=rtsp_url + ) + elif snapshot_url: + success = self.create_subscription( + subscription_id=subscription_id, + camera_id=camera_id, + snapshot_url=snapshot_url, + snapshot_interval_ms=snapshot_interval + ) + else: + logger.error(f"No valid stream URL provided for camera {camera_id}") + return False + + if success: + logger.info(f"Started stream for camera {camera_id} with subscription {subscription_id}") + return True + else: + logger.error(f"Failed to start stream for camera {camera_id}") + return False + + except Exception as e: + logger.error(f"Error starting stream for camera {camera_id}: {e}") + return False + + async def stop_stream(self, camera_id: str) -> bool: + """ + Stop a stream for WebSocket handler compatibility. + + Args: + camera_id: Camera identifier + + Returns: + True if stream stopped successfully, False otherwise + """ + try: + # Find and remove subscriptions for this camera + subscriptions_to_remove = [ + sub_id for sub_id, sub in self._subscriptions.items() + if sub.camera_id == camera_id + ] + + success = True + for sub_id in subscriptions_to_remove: + if not self.remove_subscription(sub_id): + success = False + + if success and subscriptions_to_remove: + logger.info(f"Stopped stream for camera {camera_id}") + return True + elif not subscriptions_to_remove: + logger.warning(f"No active subscriptions found for camera {camera_id}") + return True + else: + logger.error(f"Failed to stop some subscriptions for camera {camera_id}") + return False + + except Exception as e: + logger.error(f"Error stopping stream for camera {camera_id}: {e}") + return False # Global stream manager instance