new buffer paradigm

This commit is contained in:
Siwat Sirichai 2025-11-11 02:02:12 +07:00
parent fdaeb9981c
commit a519dea130
6 changed files with 341 additions and 327 deletions

View file

@ -1,16 +1,19 @@
""" """
Base Model Controller - Abstract base class for batched inference controllers. Base Model Controller - Simple circular buffer with continuous batch processing.
Provides ping-pong buffer architecture with force-switch timeout mechanism. Replaces the complex ping-pong buffer architecture with a simple queue:
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.). - Frames arrive and go into a single deque (circular buffer)
- When batch_size frames are ready, process them
- Continue consuming batches until queue is empty
- Drop oldest frames if queue is full
""" """
import logging import logging
import threading import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
@ -28,62 +31,37 @@ class BatchFrame:
metadata: Dict = field(default_factory=dict) metadata: Dict = field(default_factory=dict)
class BufferState(Enum):
"""State of a ping-pong buffer"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"
class BaseModelController(ABC): class BaseModelController(ABC):
""" """
Abstract base class for batched inference with ping-pong buffers. Simple batched inference with circular buffer.
This controller accumulates frames from multiple streams into batches, Architecture:
processes them through an inference backend, and routes results back to - Single deque (circular buffer) for incoming frames
stream-specific callbacks. - Batch processor thread continuously consumes batches
- Frames come in fast, batches go out as fast as inference allows
Features: - Automatic frame dropping when queue is full
- Ping-pong circular buffers (BufferA/BufferB)
- Force-switch timeout to prevent batch starvation
- Event-driven processing with callbacks
- Thread-safe frame submission
Subclasses must implement:
- _run_batch_inference(): Backend-specific inference logic
""" """
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
batch_size: int = 16, batch_size: int = 16,
force_timeout: float = 0.05, max_queue_size: int = 100,
preprocess_fn: Optional[Callable] = None, preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None,
): ):
self.model_id = model_id self.model_id = model_id
self.batch_size = batch_size self.batch_size = batch_size
self.force_timeout = force_timeout self.max_queue_size = max_queue_size
self.preprocess_fn = preprocess_fn self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn self.postprocess_fn = postprocess_fn
# Ping-pong buffers # Single circular buffer
self.buffer_a: List[BatchFrame] = [] self.frame_queue = deque(maxlen=max_queue_size)
self.buffer_b: List[BatchFrame] = [] self.queue_lock = threading.Lock()
# Buffer states # Processing thread
self.active_buffer = "A" self.processor_thread: Optional[threading.Thread] = None
self.buffer_a_state = BufferState.IDLE
self.buffer_b_state = BufferState.IDLE
# Threading coordination
self.buffer_lock = threading.RLock()
self.last_submit_time = time.time()
# Threads
self.timeout_thread: Optional[threading.Thread] = None
self.processor_threads: Dict[str, threading.Thread] = {}
self.running = False self.running = False
self.stop_event = threading.Event() self.stop_event = threading.Event()
@ -93,33 +71,24 @@ class BaseModelController(ABC):
# Statistics # Statistics
self.total_frames_processed = 0 self.total_frames_processed = 0
self.total_batches_processed = 0 self.total_batches_processed = 0
self.total_frames_dropped = 0
def start(self): def start(self):
"""Start the controller background threads""" """Start the controller background thread"""
if self.running: if self.running:
logger.warning("ModelController already running") logger.warning(f"{self.__class__.__name__} already running")
return return
self.running = True self.running = True
self.stop_event.clear() self.stop_event.clear()
# Start timeout monitor thread # Start single processor thread
self.timeout_thread = threading.Thread( self.processor_thread = threading.Thread(
target=self._timeout_monitor, daemon=True target=self._batch_processor, daemon=True
) )
self.timeout_thread.start() self.processor_thread.start()
# Start processor threads for each buffer logger.info(f"{self.__class__.__name__} started (batch_size={self.batch_size})")
self.processor_threads["A"] = threading.Thread(
target=self._batch_processor, args=("A",), daemon=True
)
self.processor_threads["B"] = threading.Thread(
target=self._batch_processor, args=("B",), daemon=True
)
self.processor_threads["A"].start()
self.processor_threads["B"].start()
logger.info(f"{self.__class__.__name__} started")
def stop(self): def stop(self):
"""Stop the controller and cleanup""" """Stop the controller and cleanup"""
@ -130,16 +99,12 @@ class BaseModelController(ABC):
self.running = False self.running = False
self.stop_event.set() self.stop_event.set()
# Wait for threads to finish # Wait for thread to finish
if self.timeout_thread and self.timeout_thread.is_alive(): if self.processor_thread and self.processor_thread.is_alive():
self.timeout_thread.join(timeout=2.0) self.processor_thread.join(timeout=2.0)
for thread in self.processor_threads.values(): # Process remaining frames
if thread and thread.is_alive(): self._process_remaining_frames()
thread.join(timeout=2.0)
# Process any remaining frames
self._process_remaining_buffers()
logger.info(f"{self.__class__.__name__} stopped") logger.info(f"{self.__class__.__name__} stopped")
def register_callback(self, stream_id: str, callback: Callable): def register_callback(self, stream_id: str, callback: Callable):
@ -156,98 +121,48 @@ class BaseModelController(ABC):
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
): ):
"""Submit a frame for batched inference""" """Submit a frame for batched inference"""
with self.buffer_lock: with self.queue_lock:
# If queue is full, oldest frame is automatically dropped (deque with maxlen)
if len(self.frame_queue) >= self.max_queue_size:
self.total_frames_dropped += 1
batch_frame = BatchFrame( batch_frame = BatchFrame(
stream_id=stream_id, stream_id=stream_id,
frame=frame, frame=frame,
timestamp=time.time(), timestamp=time.time(),
metadata=metadata or {}, metadata=metadata or {},
) )
self.frame_queue.append(batch_frame)
# Add to active buffer def _batch_processor(self):
if self.active_buffer == "A": """Background thread that continuously processes batches"""
self.buffer_a.append(batch_frame) logger.info(f"{self.__class__.__name__} batch processor started")
self.buffer_a_state = BufferState.FILLING
buffer_size = len(self.buffer_a)
else:
self.buffer_b.append(batch_frame)
self.buffer_b_state = BufferState.FILLING
buffer_size = len(self.buffer_b)
self.last_submit_time = time.time()
# Check if we should immediately swap (batch full)
if buffer_size >= self.batch_size:
self._try_swap_buffers()
def _timeout_monitor(self):
"""Monitor force-switch timeout"""
while self.running and not self.stop_event.wait(0.01):
with self.buffer_lock:
time_since_submit = time.time() - self.last_submit_time
if time_since_submit >= self.force_timeout:
active_buffer = (
self.buffer_a if self.active_buffer == "A" else self.buffer_b
)
if len(active_buffer) > 0:
self._try_swap_buffers()
def _try_swap_buffers(self):
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
inactive_state = (
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
)
if inactive_state != BufferState.PROCESSING:
old_active = self.active_buffer
self.active_buffer = "B" if old_active == "A" else "A"
if old_active == "A":
self.buffer_a_state = BufferState.PROCESSING
buffer_size = len(self.buffer_a)
else:
self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b)
logger.debug(
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
)
def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available"""
while self.running and not self.stop_event.is_set(): while self.running and not self.stop_event.is_set():
time.sleep(0.001) # Check if we have enough frames for a batch
with self.queue_lock:
queue_size = len(self.frame_queue)
with self.buffer_lock: if queue_size > 0 and queue_size % 10 == 0:
if buffer_name == "A": logger.info(f"Queue size: {queue_size}/{self.batch_size}")
should_process = self.buffer_a_state == BufferState.PROCESSING
if queue_size >= self.batch_size:
# Extract batch
with self.queue_lock:
batch = []
for _ in range(min(self.batch_size, len(self.frame_queue))):
if self.frame_queue:
batch.append(self.frame_queue.popleft())
if batch:
logger.info(f"Processing batch of {len(batch)} frames")
self._process_batch(batch)
else: else:
should_process = self.buffer_b_state == BufferState.PROCESSING # Not enough frames, sleep briefly
time.sleep(0.001) # 1ms
if should_process: def _process_batch(self, batch: List[BatchFrame]):
self._process_buffer(buffer_name) """Process a batch through inference"""
def _process_buffer(self, buffer_name: str):
"""Process a buffer through inference"""
# Extract buffer to process
with self.buffer_lock:
if buffer_name == "A":
batch = self.buffer_a.copy()
self.buffer_a.clear()
else:
batch = self.buffer_b.copy()
self.buffer_b.clear()
if len(batch) == 0:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
return
# Process batch (outside lock to allow concurrent submissions)
try: try:
start_time = time.time() start_time = time.time()
results = self._run_batch_inference(batch) results = self._run_batch_inference(batch)
@ -258,7 +173,6 @@ class BaseModelController(ABC):
logger.debug( logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms" f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms"
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
) )
# Emit results to callbacks # Emit results to callbacks
@ -276,49 +190,58 @@ class BaseModelController(ABC):
except Exception as e: except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True) logger.error(f"Error processing batch: {e}", exc_info=True)
finally:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
@abstractmethod @abstractmethod
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
""" """Run inference on a batch of frames (backend-specific)"""
Run inference on a batch of frames (backend-specific).
Args:
batch: List of BatchFrame objects
Returns:
List of detection results (one per frame)
"""
pass pass
def _process_remaining_buffers(self): def _process_remaining_frames(self):
"""Process any remaining frames in buffers during shutdown""" """Process any remaining frames in queue during shutdown"""
if len(self.buffer_a) > 0: with self.queue_lock:
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A") remaining = len(self.frame_queue)
self._process_buffer("A")
if len(self.buffer_b) > 0: if remaining > 0:
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B") logger.info(f"Processing remaining {remaining} frames")
self._process_buffer("B") while True:
with self.queue_lock:
if not self.frame_queue:
break
batch = []
for _ in range(min(self.batch_size, len(self.frame_queue))):
if self.frame_queue:
batch.append(self.frame_queue.popleft())
if batch:
self._process_batch(batch)
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Get current buffer statistics""" """Get current statistics"""
with self.queue_lock:
queue_size = len(self.frame_queue)
return { return {
"active_buffer": self.active_buffer, "queue_size": queue_size,
"buffer_a_size": len(self.buffer_a), "max_queue_size": self.max_queue_size,
"buffer_b_size": len(self.buffer_b), "batch_size": self.batch_size,
"buffer_a_state": self.buffer_a_state.value,
"buffer_b_state": self.buffer_b_state.value,
"registered_streams": len(self.result_callbacks), "registered_streams": len(self.result_callbacks),
"total_frames_processed": self.total_frames_processed, "total_frames_processed": self.total_frames_processed,
"total_batches_processed": self.total_batches_processed, "total_batches_processed": self.total_batches_processed,
"total_frames_dropped": self.total_frames_dropped,
"avg_batch_size": ( "avg_batch_size": (
self.total_frames_processed / self.total_batches_processed self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0 if self.total_batches_processed > 0
else 0 else 0
), ),
} }
# Keep old BufferState enum for backwards compatibility
from enum import Enum
class BufferState(Enum):
"""Deprecated - kept for backwards compatibility"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"

View file

@ -9,6 +9,7 @@ Provides a unified interface for different inference backends:
All engines support zero-copy GPU tensor inference where possible. All engines support zero-copy GPU tensor inference where possible.
""" """
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@ -17,6 +18,8 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
logger = logging.getLogger(__name__)
class BackendType(Enum): class BackendType(Enum):
"""Supported inference backend types""" """Supported inference backend types"""
@ -423,9 +426,18 @@ class UltralyticsEngine(IInferenceEngine):
final_model_path = engine_path final_model_path = engine_path
print(f"Using TensorRT engine: {engine_path}") print(f"Using TensorRT engine: {engine_path}")
# CRITICAL: Update _model_path to point to the .engine file for metadata extraction
self._model_path = engine_path
# Load model (Ultralytics handles .engine files natively) # Load model (Ultralytics handles .engine files natively)
self._model = YOLO(final_model_path) self._model = YOLO(final_model_path)
logger.info(f"Loaded Ultralytics model: {type(self._model)}")
if hasattr(self._model, "predictor"):
logger.info(
f"Model has predictor: {type(self._model.predictor) if self._model.predictor else None}"
)
# Move to device if needed (only for .pt models, .engine already on specific device) # Move to device if needed (only for .pt models, .engine already on specific device)
if hasattr(self._model, "model") and self._model.model is not None: if hasattr(self._model, "model") and self._model.model is not None:
# Check if it's actually a torch model (not a string path for .engine files) # Check if it's actually a torch model (not a string path for .engine files)
@ -437,6 +449,39 @@ class UltralyticsEngine(IInferenceEngine):
return self._metadata return self._metadata
def _read_batch_size_from_engine_file(self, engine_path: str) -> int:
"""
Read batch size from the metadata JSON file saved next to the engine.
Much simpler than parsing TensorRT engine!
"""
try:
import json
from pathlib import Path
# The metadata file is named: <engine_path_without_extension>_metadata.json
engine_file = Path(engine_path)
metadata_file = engine_file.with_name(f"{engine_file.stem}_metadata.json")
print(f"[UltralyticsEngine] Looking for metadata file: {metadata_file}")
if metadata_file.exists():
with open(metadata_file, "r") as f:
metadata = json.load(f)
batch_size = metadata.get("batch", -1)
print(
f"[UltralyticsEngine] Found metadata: batch={batch_size}, imgsz={metadata.get('imgsz')}"
)
return batch_size
else:
print(f"[UltralyticsEngine] Metadata file not found: {metadata_file}")
except Exception as e:
print(
f"[UltralyticsEngine] Could not read batch size from metadata file: {e}"
)
return -1 # Default to dynamic
def _extract_metadata(self) -> EngineMetadata: def _extract_metadata(self) -> EngineMetadata:
"""Extract metadata from Ultralytics model""" """Extract metadata from Ultralytics model"""
# Ultralytics models typically expect (B, 3, H, W) input # Ultralytics models typically expect (B, 3, H, W) input
@ -447,6 +492,17 @@ class UltralyticsEngine(IInferenceEngine):
imgsz = 640 imgsz = 640
input_shape = (batch_size, 3, imgsz, imgsz) input_shape = (batch_size, 3, imgsz, imgsz)
# CRITICAL: For .engine files, read batch size directly from the TensorRT engine file
print(f"[UltralyticsEngine] _model_path={self._model_path}")
if self._model_path.endswith(".engine"):
print(f"[UltralyticsEngine] Reading batch size from engine file...")
batch_size = self._read_batch_size_from_engine_file(self._model_path)
print(f"[UltralyticsEngine] Read batch_size={batch_size} from .engine file")
if batch_size > 0:
input_shape = (batch_size, 3, imgsz, imgsz)
else:
print(f"[UltralyticsEngine] Not an .engine file, skipping direct read")
if hasattr(self._model, "model") and self._model.model is not None: if hasattr(self._model, "model") and self._model.model is not None:
# Try to get actual input shape from model # Try to get actual input shape from model
try: try:
@ -508,6 +564,10 @@ class UltralyticsEngine(IInferenceEngine):
logger.warning(f"Could not extract full metadata: {e}") logger.warning(f"Could not extract full metadata: {e}")
pass pass
logger.info(
f"Extracted Ultralytics metadata: batch_size={batch_size}, imgsz={imgsz}, input_shape={input_shape}"
)
return EngineMetadata( return EngineMetadata(
engine_type="ultralytics", engine_type="ultralytics",
model_path=self._model_path, model_path=self._model_path,

View file

@ -42,6 +42,7 @@ class TrackingResult:
tracked_objects: List # List of TrackedObject from TrackingController tracked_objects: List # List of TrackedObject from TrackingController
detections: List # Raw detections detections: List # Raw detections
frame_shape: Tuple[int, int, int] frame_shape: Tuple[int, int, int]
frame_tensor: Optional[torch.Tensor] # GPU tensor of the frame (C, H, W)
metadata: Dict metadata: Dict
@ -158,6 +159,9 @@ class StreamConnection:
# Submit to model controller for batched inference # Submit to model controller for batched inference
# Pass the FrameReference in metadata so we can free it later # Pass the FrameReference in metadata so we can free it later
logger.debug(
f"[{self.stream_id}] Submitting frame {self.frame_count} to model controller"
)
self.model_controller.submit_frame( self.model_controller.submit_frame(
stream_id=self.stream_id, stream_id=self.stream_id,
frame=cloned_tensor, # Use cloned tensor, not original frame=cloned_tensor, # Use cloned tensor, not original
@ -167,6 +171,9 @@ class StreamConnection:
"frame_ref": frame_ref, # Store reference for later cleanup "frame_ref": frame_ref, # Store reference for later cleanup
}, },
) )
logger.debug(
f"[{self.stream_id}] Frame {self.frame_count} submitted, queue size: {len(self.model_controller.frame_queue)}"
)
# Update connection status based on decoder status # Update connection status based on decoder status
if ( if (
@ -211,6 +218,12 @@ class StreamConnection:
frame_shape = result["metadata"].get("shape") frame_shape = result["metadata"].get("shape")
tracked_objects = self._run_tracking_sync(detections, frame_shape) tracked_objects = self._run_tracking_sync(detections, frame_shape)
# Get ORIGINAL frame tensor from metadata (not the preprocessed one in result["frame"])
# The frame in result["frame"] is preprocessed (resized, normalized)
# We need the original frame for visualization
frame_ref = result["metadata"].get("frame_ref")
frame_tensor = frame_ref.rgb_tensor if frame_ref else None
# Create tracking result # Create tracking result
tracking_result = TrackingResult( tracking_result = TrackingResult(
stream_id=self.stream_id, stream_id=self.stream_id,
@ -218,6 +231,7 @@ class StreamConnection:
tracked_objects=tracked_objects, tracked_objects=tracked_objects,
detections=detections, detections=detections,
frame_shape=result["metadata"].get("shape"), frame_shape=result["metadata"].get("shape"),
frame_tensor=frame_tensor, # Original frame, not preprocessed
metadata=result["metadata"], metadata=result["metadata"],
) )
@ -328,7 +342,7 @@ class StreamConnectionManager:
Args: Args:
gpu_id: GPU device ID (default: 0) gpu_id: GPU device ID (default: 0)
batch_size: Maximum batch size for inference (default: 16) batch_size: Maximum batch size for inference (default: 16)
force_timeout: Force buffer switch timeout in seconds (default: 0.05) max_queue_size: Maximum frames in queue before dropping (default: 100)
poll_interval: Frame polling interval in seconds (default: 0.01) poll_interval: Frame polling interval in seconds (default: 0.01)
Example: Example:
@ -343,14 +357,14 @@ class StreamConnectionManager:
self, self,
gpu_id: int = 0, gpu_id: int = 0,
batch_size: int = 16, batch_size: int = 16,
force_timeout: float = 0.05, max_queue_size: int = 100,
poll_interval: float = 0.01, poll_interval: float = 0.01,
enable_pt_conversion: bool = True, enable_pt_conversion: bool = True,
backend: str = "tensorrt", # "tensorrt" or "ultralytics" backend: str = "tensorrt", # "tensorrt" or "ultralytics"
): ):
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.batch_size = batch_size self.batch_size = batch_size
self.force_timeout = force_timeout self.max_queue_size = max_queue_size
self.poll_interval = poll_interval self.poll_interval = poll_interval
self.backend = backend.lower() self.backend = backend.lower()
@ -449,7 +463,7 @@ class StreamConnectionManager:
inference_engine=self.inference_engine, inference_engine=self.inference_engine,
model_id=model_id, model_id=model_id,
batch_size=self.batch_size, batch_size=self.batch_size,
force_timeout=self.force_timeout, max_queue_size=self.max_queue_size,
preprocess_fn=preprocess_fn, preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn, postprocess_fn=postprocess_fn,
) )
@ -473,7 +487,7 @@ class StreamConnectionManager:
model_repository=self.model_repository, model_repository=self.model_repository,
model_id=model_id, model_id=model_id,
batch_size=self.batch_size, batch_size=self.batch_size,
force_timeout=self.force_timeout, max_queue_size=self.max_queue_size,
preprocess_fn=preprocess_fn, preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn, postprocess_fn=postprocess_fn,
) )
@ -656,7 +670,7 @@ class StreamConnectionManager:
"gpu_id": self.gpu_id, "gpu_id": self.gpu_id,
"num_connections": len(self.connections), "num_connections": len(self.connections),
"batch_size": self.batch_size, "batch_size": self.batch_size,
"force_timeout": self.force_timeout, "max_queue_size": self.max_queue_size,
"poll_interval": self.poll_interval, "poll_interval": self.poll_interval,
}, },
"model_controller": self.model_controller.get_stats() "model_controller": self.model_controller.get_stats()

View file

@ -25,14 +25,14 @@ class TensorRTModelController(BaseModelController):
model_repository, model_repository,
model_id: str, model_id: str,
batch_size: int = 16, batch_size: int = 16,
force_timeout: float = 0.05, max_queue_size: int = 100,
preprocess_fn: Optional[Callable] = None, preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None,
): ):
super().__init__( super().__init__(
model_id=model_id, model_id=model_id,
batch_size=batch_size, batch_size=batch_size,
force_timeout=force_timeout, max_queue_size=max_queue_size,
preprocess_fn=preprocess_fn, preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn, postprocess_fn=postprocess_fn,
) )
@ -115,6 +115,7 @@ class TensorRTModelController(BaseModelController):
"stream_id": batch_frame.stream_id, "stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp, "timestamp": batch_frame.timestamp,
"detections": detections, "detections": detections,
"frame": batch_frame.frame, # Include original frame tensor
"metadata": batch_frame.metadata, "metadata": batch_frame.metadata,
} }
results.append(result) results.append(result)
@ -175,6 +176,7 @@ class TensorRTModelController(BaseModelController):
"stream_id": batch_frame.stream_id, "stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp, "timestamp": batch_frame.timestamp,
"detections": detections, "detections": detections,
"frame": batch_frame.frame, # Include original frame tensor
"metadata": batch_frame.metadata, "metadata": batch_frame.metadata,
} }
results.append(result) results.append(result)

