""" Test script for TrackingController and TrackingFactory. This script demonstrates how to use the tracking system with: - TensorRT model repository (dependency injection) - TrackingFactory for controller creation - GPU-accelerated object tracking on RTSP streams - Persistent track IDs and history management """ import time import os from dotenv import load_dotenv from services import ( StreamDecoderFactory, TensorRTModelRepository, TrackingFactory, TrackedObject ) # Load environment variables load_dotenv() def main(): """ Main test function demonstrating tracking workflow. """ # Configuration GPU_ID = 0 MODEL_PATH = "models/yolov8n.trt" # Update with your model path RTSP_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test') BUFFER_SIZE = 30 # COCO class names (example for YOLOv8) COCO_CLASSES = { 0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', # Add more as needed... } print("=" * 80) print("GPU-Accelerated Object Tracking Test") print("=" * 80) # Step 1: Create model repository print("\n[1/5] Initializing TensorRT Model Repository...") model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4) # Load detection model (if file exists) model_id = "yolov8_detector" if os.path.exists(MODEL_PATH): try: metadata = model_repo.load_model( model_id=model_id, file_path=MODEL_PATH, num_contexts=4 ) print(f"✓ Model loaded successfully") print(f" Input shape: {metadata.input_shapes}") print(f" Output shape: {metadata.output_shapes}") except Exception as e: print(f"✗ Failed to load model: {e}") print(f" Please ensure {MODEL_PATH} exists") print(f" Continuing with demo (will use mock detections)...") model_id = None else: print(f"✗ Model file not found: {MODEL_PATH}") print(f" Continuing with demo (will use mock detections)...") model_id = None # Step 2: Create tracking factory print("\n[2/5] Creating TrackingFactory...") tracking_factory = TrackingFactory(gpu_id=GPU_ID) print(f"✓ Factory created: {tracking_factory}") # Step 3: Create tracking controller (only if model loaded) tracking_controller = None if model_id is not None: print("\n[3/5] Creating TrackingController...") try: tracking_controller = tracking_factory.create_controller( model_repository=model_repo, model_id=model_id, tracker_type="iou", max_age=30, min_confidence=0.5, iou_threshold=0.3, class_names=COCO_CLASSES ) print(f"✓ Controller created: {tracking_controller}") except Exception as e: print(f"✗ Failed to create controller: {e}") tracking_controller = None else: print("\n[3/5] Skipping TrackingController creation (no model loaded)") # Step 4: Create stream decoder print("\n[4/5] Creating RTSP Stream Decoder...") stream_factory = StreamDecoderFactory(gpu_id=GPU_ID) decoder = stream_factory.create_decoder( rtsp_url=RTSP_URL, buffer_size=BUFFER_SIZE ) decoder.start() print(f"✓ Decoder started for: {RTSP_URL}") print(f" Waiting for connection...") # Wait for stream connection time.sleep(5) if decoder.is_connected(): print(f"✓ Stream connected!") else: print(f"✗ Stream not connected (status: {decoder.get_status().value})") print(f" Note: This is expected if RTSP URL is not available") print(f" The tracking system will still work with valid streams") # Step 5: Run tracking loop (demo) print("\n[5/5] Running Tracking Loop...") print(f" Processing frames for 30 seconds...") print(f" Press Ctrl+C to stop early\n") try: frame_count = 0 start_time = time.time() while time.time() - start_time < 30: # Get latest frame from decoder (GPU tensor) frame = decoder.get_latest_frame(rgb=True) if frame is None: time.sleep(0.1) continue frame_count += 1 # Run tracking (if controller available) if tracking_controller is not None: try: # Track objects in frame tracked_objects = tracking_controller.track(frame) # Display tracking results every 10 frames if frame_count % 10 == 0: print(f"\n--- Frame {frame_count} ---") print(f"Active tracks: {len(tracked_objects)}") for obj in tracked_objects: print(f" Track #{obj.track_id}: {obj.class_name} " f"(conf={obj.confidence:.2f}, " f"bbox={[f'{x:.1f}' for x in obj.bbox]}, " f"age={obj.age(tracking_controller._frame_count)} frames)") # Print statistics stats = tracking_controller.get_statistics() print(f"\nStatistics:") print(f" Total frames processed: {stats['frame_count']}") print(f" Total tracks created: {stats['total_tracks_created']}") print(f" Total detections: {stats['total_detections']}") print(f" Avg detections/frame: {stats['avg_detections_per_frame']:.2f}") print(f" Class counts: {stats['class_counts']}") except Exception as e: print(f"✗ Tracking error on frame {frame_count}: {e}") # Small delay to avoid overwhelming output time.sleep(0.1) except KeyboardInterrupt: print("\n\n✓ Interrupted by user") # Cleanup print("\n" + "=" * 80) print("Cleanup") print("=" * 80) if tracking_controller is not None: print("\nTracking final statistics:") stats = tracking_controller.get_statistics() for key, value in stats.items(): print(f" {key}: {value}") print("\nExporting tracks to JSON...") try: tracks_json = tracking_controller.export_tracks(format="json") with open("tracked_objects.json", "w") as f: f.write(tracks_json) print(f"✓ Tracks exported to tracked_objects.json") except Exception as e: print(f"✗ Export failed: {e}") print("\nStopping decoder...") decoder.stop() print("✓ Decoder stopped") print("\n" + "=" * 80) print("Test completed successfully!") print("=" * 80) def test_multi_camera_tracking(): """ Example: Track objects across multiple camera streams. This demonstrates: - Shared model repository across multiple streams - Multiple tracking controllers (one per camera) - Efficient GPU resource usage """ GPU_ID = 0 MODEL_PATH = "models/yolov8n.trt" # Load multiple camera URLs camera_urls = [] i = 1 while True: url = os.getenv(f'CAMERA_URL_{i}') if url: camera_urls.append(url) i += 1 else: break if not camera_urls: print("No camera URLs found in .env file") return print(f"Testing multi-camera tracking with {len(camera_urls)} cameras") # Create shared model repository model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8) if os.path.exists(MODEL_PATH): model_repo.load_model("detector", MODEL_PATH, num_contexts=8) else: print(f"Model not found: {MODEL_PATH}") return # Create tracking factory tracking_factory = TrackingFactory(gpu_id=GPU_ID) # Create stream decoders and tracking controllers stream_factory = StreamDecoderFactory(gpu_id=GPU_ID) decoders = [] controllers = [] for i, url in enumerate(camera_urls): # Create decoder decoder = stream_factory.create_decoder(url, buffer_size=30) decoder.start() decoders.append(decoder) # Create tracking controller controller = tracking_factory.create_controller( model_repository=model_repo, model_id="detector", tracker_type="iou", max_age=30, min_confidence=0.5 ) controllers.append(controller) print(f"Camera {i+1}: {url}") print(f"\nWaiting for streams to connect...") time.sleep(10) # Track objects for 30 seconds print(f"\nTracking objects across {len(camera_urls)} cameras...") start_time = time.time() try: while time.time() - start_time < 30: for i, (decoder, controller) in enumerate(zip(decoders, controllers)): frame = decoder.get_latest_frame(rgb=True) if frame is not None: tracked_objects = controller.track(frame) # Print stats every 10 seconds if int(time.time() - start_time) % 10 == 0: stats = controller.get_statistics() print(f"Camera {i+1}: {stats['active_tracks']} tracks, " f"{stats['frame_count']} frames") time.sleep(0.1) except KeyboardInterrupt: print("\nInterrupted by user") # Cleanup print("\nCleaning up...") for decoder in decoders: decoder.stop() # Print final stats print("\nFinal Statistics:") for i, controller in enumerate(controllers): stats = controller.get_statistics() print(f"\nCamera {i+1}:") print(f" Frames: {stats['frame_count']}") print(f" Tracks created: {stats['total_tracks_created']}") print(f" Active tracks: {stats['active_tracks']}") # Print model repository stats print("\nModel Repository Stats:") repo_stats = model_repo.get_stats() for key, value in repo_stats.items(): print(f" {key}: {value}") if __name__ == "__main__": # Run single camera test main() # Uncomment to test multi-camera tracking # test_multi_camera_tracking()