Refactor: done phase 4
This commit is contained in:
parent
7e8034c6e5
commit
9e4c23c75c
8 changed files with 1533 additions and 37 deletions
|
@ -11,6 +11,7 @@ from collections import defaultdict
|
|||
|
||||
from .readers import RTSPReader, HTTPSnapshotReader
|
||||
from .buffers import shared_cache_buffer, save_frame_for_testing
|
||||
from ..tracking.integration import TrackingPipelineIntegration
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -35,6 +36,9 @@ class SubscriptionInfo:
|
|||
stream_config: StreamConfig
|
||||
created_at: float
|
||||
crop_coords: Optional[tuple] = None
|
||||
model_id: Optional[str] = None
|
||||
model_url: Optional[str] = None
|
||||
tracking_integration: Optional[TrackingPipelineIntegration] = None
|
||||
|
||||
|
||||
class StreamManager:
|
||||
|
@ -48,7 +52,10 @@ class StreamManager:
|
|||
self._lock = threading.RLock()
|
||||
|
||||
def add_subscription(self, subscription_id: str, stream_config: StreamConfig,
|
||||
crop_coords: Optional[tuple] = None) -> bool:
|
||||
crop_coords: Optional[tuple] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_url: Optional[str] = None,
|
||||
tracking_integration: Optional[TrackingPipelineIntegration] = None) -> bool:
|
||||
"""Add a new subscription. Returns True if successful."""
|
||||
with self._lock:
|
||||
if subscription_id in self._subscriptions:
|
||||
|
@ -63,7 +70,10 @@ class StreamManager:
|
|||
camera_id=camera_id,
|
||||
stream_config=stream_config,
|
||||
created_at=time.time(),
|
||||
crop_coords=crop_coords
|
||||
crop_coords=crop_coords,
|
||||
model_id=model_id,
|
||||
model_url=model_url,
|
||||
tracking_integration=tracking_integration
|
||||
)
|
||||
|
||||
self._subscriptions[subscription_id] = subscription_info
|
||||
|
@ -175,9 +185,64 @@ class StreamManager:
|
|||
save_frame_for_testing(camera_id, frame)
|
||||
break # Only save once per frame
|
||||
|
||||
# Process tracking for subscriptions with tracking integration
|
||||
self._process_tracking_for_camera(camera_id, frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in frame callback for camera {camera_id}: {e}")
|
||||
|
||||
def _process_tracking_for_camera(self, camera_id: str, frame):
|
||||
"""Process tracking for all subscriptions of a camera."""
|
||||
try:
|
||||
with self._lock:
|
||||
for subscription_id in self._camera_subscribers[camera_id]:
|
||||
subscription_info = self._subscriptions[subscription_id]
|
||||
|
||||
# Skip if no tracking integration
|
||||
if not subscription_info.tracking_integration:
|
||||
continue
|
||||
|
||||
# Extract display_id from subscription_id
|
||||
display_id = subscription_id.split(';')[0] if ';' in subscription_id else subscription_id
|
||||
|
||||
# Process frame through tracking asynchronously
|
||||
# Note: This is synchronous for now, can be made async in future
|
||||
try:
|
||||
# Create a simple asyncio event loop for this frame
|
||||
import asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
subscription_info.tracking_integration.process_frame(
|
||||
frame, display_id, subscription_id
|
||||
)
|
||||
)
|
||||
# Log tracking results
|
||||
if result:
|
||||
tracked_count = len(result.get('tracked_vehicles', []))
|
||||
validated_vehicle = result.get('validated_vehicle')
|
||||
pipeline_result = result.get('pipeline_result')
|
||||
|
||||
if tracked_count > 0:
|
||||
logger.info(f"[Tracking] {camera_id}: {tracked_count} vehicles tracked")
|
||||
|
||||
if validated_vehicle:
|
||||
logger.info(f"[Tracking] {camera_id}: Vehicle {validated_vehicle['track_id']} "
|
||||
f"validated as {validated_vehicle['state']} "
|
||||
f"(confidence: {validated_vehicle['confidence']:.2f})")
|
||||
|
||||
if pipeline_result:
|
||||
logger.info(f"[Pipeline] {camera_id}: {pipeline_result.get('status', 'unknown')} - "
|
||||
f"{pipeline_result.get('message', 'no message')}")
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as track_e:
|
||||
logger.error(f"Error in tracking for {subscription_id}: {track_e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tracking for camera {camera_id}: {e}")
|
||||
|
||||
def get_frame(self, camera_id: str, crop_coords: Optional[tuple] = None):
|
||||
"""Get the latest frame for a camera with optional cropping."""
|
||||
return shared_cache_buffer.get_frame(camera_id, crop_coords)
|
||||
|
@ -280,7 +345,13 @@ class StreamManager:
|
|||
save_test_frames=True # Enable for testing
|
||||
)
|
||||
|
||||
return self.add_subscription(subscription_id, stream_config, crop_coords)
|
||||
return self.add_subscription(
|
||||
subscription_id,
|
||||
stream_config,
|
||||
crop_coords,
|
||||
model_id=payload.get('modelId'),
|
||||
model_url=payload.get('modelUrl')
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding subscription from payload {subscription_id}: {e}")
|
||||
|
@ -300,10 +371,38 @@ class StreamManager:
|
|||
|
||||
logger.info("Stopped all streams and cleared all subscriptions")
|
||||
|
||||
def set_session_id(self, display_id: str, session_id: str):
|
||||
"""Set session ID for tracking integration."""
|
||||
with self._lock:
|
||||
for subscription_info in self._subscriptions.values():
|
||||
# Check if this subscription matches the display_id
|
||||
subscription_display_id = subscription_info.subscription_id.split(';')[0]
|
||||
if subscription_display_id == display_id and subscription_info.tracking_integration:
|
||||
subscription_info.tracking_integration.set_session_id(display_id, session_id)
|
||||
logger.debug(f"Set session {session_id} for display {display_id}")
|
||||
|
||||
def clear_session_id(self, session_id: str):
|
||||
"""Clear session ID from tracking integrations."""
|
||||
with self._lock:
|
||||
for subscription_info in self._subscriptions.values():
|
||||
if subscription_info.tracking_integration:
|
||||
subscription_info.tracking_integration.clear_session_id(session_id)
|
||||
logger.debug(f"Cleared session {session_id}")
|
||||
|
||||
def get_tracking_stats(self) -> Dict[str, Any]:
|
||||
"""Get tracking statistics from all subscriptions."""
|
||||
stats = {}
|
||||
with self._lock:
|
||||
for subscription_id, subscription_info in self._subscriptions.items():
|
||||
if subscription_info.tracking_integration:
|
||||
stats[subscription_id] = subscription_info.tracking_integration.get_statistics()
|
||||
return stats
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive streaming statistics."""
|
||||
with self._lock:
|
||||
buffer_stats = shared_cache_buffer.get_stats()
|
||||
tracking_stats = self.get_tracking_stats()
|
||||
|
||||
return {
|
||||
'active_subscriptions': len(self._subscriptions),
|
||||
|
@ -314,7 +413,8 @@ class StreamManager:
|
|||
camera_id: len(subscribers)
|
||||
for camera_id, subscribers in self._camera_subscribers.items()
|
||||
},
|
||||
'buffer_stats': buffer_stats
|
||||
'buffer_stats': buffer_stats,
|
||||
'tracking_stats': tracking_stats
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue