165 lines
4.6 KiB
Python
165 lines
4.6 KiB
Python
"""
|
|
Profiling script for the real-time object tracking pipeline.
|
|
|
|
This script runs the single-stream example from test_tracking_realtime.py
|
|
under the Python profiler (cProfile) to identify performance bottlenecks.
|
|
|
|
Usage:
|
|
python scripts/profiling.py
|
|
|
|
The script will print a summary of the most time-consuming functions
|
|
at the end of the run.
|
|
"""
|
|
|
|
import asyncio
|
|
import cProfile
|
|
import pstats
|
|
import io
|
|
import time
|
|
import os
|
|
import torch
|
|
import cv2
|
|
from dotenv import load_dotenv
|
|
|
|
# Add project root to path to allow imports from services
|
|
import sys
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
|
|
from services import (
|
|
StreamConnectionManager,
|
|
YOLOv8Utils,
|
|
)
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
|
|
async def profiled_main():
|
|
"""
|
|
Single stream example with event-driven architecture, adapted for profiling.
|
|
This function is a modified version of main_single_stream from test_tracking_realtime.py
|
|
"""
|
|
print("=" * 80)
|
|
print("Profiling: Event-Driven GPU-Accelerated Object Tracking")
|
|
print("=" * 80)
|
|
|
|
# Configuration
|
|
GPU_ID = 0
|
|
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
|
|
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
|
BATCH_SIZE = 4
|
|
FORCE_TIMEOUT = 0.05
|
|
# NOTE: Display is disabled for profiling to isolate pipeline performance
|
|
ENABLE_DISPLAY = False
|
|
# Run for a limited number of frames to get a representative profile
|
|
MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300'))
|
|
|
|
print(f"\nConfiguration:")
|
|
print(f" GPU: {GPU_ID}")
|
|
print(f" Model: {MODEL_PATH}")
|
|
print(f" Stream: {STREAM_URL}")
|
|
print(f" Batch size: {BATCH_SIZE}")
|
|
print(f" Force timeout: {FORCE_TIMEOUT}s")
|
|
print(f" Display: Disabled for profiling")
|
|
print(f" Max frames: {MAX_FRAMES}\n")
|
|
|
|
# Create StreamConnectionManager
|
|
print("[1/3] Creating StreamConnectionManager...")
|
|
manager = StreamConnectionManager(
|
|
gpu_id=GPU_ID,
|
|
batch_size=BATCH_SIZE,
|
|
force_timeout=FORCE_TIMEOUT,
|
|
enable_pt_conversion=True
|
|
)
|
|
print("✓ Manager created")
|
|
|
|
# Initialize with PT model
|
|
print("\n[2/3] Initializing with PT model...")
|
|
try:
|
|
await manager.initialize(
|
|
model_path=MODEL_PATH,
|
|
model_id="detector",
|
|
preprocess_fn=YOLOv8Utils.preprocess,
|
|
postprocess_fn=YOLOv8Utils.postprocess,
|
|
num_contexts=4,
|
|
pt_input_shapes={"images": (1, 3, 640, 640)},
|
|
pt_precision=torch.float16
|
|
)
|
|
print("✓ Manager initialized")
|
|
except Exception as e:
|
|
print(f"✗ Failed to initialize: {e}")
|
|
return
|
|
|
|
# Connect stream
|
|
print("\n[3/3] Connecting to stream...")
|
|
try:
|
|
connection = await manager.connect_stream(
|
|
rtsp_url=STREAM_URL,
|
|
stream_id="camera_1",
|
|
buffer_size=30
|
|
)
|
|
print(f"✓ Stream connected: camera_1")
|
|
except Exception as e:
|
|
print(f"✗ Failed to connect stream: {e}")
|
|
await manager.shutdown()
|
|
return
|
|
|
|
print(f"\n{'=' * 80}")
|
|
print(f"Profiling is running for {MAX_FRAMES} frames...")
|
|
print(f"{ '=' * 80}\n")
|
|
|
|
result_count = 0
|
|
start_time = time.time()
|
|
|
|
try:
|
|
async for result in connection.tracking_results():
|
|
result_count += 1
|
|
if result_count >= MAX_FRAMES:
|
|
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
|
|
break
|
|
|
|
if result_count % 50 == 0:
|
|
print(f" Processed {result_count}/{MAX_FRAMES} frames...")
|
|
|
|
except KeyboardInterrupt:
|
|
print(f"\n✓ Interrupted by user")
|
|
|
|
# Cleanup
|
|
print(f"\n{'=' * 80}")
|
|
print("Cleanup")
|
|
print(f"{ '=' * 80}")
|
|
|
|
await connection.stop()
|
|
await manager.shutdown()
|
|
print("✓ Stopped")
|
|
|
|
# Final stats
|
|
elapsed = time.time() - start_time
|
|
avg_fps = result_count / elapsed if elapsed > 0 else 0
|
|
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Create a profiler object
|
|
profiler = cProfile.Profile()
|
|
|
|
# Run the async main function under the profiler
|
|
print("Starting profiler...")
|
|
profiler.enable()
|
|
|
|
asyncio.run(profiled_main())
|
|
|
|
profiler.disable()
|
|
print("Profiling complete.")
|
|
|
|
# Print the stats
|
|
s = io.StringIO()
|
|
# Sort stats by cumulative time
|
|
sortby = pstats.SortKey.CUMULATIVE
|
|
ps = pstats.Stats(profiler, stream=s).sort_stats(sortby)
|
|
ps.print_stats(30) # Print top 30 functions
|
|
|
|
print("\n" + "="*80)
|
|
print("PROFILING RESULTS (Top 30, sorted by cumulative time)")
|
|
print("="*80)
|
|
print(s.getvalue())
|