feat: trac king
This commit is contained in:
parent
cf24a172a2
commit
bea895d3d8
4 changed files with 1054 additions and 0 deletions
318
test_tracking.py
Normal file
318
test_tracking.py
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
"""
|
||||
Test script for TrackingController and TrackingFactory.
|
||||
|
||||
This script demonstrates how to use the tracking system with:
|
||||
- TensorRT model repository (dependency injection)
|
||||
- TrackingFactory for controller creation
|
||||
- GPU-accelerated object tracking on RTSP streams
|
||||
- Persistent track IDs and history management
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from services import (
|
||||
StreamDecoderFactory,
|
||||
TensorRTModelRepository,
|
||||
TrackingFactory,
|
||||
TrackedObject
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main test function demonstrating tracking workflow.
|
||||
"""
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.trt" # Update with your model path
|
||||
RTSP_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
BUFFER_SIZE = 30
|
||||
|
||||
# COCO class names (example for YOLOv8)
|
||||
COCO_CLASSES = {
|
||||
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
|
||||
5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
|
||||
# Add more as needed...
|
||||
}
|
||||
|
||||
print("=" * 80)
|
||||
print("GPU-Accelerated Object Tracking Test")
|
||||
print("=" * 80)
|
||||
|
||||
# Step 1: Create model repository
|
||||
print("\n[1/5] Initializing TensorRT Model Repository...")
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4)
|
||||
|
||||
# Load detection model (if file exists)
|
||||
model_id = "yolov8_detector"
|
||||
if os.path.exists(MODEL_PATH):
|
||||
try:
|
||||
metadata = model_repo.load_model(
|
||||
model_id=model_id,
|
||||
file_path=MODEL_PATH,
|
||||
num_contexts=4
|
||||
)
|
||||
print(f"✓ Model loaded successfully")
|
||||
print(f" Input shape: {metadata.input_shapes}")
|
||||
print(f" Output shape: {metadata.output_shapes}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
print(f" Please ensure {MODEL_PATH} exists")
|
||||
print(f" Continuing with demo (will use mock detections)...")
|
||||
model_id = None
|
||||
else:
|
||||
print(f"✗ Model file not found: {MODEL_PATH}")
|
||||
print(f" Continuing with demo (will use mock detections)...")
|
||||
model_id = None
|
||||
|
||||
# Step 2: Create tracking factory
|
||||
print("\n[2/5] Creating TrackingFactory...")
|
||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||
print(f"✓ Factory created: {tracking_factory}")
|
||||
|
||||
# Step 3: Create tracking controller (only if model loaded)
|
||||
tracking_controller = None
|
||||
if model_id is not None:
|
||||
print("\n[3/5] Creating TrackingController...")
|
||||
try:
|
||||
tracking_controller = tracking_factory.create_controller(
|
||||
model_repository=model_repo,
|
||||
model_id=model_id,
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5,
|
||||
iou_threshold=0.3,
|
||||
class_names=COCO_CLASSES
|
||||
)
|
||||
print(f"✓ Controller created: {tracking_controller}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to create controller: {e}")
|
||||
tracking_controller = None
|
||||
else:
|
||||
print("\n[3/5] Skipping TrackingController creation (no model loaded)")
|
||||
|
||||
# Step 4: Create stream decoder
|
||||
print("\n[4/5] Creating RTSP Stream Decoder...")
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoder = stream_factory.create_decoder(
|
||||
rtsp_url=RTSP_URL,
|
||||
buffer_size=BUFFER_SIZE
|
||||
)
|
||||
decoder.start()
|
||||
print(f"✓ Decoder started for: {RTSP_URL}")
|
||||
print(f" Waiting for connection...")
|
||||
|
||||
# Wait for stream connection
|
||||
time.sleep(5)
|
||||
|
||||
if decoder.is_connected():
|
||||
print(f"✓ Stream connected!")
|
||||
else:
|
||||
print(f"✗ Stream not connected (status: {decoder.get_status().value})")
|
||||
print(f" Note: This is expected if RTSP URL is not available")
|
||||
print(f" The tracking system will still work with valid streams")
|
||||
|
||||
# Step 5: Run tracking loop (demo)
|
||||
print("\n[5/5] Running Tracking Loop...")
|
||||
print(f" Processing frames for 30 seconds...")
|
||||
print(f" Press Ctrl+C to stop early\n")
|
||||
|
||||
try:
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < 30:
|
||||
# Get latest frame from decoder (GPU tensor)
|
||||
frame = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
if frame is None:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# Run tracking (if controller available)
|
||||
if tracking_controller is not None:
|
||||
try:
|
||||
# Track objects in frame
|
||||
tracked_objects = tracking_controller.track(frame)
|
||||
|
||||
# Display tracking results every 10 frames
|
||||
if frame_count % 10 == 0:
|
||||
print(f"\n--- Frame {frame_count} ---")
|
||||
print(f"Active tracks: {len(tracked_objects)}")
|
||||
|
||||
for obj in tracked_objects:
|
||||
print(f" Track #{obj.track_id}: {obj.class_name} "
|
||||
f"(conf={obj.confidence:.2f}, "
|
||||
f"bbox={[f'{x:.1f}' for x in obj.bbox]}, "
|
||||
f"age={obj.age(tracking_controller._frame_count)} frames)")
|
||||
|
||||
# Print statistics
|
||||
stats = tracking_controller.get_statistics()
|
||||
print(f"\nStatistics:")
|
||||
print(f" Total frames processed: {stats['frame_count']}")
|
||||
print(f" Total tracks created: {stats['total_tracks_created']}")
|
||||
print(f" Total detections: {stats['total_detections']}")
|
||||
print(f" Avg detections/frame: {stats['avg_detections_per_frame']:.2f}")
|
||||
print(f" Class counts: {stats['class_counts']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Tracking error on frame {frame_count}: {e}")
|
||||
|
||||
# Small delay to avoid overwhelming output
|
||||
time.sleep(0.1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n✓ Interrupted by user")
|
||||
|
||||
# Cleanup
|
||||
print("\n" + "=" * 80)
|
||||
print("Cleanup")
|
||||
print("=" * 80)
|
||||
|
||||
if tracking_controller is not None:
|
||||
print("\nTracking final statistics:")
|
||||
stats = tracking_controller.get_statistics()
|
||||
for key, value in stats.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\nExporting tracks to JSON...")
|
||||
try:
|
||||
tracks_json = tracking_controller.export_tracks(format="json")
|
||||
with open("tracked_objects.json", "w") as f:
|
||||
f.write(tracks_json)
|
||||
print(f"✓ Tracks exported to tracked_objects.json")
|
||||
except Exception as e:
|
||||
print(f"✗ Export failed: {e}")
|
||||
|
||||
print("\nStopping decoder...")
|
||||
decoder.stop()
|
||||
print("✓ Decoder stopped")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Test completed successfully!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def test_multi_camera_tracking():
|
||||
"""
|
||||
Example: Track objects across multiple camera streams.
|
||||
|
||||
This demonstrates:
|
||||
- Shared model repository across multiple streams
|
||||
- Multiple tracking controllers (one per camera)
|
||||
- Efficient GPU resource usage
|
||||
"""
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.trt"
|
||||
|
||||
# Load multiple camera URLs
|
||||
camera_urls = []
|
||||
i = 1
|
||||
while True:
|
||||
url = os.getenv(f'CAMERA_URL_{i}')
|
||||
if url:
|
||||
camera_urls.append(url)
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if not camera_urls:
|
||||
print("No camera URLs found in .env file")
|
||||
return
|
||||
|
||||
print(f"Testing multi-camera tracking with {len(camera_urls)} cameras")
|
||||
|
||||
# Create shared model repository
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8)
|
||||
|
||||
if os.path.exists(MODEL_PATH):
|
||||
model_repo.load_model("detector", MODEL_PATH, num_contexts=8)
|
||||
else:
|
||||
print(f"Model not found: {MODEL_PATH}")
|
||||
return
|
||||
|
||||
# Create tracking factory
|
||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||
|
||||
# Create stream decoders and tracking controllers
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoders = []
|
||||
controllers = []
|
||||
|
||||
for i, url in enumerate(camera_urls):
|
||||
# Create decoder
|
||||
decoder = stream_factory.create_decoder(url, buffer_size=30)
|
||||
decoder.start()
|
||||
decoders.append(decoder)
|
||||
|
||||
# Create tracking controller
|
||||
controller = tracking_factory.create_controller(
|
||||
model_repository=model_repo,
|
||||
model_id="detector",
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5
|
||||
)
|
||||
controllers.append(controller)
|
||||
|
||||
print(f"Camera {i+1}: {url}")
|
||||
|
||||
print(f"\nWaiting for streams to connect...")
|
||||
time.sleep(10)
|
||||
|
||||
# Track objects for 30 seconds
|
||||
print(f"\nTracking objects across {len(camera_urls)} cameras...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while time.time() - start_time < 30:
|
||||
for i, (decoder, controller) in enumerate(zip(decoders, controllers)):
|
||||
frame = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
if frame is not None:
|
||||
tracked_objects = controller.track(frame)
|
||||
|
||||
# Print stats every 10 seconds
|
||||
if int(time.time() - start_time) % 10 == 0:
|
||||
stats = controller.get_statistics()
|
||||
print(f"Camera {i+1}: {stats['active_tracks']} tracks, "
|
||||
f"{stats['frame_count']} frames")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
|
||||
# Cleanup
|
||||
print("\nCleaning up...")
|
||||
for decoder in decoders:
|
||||
decoder.stop()
|
||||
|
||||
# Print final stats
|
||||
print("\nFinal Statistics:")
|
||||
for i, controller in enumerate(controllers):
|
||||
stats = controller.get_statistics()
|
||||
print(f"\nCamera {i+1}:")
|
||||
print(f" Frames: {stats['frame_count']}")
|
||||
print(f" Tracks created: {stats['total_tracks_created']}")
|
||||
print(f" Active tracks: {stats['active_tracks']}")
|
||||
|
||||
# Print model repository stats
|
||||
print("\nModel Repository Stats:")
|
||||
repo_stats = model_repo.get_stats()
|
||||
for key, value in repo_stats.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run single camera test
|
||||
main()
|
||||
|
||||
# Uncomment to test multi-camera tracking
|
||||
# test_multi_camera_tracking()
|
||||
Loading…
Add table
Add a link
Reference in a new issue