""" StreamConnectionManager - Event-driven orchestration for stream processing with batched inference. This module provides high-level connection management for multiple RTSP streams, coordinating decoders, batched inference, and tracking with callbacks and threading. """ import logging import queue import threading import time from dataclasses import dataclass from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple import torch from .base_model_controller import BaseModelController from .model_repository import TensorRTModelRepository from .stream_decoder import StreamDecoderFactory from .tensorrt_model_controller import TensorRTModelController from .ultralytics_model_controller import UltralyticsModelController logger = logging.getLogger(__name__) class ConnectionStatus(Enum): """Status of a stream connection""" CONNECTING = "connecting" CONNECTED = "connected" DISCONNECTED = "disconnected" ERROR = "error" @dataclass class TrackingResult: """Result emitted to user callbacks""" stream_id: str timestamp: float tracked_objects: List # List of TrackedObject from TrackingController detections: List # Raw detections frame_shape: Tuple[int, int, int] frame_tensor: Optional[torch.Tensor] # GPU tensor of the frame (C, H, W) metadata: Dict class StreamConnection: """ Represents a single stream connection with event emission. This class wraps a StreamDecoder, polls frames in a thread, submits them to the ModelController for batched inference, runs tracking, and emits results via queues or callbacks. Args: stream_id: Unique identifier for this stream decoder: StreamDecoder instance model_controller: ModelController for batched inference tracking_controller: TrackingController for object tracking poll_interval: Frame polling interval in seconds (default: 0.01) """ def __init__( self, stream_id: str, decoder, model_controller: BaseModelController, tracking_controller, poll_interval: float = 0.01, ): self.stream_id = stream_id self.decoder = decoder self.model_controller = model_controller self.tracking_controller = tracking_controller self.poll_interval = poll_interval self.status = ConnectionStatus.CONNECTING self.frame_count = 0 self.last_frame_time = 0.0 # Event emission self.result_queue: queue.Queue[TrackingResult] = queue.Queue() self.error_queue: queue.Queue[Exception] = queue.Queue() # Event-driven state self.running = False def start(self): """Start the connection (decoder with frame callback)""" self.running = True # Register callback for frame events from decoder self.decoder.register_frame_callback(self._on_frame_decoded) # Start decoder (runs in background thread) self.decoder.start() # Wait for initial connection (try for up to 10 seconds) max_wait = 10.0 wait_interval = 0.5 elapsed = 0.0 while elapsed < max_wait: time.sleep(wait_interval) elapsed += wait_interval if self.decoder.is_connected(): self.status = ConnectionStatus.CONNECTED logger.info(f"Stream {self.stream_id} connected after {elapsed:.1f}s") break else: # Timeout - but don't fail hard, let it try to connect in background logger.warning( f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying..." ) self.status = ConnectionStatus.CONNECTING def stop(self): """Stop the connection and cleanup""" logger.info(f"Stopping stream {self.stream_id}...") self.running = False # Unregister frame callback self.decoder.unregister_frame_callback(self._on_frame_decoded) # Stop decoder self.decoder.stop() # Unregister from model controller self.model_controller.unregister_callback(self.stream_id) self.status = ConnectionStatus.DISCONNECTED logger.info(f"Stream {self.stream_id} stopped") def _on_frame_decoded(self, frame_ref): """ Event handler called by decoder when a new frame is decoded. This is the event-driven replacement for polling. Args: frame_ref: FrameReference object containing the RGB frame tensor """ if not self.running: # If not running, free the frame immediately frame_ref.free() return try: self.last_frame_time = time.time() self.frame_count += 1 # CRITICAL: Clone the GPU tensor to decouple from decoder's frame buffer # The decoder reuses frame buffer memory, so we must copy the tensor # before submitting to async batched inference to prevent race conditions # where the decoder overwrites memory while inference is still reading it. cloned_tensor = frame_ref.rgb_tensor.clone() # Submit to model controller for batched inference # Pass the FrameReference in metadata so we can free it later logger.debug( f"[{self.stream_id}] Submitting frame {self.frame_count} to model controller" ) self.model_controller.submit_frame( stream_id=self.stream_id, frame=cloned_tensor, # Use cloned tensor, not original metadata={ "frame_number": self.frame_count, "shape": tuple(cloned_tensor.shape), "frame_ref": frame_ref, # Store reference for later cleanup }, ) logger.debug( f"[{self.stream_id}] Frame {self.frame_count} submitted, queue size: {len(self.model_controller.frame_queue)}" ) # Update connection status based on decoder status if ( self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED ): logger.info(f"Stream {self.stream_id} reconnected") self.status = ConnectionStatus.CONNECTED elif ( not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED ): logger.warning(f"Stream {self.stream_id} disconnected") self.status = ConnectionStatus.DISCONNECTED except Exception as e: logger.error( f"Error processing frame for {self.stream_id}: {e}", exc_info=True ) self.error_queue.put(e) self.status = ConnectionStatus.ERROR # Free the frame on error frame_ref.free() def _handle_inference_result(self, result: Dict[str, Any]): """ Callback invoked by ModelController when inference is done. Runs tracking and emits final result. Args: result: Inference result dictionary """ frame_ref = None try: # Extract detections detections = result["detections"] # Get FrameReference from metadata (if present) frame_ref = result["metadata"].get("frame_ref") # Run tracking (synchronous) with frame shape for bbox scaling frame_shape = result["metadata"].get("shape") tracked_objects = self._run_tracking_sync(detections, frame_shape) # Get ORIGINAL frame tensor from metadata (not the preprocessed one in result["frame"]) # The frame in result["frame"] is preprocessed (resized, normalized) # We need the original frame for visualization frame_ref = result["metadata"].get("frame_ref") frame_tensor = frame_ref.rgb_tensor if frame_ref else None # Create tracking result tracking_result = TrackingResult( stream_id=self.stream_id, timestamp=result["timestamp"], tracked_objects=tracked_objects, detections=detections, frame_shape=result["metadata"].get("shape"), frame_tensor=frame_tensor, # Original frame, not preprocessed metadata=result["metadata"], ) # Emit to result queue self.result_queue.put(tracking_result) except Exception as e: logger.error( f"Error handling inference result for {self.stream_id}: {e}", exc_info=True, ) self.error_queue.put(e) finally: # Free the frame reference - this is the last point in the pipeline if frame_ref is not None: frame_ref.free() def _run_tracking_sync(self, detections, frame_shape=None, min_confidence=0.7): """ Run tracking synchronously (called from executor). Args: detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id] frame_shape: Original frame shape (C, H, W) for scaling bboxes min_confidence: Minimum confidence threshold for detections Returns: List of TrackedObject instances """ # Convert tensor detections to Detection objects, filtering by confidence from .tracking_controller import Detection detection_list = [] for det in detections: confidence = float(det[4]) # Filter by confidence threshold (prevents track accumulation) if confidence < min_confidence: continue detection_list.append( Detection( bbox=det[:4].cpu().tolist(), confidence=confidence, class_id=int(det[5]) if det.shape[0] > 5 else 0, class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown", ) ) # Update tracker with detections (will scale bboxes to frame space) return self.tracking_controller.update(detection_list, frame_shape=frame_shape) def tracking_results(self): """ Generator for tracking results (blocking iterator). Usage: for result in connection.tracking_results(): print(result.tracked_objects) Yields: TrackingResult objects as they become available """ while self.running or not self.result_queue.empty(): try: result = self.result_queue.get(timeout=1.0) yield result except queue.Empty: continue def errors(self): """ Generator for errors (blocking iterator). Yields: Exception objects as they occur """ while self.running or not self.error_queue.empty(): try: error = self.error_queue.get(timeout=1.0) yield error except queue.Empty: continue def get_stats(self) -> Dict[str, Any]: """Get connection statistics""" return { "stream_id": self.stream_id, "status": self.status.value, "frame_count": self.frame_count, "last_frame_time": self.last_frame_time, "decoder_connected": self.decoder.is_connected(), "decoder_buffer_size": self.decoder.get_buffer_size(), "result_queue_size": self.result_queue.qsize(), "error_queue_size": self.error_queue.qsize(), } class StreamConnectionManager: """ High-level manager for stream connections with batched inference. This manager coordinates multiple RTSP streams, batched model inference, and object tracking through an async event-driven API. Args: gpu_id: GPU device ID (default: 0) batch_size: Maximum batch size for inference (default: 16) max_queue_size: Maximum frames in queue before dropping (default: 100) poll_interval: Frame polling interval in seconds (default: 0.01) Example: manager = StreamConnectionManager(gpu_id=0, batch_size=16) await manager.initialize(model_path="yolov8n.trt", ...) connection = await manager.connect_stream(rtsp_url, on_tracking_result=callback) await asyncio.sleep(60) await manager.shutdown() """ def __init__( self, gpu_id: int = 0, batch_size: int = 16, max_queue_size: int = 100, poll_interval: float = 0.01, enable_pt_conversion: bool = True, backend: str = "tensorrt", # "tensorrt" or "ultralytics" ): self.gpu_id = gpu_id self.batch_size = batch_size self.max_queue_size = max_queue_size self.poll_interval = poll_interval self.backend = backend.lower() # Factories self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id) # Initialize inference engine based on backend self.inference_engine = None self.model_repository = None # Legacy - will be removed if self.backend == "ultralytics": # Use Ultralytics native YOLO inference from .inference_engine import UltralyticsEngine self.inference_engine = UltralyticsEngine() logger.info("Using Ultralytics inference engine") else: # Use native TensorRT inference self.model_repository = TensorRTModelRepository( gpu_id=gpu_id, enable_pt_conversion=enable_pt_conversion ) logger.info("Using native TensorRT inference engine") # Controllers self.model_controller = ( None # Will be TensorRTModelController or UltralyticsModelController ) # Connections self.connections: Dict[str, StreamConnection] = {} # State self.initialized = False def initialize( self, model_path: str, model_id: str = "detector", preprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None, num_contexts: int = 4, pt_input_shapes: Optional[Dict] = None, pt_precision: Optional[Any] = None, **pt_conversion_kwargs, ): """ Initialize the manager with a model. Supports transparent loading of .pt (YOLO), .engine, and .trt files. For Ultralytics YOLO models (.pt), metadata is auto-detected - no manual input_shapes or precision needed! Non-YOLO models still require input_shapes. Args: model_path: Path to model file (.trt, .engine, .pt, .pth) - .engine: Ultralytics native format (recommended) - .pt: Auto-converts to .engine (YOLO models only) - .trt: Raw TensorRT engine model_id: Model identifier (default: "detector") preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess) postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess) num_contexts: Number of TensorRT execution contexts (default: 4) pt_input_shapes: [Optional] Only required for non-YOLO PyTorch models YOLO models auto-detect from embedded metadata pt_precision: [Optional] Precision for PT conversion (auto-detected for YOLO) **pt_conversion_kwargs: Additional PT conversion arguments Example: # YOLO model - no manual parameters needed: manager.initialize( model_path="model.pt", # or .engine preprocess_fn=YOLOv8Utils.preprocess, postprocess_fn=YOLOv8Utils.postprocess ) """ logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}") logger.info(f"Backend: {self.backend}") # Initialize engine based on backend if self.backend == "ultralytics": # Use Ultralytics native inference logger.info("Initializing Ultralytics YOLO engine...") device = torch.device(f"cuda:{self.gpu_id}") metadata = self.inference_engine.initialize( model_path=model_path, device=device, batch=self.batch_size, half=False, # Use FP32 for now imgsz=640, **pt_conversion_kwargs, ) logger.info(f"Ultralytics engine initialized: {metadata}") # Create Ultralytics model controller self.model_controller = UltralyticsModelController( inference_engine=self.inference_engine, model_id=model_id, batch_size=self.batch_size, max_queue_size=self.max_queue_size, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, ) self.model_controller.start() else: # Use native TensorRT with model repository logger.info("Initializing TensorRT engine...") self.model_repository.load_model( model_id, model_path, num_contexts=num_contexts, pt_input_shapes=pt_input_shapes, pt_precision=pt_precision, **pt_conversion_kwargs, ) logger.info(f"Loaded model {model_id} from {model_path}") # Create TensorRT model controller self.model_controller = TensorRTModelController( model_repository=self.model_repository, model_id=model_id, batch_size=self.batch_size, max_queue_size=self.max_queue_size, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, ) self.model_controller.start() # Don't create a shared tracking controller here # Each stream will get its own tracking controller to avoid track accumulation self.tracking_controller = None self.model_id_for_tracking = model_id # Store for later use self.initialized = True logger.info("StreamConnectionManager initialized successfully") def connect_stream( self, rtsp_url: str, stream_id: Optional[str] = None, on_tracking_result: Optional[Callable] = None, on_error: Optional[Callable] = None, buffer_size: int = 30, ) -> StreamConnection: """ Connect to a stream and start processing. Args: rtsp_url: RTSP stream URL stream_id: Optional stream identifier (auto-generated if not provided) on_tracking_result: Optional callback for tracking results (synchronous) on_error: Optional callback for errors (synchronous) buffer_size: Decoder buffer size (default: 30) Returns: StreamConnection object for this stream Raises: RuntimeError: If manager is not initialized ConnectionError: If stream connection fails """ if not self.initialized: raise RuntimeError("Manager not initialized. Call initialize() first.") # Generate stream ID if not provided if stream_id is None: stream_id = f"stream_{len(self.connections)}" logger.info(f"Connecting to stream {stream_id}: {rtsp_url}") # Create decoder decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size) # Create lightweight tracker (NO model_repository dependency!) from .tracking_controller import ObjectTracker tracking_controller = ObjectTracker( gpu_id=self.gpu_id, tracker_type="iou", max_age=30, iou_threshold=0.3, class_names=None, # TODO: pass class names if available ) logger.info(f"Created lightweight ObjectTracker for stream {stream_id}") # Create connection connection = StreamConnection( stream_id=stream_id, decoder=decoder, model_controller=self.model_controller, tracking_controller=tracking_controller, poll_interval=self.poll_interval, ) # Register callback with model controller self.model_controller.register_callback( stream_id, connection._handle_inference_result ) # Start connection connection.start() # Store connection self.connections[stream_id] = connection # Set up user callbacks if provided (run in separate threads) if on_tracking_result: threading.Thread( target=self._forward_results, args=(connection, on_tracking_result), daemon=True, ).start() if on_error: threading.Thread( target=self._forward_errors, args=(connection, on_error), daemon=True ).start() logger.info(f"Stream {stream_id} connected successfully") return connection def disconnect_stream(self, stream_id: str): """ Disconnect and cleanup a stream. Args: stream_id: Stream identifier to disconnect """ connection = self.connections.get(stream_id) if connection: connection.stop() del self.connections[stream_id] logger.info(f"Stream {stream_id} disconnected") def disconnect_all(self): """Disconnect all streams""" logger.info("Disconnecting all streams...") stream_ids = list(self.connections.keys()) for stream_id in stream_ids: self.disconnect_stream(stream_id) def shutdown(self): """Shutdown the manager and cleanup all resources""" logger.info("Shutting down StreamConnectionManager...") # Disconnect all streams self.disconnect_all() # Stop model controller if self.model_controller: self.model_controller.stop() # Note: Model repository cleanup is sync and may cause segfaults # Leaving cleanup to garbage collection for now self.initialized = False logger.info("StreamConnectionManager shutdown complete") def _forward_results(self, connection: StreamConnection, callback: Callable): """ Forward results from connection to user callback. Args: connection: StreamConnection to listen to callback: User callback (synchronous) """ try: for result in connection.tracking_results(): callback(result) except Exception as e: logger.error( f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True, ) def _forward_errors(self, connection: StreamConnection, callback: Callable): """ Forward errors from connection to user callback. Args: connection: StreamConnection to listen to callback: User callback (synchronous) """ try: for error in connection.errors(): callback(error) except Exception as e: logger.error( f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True, ) def get_stats(self) -> Dict[str, Any]: """ Get statistics for all connections. Returns: Dictionary with manager and connection statistics """ return { "manager": { "initialized": self.initialized, "gpu_id": self.gpu_id, "num_connections": len(self.connections), "batch_size": self.batch_size, "max_queue_size": self.max_queue_size, "poll_interval": self.poll_interval, }, "model_controller": self.model_controller.get_stats() if self.model_controller else {}, "connections": { stream_id: conn.get_stats() for stream_id, conn in self.connections.items() }, }