new buffer paradigm
This commit is contained in:
parent
fdaeb9981c
commit
a519dea130
6 changed files with 341 additions and 327 deletions
|
|
@ -1,16 +1,19 @@
|
|||
"""
|
||||
Base Model Controller - Abstract base class for batched inference controllers.
|
||||
Base Model Controller - Simple circular buffer with continuous batch processing.
|
||||
|
||||
Provides ping-pong buffer architecture with force-switch timeout mechanism.
|
||||
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.).
|
||||
Replaces the complex ping-pong buffer architecture with a simple queue:
|
||||
- Frames arrive and go into a single deque (circular buffer)
|
||||
- When batch_size frames are ready, process them
|
||||
- Continue consuming batches until queue is empty
|
||||
- Drop oldest frames if queue is full
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
|
@ -28,62 +31,37 @@ class BatchFrame:
|
|||
metadata: Dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class BufferState(Enum):
|
||||
"""State of a ping-pong buffer"""
|
||||
|
||||
IDLE = "idle"
|
||||
FILLING = "filling"
|
||||
PROCESSING = "processing"
|
||||
|
||||
|
||||
class BaseModelController(ABC):
|
||||
"""
|
||||
Abstract base class for batched inference with ping-pong buffers.
|
||||
Simple batched inference with circular buffer.
|
||||
|
||||
This controller accumulates frames from multiple streams into batches,
|
||||
processes them through an inference backend, and routes results back to
|
||||
stream-specific callbacks.
|
||||
|
||||
Features:
|
||||
- Ping-pong circular buffers (BufferA/BufferB)
|
||||
- Force-switch timeout to prevent batch starvation
|
||||
- Event-driven processing with callbacks
|
||||
- Thread-safe frame submission
|
||||
|
||||
Subclasses must implement:
|
||||
- _run_batch_inference(): Backend-specific inference logic
|
||||
Architecture:
|
||||
- Single deque (circular buffer) for incoming frames
|
||||
- Batch processor thread continuously consumes batches
|
||||
- Frames come in fast, batches go out as fast as inference allows
|
||||
- Automatic frame dropping when queue is full
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
max_queue_size: int = 100,
|
||||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.batch_size = batch_size
|
||||
self.force_timeout = force_timeout
|
||||
self.max_queue_size = max_queue_size
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self.postprocess_fn = postprocess_fn
|
||||
|
||||
# Ping-pong buffers
|
||||
self.buffer_a: List[BatchFrame] = []
|
||||
self.buffer_b: List[BatchFrame] = []
|
||||
# Single circular buffer
|
||||
self.frame_queue = deque(maxlen=max_queue_size)
|
||||
self.queue_lock = threading.Lock()
|
||||
|
||||
# Buffer states
|
||||
self.active_buffer = "A"
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
# Threading coordination
|
||||
self.buffer_lock = threading.RLock()
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Threads
|
||||
self.timeout_thread: Optional[threading.Thread] = None
|
||||
self.processor_threads: Dict[str, threading.Thread] = {}
|
||||
# Processing thread
|
||||
self.processor_thread: Optional[threading.Thread] = None
|
||||
self.running = False
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
|
|
@ -93,33 +71,24 @@ class BaseModelController(ABC):
|
|||
# Statistics
|
||||
self.total_frames_processed = 0
|
||||
self.total_batches_processed = 0
|
||||
self.total_frames_dropped = 0
|
||||
|
||||
def start(self):
|
||||
"""Start the controller background threads"""
|
||||
"""Start the controller background thread"""
|
||||
if self.running:
|
||||
logger.warning("ModelController already running")
|
||||
logger.warning(f"{self.__class__.__name__} already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.stop_event.clear()
|
||||
|
||||
# Start timeout monitor thread
|
||||
self.timeout_thread = threading.Thread(
|
||||
target=self._timeout_monitor, daemon=True
|
||||
# Start single processor thread
|
||||
self.processor_thread = threading.Thread(
|
||||
target=self._batch_processor, daemon=True
|
||||
)
|
||||
self.timeout_thread.start()
|
||||
self.processor_thread.start()
|
||||
|
||||
# Start processor threads for each buffer
|
||||
self.processor_threads["A"] = threading.Thread(
|
||||
target=self._batch_processor, args=("A",), daemon=True
|
||||
)
|
||||
self.processor_threads["B"] = threading.Thread(
|
||||
target=self._batch_processor, args=("B",), daemon=True
|
||||
)
|
||||
self.processor_threads["A"].start()
|
||||
self.processor_threads["B"].start()
|
||||
|
||||
logger.info(f"{self.__class__.__name__} started")
|
||||
logger.info(f"{self.__class__.__name__} started (batch_size={self.batch_size})")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and cleanup"""
|
||||
|
|
@ -130,16 +99,12 @@ class BaseModelController(ABC):
|
|||
self.running = False
|
||||
self.stop_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if self.timeout_thread and self.timeout_thread.is_alive():
|
||||
self.timeout_thread.join(timeout=2.0)
|
||||
# Wait for thread to finish
|
||||
if self.processor_thread and self.processor_thread.is_alive():
|
||||
self.processor_thread.join(timeout=2.0)
|
||||
|
||||
for thread in self.processor_threads.values():
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
# Process any remaining frames
|
||||
self._process_remaining_buffers()
|
||||
# Process remaining frames
|
||||
self._process_remaining_frames()
|
||||
logger.info(f"{self.__class__.__name__} stopped")
|
||||
|
||||
def register_callback(self, stream_id: str, callback: Callable):
|
||||
|
|
@ -156,98 +121,48 @@ class BaseModelController(ABC):
|
|||
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
||||
):
|
||||
"""Submit a frame for batched inference"""
|
||||
with self.buffer_lock:
|
||||
with self.queue_lock:
|
||||
# If queue is full, oldest frame is automatically dropped (deque with maxlen)
|
||||
if len(self.frame_queue) >= self.max_queue_size:
|
||||
self.total_frames_dropped += 1
|
||||
|
||||
batch_frame = BatchFrame(
|
||||
stream_id=stream_id,
|
||||
frame=frame,
|
||||
timestamp=time.time(),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.frame_queue.append(batch_frame)
|
||||
|
||||
# Add to active buffer
|
||||
if self.active_buffer == "A":
|
||||
self.buffer_a.append(batch_frame)
|
||||
self.buffer_a_state = BufferState.FILLING
|
||||
buffer_size = len(self.buffer_a)
|
||||
else:
|
||||
self.buffer_b.append(batch_frame)
|
||||
self.buffer_b_state = BufferState.FILLING
|
||||
buffer_size = len(self.buffer_b)
|
||||
def _batch_processor(self):
|
||||
"""Background thread that continuously processes batches"""
|
||||
logger.info(f"{self.__class__.__name__} batch processor started")
|
||||
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Check if we should immediately swap (batch full)
|
||||
if buffer_size >= self.batch_size:
|
||||
self._try_swap_buffers()
|
||||
|
||||
def _timeout_monitor(self):
|
||||
"""Monitor force-switch timeout"""
|
||||
while self.running and not self.stop_event.wait(0.01):
|
||||
with self.buffer_lock:
|
||||
time_since_submit = time.time() - self.last_submit_time
|
||||
|
||||
if time_since_submit >= self.force_timeout:
|
||||
active_buffer = (
|
||||
self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||
)
|
||||
if len(active_buffer) > 0:
|
||||
self._try_swap_buffers()
|
||||
|
||||
def _try_swap_buffers(self):
|
||||
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
|
||||
inactive_state = (
|
||||
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
||||
)
|
||||
|
||||
if inactive_state != BufferState.PROCESSING:
|
||||
old_active = self.active_buffer
|
||||
self.active_buffer = "B" if old_active == "A" else "A"
|
||||
|
||||
if old_active == "A":
|
||||
self.buffer_a_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_a)
|
||||
else:
|
||||
self.buffer_b_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_b)
|
||||
|
||||
logger.debug(
|
||||
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
|
||||
)
|
||||
|
||||
def _batch_processor(self, buffer_name: str):
|
||||
"""Background thread that processes a specific buffer when available"""
|
||||
while self.running and not self.stop_event.is_set():
|
||||
time.sleep(0.001)
|
||||
# Check if we have enough frames for a batch
|
||||
with self.queue_lock:
|
||||
queue_size = len(self.frame_queue)
|
||||
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
should_process = self.buffer_a_state == BufferState.PROCESSING
|
||||
else:
|
||||
should_process = self.buffer_b_state == BufferState.PROCESSING
|
||||
if queue_size > 0 and queue_size % 10 == 0:
|
||||
logger.info(f"Queue size: {queue_size}/{self.batch_size}")
|
||||
|
||||
if should_process:
|
||||
self._process_buffer(buffer_name)
|
||||
if queue_size >= self.batch_size:
|
||||
# Extract batch
|
||||
with self.queue_lock:
|
||||
batch = []
|
||||
for _ in range(min(self.batch_size, len(self.frame_queue))):
|
||||
if self.frame_queue:
|
||||
batch.append(self.frame_queue.popleft())
|
||||
|
||||
def _process_buffer(self, buffer_name: str):
|
||||
"""Process a buffer through inference"""
|
||||
# Extract buffer to process
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
batch = self.buffer_a.copy()
|
||||
self.buffer_a.clear()
|
||||
if batch:
|
||||
logger.info(f"Processing batch of {len(batch)} frames")
|
||||
self._process_batch(batch)
|
||||
else:
|
||||
batch = self.buffer_b.copy()
|
||||
self.buffer_b.clear()
|
||||
# Not enough frames, sleep briefly
|
||||
time.sleep(0.001) # 1ms
|
||||
|
||||
if len(batch) == 0:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
return
|
||||
|
||||
# Process batch (outside lock to allow concurrent submissions)
|
||||
def _process_batch(self, batch: List[BatchFrame]):
|
||||
"""Process a batch through inference"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
results = self._run_batch_inference(batch)
|
||||
|
|
@ -257,8 +172,7 @@ class BaseModelController(ABC):
|
|||
self.total_batches_processed += 1
|
||||
|
||||
logger.debug(
|
||||
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
|
||||
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
|
||||
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
# Emit results to callbacks
|
||||
|
|
@ -276,49 +190,58 @@ class BaseModelController(ABC):
|
|||
except Exception as e:
|
||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
@abstractmethod
|
||||
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run inference on a batch of frames (backend-specific).
|
||||
|
||||
Args:
|
||||
batch: List of BatchFrame objects
|
||||
|
||||
Returns:
|
||||
List of detection results (one per frame)
|
||||
"""
|
||||
"""Run inference on a batch of frames (backend-specific)"""
|
||||
pass
|
||||
|
||||
def _process_remaining_buffers(self):
|
||||
"""Process any remaining frames in buffers during shutdown"""
|
||||
if len(self.buffer_a) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
||||
self._process_buffer("A")
|
||||
if len(self.buffer_b) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
||||
self._process_buffer("B")
|
||||
def _process_remaining_frames(self):
|
||||
"""Process any remaining frames in queue during shutdown"""
|
||||
with self.queue_lock:
|
||||
remaining = len(self.frame_queue)
|
||||
|
||||
if remaining > 0:
|
||||
logger.info(f"Processing remaining {remaining} frames")
|
||||
while True:
|
||||
with self.queue_lock:
|
||||
if not self.frame_queue:
|
||||
break
|
||||
batch = []
|
||||
for _ in range(min(self.batch_size, len(self.frame_queue))):
|
||||
if self.frame_queue:
|
||||
batch.append(self.frame_queue.popleft())
|
||||
|
||||
if batch:
|
||||
self._process_batch(batch)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current buffer statistics"""
|
||||
"""Get current statistics"""
|
||||
with self.queue_lock:
|
||||
queue_size = len(self.frame_queue)
|
||||
|
||||
return {
|
||||
"active_buffer": self.active_buffer,
|
||||
"buffer_a_size": len(self.buffer_a),
|
||||
"buffer_b_size": len(self.buffer_b),
|
||||
"buffer_a_state": self.buffer_a_state.value,
|
||||
"buffer_b_state": self.buffer_b_state.value,
|
||||
"queue_size": queue_size,
|
||||
"max_queue_size": self.max_queue_size,
|
||||
"batch_size": self.batch_size,
|
||||
"registered_streams": len(self.result_callbacks),
|
||||
"total_frames_processed": self.total_frames_processed,
|
||||
"total_batches_processed": self.total_batches_processed,
|
||||
"total_frames_dropped": self.total_frames_dropped,
|
||||
"avg_batch_size": (
|
||||
self.total_frames_processed / self.total_batches_processed
|
||||
if self.total_batches_processed > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Keep old BufferState enum for backwards compatibility
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BufferState(Enum):
|
||||
"""Deprecated - kept for backwards compatibility"""
|
||||
|
||||
IDLE = "idle"
|
||||
FILLING = "filling"
|
||||
PROCESSING = "processing"
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ Provides a unified interface for different inference backends:
|
|||
All engines support zero-copy GPU tensor inference where possible.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
|
@ -17,6 +18,8 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BackendType(Enum):
|
||||
"""Supported inference backend types"""
|
||||
|
|
@ -423,9 +426,18 @@ class UltralyticsEngine(IInferenceEngine):
|
|||
final_model_path = engine_path
|
||||
print(f"Using TensorRT engine: {engine_path}")
|
||||
|
||||
# CRITICAL: Update _model_path to point to the .engine file for metadata extraction
|
||||
self._model_path = engine_path
|
||||
|
||||
# Load model (Ultralytics handles .engine files natively)
|
||||
self._model = YOLO(final_model_path)
|
||||
|
||||
logger.info(f"Loaded Ultralytics model: {type(self._model)}")
|
||||
if hasattr(self._model, "predictor"):
|
||||
logger.info(
|
||||
f"Model has predictor: {type(self._model.predictor) if self._model.predictor else None}"
|
||||
)
|
||||
|
||||
# Move to device if needed (only for .pt models, .engine already on specific device)
|
||||
if hasattr(self._model, "model") and self._model.model is not None:
|
||||
# Check if it's actually a torch model (not a string path for .engine files)
|
||||
|
|
@ -437,6 +449,39 @@ class UltralyticsEngine(IInferenceEngine):
|
|||
|
||||
return self._metadata
|
||||
|
||||
def _read_batch_size_from_engine_file(self, engine_path: str) -> int:
|
||||
"""
|
||||
Read batch size from the metadata JSON file saved next to the engine.
|
||||
|
||||
Much simpler than parsing TensorRT engine!
|
||||
"""
|
||||
try:
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# The metadata file is named: <engine_path_without_extension>_metadata.json
|
||||
engine_file = Path(engine_path)
|
||||
metadata_file = engine_file.with_name(f"{engine_file.stem}_metadata.json")
|
||||
|
||||
print(f"[UltralyticsEngine] Looking for metadata file: {metadata_file}")
|
||||
|
||||
if metadata_file.exists():
|
||||
with open(metadata_file, "r") as f:
|
||||
metadata = json.load(f)
|
||||
batch_size = metadata.get("batch", -1)
|
||||
print(
|
||||
f"[UltralyticsEngine] Found metadata: batch={batch_size}, imgsz={metadata.get('imgsz')}"
|
||||
)
|
||||
return batch_size
|
||||
else:
|
||||
print(f"[UltralyticsEngine] Metadata file not found: {metadata_file}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[UltralyticsEngine] Could not read batch size from metadata file: {e}"
|
||||
)
|
||||
|
||||
return -1 # Default to dynamic
|
||||
|
||||
def _extract_metadata(self) -> EngineMetadata:
|
||||
"""Extract metadata from Ultralytics model"""
|
||||
# Ultralytics models typically expect (B, 3, H, W) input
|
||||
|
|
@ -447,6 +492,17 @@ class UltralyticsEngine(IInferenceEngine):
|
|||
imgsz = 640
|
||||
input_shape = (batch_size, 3, imgsz, imgsz)
|
||||
|
||||
# CRITICAL: For .engine files, read batch size directly from the TensorRT engine file
|
||||
print(f"[UltralyticsEngine] _model_path={self._model_path}")
|
||||
if self._model_path.endswith(".engine"):
|
||||
print(f"[UltralyticsEngine] Reading batch size from engine file...")
|
||||
batch_size = self._read_batch_size_from_engine_file(self._model_path)
|
||||
print(f"[UltralyticsEngine] Read batch_size={batch_size} from .engine file")
|
||||
if batch_size > 0:
|
||||
input_shape = (batch_size, 3, imgsz, imgsz)
|
||||
else:
|
||||
print(f"[UltralyticsEngine] Not an .engine file, skipping direct read")
|
||||
|
||||
if hasattr(self._model, "model") and self._model.model is not None:
|
||||
# Try to get actual input shape from model
|
||||
try:
|
||||
|
|
@ -508,6 +564,10 @@ class UltralyticsEngine(IInferenceEngine):
|
|||
logger.warning(f"Could not extract full metadata: {e}")
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"Extracted Ultralytics metadata: batch_size={batch_size}, imgsz={imgsz}, input_shape={input_shape}"
|
||||
)
|
||||
|
||||
return EngineMetadata(
|
||||
engine_type="ultralytics",
|
||||
model_path=self._model_path,
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ class TrackingResult:
|
|||
tracked_objects: List # List of TrackedObject from TrackingController
|
||||
detections: List # Raw detections
|
||||
frame_shape: Tuple[int, int, int]
|
||||
frame_tensor: Optional[torch.Tensor] # GPU tensor of the frame (C, H, W)
|
||||
metadata: Dict
|
||||
|
||||
|
||||
|
|
@ -158,6 +159,9 @@ class StreamConnection:
|
|||
|
||||
# Submit to model controller for batched inference
|
||||
# Pass the FrameReference in metadata so we can free it later
|
||||
logger.debug(
|
||||
f"[{self.stream_id}] Submitting frame {self.frame_count} to model controller"
|
||||
)
|
||||
self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=cloned_tensor, # Use cloned tensor, not original
|
||||
|
|
@ -167,6 +171,9 @@ class StreamConnection:
|
|||
"frame_ref": frame_ref, # Store reference for later cleanup
|
||||
},
|
||||
)
|
||||
logger.debug(
|
||||
f"[{self.stream_id}] Frame {self.frame_count} submitted, queue size: {len(self.model_controller.frame_queue)}"
|
||||
)
|
||||
|
||||
# Update connection status based on decoder status
|
||||
if (
|
||||
|
|
@ -211,6 +218,12 @@ class StreamConnection:
|
|||
frame_shape = result["metadata"].get("shape")
|
||||
tracked_objects = self._run_tracking_sync(detections, frame_shape)
|
||||
|
||||
# Get ORIGINAL frame tensor from metadata (not the preprocessed one in result["frame"])
|
||||
# The frame in result["frame"] is preprocessed (resized, normalized)
|
||||
# We need the original frame for visualization
|
||||
frame_ref = result["metadata"].get("frame_ref")
|
||||
frame_tensor = frame_ref.rgb_tensor if frame_ref else None
|
||||
|
||||
# Create tracking result
|
||||
tracking_result = TrackingResult(
|
||||
stream_id=self.stream_id,
|
||||
|
|
@ -218,6 +231,7 @@ class StreamConnection:
|
|||
tracked_objects=tracked_objects,
|
||||
detections=detections,
|
||||
frame_shape=result["metadata"].get("shape"),
|
||||
frame_tensor=frame_tensor, # Original frame, not preprocessed
|
||||
metadata=result["metadata"],
|
||||
)
|
||||
|
||||
|
|
@ -328,7 +342,7 @@ class StreamConnectionManager:
|
|||
Args:
|
||||
gpu_id: GPU device ID (default: 0)
|
||||
batch_size: Maximum batch size for inference (default: 16)
|
||||
force_timeout: Force buffer switch timeout in seconds (default: 0.05)
|
||||
max_queue_size: Maximum frames in queue before dropping (default: 100)
|
||||
poll_interval: Frame polling interval in seconds (default: 0.01)
|
||||
|
||||
Example:
|
||||
|
|
@ -343,14 +357,14 @@ class StreamConnectionManager:
|
|||
self,
|
||||
gpu_id: int = 0,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
max_queue_size: int = 100,
|
||||
poll_interval: float = 0.01,
|
||||
enable_pt_conversion: bool = True,
|
||||
backend: str = "tensorrt", # "tensorrt" or "ultralytics"
|
||||
):
|
||||
self.gpu_id = gpu_id
|
||||
self.batch_size = batch_size
|
||||
self.force_timeout = force_timeout
|
||||
self.max_queue_size = max_queue_size
|
||||
self.poll_interval = poll_interval
|
||||
self.backend = backend.lower()
|
||||
|
||||
|
|
@ -449,7 +463,7 @@ class StreamConnectionManager:
|
|||
inference_engine=self.inference_engine,
|
||||
model_id=model_id,
|
||||
batch_size=self.batch_size,
|
||||
force_timeout=self.force_timeout,
|
||||
max_queue_size=self.max_queue_size,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
|
|
@ -473,7 +487,7 @@ class StreamConnectionManager:
|
|||
model_repository=self.model_repository,
|
||||
model_id=model_id,
|
||||
batch_size=self.batch_size,
|
||||
force_timeout=self.force_timeout,
|
||||
max_queue_size=self.max_queue_size,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
|
|
@ -656,7 +670,7 @@ class StreamConnectionManager:
|
|||
"gpu_id": self.gpu_id,
|
||||
"num_connections": len(self.connections),
|
||||
"batch_size": self.batch_size,
|
||||
"force_timeout": self.force_timeout,
|
||||
"max_queue_size": self.max_queue_size,
|
||||
"poll_interval": self.poll_interval,
|
||||
},
|
||||
"model_controller": self.model_controller.get_stats()
|
||||
|
|
|
|||
|
|
@ -25,14 +25,14 @@ class TensorRTModelController(BaseModelController):
|
|||
model_repository,
|
||||
model_id: str,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
max_queue_size: int = 100,
|
||||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
batch_size=batch_size,
|
||||
force_timeout=force_timeout,
|
||||
max_queue_size=max_queue_size,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
|
|
@ -115,6 +115,7 @@ class TensorRTModelController(BaseModelController):
|
|||
"stream_id": batch_frame.stream_id,
|
||||
"timestamp": batch_frame.timestamp,
|
||||
"detections": detections,
|
||||
"frame": batch_frame.frame, # Include original frame tensor
|
||||
"metadata": batch_frame.metadata,
|
||||
}
|
||||
results.append(result)
|
||||
|
|
@ -175,6 +176,7 @@ class TensorRTModelController(BaseModelController):
|
|||
"stream_id": batch_frame.stream_id,
|
||||
"timestamp": batch_frame.timestamp,
|
||||
"detections": detections,
|
||||
"frame": batch_frame.frame, # Include original frame tensor
|
||||
"metadata": batch_frame.metadata,
|
||||
}
|
||||
results.append(result)
|
||||
|
|
|
|||
|
|
@ -25,20 +25,27 @@ class UltralyticsModelController(BaseModelController):
|
|||
inference_engine,
|
||||
model_id: str,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
max_queue_size: int = 100,
|
||||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
):
|
||||
# Auto-detect actual batch size from the YOLO engine
|
||||
print(f"[UltralyticsModelController] Detecting batch size from engine...")
|
||||
engine_batch_size = self._detect_engine_batch_size(inference_engine)
|
||||
print(
|
||||
f"[UltralyticsModelController] Detected engine_batch_size={engine_batch_size}"
|
||||
)
|
||||
|
||||
# If engine has fixed batch size, use it. Otherwise use user's batch_size
|
||||
actual_batch_size = engine_batch_size if engine_batch_size > 0 else batch_size
|
||||
print(
|
||||
f"[UltralyticsModelController] Using actual_batch_size={actual_batch_size}"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
batch_size=actual_batch_size,
|
||||
force_timeout=force_timeout,
|
||||
max_queue_size=max_queue_size,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
|
|
@ -46,11 +53,23 @@ class UltralyticsModelController(BaseModelController):
|
|||
self.engine_batch_size = engine_batch_size # Store for padding logic
|
||||
|
||||
if engine_batch_size > 0:
|
||||
print(f"✓ Ultralytics engine has FIXED batch_size={engine_batch_size}")
|
||||
print(
|
||||
f" Will pad/truncate all batches to exactly {engine_batch_size} frames"
|
||||
)
|
||||
logger.info(
|
||||
f"Ultralytics engine has fixed batch_size={engine_batch_size}, "
|
||||
f"will pad batches to match"
|
||||
)
|
||||
# CRITICAL: Override the parent's batch_size to match engine's fixed size
|
||||
# This prevents buffer accumulation beyond the engine's capacity
|
||||
self.batch_size = engine_batch_size
|
||||
print(f" Controller self.batch_size is now: {self.batch_size}")
|
||||
print(f" Buffer will swap when size >= {self.batch_size}")
|
||||
else:
|
||||
print(
|
||||
f"✓ Ultralytics engine supports DYNAMIC batching, max={actual_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f"Ultralytics engine supports dynamic batching, "
|
||||
f"using max batch_size={actual_batch_size}"
|
||||
|
|
@ -67,16 +86,22 @@ class UltralyticsModelController(BaseModelController):
|
|||
# Get engine metadata
|
||||
metadata = inference_engine.get_metadata()
|
||||
|
||||
logger.info(f"Detecting batch size from engine metadata: {metadata}")
|
||||
|
||||
# Check input shape for batch dimension
|
||||
if "images" in metadata.input_shapes:
|
||||
input_shape = metadata.input_shapes["images"]
|
||||
batch_dim = input_shape[0]
|
||||
|
||||
logger.info(f"Found batch dimension in metadata: {batch_dim}")
|
||||
|
||||
if batch_dim > 0:
|
||||
# Fixed batch size
|
||||
logger.info(f"Using fixed batch size from engine: {batch_dim}")
|
||||
return batch_dim
|
||||
else:
|
||||
# Dynamic batch size (-1)
|
||||
logger.info("Engine supports dynamic batching (batch_dim=-1)")
|
||||
return -1
|
||||
|
||||
# Fallback: try to get from model directly
|
||||
|
|
@ -187,28 +212,16 @@ class UltralyticsModelController(BaseModelController):
|
|||
# No detections
|
||||
detections = torch.zeros((0, 6), device=batch_tensor.device)
|
||||
|
||||
# Apply custom postprocessing if provided
|
||||
if self.postprocess_fn:
|
||||
try:
|
||||
# For Ultralytics, postprocess_fn might do additional filtering
|
||||
# Pass the raw boxes tensor in the same format as TensorRT output
|
||||
detections = self.postprocess_fn(
|
||||
{
|
||||
"output0": detections.unsqueeze(
|
||||
0
|
||||
) # Add batch dim for compatibility
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
||||
)
|
||||
detections = torch.zeros((0, 6), device=batch_tensor.device)
|
||||
# NOTE: Skip postprocess_fn for Ultralytics backend!
|
||||
# Ultralytics already does confidence filtering, NMS, and format conversion.
|
||||
# The detections are already in final format: [x1, y1, x2, y2, conf, cls]
|
||||
# Any custom postprocess_fn would expect raw TensorRT output and will fail.
|
||||
|
||||
result = {
|
||||
"stream_id": batch_frame.stream_id,
|
||||
"timestamp": batch_frame.timestamp,
|
||||
"detections": detections,
|
||||
"frame": batch_frame.frame, # Include original frame tensor
|
||||
"metadata": batch_frame.metadata,
|
||||
"yolo_result": yolo_result, # Keep original Results object for debugging
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ Real-time object tracking with event-driven batching architecture.
|
|||
This script demonstrates:
|
||||
- Event-driven stream processing with StreamConnectionManager
|
||||
- Batched GPU inference with ModelController
|
||||
- Ping-pong buffer architecture for optimal throughput
|
||||
- Callback-based event-driven pattern for RTSP streams
|
||||
- Automatic PT to TensorRT conversion
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
|
@ -28,6 +28,11 @@ from services import (
|
|||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Enable debug logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
def main_multi_stream():
|
||||
"""Multi-stream example with batched inference."""
|
||||
|
|
@ -41,8 +46,8 @@ def main_multi_stream():
|
|||
USE_ULTRALYTICS = (
|
||||
os.getenv("USE_ULTRALYTICS", "true").lower() == "true"
|
||||
) # Use Ultralytics engine for YOLO
|
||||
BATCH_SIZE = 2 # Reduced to 2 to avoid GPU memory issues
|
||||
FORCE_TIMEOUT = 0.05
|
||||
BATCH_SIZE = 2 # Must match engine's fixed batch size
|
||||
MAX_QUEUE_SIZE = 50 # Drop frames if queue gets too long
|
||||
ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true"
|
||||
|
||||
# Load camera URLs
|
||||
|
|
@ -73,10 +78,11 @@ def main_multi_stream():
|
|||
manager = StreamConnectionManager(
|
||||
gpu_id=GPU_ID,
|
||||
batch_size=BATCH_SIZE,
|
||||
force_timeout=FORCE_TIMEOUT,
|
||||
max_queue_size=MAX_QUEUE_SIZE,
|
||||
enable_pt_conversion=True,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
print("✓ Manager created")
|
||||
|
||||
# Initialize model (transparent loading)
|
||||
|
|
@ -86,7 +92,6 @@ def main_multi_stream():
|
|||
model_path=MODEL_PATH,
|
||||
model_id="detector",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
num_contexts=1, # Single context to minimize GPU memory usage
|
||||
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
||||
)
|
||||
|
|
@ -98,6 +103,109 @@ def main_multi_stream():
|
|||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Track stats (initialize before callback definition)
|
||||
stream_stats = {sid: {"count": 0, "start": time.time()} for sid, _ in camera_urls}
|
||||
total_results = 0
|
||||
start_time = time.time()
|
||||
stats_lock = threading.Lock()
|
||||
|
||||
# Create windows for each stream if display enabled
|
||||
if ENABLE_DISPLAY:
|
||||
for stream_id, _ in camera_urls:
|
||||
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(
|
||||
stream_id, 640, 360
|
||||
) # Smaller windows for multiple streams
|
||||
|
||||
def on_tracking_result(result):
|
||||
"""Callback for tracking results - called automatically per stream"""
|
||||
nonlocal total_results
|
||||
|
||||
# Debug: Check if we have frame tensor
|
||||
has_frame = result.frame_tensor is not None
|
||||
frame_shape = result.frame_tensor.shape if has_frame else None
|
||||
print(
|
||||
f"[CALLBACK] Got result for {result.stream_id}, has_frame={has_frame}, shape={frame_shape}, detections={len(result.detections)}"
|
||||
)
|
||||
|
||||
with stats_lock:
|
||||
total_results += 1
|
||||
stream_id = result.stream_id
|
||||
if stream_id in stream_stats:
|
||||
stream_stats[stream_id]["count"] += 1
|
||||
|
||||
# Print stats every 10 results (changed from 100 for faster feedback)
|
||||
if total_results % 10 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
total_fps = total_results / elapsed if elapsed > 0 else 0
|
||||
print(
|
||||
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
|
||||
)
|
||||
for sid, stats in stream_stats.items():
|
||||
s_elapsed = time.time() - stats["start"]
|
||||
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
|
||||
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
||||
|
||||
# Display visualization if enabled
|
||||
if ENABLE_DISPLAY and result.frame_tensor is not None:
|
||||
# Convert GPU tensor (C, H, W) to CPU numpy (H, W, C) for OpenCV
|
||||
frame_tensor = result.frame_tensor # (3, 720, 1280) RGB uint8
|
||||
frame_np = (
|
||||
frame_tensor.cpu().permute(1, 2, 0).numpy().astype(np.uint8)
|
||||
) # (720, 1280, 3)
|
||||
frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Draw bounding boxes
|
||||
for obj in result.tracked_objects:
|
||||
x1, y1, x2, y2 = map(int, obj.bbox)
|
||||
|
||||
# Draw box
|
||||
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# Draw label with ID and class
|
||||
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||
(label_w, label_h), _ = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
cv2.rectangle(
|
||||
frame_bgr,
|
||||
(x1, y1 - label_h - 10),
|
||||
(x1 + label_w, y1),
|
||||
(0, 255, 0),
|
||||
-1,
|
||||
)
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
label,
|
||||
(x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 0, 0),
|
||||
1,
|
||||
)
|
||||
|
||||
# Show FPS on frame
|
||||
with stats_lock:
|
||||
s_elapsed = time.time() - stream_stats[stream_id]["start"]
|
||||
s_fps = (
|
||||
stream_stats[stream_id]["count"] / s_elapsed if s_elapsed > 0 else 0
|
||||
)
|
||||
fps_text = (
|
||||
f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
|
||||
)
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
fps_text,
|
||||
(10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.7,
|
||||
(0, 255, 0),
|
||||
2,
|
||||
)
|
||||
|
||||
# Display
|
||||
cv2.imshow(stream_id, frame_bgr)
|
||||
|
||||
# Connect all streams in parallel using threads
|
||||
print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...")
|
||||
connections = {}
|
||||
|
|
@ -108,7 +216,10 @@ def main_multi_stream():
|
|||
"""Thread worker to connect a single stream"""
|
||||
try:
|
||||
conn = manager.connect_stream(
|
||||
rtsp_url=rtsp_url, stream_id=stream_id, buffer_size=3
|
||||
rtsp_url=rtsp_url,
|
||||
stream_id=stream_id,
|
||||
buffer_size=2,
|
||||
on_tracking_result=on_tracking_result, # Register callback
|
||||
)
|
||||
connection_results[stream_id] = ("success", conn)
|
||||
except Exception as e:
|
||||
|
|
@ -144,124 +255,15 @@ def main_multi_stream():
|
|||
print("Press Ctrl+C to stop")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Track stats
|
||||
stream_stats = {
|
||||
sid: {"count": 0, "start": time.time()} for sid in connections.keys()
|
||||
}
|
||||
total_results = 0
|
||||
start_time = time.time()
|
||||
|
||||
# Create windows for each stream if display enabled
|
||||
if ENABLE_DISPLAY:
|
||||
for stream_id in connections.keys():
|
||||
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(
|
||||
stream_id, 640, 360
|
||||
) # Smaller windows for multiple streams
|
||||
|
||||
try:
|
||||
# Merge all result queues from all connections
|
||||
import queue as queue_module
|
||||
|
||||
running = True
|
||||
while running:
|
||||
# Poll all connection queues (non-blocking)
|
||||
got_result = False
|
||||
for conn in connections.values():
|
||||
try:
|
||||
# Non-blocking get from each connection's queue
|
||||
result = conn.result_queue.get_nowait()
|
||||
got_result = True
|
||||
|
||||
total_results += 1
|
||||
stream_id = result.stream_id
|
||||
|
||||
if stream_id in stream_stats:
|
||||
stream_stats[stream_id]["count"] += 1
|
||||
|
||||
# Display visualization if enabled
|
||||
if ENABLE_DISPLAY:
|
||||
# Get latest frame from decoder (already in CPU memory as numpy RGB)
|
||||
frame_rgb = conn.decoder.get_latest_frame_cpu(rgb=True)
|
||||
if frame_rgb is not None:
|
||||
# Convert RGB to BGR for OpenCV
|
||||
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Draw bounding boxes
|
||||
for obj in result.tracked_objects:
|
||||
x1, y1, x2, y2 = map(int, obj.bbox)
|
||||
|
||||
# Draw box
|
||||
cv2.rectangle(
|
||||
frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2
|
||||
)
|
||||
|
||||
# Draw label with ID and class
|
||||
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||
(label_w, label_h), _ = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
cv2.rectangle(
|
||||
frame_bgr,
|
||||
(x1, y1 - label_h - 10),
|
||||
(x1 + label_w, y1),
|
||||
(0, 255, 0),
|
||||
-1,
|
||||
)
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
label,
|
||||
(x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 0, 0),
|
||||
1,
|
||||
)
|
||||
|
||||
# Show FPS on frame
|
||||
s_elapsed = time.time() - stream_stats[stream_id]["start"]
|
||||
s_fps = (
|
||||
stream_stats[stream_id]["count"] / s_elapsed
|
||||
if s_elapsed > 0
|
||||
else 0
|
||||
)
|
||||
fps_text = f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
fps_text,
|
||||
(10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.7,
|
||||
(0, 255, 0),
|
||||
2,
|
||||
)
|
||||
|
||||
# Display
|
||||
cv2.imshow(stream_id, frame_bgr)
|
||||
|
||||
# Print stats every 100 results
|
||||
if total_results % 100 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
total_fps = total_results / elapsed if elapsed > 0 else 0
|
||||
|
||||
print(
|
||||
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
|
||||
)
|
||||
for sid, stats in stream_stats.items():
|
||||
s_elapsed = time.time() - stats["start"]
|
||||
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
|
||||
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
||||
|
||||
except queue_module.Empty:
|
||||
continue
|
||||
|
||||
# Process OpenCV events to keep windows responsive
|
||||
# Keep main thread alive and process OpenCV events
|
||||
while True:
|
||||
if ENABLE_DISPLAY:
|
||||
cv2.waitKey(1)
|
||||
|
||||
# Small sleep if no results to avoid busy loop
|
||||
if not got_result:
|
||||
time.sleep(0.01)
|
||||
# Process OpenCV events to keep windows responsive
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n✓ Interrupted")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue