profiling
This commit is contained in:
parent
7044b1e588
commit
c0ffa3967b
9 changed files with 354 additions and 1298 deletions
165
scripts/profiling.py
Normal file
165
scripts/profiling.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
"""
|
||||
Profiling script for the real-time object tracking pipeline.
|
||||
|
||||
This script runs the single-stream example from test_tracking_realtime.py
|
||||
under the Python profiler (cProfile) to identify performance bottlenecks.
|
||||
|
||||
Usage:
|
||||
python scripts/profiling.py
|
||||
|
||||
The script will print a summary of the most time-consuming functions
|
||||
at the end of the run.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import cProfile
|
||||
import pstats
|
||||
import io
|
||||
import time
|
||||
import os
|
||||
import torch
|
||||
import cv2
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Add project root to path to allow imports from services
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from services import (
|
||||
StreamConnectionManager,
|
||||
YOLOv8Utils,
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def profiled_main():
|
||||
"""
|
||||
Single stream example with event-driven architecture, adapted for profiling.
|
||||
This function is a modified version of main_single_stream from test_tracking_realtime.py
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("Profiling: Event-Driven GPU-Accelerated Object Tracking")
|
||||
print("=" * 80)
|
||||
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
|
||||
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
BATCH_SIZE = 4
|
||||
FORCE_TIMEOUT = 0.05
|
||||
# NOTE: Display is disabled for profiling to isolate pipeline performance
|
||||
ENABLE_DISPLAY = False
|
||||
# Run for a limited number of frames to get a representative profile
|
||||
MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300'))
|
||||
|
||||
print(f"\nConfiguration:")
|
||||
print(f" GPU: {GPU_ID}")
|
||||
print(f" Model: {MODEL_PATH}")
|
||||
print(f" Stream: {STREAM_URL}")
|
||||
print(f" Batch size: {BATCH_SIZE}")
|
||||
print(f" Force timeout: {FORCE_TIMEOUT}s")
|
||||
print(f" Display: Disabled for profiling")
|
||||
print(f" Max frames: {MAX_FRAMES}\n")
|
||||
|
||||
# Create StreamConnectionManager
|
||||
print("[1/3] Creating StreamConnectionManager...")
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=GPU_ID,
|
||||
batch_size=BATCH_SIZE,
|
||||
force_timeout=FORCE_TIMEOUT,
|
||||
enable_pt_conversion=True
|
||||
)
|
||||
print("✓ Manager created")
|
||||
|
||||
# Initialize with PT model
|
||||
print("\n[2/3] Initializing with PT model...")
|
||||
try:
|
||||
await manager.initialize(
|
||||
model_path=MODEL_PATH,
|
||||
model_id="detector",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
num_contexts=4,
|
||||
pt_input_shapes={"images": (1, 3, 640, 640)},
|
||||
pt_precision=torch.float16
|
||||
)
|
||||
print("✓ Manager initialized")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to initialize: {e}")
|
||||
return
|
||||
|
||||
# Connect stream
|
||||
print("\n[3/3] Connecting to stream...")
|
||||
try:
|
||||
connection = await manager.connect_stream(
|
||||
rtsp_url=STREAM_URL,
|
||||
stream_id="camera_1",
|
||||
buffer_size=30
|
||||
)
|
||||
print(f"✓ Stream connected: camera_1")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to connect stream: {e}")
|
||||
await manager.shutdown()
|
||||
return
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Profiling is running for {MAX_FRAMES} frames...")
|
||||
print(f"{ '=' * 80}\n")
|
||||
|
||||
result_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
async for result in connection.tracking_results():
|
||||
result_count += 1
|
||||
if result_count >= MAX_FRAMES:
|
||||
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
|
||||
break
|
||||
|
||||
if result_count % 50 == 0:
|
||||
print(f" Processed {result_count}/{MAX_FRAMES} frames...")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n✓ Interrupted by user")
|
||||
|
||||
# Cleanup
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Cleanup")
|
||||
print(f"{ '=' * 80}")
|
||||
|
||||
await connection.stop()
|
||||
await manager.shutdown()
|
||||
print("✓ Stopped")
|
||||
|
||||
# Final stats
|
||||
elapsed = time.time() - start_time
|
||||
avg_fps = result_count / elapsed if elapsed > 0 else 0
|
||||
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create a profiler object
|
||||
profiler = cProfile.Profile()
|
||||
|
||||
# Run the async main function under the profiler
|
||||
print("Starting profiler...")
|
||||
profiler.enable()
|
||||
|
||||
asyncio.run(profiled_main())
|
||||
|
||||
profiler.disable()
|
||||
print("Profiling complete.")
|
||||
|
||||
# Print the stats
|
||||
s = io.StringIO()
|
||||
# Sort stats by cumulative time
|
||||
sortby = pstats.SortKey.CUMULATIVE
|
||||
ps = pstats.Stats(profiler, stream=s).sort_stats(sortby)
|
||||
ps.print_stats(30) # Print top 30 functions
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("PROFILING RESULTS (Top 30, sorted by cumulative time)")
|
||||
print("="*80)
|
||||
print(s.getvalue())
|
||||
Loading…
Add table
Add a link
Reference in a new issue