From dd57b5a24673805f01d1a93a7ead9b3b04fbd915 Mon Sep 17 00:00:00 2001 From: Siwat Sirichai Date: Sun, 9 Nov 2025 18:43:56 +0700 Subject: [PATCH] batch processing/event driven --- EVENT_DRIVEN_DESIGN.md | 1108 +++++++++++++++++++++++++ services/__init__.py | 8 + services/model_controller.py | 499 +++++++++++ services/stream_connection_manager.py | 566 +++++++++++++ test_event_driven.py | 373 +++++++++ test_event_driven_quick.py | 117 +++ test_tracking_realtime.py | 4 +- 7 files changed, 2673 insertions(+), 2 deletions(-) create mode 100644 EVENT_DRIVEN_DESIGN.md create mode 100644 services/model_controller.py create mode 100644 services/stream_connection_manager.py create mode 100644 test_event_driven.py create mode 100755 test_event_driven_quick.py diff --git a/EVENT_DRIVEN_DESIGN.md b/EVENT_DRIVEN_DESIGN.md new file mode 100644 index 0000000..45d908a --- /dev/null +++ b/EVENT_DRIVEN_DESIGN.md @@ -0,0 +1,1108 @@ +# Event-Driven Stream Processing Architecture with Batching + +## Overview + +This document describes the AsyncIO-based event-driven architecture for connecting stream decoders to models and tracking, with support for batched inference using ping-pong circular buffers. + +## Architecture Diagram + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ StreamConnectionManager │ +│ - Manages multiple stream connections │ +│ - Routes events to user callbacks/generators │ +│ - Coordinates ModelController and TrackingController │ +└─────────────────────────────────────────────────────────────────┘ + │ + ├──────────────────┬──────────────────┐ + ▼ ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ + │ StreamConnection│ │ StreamConnection│ │ StreamConnection│ + │ (Stream 1) │ │ (Stream 2) │ │ (Stream N) │ + │ │ │ │ │ │ + │ - StreamDecoder │ │ - StreamDecoder │ │ - StreamDecoder │ + │ - Frame Poller │ │ - Frame Poller │ │ - Frame Poller │ + │ - Event Emitter │ │ - Event Emitter │ │ - Event Emitter │ + └─────────────────┘ └─────────────────┘ └─────────────────┘ + │ │ │ + └──────────────────┴──────────────────┘ + │ + ▼ + ┌─────────────────────────────────────┐ + │ ModelController │ + │ ┌────────────┐ ┌────────────┐ │ + │ │ Buffer A │ │ Buffer B │ │ + │ │ (Active) │ │(Processing)│ │ + │ │ [frame1] │ │ [frame9] │ │ + │ │ [frame2] │ │ [frame10] │ │ + │ │ [frame3] │ │ [...] │ │ + │ └────────────┘ └────────────┘ │ + │ │ + │ - Batch accumulation │ + │ - Force timeout monitor │ + │ - Ping-pong switching │ + └─────────────────────────────────────┘ + │ + ┌────────────┴────────────┐ + ▼ ▼ + ┌─────────────────────┐ ┌─────────────────────┐ + │ TensorRTModelRepo │ │ TrackingController │ + │ - Batched inference │ │ - Track association │ + │ - Context pooling │ │ - Track management │ + └─────────────────────┘ └─────────────────────┘ + │ + ▼ + ┌─────────────────────────┐ + │ User Callbacks/Queues │ + │ - on_tracking_result │ + │ - on_detections │ + │ - on_error │ + └─────────────────────────┘ +``` + +## Component Details + +### 1. ModelController (Async Batching Layer) + +**Responsibilities:** +- Accumulate frames from multiple streams into batches +- Manage ping-pong buffers (BufferA/BufferB) +- Monitor force-switch timeout +- Execute batched inference +- Route results back to streams + +**Ping-Pong Buffer Logic:** +- **BufferA (Active)**: Accumulates incoming frames +- **BufferB (Processing)**: Being processed through inference +- **Switch Triggers:** + 1. Active buffer reaches `batch_size` → immediate swap + 2. `force_timeout` expires AND processing buffer is idle → force swap + 3. Never switch if processing buffer is busy + +### 2. StreamConnectionManager + +**Responsibilities:** +- Create and manage stream connections +- Coordinate ModelController and TrackingController +- Route tracking results to user callbacks/generators +- Handle stream lifecycle (connect, disconnect, errors) + +### 3. StreamConnection + +**Responsibilities:** +- Wrap a single StreamDecoder +- Poll frames from threaded decoder (bridge to async) +- Submit frames to ModelController +- Emit events to user code + +--- + +## Pseudo Code Implementation + +### 1. ModelController with Ping-Pong Buffers + +```python +import asyncio +import torch +from typing import Dict, List, Tuple, Optional, Callable +from dataclasses import dataclass +from enum import Enum +import time + +@dataclass +class BatchFrame: + """Represents a frame in the batch buffer""" + stream_id: str + frame: torch.Tensor # GPU tensor (3, H, W) + timestamp: float + metadata: Dict = None + +class BufferState(Enum): + IDLE = "idle" + FILLING = "filling" + PROCESSING = "processing" + +class ModelController: + """ + Manages batched inference with ping-pong buffers and force-switch timeout. + """ + + def __init__( + self, + model_repository, + model_id: str, + batch_size: int = 16, + force_timeout: float = 0.05, # 50ms + preprocess_fn: Callable = None, + postprocess_fn: Callable = None, + ): + self.model_repository = model_repository + self.model_id = model_id + self.batch_size = batch_size + self.force_timeout = force_timeout + self.preprocess_fn = preprocess_fn + self.postprocess_fn = postprocess_fn + + # Ping-pong buffers + self.buffer_a: List[BatchFrame] = [] + self.buffer_b: List[BatchFrame] = [] + + # Buffer states + self.active_buffer = "A" # Which buffer is currently active (filling) + self.buffer_a_state = BufferState.IDLE + self.buffer_b_state = BufferState.IDLE + + # Async coordination + self.buffer_lock = asyncio.Lock() + self.last_submit_time = time.time() + + # Tasks + self.timeout_task: Optional[asyncio.Task] = None + self.processor_task: Optional[asyncio.Task] = None + self.running = False + + # Result callbacks (stream_id -> callback) + self.result_callbacks: Dict[str, Callable] = {} + + async def start(self): + """Start the controller background tasks""" + self.running = True + self.timeout_task = asyncio.create_task(self._timeout_monitor()) + self.processor_task = asyncio.create_task(self._batch_processor()) + + async def stop(self): + """Stop the controller and cleanup""" + self.running = False + if self.timeout_task: + self.timeout_task.cancel() + if self.processor_task: + self.processor_task.cancel() + + # Process any remaining frames + await self._process_remaining_buffers() + + def register_callback(self, stream_id: str, callback: Callable): + """Register a callback for inference results from a stream""" + self.result_callbacks[stream_id] = callback + + def unregister_callback(self, stream_id: str): + """Unregister a stream callback""" + self.result_callbacks.pop(stream_id, None) + + async def submit_frame(self, stream_id: str, frame: torch.Tensor, metadata: Dict = None): + """ + Submit a frame for batched inference. + + Args: + stream_id: Unique stream identifier + frame: GPU tensor (3, H, W) + metadata: Optional metadata to attach + """ + async with self.buffer_lock: + batch_frame = BatchFrame( + stream_id=stream_id, + frame=frame, + timestamp=time.time(), + metadata=metadata or {} + ) + + # Add to active buffer + if self.active_buffer == "A": + self.buffer_a.append(batch_frame) + self.buffer_a_state = BufferState.FILLING + buffer_size = len(self.buffer_a) + else: + self.buffer_b.append(batch_frame) + self.buffer_b_state = BufferState.FILLING + buffer_size = len(self.buffer_b) + + self.last_submit_time = time.time() + + # Check if we should immediately swap (batch full) + if buffer_size >= self.batch_size: + await self._try_swap_buffers() + + async def _timeout_monitor(self): + """Monitor force-switch timeout""" + while self.running: + await asyncio.sleep(0.01) # Check every 10ms + + async with self.buffer_lock: + time_since_submit = time.time() - self.last_submit_time + + # Check if timeout expired and we have frames waiting + if time_since_submit >= self.force_timeout: + active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b + if len(active_buffer) > 0: + await self._try_swap_buffers() + + async def _try_swap_buffers(self): + """ + Attempt to swap ping-pong buffers. + Only swaps if the inactive buffer is not currently processing. + + This method should be called with buffer_lock held. + """ + # Check if inactive buffer is available + inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state + + if inactive_state != BufferState.PROCESSING: + # Swap active buffer + old_active = self.active_buffer + self.active_buffer = "B" if old_active == "A" else "A" + + # Mark old active buffer as ready for processing + if old_active == "A": + self.buffer_a_state = BufferState.PROCESSING + else: + self.buffer_b_state = BufferState.PROCESSING + + # Signal processor that there's work to do + # (The processor task is already running and will pick it up) + + async def _batch_processor(self): + """Background task that processes batches when available""" + while self.running: + await asyncio.sleep(0.001) # Check every 1ms + + # Check if buffer A needs processing + if self.buffer_a_state == BufferState.PROCESSING: + await self._process_buffer("A") + + # Check if buffer B needs processing + if self.buffer_b_state == BufferState.PROCESSING: + await self._process_buffer("B") + + async def _process_buffer(self, buffer_name: str): + """ + Process a buffer through inference. + + Args: + buffer_name: "A" or "B" + """ + async with self.buffer_lock: + # Get buffer to process + if buffer_name == "A": + batch = self.buffer_a.copy() + self.buffer_a.clear() + else: + batch = self.buffer_b.copy() + self.buffer_b.clear() + + if len(batch) == 0: + # Mark as idle and return + async with self.buffer_lock: + if buffer_name == "A": + self.buffer_a_state = BufferState.IDLE + else: + self.buffer_b_state = BufferState.IDLE + return + + # Process batch (outside lock to allow concurrent submissions) + try: + results = await self._run_batch_inference(batch) + + # Emit results to callbacks + for batch_frame, result in zip(batch, results): + callback = self.result_callbacks.get(batch_frame.stream_id) + if callback: + # Schedule callback asynchronously + if asyncio.iscoroutinefunction(callback): + asyncio.create_task(callback(result)) + else: + callback(result) + + except Exception as e: + print(f"Error processing batch: {e}") + # TODO: Emit error events + + finally: + # Mark buffer as idle + async with self.buffer_lock: + if buffer_name == "A": + self.buffer_a_state = BufferState.IDLE + else: + self.buffer_b_state = BufferState.IDLE + + async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict]: + """ + Run inference on a batch of frames. + + Args: + batch: List of BatchFrame objects + + Returns: + List of detection results (one per frame) + """ + # Preprocess frames (on GPU) + preprocessed = [] + for batch_frame in batch: + if self.preprocess_fn: + processed = self.preprocess_fn(batch_frame.frame) + else: + processed = batch_frame.frame + preprocessed.append(processed) + + # Stack into batch tensor: (N, C, H, W) + batch_tensor = torch.stack(preprocessed, dim=0) + + # Run inference (TensorRT model repository is sync, so run in executor) + loop = asyncio.get_event_loop() + outputs = await loop.run_in_executor( + None, + lambda: self.model_repository.infer( + self.model_id, + {"images": batch_tensor}, + synchronize=True + ) + ) + + # Postprocess results (split batch back to individual results) + results = [] + for i, batch_frame in enumerate(batch): + # Extract single frame output from batch + frame_output = {k: v[i:i+1] for k, v in outputs.items()} + + if self.postprocess_fn: + detections = self.postprocess_fn(frame_output) + else: + detections = frame_output + + result = { + "stream_id": batch_frame.stream_id, + "timestamp": batch_frame.timestamp, + "detections": detections, + "metadata": batch_frame.metadata, + } + results.append(result) + + return results + + async def _process_remaining_buffers(self): + """Process any remaining frames in buffers during shutdown""" + if len(self.buffer_a) > 0: + await self._process_buffer("A") + if len(self.buffer_b) > 0: + await self._process_buffer("B") + + def get_stats(self) -> Dict: + """Get current buffer statistics""" + return { + "active_buffer": self.active_buffer, + "buffer_a_size": len(self.buffer_a), + "buffer_b_size": len(self.buffer_b), + "buffer_a_state": self.buffer_a_state.value, + "buffer_b_state": self.buffer_b_state.value, + "registered_streams": len(self.result_callbacks), + } +``` + +### 2. StreamConnectionManager + +```python +import asyncio +from typing import Dict, Optional, Callable, AsyncIterator +from dataclasses import dataclass +from enum import Enum + +class ConnectionStatus(Enum): + CONNECTING = "connecting" + CONNECTED = "connected" + DISCONNECTED = "disconnected" + ERROR = "error" + +@dataclass +class TrackingResult: + """Result emitted to user callbacks""" + stream_id: str + timestamp: float + tracked_objects: List # List of TrackedObject from TrackingController + detections: List # Raw detections + frame_shape: Tuple[int, int, int] + metadata: Dict + +class StreamConnection: + """Represents a single stream connection with event emission""" + + def __init__( + self, + stream_id: str, + decoder, + model_controller: ModelController, + tracking_controller, + poll_interval: float = 0.01, + ): + self.stream_id = stream_id + self.decoder = decoder + self.model_controller = model_controller + self.tracking_controller = tracking_controller + self.poll_interval = poll_interval + + self.status = ConnectionStatus.CONNECTING + self.frame_count = 0 + self.last_frame_time = 0.0 + + # Event emission + self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue() + self.error_queue: asyncio.Queue[Exception] = asyncio.Queue() + + # Tasks + self.poller_task: Optional[asyncio.Task] = None + self.running = False + + async def start(self): + """Start the connection (decoder and frame polling)""" + # Start decoder (runs in background thread) + self.decoder.start() + + # Wait for initial connection + await asyncio.sleep(2.0) + + if self.decoder.is_connected(): + self.status = ConnectionStatus.CONNECTED + else: + self.status = ConnectionStatus.ERROR + raise ConnectionError(f"Failed to connect to stream {self.stream_id}") + + # Start frame polling task + self.running = True + self.poller_task = asyncio.create_task(self._frame_poller()) + + async def stop(self): + """Stop the connection and cleanup""" + self.running = False + + if self.poller_task: + self.poller_task.cancel() + try: + await self.poller_task + except asyncio.CancelledError: + pass + + # Stop decoder + self.decoder.stop() + + # Unregister from model controller + self.model_controller.unregister_callback(self.stream_id) + + self.status = ConnectionStatus.DISCONNECTED + + async def _frame_poller(self): + """Poll frames from threaded decoder and submit to model controller""" + last_frame_ptr = None + + while self.running: + try: + # Poll frame from decoder (runs in thread) + frame = self.decoder.get_latest_frame(rgb=True) + + # Check if we got a new frame (avoid reprocessing same frame) + if frame is not None and frame.data_ptr() != last_frame_ptr: + last_frame_ptr = frame.data_ptr() + self.last_frame_time = time.time() + self.frame_count += 1 + + # Submit to model controller for batched inference + await self.model_controller.submit_frame( + stream_id=self.stream_id, + frame=frame, + metadata={ + "frame_number": self.frame_count, + "shape": frame.shape, + } + ) + + # Check decoder status + if not self.decoder.is_connected(): + self.status = ConnectionStatus.DISCONNECTED + # Decoder will auto-reconnect, just update status + await asyncio.sleep(1.0) + if self.decoder.is_connected(): + self.status = ConnectionStatus.CONNECTED + + except Exception as e: + await self.error_queue.put(e) + self.status = ConnectionStatus.ERROR + + # Sleep until next poll + await asyncio.sleep(self.poll_interval) + + async def _handle_inference_result(self, result: Dict): + """ + Callback invoked by ModelController when inference is done. + Runs tracking and emits final result. + """ + try: + # Extract detections + detections = result["detections"] + + # Run tracking (this is sync, so run in executor) + loop = asyncio.get_event_loop() + tracked_objects = await loop.run_in_executor( + None, + lambda: self.tracking_controller.update_tracks(detections) + ) + + # Create tracking result + tracking_result = TrackingResult( + stream_id=self.stream_id, + timestamp=result["timestamp"], + tracked_objects=tracked_objects, + detections=detections, + frame_shape=result["metadata"].get("shape"), + metadata=result["metadata"], + ) + + # Emit to result queue + await self.result_queue.put(tracking_result) + + except Exception as e: + await self.error_queue.put(e) + + async def tracking_results(self) -> AsyncIterator[TrackingResult]: + """ + Async generator for tracking results. + + Usage: + async for result in connection.tracking_results(): + print(result.tracked_objects) + """ + while self.running or not self.result_queue.empty(): + try: + result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0) + yield result + except asyncio.TimeoutError: + continue + + async def errors(self) -> AsyncIterator[Exception]: + """Async generator for errors""" + while self.running or not self.error_queue.empty(): + try: + error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0) + yield error + except asyncio.TimeoutError: + continue + + def get_stats(self) -> Dict: + """Get connection statistics""" + return { + "stream_id": self.stream_id, + "status": self.status.value, + "frame_count": self.frame_count, + "last_frame_time": self.last_frame_time, + "decoder_connected": self.decoder.is_connected(), + "decoder_buffer_size": self.decoder.get_buffer_size(), + } + + +class StreamConnectionManager: + """ + High-level manager for stream connections with batched inference. + """ + + def __init__( + self, + gpu_id: int = 0, + batch_size: int = 16, + force_timeout: float = 0.05, + poll_interval: float = 0.01, + ): + self.gpu_id = gpu_id + self.batch_size = batch_size + self.force_timeout = force_timeout + self.poll_interval = poll_interval + + # Factories + from services import StreamDecoderFactory, TrackingFactory + from services.model_repository import TensorRTModelRepository + + self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id) + self.tracking_factory = TrackingFactory(gpu_id=gpu_id) + self.model_repository = TensorRTModelRepository(gpu_id=gpu_id) + + # Controllers + self.model_controller: Optional[ModelController] = None + self.tracking_controller = None + + # Connections + self.connections: Dict[str, StreamConnection] = {} + + # State + self.initialized = False + + async def initialize( + self, + model_path: str, + model_id: str = "detector", + preprocess_fn: Callable = None, + postprocess_fn: Callable = None, + ): + """ + Initialize the manager with a model. + + Args: + model_path: Path to TensorRT model file + model_id: Model identifier + preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess) + postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess) + """ + # Load model + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, + lambda: self.model_repository.load_model(model_id, model_path, num_contexts=4) + ) + + # Create model controller + self.model_controller = ModelController( + model_repository=self.model_repository, + model_id=model_id, + batch_size=self.batch_size, + force_timeout=self.force_timeout, + preprocess_fn=preprocess_fn, + postprocess_fn=postprocess_fn, + ) + await self.model_controller.start() + + # Create tracking controller + self.tracking_controller = self.tracking_factory.create_controller( + model_repository=self.model_repository, + model_id=model_id, + tracker_type="iou", + ) + + self.initialized = True + + async def connect_stream( + self, + rtsp_url: str, + stream_id: Optional[str] = None, + on_tracking_result: Optional[Callable] = None, + on_error: Optional[Callable] = None, + ) -> StreamConnection: + """ + Connect to a stream and start processing. + + Args: + rtsp_url: RTSP stream URL + stream_id: Optional stream identifier (auto-generated if not provided) + on_tracking_result: Optional callback for tracking results + on_error: Optional callback for errors + + Returns: + StreamConnection object for this stream + """ + if not self.initialized: + raise RuntimeError("Manager not initialized. Call initialize() first.") + + # Generate stream ID if not provided + if stream_id is None: + stream_id = f"stream_{len(self.connections)}" + + # Create decoder + decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=30) + + # Create connection + connection = StreamConnection( + stream_id=stream_id, + decoder=decoder, + model_controller=self.model_controller, + tracking_controller=self.tracking_controller, + poll_interval=self.poll_interval, + ) + + # Register callback with model controller + self.model_controller.register_callback( + stream_id, + connection._handle_inference_result + ) + + # Start connection + await connection.start() + + # Store connection + self.connections[stream_id] = connection + + # Set up user callbacks if provided + if on_tracking_result: + asyncio.create_task(self._forward_results(connection, on_tracking_result)) + + if on_error: + asyncio.create_task(self._forward_errors(connection, on_error)) + + return connection + + async def disconnect_stream(self, stream_id: str): + """Disconnect and cleanup a stream""" + connection = self.connections.get(stream_id) + if connection: + await connection.stop() + del self.connections[stream_id] + + async def disconnect_all(self): + """Disconnect all streams""" + stream_ids = list(self.connections.keys()) + for stream_id in stream_ids: + await self.disconnect_stream(stream_id) + + async def shutdown(self): + """Shutdown the manager and cleanup all resources""" + # Disconnect all streams + await self.disconnect_all() + + # Stop model controller + if self.model_controller: + await self.model_controller.stop() + + # Cleanup (model repository cleanup is sync) + # Note: May need to handle cleanup carefully to avoid segfaults + + async def _forward_results(self, connection: StreamConnection, callback: Callable): + """Forward results from connection to user callback""" + async for result in connection.tracking_results(): + if asyncio.iscoroutinefunction(callback): + await callback(result) + else: + callback(result) + + async def _forward_errors(self, connection: StreamConnection, callback: Callable): + """Forward errors from connection to user callback""" + async for error in connection.errors(): + if asyncio.iscoroutinefunction(callback): + await callback(error) + else: + callback(error) + + def get_stats(self) -> Dict: + """Get statistics for all connections""" + return { + "manager": { + "initialized": self.initialized, + "num_connections": len(self.connections), + "batch_size": self.batch_size, + "force_timeout": self.force_timeout, + }, + "model_controller": self.model_controller.get_stats() if self.model_controller else {}, + "connections": { + stream_id: conn.get_stats() + for stream_id, conn in self.connections.items() + }, + } +``` + +### 3. User API Examples + +#### Example 1: Simple Callback Pattern + +```python +import asyncio +from services import StreamConnectionManager +from services.yolo import YOLOv8Utils + +async def main(): + # Create manager + manager = StreamConnectionManager( + gpu_id=0, + batch_size=16, + force_timeout=0.05, # 50ms + ) + + # Initialize with model + await manager.initialize( + model_path="models/yolov8n.trt", + model_id="yolo", + preprocess_fn=YOLOv8Utils.preprocess, + postprocess_fn=YOLOv8Utils.postprocess, + ) + + # Define callback for tracking results + def on_tracking_result(result): + print(f"Stream: {result.stream_id}") + print(f"Timestamp: {result.timestamp}") + print(f"Tracked objects: {len(result.tracked_objects)}") + for obj in result.tracked_objects: + print(f" - Track ID {obj.track_id}: class={obj.class_id}, conf={obj.confidence:.2f}") + + def on_error(error): + print(f"Error: {error}") + + # Connect to stream + connection = await manager.connect_stream( + rtsp_url="rtsp://camera1.example.com/stream", + stream_id="camera1", + on_tracking_result=on_tracking_result, + on_error=on_error, + ) + + # Let it run for 60 seconds + await asyncio.sleep(60) + + # Get statistics + stats = manager.get_stats() + print(f"Stats: {stats}") + + # Cleanup + await manager.shutdown() + +if __name__ == "__main__": + asyncio.run(main()) +``` + +#### Example 2: Async Generator Pattern (Multiple Streams) + +```python +import asyncio +from services import StreamConnectionManager +from services.yolo import YOLOv8Utils + +async def process_stream(connection, stream_name): + """Process results from a single stream""" + async for result in connection.tracking_results(): + print(f"[{stream_name}] Frame {result.metadata['frame_number']}: {len(result.tracked_objects)} objects") + + # Do something with tracked objects + for obj in result.tracked_objects: + if obj.class_id == 0: # Person class + print(f" Person detected! Track ID: {obj.track_id}, Conf: {obj.confidence:.2f}") + +async def main(): + manager = StreamConnectionManager( + gpu_id=0, + batch_size=32, # Larger batch for multiple streams + force_timeout=0.05, + ) + + await manager.initialize( + model_path="models/yolov8n.trt", + preprocess_fn=YOLOv8Utils.preprocess, + postprocess_fn=YOLOv8Utils.postprocess, + ) + + # Connect to multiple streams + camera_urls = [ + ("rtsp://camera1.example.com/stream", "Front Door"), + ("rtsp://camera2.example.com/stream", "Parking Lot"), + ("rtsp://camera3.example.com/stream", "Warehouse"), + ("rtsp://camera4.example.com/stream", "Loading Bay"), + ] + + tasks = [] + for url, name in camera_urls: + connection = await manager.connect_stream( + rtsp_url=url, + stream_id=name.lower().replace(" ", "_"), + ) + + # Create task to process this stream + task = asyncio.create_task(process_stream(connection, name)) + tasks.append(task) + + # Run all streams concurrently + try: + await asyncio.gather(*tasks) + except KeyboardInterrupt: + print("Shutting down...") + + await manager.shutdown() + +if __name__ == "__main__": + asyncio.run(main()) +``` + +#### Example 3: Queue-Based Pattern (for integration with other systems) + +```python +import asyncio +from services import StreamConnectionManager +from services.yolo import YOLOv8Utils + +async def main(): + manager = StreamConnectionManager(gpu_id=0, batch_size=16) + + await manager.initialize( + model_path="models/yolov8n.trt", + preprocess_fn=YOLOv8Utils.preprocess, + postprocess_fn=YOLOv8Utils.postprocess, + ) + + # Connect to stream (no callback) + connection = await manager.connect_stream( + rtsp_url="rtsp://camera.example.com/stream", + stream_id="main_camera", + ) + + # Use the built-in queue + result_queue = connection.result_queue + + # Process results from queue + while True: + result = await result_queue.get() + + # Send to external system (e.g., message queue, database, API) + await send_to_kafka(result) + await save_to_database(result) + + # Or do real-time processing + if has_person_alert(result.tracked_objects): + await send_alert("Person detected in restricted area!") + +async def send_to_kafka(result): + # Your Kafka producer code + pass + +async def save_to_database(result): + # Your database code + pass + +def has_person_alert(tracked_objects): + # Your alert logic + return any(obj.class_id == 0 for obj in tracked_objects) + +async def send_alert(message): + print(f"ALERT: {message}") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +#### Example 4: Async Callback with Error Handling + +```python +import asyncio +from services import StreamConnectionManager +from services.yolo import YOLOv8Utils + +async def main(): + manager = StreamConnectionManager(gpu_id=0, batch_size=16) + + await manager.initialize( + model_path="models/yolov8n.trt", + preprocess_fn=YOLOv8Utils.preprocess, + postprocess_fn=YOLOv8Utils.postprocess, + ) + + # Async callback (can do I/O operations) + async def on_tracking_result(result): + # Can use async operations in callback + await save_to_database(result) + + # Check for alerts + for obj in result.tracked_objects: + if obj.class_id == 0 and obj.confidence > 0.8: + await send_notification(f"High confidence person detection: {obj.track_id}") + + async def on_error(error): + await log_error_to_monitoring_system(error) + + # Connect with async callbacks + connection = await manager.connect_stream( + rtsp_url="rtsp://camera.example.com/stream", + on_tracking_result=on_tracking_result, + on_error=on_error, + ) + + # Monitor stats periodically + while True: + await asyncio.sleep(10) + stats = manager.get_stats() + print(f"Buffer stats: {stats['model_controller']}") + print(f"Connection stats: {stats['connections']}") + +async def save_to_database(result): + # Simulate async database operation + await asyncio.sleep(0.01) + +async def send_notification(message): + print(f"NOTIFICATION: {message}") + +async def log_error_to_monitoring_system(error): + print(f"ERROR: {error}") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +--- + +## Configuration Examples + +### Performance Tuning + +```python +# Low latency (small batches, quick timeout) +manager = StreamConnectionManager( + gpu_id=0, + batch_size=4, + force_timeout=0.02, # 20ms + poll_interval=0.005, # 200 FPS +) + +# High throughput (large batches, longer timeout) +manager = StreamConnectionManager( + gpu_id=0, + batch_size=32, + force_timeout=0.1, # 100ms + poll_interval=0.02, # 50 FPS +) + +# Balanced (default) +manager = StreamConnectionManager( + gpu_id=0, + batch_size=16, + force_timeout=0.05, # 50ms + poll_interval=0.01, # 100 FPS +) +``` + +### Multiple GPUs + +```python +# Create manager per GPU +manager_gpu0 = StreamConnectionManager(gpu_id=0, batch_size=16) +manager_gpu1 = StreamConnectionManager(gpu_id=1, batch_size=16) + +# Initialize both +await manager_gpu0.initialize(model_path="models/yolov8n.trt", ...) +await manager_gpu1.initialize(model_path="models/yolov8n.trt", ...) + +# Distribute streams across GPUs +await manager_gpu0.connect_stream(url1, ...) +await manager_gpu0.connect_stream(url2, ...) +await manager_gpu1.connect_stream(url3, ...) +await manager_gpu1.connect_stream(url4, ...) +``` + +--- + +## Key Features Summary + +1. **Ping-Pong Buffers**: Efficient batching with minimal latency +2. **Force Timeout**: Prevents starvation of small batches +3. **AsyncIO**: Clean event-driven architecture +4. **Multiple Patterns**: Callbacks, generators, queues +5. **Thread-Async Bridge**: Integrates with existing threaded decoders +6. **Zero-Copy**: All processing stays on GPU +7. **Auto-Reconnection**: Inherits from StreamDecoder +8. **Statistics**: Real-time monitoring of buffers and connections + +--- + +## Performance Characteristics + +- **Latency**: `force_timeout + inference_time` +- **Throughput**: Maximized by batching +- **VRAM**: 60MB per stream + batch buffer overhead +- **CPU**: Minimal (async event loop + thread polling) + +--- + +## Next Steps + +To implement this design: + +1. Create `services/model_controller.py` with `ModelController` class +2. Create `services/stream_connection_manager.py` with `StreamConnectionManager` and `StreamConnection` classes +3. Update `services/__init__.py` to export new classes +4. Create `test_event_driven.py` to test the system +5. Add monitoring/logging throughout +6. Handle edge cases (reconnection, cleanup, errors) diff --git a/services/__init__.py b/services/__init__.py index f0df9d6..497e777 100644 --- a/services/__init__.py +++ b/services/__init__.py @@ -8,6 +8,8 @@ from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionC from .tracking_controller import TrackingController, TrackedObject from .tracking_factory import TrackingFactory from .yolo import YOLOv8Utils, COCO_CLASSES +from .model_controller import ModelController, BatchFrame, BufferState +from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult __all__ = [ 'StreamDecoderFactory', @@ -24,4 +26,10 @@ __all__ = [ 'TrackingFactory', 'YOLOv8Utils', 'COCO_CLASSES', + 'ModelController', + 'BatchFrame', + 'BufferState', + 'StreamConnectionManager', + 'StreamConnection', + 'TrackingResult', ] diff --git a/services/model_controller.py b/services/model_controller.py new file mode 100644 index 0000000..2873506 --- /dev/null +++ b/services/model_controller.py @@ -0,0 +1,499 @@ +""" +ModelController - Async batching layer with ping-pong buffers for inference. + +This module provides batched inference coordination using ping-pong circular buffers +with force-switch timeout mechanism. +""" + +import asyncio +import torch +from typing import Dict, List, Optional, Callable, Any +from dataclasses import dataclass, field +from enum import Enum +import time +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class BatchFrame: + """Represents a frame in the batch buffer""" + stream_id: str + frame: torch.Tensor # GPU tensor (3, H, W) + timestamp: float + metadata: Dict = field(default_factory=dict) + + +class BufferState(Enum): + """State of a ping-pong buffer""" + IDLE = "idle" + FILLING = "filling" + PROCESSING = "processing" + + +class ModelController: + """ + Manages batched inference with ping-pong buffers and force-switch timeout. + + This controller accumulates frames from multiple streams into batches, + processes them through a model repository, and routes results back to + stream-specific callbacks. + + Features: + - Ping-pong circular buffers (BufferA/BufferB) + - Force-switch timeout to prevent batch starvation + - Async event-driven processing + - Thread-safe frame submission + + Args: + model_repository: TensorRT model repository for inference + model_id: Model identifier in the repository + batch_size: Maximum frames per batch (default: 16) + force_timeout: Max wait time before forcing buffer switch in seconds (default: 0.05) + preprocess_fn: Optional preprocessing function for frames + postprocess_fn: Optional postprocessing function for model outputs + """ + + def __init__( + self, + model_repository, + model_id: str, + batch_size: int = 16, + force_timeout: float = 0.05, + preprocess_fn: Optional[Callable] = None, + postprocess_fn: Optional[Callable] = None, + ): + self.model_repository = model_repository + self.model_id = model_id + self.batch_size = batch_size + self.force_timeout = force_timeout + self.preprocess_fn = preprocess_fn + self.postprocess_fn = postprocess_fn + + # Detect model's actual batch size from input shape + self.model_batch_size = self._detect_model_batch_size() + if self.model_batch_size == 1: + logger.warning( + f"Model '{model_id}' has fixed batch_size=1. " + f"Will process frames sequentially. Consider rebuilding model with dynamic batching." + ) + else: + logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}") + + # Ping-pong buffers + self.buffer_a: List[BatchFrame] = [] + self.buffer_b: List[BatchFrame] = [] + + # Buffer states + self.active_buffer = "A" # Which buffer is currently active (filling) + self.buffer_a_state = BufferState.IDLE + self.buffer_b_state = BufferState.IDLE + + # Async coordination + self.buffer_lock = asyncio.Lock() + self.last_submit_time = time.time() + + # Tasks + self.timeout_task: Optional[asyncio.Task] = None + self.processor_task: Optional[asyncio.Task] = None + self.running = False + + # Result callbacks (stream_id -> callback) + self.result_callbacks: Dict[str, Callable] = {} + + # Statistics + self.total_frames_processed = 0 + self.total_batches_processed = 0 + + def _detect_model_batch_size(self) -> int: + """ + Detect the model's batch size from its input shape. + + Returns: + Maximum batch size supported by the model (1 for fixed batch size models) + """ + try: + metadata = self.model_repository.get_metadata(self.model_id) + # Get first input tensor shape + first_input = list(metadata.inputs.values())[0] + batch_dim = first_input["shape"][0] + + # batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size) + if batch_dim == -1: + # Dynamic batch size - use user-specified batch_size + return self.batch_size + else: + # Fixed batch size + return batch_dim + except Exception as e: + logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1") + return 1 + + async def start(self): + """Start the controller background tasks""" + if self.running: + logger.warning("ModelController already running") + return + + self.running = True + self.timeout_task = asyncio.create_task(self._timeout_monitor()) + self.processor_task = asyncio.create_task(self._batch_processor()) + logger.info("ModelController started") + + async def stop(self): + """Stop the controller and cleanup""" + if not self.running: + return + + logger.info("Stopping ModelController...") + self.running = False + + # Cancel tasks + if self.timeout_task: + self.timeout_task.cancel() + try: + await self.timeout_task + except asyncio.CancelledError: + pass + + if self.processor_task: + self.processor_task.cancel() + try: + await self.processor_task + except asyncio.CancelledError: + pass + + # Process any remaining frames + await self._process_remaining_buffers() + logger.info("ModelController stopped") + + def register_callback(self, stream_id: str, callback: Callable): + """ + Register a callback for inference results from a stream. + + Args: + stream_id: Unique stream identifier + callback: Callback function to receive results (can be sync or async) + """ + self.result_callbacks[stream_id] = callback + logger.debug(f"Registered callback for stream: {stream_id}") + + def unregister_callback(self, stream_id: str): + """ + Unregister a stream callback. + + Args: + stream_id: Stream identifier to unregister + """ + self.result_callbacks.pop(stream_id, None) + logger.debug(f"Unregistered callback for stream: {stream_id}") + + async def submit_frame( + self, + stream_id: str, + frame: torch.Tensor, + metadata: Optional[Dict] = None + ): + """ + Submit a frame for batched inference. + + Args: + stream_id: Unique stream identifier + frame: GPU tensor (3, H, W) or (C, H, W) + metadata: Optional metadata to attach to the frame + """ + async with self.buffer_lock: + batch_frame = BatchFrame( + stream_id=stream_id, + frame=frame, + timestamp=time.time(), + metadata=metadata or {} + ) + + # Add to active buffer + if self.active_buffer == "A": + self.buffer_a.append(batch_frame) + self.buffer_a_state = BufferState.FILLING + buffer_size = len(self.buffer_a) + else: + self.buffer_b.append(batch_frame) + self.buffer_b_state = BufferState.FILLING + buffer_size = len(self.buffer_b) + + self.last_submit_time = time.time() + + # Check if we should immediately swap (batch full) + if buffer_size >= self.batch_size: + await self._try_swap_buffers() + + async def _timeout_monitor(self): + """Monitor force-switch timeout""" + while self.running: + await asyncio.sleep(0.01) # Check every 10ms + + async with self.buffer_lock: + time_since_submit = time.time() - self.last_submit_time + + # Check if timeout expired and we have frames waiting + if time_since_submit >= self.force_timeout: + active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b + if len(active_buffer) > 0: + await self._try_swap_buffers() + + async def _try_swap_buffers(self): + """ + Attempt to swap ping-pong buffers. + Only swaps if the inactive buffer is not currently processing. + + This method should be called with buffer_lock held. + """ + # Check if inactive buffer is available + inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state + + if inactive_state != BufferState.PROCESSING: + # Swap active buffer + old_active = self.active_buffer + self.active_buffer = "B" if old_active == "A" else "A" + + # Mark old active buffer as ready for processing + if old_active == "A": + self.buffer_a_state = BufferState.PROCESSING + buffer_size = len(self.buffer_a) + else: + self.buffer_b_state = BufferState.PROCESSING + buffer_size = len(self.buffer_b) + + logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})") + + async def _batch_processor(self): + """Background task that processes batches when available""" + while self.running: + await asyncio.sleep(0.001) # Check every 1ms + + # Check if buffer A needs processing + if self.buffer_a_state == BufferState.PROCESSING: + await self._process_buffer("A") + + # Check if buffer B needs processing + if self.buffer_b_state == BufferState.PROCESSING: + await self._process_buffer("B") + + async def _process_buffer(self, buffer_name: str): + """ + Process a buffer through inference. + + Args: + buffer_name: "A" or "B" + """ + # Extract buffer to process + async with self.buffer_lock: + if buffer_name == "A": + batch = self.buffer_a.copy() + self.buffer_a.clear() + else: + batch = self.buffer_b.copy() + self.buffer_b.clear() + + if len(batch) == 0: + # Mark as idle and return + async with self.buffer_lock: + if buffer_name == "A": + self.buffer_a_state = BufferState.IDLE + else: + self.buffer_b_state = BufferState.IDLE + return + + # Process batch (outside lock to allow concurrent submissions) + try: + start_time = time.time() + results = await self._run_batch_inference(batch) + inference_time = time.time() - start_time + + # Update statistics + self.total_frames_processed += len(batch) + self.total_batches_processed += 1 + + logger.debug( + f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms " + f"({inference_time*1000/len(batch):.2f}ms per frame)" + ) + + # Emit results to callbacks + for batch_frame, result in zip(batch, results): + callback = self.result_callbacks.get(batch_frame.stream_id) + if callback: + # Schedule callback asynchronously + if asyncio.iscoroutinefunction(callback): + asyncio.create_task(callback(result)) + else: + # Run sync callback in executor to avoid blocking + loop = asyncio.get_event_loop() + loop.call_soon(lambda cb=callback, r=result: cb(r)) + + except Exception as e: + logger.error(f"Error processing batch: {e}", exc_info=True) + # TODO: Emit error events to streams + + finally: + # Mark buffer as idle + async with self.buffer_lock: + if buffer_name == "A": + self.buffer_a_state = BufferState.IDLE + else: + self.buffer_b_state = BufferState.IDLE + + async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: + """ + Run inference on a batch of frames. + + Args: + batch: List of BatchFrame objects + + Returns: + List of detection results (one per frame) + """ + loop = asyncio.get_event_loop() + + # Check if model supports batching + if self.model_batch_size == 1: + # Process frames one at a time for batch_size=1 models + return await self._run_sequential_inference(batch, loop) + else: + # Use true batching for models that support it + return await self._run_batched_inference(batch, loop) + + async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]: + """Run inference sequentially for batch_size=1 models""" + results = [] + + for batch_frame in batch: + # Preprocess frame + if self.preprocess_fn: + processed = self.preprocess_fn(batch_frame.frame) + else: + # Ensure we have batch dimension + processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame + + # Run inference for this frame + outputs = await loop.run_in_executor( + None, + lambda p=processed: self.model_repository.infer( + self.model_id, + {"images": p}, + synchronize=True + ) + ) + + # Postprocess + if self.postprocess_fn: + try: + detections = self.postprocess_fn(outputs) + except Exception as e: + logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}") + # Return empty detections on error + detections = torch.zeros((0, 6), device=list(outputs.values())[0].device) + else: + detections = outputs + + result = { + "stream_id": batch_frame.stream_id, + "timestamp": batch_frame.timestamp, + "detections": detections, + "metadata": batch_frame.metadata, + } + results.append(result) + + return results + + async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]: + """Run true batched inference for models that support it""" + # Preprocess frames (on GPU) + preprocessed = [] + for batch_frame in batch: + if self.preprocess_fn: + processed = self.preprocess_fn(batch_frame.frame) + # Preprocess may return (1, C, H, W), squeeze to (C, H, W) + if processed.dim() == 4 and processed.shape[0] == 1: + processed = processed.squeeze(0) + else: + processed = batch_frame.frame + preprocessed.append(processed) + + # Stack into batch tensor: (N, C, H, W) + batch_tensor = torch.stack(preprocessed, dim=0) + + # Limit batch size to model's max batch size + if batch_tensor.shape[0] > self.model_batch_size: + logger.warning( + f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}, " + f"will split into sub-batches" + ) + # TODO: Handle splitting into sub-batches + batch_tensor = batch_tensor[:self.model_batch_size] + batch = batch[:self.model_batch_size] + + # Run inference + outputs = await loop.run_in_executor( + None, + lambda: self.model_repository.infer( + self.model_id, + {"images": batch_tensor}, + synchronize=True + ) + ) + + # Postprocess results (split batch back to individual results) + results = [] + for i, batch_frame in enumerate(batch): + # Extract single frame output from batch + frame_output = {} + for k, v in outputs.items(): + # v has shape (N, ...), extract index i and keep batch dimension + frame_output[k] = v[i:i+1] # Shape: (1, ...) + + if self.postprocess_fn: + try: + detections = self.postprocess_fn(frame_output) + except Exception as e: + logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}") + # Return empty detections on error + detections = torch.zeros((0, 6), device=list(outputs.values())[0].device) + else: + detections = frame_output + + result = { + "stream_id": batch_frame.stream_id, + "timestamp": batch_frame.timestamp, + "detections": detections, + "metadata": batch_frame.metadata, + } + results.append(result) + + return results + + async def _process_remaining_buffers(self): + """Process any remaining frames in buffers during shutdown""" + if len(self.buffer_a) > 0: + logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A") + await self._process_buffer("A") + if len(self.buffer_b) > 0: + logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B") + await self._process_buffer("B") + + def get_stats(self) -> Dict[str, Any]: + """Get current buffer statistics""" + return { + "active_buffer": self.active_buffer, + "buffer_a_size": len(self.buffer_a), + "buffer_b_size": len(self.buffer_b), + "buffer_a_state": self.buffer_a_state.value, + "buffer_b_state": self.buffer_b_state.value, + "registered_streams": len(self.result_callbacks), + "total_frames_processed": self.total_frames_processed, + "total_batches_processed": self.total_batches_processed, + "avg_batch_size": ( + self.total_frames_processed / self.total_batches_processed + if self.total_batches_processed > 0 else 0 + ), + } diff --git a/services/stream_connection_manager.py b/services/stream_connection_manager.py new file mode 100644 index 0000000..7c9960e --- /dev/null +++ b/services/stream_connection_manager.py @@ -0,0 +1,566 @@ +""" +StreamConnectionManager - Async orchestration for stream processing with batched inference. + +This module provides high-level connection management for multiple RTSP streams, +coordinating decoders, batched inference, and tracking with an event-driven API. +""" + +import asyncio +import time +from typing import Dict, Optional, Callable, AsyncIterator, Tuple, Any, List +from dataclasses import dataclass +from enum import Enum +import logging + +import torch + +from .model_controller import ModelController +from .stream_decoder import StreamDecoderFactory +from .tracking_factory import TrackingFactory +from .model_repository import TensorRTModelRepository + +logger = logging.getLogger(__name__) + + +class ConnectionStatus(Enum): + """Status of a stream connection""" + CONNECTING = "connecting" + CONNECTED = "connected" + DISCONNECTED = "disconnected" + ERROR = "error" + + +@dataclass +class TrackingResult: + """Result emitted to user callbacks""" + stream_id: str + timestamp: float + tracked_objects: List # List of TrackedObject from TrackingController + detections: List # Raw detections + frame_shape: Tuple[int, int, int] + metadata: Dict + + +class StreamConnection: + """ + Represents a single stream connection with event emission. + + This class wraps a StreamDecoder, polls frames asynchronously, submits them + to the ModelController for batched inference, runs tracking, and emits results + via queues or callbacks. + + Args: + stream_id: Unique identifier for this stream + decoder: StreamDecoder instance + model_controller: ModelController for batched inference + tracking_controller: TrackingController for object tracking + poll_interval: Frame polling interval in seconds (default: 0.01) + """ + + def __init__( + self, + stream_id: str, + decoder, + model_controller: ModelController, + tracking_controller, + poll_interval: float = 0.01, + ): + self.stream_id = stream_id + self.decoder = decoder + self.model_controller = model_controller + self.tracking_controller = tracking_controller + self.poll_interval = poll_interval + + self.status = ConnectionStatus.CONNECTING + self.frame_count = 0 + self.last_frame_time = 0.0 + + # Event emission + self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue() + self.error_queue: asyncio.Queue[Exception] = asyncio.Queue() + + # Tasks + self.poller_task: Optional[asyncio.Task] = None + self.running = False + + async def start(self): + """Start the connection (decoder and frame polling)""" + # Start decoder (runs in background thread) + self.decoder.start() + + # Wait for initial connection (try for up to 10 seconds) + max_wait = 10.0 + wait_interval = 0.5 + elapsed = 0.0 + + while elapsed < max_wait: + await asyncio.sleep(wait_interval) + elapsed += wait_interval + + if self.decoder.is_connected(): + self.status = ConnectionStatus.CONNECTED + logger.info(f"Stream {self.stream_id} connected after {elapsed:.1f}s") + break + else: + # Timeout - but don't fail hard, let it try to connect in background + logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...") + self.status = ConnectionStatus.CONNECTING + + # Start frame polling task + self.running = True + self.poller_task = asyncio.create_task(self._frame_poller()) + + async def stop(self): + """Stop the connection and cleanup""" + logger.info(f"Stopping stream {self.stream_id}...") + self.running = False + + if self.poller_task: + self.poller_task.cancel() + try: + await self.poller_task + except asyncio.CancelledError: + pass + + # Stop decoder + self.decoder.stop() + + # Unregister from model controller + self.model_controller.unregister_callback(self.stream_id) + + self.status = ConnectionStatus.DISCONNECTED + logger.info(f"Stream {self.stream_id} stopped") + + async def _frame_poller(self): + """Poll frames from threaded decoder and submit to model controller""" + last_frame_ptr = None + + while self.running: + try: + # Poll frame from decoder (runs in thread) + frame = self.decoder.get_latest_frame(rgb=True) + + # Check if we got a new frame (avoid reprocessing same frame) + if frame is not None and frame.data_ptr() != last_frame_ptr: + last_frame_ptr = frame.data_ptr() + self.last_frame_time = time.time() + self.frame_count += 1 + + # Submit to model controller for batched inference + await self.model_controller.submit_frame( + stream_id=self.stream_id, + frame=frame, + metadata={ + "frame_number": self.frame_count, + "shape": tuple(frame.shape), + } + ) + + # Check decoder status + if not self.decoder.is_connected(): + if self.status == ConnectionStatus.CONNECTED: + logger.warning(f"Stream {self.stream_id} disconnected") + self.status = ConnectionStatus.DISCONNECTED + # Decoder will auto-reconnect, just update status + await asyncio.sleep(1.0) + if self.decoder.is_connected(): + logger.info(f"Stream {self.stream_id} reconnected") + self.status = ConnectionStatus.CONNECTED + + except Exception as e: + logger.error(f"Error in frame poller for {self.stream_id}: {e}", exc_info=True) + await self.error_queue.put(e) + self.status = ConnectionStatus.ERROR + + # Sleep until next poll + await asyncio.sleep(self.poll_interval) + + async def _handle_inference_result(self, result: Dict[str, Any]): + """ + Callback invoked by ModelController when inference is done. + Runs tracking and emits final result. + + Args: + result: Inference result dictionary + """ + try: + # Extract detections + detections = result["detections"] + + # Run tracking (this is sync, so run in executor) + loop = asyncio.get_event_loop() + tracked_objects = await loop.run_in_executor( + None, + lambda: self._run_tracking_sync(detections) + ) + + # Create tracking result + tracking_result = TrackingResult( + stream_id=self.stream_id, + timestamp=result["timestamp"], + tracked_objects=tracked_objects, + detections=detections, + frame_shape=result["metadata"].get("shape"), + metadata=result["metadata"], + ) + + # Emit to result queue + await self.result_queue.put(tracking_result) + + except Exception as e: + logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True) + await self.error_queue.put(e) + + def _run_tracking_sync(self, detections): + """ + Run tracking synchronously (called from executor). + + Args: + detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id] + + Returns: + List of TrackedObject instances + """ + # Use the TrackingController's internal tracking with detections + # We need to manually update tracks since we already have detections + import torch + + with self.tracking_controller._lock: + self.tracking_controller._frame_count += 1 + + # If no detections, just cleanup and return current tracks + if len(detections) == 0: + self.tracking_controller._cleanup_stale_tracks() + return list(self.tracking_controller._tracks.values()) + + # Run IoU tracking to associate detections with existing tracks + associations = self.tracking_controller._iou_tracking(detections) + + # Update or create tracks + for (det_idx, track_id), detection in zip(associations, detections): + bbox = detection[:4].cpu().tolist() + confidence = float(detection[4]) + class_id = int(detection[5]) if detection.shape[0] > 5 else 0 + + if track_id == -1: + # Create new track + new_track = self.tracking_controller._create_track( + bbox, confidence, class_id, self.tracking_controller._frame_count + ) + self.tracking_controller._tracks[new_track.track_id] = new_track + else: + # Update existing track + self.tracking_controller._tracks[track_id].update( + bbox, confidence, self.tracking_controller._frame_count + ) + + # Cleanup stale tracks + self.tracking_controller._cleanup_stale_tracks() + + return list(self.tracking_controller._tracks.values()) + + async def tracking_results(self) -> AsyncIterator[TrackingResult]: + """ + Async generator for tracking results. + + Usage: + async for result in connection.tracking_results(): + print(result.tracked_objects) + + Yields: + TrackingResult objects as they become available + """ + while self.running or not self.result_queue.empty(): + try: + result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0) + yield result + except asyncio.TimeoutError: + continue + + async def errors(self) -> AsyncIterator[Exception]: + """ + Async generator for errors. + + Yields: + Exception objects as they occur + """ + while self.running or not self.error_queue.empty(): + try: + error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0) + yield error + except asyncio.TimeoutError: + continue + + def get_stats(self) -> Dict[str, Any]: + """Get connection statistics""" + return { + "stream_id": self.stream_id, + "status": self.status.value, + "frame_count": self.frame_count, + "last_frame_time": self.last_frame_time, + "decoder_connected": self.decoder.is_connected(), + "decoder_buffer_size": self.decoder.get_buffer_size(), + "result_queue_size": self.result_queue.qsize(), + "error_queue_size": self.error_queue.qsize(), + } + + +class StreamConnectionManager: + """ + High-level manager for stream connections with batched inference. + + This manager coordinates multiple RTSP streams, batched model inference, + and object tracking through an async event-driven API. + + Args: + gpu_id: GPU device ID (default: 0) + batch_size: Maximum batch size for inference (default: 16) + force_timeout: Force buffer switch timeout in seconds (default: 0.05) + poll_interval: Frame polling interval in seconds (default: 0.01) + + Example: + manager = StreamConnectionManager(gpu_id=0, batch_size=16) + await manager.initialize(model_path="yolov8n.trt", ...) + connection = await manager.connect_stream(rtsp_url, on_tracking_result=callback) + await asyncio.sleep(60) + await manager.shutdown() + """ + + def __init__( + self, + gpu_id: int = 0, + batch_size: int = 16, + force_timeout: float = 0.05, + poll_interval: float = 0.01, + ): + self.gpu_id = gpu_id + self.batch_size = batch_size + self.force_timeout = force_timeout + self.poll_interval = poll_interval + + # Factories + self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id) + self.tracking_factory = TrackingFactory(gpu_id=gpu_id) + self.model_repository = TensorRTModelRepository(gpu_id=gpu_id) + + # Controllers + self.model_controller: Optional[ModelController] = None + self.tracking_controller = None + + # Connections + self.connections: Dict[str, StreamConnection] = {} + + # State + self.initialized = False + + async def initialize( + self, + model_path: str, + model_id: str = "detector", + preprocess_fn: Optional[Callable] = None, + postprocess_fn: Optional[Callable] = None, + num_contexts: int = 4, + ): + """ + Initialize the manager with a model. + + Args: + model_path: Path to TensorRT model file + model_id: Model identifier (default: "detector") + preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess) + postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess) + num_contexts: Number of TensorRT execution contexts (default: 4) + """ + logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}") + + # Load model + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, + lambda: self.model_repository.load_model(model_id, model_path, num_contexts=num_contexts) + ) + logger.info(f"Loaded model {model_id} from {model_path}") + + # Create model controller + self.model_controller = ModelController( + model_repository=self.model_repository, + model_id=model_id, + batch_size=self.batch_size, + force_timeout=self.force_timeout, + preprocess_fn=preprocess_fn, + postprocess_fn=postprocess_fn, + ) + await self.model_controller.start() + + # Create tracking controller + self.tracking_controller = self.tracking_factory.create_controller( + model_repository=self.model_repository, + model_id=model_id, + tracker_type="iou", + ) + logger.info("TrackingController created") + + self.initialized = True + logger.info("StreamConnectionManager initialized successfully") + + async def connect_stream( + self, + rtsp_url: str, + stream_id: Optional[str] = None, + on_tracking_result: Optional[Callable] = None, + on_error: Optional[Callable] = None, + buffer_size: int = 30, + ) -> StreamConnection: + """ + Connect to a stream and start processing. + + Args: + rtsp_url: RTSP stream URL + stream_id: Optional stream identifier (auto-generated if not provided) + on_tracking_result: Optional callback for tracking results (sync or async) + on_error: Optional callback for errors (sync or async) + buffer_size: Decoder buffer size (default: 30) + + Returns: + StreamConnection object for this stream + + Raises: + RuntimeError: If manager is not initialized + ConnectionError: If stream connection fails + """ + if not self.initialized: + raise RuntimeError("Manager not initialized. Call initialize() first.") + + # Generate stream ID if not provided + if stream_id is None: + stream_id = f"stream_{len(self.connections)}" + + logger.info(f"Connecting to stream {stream_id}: {rtsp_url}") + + # Create decoder + decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size) + + # Create connection + connection = StreamConnection( + stream_id=stream_id, + decoder=decoder, + model_controller=self.model_controller, + tracking_controller=self.tracking_controller, + poll_interval=self.poll_interval, + ) + + # Register callback with model controller + self.model_controller.register_callback( + stream_id, + connection._handle_inference_result + ) + + # Start connection + await connection.start() + + # Store connection + self.connections[stream_id] = connection + + # Set up user callbacks if provided + if on_tracking_result: + asyncio.create_task(self._forward_results(connection, on_tracking_result)) + + if on_error: + asyncio.create_task(self._forward_errors(connection, on_error)) + + logger.info(f"Stream {stream_id} connected successfully") + return connection + + async def disconnect_stream(self, stream_id: str): + """ + Disconnect and cleanup a stream. + + Args: + stream_id: Stream identifier to disconnect + """ + connection = self.connections.get(stream_id) + if connection: + await connection.stop() + del self.connections[stream_id] + logger.info(f"Stream {stream_id} disconnected") + + async def disconnect_all(self): + """Disconnect all streams""" + logger.info("Disconnecting all streams...") + stream_ids = list(self.connections.keys()) + for stream_id in stream_ids: + await self.disconnect_stream(stream_id) + + async def shutdown(self): + """Shutdown the manager and cleanup all resources""" + logger.info("Shutting down StreamConnectionManager...") + + # Disconnect all streams + await self.disconnect_all() + + # Stop model controller + if self.model_controller: + await self.model_controller.stop() + + # Note: Model repository cleanup is sync and may cause segfaults + # Leaving cleanup to garbage collection for now + + self.initialized = False + logger.info("StreamConnectionManager shutdown complete") + + async def _forward_results(self, connection: StreamConnection, callback: Callable): + """ + Forward results from connection to user callback. + + Args: + connection: StreamConnection to listen to + callback: User callback (sync or async) + """ + try: + async for result in connection.tracking_results(): + if asyncio.iscoroutinefunction(callback): + await callback(result) + else: + callback(result) + except Exception as e: + logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True) + + async def _forward_errors(self, connection: StreamConnection, callback: Callable): + """ + Forward errors from connection to user callback. + + Args: + connection: StreamConnection to listen to + callback: User callback (sync or async) + """ + try: + async for error in connection.errors(): + if asyncio.iscoroutinefunction(callback): + await callback(error) + else: + callback(error) + except Exception as e: + logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True) + + def get_stats(self) -> Dict[str, Any]: + """ + Get statistics for all connections. + + Returns: + Dictionary with manager and connection statistics + """ + return { + "manager": { + "initialized": self.initialized, + "gpu_id": self.gpu_id, + "num_connections": len(self.connections), + "batch_size": self.batch_size, + "force_timeout": self.force_timeout, + "poll_interval": self.poll_interval, + }, + "model_controller": self.model_controller.get_stats() if self.model_controller else {}, + "connections": { + stream_id: conn.get_stats() + for stream_id, conn in self.connections.items() + }, + } diff --git a/test_event_driven.py b/test_event_driven.py new file mode 100644 index 0000000..3c7f6e3 --- /dev/null +++ b/test_event_driven.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +""" +Test script for event-driven stream processing with batched inference. + +This demonstrates the new AsyncIO-based API for connecting to RTSP streams, +processing frames through batched inference, and receiving tracking results +via callbacks and async generators. +""" + +import asyncio +import os +import time +import logging +from dotenv import load_dotenv + +from services import StreamConnectionManager, YOLOv8Utils, COCO_CLASSES + +# Setup logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +# Example 1: Simple callback pattern +async def example_callback_pattern(): + """Demonstrates the simple callback pattern for a single stream""" + logger.info("=== Example 1: Callback Pattern ===") + + # Load environment variables + load_dotenv() + camera_url = os.getenv('CAMERA_URL_1') + if not camera_url: + logger.error("CAMERA_URL_1 not found in .env file") + return + + # Create manager + manager = StreamConnectionManager( + gpu_id=0, + batch_size=16, + force_timeout=0.05, # 50ms + poll_interval=0.01, # 100 FPS + ) + + # Initialize with YOLOv8 model + model_path = "models/yolov8n.trt" # Adjust path as needed + if not os.path.exists(model_path): + logger.error(f"Model file not found: {model_path}") + return + + await manager.initialize( + model_path=model_path, + model_id="yolo", + preprocess_fn=YOLOv8Utils.preprocess, + postprocess_fn=YOLOv8Utils.postprocess, + ) + + # Define callback for tracking results + def on_tracking_result(result): + logger.info(f"[{result.stream_id}] Frame {result.metadata.get('frame_number', 0)}") + logger.info(f" Timestamp: {result.timestamp:.3f}") + logger.info(f" Tracked objects: {len(result.tracked_objects)}") + + for obj in result.tracked_objects[:5]: # Show first 5 + class_name = COCO_CLASSES.get(obj.class_id, f"Class {obj.class_id}") + logger.info( + f" Track ID {obj.track_id}: {class_name}, " + f"conf={obj.confidence:.2f}, bbox={obj.bbox}" + ) + + def on_error(error): + logger.error(f"Stream error: {error}") + + # Connect to stream + connection = await manager.connect_stream( + rtsp_url=camera_url, + stream_id="camera1", + on_tracking_result=on_tracking_result, + on_error=on_error, + ) + + # Let it run for 30 seconds + logger.info("Processing stream for 30 seconds...") + await asyncio.sleep(30) + + # Get statistics + stats = manager.get_stats() + logger.info("=== Statistics ===") + logger.info(f"Manager: {stats['manager']}") + logger.info(f"Model Controller: {stats['model_controller']}") + logger.info(f"Connection: {stats['connections']['camera1']}") + + # Cleanup + await manager.shutdown() + logger.info("Example 1 complete\n") + + +# Example 2: Async generator pattern with multiple streams +async def example_async_generator_pattern(): + """Demonstrates async generator pattern for multiple streams""" + logger.info("=== Example 2: Async Generator Pattern (Multiple Streams) ===") + + # Load environment variables + load_dotenv() + camera_urls = [] + for i in range(1, 5): # Try to load 4 cameras + url = os.getenv(f'CAMERA_URL_{i}') + if url: + camera_urls.append((url, f"camera{i}")) + + if not camera_urls: + logger.error("No camera URLs found in .env file") + return + + logger.info(f"Found {len(camera_urls)} camera(s)") + + # Create manager with larger batch for multiple streams + manager = StreamConnectionManager( + gpu_id=0, + batch_size=32, # Larger batch for multiple streams + force_timeout=0.05, + ) + + # Initialize + model_path = "models/yolov8n.trt" + if not os.path.exists(model_path): + logger.error(f"Model file not found: {model_path}") + return + + await manager.initialize( + model_path=model_path, + preprocess_fn=YOLOv8Utils.preprocess, + postprocess_fn=YOLOv8Utils.postprocess, + ) + + # Connect to all streams + connections = [] + for url, stream_id in camera_urls: + try: + connection = await manager.connect_stream( + rtsp_url=url, + stream_id=stream_id, + ) + connections.append((connection, stream_id)) + logger.info(f"Connected to {stream_id}") + except Exception as e: + logger.error(f"Failed to connect to {stream_id}: {e}") + + # Process each stream with async generator + async def process_stream(connection, stream_name): + """Process results from a single stream""" + frame_count = 0 + person_detections = 0 + + async for result in connection.tracking_results(): + frame_count += 1 + + # Count person detections (class_id 0 in COCO) + for obj in result.tracked_objects: + if obj.class_id == 0: + person_detections += 1 + + # Log every 10th frame + if frame_count % 10 == 0: + logger.info( + f"[{stream_name}] Processed {frame_count} frames, " + f"{person_detections} person detections" + ) + + # Stop after 100 frames + if frame_count >= 100: + break + + # Run all streams concurrently + tasks = [ + asyncio.create_task(process_stream(conn, name)) + for conn, name in connections + ] + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + # Get final statistics + stats = manager.get_stats() + logger.info("\n=== Final Statistics ===") + logger.info(f"Total connections: {stats['manager']['num_connections']}") + logger.info(f"Frames processed: {stats['model_controller']['total_frames_processed']}") + logger.info(f"Batches processed: {stats['model_controller']['total_batches_processed']}") + logger.info(f"Avg batch size: {stats['model_controller']['avg_batch_size']:.2f}") + + # Cleanup + await manager.shutdown() + logger.info("Example 2 complete\n") + + +# Example 3: Queue-based pattern +async def example_queue_pattern(): + """Demonstrates direct queue access for custom processing""" + logger.info("=== Example 3: Queue-Based Pattern ===") + + # Load environment + load_dotenv() + camera_url = os.getenv('CAMERA_URL_1') + if not camera_url: + logger.error("CAMERA_URL_1 not found in .env file") + return + + # Create manager + manager = StreamConnectionManager(gpu_id=0, batch_size=16) + + # Initialize + model_path = "models/yolov8n.trt" + if not os.path.exists(model_path): + logger.error(f"Model file not found: {model_path}") + return + + await manager.initialize( + model_path=model_path, + preprocess_fn=YOLOv8Utils.preprocess, + postprocess_fn=YOLOv8Utils.postprocess, + ) + + # Connect to stream (no callback) + connection = await manager.connect_stream( + rtsp_url=camera_url, + stream_id="main_camera", + ) + + # Use the built-in queue directly + result_queue = connection.result_queue + + # Process results from queue + processed_count = 0 + while processed_count < 50: # Process 50 frames + try: + result = await asyncio.wait_for(result_queue.get(), timeout=5.0) + processed_count += 1 + + # Custom processing + has_person = any(obj.class_id == 0 for obj in result.tracked_objects) + has_car = any(obj.class_id == 2 for obj in result.tracked_objects) + + if has_person or has_car: + logger.info( + f"Frame {processed_count}: " + f"Person={'Yes' if has_person else 'No'}, " + f"Car={'Yes' if has_car else 'No'}" + ) + + except asyncio.TimeoutError: + logger.warning("Timeout waiting for result") + break + + # Cleanup + await manager.shutdown() + logger.info("Example 3 complete\n") + + +# Example 4: Performance monitoring +async def example_performance_monitoring(): + """Demonstrates real-time performance monitoring""" + logger.info("=== Example 4: Performance Monitoring ===") + + # Load environment + load_dotenv() + camera_url = os.getenv('CAMERA_URL_1') + if not camera_url: + logger.error("CAMERA_URL_1 not found in .env file") + return + + # Create manager + manager = StreamConnectionManager( + gpu_id=0, + batch_size=16, + force_timeout=0.05, + ) + + # Initialize + model_path = "models/yolov8n.trt" + if not os.path.exists(model_path): + logger.error(f"Model file not found: {model_path}") + return + + await manager.initialize( + model_path=model_path, + preprocess_fn=YOLOv8Utils.preprocess, + postprocess_fn=YOLOv8Utils.postprocess, + ) + + # Track performance metrics + frame_times = [] + last_frame_time = None + + def on_tracking_result(result): + nonlocal last_frame_time + current_time = time.time() + + if last_frame_time is not None: + frame_interval = current_time - last_frame_time + frame_times.append(frame_interval) + + last_frame_time = current_time + + # Connect + connection = await manager.connect_stream( + rtsp_url=camera_url, + on_tracking_result=on_tracking_result, + ) + + # Monitor stats periodically + for i in range(6): # Monitor for 60 seconds + await asyncio.sleep(10) + + stats = manager.get_stats() + model_stats = stats['model_controller'] + conn_stats = stats['connections'].get('stream_0', {}) + + logger.info(f"\n=== Stats Update {i+1} ===") + logger.info(f"Buffer A: {model_stats['buffer_a_size']} ({model_stats['buffer_a_state']})") + logger.info(f"Buffer B: {model_stats['buffer_b_size']} ({model_stats['buffer_b_state']})") + logger.info(f"Active buffer: {model_stats['active_buffer']}") + logger.info(f"Total frames: {model_stats['total_frames_processed']}") + logger.info(f"Total batches: {model_stats['total_batches_processed']}") + logger.info(f"Avg batch size: {model_stats['avg_batch_size']:.2f}") + logger.info(f"Decoder frames: {conn_stats.get('frame_count', 0)}") + + if frame_times: + avg_fps = 1.0 / (sum(frame_times) / len(frame_times)) + logger.info(f"Processing FPS: {avg_fps:.2f}") + + # Cleanup + await manager.shutdown() + logger.info("Example 4 complete\n") + + +async def main(): + """Run all examples""" + logger.info("Starting event-driven stream processing tests\n") + + # Choose which example to run + choice = os.getenv('EXAMPLE', '1') + + if choice == '1': + await example_callback_pattern() + elif choice == '2': + await example_async_generator_pattern() + elif choice == '3': + await example_queue_pattern() + elif choice == '4': + await example_performance_monitoring() + elif choice == 'all': + await example_callback_pattern() + await asyncio.sleep(2) + await example_async_generator_pattern() + await asyncio.sleep(2) + await example_queue_pattern() + await asyncio.sleep(2) + await example_performance_monitoring() + else: + logger.error(f"Invalid choice: {choice}") + logger.info("Set EXAMPLE env var to 1, 2, 3, 4, or 'all'") + + logger.info("All tests complete!") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("\nInterrupted by user") + except Exception as e: + logger.error(f"Error: {e}", exc_info=True) diff --git a/test_event_driven_quick.py b/test_event_driven_quick.py new file mode 100755 index 0000000..ca90494 --- /dev/null +++ b/test_event_driven_quick.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +""" +Quick test for event-driven stream processing - runs for 20 seconds. +""" + +import asyncio +import os +import logging +from dotenv import load_dotenv + +from services import StreamConnectionManager, YOLOv8Utils, COCO_CLASSES + +# Setup logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +async def main(): + """Quick test with callback pattern""" + logger.info("=== Quick Event-Driven Test (20 seconds) ===") + + # Load environment variables + load_dotenv() + camera_url = os.getenv('CAMERA_URL_1') + if not camera_url: + logger.error("CAMERA_URL_1 not found in .env file") + return + + # Create manager + manager = StreamConnectionManager( + gpu_id=0, + batch_size=16, + force_timeout=0.05, # 50ms + poll_interval=0.01, # 100 FPS + ) + + # Initialize with YOLOv8 model + model_path = "models/yolov8n.trt" + logger.info(f"Initializing with model: {model_path}") + + await manager.initialize( + model_path=model_path, + model_id="yolo", + preprocess_fn=YOLOv8Utils.preprocess, + postprocess_fn=YOLOv8Utils.postprocess, + ) + + result_count = 0 + + # Define callback for tracking results + def on_tracking_result(result): + nonlocal result_count + result_count += 1 + + if result_count % 5 == 0: # Log every 5th result + logger.info(f"[{result.stream_id}] Frame {result.metadata.get('frame_number', 0)}") + logger.info(f" Tracked objects: {len(result.tracked_objects)}") + + for obj in result.tracked_objects[:3]: # Show first 3 + class_name = COCO_CLASSES.get(obj.class_id, f"Class {obj.class_id}") + logger.info( + f" Track ID {obj.track_id}: {class_name}, " + f"conf={obj.confidence:.2f}" + ) + + def on_error(error): + logger.error(f"Stream error: {error}") + + # Connect to stream + logger.info(f"Connecting to stream...") + connection = await manager.connect_stream( + rtsp_url=camera_url, + stream_id="test_camera", + on_tracking_result=on_tracking_result, + on_error=on_error, + ) + + # Monitor for 20 seconds with stats updates + for i in range(4): # 4 x 5 seconds = 20 seconds + await asyncio.sleep(5) + + stats = manager.get_stats() + model_stats = stats['model_controller'] + + logger.info(f"\n=== Stats Update {i+1}/4 ===") + logger.info(f"Results received: {result_count}") + logger.info(f"Buffer A: {model_stats['buffer_a_size']} ({model_stats['buffer_a_state']})") + logger.info(f"Buffer B: {model_stats['buffer_b_size']} ({model_stats['buffer_b_state']})") + logger.info(f"Active buffer: {model_stats['active_buffer']}") + logger.info(f"Total frames processed: {model_stats['total_frames_processed']}") + logger.info(f"Total batches: {model_stats['total_batches_processed']}") + logger.info(f"Avg batch size: {model_stats['avg_batch_size']:.2f}") + + # Final statistics + stats = manager.get_stats() + logger.info("\n=== Final Statistics ===") + logger.info(f"Total results received: {result_count}") + logger.info(f"Manager: {stats['manager']}") + logger.info(f"Model Controller: {stats['model_controller']}") + logger.info(f"Connection: {stats['connections']['test_camera']}") + + # Cleanup + logger.info("\nShutting down...") + await manager.shutdown() + logger.info("Test complete!") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("\nInterrupted by user") + except Exception as e: + logger.error(f"Error: {e}", exc_info=True) diff --git a/test_tracking_realtime.py b/test_tracking_realtime.py index f6ba766..e5bc57c 100644 --- a/test_tracking_realtime.py +++ b/test_tracking_realtime.py @@ -509,7 +509,7 @@ def main_multi_window(): if __name__ == "__main__": # Run single camera visualization - main() + # main() # Uncomment to run multi-window visualization - # main_multi_window() + main_multi_window()