yolo util class
This commit is contained in:
parent
bea895d3d8
commit
81bbb0074e
5 changed files with 1058 additions and 1 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -3,4 +3,5 @@ __pycache__/
|
|||
*.pyc
|
||||
.env
|
||||
.claude
|
||||
models/
|
||||
models/
|
||||
/tracked_objects.json
|
||||
|
|
@ -7,6 +7,7 @@ from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg
|
|||
from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionContext, SharedEngine
|
||||
from .tracking_controller import TrackingController, TrackedObject
|
||||
from .tracking_factory import TrackingFactory
|
||||
from .yolo import YOLOv8Utils, COCO_CLASSES
|
||||
|
||||
__all__ = [
|
||||
'StreamDecoderFactory',
|
||||
|
|
@ -21,4 +22,6 @@ __all__ = [
|
|||
'TrackingController',
|
||||
'TrackedObject',
|
||||
'TrackingFactory',
|
||||
'YOLOv8Utils',
|
||||
'COCO_CLASSES',
|
||||
]
|
||||
|
|
|
|||
198
services/yolo.py
Normal file
198
services/yolo.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""
|
||||
YOLOv8 Model Utilities
|
||||
|
||||
This module provides static utility functions for YOLOv8 model preprocessing
|
||||
and postprocessing, compatible with TensorRT inference.
|
||||
|
||||
Features:
|
||||
- Preprocessing: Resize and normalize frames for YOLOv8 inference
|
||||
- Postprocessing: Parse YOLOv8 output format to detection boxes
|
||||
- Format conversion: (cx, cy, w, h) to (x1, y1, x2, y2)
|
||||
- Confidence filtering and NMS handled separately
|
||||
|
||||
Usage:
|
||||
from services.yolo import YOLOv8Utils
|
||||
|
||||
# Preprocess frame
|
||||
model_input = YOLOv8Utils.preprocess(frame_gpu, input_size=640)
|
||||
|
||||
# Run inference
|
||||
outputs = model_repo.infer(model_id="yolov8", inputs={"images": model_input})
|
||||
|
||||
# Postprocess detections
|
||||
detections = YOLOv8Utils.postprocess(outputs, conf_threshold=0.25, nms_threshold=0.45)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
class YOLOv8Utils:
|
||||
"""Static utility class for YOLOv8 model operations."""
|
||||
|
||||
@staticmethod
|
||||
def preprocess(
|
||||
frame: torch.Tensor,
|
||||
input_size: int = 640
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess frame for YOLOv8 inference.
|
||||
|
||||
Args:
|
||||
frame: RGB frame as GPU tensor, shape (3, H, W) uint8
|
||||
input_size: Model input size (default: 640 for YOLOv8)
|
||||
|
||||
Returns:
|
||||
Preprocessed frame ready for model, shape (1, 3, input_size, input_size) float32
|
||||
|
||||
Example:
|
||||
>>> frame_gpu = decoder.get_latest_frame(rgb=True) # (3, 720, 1280)
|
||||
>>> model_input = YOLOv8Utils.preprocess(frame_gpu) # (1, 3, 640, 640)
|
||||
"""
|
||||
# Add batch dimension and convert to float
|
||||
frame_batch = frame.unsqueeze(0).float() # (1, 3, H, W)
|
||||
|
||||
# Resize to model input size
|
||||
frame_resized = torch.nn.functional.interpolate(
|
||||
frame_batch,
|
||||
size=(input_size, input_size),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Normalize to [0, 1] (YOLOv8 expects normalized input)
|
||||
frame_normalized = frame_resized / 255.0
|
||||
|
||||
return frame_normalized
|
||||
|
||||
@staticmethod
|
||||
def postprocess(
|
||||
outputs: dict,
|
||||
conf_threshold: float = 0.25,
|
||||
nms_threshold: float = 0.45
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Postprocess YOLOv8 TensorRT output to detection format.
|
||||
|
||||
YOLOv8 output format: (1, 84, 8400)
|
||||
- 84 channels = 4 bbox coords (cx, cy, w, h) + 80 class scores
|
||||
- 8400 anchor points
|
||||
|
||||
Args:
|
||||
outputs: Dictionary of model outputs from TensorRT inference
|
||||
conf_threshold: Confidence threshold for filtering detections (default: 0.25)
|
||||
nms_threshold: IoU threshold for Non-Maximum Suppression (default: 0.45)
|
||||
|
||||
Returns:
|
||||
Tensor of shape (N, 6): [x1, y1, x2, y2, conf, class_id]
|
||||
- Coordinates are in model input space (0-640 for default YOLOv8)
|
||||
- N is the number of detections after NMS
|
||||
|
||||
Example:
|
||||
>>> outputs = model_repo.infer(model_id="yolov8", inputs={"images": frame})
|
||||
>>> detections = YOLOv8Utils.postprocess(outputs, conf_threshold=0.5)
|
||||
>>> # detections: [[x1, y1, x2, y2, conf, class_id], ...]
|
||||
"""
|
||||
from torchvision.ops import nms
|
||||
|
||||
# Get output tensor (first and only output)
|
||||
output_name = list(outputs.keys())[0]
|
||||
output = outputs[output_name] # (1, 84, 8400)
|
||||
|
||||
# Transpose to (1, 8400, 84) for easier processing
|
||||
output = output.transpose(1, 2)
|
||||
|
||||
# Process first batch (batch size is always 1 for single image inference)
|
||||
detections = []
|
||||
for detection in output[0]: # Iterate over 8400 anchor points
|
||||
# Split bbox coordinates and class scores
|
||||
bbox = detection[:4] # (cx, cy, w, h)
|
||||
class_scores = detection[4:] # 80 class scores
|
||||
|
||||
# Get max class score and corresponding class ID
|
||||
max_score, class_id = torch.max(class_scores, 0)
|
||||
|
||||
# Filter by confidence threshold
|
||||
if max_score > conf_threshold:
|
||||
# Convert from (cx, cy, w, h) to (x1, y1, x2, y2)
|
||||
cx, cy, w, h = bbox
|
||||
x1 = cx - w / 2
|
||||
y1 = cy - h / 2
|
||||
x2 = cx + w / 2
|
||||
y2 = cy + h / 2
|
||||
|
||||
# Append detection: [x1, y1, x2, y2, conf, class_id]
|
||||
detections.append([
|
||||
x1.item(), y1.item(), x2.item(), y2.item(),
|
||||
max_score.item(), class_id.item()
|
||||
])
|
||||
|
||||
# Return empty tensor if no detections
|
||||
if not detections:
|
||||
return torch.zeros((0, 6), device=output.device)
|
||||
|
||||
# Convert list to tensor
|
||||
detections_tensor = torch.tensor(detections, device=output.device)
|
||||
|
||||
# Apply Non-Maximum Suppression (NMS)
|
||||
boxes = detections_tensor[:, :4] # (N, 4)
|
||||
scores = detections_tensor[:, 4] # (N,)
|
||||
|
||||
# NMS returns indices of boxes to keep
|
||||
keep_indices = nms(boxes, scores, iou_threshold=nms_threshold)
|
||||
|
||||
# Return filtered detections
|
||||
return detections_tensor[keep_indices]
|
||||
|
||||
@staticmethod
|
||||
def scale_boxes(
|
||||
boxes: torch.Tensor,
|
||||
from_size: Tuple[int, int],
|
||||
to_size: Tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Scale bounding boxes from one coordinate space to another.
|
||||
|
||||
Args:
|
||||
boxes: Tensor of boxes, shape (N, 4) in format [x1, y1, x2, y2]
|
||||
from_size: Source size (width, height) - e.g., (640, 640) for model output
|
||||
to_size: Target size (width, height) - e.g., (1280, 720) for display
|
||||
|
||||
Returns:
|
||||
Scaled boxes tensor, same shape as input
|
||||
|
||||
Example:
|
||||
>>> detections = YOLOv8Utils.postprocess(outputs) # boxes in 640x640 space
|
||||
>>> boxes = detections[:, :4] # Extract boxes
|
||||
>>> scaled_boxes = YOLOv8Utils.scale_boxes(boxes, (640, 640), (1280, 720))
|
||||
"""
|
||||
scale_x = to_size[0] / from_size[0]
|
||||
scale_y = to_size[1] / from_size[1]
|
||||
|
||||
# Clone to avoid modifying original
|
||||
scaled = boxes.clone()
|
||||
scaled[:, [0, 2]] *= scale_x # Scale x coordinates
|
||||
scaled[:, [1, 3]] *= scale_y # Scale y coordinates
|
||||
|
||||
return scaled
|
||||
|
||||
|
||||
# COCO class names for YOLOv8 (80 classes)
|
||||
COCO_CLASSES = {
|
||||
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
|
||||
5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
|
||||
10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
|
||||
14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow',
|
||||
20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack',
|
||||
25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee',
|
||||
30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
|
||||
35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket',
|
||||
39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon',
|
||||
45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange',
|
||||
50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut',
|
||||
55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed',
|
||||
60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse',
|
||||
65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven',
|
||||
70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock',
|
||||
75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
|
||||
}
|
||||
340
test_fps_benchmark.py
Normal file
340
test_fps_benchmark.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
"""
|
||||
FPS Benchmark Test for Single vs Multi-Camera Tracking
|
||||
|
||||
This script benchmarks the FPS performance of:
|
||||
1. Single camera tracking
|
||||
2. Multi-camera tracking (2+ cameras)
|
||||
|
||||
Usage:
|
||||
python test_fps_benchmark.py
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from services import (
|
||||
StreamDecoderFactory,
|
||||
TensorRTModelRepository,
|
||||
TrackingFactory,
|
||||
YOLOv8Utils,
|
||||
COCO_CLASSES,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def benchmark_single_camera(duration=30):
|
||||
"""
|
||||
Benchmark single camera tracking performance.
|
||||
|
||||
Args:
|
||||
duration: Test duration in seconds
|
||||
|
||||
Returns:
|
||||
Dictionary with FPS statistics
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("SINGLE CAMERA BENCHMARK")
|
||||
print("=" * 80)
|
||||
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.trt"
|
||||
RTSP_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
|
||||
# Initialize components
|
||||
print("\nInitializing...")
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4)
|
||||
model_repo.load_model("detector", MODEL_PATH, num_contexts=4)
|
||||
|
||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||
controller = tracking_factory.create_controller(
|
||||
model_repository=model_repo,
|
||||
model_id="detector",
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5,
|
||||
iou_threshold=0.3,
|
||||
class_names=COCO_CLASSES
|
||||
)
|
||||
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoder = stream_factory.create_decoder(RTSP_URL, buffer_size=30)
|
||||
decoder.start()
|
||||
|
||||
print("Waiting for stream connection...")
|
||||
time.sleep(5)
|
||||
|
||||
if not decoder.is_connected():
|
||||
print("⚠ Stream not connected, results may be inaccurate")
|
||||
|
||||
# Benchmark
|
||||
print(f"\nRunning benchmark for {duration} seconds...")
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
fps_samples = []
|
||||
sample_start = time.time()
|
||||
sample_frames = 0
|
||||
|
||||
try:
|
||||
while time.time() - start_time < duration:
|
||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
if frame_gpu is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
# Run tracking
|
||||
tracked_objects = controller.track(
|
||||
frame_gpu,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess
|
||||
)
|
||||
|
||||
frame_count += 1
|
||||
sample_frames += 1
|
||||
|
||||
# Sample FPS every second
|
||||
if time.time() - sample_start >= 1.0:
|
||||
fps = sample_frames / (time.time() - sample_start)
|
||||
fps_samples.append(fps)
|
||||
sample_frames = 0
|
||||
sample_start = time.time()
|
||||
print(f" Current FPS: {fps:.2f}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nBenchmark interrupted")
|
||||
|
||||
# Calculate statistics
|
||||
total_time = time.time() - start_time
|
||||
avg_fps = frame_count / total_time
|
||||
|
||||
# Cleanup
|
||||
decoder.stop()
|
||||
|
||||
stats = {
|
||||
'total_frames': frame_count,
|
||||
'total_time': total_time,
|
||||
'avg_fps': avg_fps,
|
||||
'min_fps': min(fps_samples) if fps_samples else 0,
|
||||
'max_fps': max(fps_samples) if fps_samples else 0,
|
||||
'samples': fps_samples
|
||||
}
|
||||
|
||||
print("\n" + "-" * 80)
|
||||
print(f"Total Frames: {stats['total_frames']}")
|
||||
print(f"Total Time: {stats['total_time']:.2f} seconds")
|
||||
print(f"Average FPS: {stats['avg_fps']:.2f}")
|
||||
print(f"Min FPS: {stats['min_fps']:.2f}")
|
||||
print(f"Max FPS: {stats['max_fps']:.2f}")
|
||||
print("-" * 80)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def benchmark_multi_camera(duration=30):
|
||||
"""
|
||||
Benchmark multi-camera tracking performance.
|
||||
|
||||
Args:
|
||||
duration: Test duration in seconds
|
||||
|
||||
Returns:
|
||||
Dictionary with FPS statistics per camera
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("MULTI-CAMERA BENCHMARK")
|
||||
print("=" * 80)
|
||||
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.trt"
|
||||
|
||||
# Load camera URLs
|
||||
camera_urls = []
|
||||
i = 1
|
||||
while True:
|
||||
url = os.getenv(f'CAMERA_URL_{i}')
|
||||
if url:
|
||||
camera_urls.append(url)
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if len(camera_urls) < 2:
|
||||
print("⚠ Need at least 2 cameras for multi-camera test")
|
||||
print(f" Found only {len(camera_urls)} camera(s) in .env")
|
||||
return None
|
||||
|
||||
print(f"\nTesting with {len(camera_urls)} cameras")
|
||||
|
||||
# Initialize components
|
||||
print("\nInitializing...")
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8)
|
||||
model_repo.load_model("detector", MODEL_PATH, num_contexts=8)
|
||||
|
||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
|
||||
decoders = []
|
||||
controllers = []
|
||||
|
||||
for i, url in enumerate(camera_urls):
|
||||
# Create decoder
|
||||
decoder = stream_factory.create_decoder(url, buffer_size=30)
|
||||
decoder.start()
|
||||
decoders.append(decoder)
|
||||
|
||||
# Create controller
|
||||
controller = tracking_factory.create_controller(
|
||||
model_repository=model_repo,
|
||||
model_id="detector",
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5,
|
||||
iou_threshold=0.3,
|
||||
class_names=COCO_CLASSES
|
||||
)
|
||||
controllers.append(controller)
|
||||
|
||||
print(f" Camera {i+1}: {url}")
|
||||
|
||||
print("\nWaiting for streams to connect...")
|
||||
time.sleep(10)
|
||||
|
||||
# Benchmark
|
||||
print(f"\nRunning benchmark for {duration} seconds...")
|
||||
|
||||
frame_counts = [0] * len(camera_urls)
|
||||
fps_samples = [[] for _ in camera_urls]
|
||||
sample_starts = [time.time()] * len(camera_urls)
|
||||
sample_frames = [0] * len(camera_urls)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while time.time() - start_time < duration:
|
||||
for i, (decoder, controller) in enumerate(zip(decoders, controllers)):
|
||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
if frame_gpu is None:
|
||||
continue
|
||||
|
||||
# Run tracking
|
||||
tracked_objects = controller.track(
|
||||
frame_gpu,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess
|
||||
)
|
||||
|
||||
frame_counts[i] += 1
|
||||
sample_frames[i] += 1
|
||||
|
||||
# Sample FPS every second
|
||||
if time.time() - sample_starts[i] >= 1.0:
|
||||
fps = sample_frames[i] / (time.time() - sample_starts[i])
|
||||
fps_samples[i].append(fps)
|
||||
sample_frames[i] = 0
|
||||
sample_starts[i] = time.time()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nBenchmark interrupted")
|
||||
|
||||
# Calculate statistics
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Cleanup
|
||||
for decoder in decoders:
|
||||
decoder.stop()
|
||||
|
||||
# Compile results
|
||||
results = {}
|
||||
total_frames = 0
|
||||
|
||||
print("\n" + "-" * 80)
|
||||
for i in range(len(camera_urls)):
|
||||
avg_fps = frame_counts[i] / total_time if total_time > 0 else 0
|
||||
total_frames += frame_counts[i]
|
||||
|
||||
cam_stats = {
|
||||
'total_frames': frame_counts[i],
|
||||
'avg_fps': avg_fps,
|
||||
'min_fps': min(fps_samples[i]) if fps_samples[i] else 0,
|
||||
'max_fps': max(fps_samples[i]) if fps_samples[i] else 0,
|
||||
}
|
||||
|
||||
results[f'camera_{i+1}'] = cam_stats
|
||||
|
||||
print(f"Camera {i+1}:")
|
||||
print(f" Total Frames: {cam_stats['total_frames']}")
|
||||
print(f" Average FPS: {cam_stats['avg_fps']:.2f}")
|
||||
print(f" Min FPS: {cam_stats['min_fps']:.2f}")
|
||||
print(f" Max FPS: {cam_stats['max_fps']:.2f}")
|
||||
print()
|
||||
|
||||
# Combined stats
|
||||
combined_avg_fps = total_frames / total_time if total_time > 0 else 0
|
||||
|
||||
print("-" * 80)
|
||||
print(f"COMBINED:")
|
||||
print(f" Total Frames (all cameras): {total_frames}")
|
||||
print(f" Total Time: {total_time:.2f} seconds")
|
||||
print(f" Combined Throughput: {combined_avg_fps:.2f} FPS")
|
||||
print(f" Per-Camera Average: {combined_avg_fps / len(camera_urls):.2f} FPS")
|
||||
print("-" * 80)
|
||||
|
||||
results['combined'] = {
|
||||
'total_frames': total_frames,
|
||||
'total_time': total_time,
|
||||
'combined_fps': combined_avg_fps,
|
||||
'per_camera_avg': combined_avg_fps / len(camera_urls)
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""Run both benchmarks and compare."""
|
||||
print("=" * 80)
|
||||
print("FPS BENCHMARK: Single vs Multi-Camera Tracking")
|
||||
print("=" * 80)
|
||||
|
||||
# Run single camera benchmark
|
||||
single_stats = benchmark_single_camera(duration=30)
|
||||
|
||||
# Run multi-camera benchmark
|
||||
multi_stats = benchmark_multi_camera(duration=30)
|
||||
|
||||
# Comparison
|
||||
if multi_stats:
|
||||
print("\n" + "=" * 80)
|
||||
print("COMPARISON")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"\nSingle Camera Performance:")
|
||||
print(f" Average FPS: {single_stats['avg_fps']:.2f}")
|
||||
|
||||
print(f"\nMulti-Camera Performance:")
|
||||
print(f" Per-Camera Average: {multi_stats['combined']['per_camera_avg']:.2f} FPS")
|
||||
print(f" Combined Throughput: {multi_stats['combined']['combined_fps']:.2f} FPS")
|
||||
|
||||
# Calculate performance drop
|
||||
fps_drop = ((single_stats['avg_fps'] - multi_stats['combined']['per_camera_avg'])
|
||||
/ single_stats['avg_fps'] * 100)
|
||||
|
||||
print(f"\nPerformance Analysis:")
|
||||
print(f" FPS Drop per Camera: {fps_drop:.1f}%")
|
||||
|
||||
if fps_drop < 10:
|
||||
print(" ✓ Excellent - Minimal performance impact")
|
||||
elif fps_drop < 25:
|
||||
print(" ✓ Good - Acceptable performance scaling")
|
||||
elif fps_drop < 50:
|
||||
print(" ⚠ Moderate - Some performance degradation")
|
||||
else:
|
||||
print(" ⚠ Significant - Consider optimizations")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
515
test_tracking_realtime.py
Normal file
515
test_tracking_realtime.py
Normal file
|
|
@ -0,0 +1,515 @@
|
|||
"""
|
||||
Real-time object tracking visualization with OpenCV.
|
||||
|
||||
This script demonstrates:
|
||||
- GPU-accelerated decoding and tracking
|
||||
- CPU-side visualization with bounding boxes and track IDs
|
||||
- Real-time display using OpenCV
|
||||
- FPS monitoring and performance metrics
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from services import (
|
||||
StreamDecoderFactory,
|
||||
TensorRTModelRepository,
|
||||
TrackingFactory,
|
||||
YOLOv8Utils,
|
||||
COCO_CLASSES,
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def draw_tracking_overlay(frame: np.ndarray, tracked_objects, frame_info: dict) -> np.ndarray:
|
||||
"""
|
||||
Draw bounding boxes, labels, and tracking info on frame.
|
||||
|
||||
Args:
|
||||
frame: Frame in (H, W, 3) RGB format
|
||||
tracked_objects: List of TrackedObject instances
|
||||
frame_info: Dict with frame count, FPS, etc.
|
||||
|
||||
Returns:
|
||||
Frame with overlays drawn
|
||||
"""
|
||||
# Convert RGB to BGR for OpenCV
|
||||
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Get frame dimensions
|
||||
frame_height, frame_width = frame.shape[:2]
|
||||
|
||||
# Filter tracked objects to only show person and car
|
||||
filtered_objects = [obj for obj in tracked_objects if obj.class_name in ['person', 'car']]
|
||||
|
||||
# Define colors for different track IDs (cycling through colors)
|
||||
colors = [
|
||||
(0, 255, 0), # Green
|
||||
(255, 0, 0), # Blue
|
||||
(0, 0, 255), # Red
|
||||
(255, 255, 0), # Cyan
|
||||
(255, 0, 255), # Magenta
|
||||
(0, 255, 255), # Yellow
|
||||
(128, 255, 0), # Light green
|
||||
(255, 128, 0), # Orange
|
||||
]
|
||||
|
||||
# Draw each tracked object
|
||||
for obj in filtered_objects:
|
||||
|
||||
# Get color based on track ID
|
||||
color = colors[obj.track_id % len(colors)]
|
||||
|
||||
# Extract bounding box coordinates
|
||||
# Boxes come from YOLOv8 in 640x640 space, need to scale to frame size
|
||||
x1, y1, x2, y2 = obj.bbox
|
||||
|
||||
# Scale from 640x640 model space to actual frame size
|
||||
# YOLOv8 output is in 640x640, but frame is 1280x720
|
||||
scale_x = frame_width / 640.0
|
||||
scale_y = frame_height / 640.0
|
||||
|
||||
x1 = int(x1 * scale_x)
|
||||
y1 = int(y1 * scale_y)
|
||||
x2 = int(x2 * scale_x)
|
||||
y2 = int(y2 * scale_y)
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
# Prepare label text
|
||||
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||
|
||||
# Get text size for background rectangle
|
||||
(text_width, text_height), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
|
||||
# Draw label background
|
||||
cv2.rectangle(
|
||||
frame_bgr,
|
||||
(x1, y1 - text_height - baseline - 5),
|
||||
(x1 + text_width, y1),
|
||||
color,
|
||||
-1 # Filled
|
||||
)
|
||||
|
||||
# Draw label text
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
label,
|
||||
(x1, y1 - baseline - 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 0, 0), # Black text
|
||||
1,
|
||||
cv2.LINE_AA
|
||||
)
|
||||
|
||||
# Draw track history if available (trajectory)
|
||||
if hasattr(obj, 'history') and len(obj.history) > 1:
|
||||
points = []
|
||||
for hist_bbox in obj.history[-10:]: # Last 10 positions
|
||||
# Get center point of historical bbox (in 640x640 space)
|
||||
hx1, hy1, hx2, hy2 = hist_bbox
|
||||
|
||||
# Scale from 640x640 to frame size
|
||||
cx = int(((hx1 + hx2) / 2) * scale_x)
|
||||
cy = int(((hy1 + hy2) / 2) * scale_y)
|
||||
points.append((cx, cy))
|
||||
|
||||
# Draw trajectory line
|
||||
for i in range(1, len(points)):
|
||||
cv2.line(frame_bgr, points[i-1], points[i], color, 2)
|
||||
|
||||
# Draw info panel at top
|
||||
info_bg_height = 80
|
||||
overlay = frame_bgr.copy()
|
||||
cv2.rectangle(overlay, (0, 0), (frame_bgr.shape[1], info_bg_height), (0, 0, 0), -1)
|
||||
cv2.addWeighted(overlay, 0.5, frame_bgr, 0.5, 0, frame_bgr)
|
||||
|
||||
# Draw statistics text
|
||||
y_offset = 25
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
f"Frame: {frame_info.get('frame_count', 0)} | FPS: {frame_info.get('fps', 0):.1f}",
|
||||
(10, y_offset),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA
|
||||
)
|
||||
|
||||
y_offset += 25
|
||||
# Count persons and cars
|
||||
person_count = sum(1 for obj in filtered_objects if obj.class_name == 'person')
|
||||
car_count = sum(1 for obj in filtered_objects if obj.class_name == 'car')
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
f"Persons: {person_count} | Cars: {car_count} | Total Visible: {len(filtered_objects)}",
|
||||
(10, y_offset),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA
|
||||
)
|
||||
|
||||
return frame_bgr
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function for real-time tracking visualization.
|
||||
"""
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.trt"
|
||||
RTSP_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
BUFFER_SIZE = 30
|
||||
WINDOW_NAME = "Real-time Object Tracking"
|
||||
|
||||
print("=" * 80)
|
||||
print("Real-time GPU-Accelerated Object Tracking")
|
||||
print("=" * 80)
|
||||
|
||||
# Step 1: Create model repository
|
||||
print("\n[1/4] Initializing TensorRT Model Repository...")
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4)
|
||||
|
||||
# Load detection model
|
||||
model_id = "yolov8_detector"
|
||||
if os.path.exists(MODEL_PATH):
|
||||
try:
|
||||
metadata = model_repo.load_model(
|
||||
model_id=model_id,
|
||||
file_path=MODEL_PATH,
|
||||
num_contexts=4
|
||||
)
|
||||
print(f"✓ Model loaded successfully")
|
||||
print(f" Input shape: {metadata.input_shapes}")
|
||||
print(f" Output shape: {metadata.output_shapes}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
print(f" Please ensure {MODEL_PATH} exists")
|
||||
return
|
||||
else:
|
||||
print(f"✗ Model file not found: {MODEL_PATH}")
|
||||
print(f" Please provide a valid TensorRT model file")
|
||||
return
|
||||
|
||||
# Step 2: Create tracking controller
|
||||
print("\n[2/4] Creating TrackingController...")
|
||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||
|
||||
try:
|
||||
tracking_controller = tracking_factory.create_controller(
|
||||
model_repository=model_repo,
|
||||
model_id=model_id,
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5,
|
||||
iou_threshold=0.3,
|
||||
class_names=COCO_CLASSES
|
||||
)
|
||||
print(f"✓ Controller created: {tracking_controller}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to create controller: {e}")
|
||||
return
|
||||
|
||||
# Step 3: Create stream decoder
|
||||
print("\n[3/4] Creating RTSP Stream Decoder...")
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoder = stream_factory.create_decoder(
|
||||
rtsp_url=RTSP_URL,
|
||||
buffer_size=BUFFER_SIZE
|
||||
)
|
||||
decoder.start()
|
||||
print(f"✓ Decoder started for: {RTSP_URL}")
|
||||
print(f" Waiting for connection...")
|
||||
|
||||
# Wait for stream connection
|
||||
print(" Waiting up to 15 seconds for connection...")
|
||||
connected = False
|
||||
for i in range(15):
|
||||
time.sleep(1)
|
||||
if decoder.is_connected():
|
||||
connected = True
|
||||
break
|
||||
print(f" Waiting... {i+1}/15 seconds (status: {decoder.get_status().value})")
|
||||
|
||||
if connected:
|
||||
print(f"✓ Stream connected!")
|
||||
else:
|
||||
print(f"✗ Stream not connected after 15 seconds (status: {decoder.get_status().value})")
|
||||
print(f" Proceeding anyway - will start displaying when frames arrive...")
|
||||
# Don't exit - continue and wait for frames
|
||||
|
||||
# Step 4: Create OpenCV window
|
||||
print("\n[4/4] Starting Real-time Visualization...")
|
||||
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(WINDOW_NAME, 1280, 720)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Real-time tracking started!")
|
||||
print("Press 'q' to quit | Press 's' to save screenshot")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# FPS tracking
|
||||
fps_start_time = time.time()
|
||||
fps_frame_count = 0
|
||||
current_fps = 0.0
|
||||
|
||||
frame_count = 0
|
||||
screenshot_count = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get frame from decoder (CPU memory for OpenCV)
|
||||
frame_cpu = decoder.get_frame_cpu(index=-1, rgb=True)
|
||||
|
||||
if frame_cpu is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
# Get GPU frame for tracking
|
||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
if frame_gpu is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
frame_count += 1
|
||||
fps_frame_count += 1
|
||||
|
||||
# Run tracking on GPU frame with YOLOv8 pre/postprocessing
|
||||
tracked_objects = tracking_controller.track(
|
||||
frame_gpu,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess
|
||||
)
|
||||
|
||||
# Calculate FPS every second
|
||||
elapsed = time.time() - fps_start_time
|
||||
if elapsed >= 1.0:
|
||||
current_fps = fps_frame_count / elapsed
|
||||
fps_frame_count = 0
|
||||
fps_start_time = time.time()
|
||||
|
||||
# Get tracking statistics
|
||||
stats = tracking_controller.get_statistics()
|
||||
|
||||
# Prepare frame info for overlay
|
||||
frame_info = {
|
||||
'frame_count': frame_count,
|
||||
'fps': current_fps,
|
||||
'total_tracks': stats['total_tracks_created'],
|
||||
'class_counts': stats['class_counts']
|
||||
}
|
||||
|
||||
# Draw tracking overlay on CPU frame
|
||||
display_frame = draw_tracking_overlay(frame_cpu, tracked_objects, frame_info)
|
||||
|
||||
# Display frame
|
||||
cv2.imshow(WINDOW_NAME, display_frame)
|
||||
|
||||
# Handle keyboard input
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
|
||||
if key == ord('q'):
|
||||
print("\n✓ Quit requested by user")
|
||||
break
|
||||
elif key == ord('s'):
|
||||
# Save screenshot
|
||||
screenshot_count += 1
|
||||
filename = f"screenshot_{screenshot_count:04d}.jpg"
|
||||
cv2.imwrite(filename, display_frame)
|
||||
print(f"✓ Screenshot saved: {filename}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n✓ Interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error during tracking: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Cleanup
|
||||
print("\n" + "=" * 80)
|
||||
print("Cleanup")
|
||||
print("=" * 80)
|
||||
|
||||
# Print final statistics
|
||||
print("\nFinal Tracking Statistics:")
|
||||
stats = tracking_controller.get_statistics()
|
||||
for key, value in stats.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# Close OpenCV window
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# Stop decoder
|
||||
print("\nStopping decoder...")
|
||||
decoder.stop()
|
||||
print("✓ Decoder stopped")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Real-time tracking completed!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def main_multi_window():
|
||||
"""
|
||||
Example: Display multiple camera streams in separate windows.
|
||||
|
||||
This demonstrates tracking on multiple RTSP streams simultaneously
|
||||
with separate OpenCV windows for each stream.
|
||||
"""
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.trt"
|
||||
|
||||
# Load camera URLs from environment
|
||||
camera_urls = []
|
||||
i = 1
|
||||
while True:
|
||||
url = os.getenv(f'CAMERA_URL_{i}')
|
||||
if url:
|
||||
camera_urls.append(url)
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if not camera_urls:
|
||||
print("No camera URLs found in .env file")
|
||||
return
|
||||
|
||||
print(f"Starting multi-window tracking with {len(camera_urls)} cameras")
|
||||
|
||||
# Create shared model repository
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8)
|
||||
|
||||
if os.path.exists(MODEL_PATH):
|
||||
model_repo.load_model("detector", MODEL_PATH, num_contexts=8)
|
||||
else:
|
||||
print(f"Model not found: {MODEL_PATH}")
|
||||
return
|
||||
|
||||
# Create tracking factory
|
||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||
|
||||
# Create decoders and controllers
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoders = []
|
||||
controllers = []
|
||||
window_names = []
|
||||
|
||||
for i, url in enumerate(camera_urls):
|
||||
# Create decoder
|
||||
decoder = stream_factory.create_decoder(url, buffer_size=30)
|
||||
decoder.start()
|
||||
decoders.append(decoder)
|
||||
|
||||
# Create tracking controller
|
||||
controller = tracking_factory.create_controller(
|
||||
model_repository=model_repo,
|
||||
model_id="detector",
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5,
|
||||
iou_threshold=0.3,
|
||||
class_names=COCO_CLASSES
|
||||
)
|
||||
controllers.append(controller)
|
||||
|
||||
# Create window
|
||||
window_name = f"Camera {i+1}"
|
||||
window_names.append(window_name)
|
||||
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(window_name, 640, 480)
|
||||
|
||||
print(f"Camera {i+1}: {url}")
|
||||
|
||||
print("\nWaiting for streams to connect...")
|
||||
time.sleep(10)
|
||||
|
||||
print("\nPress 'q' to quit")
|
||||
|
||||
# FPS tracking for each stream
|
||||
fps_data = [{'start': time.time(), 'count': 0, 'fps': 0.0} for _ in camera_urls]
|
||||
frame_counts = [0] * len(camera_urls)
|
||||
|
||||
try:
|
||||
while True:
|
||||
for i, (decoder, controller, window_name) in enumerate(zip(decoders, controllers, window_names)):
|
||||
# Get frames
|
||||
frame_cpu = decoder.get_frame_cpu(index=-1, rgb=True)
|
||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
if frame_cpu is None or frame_gpu is None:
|
||||
continue
|
||||
|
||||
frame_counts[i] += 1
|
||||
fps_data[i]['count'] += 1
|
||||
|
||||
# Calculate FPS
|
||||
elapsed = time.time() - fps_data[i]['start']
|
||||
if elapsed >= 1.0:
|
||||
fps_data[i]['fps'] = fps_data[i]['count'] / elapsed
|
||||
fps_data[i]['count'] = 0
|
||||
fps_data[i]['start'] = time.time()
|
||||
|
||||
# Track objects with YOLOv8 pre/postprocessing
|
||||
tracked_objects = controller.track(
|
||||
frame_gpu,
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess
|
||||
)
|
||||
|
||||
# Get statistics
|
||||
stats = controller.get_statistics()
|
||||
|
||||
# Prepare frame info
|
||||
frame_info = {
|
||||
'frame_count': frame_counts[i],
|
||||
'fps': fps_data[i]['fps'],
|
||||
'total_tracks': stats['total_tracks_created'],
|
||||
'class_counts': stats['class_counts']
|
||||
}
|
||||
|
||||
# Draw overlay and display
|
||||
display_frame = draw_tracking_overlay(frame_cpu, tracked_objects, frame_info)
|
||||
cv2.imshow(window_name, display_frame)
|
||||
|
||||
# Check for quit
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
|
||||
# Cleanup
|
||||
print("\nCleaning up...")
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
for decoder in decoders:
|
||||
decoder.stop()
|
||||
|
||||
print("\nFinal Statistics:")
|
||||
for i, controller in enumerate(controllers):
|
||||
stats = controller.get_statistics()
|
||||
print(f"\nCamera {i+1}:")
|
||||
print(f" Frames: {stats['frame_count']}")
|
||||
print(f" Tracks created: {stats['total_tracks_created']}")
|
||||
print(f" Active tracks: {stats['active_tracks']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run single camera visualization
|
||||
main()
|
||||
|
||||
# Uncomment to run multi-window visualization
|
||||
# main_multi_window()
|
||||
Loading…
Add table
Add a link
Reference in a new issue