add models and update tracking system
This commit is contained in:
parent
fd470b3765
commit
2b0cfc4b72
9 changed files with 780 additions and 478 deletions
|
|
@ -1,22 +1,21 @@
|
|||
"""
|
||||
Real-time object tracking visualization with OpenCV.
|
||||
Real-time object tracking with event-driven batching architecture.
|
||||
|
||||
This script demonstrates:
|
||||
- GPU-accelerated decoding and tracking
|
||||
- CPU-side visualization with bounding boxes and track IDs
|
||||
- Real-time display using OpenCV
|
||||
- FPS monitoring and performance metrics
|
||||
- Event-driven stream processing with StreamConnectionManager
|
||||
- Batched GPU inference with ModelController
|
||||
- Ping-pong buffer architecture for optimal throughput
|
||||
- Async/await pattern for multiple RTSP streams
|
||||
- Automatic PT to TensorRT conversion
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from dotenv import load_dotenv
|
||||
from services import (
|
||||
StreamDecoderFactory,
|
||||
TensorRTModelRepository,
|
||||
TrackingFactory,
|
||||
StreamConnectionManager,
|
||||
YOLOv8Utils,
|
||||
COCO_CLASSES,
|
||||
)
|
||||
|
|
@ -25,513 +24,253 @@ from services import (
|
|||
load_dotenv()
|
||||
|
||||
|
||||
def draw_tracking_overlay(frame: np.ndarray, tracked_objects, frame_info: dict) -> np.ndarray:
|
||||
"""
|
||||
Draw bounding boxes, labels, and tracking info on frame.
|
||||
|
||||
Args:
|
||||
frame: Frame in (H, W, 3) RGB format
|
||||
tracked_objects: List of TrackedObject instances
|
||||
frame_info: Dict with frame count, FPS, etc.
|
||||
|
||||
Returns:
|
||||
Frame with overlays drawn
|
||||
"""
|
||||
# Convert RGB to BGR for OpenCV
|
||||
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Get frame dimensions
|
||||
frame_height, frame_width = frame.shape[:2]
|
||||
|
||||
# Filter tracked objects to only show person and car
|
||||
filtered_objects = [obj for obj in tracked_objects if obj.class_name in ['person', 'car']]
|
||||
|
||||
# Define colors for different track IDs (cycling through colors)
|
||||
colors = [
|
||||
(0, 255, 0), # Green
|
||||
(255, 0, 0), # Blue
|
||||
(0, 0, 255), # Red
|
||||
(255, 255, 0), # Cyan
|
||||
(255, 0, 255), # Magenta
|
||||
(0, 255, 255), # Yellow
|
||||
(128, 255, 0), # Light green
|
||||
(255, 128, 0), # Orange
|
||||
]
|
||||
|
||||
# Draw each tracked object
|
||||
for obj in filtered_objects:
|
||||
|
||||
# Get color based on track ID
|
||||
color = colors[obj.track_id % len(colors)]
|
||||
|
||||
# Extract bounding box coordinates
|
||||
# Boxes come from YOLOv8 in 640x640 space, need to scale to frame size
|
||||
x1, y1, x2, y2 = obj.bbox
|
||||
|
||||
# Scale from 640x640 model space to actual frame size
|
||||
# YOLOv8 output is in 640x640, but frame is 1280x720
|
||||
scale_x = frame_width / 640.0
|
||||
scale_y = frame_height / 640.0
|
||||
|
||||
x1 = int(x1 * scale_x)
|
||||
y1 = int(y1 * scale_y)
|
||||
x2 = int(x2 * scale_x)
|
||||
y2 = int(y2 * scale_y)
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# Prepare label text
|
||||
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||
|
||||
# Get text size for background rectangle
|
||||
(text_width, text_height), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
|
||||
# Draw label background
|
||||
cv2.rectangle(
|
||||
frame_bgr,
|
||||
(x1, y1 - text_height - baseline - 5),
|
||||
(x1 + text_width, y1),
|
||||
color,
|
||||
-1 # Filled
|
||||
)
|
||||
|
||||
# Draw label text
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
label,
|
||||
(x1, y1 - baseline - 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 0, 0), # Black text
|
||||
1,
|
||||
cv2.LINE_AA
|
||||
)
|
||||
|
||||
# Draw track history if available (trajectory)
|
||||
if hasattr(obj, 'history') and len(obj.history) > 1:
|
||||
points = []
|
||||
for hist_bbox in obj.history[-10:]: # Last 10 positions
|
||||
# Get center point of historical bbox (in 640x640 space)
|
||||
hx1, hy1, hx2, hy2 = hist_bbox
|
||||
|
||||
# Scale from 640x640 to frame size
|
||||
cx = int(((hx1 + hx2) / 2) * scale_x)
|
||||
cy = int(((hy1 + hy2) / 2) * scale_y)
|
||||
points.append((cx, cy))
|
||||
|
||||
# Draw trajectory line
|
||||
for i in range(1, len(points)):
|
||||
cv2.line(frame_bgr, points[i-1], points[i], color, 2)
|
||||
|
||||
# Draw info panel at top
|
||||
info_bg_height = 80
|
||||
overlay = frame_bgr.copy()
|
||||
cv2.rectangle(overlay, (0, 0), (frame_bgr.shape[1], info_bg_height), (0, 0, 0), -1)
|
||||
cv2.addWeighted(overlay, 0.5, frame_bgr, 0.5, 0, frame_bgr)
|
||||
|
||||
# Draw statistics text
|
||||
y_offset = 25
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
f"Frame: {frame_info.get('frame_count', 0)} | FPS: {frame_info.get('fps', 0):.1f}",
|
||||
(10, y_offset),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA
|
||||
)
|
||||
|
||||
y_offset += 25
|
||||
# Count persons and cars
|
||||
person_count = sum(1 for obj in filtered_objects if obj.class_name == 'person')
|
||||
car_count = sum(1 for obj in filtered_objects if obj.class_name == 'car')
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
f"Persons: {person_count} | Cars: {car_count} | Total Visible: {len(filtered_objects)}",
|
||||
(10, y_offset),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA
|
||||
)
|
||||
|
||||
return frame_bgr
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function for real-time tracking visualization.
|
||||
"""
|
||||
import torch
|
||||
async def main_single_stream():
|
||||
"""Single stream example with event-driven architecture."""
|
||||
print("=" * 80)
|
||||
print("Event-Driven GPU-Accelerated Object Tracking - Single Stream")
|
||||
print("=" * 80)
|
||||
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.pt" # Changed to PT file
|
||||
RTSP_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
BUFFER_SIZE = 30
|
||||
WINDOW_NAME = "Real-time Object Tracking"
|
||||
MODEL_PATH = "models/yolov8n.pt" # PT file will be auto-converted
|
||||
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
BATCH_SIZE = 4
|
||||
FORCE_TIMEOUT = 0.05
|
||||
|
||||
print("=" * 80)
|
||||
print("Real-time GPU-Accelerated Object Tracking")
|
||||
print("=" * 80)
|
||||
print(f"\nConfiguration:")
|
||||
print(f" GPU: {GPU_ID}")
|
||||
print(f" Model: {MODEL_PATH}")
|
||||
print(f" Stream: {STREAM_URL}")
|
||||
print(f" Batch size: {BATCH_SIZE}")
|
||||
print(f" Force timeout: {FORCE_TIMEOUT}s\n")
|
||||
|
||||
# Step 1: Create model repository with PT conversion enabled
|
||||
print("\n[1/4] Initializing TensorRT Model Repository...")
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4, enable_pt_conversion=True)
|
||||
|
||||
# Load detection model (will auto-convert PT to TRT)
|
||||
model_id = "yolov8_detector"
|
||||
if os.path.exists(MODEL_PATH):
|
||||
try:
|
||||
print(f"Loading model from {MODEL_PATH}...")
|
||||
print("Note: First load will convert PT to TensorRT (may take 3-5 minutes)")
|
||||
print("Subsequent loads will use cached TensorRT engine")
|
||||
|
||||
metadata = model_repo.load_model(
|
||||
model_id=model_id,
|
||||
file_path=MODEL_PATH,
|
||||
num_contexts=4,
|
||||
pt_input_shapes={"images": (1, 3, 640, 640)}, # Required for PT conversion
|
||||
pt_precision=torch.float16 # Use FP16 for better performance
|
||||
)
|
||||
print(f"✓ Model loaded successfully")
|
||||
print(f" Input shape: {metadata.input_shapes}")
|
||||
print(f" Output shape: {metadata.output_shapes}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
print(f" Please ensure {MODEL_PATH} exists")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
else:
|
||||
print(f"✗ Model file not found: {MODEL_PATH}")
|
||||
print(f" Please provide a valid PyTorch (.pt) or TensorRT (.trt) model file")
|
||||
return
|
||||
|
||||
# Step 2: Create tracking controller
|
||||
print("\n[2/4] Creating TrackingController...")
|
||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||
|
||||
try:
|
||||
tracking_controller = tracking_factory.create_controller(
|
||||
model_repository=model_repo,
|
||||
model_id=model_id,
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5,
|
||||
iou_threshold=0.3,
|
||||
class_names=COCO_CLASSES
|
||||
)
|
||||
print(f"✓ Controller created: {tracking_controller}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to create controller: {e}")
|
||||
return
|
||||
|
||||
# Step 3: Create stream decoder
|
||||
print("\n[3/4] Creating RTSP Stream Decoder...")
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoder = stream_factory.create_decoder(
|
||||
rtsp_url=RTSP_URL,
|
||||
buffer_size=BUFFER_SIZE
|
||||
# Create StreamConnectionManager with PT conversion enabled
|
||||
print("[1/3] Creating StreamConnectionManager...")
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=GPU_ID,
|
||||
batch_size=BATCH_SIZE,
|
||||
force_timeout=FORCE_TIMEOUT,
|
||||
enable_pt_conversion=True # Enable PT conversion
|
||||
)
|
||||
decoder.start()
|
||||
print(f"✓ Decoder started for: {RTSP_URL}")
|
||||
print(f" Waiting for connection...")
|
||||
|
||||
# Wait for stream connection
|
||||
print(" Waiting up to 15 seconds for connection...")
|
||||
connected = False
|
||||
for i in range(15):
|
||||
time.sleep(1)
|
||||
if decoder.is_connected():
|
||||
connected = True
|
||||
break
|
||||
print(f" Waiting... {i+1}/15 seconds (status: {decoder.get_status().value})")
|
||||
|
||||
if connected:
|
||||
print(f"✓ Stream connected!")
|
||||
else:
|
||||
print(f"✗ Stream not connected after 15 seconds (status: {decoder.get_status().value})")
|
||||
print(f" Proceeding anyway - will start displaying when frames arrive...")
|
||||
# Don't exit - continue and wait for frames
|
||||
|
||||
# Step 4: Create OpenCV window
|
||||
print("\n[4/4] Starting Real-time Visualization...")
|
||||
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(WINDOW_NAME, 1280, 720)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Real-time tracking started!")
|
||||
print("Press 'q' to quit | Press 's' to save screenshot")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# FPS tracking
|
||||
fps_start_time = time.time()
|
||||
fps_frame_count = 0
|
||||
current_fps = 0.0
|
||||
|
||||
frame_count = 0
|
||||
screenshot_count = 0
|
||||
print("✓ Manager created")
|
||||
|
||||
# Initialize with PT model (auto-conversion)
|
||||
print("\n[2/3] Initializing with PT model...")
|
||||
print("Note: First load will convert PT to TensorRT (3-5 minutes)")
|
||||
print("Subsequent loads will use cached TensorRT engine\n")
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get frame from decoder (CPU memory for OpenCV)
|
||||
frame_cpu = decoder.get_frame_cpu(index=-1, rgb=True)
|
||||
|
||||
if frame_cpu is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
# Get GPU frame for tracking
|
||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
if frame_gpu is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
frame_count += 1
|
||||
fps_frame_count += 1
|
||||
|
||||
# Run tracking on GPU frame with YOLOv8 pre/postprocessing
|
||||
tracked_objects = tracking_controller.track(
|
||||
frame_gpu,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess
|
||||
)
|
||||
|
||||
# Calculate FPS every second
|
||||
elapsed = time.time() - fps_start_time
|
||||
if elapsed >= 1.0:
|
||||
current_fps = fps_frame_count / elapsed
|
||||
fps_frame_count = 0
|
||||
fps_start_time = time.time()
|
||||
|
||||
# Get tracking statistics
|
||||
stats = tracking_controller.get_statistics()
|
||||
|
||||
# Prepare frame info for overlay
|
||||
frame_info = {
|
||||
'frame_count': frame_count,
|
||||
'fps': current_fps,
|
||||
'total_tracks': stats['total_tracks_created'],
|
||||
'class_counts': stats['class_counts']
|
||||
}
|
||||
|
||||
# Draw tracking overlay on CPU frame
|
||||
display_frame = draw_tracking_overlay(frame_cpu, tracked_objects, frame_info)
|
||||
|
||||
# Display frame
|
||||
cv2.imshow(WINDOW_NAME, display_frame)
|
||||
|
||||
# Handle keyboard input
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
|
||||
if key == ord('q'):
|
||||
print("\n✓ Quit requested by user")
|
||||
break
|
||||
elif key == ord('s'):
|
||||
# Save screenshot
|
||||
screenshot_count += 1
|
||||
filename = f"screenshot_{screenshot_count:04d}.jpg"
|
||||
cv2.imwrite(filename, display_frame)
|
||||
print(f"✓ Screenshot saved: {filename}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n✓ Interrupted by user")
|
||||
await manager.initialize(
|
||||
model_path=MODEL_PATH,
|
||||
model_id="detector",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
num_contexts=4,
|
||||
pt_input_shapes={"images": (1, 3, 640, 640)},
|
||||
pt_precision=torch.float16
|
||||
)
|
||||
print("✓ Manager initialized (PT converted to TensorRT)")
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error during tracking: {e}")
|
||||
print(f"✗ Failed to initialize: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Connect stream
|
||||
print("\n[3/3] Connecting to stream...")
|
||||
try:
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=STREAM_URL,
|
||||
stream_id="camera_1",
|
||||
buffer_size=30
|
||||
)
|
||||
print(f"✓ Stream connected: camera_1")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to connect stream: {e}")
|
||||
return
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Event-driven tracking is running!")
|
||||
print("Press Ctrl+C to stop")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Stream results
|
||||
result_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
async for result in connection.tracking_results():
|
||||
result_count += 1
|
||||
|
||||
# Print stats every 30 results
|
||||
if result_count % 30 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
fps = result_count / elapsed if elapsed > 0 else 0
|
||||
|
||||
print(f"\nResults: {result_count} | FPS: {fps:.1f}")
|
||||
print(f" Stream: {result.stream_id}")
|
||||
print(f" Objects: {len(result.tracked_objects)}")
|
||||
|
||||
if result.tracked_objects:
|
||||
class_counts = {}
|
||||
for obj in result.tracked_objects:
|
||||
class_counts[obj.class_name] = class_counts.get(obj.class_name, 0) + 1
|
||||
print(f" Classes: {class_counts}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n✓ Interrupted by user")
|
||||
|
||||
# Cleanup
|
||||
print("\n" + "=" * 80)
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Cleanup")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
await connection.stop()
|
||||
await manager.shutdown()
|
||||
print("✓ Stopped")
|
||||
|
||||
# Final stats
|
||||
elapsed = time.time() - start_time
|
||||
avg_fps = result_count / elapsed if elapsed > 0 else 0
|
||||
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
||||
|
||||
|
||||
async def main_multi_stream():
|
||||
"""Multi-stream example with batched inference."""
|
||||
print("=" * 80)
|
||||
print("Event-Driven GPU-Accelerated Object Tracking - Multi-Stream")
|
||||
print("=" * 80)
|
||||
|
||||
# Print final statistics
|
||||
print("\nFinal Tracking Statistics:")
|
||||
stats = tracking_controller.get_statistics()
|
||||
for key, value in stats.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# Close OpenCV window
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# Stop decoder
|
||||
print("\nStopping decoder...")
|
||||
decoder.stop()
|
||||
print("✓ Decoder stopped")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Real-time tracking completed!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def main_multi_window():
|
||||
"""
|
||||
Example: Display multiple camera streams in separate windows.
|
||||
|
||||
This demonstrates tracking on multiple RTSP streams simultaneously
|
||||
with separate OpenCV windows for each stream.
|
||||
"""
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.pt"
|
||||
MODEL_PATH = "models/yolov8n.pt" # PT file will be auto-converted
|
||||
BATCH_SIZE = 16
|
||||
FORCE_TIMEOUT = 0.05
|
||||
|
||||
# Load camera URLs from environment
|
||||
# Load camera URLs
|
||||
camera_urls = []
|
||||
i = 1
|
||||
while True:
|
||||
url = os.getenv(f'CAMERA_URL_{i}')
|
||||
if url:
|
||||
camera_urls.append(url)
|
||||
camera_urls.append((f"camera_{i}", url))
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if not camera_urls:
|
||||
print("No camera URLs found in .env file")
|
||||
print("No camera URLs found in .env")
|
||||
return
|
||||
|
||||
print(f"Starting multi-window tracking with {len(camera_urls)} cameras")
|
||||
print(f"\nConfiguration:")
|
||||
print(f" GPU: {GPU_ID}")
|
||||
print(f" Model: {MODEL_PATH}")
|
||||
print(f" Streams: {len(camera_urls)}")
|
||||
print(f" Batch size: {BATCH_SIZE}\n")
|
||||
|
||||
# Create shared model repository with PT conversion enabled
|
||||
import torch
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8, enable_pt_conversion=True)
|
||||
# Create manager with PT conversion
|
||||
print("[1/3] Creating StreamConnectionManager...")
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=GPU_ID,
|
||||
batch_size=BATCH_SIZE,
|
||||
force_timeout=FORCE_TIMEOUT,
|
||||
enable_pt_conversion=True
|
||||
)
|
||||
print("✓ Manager created")
|
||||
|
||||
if os.path.exists(MODEL_PATH):
|
||||
print(f"Loading model from {MODEL_PATH}...")
|
||||
print("Note: First load will convert PT to TensorRT (may take 3-5 minutes)")
|
||||
print("Subsequent loads will use cached TensorRT engine")
|
||||
|
||||
model_repo.load_model(
|
||||
# Initialize with PT model
|
||||
print("\n[2/3] Initializing with PT model...")
|
||||
try:
|
||||
await manager.initialize(
|
||||
model_path=MODEL_PATH,
|
||||
model_id="detector",
|
||||
file_path=MODEL_PATH,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
num_contexts=8,
|
||||
pt_input_shapes={"images": (1, 3, 640, 640)}, # Required for PT conversion
|
||||
pt_precision=torch.float16 # Use FP16 for better performance
|
||||
pt_input_shapes={"images": (1, 3, 640, 640)},
|
||||
pt_precision=torch.float16
|
||||
)
|
||||
print("✓ Model loaded successfully")
|
||||
else:
|
||||
print(f"Model not found: {MODEL_PATH}")
|
||||
print("✓ Manager initialized")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to initialize: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Create tracking factory
|
||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||
# Connect all streams
|
||||
print(f"\n[3/3] Connecting {len(camera_urls)} streams...")
|
||||
connections = {}
|
||||
for stream_id, rtsp_url in camera_urls:
|
||||
try:
|
||||
conn = await manager.connect_stream(
|
||||
rtsp_url=rtsp_url,
|
||||
stream_id=stream_id,
|
||||
buffer_size=30
|
||||
)
|
||||
connections[stream_id] = conn
|
||||
print(f"✓ Connected: {stream_id}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed {stream_id}: {e}")
|
||||
|
||||
# Create decoders and controllers
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoders = []
|
||||
controllers = []
|
||||
window_names = []
|
||||
if not connections:
|
||||
print("No streams connected")
|
||||
return
|
||||
|
||||
for i, url in enumerate(camera_urls):
|
||||
# Create decoder
|
||||
decoder = stream_factory.create_decoder(url, buffer_size=30)
|
||||
decoder.start()
|
||||
decoders.append(decoder)
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Multi-stream tracking running ({len(connections)} streams)")
|
||||
print("Frames from all streams are batched together!")
|
||||
print("Press Ctrl+C to stop")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Create tracking controller
|
||||
controller = tracking_factory.create_controller(
|
||||
model_repository=model_repo,
|
||||
model_id="detector",
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5,
|
||||
iou_threshold=0.3,
|
||||
class_names=COCO_CLASSES
|
||||
)
|
||||
controllers.append(controller)
|
||||
|
||||
# Create window
|
||||
window_name = f"Camera {i+1}"
|
||||
window_names.append(window_name)
|
||||
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(window_name, 640, 480)
|
||||
|
||||
print(f"Camera {i+1}: {url}")
|
||||
|
||||
print("\nWaiting for streams to connect...")
|
||||
time.sleep(10)
|
||||
|
||||
print("\nPress 'q' to quit")
|
||||
|
||||
# FPS tracking for each stream
|
||||
fps_data = [{'start': time.time(), 'count': 0, 'fps': 0.0} for _ in camera_urls]
|
||||
frame_counts = [0] * len(camera_urls)
|
||||
# Track stats
|
||||
stream_stats = {sid: {'count': 0, 'start': time.time()} for sid in connections.keys()}
|
||||
total_results = 0
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while True:
|
||||
for i, (decoder, controller, window_name) in enumerate(zip(decoders, controllers, window_names)):
|
||||
# Get frames
|
||||
frame_cpu = decoder.get_frame_cpu(index=-1, rgb=True)
|
||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||
# Simple approach: iterate over first connection's results
|
||||
# In production, you'd properly merge all result streams
|
||||
for conn in connections.values():
|
||||
async for result in conn.tracking_results():
|
||||
total_results += 1
|
||||
stream_id = result.stream_id
|
||||
|
||||
if frame_cpu is None or frame_gpu is None:
|
||||
continue
|
||||
if stream_id in stream_stats:
|
||||
stream_stats[stream_id]['count'] += 1
|
||||
|
||||
frame_counts[i] += 1
|
||||
fps_data[i]['count'] += 1
|
||||
# Print stats every 100 results
|
||||
if total_results % 100 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
total_fps = total_results / elapsed if elapsed > 0 else 0
|
||||
|
||||
# Calculate FPS
|
||||
elapsed = time.time() - fps_data[i]['start']
|
||||
if elapsed >= 1.0:
|
||||
fps_data[i]['fps'] = fps_data[i]['count'] / elapsed
|
||||
fps_data[i]['count'] = 0
|
||||
fps_data[i]['start'] = time.time()
|
||||
|
||||
# Track objects with YOLOv8 pre/postprocessing
|
||||
tracked_objects = controller.track(
|
||||
frame_gpu,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess
|
||||
)
|
||||
|
||||
# Get statistics
|
||||
stats = controller.get_statistics()
|
||||
|
||||
# Prepare frame info
|
||||
frame_info = {
|
||||
'frame_count': frame_counts[i],
|
||||
'fps': fps_data[i]['fps'],
|
||||
'total_tracks': stats['total_tracks_created'],
|
||||
'class_counts': stats['class_counts']
|
||||
}
|
||||
|
||||
# Draw overlay and display
|
||||
display_frame = draw_tracking_overlay(frame_cpu, tracked_objects, frame_info)
|
||||
cv2.imshow(window_name, display_frame)
|
||||
|
||||
# Check for quit
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
print(f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS")
|
||||
for sid, stats in stream_stats.items():
|
||||
s_elapsed = time.time() - stats['start']
|
||||
s_fps = stats['count'] / s_elapsed if s_elapsed > 0 else 0
|
||||
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
print(f"\n✓ Interrupted")
|
||||
|
||||
# Cleanup
|
||||
print("\nCleaning up...")
|
||||
cv2.destroyAllWindows()
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Cleanup")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
for decoder in decoders:
|
||||
decoder.stop()
|
||||
for conn in connections.values():
|
||||
await conn.stop()
|
||||
await manager.shutdown()
|
||||
print("✓ Stopped")
|
||||
|
||||
print("\nFinal Statistics:")
|
||||
for i, controller in enumerate(controllers):
|
||||
stats = controller.get_statistics()
|
||||
print(f"\nCamera {i+1}:")
|
||||
print(f" Frames: {stats['frame_count']}")
|
||||
print(f" Tracks created: {stats['total_tracks_created']}")
|
||||
print(f" Active tracks: {stats['active_tracks']}")
|
||||
# Final stats
|
||||
elapsed = time.time() - start_time
|
||||
avg_fps = total_results / elapsed if elapsed > 0 else 0
|
||||
print(f"\nFinal: {total_results} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run single camera visualization
|
||||
# main()
|
||||
|
||||
# Uncomment to run multi-window visualization
|
||||
main_multi_window()
|
||||
import sys
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "single":
|
||||
asyncio.run(main_single_stream())
|
||||
else:
|
||||
asyncio.run(main_multi_stream())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue