From a519dea130cc315c9680cd6afeae162483cf7bbd Mon Sep 17 00:00:00 2001 From: Siwat Sirichai Date: Tue, 11 Nov 2025 02:02:12 +0700 Subject: [PATCH] new buffer paradigm --- services/base_model_controller.py | 279 ++++++++--------------- services/inference_engine.py | 60 +++++ services/stream_connection_manager.py | 26 ++- services/tensorrt_model_controller.py | 6 +- services/ultralytics_model_controller.py | 51 +++-- test_tracking_realtime.py | 246 ++++++++++---------- 6 files changed, 341 insertions(+), 327 deletions(-) diff --git a/services/base_model_controller.py b/services/base_model_controller.py index c8d81f9..0175702 100644 --- a/services/base_model_controller.py +++ b/services/base_model_controller.py @@ -1,16 +1,19 @@ """ -Base Model Controller - Abstract base class for batched inference controllers. +Base Model Controller - Simple circular buffer with continuous batch processing. -Provides ping-pong buffer architecture with force-switch timeout mechanism. -Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.). +Replaces the complex ping-pong buffer architecture with a simple queue: +- Frames arrive and go into a single deque (circular buffer) +- When batch_size frames are ready, process them +- Continue consuming batches until queue is empty +- Drop oldest frames if queue is full """ import logging import threading import time from abc import ABC, abstractmethod +from collections import deque from dataclasses import dataclass, field -from enum import Enum from typing import Any, Callable, Dict, List, Optional import torch @@ -28,62 +31,37 @@ class BatchFrame: metadata: Dict = field(default_factory=dict) -class BufferState(Enum): - """State of a ping-pong buffer""" - - IDLE = "idle" - FILLING = "filling" - PROCESSING = "processing" - - class BaseModelController(ABC): """ - Abstract base class for batched inference with ping-pong buffers. + Simple batched inference with circular buffer. - This controller accumulates frames from multiple streams into batches, - processes them through an inference backend, and routes results back to - stream-specific callbacks. - - Features: - - Ping-pong circular buffers (BufferA/BufferB) - - Force-switch timeout to prevent batch starvation - - Event-driven processing with callbacks - - Thread-safe frame submission - - Subclasses must implement: - - _run_batch_inference(): Backend-specific inference logic + Architecture: + - Single deque (circular buffer) for incoming frames + - Batch processor thread continuously consumes batches + - Frames come in fast, batches go out as fast as inference allows + - Automatic frame dropping when queue is full """ def __init__( self, model_id: str, batch_size: int = 16, - force_timeout: float = 0.05, + max_queue_size: int = 100, preprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None, ): self.model_id = model_id self.batch_size = batch_size - self.force_timeout = force_timeout + self.max_queue_size = max_queue_size self.preprocess_fn = preprocess_fn self.postprocess_fn = postprocess_fn - # Ping-pong buffers - self.buffer_a: List[BatchFrame] = [] - self.buffer_b: List[BatchFrame] = [] + # Single circular buffer + self.frame_queue = deque(maxlen=max_queue_size) + self.queue_lock = threading.Lock() - # Buffer states - self.active_buffer = "A" - self.buffer_a_state = BufferState.IDLE - self.buffer_b_state = BufferState.IDLE - - # Threading coordination - self.buffer_lock = threading.RLock() - self.last_submit_time = time.time() - - # Threads - self.timeout_thread: Optional[threading.Thread] = None - self.processor_threads: Dict[str, threading.Thread] = {} + # Processing thread + self.processor_thread: Optional[threading.Thread] = None self.running = False self.stop_event = threading.Event() @@ -93,33 +71,24 @@ class BaseModelController(ABC): # Statistics self.total_frames_processed = 0 self.total_batches_processed = 0 + self.total_frames_dropped = 0 def start(self): - """Start the controller background threads""" + """Start the controller background thread""" if self.running: - logger.warning("ModelController already running") + logger.warning(f"{self.__class__.__name__} already running") return self.running = True self.stop_event.clear() - # Start timeout monitor thread - self.timeout_thread = threading.Thread( - target=self._timeout_monitor, daemon=True + # Start single processor thread + self.processor_thread = threading.Thread( + target=self._batch_processor, daemon=True ) - self.timeout_thread.start() + self.processor_thread.start() - # Start processor threads for each buffer - self.processor_threads["A"] = threading.Thread( - target=self._batch_processor, args=("A",), daemon=True - ) - self.processor_threads["B"] = threading.Thread( - target=self._batch_processor, args=("B",), daemon=True - ) - self.processor_threads["A"].start() - self.processor_threads["B"].start() - - logger.info(f"{self.__class__.__name__} started") + logger.info(f"{self.__class__.__name__} started (batch_size={self.batch_size})") def stop(self): """Stop the controller and cleanup""" @@ -130,16 +99,12 @@ class BaseModelController(ABC): self.running = False self.stop_event.set() - # Wait for threads to finish - if self.timeout_thread and self.timeout_thread.is_alive(): - self.timeout_thread.join(timeout=2.0) + # Wait for thread to finish + if self.processor_thread and self.processor_thread.is_alive(): + self.processor_thread.join(timeout=2.0) - for thread in self.processor_threads.values(): - if thread and thread.is_alive(): - thread.join(timeout=2.0) - - # Process any remaining frames - self._process_remaining_buffers() + # Process remaining frames + self._process_remaining_frames() logger.info(f"{self.__class__.__name__} stopped") def register_callback(self, stream_id: str, callback: Callable): @@ -156,98 +121,48 @@ class BaseModelController(ABC): self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None ): """Submit a frame for batched inference""" - with self.buffer_lock: + with self.queue_lock: + # If queue is full, oldest frame is automatically dropped (deque with maxlen) + if len(self.frame_queue) >= self.max_queue_size: + self.total_frames_dropped += 1 + batch_frame = BatchFrame( stream_id=stream_id, frame=frame, timestamp=time.time(), metadata=metadata or {}, ) + self.frame_queue.append(batch_frame) - # Add to active buffer - if self.active_buffer == "A": - self.buffer_a.append(batch_frame) - self.buffer_a_state = BufferState.FILLING - buffer_size = len(self.buffer_a) - else: - self.buffer_b.append(batch_frame) - self.buffer_b_state = BufferState.FILLING - buffer_size = len(self.buffer_b) + def _batch_processor(self): + """Background thread that continuously processes batches""" + logger.info(f"{self.__class__.__name__} batch processor started") - self.last_submit_time = time.time() - - # Check if we should immediately swap (batch full) - if buffer_size >= self.batch_size: - self._try_swap_buffers() - - def _timeout_monitor(self): - """Monitor force-switch timeout""" - while self.running and not self.stop_event.wait(0.01): - with self.buffer_lock: - time_since_submit = time.time() - self.last_submit_time - - if time_since_submit >= self.force_timeout: - active_buffer = ( - self.buffer_a if self.active_buffer == "A" else self.buffer_b - ) - if len(active_buffer) > 0: - self._try_swap_buffers() - - def _try_swap_buffers(self): - """Attempt to swap ping-pong buffers (called with buffer_lock held)""" - inactive_state = ( - self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state - ) - - if inactive_state != BufferState.PROCESSING: - old_active = self.active_buffer - self.active_buffer = "B" if old_active == "A" else "A" - - if old_active == "A": - self.buffer_a_state = BufferState.PROCESSING - buffer_size = len(self.buffer_a) - else: - self.buffer_b_state = BufferState.PROCESSING - buffer_size = len(self.buffer_b) - - logger.debug( - f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})" - ) - - def _batch_processor(self, buffer_name: str): - """Background thread that processes a specific buffer when available""" while self.running and not self.stop_event.is_set(): - time.sleep(0.001) + # Check if we have enough frames for a batch + with self.queue_lock: + queue_size = len(self.frame_queue) - with self.buffer_lock: - if buffer_name == "A": - should_process = self.buffer_a_state == BufferState.PROCESSING - else: - should_process = self.buffer_b_state == BufferState.PROCESSING + if queue_size > 0 and queue_size % 10 == 0: + logger.info(f"Queue size: {queue_size}/{self.batch_size}") - if should_process: - self._process_buffer(buffer_name) + if queue_size >= self.batch_size: + # Extract batch + with self.queue_lock: + batch = [] + for _ in range(min(self.batch_size, len(self.frame_queue))): + if self.frame_queue: + batch.append(self.frame_queue.popleft()) - def _process_buffer(self, buffer_name: str): - """Process a buffer through inference""" - # Extract buffer to process - with self.buffer_lock: - if buffer_name == "A": - batch = self.buffer_a.copy() - self.buffer_a.clear() + if batch: + logger.info(f"Processing batch of {len(batch)} frames") + self._process_batch(batch) else: - batch = self.buffer_b.copy() - self.buffer_b.clear() + # Not enough frames, sleep briefly + time.sleep(0.001) # 1ms - if len(batch) == 0: - with self.buffer_lock: - if buffer_name == "A": - self.buffer_a_state = BufferState.IDLE - else: - self.buffer_b_state = BufferState.IDLE - return - - # Process batch (outside lock to allow concurrent submissions) + def _process_batch(self, batch: List[BatchFrame]): + """Process a batch through inference""" try: start_time = time.time() results = self._run_batch_inference(batch) @@ -257,8 +172,7 @@ class BaseModelController(ABC): self.total_batches_processed += 1 logger.debug( - f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms " - f"({inference_time * 1000 / len(batch):.2f}ms per frame)" + f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms" ) # Emit results to callbacks @@ -276,49 +190,58 @@ class BaseModelController(ABC): except Exception as e: logger.error(f"Error processing batch: {e}", exc_info=True) - finally: - with self.buffer_lock: - if buffer_name == "A": - self.buffer_a_state = BufferState.IDLE - else: - self.buffer_b_state = BufferState.IDLE - @abstractmethod def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: - """ - Run inference on a batch of frames (backend-specific). - - Args: - batch: List of BatchFrame objects - - Returns: - List of detection results (one per frame) - """ + """Run inference on a batch of frames (backend-specific)""" pass - def _process_remaining_buffers(self): - """Process any remaining frames in buffers during shutdown""" - if len(self.buffer_a) > 0: - logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A") - self._process_buffer("A") - if len(self.buffer_b) > 0: - logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B") - self._process_buffer("B") + def _process_remaining_frames(self): + """Process any remaining frames in queue during shutdown""" + with self.queue_lock: + remaining = len(self.frame_queue) + + if remaining > 0: + logger.info(f"Processing remaining {remaining} frames") + while True: + with self.queue_lock: + if not self.frame_queue: + break + batch = [] + for _ in range(min(self.batch_size, len(self.frame_queue))): + if self.frame_queue: + batch.append(self.frame_queue.popleft()) + + if batch: + self._process_batch(batch) def get_stats(self) -> Dict[str, Any]: - """Get current buffer statistics""" + """Get current statistics""" + with self.queue_lock: + queue_size = len(self.frame_queue) + return { - "active_buffer": self.active_buffer, - "buffer_a_size": len(self.buffer_a), - "buffer_b_size": len(self.buffer_b), - "buffer_a_state": self.buffer_a_state.value, - "buffer_b_state": self.buffer_b_state.value, + "queue_size": queue_size, + "max_queue_size": self.max_queue_size, + "batch_size": self.batch_size, "registered_streams": len(self.result_callbacks), "total_frames_processed": self.total_frames_processed, "total_batches_processed": self.total_batches_processed, + "total_frames_dropped": self.total_frames_dropped, "avg_batch_size": ( self.total_frames_processed / self.total_batches_processed if self.total_batches_processed > 0 else 0 ), } + + +# Keep old BufferState enum for backwards compatibility +from enum import Enum + + +class BufferState(Enum): + """Deprecated - kept for backwards compatibility""" + + IDLE = "idle" + FILLING = "filling" + PROCESSING = "processing" diff --git a/services/inference_engine.py b/services/inference_engine.py index 191ab82..e03a2d3 100644 --- a/services/inference_engine.py +++ b/services/inference_engine.py @@ -9,6 +9,7 @@ Provides a unified interface for different inference backends: All engines support zero-copy GPU tensor inference where possible. """ +import logging from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum @@ -17,6 +18,8 @@ from typing import Any, Dict, List, Optional, Tuple import torch +logger = logging.getLogger(__name__) + class BackendType(Enum): """Supported inference backend types""" @@ -423,9 +426,18 @@ class UltralyticsEngine(IInferenceEngine): final_model_path = engine_path print(f"Using TensorRT engine: {engine_path}") + # CRITICAL: Update _model_path to point to the .engine file for metadata extraction + self._model_path = engine_path + # Load model (Ultralytics handles .engine files natively) self._model = YOLO(final_model_path) + logger.info(f"Loaded Ultralytics model: {type(self._model)}") + if hasattr(self._model, "predictor"): + logger.info( + f"Model has predictor: {type(self._model.predictor) if self._model.predictor else None}" + ) + # Move to device if needed (only for .pt models, .engine already on specific device) if hasattr(self._model, "model") and self._model.model is not None: # Check if it's actually a torch model (not a string path for .engine files) @@ -437,6 +449,39 @@ class UltralyticsEngine(IInferenceEngine): return self._metadata + def _read_batch_size_from_engine_file(self, engine_path: str) -> int: + """ + Read batch size from the metadata JSON file saved next to the engine. + + Much simpler than parsing TensorRT engine! + """ + try: + import json + from pathlib import Path + + # The metadata file is named: _metadata.json + engine_file = Path(engine_path) + metadata_file = engine_file.with_name(f"{engine_file.stem}_metadata.json") + + print(f"[UltralyticsEngine] Looking for metadata file: {metadata_file}") + + if metadata_file.exists(): + with open(metadata_file, "r") as f: + metadata = json.load(f) + batch_size = metadata.get("batch", -1) + print( + f"[UltralyticsEngine] Found metadata: batch={batch_size}, imgsz={metadata.get('imgsz')}" + ) + return batch_size + else: + print(f"[UltralyticsEngine] Metadata file not found: {metadata_file}") + except Exception as e: + print( + f"[UltralyticsEngine] Could not read batch size from metadata file: {e}" + ) + + return -1 # Default to dynamic + def _extract_metadata(self) -> EngineMetadata: """Extract metadata from Ultralytics model""" # Ultralytics models typically expect (B, 3, H, W) input @@ -447,6 +492,17 @@ class UltralyticsEngine(IInferenceEngine): imgsz = 640 input_shape = (batch_size, 3, imgsz, imgsz) + # CRITICAL: For .engine files, read batch size directly from the TensorRT engine file + print(f"[UltralyticsEngine] _model_path={self._model_path}") + if self._model_path.endswith(".engine"): + print(f"[UltralyticsEngine] Reading batch size from engine file...") + batch_size = self._read_batch_size_from_engine_file(self._model_path) + print(f"[UltralyticsEngine] Read batch_size={batch_size} from .engine file") + if batch_size > 0: + input_shape = (batch_size, 3, imgsz, imgsz) + else: + print(f"[UltralyticsEngine] Not an .engine file, skipping direct read") + if hasattr(self._model, "model") and self._model.model is not None: # Try to get actual input shape from model try: @@ -508,6 +564,10 @@ class UltralyticsEngine(IInferenceEngine): logger.warning(f"Could not extract full metadata: {e}") pass + logger.info( + f"Extracted Ultralytics metadata: batch_size={batch_size}, imgsz={imgsz}, input_shape={input_shape}" + ) + return EngineMetadata( engine_type="ultralytics", model_path=self._model_path, diff --git a/services/stream_connection_manager.py b/services/stream_connection_manager.py index 1af1ce3..cabbad7 100644 --- a/services/stream_connection_manager.py +++ b/services/stream_connection_manager.py @@ -42,6 +42,7 @@ class TrackingResult: tracked_objects: List # List of TrackedObject from TrackingController detections: List # Raw detections frame_shape: Tuple[int, int, int] + frame_tensor: Optional[torch.Tensor] # GPU tensor of the frame (C, H, W) metadata: Dict @@ -158,6 +159,9 @@ class StreamConnection: # Submit to model controller for batched inference # Pass the FrameReference in metadata so we can free it later + logger.debug( + f"[{self.stream_id}] Submitting frame {self.frame_count} to model controller" + ) self.model_controller.submit_frame( stream_id=self.stream_id, frame=cloned_tensor, # Use cloned tensor, not original @@ -167,6 +171,9 @@ class StreamConnection: "frame_ref": frame_ref, # Store reference for later cleanup }, ) + logger.debug( + f"[{self.stream_id}] Frame {self.frame_count} submitted, queue size: {len(self.model_controller.frame_queue)}" + ) # Update connection status based on decoder status if ( @@ -211,6 +218,12 @@ class StreamConnection: frame_shape = result["metadata"].get("shape") tracked_objects = self._run_tracking_sync(detections, frame_shape) + # Get ORIGINAL frame tensor from metadata (not the preprocessed one in result["frame"]) + # The frame in result["frame"] is preprocessed (resized, normalized) + # We need the original frame for visualization + frame_ref = result["metadata"].get("frame_ref") + frame_tensor = frame_ref.rgb_tensor if frame_ref else None + # Create tracking result tracking_result = TrackingResult( stream_id=self.stream_id, @@ -218,6 +231,7 @@ class StreamConnection: tracked_objects=tracked_objects, detections=detections, frame_shape=result["metadata"].get("shape"), + frame_tensor=frame_tensor, # Original frame, not preprocessed metadata=result["metadata"], ) @@ -328,7 +342,7 @@ class StreamConnectionManager: Args: gpu_id: GPU device ID (default: 0) batch_size: Maximum batch size for inference (default: 16) - force_timeout: Force buffer switch timeout in seconds (default: 0.05) + max_queue_size: Maximum frames in queue before dropping (default: 100) poll_interval: Frame polling interval in seconds (default: 0.01) Example: @@ -343,14 +357,14 @@ class StreamConnectionManager: self, gpu_id: int = 0, batch_size: int = 16, - force_timeout: float = 0.05, + max_queue_size: int = 100, poll_interval: float = 0.01, enable_pt_conversion: bool = True, backend: str = "tensorrt", # "tensorrt" or "ultralytics" ): self.gpu_id = gpu_id self.batch_size = batch_size - self.force_timeout = force_timeout + self.max_queue_size = max_queue_size self.poll_interval = poll_interval self.backend = backend.lower() @@ -449,7 +463,7 @@ class StreamConnectionManager: inference_engine=self.inference_engine, model_id=model_id, batch_size=self.batch_size, - force_timeout=self.force_timeout, + max_queue_size=self.max_queue_size, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, ) @@ -473,7 +487,7 @@ class StreamConnectionManager: model_repository=self.model_repository, model_id=model_id, batch_size=self.batch_size, - force_timeout=self.force_timeout, + max_queue_size=self.max_queue_size, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, ) @@ -656,7 +670,7 @@ class StreamConnectionManager: "gpu_id": self.gpu_id, "num_connections": len(self.connections), "batch_size": self.batch_size, - "force_timeout": self.force_timeout, + "max_queue_size": self.max_queue_size, "poll_interval": self.poll_interval, }, "model_controller": self.model_controller.get_stats() diff --git a/services/tensorrt_model_controller.py b/services/tensorrt_model_controller.py index 47b17e4..23b0bf2 100644 --- a/services/tensorrt_model_controller.py +++ b/services/tensorrt_model_controller.py @@ -25,14 +25,14 @@ class TensorRTModelController(BaseModelController): model_repository, model_id: str, batch_size: int = 16, - force_timeout: float = 0.05, + max_queue_size: int = 100, preprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None, ): super().__init__( model_id=model_id, batch_size=batch_size, - force_timeout=force_timeout, + max_queue_size=max_queue_size, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, ) @@ -115,6 +115,7 @@ class TensorRTModelController(BaseModelController): "stream_id": batch_frame.stream_id, "timestamp": batch_frame.timestamp, "detections": detections, + "frame": batch_frame.frame, # Include original frame tensor "metadata": batch_frame.metadata, } results.append(result) @@ -175,6 +176,7 @@ class TensorRTModelController(BaseModelController): "stream_id": batch_frame.stream_id, "timestamp": batch_frame.timestamp, "detections": detections, + "frame": batch_frame.frame, # Include original frame tensor "metadata": batch_frame.metadata, } results.append(result) diff --git a/services/ultralytics_model_controller.py b/services/ultralytics_model_controller.py index dbf1406..3ef26fc 100644 --- a/services/ultralytics_model_controller.py +++ b/services/ultralytics_model_controller.py @@ -25,20 +25,27 @@ class UltralyticsModelController(BaseModelController): inference_engine, model_id: str, batch_size: int = 16, - force_timeout: float = 0.05, + max_queue_size: int = 100, preprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None, ): # Auto-detect actual batch size from the YOLO engine + print(f"[UltralyticsModelController] Detecting batch size from engine...") engine_batch_size = self._detect_engine_batch_size(inference_engine) + print( + f"[UltralyticsModelController] Detected engine_batch_size={engine_batch_size}" + ) # If engine has fixed batch size, use it. Otherwise use user's batch_size actual_batch_size = engine_batch_size if engine_batch_size > 0 else batch_size + print( + f"[UltralyticsModelController] Using actual_batch_size={actual_batch_size}" + ) super().__init__( model_id=model_id, batch_size=actual_batch_size, - force_timeout=force_timeout, + max_queue_size=max_queue_size, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, ) @@ -46,11 +53,23 @@ class UltralyticsModelController(BaseModelController): self.engine_batch_size = engine_batch_size # Store for padding logic if engine_batch_size > 0: + print(f"āœ“ Ultralytics engine has FIXED batch_size={engine_batch_size}") + print( + f" Will pad/truncate all batches to exactly {engine_batch_size} frames" + ) logger.info( f"Ultralytics engine has fixed batch_size={engine_batch_size}, " f"will pad batches to match" ) + # CRITICAL: Override the parent's batch_size to match engine's fixed size + # This prevents buffer accumulation beyond the engine's capacity + self.batch_size = engine_batch_size + print(f" Controller self.batch_size is now: {self.batch_size}") + print(f" Buffer will swap when size >= {self.batch_size}") else: + print( + f"āœ“ Ultralytics engine supports DYNAMIC batching, max={actual_batch_size}" + ) logger.info( f"Ultralytics engine supports dynamic batching, " f"using max batch_size={actual_batch_size}" @@ -67,16 +86,22 @@ class UltralyticsModelController(BaseModelController): # Get engine metadata metadata = inference_engine.get_metadata() + logger.info(f"Detecting batch size from engine metadata: {metadata}") + # Check input shape for batch dimension if "images" in metadata.input_shapes: input_shape = metadata.input_shapes["images"] batch_dim = input_shape[0] + logger.info(f"Found batch dimension in metadata: {batch_dim}") + if batch_dim > 0: # Fixed batch size + logger.info(f"Using fixed batch size from engine: {batch_dim}") return batch_dim else: # Dynamic batch size (-1) + logger.info("Engine supports dynamic batching (batch_dim=-1)") return -1 # Fallback: try to get from model directly @@ -187,28 +212,16 @@ class UltralyticsModelController(BaseModelController): # No detections detections = torch.zeros((0, 6), device=batch_tensor.device) - # Apply custom postprocessing if provided - if self.postprocess_fn: - try: - # For Ultralytics, postprocess_fn might do additional filtering - # Pass the raw boxes tensor in the same format as TensorRT output - detections = self.postprocess_fn( - { - "output0": detections.unsqueeze( - 0 - ) # Add batch dim for compatibility - } - ) - except Exception as e: - logger.error( - f"Error in postprocess for stream {batch_frame.stream_id}: {e}" - ) - detections = torch.zeros((0, 6), device=batch_tensor.device) + # NOTE: Skip postprocess_fn for Ultralytics backend! + # Ultralytics already does confidence filtering, NMS, and format conversion. + # The detections are already in final format: [x1, y1, x2, y2, conf, cls] + # Any custom postprocess_fn would expect raw TensorRT output and will fail. result = { "stream_id": batch_frame.stream_id, "timestamp": batch_frame.timestamp, "detections": detections, + "frame": batch_frame.frame, # Include original frame tensor "metadata": batch_frame.metadata, "yolo_result": yolo_result, # Keep original Results object for debugging } diff --git a/test_tracking_realtime.py b/test_tracking_realtime.py index 4e5c2be..8c0e58e 100644 --- a/test_tracking_realtime.py +++ b/test_tracking_realtime.py @@ -4,11 +4,11 @@ Real-time object tracking with event-driven batching architecture. This script demonstrates: - Event-driven stream processing with StreamConnectionManager - Batched GPU inference with ModelController -- Ping-pong buffer architecture for optimal throughput - Callback-based event-driven pattern for RTSP streams - Automatic PT to TensorRT conversion """ +import logging import os import threading import time @@ -28,6 +28,11 @@ from services import ( # Load environment variables load_dotenv() +# Enable debug logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + def main_multi_stream(): """Multi-stream example with batched inference.""" @@ -41,8 +46,8 @@ def main_multi_stream(): USE_ULTRALYTICS = ( os.getenv("USE_ULTRALYTICS", "true").lower() == "true" ) # Use Ultralytics engine for YOLO - BATCH_SIZE = 2 # Reduced to 2 to avoid GPU memory issues - FORCE_TIMEOUT = 0.05 + BATCH_SIZE = 2 # Must match engine's fixed batch size + MAX_QUEUE_SIZE = 50 # Drop frames if queue gets too long ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true" # Load camera URLs @@ -73,10 +78,11 @@ def main_multi_stream(): manager = StreamConnectionManager( gpu_id=GPU_ID, batch_size=BATCH_SIZE, - force_timeout=FORCE_TIMEOUT, + max_queue_size=MAX_QUEUE_SIZE, enable_pt_conversion=True, backend=backend, ) + print("āœ“ Manager created") # Initialize model (transparent loading) @@ -86,7 +92,6 @@ def main_multi_stream(): model_path=MODEL_PATH, model_id="detector", preprocess_fn=YOLOv8Utils.preprocess, - postprocess_fn=YOLOv8Utils.postprocess, num_contexts=1, # Single context to minimize GPU memory usage # Note: No pt_input_shapes or pt_precision needed for YOLO models! ) @@ -98,6 +103,109 @@ def main_multi_stream(): traceback.print_exc() return + # Track stats (initialize before callback definition) + stream_stats = {sid: {"count": 0, "start": time.time()} for sid, _ in camera_urls} + total_results = 0 + start_time = time.time() + stats_lock = threading.Lock() + + # Create windows for each stream if display enabled + if ENABLE_DISPLAY: + for stream_id, _ in camera_urls: + cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL) + cv2.resizeWindow( + stream_id, 640, 360 + ) # Smaller windows for multiple streams + + def on_tracking_result(result): + """Callback for tracking results - called automatically per stream""" + nonlocal total_results + + # Debug: Check if we have frame tensor + has_frame = result.frame_tensor is not None + frame_shape = result.frame_tensor.shape if has_frame else None + print( + f"[CALLBACK] Got result for {result.stream_id}, has_frame={has_frame}, shape={frame_shape}, detections={len(result.detections)}" + ) + + with stats_lock: + total_results += 1 + stream_id = result.stream_id + if stream_id in stream_stats: + stream_stats[stream_id]["count"] += 1 + + # Print stats every 10 results (changed from 100 for faster feedback) + if total_results % 10 == 0: + elapsed = time.time() - start_time + total_fps = total_results / elapsed if elapsed > 0 else 0 + print( + f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS" + ) + for sid, stats in stream_stats.items(): + s_elapsed = time.time() - stats["start"] + s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0 + print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)") + + # Display visualization if enabled + if ENABLE_DISPLAY and result.frame_tensor is not None: + # Convert GPU tensor (C, H, W) to CPU numpy (H, W, C) for OpenCV + frame_tensor = result.frame_tensor # (3, 720, 1280) RGB uint8 + frame_np = ( + frame_tensor.cpu().permute(1, 2, 0).numpy().astype(np.uint8) + ) # (720, 1280, 3) + frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR) + + # Draw bounding boxes + for obj in result.tracked_objects: + x1, y1, x2, y2 = map(int, obj.bbox) + + # Draw box + cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2) + + # Draw label with ID and class + label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}" + (label_w, label_h), _ = cv2.getTextSize( + label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 + ) + cv2.rectangle( + frame_bgr, + (x1, y1 - label_h - 10), + (x1 + label_w, y1), + (0, 255, 0), + -1, + ) + cv2.putText( + frame_bgr, + label, + (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 0, 0), + 1, + ) + + # Show FPS on frame + with stats_lock: + s_elapsed = time.time() - stream_stats[stream_id]["start"] + s_fps = ( + stream_stats[stream_id]["count"] / s_elapsed if s_elapsed > 0 else 0 + ) + fps_text = ( + f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects" + ) + cv2.putText( + frame_bgr, + fps_text, + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (0, 255, 0), + 2, + ) + + # Display + cv2.imshow(stream_id, frame_bgr) + # Connect all streams in parallel using threads print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...") connections = {} @@ -108,7 +216,10 @@ def main_multi_stream(): """Thread worker to connect a single stream""" try: conn = manager.connect_stream( - rtsp_url=rtsp_url, stream_id=stream_id, buffer_size=3 + rtsp_url=rtsp_url, + stream_id=stream_id, + buffer_size=2, + on_tracking_result=on_tracking_result, # Register callback ) connection_results[stream_id] = ("success", conn) except Exception as e: @@ -144,124 +255,15 @@ def main_multi_stream(): print("Press Ctrl+C to stop") print(f"{'=' * 80}\n") - # Track stats - stream_stats = { - sid: {"count": 0, "start": time.time()} for sid in connections.keys() - } - total_results = 0 - start_time = time.time() - - # Create windows for each stream if display enabled - if ENABLE_DISPLAY: - for stream_id in connections.keys(): - cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL) - cv2.resizeWindow( - stream_id, 640, 360 - ) # Smaller windows for multiple streams - try: - # Merge all result queues from all connections - import queue as queue_module - - running = True - while running: - # Poll all connection queues (non-blocking) - got_result = False - for conn in connections.values(): - try: - # Non-blocking get from each connection's queue - result = conn.result_queue.get_nowait() - got_result = True - - total_results += 1 - stream_id = result.stream_id - - if stream_id in stream_stats: - stream_stats[stream_id]["count"] += 1 - - # Display visualization if enabled - if ENABLE_DISPLAY: - # Get latest frame from decoder (already in CPU memory as numpy RGB) - frame_rgb = conn.decoder.get_latest_frame_cpu(rgb=True) - if frame_rgb is not None: - # Convert RGB to BGR for OpenCV - frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) - - # Draw bounding boxes - for obj in result.tracked_objects: - x1, y1, x2, y2 = map(int, obj.bbox) - - # Draw box - cv2.rectangle( - frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2 - ) - - # Draw label with ID and class - label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}" - (label_w, label_h), _ = cv2.getTextSize( - label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 - ) - cv2.rectangle( - frame_bgr, - (x1, y1 - label_h - 10), - (x1 + label_w, y1), - (0, 255, 0), - -1, - ) - cv2.putText( - frame_bgr, - label, - (x1, y1 - 5), - cv2.FONT_HERSHEY_SIMPLEX, - 0.5, - (0, 0, 0), - 1, - ) - - # Show FPS on frame - s_elapsed = time.time() - stream_stats[stream_id]["start"] - s_fps = ( - stream_stats[stream_id]["count"] / s_elapsed - if s_elapsed > 0 - else 0 - ) - fps_text = f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects" - cv2.putText( - frame_bgr, - fps_text, - (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, - 0.7, - (0, 255, 0), - 2, - ) - - # Display - cv2.imshow(stream_id, frame_bgr) - - # Print stats every 100 results - if total_results % 100 == 0: - elapsed = time.time() - start_time - total_fps = total_results / elapsed if elapsed > 0 else 0 - - print( - f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS" - ) - for sid, stats in stream_stats.items(): - s_elapsed = time.time() - stats["start"] - s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0 - print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)") - - except queue_module.Empty: - continue - - # Process OpenCV events to keep windows responsive + # Keep main thread alive and process OpenCV events + while True: if ENABLE_DISPLAY: - cv2.waitKey(1) - - # Small sleep if no results to avoid busy loop - if not got_result: - time.sleep(0.01) + # Process OpenCV events to keep windows responsive + if cv2.waitKey(1) & 0xFF == ord("q"): + break + else: + time.sleep(0.1) except KeyboardInterrupt: print(f"\nāœ“ Interrupted")