python-rtsp-worker/test_tracking_realtime.py

354 lines
12 KiB
Python

"""
Real-time object tracking with event-driven batching architecture.
This script demonstrates:
- Event-driven stream processing with StreamConnectionManager
- Batched GPU inference with ModelController
- Ping-pong buffer architecture for optimal throughput
- Callback-based event-driven pattern for RTSP streams
- Automatic PT to TensorRT conversion
"""
import time
import os
import torch
import cv2
import numpy as np
from dotenv import load_dotenv
from services import (
StreamConnectionManager,
YOLOv8Utils,
COCO_CLASSES,
)
# Load environment variables
load_dotenv()
def main_single_stream():
"""Single stream example with event-driven architecture."""
print("=" * 80)
print("Event-Driven GPU-Accelerated Object Tracking - Single Stream")
print("=" * 80)
# Configuration
GPU_ID = 0
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
BATCH_SIZE = 4
FORCE_TIMEOUT = 0.05
ENABLE_DISPLAY = os.getenv('ENABLE_DISPLAY', 'false').lower() == 'true' # Set to 'true' to enable OpenCV display
MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300')) # Stop after N frames (0 = unlimited)
print(f"\nConfiguration:")
print(f" GPU: {GPU_ID}")
print(f" Model: {MODEL_PATH}")
print(f" Stream: {STREAM_URL}")
print(f" Batch size: {BATCH_SIZE}")
print(f" Force timeout: {FORCE_TIMEOUT}s")
print(f" Display: {'Enabled' if ENABLE_DISPLAY else 'Disabled (inference only)'}")
print(f" Max frames: {MAX_FRAMES if MAX_FRAMES > 0 else 'Unlimited'}\n")
# Create StreamConnectionManager with PT conversion enabled
print("[1/3] Creating StreamConnectionManager...")
manager = StreamConnectionManager(
gpu_id=GPU_ID,
batch_size=BATCH_SIZE,
force_timeout=FORCE_TIMEOUT,
enable_pt_conversion=True # Enable PT conversion
)
print("✓ Manager created")
# Initialize with model (transparent loading - no manual parameters needed)
print("\n[2/3] Initializing model...")
print("Note: YOLO models auto-convert to native TensorRT .engine (first time only)")
print("Metadata is auto-detected from model - no manual input_shapes needed!\n")
try:
manager.initialize(
model_path=MODEL_PATH,
model_id="detector",
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
num_contexts=4
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
)
print("✓ Manager initialized")
except Exception as e:
print(f"✗ Failed to initialize: {e}")
import traceback
traceback.print_exc()
return
# Connect stream
print("\n[3/3] Connecting to stream...")
try:
connection = manager.connect_stream(
rtsp_url=STREAM_URL,
stream_id="camera_1",
buffer_size=30
)
print(f"✓ Stream connected: camera_1")
except Exception as e:
print(f"✗ Failed to connect stream: {e}")
return
print(f"\n{'=' * 80}")
print("Event-driven tracking is running!")
print("Press Ctrl+C to stop")
print(f"{'=' * 80}\n")
# Stream results with optional OpenCV visualization
result_count = 0
start_time = time.time()
# Create window only if display is enabled
if ENABLE_DISPLAY:
cv2.namedWindow("Object Tracking", cv2.WINDOW_NORMAL)
cv2.resizeWindow("Object Tracking", 1280, 720)
try:
for result in connection.tracking_results():
result_count += 1
# Check if we've reached max frames
if MAX_FRAMES > 0 and result_count >= MAX_FRAMES:
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
break
# OpenCV visualization (only if enabled)
if ENABLE_DISPLAY:
# Get latest frame from decoder (as CPU numpy array)
frame = connection.decoder.get_latest_frame_cpu(rgb=True)
if frame is not None:
# Convert to BGR for OpenCV
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
# Draw tracked objects
for obj in result.tracked_objects:
# Get bbox coordinates
x1, y1, x2, y2 = map(int, obj.bbox)
# Draw bounding box
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Draw track ID and class name
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
# Draw label background
cv2.rectangle(frame_bgr, (x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1), (0, 255, 0), -1)
# Draw label text
cv2.putText(frame_bgr, label, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
# Draw FPS and object count
elapsed = time.time() - start_time
fps = result_count / elapsed if elapsed > 0 else 0
info_text = f"FPS: {fps:.1f} | Objects: {len(result.tracked_objects)} | Frame: {result_count}"
cv2.putText(frame_bgr, info_text, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# Display frame
cv2.imshow("Object Tracking", frame_bgr)
# Check for 'q' key to quit
if cv2.waitKey(1) & 0xFF == ord('q'):
print(f"\n✓ Quit by user (pressed 'q')")
break
# Print stats every 30 results
if result_count % 30 == 0:
elapsed = time.time() - start_time
fps = result_count / elapsed if elapsed > 0 else 0
print(f"\nResults: {result_count} | FPS: {fps:.1f}")
print(f" Stream: {result.stream_id}")
print(f" Objects: {len(result.tracked_objects)}")
if result.tracked_objects:
class_counts = {}
for obj in result.tracked_objects:
class_counts[obj.class_name] = class_counts.get(obj.class_name, 0) + 1
print(f" Classes: {class_counts}")
except KeyboardInterrupt:
print(f"\n✓ Interrupted by user")
# Cleanup
print(f"\n{'=' * 80}")
print("Cleanup")
print(f"{'=' * 80}")
# Close OpenCV window if it was opened
if ENABLE_DISPLAY:
cv2.destroyAllWindows()
connection.stop()
manager.shutdown()
print("✓ Stopped")
# Final stats
elapsed = time.time() - start_time
avg_fps = result_count / elapsed if elapsed > 0 else 0
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
def main_multi_stream():
"""Multi-stream example with batched inference."""
print("=" * 80)
print("Event-Driven GPU-Accelerated Object Tracking - Multi-Stream")
print("=" * 80)
# Configuration
GPU_ID = 0
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
BATCH_SIZE = 16
FORCE_TIMEOUT = 0.05
# Load camera URLs
camera_urls = []
i = 1
while True:
url = os.getenv(f'CAMERA_URL_{i}')
if url:
camera_urls.append((f"camera_{i}", url))
i += 1
else:
break
if not camera_urls:
print("No camera URLs found in .env")
return
print(f"\nConfiguration:")
print(f" GPU: {GPU_ID}")
print(f" Model: {MODEL_PATH}")
print(f" Streams: {len(camera_urls)}")
print(f" Batch size: {BATCH_SIZE}\n")
# Create manager with PT conversion
print("[1/3] Creating StreamConnectionManager...")
manager = StreamConnectionManager(
gpu_id=GPU_ID,
batch_size=BATCH_SIZE,
force_timeout=FORCE_TIMEOUT,
enable_pt_conversion=True
)
print("✓ Manager created")
# Initialize model (transparent loading)
print("\n[2/3] Initializing model...")
try:
manager.initialize(
model_path=MODEL_PATH,
model_id="detector",
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
num_contexts=8
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
)
print("✓ Manager initialized")
except Exception as e:
print(f"✗ Failed to initialize: {e}")
import traceback
traceback.print_exc()
return
# Connect all streams
print(f"\n[3/3] Connecting {len(camera_urls)} streams...")
connections = {}
for stream_id, rtsp_url in camera_urls:
try:
conn = manager.connect_stream(
rtsp_url=rtsp_url,
stream_id=stream_id,
buffer_size=5
)
connections[stream_id] = conn
print(f"✓ Connected: {stream_id}")
except Exception as e:
print(f"✗ Failed {stream_id}: {e}")
if not connections:
print("No streams connected")
return
print(f"\n{'=' * 80}")
print(f"Multi-stream tracking running ({len(connections)} streams)")
print("Frames from all streams are batched together!")
print("Press Ctrl+C to stop")
print(f"{'=' * 80}\n")
# Track stats
stream_stats = {sid: {'count': 0, 'start': time.time()} for sid in connections.keys()}
total_results = 0
start_time = time.time()
try:
# Merge all result queues from all connections
import queue as queue_module
running = True
while running:
# Poll all connection queues (non-blocking)
got_result = False
for conn in connections.values():
try:
# Non-blocking get from each connection's queue
result = conn.result_queue.get_nowait()
got_result = True
total_results += 1
stream_id = result.stream_id
if stream_id in stream_stats:
stream_stats[stream_id]['count'] += 1
# Print stats every 100 results
if total_results % 100 == 0:
elapsed = time.time() - start_time
total_fps = total_results / elapsed if elapsed > 0 else 0
print(f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS")
for sid, stats in stream_stats.items():
s_elapsed = time.time() - stats['start']
s_fps = stats['count'] / s_elapsed if s_elapsed > 0 else 0
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
except queue_module.Empty:
continue
# Small sleep if no results to avoid busy loop
if not got_result:
time.sleep(0.01)
except KeyboardInterrupt:
print(f"\n✓ Interrupted")
# Cleanup
print(f"\n{'=' * 80}")
print("Cleanup")
print(f"{'=' * 80}")
for conn in connections.values():
conn.stop()
manager.shutdown()
print("✓ Stopped")
# Final stats
elapsed = time.time() - start_time
avg_fps = total_results / elapsed if elapsed > 0 else 0
print(f"\nFinal: {total_results} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "single":
main_single_stream()
else:
main_multi_stream()