Fix: several bug fixes

This commit is contained in:
ziesorx 2025-09-12 22:44:27 +07:00
parent 96ecc321ec
commit 9967bff6dc
4 changed files with 112 additions and 31 deletions

View file

@ -83,6 +83,8 @@ class WebSocketHandler:
# Message handlers # Message handlers
self.message_handlers: Dict[str, MessageHandler] = { self.message_handlers: Dict[str, MessageHandler] = {
"subscribe": self._handle_subscribe,
"unsubscribe": self._handle_unsubscribe,
"setSubscriptionList": self._handle_set_subscription_list, "setSubscriptionList": self._handle_set_subscription_list,
"requestState": self._handle_request_state, "requestState": self._handle_request_state,
"setSessionId": self._handle_set_session, "setSessionId": self._handle_set_session,
@ -180,21 +182,21 @@ class WebSocketHandler:
is_online = self.stream_manager.is_stream_active(camera_id) is_online = self.stream_manager.is_stream_active(camera_id)
connection_info = { connection_info = {
"subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id), "subscriptionIdentifier": getattr(stream_info, "subscriptionIdentifier", camera_id),
"modelId": stream_info.get("modelId", 0), "modelId": getattr(stream_info, "modelId", 0),
"modelName": stream_info.get("modelName", "Unknown Model"), "modelName": getattr(stream_info, "modelName", "Unknown Model"),
"online": is_online "online": is_online
} }
# Add crop coordinates if available # Add crop coordinates if available
if "cropX1" in stream_info: if hasattr(stream_info, "cropX1"):
connection_info["cropX1"] = stream_info["cropX1"] connection_info["cropX1"] = stream_info.cropX1
if "cropY1" in stream_info: if hasattr(stream_info, "cropY1"):
connection_info["cropY1"] = stream_info["cropY1"] connection_info["cropY1"] = stream_info.cropY1
if "cropX2" in stream_info: if hasattr(stream_info, "cropX2"):
connection_info["cropX2"] = stream_info["cropX2"] connection_info["cropX2"] = stream_info.cropX2
if "cropY2" in stream_info: if hasattr(stream_info, "cropY2"):
connection_info["cropY2"] = stream_info["cropY2"] connection_info["cropY2"] = stream_info.cropY2
camera_connections.append(connection_info) camera_connections.append(connection_info)
@ -269,7 +271,7 @@ class WebSocketHandler:
continue continue
# Get model for this camera # Get model for this camera
model_id = stream_info.get("modelId") model_id = getattr(stream_info, "modelId", None)
if not model_id: if not model_id:
continue continue
@ -413,7 +415,8 @@ class WebSocketHandler:
await self.stream_manager.stop_stream(camera_id) await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id) self.model_manager.unload_models(camera_id)
subscription_to_camera.pop(sub_id, None) subscription_to_camera.pop(sub_id, None)
self.session_cache.clear_session(camera_id) # Clear cached data (SessionCacheManager handles this automatically)
# Note: clear_session method not available, cleanup happens automatically
logger.info(f"Removed subscription: {sub_id}") logger.info(f"Removed subscription: {sub_id}")
# Add new subscriptions # Add new subscriptions
@ -612,7 +615,7 @@ class WebSocketHandler:
) -> None: ) -> None:
"""Send detection result over WebSocket.""" """Send detection result over WebSocket."""
# Get session ID for this display # Get session ID for this display
subscription_id = stream_info["subscriptionIdentifier"] subscription_id = getattr(stream_info, "subscriptionIdentifier", "")
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
session_id = self.session_ids.get(display_id) session_id = self.session_ids.get(display_id)
@ -623,8 +626,8 @@ class WebSocketHandler:
"sessionId": session_id, # Required by protocol "sessionId": session_id, # Required by protocol
"data": { "data": {
"detection": detection_result.to_dict(), "detection": detection_result.to_dict(),
"modelId": stream_info["modelId"], "modelId": getattr(stream_info, "modelId", 0),
"modelName": stream_info["modelName"] "modelName": getattr(stream_info, "modelName", "Unknown Model")
} }
} }
@ -645,18 +648,18 @@ class WebSocketHandler:
"""Send camera disconnection notification.""" """Send camera disconnection notification."""
logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null") logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null")
# Clear cached data # Clear cached data (SessionCacheManager handles this automatically via cleanup)
self.session_cache.clear_session(camera_id) # Note: clear_session method not available, cleanup happens automatically
# Send null detection # Send null detection
detection_data = { detection_data = {
"type": "imageDetection", "type": "imageDetection",
"subscriptionIdentifier": stream_info["subscriptionIdentifier"], "subscriptionIdentifier": getattr(stream_info, "subscriptionIdentifier", ""),
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"data": { "data": {
"detection": None, "detection": None,
"modelId": stream_info["modelId"], "modelId": getattr(stream_info, "modelId", 0),
"modelName": stream_info["modelName"] "modelName": getattr(stream_info, "modelName", "Unknown Model")
} }
} }
@ -672,13 +675,60 @@ class WebSocketHandler:
self.camera_monitor.mark_disconnection_notified(camera_id) self.camera_monitor.mark_disconnection_notified(camera_id)
logger.info(f"📡 SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session") logger.info(f"📡 SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session")
async def _handle_subscribe(self, data: Dict[str, Any]) -> None:
"""Handle individual subscription message."""
try:
payload = data.get("payload", {})
subscription_id = payload.get("subscriptionIdentifier")
if not subscription_id:
logger.error("Subscribe message missing subscriptionIdentifier")
return
# Convert single subscription to setSubscriptionList format
subscription_list_data = {
"type": "setSubscriptionList",
"subscriptions": [payload]
}
# Delegate to existing setSubscriptionList handler
await self._handle_set_subscription_list(subscription_list_data)
except Exception as e:
logger.error(f"Error handling subscribe: {e}")
traceback.print_exc()
async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None:
"""Handle individual unsubscription message."""
try:
payload = data.get("payload", {})
subscription_id = payload.get("subscriptionIdentifier")
if not subscription_id:
logger.error("Unsubscribe message missing subscriptionIdentifier")
return
# Stop stream and clean up
camera_id = subscription_to_camera.get(subscription_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
del subscription_to_camera[subscription_id]
logger.info(f"Unsubscribed from {subscription_id}")
else:
logger.warning(f"Unknown subscription ID: {subscription_id}")
except Exception as e:
logger.error(f"Error handling unsubscribe: {e}")
traceback.print_exc()
def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any: def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any:
"""Apply crop to frame if crop coordinates are specified.""" """Apply crop to frame if crop coordinates are specified."""
crop_coords = [ crop_coords = [
stream_info.get("cropX1"), getattr(stream_info, "cropX1", None),
stream_info.get("cropY1"), getattr(stream_info, "cropY1", None),
stream_info.get("cropX2"), getattr(stream_info, "cropX2", None),
stream_info.get("cropY2") getattr(stream_info, "cropY2", None)
] ]
if all(coord is not None for coord in crop_coords): if all(coord is not None for coord in crop_coords):

View file

@ -7,13 +7,16 @@ for the detection worker.
import os import os
import logging import logging
import threading import threading
from typing import Dict, Any, Optional, List, Set, Tuple from typing import Dict, Any, Optional, List, Set, Tuple, TYPE_CHECKING
from urllib.parse import urlparse from urllib.parse import urlparse
import traceback import traceback
from ..core.config import MODELS_DIR from ..core.config import MODELS_DIR
from ..core.exceptions import ModelLoadError from ..core.exceptions import ModelLoadError
if TYPE_CHECKING:
from .pipeline_loader import PipelineLoader
# Setup logging # Setup logging
logger = logging.getLogger("detector_worker.model_manager") logger = logging.getLogger("detector_worker.model_manager")
@ -110,11 +113,12 @@ class ModelManager:
- Pipeline model tree management - Pipeline model tree management
""" """
def __init__(self, models_dir: str = MODELS_DIR): def __init__(self, pipeline_loader: Optional['PipelineLoader'] = None, models_dir: str = MODELS_DIR):
""" """
Initialize the model manager. Initialize the model manager.
Args: Args:
pipeline_loader: Pipeline loader for handling MPTA archives (injected via DI)
models_dir: Directory to cache downloaded models models_dir: Directory to cache downloaded models
""" """
self.models_dir = models_dir self.models_dir = models_dir
@ -124,8 +128,19 @@ class ModelManager:
# Camera to models mapping # Camera to models mapping
self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree} self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree}
# Pipeline loader will be injected # Pipeline loader injected via dependency injection
self.pipeline_loader = None self.pipeline_loader = pipeline_loader
# If pipeline_loader is None, try to resolve it from the container
if self.pipeline_loader is None:
try:
from ..core.dependency_injection import get_container
from .pipeline_loader import PipelineLoader
container = get_container()
self.pipeline_loader = container.resolve(PipelineLoader)
logger.info("PipelineLoader resolved from dependency container")
except Exception as e:
logger.warning(f"Could not resolve PipelineLoader from container: {e}")
# Create models directory if it doesn't exist # Create models directory if it doesn't exist
os.makedirs(self.models_dir, exist_ok=True) os.makedirs(self.models_dir, exist_ok=True)

