149 lines
4.5 KiB
Python
149 lines
4.5 KiB
Python
"""
|
|
Add timing instrumentation to track where time is spent in the pipeline.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
import os
|
|
import torch
|
|
from dotenv import load_dotenv
|
|
import logging
|
|
|
|
import sys
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
|
|
# Monkey patch to add timing
|
|
original_handle_result = None
|
|
original_run_tracking = None
|
|
original_infer = None
|
|
|
|
timings = []
|
|
|
|
def patch_timing():
|
|
"""Add timing instrumentation to key functions"""
|
|
from services import stream_connection_manager, model_repository
|
|
|
|
global original_handle_result, original_run_tracking, original_infer
|
|
|
|
# Patch _handle_inference_result
|
|
original_handle_result = stream_connection_manager.StreamConnection._handle_inference_result
|
|
async def timed_handle_result(self, result):
|
|
t0 = time.perf_counter()
|
|
await original_handle_result(self, result)
|
|
t1 = time.perf_counter()
|
|
timings.append(('handle_result', (t1 - t0) * 1000))
|
|
stream_connection_manager.StreamConnection._handle_inference_result = timed_handle_result
|
|
|
|
# Patch _run_tracking_sync
|
|
original_run_tracking = stream_connection_manager.StreamConnection._run_tracking_sync
|
|
def timed_run_tracking(self, detections, min_confidence=0.7):
|
|
t0 = time.perf_counter()
|
|
result = original_run_tracking(self, detections, min_confidence)
|
|
t1 = time.perf_counter()
|
|
timings.append(('tracking', (t1 - t0) * 1000))
|
|
return result
|
|
stream_connection_manager.StreamConnection._run_tracking_sync = timed_run_tracking
|
|
|
|
# Patch infer
|
|
original_infer = model_repository.TensorRTModelRepository.infer
|
|
def timed_infer(self, model_id, inputs, synchronize=True):
|
|
t0 = time.perf_counter()
|
|
result = original_infer(self, model_id, inputs, synchronize)
|
|
t1 = time.perf_counter()
|
|
timings.append(('infer', (t1 - t0) * 1000))
|
|
return result
|
|
model_repository.TensorRTModelRepository.infer = timed_infer
|
|
|
|
async def instrumented_main():
|
|
"""Instrumented profiling"""
|
|
from services import StreamConnectionManager, YOLOv8Utils
|
|
|
|
load_dotenv()
|
|
|
|
print("=" * 80)
|
|
print("Timing Instrumentation")
|
|
print("=" * 80)
|
|
|
|
# Patch before creating manager
|
|
patch_timing()
|
|
|
|
# Configuration
|
|
GPU_ID = 0
|
|
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
|
|
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
|
BATCH_SIZE = 4
|
|
FORCE_TIMEOUT = 0.05
|
|
MAX_FRAMES = 30
|
|
|
|
print(f"\nConfiguration: GPU={GPU_ID}, BATCH={BATCH_SIZE}, MAX={MAX_FRAMES}\n")
|
|
|
|
# Create and initialize manager
|
|
manager = StreamConnectionManager(
|
|
gpu_id=GPU_ID,
|
|
batch_size=BATCH_SIZE,
|
|
force_timeout=FORCE_TIMEOUT,
|
|
enable_pt_conversion=True
|
|
)
|
|
|
|
await manager.initialize(
|
|
model_path=MODEL_PATH,
|
|
model_id="detector",
|
|
preprocess_fn=YOLOv8Utils.preprocess,
|
|
postprocess_fn=YOLOv8Utils.postprocess,
|
|
num_contexts=4,
|
|
pt_input_shapes={"images": (1, 3, 640, 640)},
|
|
pt_precision=torch.float16
|
|
)
|
|
|
|
connection = await manager.connect_stream(
|
|
rtsp_url=STREAM_URL,
|
|
stream_id="camera_1",
|
|
buffer_size=30
|
|
)
|
|
print("✓ Connected\n")
|
|
|
|
print(f"{'=' * 80}")
|
|
print(f"Processing {MAX_FRAMES} frames with timing...")
|
|
print(f"{'=' * 80}\n")
|
|
|
|
result_count = 0
|
|
start_time = time.time()
|
|
|
|
try:
|
|
async for result in connection.tracking_results():
|
|
result_count += 1
|
|
if result_count >= MAX_FRAMES:
|
|
break
|
|
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
# Cleanup
|
|
await connection.stop()
|
|
await manager.shutdown()
|
|
|
|
# Analysis
|
|
elapsed = time.time() - start_time
|
|
print(f"\nProcessed {result_count} frames in {elapsed:.1f}s ({result_count/elapsed:.2f} FPS)\n")
|
|
|
|
# Analyze timings
|
|
from collections import defaultdict
|
|
timing_stats = defaultdict(list)
|
|
for operation, duration in timings:
|
|
timing_stats[operation].append(duration)
|
|
|
|
print("=" * 80)
|
|
print("TIMING BREAKDOWN")
|
|
print("=" * 80)
|
|
for operation in ['infer', 'tracking', 'handle_result']:
|
|
if operation in timing_stats:
|
|
times = timing_stats[operation]
|
|
print(f"\n{operation}:")
|
|
print(f" Calls: {len(times)}")
|
|
print(f" Min: {min(times):.2f}ms")
|
|
print(f" Max: {max(times):.2f}ms")
|
|
print(f" Avg: {sum(times)/len(times):.2f}ms")
|
|
print(f" Total: {sum(times):.2f}ms")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(instrumented_main())
|