python-rtsp-worker/services/stream_connection_manager.py

592 lines
21 KiB
Python

"""
StreamConnectionManager - Async orchestration for stream processing with batched inference.
This module provides high-level connection management for multiple RTSP streams,
coordinating decoders, batched inference, and tracking with an event-driven API.
"""
import asyncio
import time
from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List
from dataclasses import dataclass
from enum import Enum
import logging
import torch
from .model_controller import ModelController
from .stream_decoder import StreamDecoderFactory
from .tracking_factory import TrackingFactory
from .model_repository import TensorRTModelRepository
logger = logging.getLogger(__name__)
class ConnectionStatus(Enum):
"""Status of a stream connection"""
CONNECTING = "connecting"
CONNECTED = "connected"
DISCONNECTED = "disconnected"
ERROR = "error"
@dataclass
class TrackingResult:
"""Result emitted to user callbacks"""
stream_id: str
timestamp: float
tracked_objects: List # List of TrackedObject from TrackingController
detections: List # Raw detections
frame_shape: Tuple[int, int, int]
metadata: Dict
class StreamConnection:
"""
Represents a single stream connection with event emission.
This class wraps a StreamDecoder, polls frames asynchronously, submits them
to the ModelController for batched inference, runs tracking, and emits results
via queues or callbacks.
Args:
stream_id: Unique identifier for this stream
decoder: StreamDecoder instance
model_controller: ModelController for batched inference
tracking_controller: TrackingController for object tracking
poll_interval: Frame polling interval in seconds (default: 0.01)
"""
def __init__(
self,
stream_id: str,
decoder,
model_controller: ModelController,
tracking_controller,
poll_interval: float = 0.01,
):
self.stream_id = stream_id
self.decoder = decoder
self.model_controller = model_controller
self.tracking_controller = tracking_controller
self.poll_interval = poll_interval
self.status = ConnectionStatus.CONNECTING
self.frame_count = 0
self.last_frame_time = 0.0
# Event emission
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
# Tasks
self.poller_task: Optional[asyncio.Task] = None
self.running = False
async def start(self):
"""Start the connection (decoder and frame polling)"""
# Start decoder (runs in background thread)
self.decoder.start()
# Wait for initial connection (try for up to 10 seconds)
max_wait = 10.0
wait_interval = 0.5
elapsed = 0.0
while elapsed < max_wait:
await asyncio.sleep(wait_interval)
elapsed += wait_interval
if self.decoder.is_connected():
self.status = ConnectionStatus.CONNECTED
logger.info(f"Stream {self.stream_id} connected after {elapsed:.1f}s")
break
else:
# Timeout - but don't fail hard, let it try to connect in background
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
self.status = ConnectionStatus.CONNECTING
# Start frame polling task
self.running = True
self.poller_task = asyncio.create_task(self._frame_poller())
async def stop(self):
"""Stop the connection and cleanup"""
logger.info(f"Stopping stream {self.stream_id}...")
self.running = False
if self.poller_task:
self.poller_task.cancel()
try:
await self.poller_task
except asyncio.CancelledError:
pass
# Stop decoder
self.decoder.stop()
# Unregister from model controller
self.model_controller.unregister_callback(self.stream_id)
self.status = ConnectionStatus.DISCONNECTED
logger.info(f"Stream {self.stream_id} stopped")
async def _frame_poller(self):
"""Poll frames from threaded decoder and submit to model controller"""
last_frame_ptr = None
while self.running:
try:
# Poll frame from decoder (runs in thread)
frame = self.decoder.get_latest_frame(rgb=True)
# Check if we got a new frame (avoid reprocessing same frame)
if frame is not None and frame.data_ptr() != last_frame_ptr:
last_frame_ptr = frame.data_ptr()
self.last_frame_time = time.time()
self.frame_count += 1
# Submit to model controller for batched inference
await self.model_controller.submit_frame(
stream_id=self.stream_id,
frame=frame,
metadata={
"frame_number": self.frame_count,
"shape": tuple(frame.shape),
}
)
# Check decoder status
if not self.decoder.is_connected():
if self.status == ConnectionStatus.CONNECTED:
logger.warning(f"Stream {self.stream_id} disconnected")
self.status = ConnectionStatus.DISCONNECTED
# Decoder will auto-reconnect, just update status
await asyncio.sleep(1.0)
if self.decoder.is_connected():
logger.info(f"Stream {self.stream_id} reconnected")
self.status = ConnectionStatus.CONNECTED
except Exception as e:
logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True)
await self.error_queue.put(e)
self.status = ConnectionStatus.ERROR
# Sleep until next poll
await asyncio.sleep(self.poll_interval)
async def _handle_inference_result(self, result: Dict[str, Any]):
"""
Callback invoked by ModelController when inference is done.
Runs tracking and emits final result.
Args:
result: Inference result dictionary
"""
try:
# Extract detections
detections = result["detections"]
# Run tracking (this is sync, so run in executor)
loop = asyncio.get_event_loop()
tracked_objects = await loop.run_in_executor(
None,
lambda: self._run_tracking_sync(detections)
)
# Create tracking result
tracking_result = TrackingResult(
stream_id=self.stream_id,
timestamp=result["timestamp"],
tracked_objects=tracked_objects,
detections=detections,
frame_shape=result["metadata"].get("shape"),
metadata=result["metadata"],
)
# Emit to result queue
await self.result_queue.put(tracking_result)
except Exception as e:
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
await self.error_queue.put(e)
def _run_tracking_sync(self, detections):
"""
Run tracking synchronously (called from executor).
Args:
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
Returns:
List of TrackedObject instances
"""
# Use the TrackingController's internal tracking with detections
# We need to manually update tracks since we already have detections
import torch
with self.tracking_controller._lock:
self.tracking_controller._frame_count += 1
# If no detections, just cleanup and return current tracks
if len(detections) == 0:
self.tracking_controller._cleanup_stale_tracks()
return list(self.tracking_controller._tracks.values())
# Run IoU tracking to associate detections with existing tracks
associations = self.tracking_controller._iou_tracking(detections)
# Update or create tracks
for (det_idx, track_id), detection in zip(associations, detections):
bbox = detection[:4].cpu().tolist()
confidence = float(detection[4])
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
if track_id == -1:
# Create new track
new_track = self.tracking_controller._create_track(
bbox, confidence, class_id, self.tracking_controller._frame_count
)
self.tracking_controller._tracks[new_track.track_id] = new_track
else:
# Update existing track
self.tracking_controller._tracks[track_id].update(
bbox, confidence, self.tracking_controller._frame_count
)
# Cleanup stale tracks
self.tracking_controller._cleanup_stale_tracks()
return list(self.tracking_controller._tracks.values())
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
"""
Async generator for tracking results.
Usage:
async for result in connection.tracking_results():
print(result.tracked_objects)
Yields:
TrackingResult objects as they become available
"""
while self.running or not self.result_queue.empty():
try:
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
yield result
except asyncio.TimeoutError:
continue
async def errors(self) -> AsyncIterator[Exception]:
"""
Async generator for errors.
Yields:
Exception objects as they occur
"""
while self.running or not self.error_queue.empty():
try:
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
yield error
except asyncio.TimeoutError:
continue
def get_stats(self) -> Dict[str, Any]:
"""Get connection statistics"""
return {
"stream_id": self.stream_id,
"status": self.status.value,
"frame_count": self.frame_count,
"last_frame_time": self.last_frame_time,
"decoder_connected": self.decoder.is_connected(),
"decoder_buffer_size": self.decoder.get_buffer_size(),
"result_queue_size": self.result_queue.qsize(),
"error_queue_size": self.error_queue.qsize(),
}
class StreamConnectionManager:
"""
High-level manager for stream connections with batched inference.
This manager coordinates multiple RTSP streams, batched model inference,
and object tracking through an async event-driven API.
Args:
gpu_id: GPU device ID (default: 0)
batch_size: Maximum batch size for inference (default: 16)
force_timeout: Force buffer switch timeout in seconds (default: 0.05)
poll_interval: Frame polling interval in seconds (default: 0.01)
Example:
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
await manager.initialize(model_path="yolov8n.trt", ...)
connection = await manager.connect_stream(rtsp_url, on_tracking_result=callback)
await asyncio.sleep(60)
await manager.shutdown()
"""
def __init__(
self,
gpu_id: int = 0,
batch_size: int = 16,
force_timeout: float = 0.05,
poll_interval: float = 0.01,
enable_pt_conversion: bool = True,
):
self.gpu_id = gpu_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.poll_interval = poll_interval
# Factories
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
self.model_repository = TensorRTModelRepository(
gpu_id=gpu_id,
enable_pt_conversion=enable_pt_conversion
)
# Controllers
self.model_controller: Optional[ModelController] = None
self.tracking_controller = None
# Connections
self.connections: Dict[str, StreamConnection] = {}
# State
self.initialized = False
async def initialize(
self,
model_path: str,
model_id: str = "detector",
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
num_contexts: int = 4,
pt_input_shapes: Optional[Dict] = None,
pt_precision: Optional[Any] = None,
**pt_conversion_kwargs
):
"""
Initialize the manager with a model.
Args:
model_path: Path to TensorRT or PyTorch model file (.trt, .pt, .pth)
model_id: Model identifier (default: "detector")
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
num_contexts: Number of TensorRT execution contexts (default: 4)
pt_input_shapes: Required for PT files - dict of input shapes
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
**pt_conversion_kwargs: Additional PT conversion arguments
"""
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
# Load model
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
lambda: self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs
)
)
logger.info(f"Loaded model {model_id} from {model_path}")
# Create model controller
self.model_controller = ModelController(
model_repository=self.model_repository,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
await self.model_controller.start()
# Don't create a shared tracking controller here
# Each stream will get its own tracking controller to avoid track accumulation
self.tracking_controller = None
self.model_id_for_tracking = model_id # Store for later use
self.initialized = True
logger.info("StreamConnectionManager initialized successfully")
async def connect_stream(
self,
rtsp_url: str,
stream_id: Optional[str] = None,
on_tracking_result: Optional[Callable] = None,
on_error: Optional[Callable] = None,
buffer_size: int = 30,
) -> StreamConnection:
"""
Connect to a stream and start processing.
Args:
rtsp_url: RTSP stream URL
stream_id: Optional stream identifier (auto-generated if not provided)
on_tracking_result: Optional callback for tracking results (sync or async)
on_error: Optional callback for errors (sync or async)
buffer_size: Decoder buffer size (default: 30)
Returns:
StreamConnection object for this stream
Raises:
RuntimeError: If manager is not initialized
ConnectionError: If stream connection fails
"""
if not self.initialized:
raise RuntimeError("Manager not initialized. Call initialize() first.")
# Generate stream ID if not provided
if stream_id is None:
stream_id = f"stream_{len(self.connections)}"
logger.info(f"Connecting to stream {stream_id}: {rtsp_url}")
# Create decoder
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
# Create dedicated tracking controller for THIS stream
# This prevents track accumulation across multiple streams
tracking_controller = self.tracking_factory.create_controller(
model_repository=self.model_repository,
model_id=self.model_id_for_tracking,
tracker_type="iou",
max_age=30,
min_confidence=0.5,
iou_threshold=0.3,
)
logger.info(f"Created dedicated TrackingController for stream {stream_id}")
# Create connection
connection = StreamConnection(
stream_id=stream_id,
decoder=decoder,
model_controller=self.model_controller,
tracking_controller=tracking_controller,
poll_interval=self.poll_interval,
)
# Register callback with model controller
self.model_controller.register_callback(
stream_id,
connection._handle_inference_result
)
# Start connection
await connection.start()
# Store connection
self.connections[stream_id] = connection
# Set up user callbacks if provided
if on_tracking_result:
asyncio.create_task(self._forward_results(connection, on_tracking_result))
if on_error:
asyncio.create_task(self._forward_errors(connection, on_error))
logger.info(f"Stream {stream_id} connected successfully")
return connection
async def disconnect_stream(self, stream_id: str):
"""
Disconnect and cleanup a stream.
Args:
stream_id: Stream identifier to disconnect
"""
connection = self.connections.get(stream_id)
if connection:
await connection.stop()
del self.connections[stream_id]
logger.info(f"Stream {stream_id} disconnected")
async def disconnect_all(self):
"""Disconnect all streams"""
logger.info("Disconnecting all streams...")
stream_ids = list(self.connections.keys())
for stream_id in stream_ids:
await self.disconnect_stream(stream_id)
async def shutdown(self):
"""Shutdown the manager and cleanup all resources"""
logger.info("Shutting down StreamConnectionManager...")
# Disconnect all streams
await self.disconnect_all()
# Stop model controller
if self.model_controller:
await self.model_controller.stop()
# Note: Model repository cleanup is sync and may cause segfaults
# Leaving cleanup to garbage collection for now
self.initialized = False
logger.info("StreamConnectionManager shutdown complete")
async def _forward_results(self, connection: StreamConnection, callback: Callable):
"""
Forward results from connection to user callback.
Args:
connection: StreamConnection to listen to
callback: User callback (sync or async)
"""
try:
async for result in connection.tracking_results():
if asyncio.iscoroutinefunction(callback):
await callback(result)
else:
callback(result)
except Exception as e:
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
"""
Forward errors from connection to user callback.
Args:
connection: StreamConnection to listen to
callback: User callback (sync or async)
"""
try:
async for error in connection.errors():
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
except Exception as e:
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
def get_stats(self) -> Dict[str, Any]:
"""
Get statistics for all connections.
Returns:
Dictionary with manager and connection statistics
"""
return {
"manager": {
"initialized": self.initialized,
"gpu_id": self.gpu_id,
"num_connections": len(self.connections),
"batch_size": self.batch_size,
"force_timeout": self.force_timeout,
"poll_interval": self.poll_interval,
},
"model_controller": self.model_controller.get_stats() if self.model_controller else {},
"connections": {
stream_id: conn.get_stats()
for stream_id, conn in self.connections.items()
},
}