diff --git a/EVENT_DRIVEN_DESIGN.md b/EVENT_DRIVEN_DESIGN.md deleted file mode 100644 index 45d908a..0000000 --- a/EVENT_DRIVEN_DESIGN.md +++ /dev/null @@ -1,1108 +0,0 @@ -# Event-Driven Stream Processing Architecture with Batching - -## Overview - -This document describes the AsyncIO-based event-driven architecture for connecting stream decoders to models and tracking, with support for batched inference using ping-pong circular buffers. - -## Architecture Diagram - -``` -┌─────────────────────────────────────────────────────────────────┐ -│ StreamConnectionManager │ -│ - Manages multiple stream connections │ -│ - Routes events to user callbacks/generators │ -│ - Coordinates ModelController and TrackingController │ -└─────────────────────────────────────────────────────────────────┘ - │ - ├──────────────────┬──────────────────┐ - ▼ ▼ ▼ - ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ - │ StreamConnection│ │ StreamConnection│ │ StreamConnection│ - │ (Stream 1) │ │ (Stream 2) │ │ (Stream N) │ - │ │ │ │ │ │ - │ - StreamDecoder │ │ - StreamDecoder │ │ - StreamDecoder │ - │ - Frame Poller │ │ - Frame Poller │ │ - Frame Poller │ - │ - Event Emitter │ │ - Event Emitter │ │ - Event Emitter │ - └─────────────────┘ └─────────────────┘ └─────────────────┘ - │ │ │ - └──────────────────┴──────────────────┘ - │ - ▼ - ┌─────────────────────────────────────┐ - │ ModelController │ - │ ┌────────────┐ ┌────────────┐ │ - │ │ Buffer A │ │ Buffer B │ │ - │ │ (Active) │ │(Processing)│ │ - │ │ [frame1] │ │ [frame9] │ │ - │ │ [frame2] │ │ [frame10] │ │ - │ │ [frame3] │ │ [...] │ │ - │ └────────────┘ └────────────┘ │ - │ │ - │ - Batch accumulation │ - │ - Force timeout monitor │ - │ - Ping-pong switching │ - └─────────────────────────────────────┘ - │ - ┌────────────┴────────────┐ - ▼ ▼ - ┌─────────────────────┐ ┌─────────────────────┐ - │ TensorRTModelRepo │ │ TrackingController │ - │ - Batched inference │ │ - Track association │ - │ - Context pooling │ │ - Track management │ - └─────────────────────┘ └─────────────────────┘ - │ - ▼ - ┌─────────────────────────┐ - │ User Callbacks/Queues │ - │ - on_tracking_result │ - │ - on_detections │ - │ - on_error │ - └─────────────────────────┘ -``` - -## Component Details - -### 1. ModelController (Async Batching Layer) - -**Responsibilities:** -- Accumulate frames from multiple streams into batches -- Manage ping-pong buffers (BufferA/BufferB) -- Monitor force-switch timeout -- Execute batched inference -- Route results back to streams - -**Ping-Pong Buffer Logic:** -- **BufferA (Active)**: Accumulates incoming frames -- **BufferB (Processing)**: Being processed through inference -- **Switch Triggers:** - 1. Active buffer reaches `batch_size` → immediate swap - 2. `force_timeout` expires AND processing buffer is idle → force swap - 3. Never switch if processing buffer is busy - -### 2. StreamConnectionManager - -**Responsibilities:** -- Create and manage stream connections -- Coordinate ModelController and TrackingController -- Route tracking results to user callbacks/generators -- Handle stream lifecycle (connect, disconnect, errors) - -### 3. StreamConnection - -**Responsibilities:** -- Wrap a single StreamDecoder -- Poll frames from threaded decoder (bridge to async) -- Submit frames to ModelController -- Emit events to user code - ---- - -## Pseudo Code Implementation - -### 1. ModelController with Ping-Pong Buffers - -```python -import asyncio -import torch -from typing import Dict, List, Tuple, Optional, Callable -from dataclasses import dataclass -from enum import Enum -import time - -@dataclass -class BatchFrame: - """Represents a frame in the batch buffer""" - stream_id: str - frame: torch.Tensor # GPU tensor (3, H, W) - timestamp: float - metadata: Dict = None - -class BufferState(Enum): - IDLE = "idle" - FILLING = "filling" - PROCESSING = "processing" - -class ModelController: - """ - Manages batched inference with ping-pong buffers and force-switch timeout. - """ - - def __init__( - self, - model_repository, - model_id: str, - batch_size: int = 16, - force_timeout: float = 0.05, # 50ms - preprocess_fn: Callable = None, - postprocess_fn: Callable = None, - ): - self.model_repository = model_repository - self.model_id = model_id - self.batch_size = batch_size - self.force_timeout = force_timeout - self.preprocess_fn = preprocess_fn - self.postprocess_fn = postprocess_fn - - # Ping-pong buffers - self.buffer_a: List[BatchFrame] = [] - self.buffer_b: List[BatchFrame] = [] - - # Buffer states - self.active_buffer = "A" # Which buffer is currently active (filling) - self.buffer_a_state = BufferState.IDLE - self.buffer_b_state = BufferState.IDLE - - # Async coordination - self.buffer_lock = asyncio.Lock() - self.last_submit_time = time.time() - - # Tasks - self.timeout_task: Optional[asyncio.Task] = None - self.processor_task: Optional[asyncio.Task] = None - self.running = False - - # Result callbacks (stream_id -> callback) - self.result_callbacks: Dict[str, Callable] = {} - - async def start(self): - """Start the controller background tasks""" - self.running = True - self.timeout_task = asyncio.create_task(self._timeout_monitor()) - self.processor_task = asyncio.create_task(self._batch_processor()) - - async def stop(self): - """Stop the controller and cleanup""" - self.running = False - if self.timeout_task: - self.timeout_task.cancel() - if self.processor_task: - self.processor_task.cancel() - - # Process any remaining frames - await self._process_remaining_buffers() - - def register_callback(self, stream_id: str, callback: Callable): - """Register a callback for inference results from a stream""" - self.result_callbacks[stream_id] = callback - - def unregister_callback(self, stream_id: str): - """Unregister a stream callback""" - self.result_callbacks.pop(stream_id, None) - - async def submit_frame(self, stream_id: str, frame: torch.Tensor, metadata: Dict = None): - """ - Submit a frame for batched inference. - - Args: - stream_id: Unique stream identifier - frame: GPU tensor (3, H, W) - metadata: Optional metadata to attach - """ - async with self.buffer_lock: - batch_frame = BatchFrame( - stream_id=stream_id, - frame=frame, - timestamp=time.time(), - metadata=metadata or {} - ) - - # Add to active buffer - if self.active_buffer == "A": - self.buffer_a.append(batch_frame) - self.buffer_a_state = BufferState.FILLING - buffer_size = len(self.buffer_a) - else: - self.buffer_b.append(batch_frame) - self.buffer_b_state = BufferState.FILLING - buffer_size = len(self.buffer_b) - - self.last_submit_time = time.time() - - # Check if we should immediately swap (batch full) - if buffer_size >= self.batch_size: - await self._try_swap_buffers() - - async def _timeout_monitor(self): - """Monitor force-switch timeout""" - while self.running: - await asyncio.sleep(0.01) # Check every 10ms - - async with self.buffer_lock: - time_since_submit = time.time() - self.last_submit_time - - # Check if timeout expired and we have frames waiting - if time_since_submit >= self.force_timeout: - active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b - if len(active_buffer) > 0: - await self._try_swap_buffers() - - async def _try_swap_buffers(self): - """ - Attempt to swap ping-pong buffers. - Only swaps if the inactive buffer is not currently processing. - - This method should be called with buffer_lock held. - """ - # Check if inactive buffer is available - inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state - - if inactive_state != BufferState.PROCESSING: - # Swap active buffer - old_active = self.active_buffer - self.active_buffer = "B" if old_active == "A" else "A" - - # Mark old active buffer as ready for processing - if old_active == "A": - self.buffer_a_state = BufferState.PROCESSING - else: - self.buffer_b_state = BufferState.PROCESSING - - # Signal processor that there's work to do - # (The processor task is already running and will pick it up) - - async def _batch_processor(self): - """Background task that processes batches when available""" - while self.running: - await asyncio.sleep(0.001) # Check every 1ms - - # Check if buffer A needs processing - if self.buffer_a_state == BufferState.PROCESSING: - await self._process_buffer("A") - - # Check if buffer B needs processing - if self.buffer_b_state == BufferState.PROCESSING: - await self._process_buffer("B") - - async def _process_buffer(self, buffer_name: str): - """ - Process a buffer through inference. - - Args: - buffer_name: "A" or "B" - """ - async with self.buffer_lock: - # Get buffer to process - if buffer_name == "A": - batch = self.buffer_a.copy() - self.buffer_a.clear() - else: - batch = self.buffer_b.copy() - self.buffer_b.clear() - - if len(batch) == 0: - # Mark as idle and return - async with self.buffer_lock: - if buffer_name == "A": - self.buffer_a_state = BufferState.IDLE - else: - self.buffer_b_state = BufferState.IDLE - return - - # Process batch (outside lock to allow concurrent submissions) - try: - results = await self._run_batch_inference(batch) - - # Emit results to callbacks - for batch_frame, result in zip(batch, results): - callback = self.result_callbacks.get(batch_frame.stream_id) - if callback: - # Schedule callback asynchronously - if asyncio.iscoroutinefunction(callback): - asyncio.create_task(callback(result)) - else: - callback(result) - - except Exception as e: - print(f"Error processing batch: {e}") - # TODO: Emit error events - - finally: - # Mark buffer as idle - async with self.buffer_lock: - if buffer_name == "A": - self.buffer_a_state = BufferState.IDLE - else: - self.buffer_b_state = BufferState.IDLE - - async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict]: - """ - Run inference on a batch of frames. - - Args: - batch: List of BatchFrame objects - - Returns: - List of detection results (one per frame) - """ - # Preprocess frames (on GPU) - preprocessed = [] - for batch_frame in batch: - if self.preprocess_fn: - processed = self.preprocess_fn(batch_frame.frame) - else: - processed = batch_frame.frame - preprocessed.append(processed) - - # Stack into batch tensor: (N, C, H, W) - batch_tensor = torch.stack(preprocessed, dim=0) - - # Run inference (TensorRT model repository is sync, so run in executor) - loop = asyncio.get_event_loop() - outputs = await loop.run_in_executor( - None, - lambda: self.model_repository.infer( - self.model_id, - {"images": batch_tensor}, - synchronize=True - ) - ) - - # Postprocess results (split batch back to individual results) - results = [] - for i, batch_frame in enumerate(batch): - # Extract single frame output from batch - frame_output = {k: v[i:i+1] for k, v in outputs.items()} - - if self.postprocess_fn: - detections = self.postprocess_fn(frame_output) - else: - detections = frame_output - - result = { - "stream_id": batch_frame.stream_id, - "timestamp": batch_frame.timestamp, - "detections": detections, - "metadata": batch_frame.metadata, - } - results.append(result) - - return results - - async def _process_remaining_buffers(self): - """Process any remaining frames in buffers during shutdown""" - if len(self.buffer_a) > 0: - await self._process_buffer("A") - if len(self.buffer_b) > 0: - await self._process_buffer("B") - - def get_stats(self) -> Dict: - """Get current buffer statistics""" - return { - "active_buffer": self.active_buffer, - "buffer_a_size": len(self.buffer_a), - "buffer_b_size": len(self.buffer_b), - "buffer_a_state": self.buffer_a_state.value, - "buffer_b_state": self.buffer_b_state.value, - "registered_streams": len(self.result_callbacks), - } -``` - -### 2. StreamConnectionManager - -```python -import asyncio -from typing import Dict, Optional, Callable, AsyncIterator -from dataclasses import dataclass -from enum import Enum - -class ConnectionStatus(Enum): - CONNECTING = "connecting" - CONNECTED = "connected" - DISCONNECTED = "disconnected" - ERROR = "error" - -@dataclass -class TrackingResult: - """Result emitted to user callbacks""" - stream_id: str - timestamp: float - tracked_objects: List # List of TrackedObject from TrackingController - detections: List # Raw detections - frame_shape: Tuple[int, int, int] - metadata: Dict - -class StreamConnection: - """Represents a single stream connection with event emission""" - - def __init__( - self, - stream_id: str, - decoder, - model_controller: ModelController, - tracking_controller, - poll_interval: float = 0.01, - ): - self.stream_id = stream_id - self.decoder = decoder - self.model_controller = model_controller - self.tracking_controller = tracking_controller - self.poll_interval = poll_interval - - self.status = ConnectionStatus.CONNECTING - self.frame_count = 0 - self.last_frame_time = 0.0 - - # Event emission - self.result_queue: asyncio.Queue[TrackingResult] = asyncio.Queue() - self.error_queue: asyncio.Queue[Exception] = asyncio.Queue() - - # Tasks - self.poller_task: Optional[asyncio.Task] = None - self.running = False - - async def start(self): - """Start the connection (decoder and frame polling)""" - # Start decoder (runs in background thread) - self.decoder.start() - - # Wait for initial connection - await asyncio.sleep(2.0) - - if self.decoder.is_connected(): - self.status = ConnectionStatus.CONNECTED - else: - self.status = ConnectionStatus.ERROR - raise ConnectionError(f"Failed to connect to stream {self.stream_id}") - - # Start frame polling task - self.running = True - self.poller_task = asyncio.create_task(self._frame_poller()) - - async def stop(self): - """Stop the connection and cleanup""" - self.running = False - - if self.poller_task: - self.poller_task.cancel() - try: - await self.poller_task - except asyncio.CancelledError: - pass - - # Stop decoder - self.decoder.stop() - - # Unregister from model controller - self.model_controller.unregister_callback(self.stream_id) - - self.status = ConnectionStatus.DISCONNECTED - - async def _frame_poller(self): - """Poll frames from threaded decoder and submit to model controller""" - last_frame_ptr = None - - while self.running: - try: - # Poll frame from decoder (runs in thread) - frame = self.decoder.get_latest_frame(rgb=True) - - # Check if we got a new frame (avoid reprocessing same frame) - if frame is not None and frame.data_ptr() != last_frame_ptr: - last_frame_ptr = frame.data_ptr() - self.last_frame_time = time.time() - self.frame_count += 1 - - # Submit to model controller for batched inference - await self.model_controller.submit_frame( - stream_id=self.stream_id, - frame=frame, - metadata={ - "frame_number": self.frame_count, - "shape": frame.shape, - } - ) - - # Check decoder status - if not self.decoder.is_connected(): - self.status = ConnectionStatus.DISCONNECTED - # Decoder will auto-reconnect, just update status - await asyncio.sleep(1.0) - if self.decoder.is_connected(): - self.status = ConnectionStatus.CONNECTED - - except Exception as e: - await self.error_queue.put(e) - self.status = ConnectionStatus.ERROR - - # Sleep until next poll - await asyncio.sleep(self.poll_interval) - - async def _handle_inference_result(self, result: Dict): - """ - Callback invoked by ModelController when inference is done. - Runs tracking and emits final result. - """ - try: - # Extract detections - detections = result["detections"] - - # Run tracking (this is sync, so run in executor) - loop = asyncio.get_event_loop() - tracked_objects = await loop.run_in_executor( - None, - lambda: self.tracking_controller.update_tracks(detections) - ) - - # Create tracking result - tracking_result = TrackingResult( - stream_id=self.stream_id, - timestamp=result["timestamp"], - tracked_objects=tracked_objects, - detections=detections, - frame_shape=result["metadata"].get("shape"), - metadata=result["metadata"], - ) - - # Emit to result queue - await self.result_queue.put(tracking_result) - - except Exception as e: - await self.error_queue.put(e) - - async def tracking_results(self) -> AsyncIterator[TrackingResult]: - """ - Async generator for tracking results. - - Usage: - async for result in connection.tracking_results(): - print(result.tracked_objects) - """ - while self.running or not self.result_queue.empty(): - try: - result = await asyncio.wait_for(self.result_queue.get(), timeout=1.0) - yield result - except asyncio.TimeoutError: - continue - - async def errors(self) -> AsyncIterator[Exception]: - """Async generator for errors""" - while self.running or not self.error_queue.empty(): - try: - error = await asyncio.wait_for(self.error_queue.get(), timeout=1.0) - yield error - except asyncio.TimeoutError: - continue - - def get_stats(self) -> Dict: - """Get connection statistics""" - return { - "stream_id": self.stream_id, - "status": self.status.value, - "frame_count": self.frame_count, - "last_frame_time": self.last_frame_time, - "decoder_connected": self.decoder.is_connected(), - "decoder_buffer_size": self.decoder.get_buffer_size(), - } - - -class StreamConnectionManager: - """ - High-level manager for stream connections with batched inference. - """ - - def __init__( - self, - gpu_id: int = 0, - batch_size: int = 16, - force_timeout: float = 0.05, - poll_interval: float = 0.01, - ): - self.gpu_id = gpu_id - self.batch_size = batch_size - self.force_timeout = force_timeout - self.poll_interval = poll_interval - - # Factories - from services import StreamDecoderFactory, TrackingFactory - from services.model_repository import TensorRTModelRepository - - self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id) - self.tracking_factory = TrackingFactory(gpu_id=gpu_id) - self.model_repository = TensorRTModelRepository(gpu_id=gpu_id) - - # Controllers - self.model_controller: Optional[ModelController] = None - self.tracking_controller = None - - # Connections - self.connections: Dict[str, StreamConnection] = {} - - # State - self.initialized = False - - async def initialize( - self, - model_path: str, - model_id: str = "detector", - preprocess_fn: Callable = None, - postprocess_fn: Callable = None, - ): - """ - Initialize the manager with a model. - - Args: - model_path: Path to TensorRT model file - model_id: Model identifier - preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess) - postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess) - """ - # Load model - loop = asyncio.get_event_loop() - await loop.run_in_executor( - None, - lambda: self.model_repository.load_model(model_id, model_path, num_contexts=4) - ) - - # Create model controller - self.model_controller = ModelController( - model_repository=self.model_repository, - model_id=model_id, - batch_size=self.batch_size, - force_timeout=self.force_timeout, - preprocess_fn=preprocess_fn, - postprocess_fn=postprocess_fn, - ) - await self.model_controller.start() - - # Create tracking controller - self.tracking_controller = self.tracking_factory.create_controller( - model_repository=self.model_repository, - model_id=model_id, - tracker_type="iou", - ) - - self.initialized = True - - async def connect_stream( - self, - rtsp_url: str, - stream_id: Optional[str] = None, - on_tracking_result: Optional[Callable] = None, - on_error: Optional[Callable] = None, - ) -> StreamConnection: - """ - Connect to a stream and start processing. - - Args: - rtsp_url: RTSP stream URL - stream_id: Optional stream identifier (auto-generated if not provided) - on_tracking_result: Optional callback for tracking results - on_error: Optional callback for errors - - Returns: - StreamConnection object for this stream - """ - if not self.initialized: - raise RuntimeError("Manager not initialized. Call initialize() first.") - - # Generate stream ID if not provided - if stream_id is None: - stream_id = f"stream_{len(self.connections)}" - - # Create decoder - decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=30) - - # Create connection - connection = StreamConnection( - stream_id=stream_id, - decoder=decoder, - model_controller=self.model_controller, - tracking_controller=self.tracking_controller, - poll_interval=self.poll_interval, - ) - - # Register callback with model controller - self.model_controller.register_callback( - stream_id, - connection._handle_inference_result - ) - - # Start connection - await connection.start() - - # Store connection - self.connections[stream_id] = connection - - # Set up user callbacks if provided - if on_tracking_result: - asyncio.create_task(self._forward_results(connection, on_tracking_result)) - - if on_error: - asyncio.create_task(self._forward_errors(connection, on_error)) - - return connection - - async def disconnect_stream(self, stream_id: str): - """Disconnect and cleanup a stream""" - connection = self.connections.get(stream_id) - if connection: - await connection.stop() - del self.connections[stream_id] - - async def disconnect_all(self): - """Disconnect all streams""" - stream_ids = list(self.connections.keys()) - for stream_id in stream_ids: - await self.disconnect_stream(stream_id) - - async def shutdown(self): - """Shutdown the manager and cleanup all resources""" - # Disconnect all streams - await self.disconnect_all() - - # Stop model controller - if self.model_controller: - await self.model_controller.stop() - - # Cleanup (model repository cleanup is sync) - # Note: May need to handle cleanup carefully to avoid segfaults - - async def _forward_results(self, connection: StreamConnection, callback: Callable): - """Forward results from connection to user callback""" - async for result in connection.tracking_results(): - if asyncio.iscoroutinefunction(callback): - await callback(result) - else: - callback(result) - - async def _forward_errors(self, connection: StreamConnection, callback: Callable): - """Forward errors from connection to user callback""" - async for error in connection.errors(): - if asyncio.iscoroutinefunction(callback): - await callback(error) - else: - callback(error) - - def get_stats(self) -> Dict: - """Get statistics for all connections""" - return { - "manager": { - "initialized": self.initialized, - "num_connections": len(self.connections), - "batch_size": self.batch_size, - "force_timeout": self.force_timeout, - }, - "model_controller": self.model_controller.get_stats() if self.model_controller else {}, - "connections": { - stream_id: conn.get_stats() - for stream_id, conn in self.connections.items() - }, - } -``` - -### 3. User API Examples - -#### Example 1: Simple Callback Pattern - -```python -import asyncio -from services import StreamConnectionManager -from services.yolo import YOLOv8Utils - -async def main(): - # Create manager - manager = StreamConnectionManager( - gpu_id=0, - batch_size=16, - force_timeout=0.05, # 50ms - ) - - # Initialize with model - await manager.initialize( - model_path="models/yolov8n.trt", - model_id="yolo", - preprocess_fn=YOLOv8Utils.preprocess, - postprocess_fn=YOLOv8Utils.postprocess, - ) - - # Define callback for tracking results - def on_tracking_result(result): - print(f"Stream: {result.stream_id}") - print(f"Timestamp: {result.timestamp}") - print(f"Tracked objects: {len(result.tracked_objects)}") - for obj in result.tracked_objects: - print(f" - Track ID {obj.track_id}: class={obj.class_id}, conf={obj.confidence:.2f}") - - def on_error(error): - print(f"Error: {error}") - - # Connect to stream - connection = await manager.connect_stream( - rtsp_url="rtsp://camera1.example.com/stream", - stream_id="camera1", - on_tracking_result=on_tracking_result, - on_error=on_error, - ) - - # Let it run for 60 seconds - await asyncio.sleep(60) - - # Get statistics - stats = manager.get_stats() - print(f"Stats: {stats}") - - # Cleanup - await manager.shutdown() - -if __name__ == "__main__": - asyncio.run(main()) -``` - -#### Example 2: Async Generator Pattern (Multiple Streams) - -```python -import asyncio -from services import StreamConnectionManager -from services.yolo import YOLOv8Utils - -async def process_stream(connection, stream_name): - """Process results from a single stream""" - async for result in connection.tracking_results(): - print(f"[{stream_name}] Frame {result.metadata['frame_number']}: {len(result.tracked_objects)} objects") - - # Do something with tracked objects - for obj in result.tracked_objects: - if obj.class_id == 0: # Person class - print(f" Person detected! Track ID: {obj.track_id}, Conf: {obj.confidence:.2f}") - -async def main(): - manager = StreamConnectionManager( - gpu_id=0, - batch_size=32, # Larger batch for multiple streams - force_timeout=0.05, - ) - - await manager.initialize( - model_path="models/yolov8n.trt", - preprocess_fn=YOLOv8Utils.preprocess, - postprocess_fn=YOLOv8Utils.postprocess, - ) - - # Connect to multiple streams - camera_urls = [ - ("rtsp://camera1.example.com/stream", "Front Door"), - ("rtsp://camera2.example.com/stream", "Parking Lot"), - ("rtsp://camera3.example.com/stream", "Warehouse"), - ("rtsp://camera4.example.com/stream", "Loading Bay"), - ] - - tasks = [] - for url, name in camera_urls: - connection = await manager.connect_stream( - rtsp_url=url, - stream_id=name.lower().replace(" ", "_"), - ) - - # Create task to process this stream - task = asyncio.create_task(process_stream(connection, name)) - tasks.append(task) - - # Run all streams concurrently - try: - await asyncio.gather(*tasks) - except KeyboardInterrupt: - print("Shutting down...") - - await manager.shutdown() - -if __name__ == "__main__": - asyncio.run(main()) -``` - -#### Example 3: Queue-Based Pattern (for integration with other systems) - -```python -import asyncio -from services import StreamConnectionManager -from services.yolo import YOLOv8Utils - -async def main(): - manager = StreamConnectionManager(gpu_id=0, batch_size=16) - - await manager.initialize( - model_path="models/yolov8n.trt", - preprocess_fn=YOLOv8Utils.preprocess, - postprocess_fn=YOLOv8Utils.postprocess, - ) - - # Connect to stream (no callback) - connection = await manager.connect_stream( - rtsp_url="rtsp://camera.example.com/stream", - stream_id="main_camera", - ) - - # Use the built-in queue - result_queue = connection.result_queue - - # Process results from queue - while True: - result = await result_queue.get() - - # Send to external system (e.g., message queue, database, API) - await send_to_kafka(result) - await save_to_database(result) - - # Or do real-time processing - if has_person_alert(result.tracked_objects): - await send_alert("Person detected in restricted area!") - -async def send_to_kafka(result): - # Your Kafka producer code - pass - -async def save_to_database(result): - # Your database code - pass - -def has_person_alert(tracked_objects): - # Your alert logic - return any(obj.class_id == 0 for obj in tracked_objects) - -async def send_alert(message): - print(f"ALERT: {message}") - -if __name__ == "__main__": - asyncio.run(main()) -``` - -#### Example 4: Async Callback with Error Handling - -```python -import asyncio -from services import StreamConnectionManager -from services.yolo import YOLOv8Utils - -async def main(): - manager = StreamConnectionManager(gpu_id=0, batch_size=16) - - await manager.initialize( - model_path="models/yolov8n.trt", - preprocess_fn=YOLOv8Utils.preprocess, - postprocess_fn=YOLOv8Utils.postprocess, - ) - - # Async callback (can do I/O operations) - async def on_tracking_result(result): - # Can use async operations in callback - await save_to_database(result) - - # Check for alerts - for obj in result.tracked_objects: - if obj.class_id == 0 and obj.confidence > 0.8: - await send_notification(f"High confidence person detection: {obj.track_id}") - - async def on_error(error): - await log_error_to_monitoring_system(error) - - # Connect with async callbacks - connection = await manager.connect_stream( - rtsp_url="rtsp://camera.example.com/stream", - on_tracking_result=on_tracking_result, - on_error=on_error, - ) - - # Monitor stats periodically - while True: - await asyncio.sleep(10) - stats = manager.get_stats() - print(f"Buffer stats: {stats['model_controller']}") - print(f"Connection stats: {stats['connections']}") - -async def save_to_database(result): - # Simulate async database operation - await asyncio.sleep(0.01) - -async def send_notification(message): - print(f"NOTIFICATION: {message}") - -async def log_error_to_monitoring_system(error): - print(f"ERROR: {error}") - -if __name__ == "__main__": - asyncio.run(main()) -``` - ---- - -## Configuration Examples - -### Performance Tuning - -```python -# Low latency (small batches, quick timeout) -manager = StreamConnectionManager( - gpu_id=0, - batch_size=4, - force_timeout=0.02, # 20ms - poll_interval=0.005, # 200 FPS -) - -# High throughput (large batches, longer timeout) -manager = StreamConnectionManager( - gpu_id=0, - batch_size=32, - force_timeout=0.1, # 100ms - poll_interval=0.02, # 50 FPS -) - -# Balanced (default) -manager = StreamConnectionManager( - gpu_id=0, - batch_size=16, - force_timeout=0.05, # 50ms - poll_interval=0.01, # 100 FPS -) -``` - -### Multiple GPUs - -```python -# Create manager per GPU -manager_gpu0 = StreamConnectionManager(gpu_id=0, batch_size=16) -manager_gpu1 = StreamConnectionManager(gpu_id=1, batch_size=16) - -# Initialize both -await manager_gpu0.initialize(model_path="models/yolov8n.trt", ...) -await manager_gpu1.initialize(model_path="models/yolov8n.trt", ...) - -# Distribute streams across GPUs -await manager_gpu0.connect_stream(url1, ...) -await manager_gpu0.connect_stream(url2, ...) -await manager_gpu1.connect_stream(url3, ...) -await manager_gpu1.connect_stream(url4, ...) -``` - ---- - -## Key Features Summary - -1. **Ping-Pong Buffers**: Efficient batching with minimal latency -2. **Force Timeout**: Prevents starvation of small batches -3. **AsyncIO**: Clean event-driven architecture -4. **Multiple Patterns**: Callbacks, generators, queues -5. **Thread-Async Bridge**: Integrates with existing threaded decoders -6. **Zero-Copy**: All processing stays on GPU -7. **Auto-Reconnection**: Inherits from StreamDecoder -8. **Statistics**: Real-time monitoring of buffers and connections - ---- - -## Performance Characteristics - -- **Latency**: `force_timeout + inference_time` -- **Throughput**: Maximized by batching -- **VRAM**: 60MB per stream + batch buffer overhead -- **CPU**: Minimal (async event loop + thread polling) - ---- - -## Next Steps - -To implement this design: - -1. Create `services/model_controller.py` with `ModelController` class -2. Create `services/stream_connection_manager.py` with `StreamConnectionManager` and `StreamConnection` classes -3. Update `services/__init__.py` to export new classes -4. Create `test_event_driven.py` to test the system -5. Add monitoring/logging throughout -6. Handle edge cases (reconnection, cleanup, errors) diff --git a/ISSUES.md b/ISSUES.md index f565f42..adbba6c 100644 --- a/ISSUES.md +++ b/ISSUES.md @@ -2,4 +2,8 @@ - It doesn't really care what pt file is included and it always use YOLO's model id, for example if id 1 is apple, it still say person. maybe extract class list from yolo's .pt somehow? +- It read frame a bit too fast. it say it's infering at 20-ish fps but the actual camera is only 5 fps or so + - Potential race condition issue when multiple camera try to init with the same unconverted model. + +- Blurry asyncio archtecture, require documentations \ No newline at end of file diff --git a/scripts/profiling.py b/scripts/profiling.py new file mode 100644 index 0000000..7b760d6 --- /dev/null +++ b/scripts/profiling.py @@ -0,0 +1,165 @@ +""" +Profiling script for the real-time object tracking pipeline. + +This script runs the single-stream example from test_tracking_realtime.py +under the Python profiler (cProfile) to identify performance bottlenecks. + +Usage: + python scripts/profiling.py + +The script will print a summary of the most time-consuming functions +at the end of the run. +""" + +import asyncio +import cProfile +import pstats +import io +import time +import os +import torch +import cv2 +from dotenv import load_dotenv + +# Add project root to path to allow imports from services +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from services import ( + StreamConnectionManager, + YOLOv8Utils, +) + +# Load environment variables +load_dotenv() + + +async def profiled_main(): + """ + Single stream example with event-driven architecture, adapted for profiling. + This function is a modified version of main_single_stream from test_tracking_realtime.py + """ + print("=" * 80) + print("Profiling: Event-Driven GPU-Accelerated Object Tracking") + print("=" * 80) + + # Configuration + GPU_ID = 0 + MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" + STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test') + BATCH_SIZE = 4 + FORCE_TIMEOUT = 0.05 + # NOTE: Display is disabled for profiling to isolate pipeline performance + ENABLE_DISPLAY = False + # Run for a limited number of frames to get a representative profile + MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300')) + + print(f"\nConfiguration:") + print(f" GPU: {GPU_ID}") + print(f" Model: {MODEL_PATH}") + print(f" Stream: {STREAM_URL}") + print(f" Batch size: {BATCH_SIZE}") + print(f" Force timeout: {FORCE_TIMEOUT}s") + print(f" Display: Disabled for profiling") + print(f" Max frames: {MAX_FRAMES}\n") + + # Create StreamConnectionManager + print("[1/3] Creating StreamConnectionManager...") + manager = StreamConnectionManager( + gpu_id=GPU_ID, + batch_size=BATCH_SIZE, + force_timeout=FORCE_TIMEOUT, + enable_pt_conversion=True + ) + print("✓ Manager created") + + # Initialize with PT model + print("\n[2/3] Initializing with PT model...") + try: + await manager.initialize( + model_path=MODEL_PATH, + model_id="detector", + preprocess_fn=YOLOv8Utils.preprocess, + postprocess_fn=YOLOv8Utils.postprocess, + num_contexts=4, + pt_input_shapes={"images": (1, 3, 640, 640)}, + pt_precision=torch.float16 + ) + print("✓ Manager initialized") + except Exception as e: + print(f"✗ Failed to initialize: {e}") + return + + # Connect stream + print("\n[3/3] Connecting to stream...") + try: + connection = await manager.connect_stream( + rtsp_url=STREAM_URL, + stream_id="camera_1", + buffer_size=30 + ) + print(f"✓ Stream connected: camera_1") + except Exception as e: + print(f"✗ Failed to connect stream: {e}") + await manager.shutdown() + return + + print(f"\n{'=' * 80}") + print(f"Profiling is running for {MAX_FRAMES} frames...") + print(f"{ '=' * 80}\n") + + result_count = 0 + start_time = time.time() + + try: + async for result in connection.tracking_results(): + result_count += 1 + if result_count >= MAX_FRAMES: + print(f"\n✓ Reached max frames limit ({MAX_FRAMES})") + break + + if result_count % 50 == 0: + print(f" Processed {result_count}/{MAX_FRAMES} frames...") + + except KeyboardInterrupt: + print(f"\n✓ Interrupted by user") + + # Cleanup + print(f"\n{'=' * 80}") + print("Cleanup") + print(f"{ '=' * 80}") + + await connection.stop() + await manager.shutdown() + print("✓ Stopped") + + # Final stats + elapsed = time.time() - start_time + avg_fps = result_count / elapsed if elapsed > 0 else 0 + print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)") + + +if __name__ == "__main__": + # Create a profiler object + profiler = cProfile.Profile() + + # Run the async main function under the profiler + print("Starting profiler...") + profiler.enable() + + asyncio.run(profiled_main()) + + profiler.disable() + print("Profiling complete.") + + # Print the stats + s = io.StringIO() + # Sort stats by cumulative time + sortby = pstats.SortKey.CUMULATIVE + ps = pstats.Stats(profiler, stream=s).sort_stats(sortby) + ps.print_stats(30) # Print top 30 functions + + print("\n" + "="*80) + print("PROFILING RESULTS (Top 30, sorted by cumulative time)") + print("="*80) + print(s.getvalue()) diff --git a/services/__init__.py b/services/__init__.py index 7510f61..a4008b3 100644 --- a/services/__init__.py +++ b/services/__init__.py @@ -5,8 +5,7 @@ Services package for RTSP stream processing with GPU acceleration. from .stream_decoder import StreamDecoderFactory, StreamDecoder, ConnectionStatus from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionContext, SharedEngine -from .tracking_controller import TrackingController, TrackedObject -from .tracking_factory import TrackingFactory +from .tracking_controller import ObjectTracker, TrackedObject, Detection from .yolo import YOLOv8Utils, COCO_CLASSES from .model_controller import ModelController, BatchFrame, BufferState from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult @@ -23,9 +22,9 @@ __all__ = [ 'ModelMetadata', 'ExecutionContext', 'SharedEngine', - 'TrackingController', + 'ObjectTracker', 'TrackedObject', - 'TrackingFactory', + 'Detection', 'YOLOv8Utils', 'COCO_CLASSES', 'ModelController', diff --git a/services/stream_connection_manager.py b/services/stream_connection_manager.py index 0324957..a8d6b8a 100644 --- a/services/stream_connection_manager.py +++ b/services/stream_connection_manager.py @@ -16,7 +16,6 @@ import torch from .model_controller import ModelController from .stream_decoder import StreamDecoderFactory -from .tracking_factory import TrackingFactory from .model_repository import TensorRTModelRepository logger = logging.getLogger(__name__) @@ -133,28 +132,32 @@ class StreamConnection: async def _frame_poller(self): """Poll frames from threaded decoder and submit to model controller""" - last_frame_ptr = None + last_decoder_frame_count = -1 while self.running: try: - # Poll frame from decoder (runs in thread) - frame = self.decoder.get_latest_frame(rgb=True) + # Get current decoder frame count (no data transfer, just counter) + decoder_frame_count = self.decoder.get_frame_count() - # Check if we got a new frame (avoid reprocessing same frame) - if frame is not None and frame.data_ptr() != last_frame_ptr: - last_frame_ptr = frame.data_ptr() - self.last_frame_time = time.time() - self.frame_count += 1 + # Check if decoder has a new frame (avoid reprocessing same frame) + if decoder_frame_count > last_decoder_frame_count: + # Poll frame from decoder (zero-copy - stays in VRAM) + frame = self.decoder.get_latest_frame(rgb=True) - # Submit to model controller for batched inference - await self.model_controller.submit_frame( - stream_id=self.stream_id, - frame=frame, - metadata={ - "frame_number": self.frame_count, - "shape": tuple(frame.shape), - } - ) + if frame is not None: + last_decoder_frame_count = decoder_frame_count + self.last_frame_time = time.time() + self.frame_count += 1 + + # Submit to model controller for batched inference + await self.model_controller.submit_frame( + stream_id=self.stream_id, + frame=frame, + metadata={ + "frame_number": self.frame_count, + "shape": tuple(frame.shape), + } + ) # Check decoder status if not self.decoder.is_connected(): @@ -211,53 +214,37 @@ class StreamConnection: logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True) await self.error_queue.put(e) - def _run_tracking_sync(self, detections): + def _run_tracking_sync(self, detections, min_confidence=0.7): """ Run tracking synchronously (called from executor). Args: detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id] + min_confidence: Minimum confidence threshold for detections Returns: List of TrackedObject instances """ - # Use the TrackingController's internal tracking with detections - # We need to manually update tracks since we already have detections - import torch + # Convert tensor detections to Detection objects, filtering by confidence + from .tracking_controller import Detection - with self.tracking_controller._lock: - self.tracking_controller._frame_count += 1 + detection_list = [] + for det in detections: + confidence = float(det[4]) - # If no detections, just cleanup and return current tracks - if len(detections) == 0: - self.tracking_controller._cleanup_stale_tracks() - return list(self.tracking_controller._tracks.values()) + # Filter by confidence threshold (prevents track accumulation) + if confidence < min_confidence: + continue - # Run IoU tracking to associate detections with existing tracks - associations = self.tracking_controller._iou_tracking(detections) + detection_list.append(Detection( + bbox=det[:4].cpu().tolist(), + confidence=confidence, + class_id=int(det[5]) if det.shape[0] > 5 else 0, + class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown" + )) - # Update or create tracks - for (det_idx, track_id), detection in zip(associations, detections): - bbox = detection[:4].cpu().tolist() - confidence = float(detection[4]) - class_id = int(detection[5]) if detection.shape[0] > 5 else 0 - - if track_id == -1: - # Create new track - new_track = self.tracking_controller._create_track( - bbox, confidence, class_id, self.tracking_controller._frame_count - ) - self.tracking_controller._tracks[new_track.track_id] = new_track - else: - # Update existing track - self.tracking_controller._tracks[track_id].update( - bbox, confidence, self.tracking_controller._frame_count - ) - - # Cleanup stale tracks - self.tracking_controller._cleanup_stale_tracks() - - return list(self.tracking_controller._tracks.values()) + # Update tracker with detections (lightweight, no model dependency!) + return self.tracking_controller.update(detection_list) async def tracking_results(self) -> AsyncIterator[TrackingResult]: """ @@ -341,7 +328,6 @@ class StreamConnectionManager: # Factories self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id) - self.tracking_factory = TrackingFactory(gpu_id=gpu_id) self.model_repository = TensorRTModelRepository( gpu_id=gpu_id, enable_pt_conversion=enable_pt_conversion @@ -349,7 +335,6 @@ class StreamConnectionManager: # Controllers self.model_controller: Optional[ModelController] = None - self.tracking_controller = None # Connections self.connections: Dict[str, StreamConnection] = {} @@ -454,17 +439,16 @@ class StreamConnectionManager: # Create decoder decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size) - # Create dedicated tracking controller for THIS stream - # This prevents track accumulation across multiple streams - tracking_controller = self.tracking_factory.create_controller( - model_repository=self.model_repository, - model_id=self.model_id_for_tracking, + # Create lightweight tracker (NO model_repository dependency!) + from .tracking_controller import ObjectTracker + tracking_controller = ObjectTracker( + gpu_id=self.gpu_id, tracker_type="iou", max_age=30, - min_confidence=0.5, iou_threshold=0.3, + class_names=None # TODO: pass class names if available ) - logger.info(f"Created dedicated TrackingController for stream {stream_id}") + logger.info(f"Created lightweight ObjectTracker for stream {stream_id}") # Create connection connection = StreamConnection( diff --git a/services/stream_decoder.py b/services/stream_decoder.py index 327ab9a..55c2eb7 100644 --- a/services/stream_decoder.py +++ b/services/stream_decoder.py @@ -448,6 +448,10 @@ class StreamDecoder: with self._buffer_lock: return len(self.frame_buffer) + def get_frame_count(self) -> int: + """Get total number of frames decoded since start""" + return self.frame_count + def is_connected(self) -> bool: """Check if stream is actively connected""" return self.get_status() == ConnectionStatus.CONNECTED diff --git a/services/tracking_controller.py b/services/tracking_controller.py index 73b81ba..6c8cca8 100644 --- a/services/tracking_controller.py +++ b/services/tracking_controller.py @@ -5,7 +5,6 @@ from collections import defaultdict, deque import time import torch import numpy as np -from .model_repository import TensorRTModelRepository @dataclass @@ -61,78 +60,81 @@ class TrackedObject: } -class TrackingController: +@dataclass +class Detection: """ - GPU-accelerated object tracking controller that wraps TensorRTModelRepository. + Represents a single detection from object detection model. - Architecture: - - Wraps model repository for dependency injection - - Maintains CUDA state for bbox tracking operations - - Stores persistent tracking data (track IDs, histories, states) - - Processes GPU tensor frames directly (zero-copy pipeline) - - Thread-safe for concurrent tracking operations + Attributes: + bbox: Bounding box [x1, y1, x2, y2] + confidence: Detection confidence (0-1) + class_id: Object class ID + class_name: Object class name (optional) + """ + bbox: List[float] + confidence: float + class_id: int + class_name: str = "unknown" + + +class ObjectTracker: + """ + Lightweight GPU-accelerated object tracker (decoupled from inference). + + This class only handles tracking logic - associating detections with existing tracks, + maintaining track IDs, and managing track lifecycle. It does NOT perform inference. + + Architecture (Event-Driven Mode): + - Receives pre-computed detections (from ModelController) + - Maintains persistent tracking state (track IDs, histories) + - GPU-accelerated IoU computation for track association + - Thread-safe for concurrent operations Tracking Flow: - GPU Frame → Model Inference (GPU) → Detections (GPU) - ↓ - Tracking Algorithm (GPU/CPU) → Track Assignment - ↓ - Update Persistent Tracks → Return Tracked Objects + Detections → Track Association (GPU IoU) → Update Tracks → Return Tracked Objects Features: - - GPU-first: All tensor operations stay on GPU until final results + - Lightweight: No model_repository dependency (zero VRAM overhead) + - GPU-accelerated: IoU computation on GPU for performance - Persistent IDs: Tracks maintain consistent IDs across frames - Track History: Maintains trajectory history for each object - - Configurable: Supports custom tracking algorithms via callbacks - Thread-safe: Mutex-based locking for concurrent access Example: - # Initialize with DI - repo = TensorRTModelRepository(gpu_id=0) - factory = TrackingFactory(gpu_id=0) - controller = factory.create_controller( - model_repository=repo, - model_id="yolov8_detector", - tracker_type="iou" + # Event-driven mode (no model dependency) + tracker = ObjectTracker( + gpu_id=0, + tracker_type="iou", + max_age=30, + iou_threshold=0.3, + class_names=COCO_CLASSES ) - # Track objects in frame - rgb_frame = decoder.get_latest_frame() # GPU tensor - tracked_objects = controller.track(rgb_frame) - - # Get all tracked objects - all_tracks = controller.get_all_tracks() + # Update with pre-computed detections + detections = [Detection(bbox=[x1,y1,x2,y2], confidence=0.9, class_id=0)] + tracked_objects = tracker.update(detections) """ def __init__(self, - model_repository: TensorRTModelRepository, - model_id: str, gpu_id: int = 0, tracker_type: str = "iou", max_age: int = 30, - min_confidence: float = 0.5, iou_threshold: float = 0.3, class_names: Optional[Dict[int, str]] = None): """ - Initialize TrackingController. + Initialize ObjectTracker (no model dependency). Args: - model_repository: TensorRT model repository (dependency injection) - model_id: Model ID in repository to use for detection - gpu_id: GPU device ID - tracker_type: Tracking algorithm type ("iou", "sort", "deepsort", "bytetrack") + gpu_id: GPU device ID for IoU computation + tracker_type: Tracking algorithm type ("iou") max_age: Maximum frames to keep track without detection - min_confidence: Minimum confidence threshold for detections iou_threshold: IoU threshold for track association class_names: Optional mapping of class IDs to names """ - self.model_repository = model_repository - self.model_id = model_id self.gpu_id = gpu_id self.device = torch.device(f'cuda:{gpu_id}') self.tracker_type = tracker_type self.max_age = max_age - self.min_confidence = min_confidence self.iou_threshold = iou_threshold self.class_names = class_names or {} @@ -146,19 +148,6 @@ class TrackingController: self._total_detections = 0 self._total_tracks_created = 0 - # Verify model exists in repository - metadata = self.model_repository.get_metadata(model_id) - if metadata is None: - raise ValueError(f"Model '{model_id}' not found in repository") - - print(f"TrackingController initialized:") - print(f" Model ID: {model_id}") - print(f" GPU: {gpu_id}") - print(f" Tracker: {tracker_type}") - print(f" Max age: {max_age} frames") - print(f" Min confidence: {min_confidence}") - print(f" IoU threshold: {iou_threshold}") - def _compute_iou_gpu(self, boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: """ Compute IoU between two sets of boxes on GPU. @@ -283,97 +272,51 @@ class TrackingController: for tid in stale_track_ids: del self._tracks[tid] - def track(self, frame: torch.Tensor, - preprocess_fn: Optional[callable] = None, - postprocess_fn: Optional[callable] = None) -> List[TrackedObject]: + def update(self, detections: List[Detection]) -> List[TrackedObject]: """ - Track objects in a GPU tensor frame. + Update tracker with new detections (decoupled from inference). Args: - frame: RGB frame as GPU tensor, shape (3, H, W) or (1, 3, H, W) - preprocess_fn: Optional preprocessing function (frame -> model_input) - postprocess_fn: Optional postprocessing function (model_output -> detections) - Should return tensor of shape (N, 6): [x1, y1, x2, y2, conf, class_id] + detections: List of Detection objects from model inference Returns: List of currently tracked objects """ with self._lock: self._frame_count += 1 - - # Ensure frame is on correct device - if not frame.is_cuda: - frame = frame.to(self.device) - elif frame.device != self.device: - frame = frame.to(self.device) - - # Preprocess frame for model - if preprocess_fn is not None: - model_input = preprocess_fn(frame) - else: - # Default: add batch dimension if needed - if frame.dim() == 3: - model_input = frame.unsqueeze(0) # (1, 3, H, W) - else: - model_input = frame - - # Run inference (GPU-to-GPU) - # Assuming model expects input named "images" or "input" - metadata = self.model_repository.get_metadata(self.model_id) - input_name = metadata.input_names[0] if metadata else "images" - - outputs = self.model_repository.infer( - model_id=self.model_id, - inputs={input_name: model_input}, - synchronize=True - ) - - # Postprocess model output to get detections - if postprocess_fn is not None: - detections = postprocess_fn(outputs) - else: - # Default: assume output is already in correct format - # Get first output tensor - output_name = list(outputs.keys())[0] - detections = outputs[output_name] - - # Reshape if needed: (1, N, 6) -> (N, 6) - if detections.dim() == 3: - detections = detections.squeeze(0) - - # Filter by confidence - if detections.dim() == 2 and detections.shape[1] >= 5: - conf_mask = detections[:, 4] >= self.min_confidence - detections = detections[conf_mask] - self._total_detections += len(detections) - # Track objects + # No detections, just cleanup stale tracks if len(detections) == 0: - # No detections, just cleanup stale tracks self._cleanup_stale_tracks() return list(self._tracks.values()) + # Convert detections to tensor for GPU processing + det_tensor = torch.tensor( + [[*det.bbox, det.confidence, det.class_id] for det in detections], + dtype=torch.float32, + device=self.device + ) + # Run tracking algorithm if self.tracker_type == "iou": - associations = self._iou_tracking(detections) + associations = self._iou_tracking(det_tensor) else: raise NotImplementedError(f"Tracker type '{self.tracker_type}' not implemented") # Update tracks based on associations for det_idx, track_id in associations: - detection = detections[det_idx] - bbox = detection[:4].cpu().tolist() - confidence = float(detection[4]) - class_id = int(detection[5]) if detection.shape[0] > 5 else 0 + det = detections[det_idx] if track_id == -1: # Create new track - new_track = self._create_track(bbox, confidence, class_id, self._frame_count) + new_track = self._create_track( + det.bbox, det.confidence, det.class_id, self._frame_count + ) self._tracks[new_track.track_id] = new_track else: # Update existing track - self._tracks[track_id].update(bbox, confidence, self._frame_count) + self._tracks[track_id].update(det.bbox, det.confidence, self._frame_count) # Cleanup stale tracks self._cleanup_stale_tracks() @@ -476,7 +419,6 @@ class TrackingController: 'total_tracks_created': self._total_tracks_created, 'total_detections': self._total_detections, 'avg_detections_per_frame': self._total_detections / max(self._frame_count, 1), - 'model_id': self.model_id, 'tracker_type': self.tracker_type, 'class_counts': self.get_class_counts(active_only=True) } @@ -518,7 +460,6 @@ class TrackingController: def __repr__(self): with self._lock: - return (f"TrackingController(model={self.model_id}, " - f"tracker={self.tracker_type}, " + return (f"ObjectTracker(tracker={self.tracker_type}, " f"frame={self._frame_count}, " f"tracks={len(self._tracks)})") diff --git a/services/tracking_factory.py b/services/tracking_factory.py index cb2b403..fe464db 100644 --- a/services/tracking_factory.py +++ b/services/tracking_factory.py @@ -1,8 +1,11 @@ import threading from typing import Optional, Dict -from .tracking_controller import TrackingController +from .tracking_controller import ObjectTracker from .model_repository import TensorRTModelRepository +# Backward compatibility alias (TrackingFactory is deprecated in event-driven mode) +TrackingController = ObjectTracker + class TrackingFactory: """ diff --git a/test_tracking_realtime.py b/test_tracking_realtime.py index d8e2254..863de8e 100644 --- a/test_tracking_realtime.py +++ b/test_tracking_realtime.py @@ -13,6 +13,8 @@ import asyncio import time import os import torch +import cv2 +import numpy as np from dotenv import load_dotenv from services import ( StreamConnectionManager, @@ -32,17 +34,21 @@ async def main_single_stream(): # Configuration GPU_ID = 0 - MODEL_PATH = "models/yolov8n.pt" # PT file will be auto-converted + MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # PT file will be auto-converted STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test') BATCH_SIZE = 4 FORCE_TIMEOUT = 0.05 + ENABLE_DISPLAY = os.getenv('ENABLE_DISPLAY', 'false').lower() == 'true' # Set to 'true' to enable OpenCV display + MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300')) # Stop after N frames (0 = unlimited) print(f"\nConfiguration:") print(f" GPU: {GPU_ID}") print(f" Model: {MODEL_PATH}") print(f" Stream: {STREAM_URL}") print(f" Batch size: {BATCH_SIZE}") - print(f" Force timeout: {FORCE_TIMEOUT}s\n") + print(f" Force timeout: {FORCE_TIMEOUT}s") + print(f" Display: {'Enabled' if ENABLE_DISPLAY else 'Disabled (inference only)'}") + print(f" Max frames: {MAX_FRAMES if MAX_FRAMES > 0 else 'Unlimited'}\n") # Create StreamConnectionManager with PT conversion enabled print("[1/3] Creating StreamConnectionManager...") @@ -94,14 +100,68 @@ async def main_single_stream(): print("Press Ctrl+C to stop") print(f"{'=' * 80}\n") - # Stream results + # Stream results with optional OpenCV visualization result_count = 0 start_time = time.time() + # Create window only if display is enabled + if ENABLE_DISPLAY: + cv2.namedWindow("Object Tracking", cv2.WINDOW_NORMAL) + cv2.resizeWindow("Object Tracking", 1280, 720) + try: async for result in connection.tracking_results(): result_count += 1 + # Check if we've reached max frames + if MAX_FRAMES > 0 and result_count >= MAX_FRAMES: + print(f"\n✓ Reached max frames limit ({MAX_FRAMES})") + break + + # OpenCV visualization (only if enabled) + if ENABLE_DISPLAY: + # Get latest frame from decoder (as CPU numpy array) + frame = connection.decoder.get_latest_frame_cpu(rgb=True) + + if frame is not None: + # Convert to BGR for OpenCV + frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + # Draw tracked objects + for obj in result.tracked_objects: + # Get bbox coordinates + x1, y1, x2, y2 = map(int, obj.bbox) + + # Draw bounding box + cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2) + + # Draw track ID and class name + label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}" + label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + + # Draw label background + cv2.rectangle(frame_bgr, (x1, y1 - label_size[1] - 10), + (x1 + label_size[0], y1), (0, 255, 0), -1) + + # Draw label text + cv2.putText(frame_bgr, label, (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) + + # Draw FPS and object count + elapsed = time.time() - start_time + fps = result_count / elapsed if elapsed > 0 else 0 + info_text = f"FPS: {fps:.1f} | Objects: {len(result.tracked_objects)} | Frame: {result_count}" + cv2.putText(frame_bgr, info_text, (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + + # Display frame + cv2.imshow("Object Tracking", frame_bgr) + + # Check for 'q' key to quit + if cv2.waitKey(1) & 0xFF == ord('q'): + print(f"\n✓ Quit by user (pressed 'q')") + break + # Print stats every 30 results if result_count % 30 == 0: elapsed = time.time() - start_time @@ -110,7 +170,7 @@ async def main_single_stream(): print(f"\nResults: {result_count} | FPS: {fps:.1f}") print(f" Stream: {result.stream_id}") print(f" Objects: {len(result.tracked_objects)}") - + if result.tracked_objects: class_counts = {} for obj in result.tracked_objects: @@ -125,6 +185,10 @@ async def main_single_stream(): print("Cleanup") print(f"{'=' * 80}") + # Close OpenCV window if it was opened + if ENABLE_DISPLAY: + cv2.destroyAllWindows() + await connection.stop() await manager.shutdown() print("✓ Stopped")