290 lines
9.8 KiB
Python
290 lines
9.8 KiB
Python
"""
|
|
Real-time object tracking with event-driven batching architecture.
|
|
|
|
This script demonstrates:
|
|
- Event-driven stream processing with StreamConnectionManager
|
|
- Batched GPU inference with ModelController
|
|
- Ping-pong buffer architecture for optimal throughput
|
|
- Callback-based event-driven pattern for RTSP streams
|
|
- Automatic PT to TensorRT conversion
|
|
"""
|
|
|
|
import os
|
|
import threading
|
|
import time
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from dotenv import load_dotenv
|
|
|
|
from services import (
|
|
COCO_CLASSES,
|
|
StreamConnectionManager,
|
|
UltralyticsExporter,
|
|
YOLOv8Utils,
|
|
)
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
|
|
def main_multi_stream():
|
|
"""Multi-stream example with batched inference."""
|
|
print("=" * 80)
|
|
print("Event-Driven GPU-Accelerated Object Tracking - Multi-Stream")
|
|
print("=" * 80)
|
|
|
|
# Configuration
|
|
GPU_ID = 0
|
|
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
|
|
USE_ULTRALYTICS = (
|
|
os.getenv("USE_ULTRALYTICS", "true").lower() == "true"
|
|
) # Use Ultralytics engine for YOLO
|
|
BATCH_SIZE = 2 # Reduced to 2 to avoid GPU memory issues
|
|
FORCE_TIMEOUT = 0.05
|
|
ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true"
|
|
|
|
# Load camera URLs
|
|
camera_urls = []
|
|
i = 1
|
|
while True:
|
|
url = os.getenv(f"CAMERA_URL_{i}")
|
|
if url:
|
|
camera_urls.append((f"camera_{i}", url))
|
|
i += 1
|
|
else:
|
|
break
|
|
|
|
if not camera_urls:
|
|
print("No camera URLs found in .env")
|
|
return
|
|
|
|
print(f"\nConfiguration:")
|
|
print(f" GPU: {GPU_ID}")
|
|
print(f" Model: {MODEL_PATH}")
|
|
print(f" Streams: {len(camera_urls)}")
|
|
print(f" Batch size: {BATCH_SIZE}\n")
|
|
|
|
# Create manager with backend selection
|
|
print("[1/3] Creating StreamConnectionManager...")
|
|
backend = "ultralytics"
|
|
print(f" Backend: {backend}")
|
|
manager = StreamConnectionManager(
|
|
gpu_id=GPU_ID,
|
|
batch_size=BATCH_SIZE,
|
|
force_timeout=FORCE_TIMEOUT,
|
|
enable_pt_conversion=True,
|
|
backend=backend,
|
|
)
|
|
print("✓ Manager created")
|
|
|
|
# Initialize model (transparent loading)
|
|
print("\n[2/3] Initializing model...")
|
|
try:
|
|
manager.initialize(
|
|
model_path=MODEL_PATH,
|
|
model_id="detector",
|
|
preprocess_fn=YOLOv8Utils.preprocess,
|
|
postprocess_fn=YOLOv8Utils.postprocess,
|
|
num_contexts=1, # Single context to minimize GPU memory usage
|
|
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
|
)
|
|
print("✓ Manager initialized")
|
|
except Exception as e:
|
|
print(f"✗ Failed to initialize: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
return
|
|
|
|
# Connect all streams in parallel using threads
|
|
print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...")
|
|
connections = {}
|
|
connection_threads = []
|
|
connection_results = {}
|
|
|
|
def connect_stream(stream_id, rtsp_url):
|
|
"""Thread worker to connect a single stream"""
|
|
try:
|
|
conn = manager.connect_stream(
|
|
rtsp_url=rtsp_url, stream_id=stream_id, buffer_size=3
|
|
)
|
|
connection_results[stream_id] = ("success", conn)
|
|
except Exception as e:
|
|
connection_results[stream_id] = ("error", str(e))
|
|
|
|
# Start all connection threads
|
|
for stream_id, rtsp_url in camera_urls:
|
|
thread = threading.Thread(
|
|
target=connect_stream, args=(stream_id, rtsp_url), daemon=True
|
|
)
|
|
thread.start()
|
|
connection_threads.append(thread)
|
|
|
|
# Wait for all connections to complete
|
|
for thread in connection_threads:
|
|
thread.join()
|
|
|
|
# Collect results
|
|
for stream_id, (status, result) in connection_results.items():
|
|
if status == "success":
|
|
connections[stream_id] = result
|
|
print(f"✓ Connected: {stream_id}")
|
|
else:
|
|
print(f"✗ Failed {stream_id}: {result}")
|
|
|
|
if not connections:
|
|
print("No streams connected")
|
|
return
|
|
|
|
print(f"\n{'=' * 80}")
|
|
print(f"Multi-stream tracking running ({len(connections)} streams)")
|
|
print("Frames from all streams are batched together!")
|
|
print("Press Ctrl+C to stop")
|
|
print(f"{'=' * 80}\n")
|
|
|
|
# Track stats
|
|
stream_stats = {
|
|
sid: {"count": 0, "start": time.time()} for sid in connections.keys()
|
|
}
|
|
total_results = 0
|
|
start_time = time.time()
|
|
|
|
# Create windows for each stream if display enabled
|
|
if ENABLE_DISPLAY:
|
|
for stream_id in connections.keys():
|
|
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
|
|
cv2.resizeWindow(
|
|
stream_id, 640, 360
|
|
) # Smaller windows for multiple streams
|
|
|
|
try:
|
|
# Merge all result queues from all connections
|
|
import queue as queue_module
|
|
|
|
running = True
|
|
while running:
|
|
# Poll all connection queues (non-blocking)
|
|
got_result = False
|
|
for conn in connections.values():
|
|
try:
|
|
# Non-blocking get from each connection's queue
|
|
result = conn.result_queue.get_nowait()
|
|
got_result = True
|
|
|
|
total_results += 1
|
|
stream_id = result.stream_id
|
|
|
|
if stream_id in stream_stats:
|
|
stream_stats[stream_id]["count"] += 1
|
|
|
|
# Display visualization if enabled
|
|
if ENABLE_DISPLAY:
|
|
# Get latest frame from decoder (already in CPU memory as numpy RGB)
|
|
frame_rgb = conn.decoder.get_latest_frame_cpu(rgb=True)
|
|
if frame_rgb is not None:
|
|
# Convert RGB to BGR for OpenCV
|
|
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
|
|
|
# Draw bounding boxes
|
|
for obj in result.tracked_objects:
|
|
x1, y1, x2, y2 = map(int, obj.bbox)
|
|
|
|
# Draw box
|
|
cv2.rectangle(
|
|
frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2
|
|
)
|
|
|
|
# Draw label with ID and class
|
|
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
|
(label_w, label_h), _ = cv2.getTextSize(
|
|
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
|
)
|
|
cv2.rectangle(
|
|
frame_bgr,
|
|
(x1, y1 - label_h - 10),
|
|
(x1 + label_w, y1),
|
|
(0, 255, 0),
|
|
-1,
|
|
)
|
|
cv2.putText(
|
|
frame_bgr,
|
|
label,
|
|
(x1, y1 - 5),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
(0, 0, 0),
|
|
1,
|
|
)
|
|
|
|
# Show FPS on frame
|
|
s_elapsed = time.time() - stream_stats[stream_id]["start"]
|
|
s_fps = (
|
|
stream_stats[stream_id]["count"] / s_elapsed
|
|
if s_elapsed > 0
|
|
else 0
|
|
)
|
|
fps_text = f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
|
|
cv2.putText(
|
|
frame_bgr,
|
|
fps_text,
|
|
(10, 30),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.7,
|
|
(0, 255, 0),
|
|
2,
|
|
)
|
|
|
|
# Display
|
|
cv2.imshow(stream_id, frame_bgr)
|
|
|
|
# Print stats every 100 results
|
|
if total_results % 100 == 0:
|
|
elapsed = time.time() - start_time
|
|
total_fps = total_results / elapsed if elapsed > 0 else 0
|
|
|
|
print(
|
|
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
|
|
)
|
|
for sid, stats in stream_stats.items():
|
|
s_elapsed = time.time() - stats["start"]
|
|
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
|
|
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
|
|
|
except queue_module.Empty:
|
|
continue
|
|
|
|
# Process OpenCV events to keep windows responsive
|
|
if ENABLE_DISPLAY:
|
|
cv2.waitKey(1)
|
|
|
|
# Small sleep if no results to avoid busy loop
|
|
if not got_result:
|
|
time.sleep(0.01)
|
|
|
|
except KeyboardInterrupt:
|
|
print(f"\n✓ Interrupted")
|
|
|
|
# Cleanup
|
|
print(f"\n{'=' * 80}")
|
|
print("Cleanup")
|
|
print(f"{'=' * 80}")
|
|
|
|
# Close OpenCV windows if they were opened
|
|
if ENABLE_DISPLAY:
|
|
cv2.destroyAllWindows()
|
|
|
|
for conn in connections.values():
|
|
conn.stop()
|
|
manager.shutdown()
|
|
print("✓ Stopped")
|
|
|
|
# Final stats
|
|
elapsed = time.time() - start_time
|
|
avg_fps = total_results / elapsed if elapsed > 0 else 0
|
|
print(f"\nFinal: {total_results} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main_multi_stream()
|