Refactor: done phase 4

This commit is contained in:
ziesorx 2025-09-23 17:56:40 +07:00
parent 7e8034c6e5
commit 9e4c23c75c
8 changed files with 1533 additions and 37 deletions

View file

@ -11,6 +11,7 @@ from collections import defaultdict
from .readers import RTSPReader, HTTPSnapshotReader
from .buffers import shared_cache_buffer, save_frame_for_testing
from ..tracking.integration import TrackingPipelineIntegration
logger = logging.getLogger(__name__)
@ -35,6 +36,9 @@ class SubscriptionInfo:
stream_config: StreamConfig
created_at: float
crop_coords: Optional[tuple] = None
model_id: Optional[str] = None
model_url: Optional[str] = None
tracking_integration: Optional[TrackingPipelineIntegration] = None
class StreamManager:
@ -48,7 +52,10 @@ class StreamManager:
self._lock = threading.RLock()
def add_subscription(self, subscription_id: str, stream_config: StreamConfig,
crop_coords: Optional[tuple] = None) -> bool:
crop_coords: Optional[tuple] = None,
model_id: Optional[str] = None,
model_url: Optional[str] = None,
tracking_integration: Optional[TrackingPipelineIntegration] = None) -> bool:
"""Add a new subscription. Returns True if successful."""
with self._lock:
if subscription_id in self._subscriptions:
@ -63,7 +70,10 @@ class StreamManager:
camera_id=camera_id,
stream_config=stream_config,
created_at=time.time(),
crop_coords=crop_coords
crop_coords=crop_coords,
model_id=model_id,
model_url=model_url,
tracking_integration=tracking_integration
)
self._subscriptions[subscription_id] = subscription_info
@ -175,9 +185,64 @@ class StreamManager:
save_frame_for_testing(camera_id, frame)
break # Only save once per frame
# Process tracking for subscriptions with tracking integration
self._process_tracking_for_camera(camera_id, frame)
except Exception as e:
logger.error(f"Error in frame callback for camera {camera_id}: {e}")
def _process_tracking_for_camera(self, camera_id: str, frame):
"""Process tracking for all subscriptions of a camera."""
try:
with self._lock:
for subscription_id in self._camera_subscribers[camera_id]:
subscription_info = self._subscriptions[subscription_id]
# Skip if no tracking integration
if not subscription_info.tracking_integration:
continue
# Extract display_id from subscription_id
display_id = subscription_id.split(';')[0] if ';' in subscription_id else subscription_id
# Process frame through tracking asynchronously
# Note: This is synchronous for now, can be made async in future
try:
# Create a simple asyncio event loop for this frame
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
subscription_info.tracking_integration.process_frame(
frame, display_id, subscription_id
)
)
# Log tracking results
if result:
tracked_count = len(result.get('tracked_vehicles', []))
validated_vehicle = result.get('validated_vehicle')
pipeline_result = result.get('pipeline_result')
if tracked_count > 0:
logger.info(f"[Tracking] {camera_id}: {tracked_count} vehicles tracked")
if validated_vehicle:
logger.info(f"[Tracking] {camera_id}: Vehicle {validated_vehicle['track_id']} "
f"validated as {validated_vehicle['state']} "
f"(confidence: {validated_vehicle['confidence']:.2f})")
if pipeline_result:
logger.info(f"[Pipeline] {camera_id}: {pipeline_result.get('status', 'unknown')} - "
f"{pipeline_result.get('message', 'no message')}")
finally:
loop.close()
except Exception as track_e:
logger.error(f"Error in tracking for {subscription_id}: {track_e}")
except Exception as e:
logger.error(f"Error processing tracking for camera {camera_id}: {e}")
def get_frame(self, camera_id: str, crop_coords: Optional[tuple] = None):
"""Get the latest frame for a camera with optional cropping."""
return shared_cache_buffer.get_frame(camera_id, crop_coords)
@ -280,7 +345,13 @@ class StreamManager:
save_test_frames=True # Enable for testing
)
return self.add_subscription(subscription_id, stream_config, crop_coords)
return self.add_subscription(
subscription_id,
stream_config,
crop_coords,
model_id=payload.get('modelId'),
model_url=payload.get('modelUrl')
)
except Exception as e:
logger.error(f"Error adding subscription from payload {subscription_id}: {e}")
@ -300,10 +371,38 @@ class StreamManager:
logger.info("Stopped all streams and cleared all subscriptions")
def set_session_id(self, display_id: str, session_id: str):
"""Set session ID for tracking integration."""
with self._lock:
for subscription_info in self._subscriptions.values():
# Check if this subscription matches the display_id
subscription_display_id = subscription_info.subscription_id.split(';')[0]
if subscription_display_id == display_id and subscription_info.tracking_integration:
subscription_info.tracking_integration.set_session_id(display_id, session_id)
logger.debug(f"Set session {session_id} for display {display_id}")
def clear_session_id(self, session_id: str):
"""Clear session ID from tracking integrations."""
with self._lock:
for subscription_info in self._subscriptions.values():
if subscription_info.tracking_integration:
subscription_info.tracking_integration.clear_session_id(session_id)
logger.debug(f"Cleared session {session_id}")
def get_tracking_stats(self) -> Dict[str, Any]:
"""Get tracking statistics from all subscriptions."""
stats = {}
with self._lock:
for subscription_id, subscription_info in self._subscriptions.items():
if subscription_info.tracking_integration:
stats[subscription_id] = subscription_info.tracking_integration.get_statistics()
return stats
def get_stats(self) -> Dict[str, Any]:
"""Get comprehensive streaming statistics."""
with self._lock:
buffer_stats = shared_cache_buffer.get_stats()
tracking_stats = self.get_tracking_stats()
return {
'active_subscriptions': len(self._subscriptions),
@ -314,7 +413,8 @@ class StreamManager:
camera_id: len(subscribers)
for camera_id, subscribers in self._camera_subscribers.items()
},
'buffer_stats': buffer_stats
'buffer_stats': buffer_stats,
'tracking_stats': tracking_stats
}