Fix: websocket error
This commit is contained in:
parent
aacc5145d4
commit
f617025e01
4 changed files with 207 additions and 5 deletions
2
app.py
2
app.py
|
@ -102,7 +102,7 @@ app = FastAPI(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.websocket("/")
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
"""Main WebSocket endpoint for real-time communication."""
|
"""Main WebSocket endpoint for real-time communication."""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -95,7 +95,7 @@ class WebSocketHandler:
|
||||||
self.display_identifiers: Set[str] = set()
|
self.display_identifiers: Set[str] = set()
|
||||||
|
|
||||||
# Camera monitor
|
# Camera monitor
|
||||||
self.camera_monitor = CameraConnectionMonitor()
|
self.camera_monitor = CameraMonitor()
|
||||||
|
|
||||||
async def handle_connection(self, websocket: WebSocket) -> None:
|
async def handle_connection(self, websocket: WebSocket) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -188,7 +188,8 @@ class ServiceContainer:
|
||||||
"""Internal service resolution."""
|
"""Internal service resolution."""
|
||||||
# Check if service is registered
|
# Check if service is registered
|
||||||
if service_type not in self._services:
|
if service_type not in self._services:
|
||||||
raise DependencyInjectionError(f"Service {service_type.__name__} is not registered")
|
service_name = getattr(service_type, '__name__', str(service_type))
|
||||||
|
raise DependencyInjectionError(f"Service {service_name} is not registered")
|
||||||
|
|
||||||
descriptor = self._services[service_type]
|
descriptor = self._services[service_type]
|
||||||
|
|
||||||
|
@ -263,15 +264,26 @@ class ServiceContainer:
|
||||||
if param.annotation != inspect.Parameter.empty:
|
if param.annotation != inspect.Parameter.empty:
|
||||||
# Try to resolve the parameter type
|
# Try to resolve the parameter type
|
||||||
try:
|
try:
|
||||||
dependency = self.resolve(param.annotation)
|
# Handle Optional[T] types by extracting the inner type
|
||||||
|
param_type = param.annotation
|
||||||
|
if hasattr(param_type, '__origin__') and param_type.__origin__ is Union:
|
||||||
|
# This is Optional[T] which is Union[T, None]
|
||||||
|
union_args = getattr(param_type, '__args__', ())
|
||||||
|
if len(union_args) == 2 and type(None) in union_args:
|
||||||
|
# Extract the non-None type
|
||||||
|
param_type = next(arg for arg in union_args if arg is not type(None))
|
||||||
|
|
||||||
|
dependency = self.resolve(param_type)
|
||||||
dependencies.append(dependency)
|
dependencies.append(dependency)
|
||||||
except DependencyInjectionError:
|
except DependencyInjectionError:
|
||||||
# If dependency cannot be resolved and has default, use default
|
# If dependency cannot be resolved and has default, use default
|
||||||
if param.default != inspect.Parameter.empty:
|
if param.default != inspect.Parameter.empty:
|
||||||
dependencies.append(param.default)
|
dependencies.append(param.default)
|
||||||
else:
|
else:
|
||||||
|
param_name = getattr(param.annotation, '__name__', str(param.annotation))
|
||||||
|
service_name = getattr(service_type, '__name__', str(service_type))
|
||||||
raise DependencyInjectionError(
|
raise DependencyInjectionError(
|
||||||
f"Cannot resolve dependency {param.annotation.__name__} for {service_type.__name__}"
|
f"Cannot resolve dependency {param_name} for {service_name}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Parameter without type annotation, use default if available
|
# Parameter without type annotation, use default if available
|
||||||
|
@ -423,6 +435,9 @@ class DetectorWorkerContainer:
|
||||||
|
|
||||||
# Register other core services
|
# Register other core services
|
||||||
self._register_detection_services()
|
self._register_detection_services()
|
||||||
|
self._register_stream_services()
|
||||||
|
self._register_model_services()
|
||||||
|
self._register_pipeline_services()
|
||||||
self._register_communication_services()
|
self._register_communication_services()
|
||||||
self._register_storage_services()
|
self._register_storage_services()
|
||||||
|
|
||||||
|
@ -443,6 +458,47 @@ class DetectorWorkerContainer:
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not register detection services: {e}")
|
logger.warning(f"Could not register detection services: {e}")
|
||||||
|
|
||||||
|
def _register_stream_services(self) -> None:
|
||||||
|
"""Register stream-related services."""
|
||||||
|
try:
|
||||||
|
from ..streams.stream_manager import StreamManager
|
||||||
|
from ..streams.frame_reader import RTSPFrameReader, SnapshotFrameReader
|
||||||
|
from ..streams.camera_monitor import CameraMonitor
|
||||||
|
|
||||||
|
self.container.register_singleton(StreamManager)
|
||||||
|
self.container.register_transient(RTSPFrameReader)
|
||||||
|
self.container.register_transient(SnapshotFrameReader)
|
||||||
|
self.container.register_singleton(CameraMonitor)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Could not register stream services: {e}")
|
||||||
|
|
||||||
|
def _register_model_services(self) -> None:
|
||||||
|
"""Register model-related services."""
|
||||||
|
try:
|
||||||
|
from ..models.model_manager import ModelManager
|
||||||
|
from ..models.pipeline_loader import PipelineLoader
|
||||||
|
|
||||||
|
self.container.register_singleton(ModelManager)
|
||||||
|
self.container.register_singleton(PipelineLoader)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Could not register model services: {e}")
|
||||||
|
|
||||||
|
def _register_pipeline_services(self) -> None:
|
||||||
|
"""Register pipeline-related services."""
|
||||||
|
try:
|
||||||
|
from ..pipeline.pipeline_executor import PipelineExecutor
|
||||||
|
from ..pipeline.action_executor import ActionExecutor
|
||||||
|
from ..pipeline.field_mapper import FieldMapper
|
||||||
|
|
||||||
|
self.container.register_singleton(PipelineExecutor)
|
||||||
|
self.container.register_transient(ActionExecutor)
|
||||||
|
self.container.register_singleton(FieldMapper)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Could not register pipeline services: {e}")
|
||||||
|
|
||||||
def _register_communication_services(self) -> None:
|
def _register_communication_services(self) -> None:
|
||||||
"""Register communication-related services."""
|
"""Register communication-related services."""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -100,6 +100,17 @@ class StreamManager:
|
||||||
self._subscriptions: Dict[str, StreamSubscription] = {}
|
self._subscriptions: Dict[str, StreamSubscription] = {}
|
||||||
self._lock = None
|
self._lock = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def streams(self) -> Dict[str, StreamInfo]:
|
||||||
|
"""Public access to streams dictionary."""
|
||||||
|
return self._streams
|
||||||
|
|
||||||
|
@property
|
||||||
|
def streams_lock(self):
|
||||||
|
"""Public access to streams lock."""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
return self._lock
|
||||||
|
|
||||||
def _ensure_thread_safety(self):
|
def _ensure_thread_safety(self):
|
||||||
"""Initialize thread safety if not already done."""
|
"""Initialize thread safety if not already done."""
|
||||||
if self._lock is None:
|
if self._lock is None:
|
||||||
|
@ -526,6 +537,141 @@ class StreamManager:
|
||||||
|
|
||||||
logger.info("All streams shut down successfully")
|
logger.info("All streams shut down successfully")
|
||||||
|
|
||||||
|
# ===== WEBSOCKET HANDLER COMPATIBILITY METHODS =====
|
||||||
|
# These methods provide compatibility with the WebSocketHandler interface
|
||||||
|
|
||||||
|
def get_active_streams(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get all active streams for WebSocket handler compatibility.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of active streams with their information
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
with self._lock:
|
||||||
|
active_streams = {}
|
||||||
|
for camera_id, stream_info in self._streams.items():
|
||||||
|
if stream_info.thread and stream_info.thread.is_alive():
|
||||||
|
active_streams[camera_id] = {
|
||||||
|
'camera_id': camera_id,
|
||||||
|
'status': 'active',
|
||||||
|
'stream_type': stream_info.stream_type,
|
||||||
|
'subscribers': len([sub_id for sub_id, sub in self._subscriptions.items()
|
||||||
|
if sub.camera_id == camera_id]),
|
||||||
|
'last_frame_time': getattr(stream_info, 'last_frame_time', None),
|
||||||
|
'error_count': getattr(stream_info, 'error_count', 0)
|
||||||
|
}
|
||||||
|
return active_streams
|
||||||
|
|
||||||
|
def get_latest_frame(self, camera_id: str) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
Get the latest frame for a camera for WebSocket handler compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Latest frame data or None if not available
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
with self._lock:
|
||||||
|
stream_info = self._streams.get(camera_id)
|
||||||
|
if stream_info and hasattr(stream_info, 'latest_frame'):
|
||||||
|
return stream_info.latest_frame
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def cleanup_all_streams(self) -> None:
|
||||||
|
"""
|
||||||
|
Cleanup all streams asynchronously for WebSocket handler compatibility.
|
||||||
|
"""
|
||||||
|
# This is an async wrapper around shutdown_all for compatibility
|
||||||
|
self.shutdown_all()
|
||||||
|
|
||||||
|
async def start_stream(self, camera_id: str, payload: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Start a stream for WebSocket handler compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
payload: Stream configuration payload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if stream started successfully, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create a subscription ID for this stream
|
||||||
|
subscription_id = f"ws_{camera_id}_{int(time.time() * 1000)}"
|
||||||
|
|
||||||
|
# Extract stream parameters from payload
|
||||||
|
rtsp_url = payload.get('rtspUrl')
|
||||||
|
snapshot_url = payload.get('snapshotUrl')
|
||||||
|
snapshot_interval = payload.get('snapshotInterval', 5000)
|
||||||
|
|
||||||
|
# Create subscription based on available URL type
|
||||||
|
if rtsp_url:
|
||||||
|
success = self.create_subscription(
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
camera_id=camera_id,
|
||||||
|
rtsp_url=rtsp_url
|
||||||
|
)
|
||||||
|
elif snapshot_url:
|
||||||
|
success = self.create_subscription(
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
camera_id=camera_id,
|
||||||
|
snapshot_url=snapshot_url,
|
||||||
|
snapshot_interval_ms=snapshot_interval
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"No valid stream URL provided for camera {camera_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(f"Started stream for camera {camera_id} with subscription {subscription_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to start stream for camera {camera_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting stream for camera {camera_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def stop_stream(self, camera_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Stop a stream for WebSocket handler compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if stream stopped successfully, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Find and remove subscriptions for this camera
|
||||||
|
subscriptions_to_remove = [
|
||||||
|
sub_id for sub_id, sub in self._subscriptions.items()
|
||||||
|
if sub.camera_id == camera_id
|
||||||
|
]
|
||||||
|
|
||||||
|
success = True
|
||||||
|
for sub_id in subscriptions_to_remove:
|
||||||
|
if not self.remove_subscription(sub_id):
|
||||||
|
success = False
|
||||||
|
|
||||||
|
if success and subscriptions_to_remove:
|
||||||
|
logger.info(f"Stopped stream for camera {camera_id}")
|
||||||
|
return True
|
||||||
|
elif not subscriptions_to_remove:
|
||||||
|
logger.warning(f"No active subscriptions found for camera {camera_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to stop some subscriptions for camera {camera_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error stopping stream for camera {camera_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# Global stream manager instance
|
# Global stream manager instance
|
||||||
stream_manager = StreamManager()
|
stream_manager = StreamManager()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue