new buffer paradigm
This commit is contained in:
parent
fdaeb9981c
commit
a519dea130
6 changed files with 341 additions and 327 deletions
|
|
@ -1,16 +1,19 @@
|
||||||
"""
|
"""
|
||||||
Base Model Controller - Abstract base class for batched inference controllers.
|
Base Model Controller - Simple circular buffer with continuous batch processing.
|
||||||
|
|
||||||
Provides ping-pong buffer architecture with force-switch timeout mechanism.
|
Replaces the complex ping-pong buffer architecture with a simple queue:
|
||||||
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.).
|
- Frames arrive and go into a single deque (circular buffer)
|
||||||
|
- When batch_size frames are ready, process them
|
||||||
|
- Continue consuming batches until queue is empty
|
||||||
|
- Drop oldest frames if queue is full
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import deque
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -28,62 +31,37 @@ class BatchFrame:
|
||||||
metadata: Dict = field(default_factory=dict)
|
metadata: Dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class BufferState(Enum):
|
|
||||||
"""State of a ping-pong buffer"""
|
|
||||||
|
|
||||||
IDLE = "idle"
|
|
||||||
FILLING = "filling"
|
|
||||||
PROCESSING = "processing"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModelController(ABC):
|
class BaseModelController(ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for batched inference with ping-pong buffers.
|
Simple batched inference with circular buffer.
|
||||||
|
|
||||||
This controller accumulates frames from multiple streams into batches,
|
Architecture:
|
||||||
processes them through an inference backend, and routes results back to
|
- Single deque (circular buffer) for incoming frames
|
||||||
stream-specific callbacks.
|
- Batch processor thread continuously consumes batches
|
||||||
|
- Frames come in fast, batches go out as fast as inference allows
|
||||||
Features:
|
- Automatic frame dropping when queue is full
|
||||||
- Ping-pong circular buffers (BufferA/BufferB)
|
|
||||||
- Force-switch timeout to prevent batch starvation
|
|
||||||
- Event-driven processing with callbacks
|
|
||||||
- Thread-safe frame submission
|
|
||||||
|
|
||||||
Subclasses must implement:
|
|
||||||
- _run_batch_inference(): Backend-specific inference logic
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
force_timeout: float = 0.05,
|
max_queue_size: int = 100,
|
||||||
preprocess_fn: Optional[Callable] = None,
|
preprocess_fn: Optional[Callable] = None,
|
||||||
postprocess_fn: Optional[Callable] = None,
|
postprocess_fn: Optional[Callable] = None,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.force_timeout = force_timeout
|
self.max_queue_size = max_queue_size
|
||||||
self.preprocess_fn = preprocess_fn
|
self.preprocess_fn = preprocess_fn
|
||||||
self.postprocess_fn = postprocess_fn
|
self.postprocess_fn = postprocess_fn
|
||||||
|
|
||||||
# Ping-pong buffers
|
# Single circular buffer
|
||||||
self.buffer_a: List[BatchFrame] = []
|
self.frame_queue = deque(maxlen=max_queue_size)
|
||||||
self.buffer_b: List[BatchFrame] = []
|
self.queue_lock = threading.Lock()
|
||||||
|
|
||||||
# Buffer states
|
# Processing thread
|
||||||
self.active_buffer = "A"
|
self.processor_thread: Optional[threading.Thread] = None
|
||||||
self.buffer_a_state = BufferState.IDLE
|
|
||||||
self.buffer_b_state = BufferState.IDLE
|
|
||||||
|
|
||||||
# Threading coordination
|
|
||||||
self.buffer_lock = threading.RLock()
|
|
||||||
self.last_submit_time = time.time()
|
|
||||||
|
|
||||||
# Threads
|
|
||||||
self.timeout_thread: Optional[threading.Thread] = None
|
|
||||||
self.processor_threads: Dict[str, threading.Thread] = {}
|
|
||||||
self.running = False
|
self.running = False
|
||||||
self.stop_event = threading.Event()
|
self.stop_event = threading.Event()
|
||||||
|
|
||||||
|
|
@ -93,33 +71,24 @@ class BaseModelController(ABC):
|
||||||
# Statistics
|
# Statistics
|
||||||
self.total_frames_processed = 0
|
self.total_frames_processed = 0
|
||||||
self.total_batches_processed = 0
|
self.total_batches_processed = 0
|
||||||
|
self.total_frames_dropped = 0
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the controller background threads"""
|
"""Start the controller background thread"""
|
||||||
if self.running:
|
if self.running:
|
||||||
logger.warning("ModelController already running")
|
logger.warning(f"{self.__class__.__name__} already running")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
self.stop_event.clear()
|
self.stop_event.clear()
|
||||||
|
|
||||||
# Start timeout monitor thread
|
# Start single processor thread
|
||||||
self.timeout_thread = threading.Thread(
|
self.processor_thread = threading.Thread(
|
||||||
target=self._timeout_monitor, daemon=True
|
target=self._batch_processor, daemon=True
|
||||||
)
|
)
|
||||||
self.timeout_thread.start()
|
self.processor_thread.start()
|
||||||
|
|
||||||
# Start processor threads for each buffer
|
logger.info(f"{self.__class__.__name__} started (batch_size={self.batch_size})")
|
||||||
self.processor_threads["A"] = threading.Thread(
|
|
||||||
target=self._batch_processor, args=("A",), daemon=True
|
|
||||||
)
|
|
||||||
self.processor_threads["B"] = threading.Thread(
|
|
||||||
target=self._batch_processor, args=("B",), daemon=True
|
|
||||||
)
|
|
||||||
self.processor_threads["A"].start()
|
|
||||||
self.processor_threads["B"].start()
|
|
||||||
|
|
||||||
logger.info(f"{self.__class__.__name__} started")
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the controller and cleanup"""
|
"""Stop the controller and cleanup"""
|
||||||
|
|
@ -130,16 +99,12 @@ class BaseModelController(ABC):
|
||||||
self.running = False
|
self.running = False
|
||||||
self.stop_event.set()
|
self.stop_event.set()
|
||||||
|
|
||||||
# Wait for threads to finish
|
# Wait for thread to finish
|
||||||
if self.timeout_thread and self.timeout_thread.is_alive():
|
if self.processor_thread and self.processor_thread.is_alive():
|
||||||
self.timeout_thread.join(timeout=2.0)
|
self.processor_thread.join(timeout=2.0)
|
||||||
|
|
||||||
for thread in self.processor_threads.values():
|
# Process remaining frames
|
||||||
if thread and thread.is_alive():
|
self._process_remaining_frames()
|
||||||
thread.join(timeout=2.0)
|
|
||||||
|
|
||||||
# Process any remaining frames
|
|
||||||
self._process_remaining_buffers()
|
|
||||||
logger.info(f"{self.__class__.__name__} stopped")
|
logger.info(f"{self.__class__.__name__} stopped")
|
||||||
|
|
||||||
def register_callback(self, stream_id: str, callback: Callable):
|
def register_callback(self, stream_id: str, callback: Callable):
|
||||||
|
|
@ -156,98 +121,48 @@ class BaseModelController(ABC):
|
||||||
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
||||||
):
|
):
|
||||||
"""Submit a frame for batched inference"""
|
"""Submit a frame for batched inference"""
|
||||||
with self.buffer_lock:
|
with self.queue_lock:
|
||||||
|
# If queue is full, oldest frame is automatically dropped (deque with maxlen)
|
||||||
|
if len(self.frame_queue) >= self.max_queue_size:
|
||||||
|
self.total_frames_dropped += 1
|
||||||
|
|
||||||
batch_frame = BatchFrame(
|
batch_frame = BatchFrame(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
frame=frame,
|
frame=frame,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
)
|
)
|
||||||
|
self.frame_queue.append(batch_frame)
|
||||||
|
|
||||||
# Add to active buffer
|
def _batch_processor(self):
|
||||||
if self.active_buffer == "A":
|
"""Background thread that continuously processes batches"""
|
||||||
self.buffer_a.append(batch_frame)
|
logger.info(f"{self.__class__.__name__} batch processor started")
|
||||||
self.buffer_a_state = BufferState.FILLING
|
|
||||||
buffer_size = len(self.buffer_a)
|
|
||||||
else:
|
|
||||||
self.buffer_b.append(batch_frame)
|
|
||||||
self.buffer_b_state = BufferState.FILLING
|
|
||||||
buffer_size = len(self.buffer_b)
|
|
||||||
|
|
||||||
self.last_submit_time = time.time()
|
|
||||||
|
|
||||||
# Check if we should immediately swap (batch full)
|
|
||||||
if buffer_size >= self.batch_size:
|
|
||||||
self._try_swap_buffers()
|
|
||||||
|
|
||||||
def _timeout_monitor(self):
|
|
||||||
"""Monitor force-switch timeout"""
|
|
||||||
while self.running and not self.stop_event.wait(0.01):
|
|
||||||
with self.buffer_lock:
|
|
||||||
time_since_submit = time.time() - self.last_submit_time
|
|
||||||
|
|
||||||
if time_since_submit >= self.force_timeout:
|
|
||||||
active_buffer = (
|
|
||||||
self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
|
||||||
)
|
|
||||||
if len(active_buffer) > 0:
|
|
||||||
self._try_swap_buffers()
|
|
||||||
|
|
||||||
def _try_swap_buffers(self):
|
|
||||||
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
|
|
||||||
inactive_state = (
|
|
||||||
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
|
||||||
)
|
|
||||||
|
|
||||||
if inactive_state != BufferState.PROCESSING:
|
|
||||||
old_active = self.active_buffer
|
|
||||||
self.active_buffer = "B" if old_active == "A" else "A"
|
|
||||||
|
|
||||||
if old_active == "A":
|
|
||||||
self.buffer_a_state = BufferState.PROCESSING
|
|
||||||
buffer_size = len(self.buffer_a)
|
|
||||||
else:
|
|
||||||
self.buffer_b_state = BufferState.PROCESSING
|
|
||||||
buffer_size = len(self.buffer_b)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _batch_processor(self, buffer_name: str):
|
|
||||||
"""Background thread that processes a specific buffer when available"""
|
|
||||||
while self.running and not self.stop_event.is_set():
|
while self.running and not self.stop_event.is_set():
|
||||||
time.sleep(0.001)
|
# Check if we have enough frames for a batch
|
||||||
|
with self.queue_lock:
|
||||||
|
queue_size = len(self.frame_queue)
|
||||||
|
|
||||||
with self.buffer_lock:
|
if queue_size > 0 and queue_size % 10 == 0:
|
||||||
if buffer_name == "A":
|
logger.info(f"Queue size: {queue_size}/{self.batch_size}")
|
||||||
should_process = self.buffer_a_state == BufferState.PROCESSING
|
|
||||||
|
if queue_size >= self.batch_size:
|
||||||
|
# Extract batch
|
||||||
|
with self.queue_lock:
|
||||||
|
batch = []
|
||||||
|
for _ in range(min(self.batch_size, len(self.frame_queue))):
|
||||||
|
if self.frame_queue:
|
||||||
|
batch.append(self.frame_queue.popleft())
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
logger.info(f"Processing batch of {len(batch)} frames")
|
||||||
|
self._process_batch(batch)
|
||||||
else:
|
else:
|
||||||
should_process = self.buffer_b_state == BufferState.PROCESSING
|
# Not enough frames, sleep briefly
|
||||||
|
time.sleep(0.001) # 1ms
|
||||||
|
|
||||||
if should_process:
|
def _process_batch(self, batch: List[BatchFrame]):
|
||||||
self._process_buffer(buffer_name)
|
"""Process a batch through inference"""
|
||||||
|
|
||||||
def _process_buffer(self, buffer_name: str):
|
|
||||||
"""Process a buffer through inference"""
|
|
||||||
# Extract buffer to process
|
|
||||||
with self.buffer_lock:
|
|
||||||
if buffer_name == "A":
|
|
||||||
batch = self.buffer_a.copy()
|
|
||||||
self.buffer_a.clear()
|
|
||||||
else:
|
|
||||||
batch = self.buffer_b.copy()
|
|
||||||
self.buffer_b.clear()
|
|
||||||
|
|
||||||
if len(batch) == 0:
|
|
||||||
with self.buffer_lock:
|
|
||||||
if buffer_name == "A":
|
|
||||||
self.buffer_a_state = BufferState.IDLE
|
|
||||||
else:
|
|
||||||
self.buffer_b_state = BufferState.IDLE
|
|
||||||
return
|
|
||||||
|
|
||||||
# Process batch (outside lock to allow concurrent submissions)
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = self._run_batch_inference(batch)
|
results = self._run_batch_inference(batch)
|
||||||
|
|
@ -258,7 +173,6 @@ class BaseModelController(ABC):
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms"
|
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms"
|
||||||
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit results to callbacks
|
# Emit results to callbacks
|
||||||
|
|
@ -276,49 +190,58 @@ class BaseModelController(ABC):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||||
|
|
||||||
finally:
|
|
||||||
with self.buffer_lock:
|
|
||||||
if buffer_name == "A":
|
|
||||||
self.buffer_a_state = BufferState.IDLE
|
|
||||||
else:
|
|
||||||
self.buffer_b_state = BufferState.IDLE
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""Run inference on a batch of frames (backend-specific)"""
|
||||||
Run inference on a batch of frames (backend-specific).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch: List of BatchFrame objects
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of detection results (one per frame)
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _process_remaining_buffers(self):
|
def _process_remaining_frames(self):
|
||||||
"""Process any remaining frames in buffers during shutdown"""
|
"""Process any remaining frames in queue during shutdown"""
|
||||||
if len(self.buffer_a) > 0:
|
with self.queue_lock:
|
||||||
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
remaining = len(self.frame_queue)
|
||||||
self._process_buffer("A")
|
|
||||||
if len(self.buffer_b) > 0:
|
if remaining > 0:
|
||||||
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
logger.info(f"Processing remaining {remaining} frames")
|
||||||
self._process_buffer("B")
|
while True:
|
||||||
|
with self.queue_lock:
|
||||||
|
if not self.frame_queue:
|
||||||
|
break
|
||||||
|
batch = []
|
||||||
|
for _ in range(min(self.batch_size, len(self.frame_queue))):
|
||||||
|
if self.frame_queue:
|
||||||
|
batch.append(self.frame_queue.popleft())
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
self._process_batch(batch)
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""Get current buffer statistics"""
|
"""Get current statistics"""
|
||||||
|
with self.queue_lock:
|
||||||
|
queue_size = len(self.frame_queue)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"active_buffer": self.active_buffer,
|
"queue_size": queue_size,
|
||||||
"buffer_a_size": len(self.buffer_a),
|
"max_queue_size": self.max_queue_size,
|
||||||
"buffer_b_size": len(self.buffer_b),
|
"batch_size": self.batch_size,
|
||||||
"buffer_a_state": self.buffer_a_state.value,
|
|
||||||
"buffer_b_state": self.buffer_b_state.value,
|
|
||||||
"registered_streams": len(self.result_callbacks),
|
"registered_streams": len(self.result_callbacks),
|
||||||
"total_frames_processed": self.total_frames_processed,
|
"total_frames_processed": self.total_frames_processed,
|
||||||
"total_batches_processed": self.total_batches_processed,
|
"total_batches_processed": self.total_batches_processed,
|
||||||
|
"total_frames_dropped": self.total_frames_dropped,
|
||||||
"avg_batch_size": (
|
"avg_batch_size": (
|
||||||
self.total_frames_processed / self.total_batches_processed
|
self.total_frames_processed / self.total_batches_processed
|
||||||
if self.total_batches_processed > 0
|
if self.total_batches_processed > 0
|
||||||
else 0
|
else 0
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Keep old BufferState enum for backwards compatibility
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class BufferState(Enum):
|
||||||
|
"""Deprecated - kept for backwards compatibility"""
|
||||||
|
|
||||||
|
IDLE = "idle"
|
||||||
|
FILLING = "filling"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ Provides a unified interface for different inference backends:
|
||||||
All engines support zero-copy GPU tensor inference where possible.
|
All engines support zero-copy GPU tensor inference where possible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
@ -17,6 +18,8 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BackendType(Enum):
|
class BackendType(Enum):
|
||||||
"""Supported inference backend types"""
|
"""Supported inference backend types"""
|
||||||
|
|
@ -423,9 +426,18 @@ class UltralyticsEngine(IInferenceEngine):
|
||||||
final_model_path = engine_path
|
final_model_path = engine_path
|
||||||
print(f"Using TensorRT engine: {engine_path}")
|
print(f"Using TensorRT engine: {engine_path}")
|
||||||
|
|
||||||
|
# CRITICAL: Update _model_path to point to the .engine file for metadata extraction
|
||||||
|
self._model_path = engine_path
|
||||||
|
|
||||||
# Load model (Ultralytics handles .engine files natively)
|
# Load model (Ultralytics handles .engine files natively)
|
||||||
self._model = YOLO(final_model_path)
|
self._model = YOLO(final_model_path)
|
||||||
|
|
||||||
|
logger.info(f"Loaded Ultralytics model: {type(self._model)}")
|
||||||
|
if hasattr(self._model, "predictor"):
|
||||||
|
logger.info(
|
||||||
|
f"Model has predictor: {type(self._model.predictor) if self._model.predictor else None}"
|
||||||
|
)
|
||||||
|
|
||||||
# Move to device if needed (only for .pt models, .engine already on specific device)
|
# Move to device if needed (only for .pt models, .engine already on specific device)
|
||||||
if hasattr(self._model, "model") and self._model.model is not None:
|
if hasattr(self._model, "model") and self._model.model is not None:
|
||||||
# Check if it's actually a torch model (not a string path for .engine files)
|
# Check if it's actually a torch model (not a string path for .engine files)
|
||||||
|
|
@ -437,6 +449,39 @@ class UltralyticsEngine(IInferenceEngine):
|
||||||
|
|
||||||
return self._metadata
|
return self._metadata
|
||||||
|
|
||||||
|
def _read_batch_size_from_engine_file(self, engine_path: str) -> int:
|
||||||
|
"""
|
||||||
|
Read batch size from the metadata JSON file saved next to the engine.
|
||||||
|
|
||||||
|
Much simpler than parsing TensorRT engine!
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# The metadata file is named: <engine_path_without_extension>_metadata.json
|
||||||
|
engine_file = Path(engine_path)
|
||||||
|
metadata_file = engine_file.with_name(f"{engine_file.stem}_metadata.json")
|
||||||
|
|
||||||
|
print(f"[UltralyticsEngine] Looking for metadata file: {metadata_file}")
|
||||||
|
|
||||||
|
if metadata_file.exists():
|
||||||
|
with open(metadata_file, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
batch_size = metadata.get("batch", -1)
|
||||||
|
print(
|
||||||
|
f"[UltralyticsEngine] Found metadata: batch={batch_size}, imgsz={metadata.get('imgsz')}"
|
||||||
|
)
|
||||||
|
return batch_size
|
||||||
|
else:
|
||||||
|
print(f"[UltralyticsEngine] Metadata file not found: {metadata_file}")
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"[UltralyticsEngine] Could not read batch size from metadata file: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return -1 # Default to dynamic
|
||||||
|
|
||||||
def _extract_metadata(self) -> EngineMetadata:
|
def _extract_metadata(self) -> EngineMetadata:
|
||||||
"""Extract metadata from Ultralytics model"""
|
"""Extract metadata from Ultralytics model"""
|
||||||
# Ultralytics models typically expect (B, 3, H, W) input
|
# Ultralytics models typically expect (B, 3, H, W) input
|
||||||
|
|
@ -447,6 +492,17 @@ class UltralyticsEngine(IInferenceEngine):
|
||||||
imgsz = 640
|
imgsz = 640
|
||||||
input_shape = (batch_size, 3, imgsz, imgsz)
|
input_shape = (batch_size, 3, imgsz, imgsz)
|
||||||
|
|
||||||
|
# CRITICAL: For .engine files, read batch size directly from the TensorRT engine file
|
||||||
|
print(f"[UltralyticsEngine] _model_path={self._model_path}")
|
||||||
|
if self._model_path.endswith(".engine"):
|
||||||
|
print(f"[UltralyticsEngine] Reading batch size from engine file...")
|
||||||
|
batch_size = self._read_batch_size_from_engine_file(self._model_path)
|
||||||
|
print(f"[UltralyticsEngine] Read batch_size={batch_size} from .engine file")
|
||||||
|
if batch_size > 0:
|
||||||
|
input_shape = (batch_size, 3, imgsz, imgsz)
|
||||||
|
else:
|
||||||
|
print(f"[UltralyticsEngine] Not an .engine file, skipping direct read")
|
||||||
|
|
||||||
if hasattr(self._model, "model") and self._model.model is not None:
|
if hasattr(self._model, "model") and self._model.model is not None:
|
||||||
# Try to get actual input shape from model
|
# Try to get actual input shape from model
|
||||||
try:
|
try:
|
||||||
|
|
@ -508,6 +564,10 @@ class UltralyticsEngine(IInferenceEngine):
|
||||||
logger.warning(f"Could not extract full metadata: {e}")
|
logger.warning(f"Could not extract full metadata: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Extracted Ultralytics metadata: batch_size={batch_size}, imgsz={imgsz}, input_shape={input_shape}"
|
||||||
|
)
|
||||||
|
|
||||||
return EngineMetadata(
|
return EngineMetadata(
|
||||||
engine_type="ultralytics",
|
engine_type="ultralytics",
|
||||||
model_path=self._model_path,
|
model_path=self._model_path,
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ class TrackingResult:
|
||||||
tracked_objects: List # List of TrackedObject from TrackingController
|
tracked_objects: List # List of TrackedObject from TrackingController
|
||||||
detections: List # Raw detections
|
detections: List # Raw detections
|
||||||
frame_shape: Tuple[int, int, int]
|
frame_shape: Tuple[int, int, int]
|
||||||
|
frame_tensor: Optional[torch.Tensor] # GPU tensor of the frame (C, H, W)
|
||||||
metadata: Dict
|
metadata: Dict
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -158,6 +159,9 @@ class StreamConnection:
|
||||||
|
|
||||||
# Submit to model controller for batched inference
|
# Submit to model controller for batched inference
|
||||||
# Pass the FrameReference in metadata so we can free it later
|
# Pass the FrameReference in metadata so we can free it later
|
||||||
|
logger.debug(
|
||||||
|
f"[{self.stream_id}] Submitting frame {self.frame_count} to model controller"
|
||||||
|
)
|
||||||
self.model_controller.submit_frame(
|
self.model_controller.submit_frame(
|
||||||
stream_id=self.stream_id,
|
stream_id=self.stream_id,
|
||||||
frame=cloned_tensor, # Use cloned tensor, not original
|
frame=cloned_tensor, # Use cloned tensor, not original
|
||||||
|
|
@ -167,6 +171,9 @@ class StreamConnection:
|
||||||
"frame_ref": frame_ref, # Store reference for later cleanup
|
"frame_ref": frame_ref, # Store reference for later cleanup
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"[{self.stream_id}] Frame {self.frame_count} submitted, queue size: {len(self.model_controller.frame_queue)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Update connection status based on decoder status
|
# Update connection status based on decoder status
|
||||||
if (
|
if (
|
||||||
|
|
@ -211,6 +218,12 @@ class StreamConnection:
|
||||||
frame_shape = result["metadata"].get("shape")
|
frame_shape = result["metadata"].get("shape")
|
||||||
tracked_objects = self._run_tracking_sync(detections, frame_shape)
|
tracked_objects = self._run_tracking_sync(detections, frame_shape)
|
||||||
|
|
||||||
|
# Get ORIGINAL frame tensor from metadata (not the preprocessed one in result["frame"])
|
||||||
|
# The frame in result["frame"] is preprocessed (resized, normalized)
|
||||||
|
# We need the original frame for visualization
|
||||||
|
frame_ref = result["metadata"].get("frame_ref")
|
||||||
|
frame_tensor = frame_ref.rgb_tensor if frame_ref else None
|
||||||
|
|
||||||
# Create tracking result
|
# Create tracking result
|
||||||
tracking_result = TrackingResult(
|
tracking_result = TrackingResult(
|
||||||
stream_id=self.stream_id,
|
stream_id=self.stream_id,
|
||||||
|
|
@ -218,6 +231,7 @@ class StreamConnection:
|
||||||
tracked_objects=tracked_objects,
|
tracked_objects=tracked_objects,
|
||||||
detections=detections,
|
detections=detections,
|
||||||
frame_shape=result["metadata"].get("shape"),
|
frame_shape=result["metadata"].get("shape"),
|
||||||
|
frame_tensor=frame_tensor, # Original frame, not preprocessed
|
||||||
metadata=result["metadata"],
|
metadata=result["metadata"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -328,7 +342,7 @@ class StreamConnectionManager:
|
||||||
Args:
|
Args:
|
||||||
gpu_id: GPU device ID (default: 0)
|
gpu_id: GPU device ID (default: 0)
|
||||||
batch_size: Maximum batch size for inference (default: 16)
|
batch_size: Maximum batch size for inference (default: 16)
|
||||||
force_timeout: Force buffer switch timeout in seconds (default: 0.05)
|
max_queue_size: Maximum frames in queue before dropping (default: 100)
|
||||||
poll_interval: Frame polling interval in seconds (default: 0.01)
|
poll_interval: Frame polling interval in seconds (default: 0.01)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
@ -343,14 +357,14 @@ class StreamConnectionManager:
|
||||||
self,
|
self,
|
||||||
gpu_id: int = 0,
|
gpu_id: int = 0,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
force_timeout: float = 0.05,
|
max_queue_size: int = 100,
|
||||||
poll_interval: float = 0.01,
|
poll_interval: float = 0.01,
|
||||||
enable_pt_conversion: bool = True,
|
enable_pt_conversion: bool = True,
|
||||||
backend: str = "tensorrt", # "tensorrt" or "ultralytics"
|
backend: str = "tensorrt", # "tensorrt" or "ultralytics"
|
||||||
):
|
):
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.force_timeout = force_timeout
|
self.max_queue_size = max_queue_size
|
||||||
self.poll_interval = poll_interval
|
self.poll_interval = poll_interval
|
||||||
self.backend = backend.lower()
|
self.backend = backend.lower()
|
||||||
|
|
||||||
|
|
@ -449,7 +463,7 @@ class StreamConnectionManager:
|
||||||
inference_engine=self.inference_engine,
|
inference_engine=self.inference_engine,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
force_timeout=self.force_timeout,
|
max_queue_size=self.max_queue_size,
|
||||||
preprocess_fn=preprocess_fn,
|
preprocess_fn=preprocess_fn,
|
||||||
postprocess_fn=postprocess_fn,
|
postprocess_fn=postprocess_fn,
|
||||||
)
|
)
|
||||||
|
|
@ -473,7 +487,7 @@ class StreamConnectionManager:
|
||||||
model_repository=self.model_repository,
|
model_repository=self.model_repository,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
force_timeout=self.force_timeout,
|
max_queue_size=self.max_queue_size,
|
||||||
preprocess_fn=preprocess_fn,
|
preprocess_fn=preprocess_fn,
|
||||||
postprocess_fn=postprocess_fn,
|
postprocess_fn=postprocess_fn,
|
||||||
)
|
)
|
||||||
|
|
@ -656,7 +670,7 @@ class StreamConnectionManager:
|
||||||
"gpu_id": self.gpu_id,
|
"gpu_id": self.gpu_id,
|
||||||
"num_connections": len(self.connections),
|
"num_connections": len(self.connections),
|
||||||
"batch_size": self.batch_size,
|
"batch_size": self.batch_size,
|
||||||
"force_timeout": self.force_timeout,
|
"max_queue_size": self.max_queue_size,
|
||||||
"poll_interval": self.poll_interval,
|
"poll_interval": self.poll_interval,
|
||||||
},
|
},
|
||||||
"model_controller": self.model_controller.get_stats()
|
"model_controller": self.model_controller.get_stats()
|
||||||
|
|
|
||||||
|
|
@ -25,14 +25,14 @@ class TensorRTModelController(BaseModelController):
|
||||||
model_repository,
|
model_repository,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
force_timeout: float = 0.05,
|
max_queue_size: int = 100,
|
||||||
preprocess_fn: Optional[Callable] = None,
|
preprocess_fn: Optional[Callable] = None,
|
||||||
postprocess_fn: Optional[Callable] = None,
|
postprocess_fn: Optional[Callable] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
force_timeout=force_timeout,
|
max_queue_size=max_queue_size,
|
||||||
preprocess_fn=preprocess_fn,
|
preprocess_fn=preprocess_fn,
|
||||||
postprocess_fn=postprocess_fn,
|
postprocess_fn=postprocess_fn,
|
||||||
)
|
)
|
||||||
|
|
@ -115,6 +115,7 @@ class TensorRTModelController(BaseModelController):
|
||||||
"stream_id": batch_frame.stream_id,
|
"stream_id": batch_frame.stream_id,
|
||||||
"timestamp": batch_frame.timestamp,
|
"timestamp": batch_frame.timestamp,
|
||||||
"detections": detections,
|
"detections": detections,
|
||||||
|
"frame": batch_frame.frame, # Include original frame tensor
|
||||||
"metadata": batch_frame.metadata,
|
"metadata": batch_frame.metadata,
|
||||||
}
|
}
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
@ -175,6 +176,7 @@ class TensorRTModelController(BaseModelController):
|
||||||
"stream_id": batch_frame.stream_id,
|
"stream_id": batch_frame.stream_id,
|
||||||
"timestamp": batch_frame.timestamp,
|
"timestamp": batch_frame.timestamp,
|
||||||
"detections": detections,
|
"detections": detections,
|
||||||
|
"frame": batch_frame.frame, # Include original frame tensor
|
||||||
"metadata": batch_frame.metadata,
|
"metadata": batch_frame.metadata,
|
||||||
}
|
}
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
|
||||||
|
|
@ -25,20 +25,27 @@ class UltralyticsModelController(BaseModelController):
|
||||||
inference_engine,
|
inference_engine,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
force_timeout: float = 0.05,
|
max_queue_size: int = 100,
|
||||||
preprocess_fn: Optional[Callable] = None,
|
preprocess_fn: Optional[Callable] = None,
|
||||||
postprocess_fn: Optional[Callable] = None,
|
postprocess_fn: Optional[Callable] = None,
|
||||||
):
|
):
|
||||||
# Auto-detect actual batch size from the YOLO engine
|
# Auto-detect actual batch size from the YOLO engine
|
||||||
|
print(f"[UltralyticsModelController] Detecting batch size from engine...")
|
||||||
engine_batch_size = self._detect_engine_batch_size(inference_engine)
|
engine_batch_size = self._detect_engine_batch_size(inference_engine)
|
||||||
|
print(
|
||||||
|
f"[UltralyticsModelController] Detected engine_batch_size={engine_batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
# If engine has fixed batch size, use it. Otherwise use user's batch_size
|
# If engine has fixed batch size, use it. Otherwise use user's batch_size
|
||||||
actual_batch_size = engine_batch_size if engine_batch_size > 0 else batch_size
|
actual_batch_size = engine_batch_size if engine_batch_size > 0 else batch_size
|
||||||
|
print(
|
||||||
|
f"[UltralyticsModelController] Using actual_batch_size={actual_batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
batch_size=actual_batch_size,
|
batch_size=actual_batch_size,
|
||||||
force_timeout=force_timeout,
|
max_queue_size=max_queue_size,
|
||||||
preprocess_fn=preprocess_fn,
|
preprocess_fn=preprocess_fn,
|
||||||
postprocess_fn=postprocess_fn,
|
postprocess_fn=postprocess_fn,
|
||||||
)
|
)
|
||||||
|
|
@ -46,11 +53,23 @@ class UltralyticsModelController(BaseModelController):
|
||||||
self.engine_batch_size = engine_batch_size # Store for padding logic
|
self.engine_batch_size = engine_batch_size # Store for padding logic
|
||||||
|
|
||||||
if engine_batch_size > 0:
|
if engine_batch_size > 0:
|
||||||
|
print(f"✓ Ultralytics engine has FIXED batch_size={engine_batch_size}")
|
||||||
|
print(
|
||||||
|
f" Will pad/truncate all batches to exactly {engine_batch_size} frames"
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Ultralytics engine has fixed batch_size={engine_batch_size}, "
|
f"Ultralytics engine has fixed batch_size={engine_batch_size}, "
|
||||||
f"will pad batches to match"
|
f"will pad batches to match"
|
||||||
)
|
)
|
||||||
|
# CRITICAL: Override the parent's batch_size to match engine's fixed size
|
||||||
|
# This prevents buffer accumulation beyond the engine's capacity
|
||||||
|
self.batch_size = engine_batch_size
|
||||||
|
print(f" Controller self.batch_size is now: {self.batch_size}")
|
||||||
|
print(f" Buffer will swap when size >= {self.batch_size}")
|
||||||
else:
|
else:
|
||||||
|
print(
|
||||||
|
f"✓ Ultralytics engine supports DYNAMIC batching, max={actual_batch_size}"
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Ultralytics engine supports dynamic batching, "
|
f"Ultralytics engine supports dynamic batching, "
|
||||||
f"using max batch_size={actual_batch_size}"
|
f"using max batch_size={actual_batch_size}"
|
||||||
|
|
@ -67,16 +86,22 @@ class UltralyticsModelController(BaseModelController):
|
||||||
# Get engine metadata
|
# Get engine metadata
|
||||||
metadata = inference_engine.get_metadata()
|
metadata = inference_engine.get_metadata()
|
||||||
|
|
||||||
|
logger.info(f"Detecting batch size from engine metadata: {metadata}")
|
||||||
|
|
||||||
# Check input shape for batch dimension
|
# Check input shape for batch dimension
|
||||||
if "images" in metadata.input_shapes:
|
if "images" in metadata.input_shapes:
|
||||||
input_shape = metadata.input_shapes["images"]
|
input_shape = metadata.input_shapes["images"]
|
||||||
batch_dim = input_shape[0]
|
batch_dim = input_shape[0]
|
||||||
|
|
||||||
|
logger.info(f"Found batch dimension in metadata: {batch_dim}")
|
||||||
|
|
||||||
if batch_dim > 0:
|
if batch_dim > 0:
|
||||||
# Fixed batch size
|
# Fixed batch size
|
||||||
|
logger.info(f"Using fixed batch size from engine: {batch_dim}")
|
||||||
return batch_dim
|
return batch_dim
|
||||||
else:
|
else:
|
||||||
# Dynamic batch size (-1)
|
# Dynamic batch size (-1)
|
||||||
|
logger.info("Engine supports dynamic batching (batch_dim=-1)")
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
# Fallback: try to get from model directly
|
# Fallback: try to get from model directly
|
||||||
|
|
@ -187,28 +212,16 @@ class UltralyticsModelController(BaseModelController):
|
||||||
# No detections
|
# No detections
|
||||||
detections = torch.zeros((0, 6), device=batch_tensor.device)
|
detections = torch.zeros((0, 6), device=batch_tensor.device)
|
||||||
|
|
||||||
# Apply custom postprocessing if provided
|
# NOTE: Skip postprocess_fn for Ultralytics backend!
|
||||||
if self.postprocess_fn:
|
# Ultralytics already does confidence filtering, NMS, and format conversion.
|
||||||
try:
|
# The detections are already in final format: [x1, y1, x2, y2, conf, cls]
|
||||||
# For Ultralytics, postprocess_fn might do additional filtering
|
# Any custom postprocess_fn would expect raw TensorRT output and will fail.
|
||||||
# Pass the raw boxes tensor in the same format as TensorRT output
|
|
||||||
detections = self.postprocess_fn(
|
|
||||||
{
|
|
||||||
"output0": detections.unsqueeze(
|
|
||||||
0
|
|
||||||
) # Add batch dim for compatibility
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
|
||||||
)
|
|
||||||
detections = torch.zeros((0, 6), device=batch_tensor.device)
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"stream_id": batch_frame.stream_id,
|
"stream_id": batch_frame.stream_id,
|
||||||
"timestamp": batch_frame.timestamp,
|
"timestamp": batch_frame.timestamp,
|
||||||
"detections": detections,
|
"detections": detections,
|
||||||
|
"frame": batch_frame.frame, # Include original frame tensor
|
||||||
"metadata": batch_frame.metadata,
|
"metadata": batch_frame.metadata,
|
||||||
"yolo_result": yolo_result, # Keep original Results object for debugging
|
"yolo_result": yolo_result, # Keep original Results object for debugging
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,11 @@ Real-time object tracking with event-driven batching architecture.
|
||||||
This script demonstrates:
|
This script demonstrates:
|
||||||
- Event-driven stream processing with StreamConnectionManager
|
- Event-driven stream processing with StreamConnectionManager
|
||||||
- Batched GPU inference with ModelController
|
- Batched GPU inference with ModelController
|
||||||
- Ping-pong buffer architecture for optimal throughput
|
|
||||||
- Callback-based event-driven pattern for RTSP streams
|
- Callback-based event-driven pattern for RTSP streams
|
||||||
- Automatic PT to TensorRT conversion
|
- Automatic PT to TensorRT conversion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
@ -28,6 +28,11 @@ from services import (
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# Enable debug logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main_multi_stream():
|
def main_multi_stream():
|
||||||
"""Multi-stream example with batched inference."""
|
"""Multi-stream example with batched inference."""
|
||||||
|
|
@ -41,8 +46,8 @@ def main_multi_stream():
|
||||||
USE_ULTRALYTICS = (
|
USE_ULTRALYTICS = (
|
||||||
os.getenv("USE_ULTRALYTICS", "true").lower() == "true"
|
os.getenv("USE_ULTRALYTICS", "true").lower() == "true"
|
||||||
) # Use Ultralytics engine for YOLO
|
) # Use Ultralytics engine for YOLO
|
||||||
BATCH_SIZE = 2 # Reduced to 2 to avoid GPU memory issues
|
BATCH_SIZE = 2 # Must match engine's fixed batch size
|
||||||
FORCE_TIMEOUT = 0.05
|
MAX_QUEUE_SIZE = 50 # Drop frames if queue gets too long
|
||||||
ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true"
|
ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true"
|
||||||
|
|
||||||
# Load camera URLs
|
# Load camera URLs
|
||||||
|
|
@ -73,10 +78,11 @@ def main_multi_stream():
|
||||||
manager = StreamConnectionManager(
|
manager = StreamConnectionManager(
|
||||||
gpu_id=GPU_ID,
|
gpu_id=GPU_ID,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
force_timeout=FORCE_TIMEOUT,
|
max_queue_size=MAX_QUEUE_SIZE,
|
||||||
enable_pt_conversion=True,
|
enable_pt_conversion=True,
|
||||||
backend=backend,
|
backend=backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("✓ Manager created")
|
print("✓ Manager created")
|
||||||
|
|
||||||
# Initialize model (transparent loading)
|
# Initialize model (transparent loading)
|
||||||
|
|
@ -86,7 +92,6 @@ def main_multi_stream():
|
||||||
model_path=MODEL_PATH,
|
model_path=MODEL_PATH,
|
||||||
model_id="detector",
|
model_id="detector",
|
||||||
preprocess_fn=YOLOv8Utils.preprocess,
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
postprocess_fn=YOLOv8Utils.postprocess,
|
|
||||||
num_contexts=1, # Single context to minimize GPU memory usage
|
num_contexts=1, # Single context to minimize GPU memory usage
|
||||||
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
||||||
)
|
)
|
||||||
|
|
@ -98,6 +103,109 @@ def main_multi_stream():
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Track stats (initialize before callback definition)
|
||||||
|
stream_stats = {sid: {"count": 0, "start": time.time()} for sid, _ in camera_urls}
|
||||||
|
total_results = 0
|
||||||
|
start_time = time.time()
|
||||||
|
stats_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Create windows for each stream if display enabled
|
||||||
|
if ENABLE_DISPLAY:
|
||||||
|
for stream_id, _ in camera_urls:
|
||||||
|
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
|
||||||
|
cv2.resizeWindow(
|
||||||
|
stream_id, 640, 360
|
||||||
|
) # Smaller windows for multiple streams
|
||||||
|
|
||||||
|
def on_tracking_result(result):
|
||||||
|
"""Callback for tracking results - called automatically per stream"""
|
||||||
|
nonlocal total_results
|
||||||
|
|
||||||
|
# Debug: Check if we have frame tensor
|
||||||
|
has_frame = result.frame_tensor is not None
|
||||||
|
frame_shape = result.frame_tensor.shape if has_frame else None
|
||||||
|
print(
|
||||||
|
f"[CALLBACK] Got result for {result.stream_id}, has_frame={has_frame}, shape={frame_shape}, detections={len(result.detections)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with stats_lock:
|
||||||
|
total_results += 1
|
||||||
|
stream_id = result.stream_id
|
||||||
|
if stream_id in stream_stats:
|
||||||
|
stream_stats[stream_id]["count"] += 1
|
||||||
|
|
||||||
|
# Print stats every 10 results (changed from 100 for faster feedback)
|
||||||
|
if total_results % 10 == 0:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
total_fps = total_results / elapsed if elapsed > 0 else 0
|
||||||
|
print(
|
||||||
|
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
|
||||||
|
)
|
||||||
|
for sid, stats in stream_stats.items():
|
||||||
|
s_elapsed = time.time() - stats["start"]
|
||||||
|
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
|
||||||
|
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
||||||
|
|
||||||
|
# Display visualization if enabled
|
||||||
|
if ENABLE_DISPLAY and result.frame_tensor is not None:
|
||||||
|
# Convert GPU tensor (C, H, W) to CPU numpy (H, W, C) for OpenCV
|
||||||
|
frame_tensor = result.frame_tensor # (3, 720, 1280) RGB uint8
|
||||||
|
frame_np = (
|
||||||
|
frame_tensor.cpu().permute(1, 2, 0).numpy().astype(np.uint8)
|
||||||
|
) # (720, 1280, 3)
|
||||||
|
frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# Draw bounding boxes
|
||||||
|
for obj in result.tracked_objects:
|
||||||
|
x1, y1, x2, y2 = map(int, obj.bbox)
|
||||||
|
|
||||||
|
# Draw box
|
||||||
|
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# Draw label with ID and class
|
||||||
|
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||||
|
(label_w, label_h), _ = cv2.getTextSize(
|
||||||
|
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||||
|
)
|
||||||
|
cv2.rectangle(
|
||||||
|
frame_bgr,
|
||||||
|
(x1, y1 - label_h - 10),
|
||||||
|
(x1 + label_w, y1),
|
||||||
|
(0, 255, 0),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
cv2.putText(
|
||||||
|
frame_bgr,
|
||||||
|
label,
|
||||||
|
(x1, y1 - 5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.5,
|
||||||
|
(0, 0, 0),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show FPS on frame
|
||||||
|
with stats_lock:
|
||||||
|
s_elapsed = time.time() - stream_stats[stream_id]["start"]
|
||||||
|
s_fps = (
|
||||||
|
stream_stats[stream_id]["count"] / s_elapsed if s_elapsed > 0 else 0
|
||||||
|
)
|
||||||
|
fps_text = (
|
||||||
|
f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
|
||||||
|
)
|
||||||
|
cv2.putText(
|
||||||
|
frame_bgr,
|
||||||
|
fps_text,
|
||||||
|
(10, 30),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.7,
|
||||||
|
(0, 255, 0),
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display
|
||||||
|
cv2.imshow(stream_id, frame_bgr)
|
||||||
|
|
||||||
# Connect all streams in parallel using threads
|
# Connect all streams in parallel using threads
|
||||||
print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...")
|
print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...")
|
||||||
connections = {}
|
connections = {}
|
||||||
|
|
@ -108,7 +216,10 @@ def main_multi_stream():
|
||||||
"""Thread worker to connect a single stream"""
|
"""Thread worker to connect a single stream"""
|
||||||
try:
|
try:
|
||||||
conn = manager.connect_stream(
|
conn = manager.connect_stream(
|
||||||
rtsp_url=rtsp_url, stream_id=stream_id, buffer_size=3
|
rtsp_url=rtsp_url,
|
||||||
|
stream_id=stream_id,
|
||||||
|
buffer_size=2,
|
||||||
|
on_tracking_result=on_tracking_result, # Register callback
|
||||||
)
|
)
|
||||||
connection_results[stream_id] = ("success", conn)
|
connection_results[stream_id] = ("success", conn)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -144,124 +255,15 @@ def main_multi_stream():
|
||||||
print("Press Ctrl+C to stop")
|
print("Press Ctrl+C to stop")
|
||||||
print(f"{'=' * 80}\n")
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
# Track stats
|
|
||||||
stream_stats = {
|
|
||||||
sid: {"count": 0, "start": time.time()} for sid in connections.keys()
|
|
||||||
}
|
|
||||||
total_results = 0
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Create windows for each stream if display enabled
|
|
||||||
if ENABLE_DISPLAY:
|
|
||||||
for stream_id in connections.keys():
|
|
||||||
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
|
|
||||||
cv2.resizeWindow(
|
|
||||||
stream_id, 640, 360
|
|
||||||
) # Smaller windows for multiple streams
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Merge all result queues from all connections
|
# Keep main thread alive and process OpenCV events
|
||||||
import queue as queue_module
|
while True:
|
||||||
|
|
||||||
running = True
|
|
||||||
while running:
|
|
||||||
# Poll all connection queues (non-blocking)
|
|
||||||
got_result = False
|
|
||||||
for conn in connections.values():
|
|
||||||
try:
|
|
||||||
# Non-blocking get from each connection's queue
|
|
||||||
result = conn.result_queue.get_nowait()
|
|
||||||
got_result = True
|
|
||||||
|
|
||||||
total_results += 1
|
|
||||||
stream_id = result.stream_id
|
|
||||||
|
|
||||||
if stream_id in stream_stats:
|
|
||||||
stream_stats[stream_id]["count"] += 1
|
|
||||||
|
|
||||||
# Display visualization if enabled
|
|
||||||
if ENABLE_DISPLAY:
|
if ENABLE_DISPLAY:
|
||||||
# Get latest frame from decoder (already in CPU memory as numpy RGB)
|
|
||||||
frame_rgb = conn.decoder.get_latest_frame_cpu(rgb=True)
|
|
||||||
if frame_rgb is not None:
|
|
||||||
# Convert RGB to BGR for OpenCV
|
|
||||||
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
|
||||||
|
|
||||||
# Draw bounding boxes
|
|
||||||
for obj in result.tracked_objects:
|
|
||||||
x1, y1, x2, y2 = map(int, obj.bbox)
|
|
||||||
|
|
||||||
# Draw box
|
|
||||||
cv2.rectangle(
|
|
||||||
frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Draw label with ID and class
|
|
||||||
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
|
||||||
(label_w, label_h), _ = cv2.getTextSize(
|
|
||||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
|
||||||
)
|
|
||||||
cv2.rectangle(
|
|
||||||
frame_bgr,
|
|
||||||
(x1, y1 - label_h - 10),
|
|
||||||
(x1 + label_w, y1),
|
|
||||||
(0, 255, 0),
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
cv2.putText(
|
|
||||||
frame_bgr,
|
|
||||||
label,
|
|
||||||
(x1, y1 - 5),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX,
|
|
||||||
0.5,
|
|
||||||
(0, 0, 0),
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show FPS on frame
|
|
||||||
s_elapsed = time.time() - stream_stats[stream_id]["start"]
|
|
||||||
s_fps = (
|
|
||||||
stream_stats[stream_id]["count"] / s_elapsed
|
|
||||||
if s_elapsed > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
fps_text = f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
|
|
||||||
cv2.putText(
|
|
||||||
frame_bgr,
|
|
||||||
fps_text,
|
|
||||||
(10, 30),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX,
|
|
||||||
0.7,
|
|
||||||
(0, 255, 0),
|
|
||||||
2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Display
|
|
||||||
cv2.imshow(stream_id, frame_bgr)
|
|
||||||
|
|
||||||
# Print stats every 100 results
|
|
||||||
if total_results % 100 == 0:
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
total_fps = total_results / elapsed if elapsed > 0 else 0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
|
|
||||||
)
|
|
||||||
for sid, stats in stream_stats.items():
|
|
||||||
s_elapsed = time.time() - stats["start"]
|
|
||||||
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
|
|
||||||
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
|
||||||
|
|
||||||
except queue_module.Empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Process OpenCV events to keep windows responsive
|
# Process OpenCV events to keep windows responsive
|
||||||
if ENABLE_DISPLAY:
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
cv2.waitKey(1)
|
break
|
||||||
|
else:
|
||||||
# Small sleep if no results to avoid busy loop
|
time.sleep(0.1)
|
||||||
if not got_result:
|
|
||||||
time.sleep(0.01)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print(f"\n✓ Interrupted")
|
print(f"\n✓ Interrupted")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue