""" Real-time object tracking with event-driven batching architecture. This script demonstrates: - Event-driven stream processing with StreamConnectionManager - Batched GPU inference with ModelController - Ping-pong buffer architecture for optimal throughput - Callback-based event-driven pattern for RTSP streams - Automatic PT to TensorRT conversion """ import time import os import torch import cv2 import numpy as np from dotenv import load_dotenv from services import ( StreamConnectionManager, YOLOv8Utils, COCO_CLASSES, ) # Load environment variables load_dotenv() def main_single_stream(): """Single stream example with event-driven architecture.""" print("=" * 80) print("Event-Driven GPU-Accelerated Object Tracking - Single Stream") print("=" * 80) # Configuration GPU_ID = 0 MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test') BATCH_SIZE = 4 FORCE_TIMEOUT = 0.05 ENABLE_DISPLAY = os.getenv('ENABLE_DISPLAY', 'false').lower() == 'true' # Set to 'true' to enable OpenCV display MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300')) # Stop after N frames (0 = unlimited) print(f"\nConfiguration:") print(f" GPU: {GPU_ID}") print(f" Model: {MODEL_PATH}") print(f" Stream: {STREAM_URL}") print(f" Batch size: {BATCH_SIZE}") print(f" Force timeout: {FORCE_TIMEOUT}s") print(f" Display: {'Enabled' if ENABLE_DISPLAY else 'Disabled (inference only)'}") print(f" Max frames: {MAX_FRAMES if MAX_FRAMES > 0 else 'Unlimited'}\n") # Create StreamConnectionManager with PT conversion enabled print("[1/3] Creating StreamConnectionManager...") manager = StreamConnectionManager( gpu_id=GPU_ID, batch_size=BATCH_SIZE, force_timeout=FORCE_TIMEOUT, enable_pt_conversion=True # Enable PT conversion ) print("✓ Manager created") # Initialize with model (transparent loading - no manual parameters needed) print("\n[2/3] Initializing model...") print("Note: YOLO models auto-convert to native TensorRT .engine (first time only)") print("Metadata is auto-detected from model - no manual input_shapes needed!\n") try: manager.initialize( model_path=MODEL_PATH, model_id="detector", preprocess_fn=YOLOv8Utils.preprocess, postprocess_fn=YOLOv8Utils.postprocess, num_contexts=4 # Note: No pt_input_shapes or pt_precision needed for YOLO models! ) print("✓ Manager initialized") except Exception as e: print(f"✗ Failed to initialize: {e}") import traceback traceback.print_exc() return # Connect stream print("\n[3/3] Connecting to stream...") try: connection = manager.connect_stream( rtsp_url=STREAM_URL, stream_id="camera_1", buffer_size=30 ) print(f"✓ Stream connected: camera_1") except Exception as e: print(f"✗ Failed to connect stream: {e}") return print(f"\n{'=' * 80}") print("Event-driven tracking is running!") print("Press Ctrl+C to stop") print(f"{'=' * 80}\n") # Stream results with optional OpenCV visualization result_count = 0 start_time = time.time() # Create window only if display is enabled if ENABLE_DISPLAY: cv2.namedWindow("Object Tracking", cv2.WINDOW_NORMAL) cv2.resizeWindow("Object Tracking", 1280, 720) try: for result in connection.tracking_results(): result_count += 1 # Check if we've reached max frames if MAX_FRAMES > 0 and result_count >= MAX_FRAMES: print(f"\n✓ Reached max frames limit ({MAX_FRAMES})") break # OpenCV visualization (only if enabled) if ENABLE_DISPLAY: # Get latest frame from decoder (as CPU numpy array) frame = connection.decoder.get_latest_frame_cpu(rgb=True) if frame is not None: # Convert to BGR for OpenCV frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Draw tracked objects for obj in result.tracked_objects: # Get bbox coordinates x1, y1, x2, y2 = map(int, obj.bbox) # Draw bounding box cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2) # Draw track ID and class name label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}" label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) # Draw label background cv2.rectangle(frame_bgr, (x1, y1 - label_size[1] - 10), (x1 + label_size[0], y1), (0, 255, 0), -1) # Draw label text cv2.putText(frame_bgr, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) # Draw FPS and object count elapsed = time.time() - start_time fps = result_count / elapsed if elapsed > 0 else 0 info_text = f"FPS: {fps:.1f} | Objects: {len(result.tracked_objects)} | Frame: {result_count}" cv2.putText(frame_bgr, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) # Display frame cv2.imshow("Object Tracking", frame_bgr) # Check for 'q' key to quit if cv2.waitKey(1) & 0xFF == ord('q'): print(f"\n✓ Quit by user (pressed 'q')") break # Print stats every 30 results if result_count % 30 == 0: elapsed = time.time() - start_time fps = result_count / elapsed if elapsed > 0 else 0 print(f"\nResults: {result_count} | FPS: {fps:.1f}") print(f" Stream: {result.stream_id}") print(f" Objects: {len(result.tracked_objects)}") if result.tracked_objects: class_counts = {} for obj in result.tracked_objects: class_counts[obj.class_name] = class_counts.get(obj.class_name, 0) + 1 print(f" Classes: {class_counts}") except KeyboardInterrupt: print(f"\n✓ Interrupted by user") # Cleanup print(f"\n{'=' * 80}") print("Cleanup") print(f"{'=' * 80}") # Close OpenCV window if it was opened if ENABLE_DISPLAY: cv2.destroyAllWindows() connection.stop() manager.shutdown() print("✓ Stopped") # Final stats elapsed = time.time() - start_time avg_fps = result_count / elapsed if elapsed > 0 else 0 print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)") def main_multi_stream(): """Multi-stream example with batched inference.""" print("=" * 80) print("Event-Driven GPU-Accelerated Object Tracking - Multi-Stream") print("=" * 80) # Configuration GPU_ID = 0 MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt BATCH_SIZE = 16 FORCE_TIMEOUT = 0.05 # Load camera URLs camera_urls = [] i = 1 while True: url = os.getenv(f'CAMERA_URL_{i}') if url: camera_urls.append((f"camera_{i}", url)) i += 1 else: break if not camera_urls: print("No camera URLs found in .env") return print(f"\nConfiguration:") print(f" GPU: {GPU_ID}") print(f" Model: {MODEL_PATH}") print(f" Streams: {len(camera_urls)}") print(f" Batch size: {BATCH_SIZE}\n") # Create manager with PT conversion print("[1/3] Creating StreamConnectionManager...") manager = StreamConnectionManager( gpu_id=GPU_ID, batch_size=BATCH_SIZE, force_timeout=FORCE_TIMEOUT, enable_pt_conversion=True ) print("✓ Manager created") # Initialize model (transparent loading) print("\n[2/3] Initializing model...") try: manager.initialize( model_path=MODEL_PATH, model_id="detector", preprocess_fn=YOLOv8Utils.preprocess, postprocess_fn=YOLOv8Utils.postprocess, num_contexts=8 # Note: No pt_input_shapes or pt_precision needed for YOLO models! ) print("✓ Manager initialized") except Exception as e: print(f"✗ Failed to initialize: {e}") import traceback traceback.print_exc() return # Connect all streams print(f"\n[3/3] Connecting {len(camera_urls)} streams...") connections = {} for stream_id, rtsp_url in camera_urls: try: conn = manager.connect_stream( rtsp_url=rtsp_url, stream_id=stream_id, buffer_size=5 ) connections[stream_id] = conn print(f"✓ Connected: {stream_id}") except Exception as e: print(f"✗ Failed {stream_id}: {e}") if not connections: print("No streams connected") return print(f"\n{'=' * 80}") print(f"Multi-stream tracking running ({len(connections)} streams)") print("Frames from all streams are batched together!") print("Press Ctrl+C to stop") print(f"{'=' * 80}\n") # Track stats stream_stats = {sid: {'count': 0, 'start': time.time()} for sid in connections.keys()} total_results = 0 start_time = time.time() try: # Merge all result queues from all connections import queue as queue_module running = True while running: # Poll all connection queues (non-blocking) got_result = False for conn in connections.values(): try: # Non-blocking get from each connection's queue result = conn.result_queue.get_nowait() got_result = True total_results += 1 stream_id = result.stream_id if stream_id in stream_stats: stream_stats[stream_id]['count'] += 1 # Print stats every 100 results if total_results % 100 == 0: elapsed = time.time() - start_time total_fps = total_results / elapsed if elapsed > 0 else 0 print(f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS") for sid, stats in stream_stats.items(): s_elapsed = time.time() - stats['start'] s_fps = stats['count'] / s_elapsed if s_elapsed > 0 else 0 print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)") except queue_module.Empty: continue # Small sleep if no results to avoid busy loop if not got_result: time.sleep(0.01) except KeyboardInterrupt: print(f"\n✓ Interrupted") # Cleanup print(f"\n{'=' * 80}") print("Cleanup") print(f"{'=' * 80}") for conn in connections.values(): conn.stop() manager.shutdown() print("✓ Stopped") # Final stats elapsed = time.time() - start_time avg_fps = total_results / elapsed if elapsed > 0 else 0 print(f"\nFinal: {total_results} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)") if __name__ == "__main__": import sys if len(sys.argv) > 1 and sys.argv[1] == "single": main_single_stream() else: main_multi_stream()