profiling
This commit is contained in:
parent
7044b1e588
commit
c0ffa3967b
9 changed files with 354 additions and 1298 deletions
|
|
@ -13,6 +13,8 @@ import asyncio
|
|||
import time
|
||||
import os
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from services import (
|
||||
StreamConnectionManager,
|
||||
|
|
@ -32,17 +34,21 @@ async def main_single_stream():
|
|||
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.pt" # PT file will be auto-converted
|
||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # PT file will be auto-converted
|
||||
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
BATCH_SIZE = 4
|
||||
FORCE_TIMEOUT = 0.05
|
||||
ENABLE_DISPLAY = os.getenv('ENABLE_DISPLAY', 'false').lower() == 'true' # Set to 'true' to enable OpenCV display
|
||||
MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300')) # Stop after N frames (0 = unlimited)
|
||||
|
||||
print(f"\nConfiguration:")
|
||||
print(f" GPU: {GPU_ID}")
|
||||
print(f" Model: {MODEL_PATH}")
|
||||
print(f" Stream: {STREAM_URL}")
|
||||
print(f" Batch size: {BATCH_SIZE}")
|
||||
print(f" Force timeout: {FORCE_TIMEOUT}s\n")
|
||||
print(f" Force timeout: {FORCE_TIMEOUT}s")
|
||||
print(f" Display: {'Enabled' if ENABLE_DISPLAY else 'Disabled (inference only)'}")
|
||||
print(f" Max frames: {MAX_FRAMES if MAX_FRAMES > 0 else 'Unlimited'}\n")
|
||||
|
||||
# Create StreamConnectionManager with PT conversion enabled
|
||||
print("[1/3] Creating StreamConnectionManager...")
|
||||
|
|
@ -94,14 +100,68 @@ async def main_single_stream():
|
|||
print("Press Ctrl+C to stop")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Stream results
|
||||
# Stream results with optional OpenCV visualization
|
||||
result_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
# Create window only if display is enabled
|
||||
if ENABLE_DISPLAY:
|
||||
cv2.namedWindow("Object Tracking", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("Object Tracking", 1280, 720)
|
||||
|
||||
try:
|
||||
async for result in connection.tracking_results():
|
||||
result_count += 1
|
||||
|
||||
# Check if we've reached max frames
|
||||
if MAX_FRAMES > 0 and result_count >= MAX_FRAMES:
|
||||
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
|
||||
break
|
||||
|
||||
# OpenCV visualization (only if enabled)
|
||||
if ENABLE_DISPLAY:
|
||||
# Get latest frame from decoder (as CPU numpy array)
|
||||
frame = connection.decoder.get_latest_frame_cpu(rgb=True)
|
||||
|
||||
if frame is not None:
|
||||
# Convert to BGR for OpenCV
|
||||
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Draw tracked objects
|
||||
for obj in result.tracked_objects:
|
||||
# Get bbox coordinates
|
||||
x1, y1, x2, y2 = map(int, obj.bbox)
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# Draw track ID and class name
|
||||
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
|
||||
# Draw label background
|
||||
cv2.rectangle(frame_bgr, (x1, y1 - label_size[1] - 10),
|
||||
(x1 + label_size[0], y1), (0, 255, 0), -1)
|
||||
|
||||
# Draw label text
|
||||
cv2.putText(frame_bgr, label, (x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
||||
|
||||
# Draw FPS and object count
|
||||
elapsed = time.time() - start_time
|
||||
fps = result_count / elapsed if elapsed > 0 else 0
|
||||
info_text = f"FPS: {fps:.1f} | Objects: {len(result.tracked_objects)} | Frame: {result_count}"
|
||||
cv2.putText(frame_bgr, info_text, (10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||
|
||||
# Display frame
|
||||
cv2.imshow("Object Tracking", frame_bgr)
|
||||
|
||||
# Check for 'q' key to quit
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
print(f"\n✓ Quit by user (pressed 'q')")
|
||||
break
|
||||
|
||||
# Print stats every 30 results
|
||||
if result_count % 30 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
|
|
@ -110,7 +170,7 @@ async def main_single_stream():
|
|||
print(f"\nResults: {result_count} | FPS: {fps:.1f}")
|
||||
print(f" Stream: {result.stream_id}")
|
||||
print(f" Objects: {len(result.tracked_objects)}")
|
||||
|
||||
|
||||
if result.tracked_objects:
|
||||
class_counts = {}
|
||||
for obj in result.tracked_objects:
|
||||
|
|
@ -125,6 +185,10 @@ async def main_single_stream():
|
|||
print("Cleanup")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Close OpenCV window if it was opened
|
||||
if ENABLE_DISPLAY:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
await connection.stop()
|
||||
await manager.shutdown()
|
||||
print("✓ Stopped")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue