add models and update tracking system
This commit is contained in:
parent
fd470b3765
commit
2b0cfc4b72
9 changed files with 780 additions and 478 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -3,5 +3,5 @@ __pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
.env
|
.env
|
||||||
.claude
|
.claude
|
||||||
models/
|
/models/
|
||||||
/tracked_objects.json
|
/tracked_objects.json
|
||||||
0
bangchak/bangchak_controller.py
Normal file
0
bangchak/bangchak_controller.py
Normal file
BIN
bangchak/models/car_bodytype_cls_v1.pt
Normal file
BIN
bangchak/models/car_bodytype_cls_v1.pt
Normal file
Binary file not shown.
BIN
bangchak/models/car_brand_cls_v3.pt
Normal file
BIN
bangchak/models/car_brand_cls_v3.pt
Normal file
Binary file not shown.
BIN
bangchak/models/car_detection_v3.pt
Normal file
BIN
bangchak/models/car_detection_v3.pt
Normal file
Binary file not shown.
BIN
bangchak/models/frontal_detection_v5.pt
Normal file
BIN
bangchak/models/frontal_detection_v5.pt
Normal file
Binary file not shown.
|
|
@ -332,6 +332,7 @@ class StreamConnectionManager:
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
force_timeout: float = 0.05,
|
force_timeout: float = 0.05,
|
||||||
poll_interval: float = 0.01,
|
poll_interval: float = 0.01,
|
||||||
|
enable_pt_conversion: bool = True,
|
||||||
):
|
):
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
@ -341,7 +342,10 @@ class StreamConnectionManager:
|
||||||
# Factories
|
# Factories
|
||||||
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
||||||
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
|
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
|
||||||
self.model_repository = TensorRTModelRepository(gpu_id=gpu_id)
|
self.model_repository = TensorRTModelRepository(
|
||||||
|
gpu_id=gpu_id,
|
||||||
|
enable_pt_conversion=enable_pt_conversion
|
||||||
|
)
|
||||||
|
|
||||||
# Controllers
|
# Controllers
|
||||||
self.model_controller: Optional[ModelController] = None
|
self.model_controller: Optional[ModelController] = None
|
||||||
|
|
@ -360,16 +364,22 @@ class StreamConnectionManager:
|
||||||
preprocess_fn: Optional[Callable] = None,
|
preprocess_fn: Optional[Callable] = None,
|
||||||
postprocess_fn: Optional[Callable] = None,
|
postprocess_fn: Optional[Callable] = None,
|
||||||
num_contexts: int = 4,
|
num_contexts: int = 4,
|
||||||
|
pt_input_shapes: Optional[Dict] = None,
|
||||||
|
pt_precision: Optional[Any] = None,
|
||||||
|
**pt_conversion_kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the manager with a model.
|
Initialize the manager with a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Path to TensorRT model file
|
model_path: Path to TensorRT or PyTorch model file (.trt, .pt, .pth)
|
||||||
model_id: Model identifier (default: "detector")
|
model_id: Model identifier (default: "detector")
|
||||||
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
|
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
|
||||||
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
|
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
|
||||||
num_contexts: Number of TensorRT execution contexts (default: 4)
|
num_contexts: Number of TensorRT execution contexts (default: 4)
|
||||||
|
pt_input_shapes: Required for PT files - dict of input shapes
|
||||||
|
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
|
||||||
|
**pt_conversion_kwargs: Additional PT conversion arguments
|
||||||
"""
|
"""
|
||||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||||
|
|
||||||
|
|
@ -377,7 +387,14 @@ class StreamConnectionManager:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
lambda: self.model_repository.load_model(model_id, model_path, num_contexts=num_contexts)
|
lambda: self.model_repository.load_model(
|
||||||
|
model_id,
|
||||||
|
model_path,
|
||||||
|
num_contexts=num_contexts,
|
||||||
|
pt_input_shapes=pt_input_shapes,
|
||||||
|
pt_precision=pt_precision,
|
||||||
|
**pt_conversion_kwargs
|
||||||
|
)
|
||||||
)
|
)
|
||||||
logger.info(f"Loaded model {model_id} from {model_path}")
|
logger.info(f"Loaded model {model_id} from {model_path}")
|
||||||
|
|
||||||
|
|
@ -392,13 +409,10 @@ class StreamConnectionManager:
|
||||||
)
|
)
|
||||||
await self.model_controller.start()
|
await self.model_controller.start()
|
||||||
|
|
||||||
# Create tracking controller
|
# Don't create a shared tracking controller here
|
||||||
self.tracking_controller = self.tracking_factory.create_controller(
|
# Each stream will get its own tracking controller to avoid track accumulation
|
||||||
model_repository=self.model_repository,
|
self.tracking_controller = None
|
||||||
model_id=model_id,
|
self.model_id_for_tracking = model_id # Store for later use
|
||||||
tracker_type="iou",
|
|
||||||
)
|
|
||||||
logger.info("TrackingController created")
|
|
||||||
|
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
logger.info("StreamConnectionManager initialized successfully")
|
logger.info("StreamConnectionManager initialized successfully")
|
||||||
|
|
@ -440,12 +454,24 @@ class StreamConnectionManager:
|
||||||
# Create decoder
|
# Create decoder
|
||||||
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
|
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# Create dedicated tracking controller for THIS stream
|
||||||
|
# This prevents track accumulation across multiple streams
|
||||||
|
tracking_controller = self.tracking_factory.create_controller(
|
||||||
|
model_repository=self.model_repository,
|
||||||
|
model_id=self.model_id_for_tracking,
|
||||||
|
tracker_type="iou",
|
||||||
|
max_age=30,
|
||||||
|
min_confidence=0.5,
|
||||||
|
iou_threshold=0.3,
|
||||||
|
)
|
||||||
|
logger.info(f"Created dedicated TrackingController for stream {stream_id}")
|
||||||
|
|
||||||
# Create connection
|
# Create connection
|
||||||
connection = StreamConnection(
|
connection = StreamConnection(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
model_controller=self.model_controller,
|
model_controller=self.model_controller,
|
||||||
tracking_controller=self.tracking_controller,
|
tracking_controller=tracking_controller,
|
||||||
poll_interval=self.poll_interval,
|
poll_interval=self.poll_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,21 @@
|
||||||
"""
|
"""
|
||||||
Real-time object tracking visualization with OpenCV.
|
Real-time object tracking with event-driven batching architecture.
|
||||||
|
|
||||||
This script demonstrates:
|
This script demonstrates:
|
||||||
- GPU-accelerated decoding and tracking
|
- Event-driven stream processing with StreamConnectionManager
|
||||||
- CPU-side visualization with bounding boxes and track IDs
|
- Batched GPU inference with ModelController
|
||||||
- Real-time display using OpenCV
|
- Ping-pong buffer architecture for optimal throughput
|
||||||
- FPS monitoring and performance metrics
|
- Async/await pattern for multiple RTSP streams
|
||||||
|
- Automatic PT to TensorRT conversion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import cv2
|
import torch
|
||||||
import numpy as np
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from services import (
|
from services import (
|
||||||
StreamDecoderFactory,
|
StreamConnectionManager,
|
||||||
TensorRTModelRepository,
|
|
||||||
TrackingFactory,
|
|
||||||
YOLOv8Utils,
|
YOLOv8Utils,
|
||||||
COCO_CLASSES,
|
COCO_CLASSES,
|
||||||
)
|
)
|
||||||
|
|
@ -25,513 +24,253 @@ from services import (
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
def draw_tracking_overlay(frame: np.ndarray, tracked_objects, frame_info: dict) -> np.ndarray:
|
async def main_single_stream():
|
||||||
"""
|
"""Single stream example with event-driven architecture."""
|
||||||
Draw bounding boxes, labels, and tracking info on frame.
|
print("=" * 80)
|
||||||
|
print("Event-Driven GPU-Accelerated Object Tracking - Single Stream")
|
||||||
Args:
|
print("=" * 80)
|
||||||
frame: Frame in (H, W, 3) RGB format
|
|
||||||
tracked_objects: List of TrackedObject instances
|
|
||||||
frame_info: Dict with frame count, FPS, etc.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Frame with overlays drawn
|
|
||||||
"""
|
|
||||||
# Convert RGB to BGR for OpenCV
|
|
||||||
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
|
||||||
|
|
||||||
# Get frame dimensions
|
|
||||||
frame_height, frame_width = frame.shape[:2]
|
|
||||||
|
|
||||||
# Filter tracked objects to only show person and car
|
|
||||||
filtered_objects = [obj for obj in tracked_objects if obj.class_name in ['person', 'car']]
|
|
||||||
|
|
||||||
# Define colors for different track IDs (cycling through colors)
|
|
||||||
colors = [
|
|
||||||
(0, 255, 0), # Green
|
|
||||||
(255, 0, 0), # Blue
|
|
||||||
(0, 0, 255), # Red
|
|
||||||
(255, 255, 0), # Cyan
|
|
||||||
(255, 0, 255), # Magenta
|
|
||||||
(0, 255, 255), # Yellow
|
|
||||||
(128, 255, 0), # Light green
|
|
||||||
(255, 128, 0), # Orange
|
|
||||||
]
|
|
||||||
|
|
||||||
# Draw each tracked object
|
|
||||||
for obj in filtered_objects:
|
|
||||||
|
|
||||||
# Get color based on track ID
|
|
||||||
color = colors[obj.track_id % len(colors)]
|
|
||||||
|
|
||||||
# Extract bounding box coordinates
|
|
||||||
# Boxes come from YOLOv8 in 640x640 space, need to scale to frame size
|
|
||||||
x1, y1, x2, y2 = obj.bbox
|
|
||||||
|
|
||||||
# Scale from 640x640 model space to actual frame size
|
|
||||||
# YOLOv8 output is in 640x640, but frame is 1280x720
|
|
||||||
scale_x = frame_width / 640.0
|
|
||||||
scale_y = frame_height / 640.0
|
|
||||||
|
|
||||||
x1 = int(x1 * scale_x)
|
|
||||||
y1 = int(y1 * scale_y)
|
|
||||||
x2 = int(x2 * scale_x)
|
|
||||||
y2 = int(y2 * scale_y)
|
|
||||||
|
|
||||||
# Draw bounding box
|
|
||||||
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), color, 2)
|
|
||||||
|
|
||||||
# Prepare label text
|
|
||||||
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
|
||||||
|
|
||||||
# Get text size for background rectangle
|
|
||||||
(text_width, text_height), baseline = cv2.getTextSize(
|
|
||||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Draw label background
|
|
||||||
cv2.rectangle(
|
|
||||||
frame_bgr,
|
|
||||||
(x1, y1 - text_height - baseline - 5),
|
|
||||||
(x1 + text_width, y1),
|
|
||||||
color,
|
|
||||||
-1 # Filled
|
|
||||||
)
|
|
||||||
|
|
||||||
# Draw label text
|
|
||||||
cv2.putText(
|
|
||||||
frame_bgr,
|
|
||||||
label,
|
|
||||||
(x1, y1 - baseline - 2),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX,
|
|
||||||
0.5,
|
|
||||||
(0, 0, 0), # Black text
|
|
||||||
1,
|
|
||||||
cv2.LINE_AA
|
|
||||||
)
|
|
||||||
|
|
||||||
# Draw track history if available (trajectory)
|
|
||||||
if hasattr(obj, 'history') and len(obj.history) > 1:
|
|
||||||
points = []
|
|
||||||
for hist_bbox in obj.history[-10:]: # Last 10 positions
|
|
||||||
# Get center point of historical bbox (in 640x640 space)
|
|
||||||
hx1, hy1, hx2, hy2 = hist_bbox
|
|
||||||
|
|
||||||
# Scale from 640x640 to frame size
|
|
||||||
cx = int(((hx1 + hx2) / 2) * scale_x)
|
|
||||||
cy = int(((hy1 + hy2) / 2) * scale_y)
|
|
||||||
points.append((cx, cy))
|
|
||||||
|
|
||||||
# Draw trajectory line
|
|
||||||
for i in range(1, len(points)):
|
|
||||||
cv2.line(frame_bgr, points[i-1], points[i], color, 2)
|
|
||||||
|
|
||||||
# Draw info panel at top
|
|
||||||
info_bg_height = 80
|
|
||||||
overlay = frame_bgr.copy()
|
|
||||||
cv2.rectangle(overlay, (0, 0), (frame_bgr.shape[1], info_bg_height), (0, 0, 0), -1)
|
|
||||||
cv2.addWeighted(overlay, 0.5, frame_bgr, 0.5, 0, frame_bgr)
|
|
||||||
|
|
||||||
# Draw statistics text
|
|
||||||
y_offset = 25
|
|
||||||
cv2.putText(
|
|
||||||
frame_bgr,
|
|
||||||
f"Frame: {frame_info.get('frame_count', 0)} | FPS: {frame_info.get('fps', 0):.1f}",
|
|
||||||
(10, y_offset),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX,
|
|
||||||
0.6,
|
|
||||||
(255, 255, 255),
|
|
||||||
2,
|
|
||||||
cv2.LINE_AA
|
|
||||||
)
|
|
||||||
|
|
||||||
y_offset += 25
|
|
||||||
# Count persons and cars
|
|
||||||
person_count = sum(1 for obj in filtered_objects if obj.class_name == 'person')
|
|
||||||
car_count = sum(1 for obj in filtered_objects if obj.class_name == 'car')
|
|
||||||
cv2.putText(
|
|
||||||
frame_bgr,
|
|
||||||
f"Persons: {person_count} | Cars: {car_count} | Total Visible: {len(filtered_objects)}",
|
|
||||||
(10, y_offset),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX,
|
|
||||||
0.6,
|
|
||||||
(255, 255, 255),
|
|
||||||
2,
|
|
||||||
cv2.LINE_AA
|
|
||||||
)
|
|
||||||
|
|
||||||
return frame_bgr
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""
|
|
||||||
Main function for real-time tracking visualization.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
GPU_ID = 0
|
GPU_ID = 0
|
||||||
MODEL_PATH = "models/yolov8n.pt" # Changed to PT file
|
MODEL_PATH = "models/yolov8n.pt" # PT file will be auto-converted
|
||||||
RTSP_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||||
BUFFER_SIZE = 30
|
BATCH_SIZE = 4
|
||||||
WINDOW_NAME = "Real-time Object Tracking"
|
FORCE_TIMEOUT = 0.05
|
||||||
|
|
||||||
print("=" * 80)
|
print(f"\nConfiguration:")
|
||||||
print("Real-time GPU-Accelerated Object Tracking")
|
print(f" GPU: {GPU_ID}")
|
||||||
print("=" * 80)
|
print(f" Model: {MODEL_PATH}")
|
||||||
|
print(f" Stream: {STREAM_URL}")
|
||||||
|
print(f" Batch size: {BATCH_SIZE}")
|
||||||
|
print(f" Force timeout: {FORCE_TIMEOUT}s\n")
|
||||||
|
|
||||||
# Step 1: Create model repository with PT conversion enabled
|
# Create StreamConnectionManager with PT conversion enabled
|
||||||
print("\n[1/4] Initializing TensorRT Model Repository...")
|
print("[1/3] Creating StreamConnectionManager...")
|
||||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4, enable_pt_conversion=True)
|
manager = StreamConnectionManager(
|
||||||
|
gpu_id=GPU_ID,
|
||||||
# Load detection model (will auto-convert PT to TRT)
|
batch_size=BATCH_SIZE,
|
||||||
model_id = "yolov8_detector"
|
force_timeout=FORCE_TIMEOUT,
|
||||||
if os.path.exists(MODEL_PATH):
|
enable_pt_conversion=True # Enable PT conversion
|
||||||
try:
|
|
||||||
print(f"Loading model from {MODEL_PATH}...")
|
|
||||||
print("Note: First load will convert PT to TensorRT (may take 3-5 minutes)")
|
|
||||||
print("Subsequent loads will use cached TensorRT engine")
|
|
||||||
|
|
||||||
metadata = model_repo.load_model(
|
|
||||||
model_id=model_id,
|
|
||||||
file_path=MODEL_PATH,
|
|
||||||
num_contexts=4,
|
|
||||||
pt_input_shapes={"images": (1, 3, 640, 640)}, # Required for PT conversion
|
|
||||||
pt_precision=torch.float16 # Use FP16 for better performance
|
|
||||||
)
|
)
|
||||||
print(f"✓ Model loaded successfully")
|
print("✓ Manager created")
|
||||||
print(f" Input shape: {metadata.input_shapes}")
|
|
||||||
print(f" Output shape: {metadata.output_shapes}")
|
# Initialize with PT model (auto-conversion)
|
||||||
|
print("\n[2/3] Initializing with PT model...")
|
||||||
|
print("Note: First load will convert PT to TensorRT (3-5 minutes)")
|
||||||
|
print("Subsequent loads will use cached TensorRT engine\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await manager.initialize(
|
||||||
|
model_path=MODEL_PATH,
|
||||||
|
model_id="detector",
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
|
num_contexts=4,
|
||||||
|
pt_input_shapes={"images": (1, 3, 640, 640)},
|
||||||
|
pt_precision=torch.float16
|
||||||
|
)
|
||||||
|
print("✓ Manager initialized (PT converted to TensorRT)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Failed to load model: {e}")
|
print(f"✗ Failed to initialize: {e}")
|
||||||
print(f" Please ensure {MODEL_PATH} exists")
|
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
else:
|
|
||||||
print(f"✗ Model file not found: {MODEL_PATH}")
|
|
||||||
print(f" Please provide a valid PyTorch (.pt) or TensorRT (.trt) model file")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Step 2: Create tracking controller
|
|
||||||
print("\n[2/4] Creating TrackingController...")
|
|
||||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
|
||||||
|
|
||||||
|
# Connect stream
|
||||||
|
print("\n[3/3] Connecting to stream...")
|
||||||
try:
|
try:
|
||||||
tracking_controller = tracking_factory.create_controller(
|
connection = await manager.connect_stream(
|
||||||
model_repository=model_repo,
|
rtsp_url=STREAM_URL,
|
||||||
model_id=model_id,
|
stream_id="camera_1",
|
||||||
tracker_type="iou",
|
buffer_size=30
|
||||||
max_age=30,
|
|
||||||
min_confidence=0.5,
|
|
||||||
iou_threshold=0.3,
|
|
||||||
class_names=COCO_CLASSES
|
|
||||||
)
|
)
|
||||||
print(f"✓ Controller created: {tracking_controller}")
|
print(f"✓ Stream connected: camera_1")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Failed to create controller: {e}")
|
print(f"✗ Failed to connect stream: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Step 3: Create stream decoder
|
|
||||||
print("\n[3/4] Creating RTSP Stream Decoder...")
|
|
||||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
|
||||||
decoder = stream_factory.create_decoder(
|
|
||||||
rtsp_url=RTSP_URL,
|
|
||||||
buffer_size=BUFFER_SIZE
|
|
||||||
)
|
|
||||||
decoder.start()
|
|
||||||
print(f"✓ Decoder started for: {RTSP_URL}")
|
|
||||||
print(f" Waiting for connection...")
|
|
||||||
|
|
||||||
# Wait for stream connection
|
|
||||||
print(" Waiting up to 15 seconds for connection...")
|
|
||||||
connected = False
|
|
||||||
for i in range(15):
|
|
||||||
time.sleep(1)
|
|
||||||
if decoder.is_connected():
|
|
||||||
connected = True
|
|
||||||
break
|
|
||||||
print(f" Waiting... {i+1}/15 seconds (status: {decoder.get_status().value})")
|
|
||||||
|
|
||||||
if connected:
|
|
||||||
print(f"✓ Stream connected!")
|
|
||||||
else:
|
|
||||||
print(f"✗ Stream not connected after 15 seconds (status: {decoder.get_status().value})")
|
|
||||||
print(f" Proceeding anyway - will start displaying when frames arrive...")
|
|
||||||
# Don't exit - continue and wait for frames
|
|
||||||
|
|
||||||
# Step 4: Create OpenCV window
|
|
||||||
print("\n[4/4] Starting Real-time Visualization...")
|
|
||||||
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
|
||||||
cv2.resizeWindow(WINDOW_NAME, 1280, 720)
|
|
||||||
|
|
||||||
print(f"\n{'=' * 80}")
|
print(f"\n{'=' * 80}")
|
||||||
print("Real-time tracking started!")
|
print("Event-driven tracking is running!")
|
||||||
print("Press 'q' to quit | Press 's' to save screenshot")
|
print("Press Ctrl+C to stop")
|
||||||
print(f"{'=' * 80}\n")
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
# FPS tracking
|
# Stream results
|
||||||
fps_start_time = time.time()
|
result_count = 0
|
||||||
fps_frame_count = 0
|
start_time = time.time()
|
||||||
current_fps = 0.0
|
|
||||||
|
|
||||||
frame_count = 0
|
|
||||||
screenshot_count = 0
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
async for result in connection.tracking_results():
|
||||||
# Get frame from decoder (CPU memory for OpenCV)
|
result_count += 1
|
||||||
frame_cpu = decoder.get_frame_cpu(index=-1, rgb=True)
|
|
||||||
|
|
||||||
if frame_cpu is None:
|
# Print stats every 30 results
|
||||||
time.sleep(0.01)
|
if result_count % 30 == 0:
|
||||||
continue
|
elapsed = time.time() - start_time
|
||||||
|
fps = result_count / elapsed if elapsed > 0 else 0
|
||||||
|
|
||||||
# Get GPU frame for tracking
|
print(f"\nResults: {result_count} | FPS: {fps:.1f}")
|
||||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
print(f" Stream: {result.stream_id}")
|
||||||
|
print(f" Objects: {len(result.tracked_objects)}")
|
||||||
|
|
||||||
if frame_gpu is None:
|
if result.tracked_objects:
|
||||||
time.sleep(0.01)
|
class_counts = {}
|
||||||
continue
|
for obj in result.tracked_objects:
|
||||||
|
class_counts[obj.class_name] = class_counts.get(obj.class_name, 0) + 1
|
||||||
frame_count += 1
|
print(f" Classes: {class_counts}")
|
||||||
fps_frame_count += 1
|
|
||||||
|
|
||||||
# Run tracking on GPU frame with YOLOv8 pre/postprocessing
|
|
||||||
tracked_objects = tracking_controller.track(
|
|
||||||
frame_gpu,
|
|
||||||
preprocess_fn=YOLOv8Utils.preprocess,
|
|
||||||
postprocess_fn=YOLOv8Utils.postprocess
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate FPS every second
|
|
||||||
elapsed = time.time() - fps_start_time
|
|
||||||
if elapsed >= 1.0:
|
|
||||||
current_fps = fps_frame_count / elapsed
|
|
||||||
fps_frame_count = 0
|
|
||||||
fps_start_time = time.time()
|
|
||||||
|
|
||||||
# Get tracking statistics
|
|
||||||
stats = tracking_controller.get_statistics()
|
|
||||||
|
|
||||||
# Prepare frame info for overlay
|
|
||||||
frame_info = {
|
|
||||||
'frame_count': frame_count,
|
|
||||||
'fps': current_fps,
|
|
||||||
'total_tracks': stats['total_tracks_created'],
|
|
||||||
'class_counts': stats['class_counts']
|
|
||||||
}
|
|
||||||
|
|
||||||
# Draw tracking overlay on CPU frame
|
|
||||||
display_frame = draw_tracking_overlay(frame_cpu, tracked_objects, frame_info)
|
|
||||||
|
|
||||||
# Display frame
|
|
||||||
cv2.imshow(WINDOW_NAME, display_frame)
|
|
||||||
|
|
||||||
# Handle keyboard input
|
|
||||||
key = cv2.waitKey(1) & 0xFF
|
|
||||||
|
|
||||||
if key == ord('q'):
|
|
||||||
print("\n✓ Quit requested by user")
|
|
||||||
break
|
|
||||||
elif key == ord('s'):
|
|
||||||
# Save screenshot
|
|
||||||
screenshot_count += 1
|
|
||||||
filename = f"screenshot_{screenshot_count:04d}.jpg"
|
|
||||||
cv2.imwrite(filename, display_frame)
|
|
||||||
print(f"✓ Screenshot saved: {filename}")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\n✓ Interrupted by user")
|
print(f"\n✓ Interrupted by user")
|
||||||
except Exception as e:
|
|
||||||
print(f"\n✗ Error during tracking: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
print("\n" + "=" * 80)
|
print(f"\n{'=' * 80}")
|
||||||
print("Cleanup")
|
print("Cleanup")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
await connection.stop()
|
||||||
|
await manager.shutdown()
|
||||||
|
print("✓ Stopped")
|
||||||
|
|
||||||
|
# Final stats
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
avg_fps = result_count / elapsed if elapsed > 0 else 0
|
||||||
|
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
||||||
|
|
||||||
|
|
||||||
|
async def main_multi_stream():
|
||||||
|
"""Multi-stream example with batched inference."""
|
||||||
|
print("=" * 80)
|
||||||
|
print("Event-Driven GPU-Accelerated Object Tracking - Multi-Stream")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
# Print final statistics
|
# Configuration
|
||||||
print("\nFinal Tracking Statistics:")
|
|
||||||
stats = tracking_controller.get_statistics()
|
|
||||||
for key, value in stats.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
|
|
||||||
# Close OpenCV window
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
|
|
||||||
# Stop decoder
|
|
||||||
print("\nStopping decoder...")
|
|
||||||
decoder.stop()
|
|
||||||
print("✓ Decoder stopped")
|
|
||||||
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("Real-time tracking completed!")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
|
|
||||||
def main_multi_window():
|
|
||||||
"""
|
|
||||||
Example: Display multiple camera streams in separate windows.
|
|
||||||
|
|
||||||
This demonstrates tracking on multiple RTSP streams simultaneously
|
|
||||||
with separate OpenCV windows for each stream.
|
|
||||||
"""
|
|
||||||
GPU_ID = 0
|
GPU_ID = 0
|
||||||
MODEL_PATH = "models/yolov8n.pt"
|
MODEL_PATH = "models/yolov8n.pt" # PT file will be auto-converted
|
||||||
|
BATCH_SIZE = 16
|
||||||
|
FORCE_TIMEOUT = 0.05
|
||||||
|
|
||||||
# Load camera URLs from environment
|
# Load camera URLs
|
||||||
camera_urls = []
|
camera_urls = []
|
||||||
i = 1
|
i = 1
|
||||||
while True:
|
while True:
|
||||||
url = os.getenv(f'CAMERA_URL_{i}')
|
url = os.getenv(f'CAMERA_URL_{i}')
|
||||||
if url:
|
if url:
|
||||||
camera_urls.append(url)
|
camera_urls.append((f"camera_{i}", url))
|
||||||
i += 1
|
i += 1
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not camera_urls:
|
if not camera_urls:
|
||||||
print("No camera URLs found in .env file")
|
print("No camera URLs found in .env")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Starting multi-window tracking with {len(camera_urls)} cameras")
|
print(f"\nConfiguration:")
|
||||||
|
print(f" GPU: {GPU_ID}")
|
||||||
|
print(f" Model: {MODEL_PATH}")
|
||||||
|
print(f" Streams: {len(camera_urls)}")
|
||||||
|
print(f" Batch size: {BATCH_SIZE}\n")
|
||||||
|
|
||||||
# Create shared model repository with PT conversion enabled
|
# Create manager with PT conversion
|
||||||
import torch
|
print("[1/3] Creating StreamConnectionManager...")
|
||||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8, enable_pt_conversion=True)
|
manager = StreamConnectionManager(
|
||||||
|
gpu_id=GPU_ID,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
force_timeout=FORCE_TIMEOUT,
|
||||||
|
enable_pt_conversion=True
|
||||||
|
)
|
||||||
|
print("✓ Manager created")
|
||||||
|
|
||||||
if os.path.exists(MODEL_PATH):
|
# Initialize with PT model
|
||||||
print(f"Loading model from {MODEL_PATH}...")
|
print("\n[2/3] Initializing with PT model...")
|
||||||
print("Note: First load will convert PT to TensorRT (may take 3-5 minutes)")
|
try:
|
||||||
print("Subsequent loads will use cached TensorRT engine")
|
await manager.initialize(
|
||||||
|
model_path=MODEL_PATH,
|
||||||
model_repo.load_model(
|
|
||||||
model_id="detector",
|
model_id="detector",
|
||||||
file_path=MODEL_PATH,
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
num_contexts=8,
|
num_contexts=8,
|
||||||
pt_input_shapes={"images": (1, 3, 640, 640)}, # Required for PT conversion
|
pt_input_shapes={"images": (1, 3, 640, 640)},
|
||||||
pt_precision=torch.float16 # Use FP16 for better performance
|
pt_precision=torch.float16
|
||||||
)
|
)
|
||||||
print("✓ Model loaded successfully")
|
print("✓ Manager initialized")
|
||||||
else:
|
except Exception as e:
|
||||||
print(f"Model not found: {MODEL_PATH}")
|
print(f"✗ Failed to initialize: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create tracking factory
|
# Connect all streams
|
||||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
print(f"\n[3/3] Connecting {len(camera_urls)} streams...")
|
||||||
|
connections = {}
|
||||||
# Create decoders and controllers
|
for stream_id, rtsp_url in camera_urls:
|
||||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
try:
|
||||||
decoders = []
|
conn = await manager.connect_stream(
|
||||||
controllers = []
|
rtsp_url=rtsp_url,
|
||||||
window_names = []
|
stream_id=stream_id,
|
||||||
|
buffer_size=30
|
||||||
for i, url in enumerate(camera_urls):
|
|
||||||
# Create decoder
|
|
||||||
decoder = stream_factory.create_decoder(url, buffer_size=30)
|
|
||||||
decoder.start()
|
|
||||||
decoders.append(decoder)
|
|
||||||
|
|
||||||
# Create tracking controller
|
|
||||||
controller = tracking_factory.create_controller(
|
|
||||||
model_repository=model_repo,
|
|
||||||
model_id="detector",
|
|
||||||
tracker_type="iou",
|
|
||||||
max_age=30,
|
|
||||||
min_confidence=0.5,
|
|
||||||
iou_threshold=0.3,
|
|
||||||
class_names=COCO_CLASSES
|
|
||||||
)
|
)
|
||||||
controllers.append(controller)
|
connections[stream_id] = conn
|
||||||
|
print(f"✓ Connected: {stream_id}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Failed {stream_id}: {e}")
|
||||||
|
|
||||||
# Create window
|
if not connections:
|
||||||
window_name = f"Camera {i+1}"
|
print("No streams connected")
|
||||||
window_names.append(window_name)
|
return
|
||||||
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
|
||||||
cv2.resizeWindow(window_name, 640, 480)
|
|
||||||
|
|
||||||
print(f"Camera {i+1}: {url}")
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Multi-stream tracking running ({len(connections)} streams)")
|
||||||
|
print("Frames from all streams are batched together!")
|
||||||
|
print("Press Ctrl+C to stop")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
print("\nWaiting for streams to connect...")
|
# Track stats
|
||||||
time.sleep(10)
|
stream_stats = {sid: {'count': 0, 'start': time.time()} for sid in connections.keys()}
|
||||||
|
total_results = 0
|
||||||
print("\nPress 'q' to quit")
|
start_time = time.time()
|
||||||
|
|
||||||
# FPS tracking for each stream
|
|
||||||
fps_data = [{'start': time.time(), 'count': 0, 'fps': 0.0} for _ in camera_urls]
|
|
||||||
frame_counts = [0] * len(camera_urls)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
# Simple approach: iterate over first connection's results
|
||||||
for i, (decoder, controller, window_name) in enumerate(zip(decoders, controllers, window_names)):
|
# In production, you'd properly merge all result streams
|
||||||
# Get frames
|
for conn in connections.values():
|
||||||
frame_cpu = decoder.get_frame_cpu(index=-1, rgb=True)
|
async for result in conn.tracking_results():
|
||||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
total_results += 1
|
||||||
|
stream_id = result.stream_id
|
||||||
|
|
||||||
if frame_cpu is None or frame_gpu is None:
|
if stream_id in stream_stats:
|
||||||
continue
|
stream_stats[stream_id]['count'] += 1
|
||||||
|
|
||||||
frame_counts[i] += 1
|
# Print stats every 100 results
|
||||||
fps_data[i]['count'] += 1
|
if total_results % 100 == 0:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
total_fps = total_results / elapsed if elapsed > 0 else 0
|
||||||
|
|
||||||
# Calculate FPS
|
print(f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS")
|
||||||
elapsed = time.time() - fps_data[i]['start']
|
for sid, stats in stream_stats.items():
|
||||||
if elapsed >= 1.0:
|
s_elapsed = time.time() - stats['start']
|
||||||
fps_data[i]['fps'] = fps_data[i]['count'] / elapsed
|
s_fps = stats['count'] / s_elapsed if s_elapsed > 0 else 0
|
||||||
fps_data[i]['count'] = 0
|
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
||||||
fps_data[i]['start'] = time.time()
|
|
||||||
|
|
||||||
# Track objects with YOLOv8 pre/postprocessing
|
|
||||||
tracked_objects = controller.track(
|
|
||||||
frame_gpu,
|
|
||||||
preprocess_fn=YOLOv8Utils.preprocess,
|
|
||||||
postprocess_fn=YOLOv8Utils.postprocess
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get statistics
|
|
||||||
stats = controller.get_statistics()
|
|
||||||
|
|
||||||
# Prepare frame info
|
|
||||||
frame_info = {
|
|
||||||
'frame_count': frame_counts[i],
|
|
||||||
'fps': fps_data[i]['fps'],
|
|
||||||
'total_tracks': stats['total_tracks_created'],
|
|
||||||
'class_counts': stats['class_counts']
|
|
||||||
}
|
|
||||||
|
|
||||||
# Draw overlay and display
|
|
||||||
display_frame = draw_tracking_overlay(frame_cpu, tracked_objects, frame_info)
|
|
||||||
cv2.imshow(window_name, display_frame)
|
|
||||||
|
|
||||||
# Check for quit
|
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
||||||
break
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nInterrupted by user")
|
print(f"\n✓ Interrupted")
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
print("\nCleaning up...")
|
print(f"\n{'=' * 80}")
|
||||||
cv2.destroyAllWindows()
|
print("Cleanup")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
for decoder in decoders:
|
for conn in connections.values():
|
||||||
decoder.stop()
|
await conn.stop()
|
||||||
|
await manager.shutdown()
|
||||||
|
print("✓ Stopped")
|
||||||
|
|
||||||
print("\nFinal Statistics:")
|
# Final stats
|
||||||
for i, controller in enumerate(controllers):
|
elapsed = time.time() - start_time
|
||||||
stats = controller.get_statistics()
|
avg_fps = total_results / elapsed if elapsed > 0 else 0
|
||||||
print(f"\nCamera {i+1}:")
|
print(f"\nFinal: {total_results} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
||||||
print(f" Frames: {stats['frame_count']}")
|
|
||||||
print(f" Tracks created: {stats['total_tracks_created']}")
|
|
||||||
print(f" Active tracks: {stats['active_tracks']}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Run single camera visualization
|
import sys
|
||||||
# main()
|
if len(sys.argv) > 1 and sys.argv[1] == "single":
|
||||||
|
asyncio.run(main_single_stream())
|
||||||
# Uncomment to run multi-window visualization
|
else:
|
||||||
main_multi_window()
|
asyncio.run(main_multi_stream())
|
||||||
|
|
|
||||||
537
test_tracking_realtime_old.py
Normal file
537
test_tracking_realtime_old.py
Normal file
|
|
@ -0,0 +1,537 @@
|
||||||
|
"""
|
||||||
|
Real-time object tracking visualization with OpenCV.
|
||||||
|
|
||||||
|
This script demonstrates:
|
||||||
|
- GPU-accelerated decoding and tracking
|
||||||
|
- CPU-side visualization with bounding boxes and track IDs
|
||||||
|
- Real-time display using OpenCV
|
||||||
|
- FPS monitoring and performance metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from services import (
|
||||||
|
StreamDecoderFactory,
|
||||||
|
TensorRTModelRepository,
|
||||||
|
TrackingFactory,
|
||||||
|
YOLOv8Utils,
|
||||||
|
COCO_CLASSES,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def draw_tracking_overlay(frame: np.ndarray, tracked_objects, frame_info: dict) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Draw bounding boxes, labels, and tracking info on frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Frame in (H, W, 3) RGB format
|
||||||
|
tracked_objects: List of TrackedObject instances
|
||||||
|
frame_info: Dict with frame count, FPS, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Frame with overlays drawn
|
||||||
|
"""
|
||||||
|
# Convert RGB to BGR for OpenCV
|
||||||
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# Get frame dimensions
|
||||||
|
frame_height, frame_width = frame.shape[:2]
|
||||||
|
|
||||||
|
# Filter tracked objects to only show person and car
|
||||||
|
filtered_objects = [obj for obj in tracked_objects if obj.class_name in ['person', 'car']]
|
||||||
|
|
||||||
|
# Define colors for different track IDs (cycling through colors)
|
||||||
|
colors = [
|
||||||
|
(0, 255, 0), # Green
|
||||||
|
(255, 0, 0), # Blue
|
||||||
|
(0, 0, 255), # Red
|
||||||
|
(255, 255, 0), # Cyan
|
||||||
|
(255, 0, 255), # Magenta
|
||||||
|
(0, 255, 255), # Yellow
|
||||||
|
(128, 255, 0), # Light green
|
||||||
|
(255, 128, 0), # Orange
|
||||||
|
]
|
||||||
|
|
||||||
|
# Draw each tracked object
|
||||||
|
for obj in filtered_objects:
|
||||||
|
|
||||||
|
# Get color based on track ID
|
||||||
|
color = colors[obj.track_id % len(colors)]
|
||||||
|
|
||||||
|
# Extract bounding box coordinates
|
||||||
|
# Boxes come from YOLOv8 in 640x640 space, need to scale to frame size
|
||||||
|
x1, y1, x2, y2 = obj.bbox
|
||||||
|
|
||||||
|
# Scale from 640x640 model space to actual frame size
|
||||||
|
# YOLOv8 output is in 640x640, but frame is 1280x720
|
||||||
|
scale_x = frame_width / 640.0
|
||||||
|
scale_y = frame_height / 640.0
|
||||||
|
|
||||||
|
x1 = int(x1 * scale_x)
|
||||||
|
y1 = int(y1 * scale_y)
|
||||||
|
x2 = int(x2 * scale_x)
|
||||||
|
y2 = int(y2 * scale_y)
|
||||||
|
|
||||||
|
# Draw bounding box
|
||||||
|
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), color, 2)
|
||||||
|
|
||||||
|
# Prepare label text
|
||||||
|
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||||
|
|
||||||
|
# Get text size for background rectangle
|
||||||
|
(text_width, text_height), baseline = cv2.getTextSize(
|
||||||
|
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw label background
|
||||||
|
cv2.rectangle(
|
||||||
|
frame_bgr,
|
||||||
|
(x1, y1 - text_height - baseline - 5),
|
||||||
|
(x1 + text_width, y1),
|
||||||
|
color,
|
||||||
|
-1 # Filled
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw label text
|
||||||
|
cv2.putText(
|
||||||
|
frame_bgr,
|
||||||
|
label,
|
||||||
|
(x1, y1 - baseline - 2),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.5,
|
||||||
|
(0, 0, 0), # Black text
|
||||||
|
1,
|
||||||
|
cv2.LINE_AA
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw track history if available (trajectory)
|
||||||
|
if hasattr(obj, 'history') and len(obj.history) > 1:
|
||||||
|
points = []
|
||||||
|
for hist_bbox in obj.history[-10:]: # Last 10 positions
|
||||||
|
# Get center point of historical bbox (in 640x640 space)
|
||||||
|
hx1, hy1, hx2, hy2 = hist_bbox
|
||||||
|
|
||||||
|
# Scale from 640x640 to frame size
|
||||||
|
cx = int(((hx1 + hx2) / 2) * scale_x)
|
||||||
|
cy = int(((hy1 + hy2) / 2) * scale_y)
|
||||||
|
points.append((cx, cy))
|
||||||
|
|
||||||
|
# Draw trajectory line
|
||||||
|
for i in range(1, len(points)):
|
||||||
|
cv2.line(frame_bgr, points[i-1], points[i], color, 2)
|
||||||
|
|
||||||
|
# Draw info panel at top
|
||||||
|
info_bg_height = 80
|
||||||
|
overlay = frame_bgr.copy()
|
||||||
|
cv2.rectangle(overlay, (0, 0), (frame_bgr.shape[1], info_bg_height), (0, 0, 0), -1)
|
||||||
|
cv2.addWeighted(overlay, 0.5, frame_bgr, 0.5, 0, frame_bgr)
|
||||||
|
|
||||||
|
# Draw statistics text
|
||||||
|
y_offset = 25
|
||||||
|
cv2.putText(
|
||||||
|
frame_bgr,
|
||||||
|
f"Frame: {frame_info.get('frame_count', 0)} | FPS: {frame_info.get('fps', 0):.1f}",
|
||||||
|
(10, y_offset),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
(255, 255, 255),
|
||||||
|
2,
|
||||||
|
cv2.LINE_AA
|
||||||
|
)
|
||||||
|
|
||||||
|
y_offset += 25
|
||||||
|
# Count persons and cars
|
||||||
|
person_count = sum(1 for obj in filtered_objects if obj.class_name == 'person')
|
||||||
|
car_count = sum(1 for obj in filtered_objects if obj.class_name == 'car')
|
||||||
|
cv2.putText(
|
||||||
|
frame_bgr,
|
||||||
|
f"Persons: {person_count} | Cars: {car_count} | Total Visible: {len(filtered_objects)}",
|
||||||
|
(10, y_offset),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
(255, 255, 255),
|
||||||
|
2,
|
||||||
|
cv2.LINE_AA
|
||||||
|
)
|
||||||
|
|
||||||
|
return frame_bgr
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Main function for real-time tracking visualization.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
GPU_ID = 0
|
||||||
|
MODEL_PATH = "models/yolov8n.pt" # Changed to PT file
|
||||||
|
RTSP_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||||
|
BUFFER_SIZE = 30
|
||||||
|
WINDOW_NAME = "Real-time Object Tracking"
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Real-time GPU-Accelerated Object Tracking")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Step 1: Create model repository with PT conversion enabled
|
||||||
|
print("\n[1/4] Initializing TensorRT Model Repository...")
|
||||||
|
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4, enable_pt_conversion=True)
|
||||||
|
|
||||||
|
# Load detection model (will auto-convert PT to TRT)
|
||||||
|
model_id = "yolov8_detector"
|
||||||
|
if os.path.exists(MODEL_PATH):
|
||||||
|
try:
|
||||||
|
print(f"Loading model from {MODEL_PATH}...")
|
||||||
|
print("Note: First load will convert PT to TensorRT (may take 3-5 minutes)")
|
||||||
|
print("Subsequent loads will use cached TensorRT engine")
|
||||||
|
|
||||||
|
metadata = model_repo.load_model(
|
||||||
|
model_id=model_id,
|
||||||
|
file_path=MODEL_PATH,
|
||||||
|
num_contexts=4,
|
||||||
|
pt_input_shapes={"images": (1, 3, 640, 640)}, # Required for PT conversion
|
||||||
|
pt_precision=torch.float16 # Use FP16 for better performance
|
||||||
|
)
|
||||||
|
print(f"✓ Model loaded successfully")
|
||||||
|
print(f" Input shape: {metadata.input_shapes}")
|
||||||
|
print(f" Output shape: {metadata.output_shapes}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Failed to load model: {e}")
|
||||||
|
print(f" Please ensure {MODEL_PATH} exists")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print(f"✗ Model file not found: {MODEL_PATH}")
|
||||||
|
print(f" Please provide a valid PyTorch (.pt) or TensorRT (.trt) model file")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 2: Create tracking controller
|
||||||
|
print("\n[2/4] Creating TrackingController...")
|
||||||
|
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tracking_controller = tracking_factory.create_controller(
|
||||||
|
model_repository=model_repo,
|
||||||
|
model_id=model_id,
|
||||||
|
tracker_type="iou",
|
||||||
|
max_age=30,
|
||||||
|
min_confidence=0.5,
|
||||||
|
iou_threshold=0.3,
|
||||||
|
class_names=COCO_CLASSES
|
||||||
|
)
|
||||||
|
print(f"✓ Controller created: {tracking_controller}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Failed to create controller: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Create stream decoder
|
||||||
|
print("\n[3/4] Creating RTSP Stream Decoder...")
|
||||||
|
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||||
|
decoder = stream_factory.create_decoder(
|
||||||
|
rtsp_url=RTSP_URL,
|
||||||
|
buffer_size=BUFFER_SIZE
|
||||||
|
)
|
||||||
|
decoder.start()
|
||||||
|
print(f"✓ Decoder started for: {RTSP_URL}")
|
||||||
|
print(f" Waiting for connection...")
|
||||||
|
|
||||||
|
# Wait for stream connection
|
||||||
|
print(" Waiting up to 15 seconds for connection...")
|
||||||
|
connected = False
|
||||||
|
for i in range(15):
|
||||||
|
time.sleep(1)
|
||||||
|
if decoder.is_connected():
|
||||||
|
connected = True
|
||||||
|
break
|
||||||
|
print(f" Waiting... {i+1}/15 seconds (status: {decoder.get_status().value})")
|
||||||
|
|
||||||
|
if connected:
|
||||||
|
print(f"✓ Stream connected!")
|
||||||
|
else:
|
||||||
|
print(f"✗ Stream not connected after 15 seconds (status: {decoder.get_status().value})")
|
||||||
|
print(f" Proceeding anyway - will start displaying when frames arrive...")
|
||||||
|
# Don't exit - continue and wait for frames
|
||||||
|
|
||||||
|
# Step 4: Create OpenCV window
|
||||||
|
print("\n[4/4] Starting Real-time Visualization...")
|
||||||
|
cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
|
||||||
|
cv2.resizeWindow(WINDOW_NAME, 1280, 720)
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("Real-time tracking started!")
|
||||||
|
print("Press 'q' to quit | Press 's' to save screenshot")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
# FPS tracking
|
||||||
|
fps_start_time = time.time()
|
||||||
|
fps_frame_count = 0
|
||||||
|
current_fps = 0.0
|
||||||
|
|
||||||
|
frame_count = 0
|
||||||
|
screenshot_count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Get frame from decoder (CPU memory for OpenCV)
|
||||||
|
frame_cpu = decoder.get_frame_cpu(index=-1, rgb=True)
|
||||||
|
|
||||||
|
if frame_cpu is None:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get GPU frame for tracking
|
||||||
|
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||||
|
|
||||||
|
if frame_gpu is None:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame_count += 1
|
||||||
|
fps_frame_count += 1
|
||||||
|
|
||||||
|
# Run tracking on GPU frame with YOLOv8 pre/postprocessing
|
||||||
|
tracked_objects = tracking_controller.track(
|
||||||
|
frame_gpu,
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate FPS every second
|
||||||
|
elapsed = time.time() - fps_start_time
|
||||||
|
if elapsed >= 1.0:
|
||||||
|
current_fps = fps_frame_count / elapsed
|
||||||
|
fps_frame_count = 0
|
||||||
|
fps_start_time = time.time()
|
||||||
|
|
||||||
|
# Get tracking statistics
|
||||||
|
stats = tracking_controller.get_statistics()
|
||||||
|
|
||||||
|
# Prepare frame info for overlay
|
||||||
|
frame_info = {
|
||||||
|
'frame_count': frame_count,
|
||||||
|
'fps': current_fps,
|
||||||
|
'total_tracks': stats['total_tracks_created'],
|
||||||
|
'class_counts': stats['class_counts']
|
||||||
|
}
|
||||||
|
|
||||||
|
# Draw tracking overlay on CPU frame
|
||||||
|
display_frame = draw_tracking_overlay(frame_cpu, tracked_objects, frame_info)
|
||||||
|
|
||||||
|
# Display frame
|
||||||
|
cv2.imshow(WINDOW_NAME, display_frame)
|
||||||
|
|
||||||
|
# Handle keyboard input
|
||||||
|
key = cv2.waitKey(1) & 0xFF
|
||||||
|
|
||||||
|
if key == ord('q'):
|
||||||
|
print("\n✓ Quit requested by user")
|
||||||
|
break
|
||||||
|
elif key == ord('s'):
|
||||||
|
# Save screenshot
|
||||||
|
screenshot_count += 1
|
||||||
|
filename = f"screenshot_{screenshot_count:04d}.jpg"
|
||||||
|
cv2.imwrite(filename, display_frame)
|
||||||
|
print(f"✓ Screenshot saved: {filename}")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n✓ Interrupted by user")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n✗ Error during tracking: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Cleanup")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Print final statistics
|
||||||
|
print("\nFinal Tracking Statistics:")
|
||||||
|
stats = tracking_controller.get_statistics()
|
||||||
|
for key, value in stats.items():
|
||||||
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
|
# Close OpenCV window
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
# Stop decoder
|
||||||
|
print("\nStopping decoder...")
|
||||||
|
decoder.stop()
|
||||||
|
print("✓ Decoder stopped")
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Real-time tracking completed!")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
def main_multi_window():
|
||||||
|
"""
|
||||||
|
Example: Display multiple camera streams in separate windows.
|
||||||
|
|
||||||
|
This demonstrates tracking on multiple RTSP streams simultaneously
|
||||||
|
with separate OpenCV windows for each stream.
|
||||||
|
"""
|
||||||
|
GPU_ID = 0
|
||||||
|
MODEL_PATH = "models/yolov8n.pt"
|
||||||
|
|
||||||
|
# Load camera URLs from environment
|
||||||
|
camera_urls = []
|
||||||
|
i = 1
|
||||||
|
while True:
|
||||||
|
url = os.getenv(f'CAMERA_URL_{i}')
|
||||||
|
if url:
|
||||||
|
camera_urls.append(url)
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not camera_urls:
|
||||||
|
print("No camera URLs found in .env file")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Starting multi-window tracking with {len(camera_urls)} cameras")
|
||||||
|
|
||||||
|
# Create shared model repository with PT conversion enabled
|
||||||
|
import torch
|
||||||
|
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8, enable_pt_conversion=True)
|
||||||
|
|
||||||
|
if os.path.exists(MODEL_PATH):
|
||||||
|
print(f"Loading model from {MODEL_PATH}...")
|
||||||
|
print("Note: First load will convert PT to TensorRT (may take 3-5 minutes)")
|
||||||
|
print("Subsequent loads will use cached TensorRT engine")
|
||||||
|
|
||||||
|
model_repo.load_model(
|
||||||
|
model_id="detector",
|
||||||
|
file_path=MODEL_PATH,
|
||||||
|
num_contexts=8,
|
||||||
|
pt_input_shapes={"images": (1, 3, 640, 640)}, # Required for PT conversion
|
||||||
|
pt_precision=torch.float16 # Use FP16 for better performance
|
||||||
|
)
|
||||||
|
print("✓ Model loaded successfully")
|
||||||
|
else:
|
||||||
|
print(f"Model not found: {MODEL_PATH}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create tracking factory
|
||||||
|
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||||
|
|
||||||
|
# Create decoders and controllers
|
||||||
|
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||||
|
decoders = []
|
||||||
|
controllers = []
|
||||||
|
window_names = []
|
||||||
|
|
||||||
|
for i, url in enumerate(camera_urls):
|
||||||
|
# Create decoder
|
||||||
|
decoder = stream_factory.create_decoder(url, buffer_size=30)
|
||||||
|
decoder.start()
|
||||||
|
decoders.append(decoder)
|
||||||
|
|
||||||
|
# Create tracking controller
|
||||||
|
controller = tracking_factory.create_controller(
|
||||||
|
model_repository=model_repo,
|
||||||
|
model_id="detector",
|
||||||
|
tracker_type="iou",
|
||||||
|
max_age=30,
|
||||||
|
min_confidence=0.5,
|
||||||
|
iou_threshold=0.3,
|
||||||
|
class_names=COCO_CLASSES
|
||||||
|
)
|
||||||
|
controllers.append(controller)
|
||||||
|
|
||||||
|
# Create window
|
||||||
|
window_name = f"Camera {i+1}"
|
||||||
|
window_names.append(window_name)
|
||||||
|
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||||||
|
cv2.resizeWindow(window_name, 640, 480)
|
||||||
|
|
||||||
|
print(f"Camera {i+1}: {url}")
|
||||||
|
|
||||||
|
print("\nWaiting for streams to connect...")
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
print("\nPress 'q' to quit")
|
||||||
|
|
||||||
|
# FPS tracking for each stream
|
||||||
|
fps_data = [{'start': time.time(), 'count': 0, 'fps': 0.0} for _ in camera_urls]
|
||||||
|
frame_counts = [0] * len(camera_urls)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
for i, (decoder, controller, window_name) in enumerate(zip(decoders, controllers, window_names)):
|
||||||
|
# Get frames
|
||||||
|
frame_cpu = decoder.get_frame_cpu(index=-1, rgb=True)
|
||||||
|
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||||
|
|
||||||
|
if frame_cpu is None or frame_gpu is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame_counts[i] += 1
|
||||||
|
fps_data[i]['count'] += 1
|
||||||
|
|
||||||
|
# Calculate FPS
|
||||||
|
elapsed = time.time() - fps_data[i]['start']
|
||||||
|
if elapsed >= 1.0:
|
||||||
|
fps_data[i]['fps'] = fps_data[i]['count'] / elapsed
|
||||||
|
fps_data[i]['count'] = 0
|
||||||
|
fps_data[i]['start'] = time.time()
|
||||||
|
|
||||||
|
# Track objects with YOLOv8 pre/postprocessing
|
||||||
|
tracked_objects = controller.track(
|
||||||
|
frame_gpu,
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get statistics
|
||||||
|
stats = controller.get_statistics()
|
||||||
|
|
||||||
|
# Prepare frame info
|
||||||
|
frame_info = {
|
||||||
|
'frame_count': frame_counts[i],
|
||||||
|
'fps': fps_data[i]['fps'],
|
||||||
|
'total_tracks': stats['total_tracks_created'],
|
||||||
|
'class_counts': stats['class_counts']
|
||||||
|
}
|
||||||
|
|
||||||
|
# Draw overlay and display
|
||||||
|
display_frame = draw_tracking_overlay(frame_cpu, tracked_objects, frame_info)
|
||||||
|
cv2.imshow(window_name, display_frame)
|
||||||
|
|
||||||
|
# Check for quit
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nInterrupted by user")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
print("\nCleaning up...")
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
for decoder in decoders:
|
||||||
|
decoder.stop()
|
||||||
|
|
||||||
|
print("\nFinal Statistics:")
|
||||||
|
for i, controller in enumerate(controllers):
|
||||||
|
stats = controller.get_statistics()
|
||||||
|
print(f"\nCamera {i+1}:")
|
||||||
|
print(f" Frames: {stats['frame_count']}")
|
||||||
|
print(f" Tracks created: {stats['total_tracks_created']}")
|
||||||
|
print(f" Active tracks: {stats['active_tracks']}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run single camera visualization
|
||||||
|
# main()
|
||||||
|
|
||||||
|
# Uncomment to run multi-window visualization
|
||||||
|
main_multi_window()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue