new buffer paradigm

This commit is contained in:
Siwat Sirichai 2025-11-11 02:02:12 +07:00
parent fdaeb9981c
commit a519dea130
6 changed files with 341 additions and 327 deletions

View file

@ -4,11 +4,11 @@ Real-time object tracking with event-driven batching architecture.
This script demonstrates:
- Event-driven stream processing with StreamConnectionManager
- Batched GPU inference with ModelController
- Ping-pong buffer architecture for optimal throughput
- Callback-based event-driven pattern for RTSP streams
- Automatic PT to TensorRT conversion
"""
import logging
import os
import threading
import time
@ -28,6 +28,11 @@ from services import (
# Load environment variables
load_dotenv()
# Enable debug logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
def main_multi_stream():
"""Multi-stream example with batched inference."""
@ -41,8 +46,8 @@ def main_multi_stream():
USE_ULTRALYTICS = (
os.getenv("USE_ULTRALYTICS", "true").lower() == "true"
) # Use Ultralytics engine for YOLO
BATCH_SIZE = 2 # Reduced to 2 to avoid GPU memory issues
FORCE_TIMEOUT = 0.05
BATCH_SIZE = 2 # Must match engine's fixed batch size
MAX_QUEUE_SIZE = 50 # Drop frames if queue gets too long
ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true"
# Load camera URLs
@ -73,10 +78,11 @@ def main_multi_stream():
manager = StreamConnectionManager(
gpu_id=GPU_ID,
batch_size=BATCH_SIZE,
force_timeout=FORCE_TIMEOUT,
max_queue_size=MAX_QUEUE_SIZE,
enable_pt_conversion=True,
backend=backend,
)
print("✓ Manager created")
# Initialize model (transparent loading)
@ -86,7 +92,6 @@ def main_multi_stream():
model_path=MODEL_PATH,
model_id="detector",
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
num_contexts=1, # Single context to minimize GPU memory usage
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
)
@ -98,6 +103,109 @@ def main_multi_stream():
traceback.print_exc()
return
# Track stats (initialize before callback definition)
stream_stats = {sid: {"count": 0, "start": time.time()} for sid, _ in camera_urls}
total_results = 0
start_time = time.time()
stats_lock = threading.Lock()
# Create windows for each stream if display enabled
if ENABLE_DISPLAY:
for stream_id, _ in camera_urls:
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
cv2.resizeWindow(
stream_id, 640, 360
) # Smaller windows for multiple streams
def on_tracking_result(result):
"""Callback for tracking results - called automatically per stream"""
nonlocal total_results
# Debug: Check if we have frame tensor
has_frame = result.frame_tensor is not None
frame_shape = result.frame_tensor.shape if has_frame else None
print(
f"[CALLBACK] Got result for {result.stream_id}, has_frame={has_frame}, shape={frame_shape}, detections={len(result.detections)}"
)
with stats_lock:
total_results += 1
stream_id = result.stream_id
if stream_id in stream_stats:
stream_stats[stream_id]["count"] += 1
# Print stats every 10 results (changed from 100 for faster feedback)
if total_results % 10 == 0:
elapsed = time.time() - start_time
total_fps = total_results / elapsed if elapsed > 0 else 0
print(
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
)
for sid, stats in stream_stats.items():
s_elapsed = time.time() - stats["start"]
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
# Display visualization if enabled
if ENABLE_DISPLAY and result.frame_tensor is not None:
# Convert GPU tensor (C, H, W) to CPU numpy (H, W, C) for OpenCV
frame_tensor = result.frame_tensor # (3, 720, 1280) RGB uint8
frame_np = (
frame_tensor.cpu().permute(1, 2, 0).numpy().astype(np.uint8)
) # (720, 1280, 3)
frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
# Draw bounding boxes
for obj in result.tracked_objects:
x1, y1, x2, y2 = map(int, obj.bbox)
# Draw box
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Draw label with ID and class
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
(label_w, label_h), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle(
frame_bgr,
(x1, y1 - label_h - 10),
(x1 + label_w, y1),
(0, 255, 0),
-1,
)
cv2.putText(
frame_bgr,
label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0),
1,
)
# Show FPS on frame
with stats_lock:
s_elapsed = time.time() - stream_stats[stream_id]["start"]
s_fps = (
stream_stats[stream_id]["count"] / s_elapsed if s_elapsed > 0 else 0
)
fps_text = (
f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
)
cv2.putText(
frame_bgr,
fps_text,
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 255, 0),
2,
)
# Display
cv2.imshow(stream_id, frame_bgr)
# Connect all streams in parallel using threads
print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...")
connections = {}
@ -108,7 +216,10 @@ def main_multi_stream():
"""Thread worker to connect a single stream"""
try:
conn = manager.connect_stream(
rtsp_url=rtsp_url, stream_id=stream_id, buffer_size=3
rtsp_url=rtsp_url,
stream_id=stream_id,
buffer_size=2,
on_tracking_result=on_tracking_result, # Register callback
)
connection_results[stream_id] = ("success", conn)
except Exception as e:
@ -144,124 +255,15 @@ def main_multi_stream():
print("Press Ctrl+C to stop")
print(f"{'=' * 80}\n")
# Track stats
stream_stats = {
sid: {"count": 0, "start": time.time()} for sid in connections.keys()
}
total_results = 0
start_time = time.time()
# Create windows for each stream if display enabled
if ENABLE_DISPLAY:
for stream_id in connections.keys():
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
cv2.resizeWindow(
stream_id, 640, 360
) # Smaller windows for multiple streams
try:
# Merge all result queues from all connections
import queue as queue_module
running = True
while running:
# Poll all connection queues (non-blocking)
got_result = False
for conn in connections.values():
try:
# Non-blocking get from each connection's queue
result = conn.result_queue.get_nowait()
got_result = True
total_results += 1
stream_id = result.stream_id
if stream_id in stream_stats:
stream_stats[stream_id]["count"] += 1
# Display visualization if enabled
if ENABLE_DISPLAY:
# Get latest frame from decoder (already in CPU memory as numpy RGB)
frame_rgb = conn.decoder.get_latest_frame_cpu(rgb=True)
if frame_rgb is not None:
# Convert RGB to BGR for OpenCV
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
# Draw bounding boxes
for obj in result.tracked_objects:
x1, y1, x2, y2 = map(int, obj.bbox)
# Draw box
cv2.rectangle(
frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2
)
# Draw label with ID and class
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
(label_w, label_h), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle(
frame_bgr,
(x1, y1 - label_h - 10),
(x1 + label_w, y1),
(0, 255, 0),
-1,
)
cv2.putText(
frame_bgr,
label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0),
1,
)
# Show FPS on frame
s_elapsed = time.time() - stream_stats[stream_id]["start"]
s_fps = (
stream_stats[stream_id]["count"] / s_elapsed
if s_elapsed > 0
else 0
)
fps_text = f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
cv2.putText(
frame_bgr,
fps_text,
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 255, 0),
2,
)
# Display
cv2.imshow(stream_id, frame_bgr)
# Print stats every 100 results
if total_results % 100 == 0:
elapsed = time.time() - start_time
total_fps = total_results / elapsed if elapsed > 0 else 0
print(
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
)
for sid, stats in stream_stats.items():
s_elapsed = time.time() - stats["start"]
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
except queue_module.Empty:
continue
# Process OpenCV events to keep windows responsive
# Keep main thread alive and process OpenCV events
while True:
if ENABLE_DISPLAY:
cv2.waitKey(1)
# Small sleep if no results to avoid busy loop
if not got_result:
time.sleep(0.01)
# Process OpenCV events to keep windows responsive
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
time.sleep(0.1)
except KeyboardInterrupt:
print(f"\n✓ Interrupted")