batch processing/event driven
This commit is contained in:
parent
e71316ef3d
commit
dd57b5a246
7 changed files with 2673 additions and 2 deletions
566
services/stream_connection_manager.py
Normal file
566
services/stream_connection_manager.py
Normal file
|
|
@ -0,0 +1,566 @@
|
|||
"""
|
||||
StreamConnectionManager - Async orchestration for stream processing with batched inference.
|
||||
|
||||
This module provides high-level connection management for multiple RTSP streams,
|
||||
coordinating decoders, batched inference, and tracking with an event-driven API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from .model_controller import ModelController
|
||||
from .stream_decoder import StreamDecoderFactory
|
||||
from .tracking_factory import TrackingFactory
|
||||
from .model_repository import TensorRTModelRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionStatus(Enum):
|
||||
"""Status of a stream connection"""
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
DISCONNECTED = "disconnected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackingResult:
|
||||
"""Result emitted to user callbacks"""
|
||||
stream_id: str
|
||||
timestamp: float
|
||||
tracked_objects: List # List of TrackedObject from TrackingController
|
||||
detections: List # Raw detections
|
||||
frame_shape: Tuple[int, int, int]
|
||||
metadata: Dict
|
||||
|
||||
|
||||
class StreamConnection:
|
||||
"""
|
||||
Represents a single stream connection with event emission.
|
||||
|
||||
This class wraps a StreamDecoder, polls frames asynchronously, submits them
|
||||
to the ModelController for batched inference, runs tracking, and emits results
|
||||
via queues or callbacks.
|
||||
|
||||
Args:
|
||||
stream_id: Unique identifier for this stream
|
||||
decoder: StreamDecoder instance
|
||||
model_controller: ModelController for batched inference
|
||||
tracking_controller: TrackingController for object tracking
|
||||
poll_interval: Frame polling interval in seconds (default: 0.01)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
decoder,
|
||||
model_controller: ModelController,
|
||||
tracking_controller,
|
||||
poll_interval: float = 0.01,
|
||||
):
|
||||
self.stream_id = stream_id
|
||||
self.decoder = decoder
|
||||
self.model_controller = model_controller
|
||||
self.tracking_controller = tracking_controller
|
||||
self.poll_interval = poll_interval
|
||||
|
||||
self.status = ConnectionStatus.CONNECTING
|
||||
self.frame_count = 0
|
||||
self.last_frame_time = 0.0
|
||||
|
||||
# Event emission
|
||||
self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue()
|
||||
self.error_queue: asyncio.Queue[Exception] = asyncio.Queue()
|
||||
|
||||
# Tasks
|
||||
self.poller_task: Optional[asyncio.Task] = None
|
||||
self.running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the connection (decoder and frame polling)"""
|
||||
# Start decoder (runs in background thread)
|
||||
self.decoder.start()
|
||||
|
||||
# Wait for initial connection (try for up to 10 seconds)
|
||||
max_wait = 10.0
|
||||
wait_interval = 0.5
|
||||
elapsed = 0.0
|
||||
|
||||
while elapsed < max_wait:
|
||||
await asyncio.sleep(wait_interval)
|
||||
elapsed += wait_interval
|
||||
|
||||
if self.decoder.is_connected():
|
||||
self.status = ConnectionStatus.CONNECTED
|
||||
logger.info(f"Stream {self.stream_id} connected after {elapsed:.1f}s")
|
||||
break
|
||||
else:
|
||||
# Timeout - but don't fail hard, let it try to connect in background
|
||||
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
|
||||
self.status = ConnectionStatus.CONNECTING
|
||||
|
||||
# Start frame polling task
|
||||
self.running = True
|
||||
self.poller_task = asyncio.create_task(self._frame_poller())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the connection and cleanup"""
|
||||
logger.info(f"Stopping stream {self.stream_id}...")
|
||||
self.running = False
|
||||
|
||||
if self.poller_task:
|
||||
self.poller_task.cancel()
|
||||
try:
|
||||
await self.poller_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop decoder
|
||||
self.decoder.stop()
|
||||
|
||||
# Unregister from model controller
|
||||
self.model_controller.unregister_callback(self.stream_id)
|
||||
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
logger.info(f"Stream {self.stream_id} stopped")
|
||||
|
||||
async def _frame_poller(self):
|
||||
"""Poll frames from threaded decoder and submit to model controller"""
|
||||
last_frame_ptr = None
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Poll frame from decoder (runs in thread)
|
||||
frame = self.decoder.get_latest_frame(rgb=True)
|
||||
|
||||
# Check if we got a new frame (avoid reprocessing same frame)
|
||||
if frame is not None and frame.data_ptr() != last_frame_ptr:
|
||||
last_frame_ptr = frame.data_ptr()
|
||||
self.last_frame_time = time.time()
|
||||
self.frame_count += 1
|
||||
|
||||
# Submit to model controller for batched inference
|
||||
await self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame,
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame.shape),
|
||||
}
|
||||
)
|
||||
|
||||
# Check decoder status
|
||||
if not self.decoder.is_connected():
|
||||
if self.status == ConnectionStatus.CONNECTED:
|
||||
logger.warning(f"Stream {self.stream_id} disconnected")
|
||||
self.status = ConnectionStatus.DISCONNECTED
|
||||
# Decoder will auto-reconnect, just update status
|
||||
await asyncio.sleep(1.0)
|
||||
if self.decoder.is_connected():
|
||||
logger.info(f"Stream {self.stream_id} reconnected")
|
||||
self.status = ConnectionStatus.CONNECTED
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True)
|
||||
await self.error_queue.put(e)
|
||||
self.status = ConnectionStatus.ERROR
|
||||
|
||||
# Sleep until next poll
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
async def _handle_inference_result(self, result: Dict[str, Any]):
|
||||
"""
|
||||
Callback invoked by ModelController when inference is done.
|
||||
Runs tracking and emits final result.
|
||||
|
||||
Args:
|
||||
result: Inference result dictionary
|
||||
"""
|
||||
try:
|
||||
# Extract detections
|
||||
detections = result["detections"]
|
||||
|
||||
# Run tracking (this is sync, so run in executor)
|
||||
loop = asyncio.get_event_loop()
|
||||
tracked_objects = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._run_tracking_sync(detections)
|
||||
)
|
||||
|
||||
# Create tracking result
|
||||
tracking_result = TrackingResult(
|
||||
stream_id=self.stream_id,
|
||||
timestamp=result["timestamp"],
|
||||
tracked_objects=tracked_objects,
|
||||
detections=detections,
|
||||
frame_shape=result["metadata"].get("shape"),
|
||||
metadata=result["metadata"],
|
||||
)
|
||||
|
||||
# Emit to result queue
|
||||
await self.result_queue.put(tracking_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||
await self.error_queue.put(e)
|
||||
|
||||
def _run_tracking_sync(self, detections):
|
||||
"""
|
||||
Run tracking synchronously (called from executor).
|
||||
|
||||
Args:
|
||||
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
|
||||
|
||||
Returns:
|
||||
List of TrackedObject instances
|
||||
"""
|
||||
# Use the TrackingController's internal tracking with detections
|
||||
# We need to manually update tracks since we already have detections
|
||||
import torch
|
||||
|
||||
with self.tracking_controller._lock:
|
||||
self.tracking_controller._frame_count += 1
|
||||
|
||||
# If no detections, just cleanup and return current tracks
|
||||
if len(detections) == 0:
|
||||
self.tracking_controller._cleanup_stale_tracks()
|
||||
return list(self.tracking_controller._tracks.values())
|
||||
|
||||
# Run IoU tracking to associate detections with existing tracks
|
||||
associations = self.tracking_controller._iou_tracking(detections)
|
||||
|
||||
# Update or create tracks
|
||||
for (det_idx, track_id), detection in zip(associations, detections):
|
||||
bbox = detection[:4].cpu().tolist()
|
||||
confidence = float(detection[4])
|
||||
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
|
||||
|
||||
if track_id == -1:
|
||||
# Create new track
|
||||
new_track = self.tracking_controller._create_track(
|
||||
bbox, confidence, class_id, self.tracking_controller._frame_count
|
||||
)
|
||||
self.tracking_controller._tracks[new_track.track_id] = new_track
|
||||
else:
|
||||
# Update existing track
|
||||
self.tracking_controller._tracks[track_id].update(
|
||||
bbox, confidence, self.tracking_controller._frame_count
|
||||
)
|
||||
|
||||
# Cleanup stale tracks
|
||||
self.tracking_controller._cleanup_stale_tracks()
|
||||
|
||||
return list(self.tracking_controller._tracks.values())
|
||||
|
||||
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
|
||||
"""
|
||||
Async generator for tracking results.
|
||||
|
||||
Usage:
|
||||
async for result in connection.tracking_results():
|
||||
print(result.tracked_objects)
|
||||
|
||||
Yields:
|
||||
TrackingResult objects as they become available
|
||||
"""
|
||||
while self.running or not self.result_queue.empty():
|
||||
try:
|
||||
result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0)
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
async def errors(self) -> AsyncIterator[Exception]:
|
||||
"""
|
||||
Async generator for errors.
|
||||
|
||||
Yields:
|
||||
Exception objects as they occur
|
||||
"""
|
||||
while self.running or not self.error_queue.empty():
|
||||
try:
|
||||
error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0)
|
||||
yield error
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get connection statistics"""
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"status": self.status.value,
|
||||
"frame_count": self.frame_count,
|
||||
"last_frame_time": self.last_frame_time,
|
||||
"decoder_connected": self.decoder.is_connected(),
|
||||
"decoder_buffer_size": self.decoder.get_buffer_size(),
|
||||
"result_queue_size": self.result_queue.qsize(),
|
||||
"error_queue_size": self.error_queue.qsize(),
|
||||
}
|
||||
|
||||
|
||||
class StreamConnectionManager:
|
||||
"""
|
||||
High-level manager for stream connections with batched inference.
|
||||
|
||||
This manager coordinates multiple RTSP streams, batched model inference,
|
||||
and object tracking through an async event-driven API.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID (default: 0)
|
||||
batch_size: Maximum batch size for inference (default: 16)
|
||||
force_timeout: Force buffer switch timeout in seconds (default: 0.05)
|
||||
poll_interval: Frame polling interval in seconds (default: 0.01)
|
||||
|
||||
Example:
|
||||
manager = StreamConnectionManager(gpu_id=0, batch_size=16)
|
||||
await manager.initialize(model_path="yolov8n.trt", ...)
|
||||
connection = await manager.connect_stream(rtsp_url, on_tracking_result=callback)
|
||||
await asyncio.sleep(60)
|
||||
await manager.shutdown()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpu_id: int = 0,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
poll_interval: float = 0.01,
|
||||
):
|
||||
self.gpu_id = gpu_id
|
||||
self.batch_size = batch_size
|
||||
self.force_timeout = force_timeout
|
||||
self.poll_interval = poll_interval
|
||||
|
||||
# Factories
|
||||
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
||||
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
|
||||
self.model_repository = TensorRTModelRepository(gpu_id=gpu_id)
|
||||
|
||||
# Controllers
|
||||
self.model_controller: Optional[ModelController] = None
|
||||
self.tracking_controller = None
|
||||
|
||||
# Connections
|
||||
self.connections: Dict[str, StreamConnection] = {}
|
||||
|
||||
# State
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
model_path: str,
|
||||
model_id: str = "detector",
|
||||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
num_contexts: int = 4,
|
||||
):
|
||||
"""
|
||||
Initialize the manager with a model.
|
||||
|
||||
Args:
|
||||
model_path: Path to TensorRT model file
|
||||
model_id: Model identifier (default: "detector")
|
||||
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
|
||||
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
|
||||
num_contexts: Number of TensorRT execution contexts (default: 4)
|
||||
"""
|
||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||
|
||||
# Load model
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model_repository.load_model(model_id, model_path, num_contexts=num_contexts)
|
||||
)
|
||||
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||
|
||||
# Create model controller
|
||||
self.model_controller = ModelController(
|
||||
model_repository=self.model_repository,
|
||||
model_id=model_id,
|
||||
batch_size=self.batch_size,
|
||||
force_timeout=self.force_timeout,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
await self.model_controller.start()
|
||||
|
||||
# Create tracking controller
|
||||
self.tracking_controller = self.tracking_factory.create_controller(
|
||||
model_repository=self.model_repository,
|
||||
model_id=model_id,
|
||||
tracker_type="iou",
|
||||
)
|
||||
logger.info("TrackingController created")
|
||||
|
||||
self.initialized = True
|
||||
logger.info("StreamConnectionManager initialized successfully")
|
||||
|
||||
async def connect_stream(
|
||||
self,
|
||||
rtsp_url: str,
|
||||
stream_id: Optional[str] = None,
|
||||
on_tracking_result: Optional[Callable] = None,
|
||||
on_error: Optional[Callable] = None,
|
||||
buffer_size: int = 30,
|
||||
) -> StreamConnection:
|
||||
"""
|
||||
Connect to a stream and start processing.
|
||||
|
||||
Args:
|
||||
rtsp_url: RTSP stream URL
|
||||
stream_id: Optional stream identifier (auto-generated if not provided)
|
||||
on_tracking_result: Optional callback for tracking results (sync or async)
|
||||
on_error: Optional callback for errors (sync or async)
|
||||
buffer_size: Decoder buffer size (default: 30)
|
||||
|
||||
Returns:
|
||||
StreamConnection object for this stream
|
||||
|
||||
Raises:
|
||||
RuntimeError: If manager is not initialized
|
||||
ConnectionError: If stream connection fails
|
||||
"""
|
||||
if not self.initialized:
|
||||
raise RuntimeError("Manager not initialized. Call initialize() first.")
|
||||
|
||||
# Generate stream ID if not provided
|
||||
if stream_id is None:
|
||||
stream_id = f"stream_{len(self.connections)}"
|
||||
|
||||
logger.info(f"Connecting to stream {stream_id}: {rtsp_url}")
|
||||
|
||||
# Create decoder
|
||||
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
|
||||
|
||||
# Create connection
|
||||
connection = StreamConnection(
|
||||
stream_id=stream_id,
|
||||
decoder=decoder,
|
||||
model_controller=self.model_controller,
|
||||
tracking_controller=self.tracking_controller,
|
||||
poll_interval=self.poll_interval,
|
||||
)
|
||||
|
||||
# Register callback with model controller
|
||||
self.model_controller.register_callback(
|
||||
stream_id,
|
||||
connection._handle_inference_result
|
||||
)
|
||||
|
||||
# Start connection
|
||||
await connection.start()
|
||||
|
||||
# Store connection
|
||||
self.connections[stream_id] = connection
|
||||
|
||||
# Set up user callbacks if provided
|
||||
if on_tracking_result:
|
||||
asyncio.create_task(self._forward_results(connection, on_tracking_result))
|
||||
|
||||
if on_error:
|
||||
asyncio.create_task(self._forward_errors(connection, on_error))
|
||||
|
||||
logger.info(f"Stream {stream_id} connected successfully")
|
||||
return connection
|
||||
|
||||
async def disconnect_stream(self, stream_id: str):
|
||||
"""
|
||||
Disconnect and cleanup a stream.
|
||||
|
||||
Args:
|
||||
stream_id: Stream identifier to disconnect
|
||||
"""
|
||||
connection = self.connections.get(stream_id)
|
||||
if connection:
|
||||
await connection.stop()
|
||||
del self.connections[stream_id]
|
||||
logger.info(f"Stream {stream_id} disconnected")
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""Disconnect all streams"""
|
||||
logger.info("Disconnecting all streams...")
|
||||
stream_ids = list(self.connections.keys())
|
||||
for stream_id in stream_ids:
|
||||
await self.disconnect_stream(stream_id)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the manager and cleanup all resources"""
|
||||
logger.info("Shutting down StreamConnectionManager...")
|
||||
|
||||
# Disconnect all streams
|
||||
await self.disconnect_all()
|
||||
|
||||
# Stop model controller
|
||||
if self.model_controller:
|
||||
await self.model_controller.stop()
|
||||
|
||||
# Note: Model repository cleanup is sync and may cause segfaults
|
||||
# Leaving cleanup to garbage collection for now
|
||||
|
||||
self.initialized = False
|
||||
logger.info("StreamConnectionManager shutdown complete")
|
||||
|
||||
async def _forward_results(self, connection: StreamConnection, callback: Callable):
|
||||
"""
|
||||
Forward results from connection to user callback.
|
||||
|
||||
Args:
|
||||
connection: StreamConnection to listen to
|
||||
callback: User callback (sync or async)
|
||||
"""
|
||||
try:
|
||||
async for result in connection.tracking_results():
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(result)
|
||||
else:
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||
|
||||
async def _forward_errors(self, connection: StreamConnection, callback: Callable):
|
||||
"""
|
||||
Forward errors from connection to user callback.
|
||||
|
||||
Args:
|
||||
connection: StreamConnection to listen to
|
||||
callback: User callback (sync or async)
|
||||
"""
|
||||
try:
|
||||
async for error in connection.errors():
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error)
|
||||
else:
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics for all connections.
|
||||
|
||||
Returns:
|
||||
Dictionary with manager and connection statistics
|
||||
"""
|
||||
return {
|
||||
"manager": {
|
||||
"initialized": self.initialized,
|
||||
"gpu_id": self.gpu_id,
|
||||
"num_connections": len(self.connections),
|
||||
"batch_size": self.batch_size,
|
||||
"force_timeout": self.force_timeout,
|
||||
"poll_interval": self.poll_interval,
|
||||
},
|
||||
"model_controller": self.model_controller.get_stats() if self.model_controller else {},
|
||||
"connections": {
|
||||
stream_id: conn.get_stats()
|
||||
for stream_id, conn in self.connections.items()
|
||||
},
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue