python-rtsp-worker/scripts/profiling.py
2025-11-10 00:10:53 +07:00

165 lines
4.6 KiB
Python

"""
Profiling script for the real-time object tracking pipeline.
This script runs the single-stream example from test_tracking_realtime.py
under the Python profiler (cProfile) to identify performance bottlenecks.
Usage:
python scripts/profiling.py
The script will print a summary of the most time-consuming functions
at the end of the run.
"""
import asyncio
import cProfile
import pstats
import io
import time
import os
import torch
import cv2
from dotenv import load_dotenv
# Add project root to path to allow imports from services
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from services import (
StreamConnectionManager,
YOLOv8Utils,
)
# Load environment variables
load_dotenv()
async def profiled_main():
"""
Single stream example with event-driven architecture, adapted for profiling.
This function is a modified version of main_single_stream from test_tracking_realtime.py
"""
print("=" * 80)
print("Profiling: Event-Driven GPU-Accelerated Object Tracking")
print("=" * 80)
# Configuration
GPU_ID = 0
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
BATCH_SIZE = 4
FORCE_TIMEOUT = 0.05
# NOTE: Display is disabled for profiling to isolate pipeline performance
ENABLE_DISPLAY = False
# Run for a limited number of frames to get a representative profile
MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300'))
print(f"\nConfiguration:")
print(f" GPU: {GPU_ID}")
print(f" Model: {MODEL_PATH}")
print(f" Stream: {STREAM_URL}")
print(f" Batch size: {BATCH_SIZE}")
print(f" Force timeout: {FORCE_TIMEOUT}s")
print(f" Display: Disabled for profiling")
print(f" Max frames: {MAX_FRAMES}\n")
# Create StreamConnectionManager
print("[1/3] Creating StreamConnectionManager...")
manager = StreamConnectionManager(
gpu_id=GPU_ID,
batch_size=BATCH_SIZE,
force_timeout=FORCE_TIMEOUT,
enable_pt_conversion=True
)
print("✓ Manager created")
# Initialize with PT model
print("\n[2/3] Initializing with PT model...")
try:
await manager.initialize(
model_path=MODEL_PATH,
model_id="detector",
preprocess_fn=YOLOv8Utils.preprocess,
postprocess_fn=YOLOv8Utils.postprocess,
num_contexts=4,
pt_input_shapes={"images": (1, 3, 640, 640)},
pt_precision=torch.float16
)
print("✓ Manager initialized")
except Exception as e:
print(f"✗ Failed to initialize: {e}")
return
# Connect stream
print("\n[3/3] Connecting to stream...")
try:
connection = await manager.connect_stream(
rtsp_url=STREAM_URL,
stream_id="camera_1",
buffer_size=30
)
print(f"✓ Stream connected: camera_1")
except Exception as e:
print(f"✗ Failed to connect stream: {e}")
await manager.shutdown()
return
print(f"\n{'=' * 80}")
print(f"Profiling is running for {MAX_FRAMES} frames...")
print(f"{ '=' * 80}\n")
result_count = 0
start_time = time.time()
try:
async for result in connection.tracking_results():
result_count += 1
if result_count >= MAX_FRAMES:
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
break
if result_count % 50 == 0:
print(f" Processed {result_count}/{MAX_FRAMES} frames...")
except KeyboardInterrupt:
print(f"\n✓ Interrupted by user")
# Cleanup
print(f"\n{'=' * 80}")
print("Cleanup")
print(f"{ '=' * 80}")
await connection.stop()
await manager.shutdown()
print("✓ Stopped")
# Final stats
elapsed = time.time() - start_time
avg_fps = result_count / elapsed if elapsed > 0 else 0
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
if __name__ == "__main__":
# Create a profiler object
profiler = cProfile.Profile()
# Run the async main function under the profiler
print("Starting profiler...")
profiler.enable()
asyncio.run(profiled_main())
profiler.disable()
print("Profiling complete.")
# Print the stats
s = io.StringIO()
# Sort stats by cumulative time
sortby = pstats.SortKey.CUMULATIVE
ps = pstats.Stats(profiler, stream=s).sort_stats(sortby)
ps.print_stats(30) # Print top 30 functions
print("\n" + "="*80)
print("PROFILING RESULTS (Top 30, sorted by cumulative time)")
print("="*80)
print(s.getvalue())