View file

@ -25,20 +25,27 @@ class UltralyticsModelController(BaseModelController):
inference_engine, inference_engine,
model_id: str, model_id: str,
batch_size: int = 16, batch_size: int = 16,
force_timeout: float = 0.05, max_queue_size: int = 100,
preprocess_fn: Optional[Callable] = None, preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None,
): ):
# Auto-detect actual batch size from the YOLO engine # Auto-detect actual batch size from the YOLO engine
print(f"[UltralyticsModelController] Detecting batch size from engine...")
engine_batch_size = self._detect_engine_batch_size(inference_engine) engine_batch_size = self._detect_engine_batch_size(inference_engine)
print(
f"[UltralyticsModelController] Detected engine_batch_size={engine_batch_size}"
)
# If engine has fixed batch size, use it. Otherwise use user's batch_size # If engine has fixed batch size, use it. Otherwise use user's batch_size
actual_batch_size = engine_batch_size if engine_batch_size > 0 else batch_size actual_batch_size = engine_batch_size if engine_batch_size > 0 else batch_size
print(
f"[UltralyticsModelController] Using actual_batch_size={actual_batch_size}"
)
super().__init__( super().__init__(
model_id=model_id, model_id=model_id,
batch_size=actual_batch_size, batch_size=actual_batch_size,
force_timeout=force_timeout, max_queue_size=max_queue_size,
preprocess_fn=preprocess_fn, preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn, postprocess_fn=postprocess_fn,
) )
@ -46,11 +53,23 @@ class UltralyticsModelController(BaseModelController):
self.engine_batch_size = engine_batch_size # Store for padding logic self.engine_batch_size = engine_batch_size # Store for padding logic
if engine_batch_size > 0: if engine_batch_size > 0:
print(f"✓ Ultralytics engine has FIXED batch_size={engine_batch_size}")
print(
f" Will pad/truncate all batches to exactly {engine_batch_size} frames"
)
logger.info( logger.info(
f"Ultralytics engine has fixed batch_size={engine_batch_size}, " f"Ultralytics engine has fixed batch_size={engine_batch_size}, "
f"will pad batches to match" f"will pad batches to match"
) )
# CRITICAL: Override the parent's batch_size to match engine's fixed size
# This prevents buffer accumulation beyond the engine's capacity
self.batch_size = engine_batch_size
print(f" Controller self.batch_size is now: {self.batch_size}")
print(f" Buffer will swap when size >= {self.batch_size}")
else: else:
print(
f"✓ Ultralytics engine supports DYNAMIC batching, max={actual_batch_size}"
)
logger.info( logger.info(
f"Ultralytics engine supports dynamic batching, " f"Ultralytics engine supports dynamic batching, "
f"using max batch_size={actual_batch_size}" f"using max batch_size={actual_batch_size}"
@ -67,16 +86,22 @@ class UltralyticsModelController(BaseModelController):
# Get engine metadata # Get engine metadata
metadata = inference_engine.get_metadata() metadata = inference_engine.get_metadata()
logger.info(f"Detecting batch size from engine metadata: {metadata}")
# Check input shape for batch dimension # Check input shape for batch dimension
if "images" in metadata.input_shapes: if "images" in metadata.input_shapes:
input_shape = metadata.input_shapes["images"] input_shape = metadata.input_shapes["images"]
batch_dim = input_shape[0] batch_dim = input_shape[0]
logger.info(f"Found batch dimension in metadata: {batch_dim}")
if batch_dim > 0: if batch_dim > 0:
# Fixed batch size # Fixed batch size
logger.info(f"Using fixed batch size from engine: {batch_dim}")
return batch_dim return batch_dim
else: else:
# Dynamic batch size (-1) # Dynamic batch size (-1)
logger.info("Engine supports dynamic batching (batch_dim=-1)")
return -1 return -1
# Fallback: try to get from model directly # Fallback: try to get from model directly
@ -187,28 +212,16 @@ class UltralyticsModelController(BaseModelController):
# No detections # No detections
detections = torch.zeros((0, 6), device=batch_tensor.device) detections = torch.zeros((0, 6), device=batch_tensor.device)
# Apply custom postprocessing if provided # NOTE: Skip postprocess_fn for Ultralytics backend!
if self.postprocess_fn: # Ultralytics already does confidence filtering, NMS, and format conversion.
try: # The detections are already in final format: [x1, y1, x2, y2, conf, cls]
# For Ultralytics, postprocess_fn might do additional filtering # Any custom postprocess_fn would expect raw TensorRT output and will fail.
# Pass the raw boxes tensor in the same format as TensorRT output
detections = self.postprocess_fn(
{
"output0": detections.unsqueeze(
0
) # Add batch dim for compatibility
}
)
except Exception as e:
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
detections = torch.zeros((0, 6), device=batch_tensor.device)
result = { result = {
"stream_id": batch_frame.stream_id, "stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp, "timestamp": batch_frame.timestamp,
"detections": detections, "detections": detections,
"frame": batch_frame.frame, # Include original frame tensor
"metadata": batch_frame.metadata, "metadata": batch_frame.metadata,
"yolo_result": yolo_result, # Keep original Results object for debugging "yolo_result": yolo_result, # Keep original Results object for debugging
} }

View file

@ -4,11 +4,11 @@ Real-time object tracking with event-driven batching architecture.
This script demonstrates: This script demonstrates:
- Event-driven stream processing with StreamConnectionManager - Event-driven stream processing with StreamConnectionManager
- Batched GPU inference with ModelController - Batched GPU inference with ModelController
- Ping-pong buffer architecture for optimal throughput
- Callback-based event-driven pattern for RTSP streams - Callback-based event-driven pattern for RTSP streams
- Automatic PT to TensorRT conversion - Automatic PT to TensorRT conversion
""" """
import logging
import os import os
import threading import threading
import time import time
@ -28,6 +28,11 @@ from services import (
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
# Enable debug logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
def main_multi_stream(): def main_multi_stream():
"""Multi-stream example with batched inference.""" """Multi-stream example with batched inference."""
@ -41,8 +46,8 @@ def main_multi_stream():
USE_ULTRALYTICS = ( USE_ULTRALYTICS = (
os.getenv("USE_ULTRALYTICS", "true").lower() == "true" os.getenv("USE_ULTRALYTICS", "true").lower() == "true"
) # Use Ultralytics engine for YOLO ) # Use Ultralytics engine for YOLO
BATCH_SIZE = 2 # Reduced to 2 to avoid GPU memory issues BATCH_SIZE = 2 # Must match engine's fixed batch size
FORCE_TIMEOUT = 0.05 MAX_QUEUE_SIZE = 50 # Drop frames if queue gets too long
ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true" ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true"
# Load camera URLs # Load camera URLs
@ -73,10 +78,11 @@ def main_multi_stream():
manager = StreamConnectionManager( manager = StreamConnectionManager(
gpu_id=GPU_ID, gpu_id=GPU_ID,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
force_timeout=FORCE_TIMEOUT, max_queue_size=MAX_QUEUE_SIZE,
enable_pt_conversion=True, enable_pt_conversion=True,
backend=backend, backend=backend,
) )
print("✓ Manager created") print("✓ Manager created")
# Initialize model (transparent loading) # Initialize model (transparent loading)
@ -86,7 +92,6 @@ def main_multi_stream():
model_path=MODEL_PATH, model_path=MODEL_PATH,
model_id="detector", model_id="detector",
preprocess_fn=YOLOv8Utils.preprocess, preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
num_contexts=1, # Single context to minimize GPU memory usage num_contexts=1, # Single context to minimize GPU memory usage
# Note: No pt_input_shapes or pt_precision needed for YOLO models! # Note: No pt_input_shapes or pt_precision needed for YOLO models!
) )
@ -98,6 +103,109 @@ def main_multi_stream():
traceback.print_exc() traceback.print_exc()
return return
# Track stats (initialize before callback definition)
stream_stats = {sid: {"count": 0, "start": time.time()} for sid, _ in camera_urls}
total_results = 0
start_time = time.time()
stats_lock = threading.Lock()
# Create windows for each stream if display enabled
if ENABLE_DISPLAY:
for stream_id, _ in camera_urls:
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
cv2.resizeWindow(
stream_id, 640, 360
) # Smaller windows for multiple streams
def on_tracking_result(result):
"""Callback for tracking results - called automatically per stream"""
nonlocal total_results
# Debug: Check if we have frame tensor
has_frame = result.frame_tensor is not None
frame_shape = result.frame_tensor.shape if has_frame else None
print(
f"[CALLBACK] Got result for {result.stream_id}, has_frame={has_frame}, shape={frame_shape}, detections={len(result.detections)}"
)
with stats_lock:
total_results += 1
stream_id = result.stream_id
if stream_id in stream_stats:
stream_stats[stream_id]["count"] += 1
# Print stats every 10 results (changed from 100 for faster feedback)
if total_results % 10 == 0:
elapsed = time.time() - start_time
total_fps = total_results / elapsed if elapsed > 0 else 0
print(
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
)
for sid, stats in stream_stats.items():
s_elapsed = time.time() - stats["start"]
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
# Display visualization if enabled
if ENABLE_DISPLAY and result.frame_tensor is not None:
# Convert GPU tensor (C, H, W) to CPU numpy (H, W, C) for OpenCV
frame_tensor = result.frame_tensor # (3, 720, 1280) RGB uint8
frame_np = (
frame_tensor.cpu().permute(1, 2, 0).numpy().astype(np.uint8)
) # (720, 1280, 3)
frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
# Draw bounding boxes
for obj in result.tracked_objects:
x1, y1, x2, y2 = map(int, obj.bbox)
# Draw box
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Draw label with ID and class
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
(label_w, label_h), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle(
frame_bgr,
(x1, y1 - label_h - 10),
(x1 + label_w, y1),
(0, 255, 0),
-1,
)
cv2.putText(
frame_bgr,
label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0),
1,
)
# Show FPS on frame
with stats_lock:
s_elapsed = time.time() - stream_stats[stream_id]["start"]
s_fps = (
stream_stats[stream_id]["count"] / s_elapsed if s_elapsed > 0 else 0
)
fps_text = (
f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
)
cv2.putText(
frame_bgr,
fps_text,
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 255, 0),
2,
)
# Display
cv2.imshow(stream_id, frame_bgr)
# Connect all streams in parallel using threads # Connect all streams in parallel using threads
print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...") print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...")
connections = {} connections = {}
@ -108,7 +216,10 @@ def main_multi_stream():
"""Thread worker to connect a single stream""" """Thread worker to connect a single stream"""
try: try:
conn = manager.connect_stream( conn = manager.connect_stream(
rtsp_url=rtsp_url, stream_id=stream_id, buffer_size=3 rtsp_url=rtsp_url,
stream_id=stream_id,
buffer_size=2,
on_tracking_result=on_tracking_result, # Register callback
) )
connection_results[stream_id] = ("success", conn) connection_results[stream_id] = ("success", conn)
except Exception as e: except Exception as e:
@ -144,124 +255,15 @@ def main_multi_stream():
print("Press Ctrl+C to stop") print("Press Ctrl+C to stop")
print(f"{'=' * 80}\n") print(f"{'=' * 80}\n")
# Track stats
stream_stats = {
sid: {"count": 0, "start": time.time()} for sid in connections.keys()
}
total_results = 0
start_time = time.time()
# Create windows for each stream if display enabled
if ENABLE_DISPLAY:
for stream_id in connections.keys():
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
cv2.resizeWindow(
stream_id, 640, 360
) # Smaller windows for multiple streams
try: try:
# Merge all result queues from all connections # Keep main thread alive and process OpenCV events
import queue as queue_module while True:
running = True
while running:
# Poll all connection queues (non-blocking)
got_result = False
for conn in connections.values():
try:
# Non-blocking get from each connection's queue
result = conn.result_queue.get_nowait()
got_result = True
total_results += 1
stream_id = result.stream_id
if stream_id in stream_stats:
stream_stats[stream_id]["count"] += 1
# Display visualization if enabled
if ENABLE_DISPLAY: if ENABLE_DISPLAY:
# Get latest frame from decoder (already in CPU memory as numpy RGB)
frame_rgb = conn.decoder.get_latest_frame_cpu(rgb=True)
if frame_rgb is not None:
# Convert RGB to BGR for OpenCV
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
# Draw bounding boxes
for obj in result.tracked_objects:
x1, y1, x2, y2 = map(int, obj.bbox)
# Draw box
cv2.rectangle(
frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2
)
# Draw label with ID and class
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
(label_w, label_h), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle(
frame_bgr,
(x1, y1 - label_h - 10),
(x1 + label_w, y1),
(0, 255, 0),
-1,
)
cv2.putText(
frame_bgr,
label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0),
1,
)
# Show FPS on frame
s_elapsed = time.time() - stream_stats[stream_id]["start"]
s_fps = (
stream_stats[stream_id]["count"] / s_elapsed
if s_elapsed > 0
else 0
)
fps_text = f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
cv2.putText(
frame_bgr,
fps_text,
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 255, 0),
2,
)
# Display
cv2.imshow(stream_id, frame_bgr)
# Print stats every 100 results
if total_results % 100 == 0:
elapsed = time.time() - start_time
total_fps = total_results / elapsed if elapsed > 0 else 0
print(
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
)
for sid, stats in stream_stats.items():
s_elapsed = time.time() - stats["start"]
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
except queue_module.Empty:
continue
# Process OpenCV events to keep windows responsive # Process OpenCV events to keep windows responsive
if ENABLE_DISPLAY: if cv2.waitKey(1) & 0xFF == ord("q"):
cv2.waitKey(1) break
else:
# Small sleep if no results to avoid busy loop time.sleep(0.1)
if not got_result:
time.sleep(0.01)
except KeyboardInterrupt: except KeyboardInterrupt:
print(f"\n✓ Interrupted") print(f"\n✓ Interrupted")