""" StreamConnectionManager - Async orchestration for stream processing with batched inference. This module provides high-level connection management for multiple RTSP streams, coordinating decoders, batched inference, and tracking with an event-driven API. """ import asyncio import time from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List from dataclasses import dataclass from enum import Enum import logging import torch from .model_controller import ModelController from .stream_decoder import StreamDecoderFactory from .model_repository import TensorRTModelRepository logger = logging.getLogger(__name__) class ConnectionStatus(Enum): """Status of a stream connection""" CONNECTING = "connecting" CONNECTED = "connected" DISCONNECTED = "disconnected" ERROR = "error" @dataclass class TrackingResult: """Result emitted to user callbacks""" stream_id: str timestamp: float tracked_objects: List # List of TrackedObject from TrackingController detections: List # Raw detections frame_shape: Tuple[int, int, int] metadata: Dict class StreamConnection: """ Represents a single stream connection with event emission. This class wraps a StreamDecoder, polls frames asynchronously, submits them to the ModelController for batched inference, runs tracking, and emits results via queues or callbacks. Args: stream_id: Unique identifier for this stream decoder: StreamDecoder instance model_controller: ModelController for batched inference tracking_controller: TrackingController for object tracking poll_interval: Frame polling interval in seconds (default: 0.01) """ def __init__( self, stream_id: str, decoder, model_controller: ModelController, tracking_controller, poll_interval: float = 0.01, ): self.stream_id = stream_id self.decoder = decoder self.model_controller = model_controller self.tracking_controller = tracking_controller self.poll_interval = poll_interval self.status = ConnectionStatus.CONNECTING self.frame_count = 0 self.last_frame_time = 0.0 # Event emission self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue() self.error_queue: asyncio.Queue[Exception] = asyncio.Queue() # Tasks self.poller_task: Optional[asyncio.Task] = None self.running = False async def start(self): """Start the connection (decoder and frame polling)""" # Start decoder (runs in background thread) self.decoder.start() # Wait for initial connection (try for up to 10 seconds) max_wait = 10.0 wait_interval = 0.5 elapsed = 0.0 while elapsed < max_wait: await asyncio.sleep(wait_interval) elapsed += wait_interval if self.decoder.is_connected(): self.status = ConnectionStatus.CONNECTED logger.info(f"Stream {self.stream_id} connected after {elapsed:.1f}s") break else: # Timeout - but don't fail hard, let it try to connect in background logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...") self.status = ConnectionStatus.CONNECTING # Start frame polling task self.running = True self.poller_task = asyncio.create_task(self._frame_poller()) async def stop(self): """Stop the connection and cleanup""" logger.info(f"Stopping stream {self.stream_id}...") self.running = False if self.poller_task: self.poller_task.cancel() try: await self.poller_task except asyncio.CancelledError: pass # Stop decoder self.decoder.stop() # Unregister from model controller self.model_controller.unregister_callback(self.stream_id) self.status = ConnectionStatus.DISCONNECTED logger.info(f"Stream {self.stream_id} stopped") async def _frame_poller(self): """Poll frames from threaded decoder and submit to model controller""" last_decoder_frame_count = -1 while self.running: try: # Get current decoder frame count (no data transfer, just counter) decoder_frame_count = self.decoder.get_frame_count() # Check if decoder has a new frame (avoid reprocessing same frame) if decoder_frame_count > last_decoder_frame_count: # Poll frame from decoder (zero-copy - stays in VRAM) frame = self.decoder.get_latest_frame(rgb=True) if frame is not None: last_decoder_frame_count = decoder_frame_count self.last_frame_time = time.time() self.frame_count += 1 # Submit to model controller for batched inference await self.model_controller.submit_frame( stream_id=self.stream_id, frame=frame, metadata={ "frame_number": self.frame_count, "shape": tuple(frame.shape), } ) # Check decoder status if not self.decoder.is_connected(): if self.status == ConnectionStatus.CONNECTED: logger.warning(f"Stream {self.stream_id} disconnected") self.status = ConnectionStatus.DISCONNECTED # Decoder will auto-reconnect, just update status await asyncio.sleep(1.0) if self.decoder.is_connected(): logger.info(f"Stream {self.stream_id} reconnected") self.status = ConnectionStatus.CONNECTED except Exception as e: logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True) await self.error_queue.put(e) self.status = ConnectionStatus.ERROR # Sleep until next poll await asyncio.sleep(self.poll_interval) async def _handle_inference_result(self, result: Dict[str, Any]): """ Callback invoked by ModelController when inference is done. Runs tracking and emits final result. Args: result: Inference result dictionary """ try: # Extract detections detections = result["detections"] # Run tracking (this is sync, so run in executor) loop = asyncio.get_event_loop() tracked_objects = await loop.run_in_executor( None, lambda: self._run_tracking_sync(detections) ) # Create tracking result tracking_result = TrackingResult( stream_id=self.stream_id, timestamp=result["timestamp"], tracked_objects=tracked_objects, detections=detections, frame_shape=result["metadata"].get("shape"), metadata=result["metadata"], ) # Emit to result queue await self.result_queue.put(tracking_result) except Exception as e: logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True) await self.error_queue.put(e) def _run_tracking_sync(self, detections, min_confidence=0.7): """ Run tracking synchronously (called from executor). Args: detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id] min_confidence: Minimum confidence threshold for detections Returns: List of TrackedObject instances """ # Convert tensor detections to Detection objects, filtering by confidence from .tracking_controller import Detection detection_list = [] for det in detections: confidence = float(det[4]) # Filter by confidence threshold (prevents track accumulation) if confidence < min_confidence: continue detection_list.append(Detection( bbox=det[:4].cpu().tolist(), confidence=confidence, class_id=int(det[5]) if det.shape[0] > 5 else 0, class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown" )) # Update tracker with detections (lightweight, no model dependency!) return self.tracking_controller.update(detection_list) async def tracking_results(self) -> AsyncIterator[TrackingResult]: """ Async generator for tracking results. Usage: async for result in connection.tracking_results(): print(result.tracked_objects) Yields: TrackingResult objects as they become available """ while self.running or not self.result_queue.empty(): try: result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0) yield result except asyncio.TimeoutError: continue async def errors(self) -> AsyncIterator[Exception]: """ Async generator for errors. Yields: Exception objects as they occur """ while self.running or not self.error_queue.empty(): try: error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0) yield error except asyncio.TimeoutError: continue def get_stats(self) -> Dict[str, Any]: """Get connection statistics""" return { "stream_id": self.stream_id, "status": self.status.value, "frame_count": self.frame_count, "last_frame_time": self.last_frame_time, "decoder_connected": self.decoder.is_connected(), "decoder_buffer_size": self.decoder.get_buffer_size(), "result_queue_size": self.result_queue.qsize(), "error_queue_size": self.error_queue.qsize(), } class StreamConnectionManager: """ High-level manager for stream connections with batched inference. This manager coordinates multiple RTSP streams, batched model inference, and object tracking through an async event-driven API. Args: gpu_id: GPU device ID (default: 0) batch_size: Maximum batch size for inference (default: 16) force_timeout: Force buffer switch timeout in seconds (default: 0.05) poll_interval: Frame polling interval in seconds (default: 0.01) Example: manager = StreamConnectionManager(gpu_id=0, batch_size=16) await manager.initialize(model_path="yolov8n.trt", ...) connection = await manager.connect_stream(rtsp_url, on_tracking_result=callback) await asyncio.sleep(60) await manager.shutdown() """ def __init__( self, gpu_id: int = 0, batch_size: int = 16, force_timeout: float = 0.05, poll_interval: float = 0.01, enable_pt_conversion: bool = True, ): self.gpu_id = gpu_id self.batch_size = batch_size self.force_timeout = force_timeout self.poll_interval = poll_interval # Factories self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id) self.model_repository = TensorRTModelRepository( gpu_id=gpu_id, enable_pt_conversion=enable_pt_conversion ) # Controllers self.model_controller: Optional[ModelController] = None # Connections self.connections: Dict[str, StreamConnection] = {} # State self.initialized = False async def initialize( self, model_path: str, model_id: str = "detector", preprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None, num_contexts: int = 4, pt_input_shapes: Optional[Dict] = None, pt_precision: Optional[Any] = None, **pt_conversion_kwargs ): """ Initialize the manager with a model. Args: model_path: Path to TensorRT or PyTorch model file (.trt, .pt, .pth) model_id: Model identifier (default: "detector") preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess) postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess) num_contexts: Number of TensorRT execution contexts (default: 4) pt_input_shapes: Required for PT files - dict of input shapes pt_precision: Precision for PT conversion (torch.float16 or torch.float32) **pt_conversion_kwargs: Additional PT conversion arguments """ logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}") # Load model loop = asyncio.get_event_loop() await loop.run_in_executor( None, lambda: self.model_repository.load_model( model_id, model_path, num_contexts=num_contexts, pt_input_shapes=pt_input_shapes, pt_precision=pt_precision, **pt_conversion_kwargs ) ) logger.info(f"Loaded model {model_id} from {model_path}") # Create model controller self.model_controller = ModelController( model_repository=self.model_repository, model_id=model_id, batch_size=self.batch_size, force_timeout=self.force_timeout, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, ) await self.model_controller.start() # Don't create a shared tracking controller here # Each stream will get its own tracking controller to avoid track accumulation self.tracking_controller = None self.model_id_for_tracking = model_id # Store for later use self.initialized = True logger.info("StreamConnectionManager initialized successfully") async def connect_stream( self, rtsp_url: str, stream_id: Optional[str] = None, on_tracking_result: Optional[Callable] = None, on_error: Optional[Callable] = None, buffer_size: int = 30, ) -> StreamConnection: """ Connect to a stream and start processing. Args: rtsp_url: RTSP stream URL stream_id: Optional stream identifier (auto-generated if not provided) on_tracking_result: Optional callback for tracking results (sync or async) on_error: Optional callback for errors (sync or async) buffer_size: Decoder buffer size (default: 30) Returns: StreamConnection object for this stream Raises: RuntimeError: If manager is not initialized ConnectionError: If stream connection fails """ if not self.initialized: raise RuntimeError("Manager not initialized. Call initialize() first.") # Generate stream ID if not provided if stream_id is None: stream_id = f"stream_{len(self.connections)}" logger.info(f"Connecting to stream {stream_id}: {rtsp_url}") # Create decoder decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size) # Create lightweight tracker (NO model_repository dependency!) from .tracking_controller import ObjectTracker tracking_controller = ObjectTracker( gpu_id=self.gpu_id, tracker_type="iou", max_age=30, iou_threshold=0.3, class_names=None # TODO: pass class names if available ) logger.info(f"Created lightweight ObjectTracker for stream {stream_id}") # Create connection connection = StreamConnection( stream_id=stream_id, decoder=decoder, model_controller=self.model_controller, tracking_controller=tracking_controller, poll_interval=self.poll_interval, ) # Register callback with model controller self.model_controller.register_callback( stream_id, connection._handle_inference_result ) # Start connection await connection.start() # Store connection self.connections[stream_id] = connection # Set up user callbacks if provided if on_tracking_result: asyncio.create_task(self._forward_results(connection, on_tracking_result)) if on_error: asyncio.create_task(self._forward_errors(connection, on_error)) logger.info(f"Stream {stream_id} connected successfully") return connection async def disconnect_stream(self, stream_id: str): """ Disconnect and cleanup a stream. Args: stream_id: Stream identifier to disconnect """ connection = self.connections.get(stream_id) if connection: await connection.stop() del self.connections[stream_id] logger.info(f"Stream {stream_id} disconnected") async def disconnect_all(self): """Disconnect all streams""" logger.info("Disconnecting all streams...") stream_ids = list(self.connections.keys()) for stream_id in stream_ids: await self.disconnect_stream(stream_id) async def shutdown(self): """Shutdown the manager and cleanup all resources""" logger.info("Shutting down StreamConnectionManager...") # Disconnect all streams await self.disconnect_all() # Stop model controller if self.model_controller: await self.model_controller.stop() # Note: Model repository cleanup is sync and may cause segfaults # Leaving cleanup to garbage collection for now self.initialized = False logger.info("StreamConnectionManager shutdown complete") async def _forward_results(self, connection: StreamConnection, callback: Callable): """ Forward results from connection to user callback. Args: connection: StreamConnection to listen to callback: User callback (sync or async) """ try: async for result in connection.tracking_results(): if asyncio.iscoroutinefunction(callback): await callback(result) else: callback(result) except Exception as e: logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True) async def _forward_errors(self, connection: StreamConnection, callback: Callable): """ Forward errors from connection to user callback. Args: connection: StreamConnection to listen to callback: User callback (sync or async) """ try: async for error in connection.errors(): if asyncio.iscoroutinefunction(callback): await callback(error) else: callback(error) except Exception as e: logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True) def get_stats(self) -> Dict[str, Any]: """ Get statistics for all connections. Returns: Dictionary with manager and connection statistics """ return { "manager": { "initialized": self.initialized, "gpu_id": self.gpu_id, "num_connections": len(self.connections), "batch_size": self.batch_size, "force_timeout": self.force_timeout, "poll_interval": self.poll_interval, }, "model_controller": self.model_controller.get_stats() if self.model_controller else {}, "connections": { stream_id: conn.get_stats() for stream_id, conn in self.connections.items() }, }