Fix: several bug fixes
This commit is contained in:
parent
96ecc321ec
commit
9967bff6dc
4 changed files with 112 additions and 31 deletions
|
@ -83,6 +83,8 @@ class WebSocketHandler:
|
||||||
|
|
||||||
# Message handlers
|
# Message handlers
|
||||||
self.message_handlers: Dict[str, MessageHandler] = {
|
self.message_handlers: Dict[str, MessageHandler] = {
|
||||||
|
"subscribe": self._handle_subscribe,
|
||||||
|
"unsubscribe": self._handle_unsubscribe,
|
||||||
"setSubscriptionList": self._handle_set_subscription_list,
|
"setSubscriptionList": self._handle_set_subscription_list,
|
||||||
"requestState": self._handle_request_state,
|
"requestState": self._handle_request_state,
|
||||||
"setSessionId": self._handle_set_session,
|
"setSessionId": self._handle_set_session,
|
||||||
|
@ -180,21 +182,21 @@ class WebSocketHandler:
|
||||||
is_online = self.stream_manager.is_stream_active(camera_id)
|
is_online = self.stream_manager.is_stream_active(camera_id)
|
||||||
|
|
||||||
connection_info = {
|
connection_info = {
|
||||||
"subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id),
|
"subscriptionIdentifier": getattr(stream_info, "subscriptionIdentifier", camera_id),
|
||||||
"modelId": stream_info.get("modelId", 0),
|
"modelId": getattr(stream_info, "modelId", 0),
|
||||||
"modelName": stream_info.get("modelName", "Unknown Model"),
|
"modelName": getattr(stream_info, "modelName", "Unknown Model"),
|
||||||
"online": is_online
|
"online": is_online
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add crop coordinates if available
|
# Add crop coordinates if available
|
||||||
if "cropX1" in stream_info:
|
if hasattr(stream_info, "cropX1"):
|
||||||
connection_info["cropX1"] = stream_info["cropX1"]
|
connection_info["cropX1"] = stream_info.cropX1
|
||||||
if "cropY1" in stream_info:
|
if hasattr(stream_info, "cropY1"):
|
||||||
connection_info["cropY1"] = stream_info["cropY1"]
|
connection_info["cropY1"] = stream_info.cropY1
|
||||||
if "cropX2" in stream_info:
|
if hasattr(stream_info, "cropX2"):
|
||||||
connection_info["cropX2"] = stream_info["cropX2"]
|
connection_info["cropX2"] = stream_info.cropX2
|
||||||
if "cropY2" in stream_info:
|
if hasattr(stream_info, "cropY2"):
|
||||||
connection_info["cropY2"] = stream_info["cropY2"]
|
connection_info["cropY2"] = stream_info.cropY2
|
||||||
|
|
||||||
camera_connections.append(connection_info)
|
camera_connections.append(connection_info)
|
||||||
|
|
||||||
|
@ -269,7 +271,7 @@ class WebSocketHandler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get model for this camera
|
# Get model for this camera
|
||||||
model_id = stream_info.get("modelId")
|
model_id = getattr(stream_info, "modelId", None)
|
||||||
if not model_id:
|
if not model_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -413,7 +415,8 @@ class WebSocketHandler:
|
||||||
await self.stream_manager.stop_stream(camera_id)
|
await self.stream_manager.stop_stream(camera_id)
|
||||||
self.model_manager.unload_models(camera_id)
|
self.model_manager.unload_models(camera_id)
|
||||||
subscription_to_camera.pop(sub_id, None)
|
subscription_to_camera.pop(sub_id, None)
|
||||||
self.session_cache.clear_session(camera_id)
|
# Clear cached data (SessionCacheManager handles this automatically)
|
||||||
|
# Note: clear_session method not available, cleanup happens automatically
|
||||||
logger.info(f"Removed subscription: {sub_id}")
|
logger.info(f"Removed subscription: {sub_id}")
|
||||||
|
|
||||||
# Add new subscriptions
|
# Add new subscriptions
|
||||||
|
@ -612,7 +615,7 @@ class WebSocketHandler:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send detection result over WebSocket."""
|
"""Send detection result over WebSocket."""
|
||||||
# Get session ID for this display
|
# Get session ID for this display
|
||||||
subscription_id = stream_info["subscriptionIdentifier"]
|
subscription_id = getattr(stream_info, "subscriptionIdentifier", "")
|
||||||
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
|
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
|
||||||
session_id = self.session_ids.get(display_id)
|
session_id = self.session_ids.get(display_id)
|
||||||
|
|
||||||
|
@ -623,8 +626,8 @@ class WebSocketHandler:
|
||||||
"sessionId": session_id, # Required by protocol
|
"sessionId": session_id, # Required by protocol
|
||||||
"data": {
|
"data": {
|
||||||
"detection": detection_result.to_dict(),
|
"detection": detection_result.to_dict(),
|
||||||
"modelId": stream_info["modelId"],
|
"modelId": getattr(stream_info, "modelId", 0),
|
||||||
"modelName": stream_info["modelName"]
|
"modelName": getattr(stream_info, "modelName", "Unknown Model")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -645,18 +648,18 @@ class WebSocketHandler:
|
||||||
"""Send camera disconnection notification."""
|
"""Send camera disconnection notification."""
|
||||||
logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null")
|
logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null")
|
||||||
|
|
||||||
# Clear cached data
|
# Clear cached data (SessionCacheManager handles this automatically via cleanup)
|
||||||
self.session_cache.clear_session(camera_id)
|
# Note: clear_session method not available, cleanup happens automatically
|
||||||
|
|
||||||
# Send null detection
|
# Send null detection
|
||||||
detection_data = {
|
detection_data = {
|
||||||
"type": "imageDetection",
|
"type": "imageDetection",
|
||||||
"subscriptionIdentifier": stream_info["subscriptionIdentifier"],
|
"subscriptionIdentifier": getattr(stream_info, "subscriptionIdentifier", ""),
|
||||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||||
"data": {
|
"data": {
|
||||||
"detection": None,
|
"detection": None,
|
||||||
"modelId": stream_info["modelId"],
|
"modelId": getattr(stream_info, "modelId", 0),
|
||||||
"modelName": stream_info["modelName"]
|
"modelName": getattr(stream_info, "modelName", "Unknown Model")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -672,13 +675,60 @@ class WebSocketHandler:
|
||||||
self.camera_monitor.mark_disconnection_notified(camera_id)
|
self.camera_monitor.mark_disconnection_notified(camera_id)
|
||||||
logger.info(f"📡 SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session")
|
logger.info(f"📡 SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session")
|
||||||
|
|
||||||
|
async def _handle_subscribe(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle individual subscription message."""
|
||||||
|
try:
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
subscription_id = payload.get("subscriptionIdentifier")
|
||||||
|
|
||||||
|
if not subscription_id:
|
||||||
|
logger.error("Subscribe message missing subscriptionIdentifier")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert single subscription to setSubscriptionList format
|
||||||
|
subscription_list_data = {
|
||||||
|
"type": "setSubscriptionList",
|
||||||
|
"subscriptions": [payload]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Delegate to existing setSubscriptionList handler
|
||||||
|
await self._handle_set_subscription_list(subscription_list_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling subscribe: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle individual unsubscription message."""
|
||||||
|
try:
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
subscription_id = payload.get("subscriptionIdentifier")
|
||||||
|
|
||||||
|
if not subscription_id:
|
||||||
|
logger.error("Unsubscribe message missing subscriptionIdentifier")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stop stream and clean up
|
||||||
|
camera_id = subscription_to_camera.get(subscription_id)
|
||||||
|
if camera_id:
|
||||||
|
await self.stream_manager.stop_stream(camera_id)
|
||||||
|
self.model_manager.unload_models(camera_id)
|
||||||
|
del subscription_to_camera[subscription_id]
|
||||||
|
logger.info(f"Unsubscribed from {subscription_id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown subscription ID: {subscription_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling unsubscribe: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any:
|
def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any:
|
||||||
"""Apply crop to frame if crop coordinates are specified."""
|
"""Apply crop to frame if crop coordinates are specified."""
|
||||||
crop_coords = [
|
crop_coords = [
|
||||||
stream_info.get("cropX1"),
|
getattr(stream_info, "cropX1", None),
|
||||||
stream_info.get("cropY1"),
|
getattr(stream_info, "cropY1", None),
|
||||||
stream_info.get("cropX2"),
|
getattr(stream_info, "cropX2", None),
|
||||||
stream_info.get("cropY2")
|
getattr(stream_info, "cropY2", None)
|
||||||
]
|
]
|
||||||
|
|
||||||
if all(coord is not None for coord in crop_coords):
|
if all(coord is not None for coord in crop_coords):
|
||||||
|
|
|
@ -7,13 +7,16 @@ for the detection worker.
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Dict, Any, Optional, List, Set, Tuple
|
from typing import Dict, Any, Optional, List, Set, Tuple, TYPE_CHECKING
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from ..core.config import MODELS_DIR
|
from ..core.config import MODELS_DIR
|
||||||
from ..core.exceptions import ModelLoadError
|
from ..core.exceptions import ModelLoadError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .pipeline_loader import PipelineLoader
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logger = logging.getLogger("detector_worker.model_manager")
|
logger = logging.getLogger("detector_worker.model_manager")
|
||||||
|
|
||||||
|
@ -110,11 +113,12 @@ class ModelManager:
|
||||||
- Pipeline model tree management
|
- Pipeline model tree management
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, models_dir: str = MODELS_DIR):
|
def __init__(self, pipeline_loader: Optional['PipelineLoader'] = None, models_dir: str = MODELS_DIR):
|
||||||
"""
|
"""
|
||||||
Initialize the model manager.
|
Initialize the model manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
pipeline_loader: Pipeline loader for handling MPTA archives (injected via DI)
|
||||||
models_dir: Directory to cache downloaded models
|
models_dir: Directory to cache downloaded models
|
||||||
"""
|
"""
|
||||||
self.models_dir = models_dir
|
self.models_dir = models_dir
|
||||||
|
@ -124,8 +128,19 @@ class ModelManager:
|
||||||
# Camera to models mapping
|
# Camera to models mapping
|
||||||
self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree}
|
self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree}
|
||||||
|
|
||||||
# Pipeline loader will be injected
|
# Pipeline loader injected via dependency injection
|
||||||
self.pipeline_loader = None
|
self.pipeline_loader = pipeline_loader
|
||||||
|
|
||||||
|
# If pipeline_loader is None, try to resolve it from the container
|
||||||
|
if self.pipeline_loader is None:
|
||||||
|
try:
|
||||||
|
from ..core.dependency_injection import get_container
|
||||||
|
from .pipeline_loader import PipelineLoader
|
||||||
|
container = get_container()
|
||||||
|
self.pipeline_loader = container.resolve(PipelineLoader)
|
||||||
|
logger.info("PipelineLoader resolved from dependency container")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not resolve PipelineLoader from container: {e}")
|
||||||
|
|
||||||
# Create models directory if it doesn't exist
|
# Create models directory if it doesn't exist
|
||||||
os.makedirs(self.models_dir, exist_ok=True)
|
os.makedirs(self.models_dir, exist_ok=True)
|
||||||
|
|
|
@ -41,6 +41,15 @@ class StreamInfo:
|
||||||
last_frame_time: Optional[float] = None
|
last_frame_time: Optional[float] = None
|
||||||
frame_count: int = 0
|
frame_count: int = 0
|
||||||
|
|
||||||
|
# Additional WebSocket fields
|
||||||
|
subscriptionIdentifier: Optional[str] = None
|
||||||
|
modelId: Optional[int] = None
|
||||||
|
modelName: Optional[str] = None
|
||||||
|
cropX1: Optional[int] = None
|
||||||
|
cropY1: Optional[int] = None
|
||||||
|
cropX2: Optional[int] = None
|
||||||
|
cropY2: Optional[int] = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Convert to dictionary format."""
|
"""Convert to dictionary format."""
|
||||||
return {
|
return {
|
||||||
|
@ -607,19 +616,24 @@ class StreamManager:
|
||||||
snapshot_url = payload.get('snapshotUrl')
|
snapshot_url = payload.get('snapshotUrl')
|
||||||
snapshot_interval = payload.get('snapshotInterval', 5000)
|
snapshot_interval = payload.get('snapshotInterval', 5000)
|
||||||
|
|
||||||
|
# Create a subscriber_id (for WebSocket compatibility, use the subscription_id)
|
||||||
|
subscriber_id = f"websocket_{int(time.time() * 1000)}"
|
||||||
|
|
||||||
# Create subscription based on available URL type
|
# Create subscription based on available URL type
|
||||||
if rtsp_url:
|
if rtsp_url:
|
||||||
success = self.create_subscription(
|
success = self.create_subscription(
|
||||||
subscription_id=subscription_id,
|
subscription_id=subscription_id,
|
||||||
camera_id=camera_id,
|
camera_id=camera_id,
|
||||||
|
subscriber_id=subscriber_id,
|
||||||
rtsp_url=rtsp_url
|
rtsp_url=rtsp_url
|
||||||
)
|
)
|
||||||
elif snapshot_url:
|
elif snapshot_url:
|
||||||
success = self.create_subscription(
|
success = self.create_subscription(
|
||||||
subscription_id=subscription_id,
|
subscription_id=subscription_id,
|
||||||
camera_id=camera_id,
|
camera_id=camera_id,
|
||||||
|
subscriber_id=subscriber_id,
|
||||||
snapshot_url=snapshot_url,
|
snapshot_url=snapshot_url,
|
||||||
snapshot_interval_ms=snapshot_interval
|
snapshot_interval=snapshot_interval
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f"No valid stream URL provided for camera {camera_id}")
|
logger.error(f"No valid stream URL provided for camera {camera_id}")
|
||||||
|
|
|
@ -2,4 +2,6 @@ fastapi[standard]
|
||||||
uvicorn
|
uvicorn
|
||||||
websockets
|
websockets
|
||||||
redis
|
redis
|
||||||
urllib3<2.0.0
|
urllib3<2.0.0
|
||||||
|
aiohttp
|
||||||
|
aiofiles
|
Loading…
Add table
Add a link
Reference in a new issue