View file

@ -41,6 +41,15 @@ class StreamInfo:
last_frame_time: Optional[float] = None last_frame_time: Optional[float] = None
frame_count: int = 0 frame_count: int = 0
# Additional WebSocket fields
subscriptionIdentifier: Optional[str] = None
modelId: Optional[int] = None
modelName: Optional[str] = None
cropX1: Optional[int] = None
cropY1: Optional[int] = None
cropX2: Optional[int] = None
cropY2: Optional[int] = None
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format.""" """Convert to dictionary format."""
return { return {
@ -607,19 +616,24 @@ class StreamManager:
snapshot_url = payload.get('snapshotUrl') snapshot_url = payload.get('snapshotUrl')
snapshot_interval = payload.get('snapshotInterval', 5000) snapshot_interval = payload.get('snapshotInterval', 5000)
# Create a subscriber_id (for WebSocket compatibility, use the subscription_id)
subscriber_id = f"websocket_{int(time.time() * 1000)}"
# Create subscription based on available URL type # Create subscription based on available URL type
if rtsp_url: if rtsp_url:
success = self.create_subscription( success = self.create_subscription(
subscription_id=subscription_id, subscription_id=subscription_id,
camera_id=camera_id, camera_id=camera_id,
subscriber_id=subscriber_id,
rtsp_url=rtsp_url rtsp_url=rtsp_url
) )
elif snapshot_url: elif snapshot_url:
success = self.create_subscription( success = self.create_subscription(
subscription_id=subscription_id, subscription_id=subscription_id,
camera_id=camera_id, camera_id=camera_id,
subscriber_id=subscriber_id,
snapshot_url=snapshot_url, snapshot_url=snapshot_url,
snapshot_interval_ms=snapshot_interval snapshot_interval=snapshot_interval
) )
else: else:
logger.error(f"No valid stream URL provided for camera {camera_id}") logger.error(f"No valid stream URL provided for camera {camera_id}")

View file

@ -2,4 +2,6 @@ fastapi[standard]
uvicorn uvicorn
websockets websockets
redis redis
urllib3<2.0.0 urllib3<2.0.0
aiohttp
aiofiles