ultralytic export
This commit is contained in:
parent
bf7b68edb1
commit
fdaeb9981c
14 changed files with 2241 additions and 507 deletions
|
|
@ -9,193 +9,25 @@ This script demonstrates:
|
|||
- Automatic PT to TensorRT conversion
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
import torch
|
||||
import threading
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from services import (
|
||||
StreamConnectionManager,
|
||||
YOLOv8Utils,
|
||||
COCO_CLASSES,
|
||||
StreamConnectionManager,
|
||||
UltralyticsExporter,
|
||||
YOLOv8Utils,
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
def main_single_stream():
|
||||
"""Single stream example with event-driven architecture."""
|
||||
print("=" * 80)
|
||||
print("Event-Driven GPU-Accelerated Object Tracking - Single Stream")
|
||||
print("=" * 80)
|
||||
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
|
||||
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
BATCH_SIZE = 4
|
||||
FORCE_TIMEOUT = 0.05
|
||||
ENABLE_DISPLAY = os.getenv('ENABLE_DISPLAY', 'false').lower() == 'true' # Set to 'true' to enable OpenCV display
|
||||
MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300')) # Stop after N frames (0 = unlimited)
|
||||
|
||||
print(f"\nConfiguration:")
|
||||
print(f" GPU: {GPU_ID}")
|
||||
print(f" Model: {MODEL_PATH}")
|
||||
print(f" Stream: {STREAM_URL}")
|
||||
print(f" Batch size: {BATCH_SIZE}")
|
||||
print(f" Force timeout: {FORCE_TIMEOUT}s")
|
||||
print(f" Display: {'Enabled' if ENABLE_DISPLAY else 'Disabled (inference only)'}")
|
||||
print(f" Max frames: {MAX_FRAMES if MAX_FRAMES > 0 else 'Unlimited'}\n")
|
||||
|
||||
# Create StreamConnectionManager with PT conversion enabled
|
||||
print("[1/3] Creating StreamConnectionManager...")
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=GPU_ID,
|
||||
batch_size=BATCH_SIZE,
|
||||
force_timeout=FORCE_TIMEOUT,
|
||||
enable_pt_conversion=True # Enable PT conversion
|
||||
)
|
||||
print("✓ Manager created")
|
||||
|
||||
# Initialize with model (transparent loading - no manual parameters needed)
|
||||
print("\n[2/3] Initializing model...")
|
||||
print("Note: YOLO models auto-convert to native TensorRT .engine (first time only)")
|
||||
print("Metadata is auto-detected from model - no manual input_shapes needed!\n")
|
||||
|
||||
try:
|
||||
manager.initialize(
|
||||
model_path=MODEL_PATH,
|
||||
model_id="detector",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
num_contexts=4
|
||||
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
||||
)
|
||||
print("✓ Manager initialized")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to initialize: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Connect stream
|
||||
print("\n[3/3] Connecting to stream...")
|
||||
try:
|
||||
connection = manager.connect_stream(
|
||||
rtsp_url=STREAM_URL,
|
||||
stream_id="camera_1",
|
||||
buffer_size=30
|
||||
)
|
||||
print(f"✓ Stream connected: camera_1")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to connect stream: {e}")
|
||||
return
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Event-driven tracking is running!")
|
||||
print("Press Ctrl+C to stop")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Stream results with optional OpenCV visualization
|
||||
result_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
# Create window only if display is enabled
|
||||
if ENABLE_DISPLAY:
|
||||
cv2.namedWindow("Object Tracking", cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow("Object Tracking", 1280, 720)
|
||||
|
||||
try:
|
||||
for result in connection.tracking_results():
|
||||
result_count += 1
|
||||
|
||||
# Check if we've reached max frames
|
||||
if MAX_FRAMES > 0 and result_count >= MAX_FRAMES:
|
||||
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
|
||||
break
|
||||
|
||||
# OpenCV visualization (only if enabled)
|
||||
if ENABLE_DISPLAY:
|
||||
# Get latest frame from decoder (as CPU numpy array)
|
||||
frame = connection.decoder.get_latest_frame_cpu(rgb=True)
|
||||
|
||||
if frame is not None:
|
||||
# Convert to BGR for OpenCV
|
||||
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Draw tracked objects
|
||||
for obj in result.tracked_objects:
|
||||
# Get bbox coordinates
|
||||
x1, y1, x2, y2 = map(int, obj.bbox)
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# Draw track ID and class name
|
||||
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
|
||||
# Draw label background
|
||||
cv2.rectangle(frame_bgr, (x1, y1 - label_size[1] - 10),
|
||||
(x1 + label_size[0], y1), (0, 255, 0), -1)
|
||||
|
||||
# Draw label text
|
||||
cv2.putText(frame_bgr, label, (x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
||||
|
||||
# Draw FPS and object count
|
||||
elapsed = time.time() - start_time
|
||||
fps = result_count / elapsed if elapsed > 0 else 0
|
||||
info_text = f"FPS: {fps:.1f} | Objects: {len(result.tracked_objects)} | Frame: {result_count}"
|
||||
cv2.putText(frame_bgr, info_text, (10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||
|
||||
# Display frame
|
||||
cv2.imshow("Object Tracking", frame_bgr)
|
||||
|
||||
# Check for 'q' key to quit
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
print(f"\n✓ Quit by user (pressed 'q')")
|
||||
break
|
||||
|
||||
# Print stats every 30 results
|
||||
if result_count % 30 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
fps = result_count / elapsed if elapsed > 0 else 0
|
||||
|
||||
print(f"\nResults: {result_count} | FPS: {fps:.1f}")
|
||||
print(f" Stream: {result.stream_id}")
|
||||
print(f" Objects: {len(result.tracked_objects)}")
|
||||
|
||||
if result.tracked_objects:
|
||||
class_counts = {}
|
||||
for obj in result.tracked_objects:
|
||||
class_counts[obj.class_name] = class_counts.get(obj.class_name, 0) + 1
|
||||
print(f" Classes: {class_counts}")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n✓ Interrupted by user")
|
||||
|
||||
# Cleanup
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Cleanup")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Close OpenCV window if it was opened
|
||||
if ENABLE_DISPLAY:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
connection.stop()
|
||||
manager.shutdown()
|
||||
print("✓ Stopped")
|
||||
|
||||
# Final stats
|
||||
elapsed = time.time() - start_time
|
||||
avg_fps = result_count / elapsed if elapsed > 0 else 0
|
||||
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
||||
|
||||
|
||||
def main_multi_stream():
|
||||
"""Multi-stream example with batched inference."""
|
||||
|
|
@ -206,14 +38,18 @@ def main_multi_stream():
|
|||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
|
||||
BATCH_SIZE = 16
|
||||
USE_ULTRALYTICS = (
|
||||
os.getenv("USE_ULTRALYTICS", "true").lower() == "true"
|
||||
) # Use Ultralytics engine for YOLO
|
||||
BATCH_SIZE = 2 # Reduced to 2 to avoid GPU memory issues
|
||||
FORCE_TIMEOUT = 0.05
|
||||
ENABLE_DISPLAY = os.getenv("ENABLE_DISPLAY", "true").lower() == "true"
|
||||
|
||||
# Load camera URLs
|
||||
camera_urls = []
|
||||
i = 1
|
||||
while True:
|
||||
url = os.getenv(f'CAMERA_URL_{i}')
|
||||
url = os.getenv(f"CAMERA_URL_{i}")
|
||||
if url:
|
||||
camera_urls.append((f"camera_{i}", url))
|
||||
i += 1
|
||||
|
|
@ -230,13 +66,16 @@ def main_multi_stream():
|
|||
print(f" Streams: {len(camera_urls)}")
|
||||
print(f" Batch size: {BATCH_SIZE}\n")
|
||||
|
||||
# Create manager with PT conversion
|
||||
# Create manager with backend selection
|
||||
print("[1/3] Creating StreamConnectionManager...")
|
||||
backend = "ultralytics"
|
||||
print(f" Backend: {backend}")
|
||||
manager = StreamConnectionManager(
|
||||
gpu_id=GPU_ID,
|
||||
batch_size=BATCH_SIZE,
|
||||
force_timeout=FORCE_TIMEOUT,
|
||||
enable_pt_conversion=True
|
||||
enable_pt_conversion=True,
|
||||
backend=backend,
|
||||
)
|
||||
print("✓ Manager created")
|
||||
|
||||
|
|
@ -248,30 +87,52 @@ def main_multi_stream():
|
|||
model_id="detector",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
num_contexts=8
|
||||
num_contexts=1, # Single context to minimize GPU memory usage
|
||||
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
||||
)
|
||||
print("✓ Manager initialized")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to initialize: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Connect all streams
|
||||
print(f"\n[3/3] Connecting {len(camera_urls)} streams...")
|
||||
# Connect all streams in parallel using threads
|
||||
print(f"\n[3/3] Connecting {len(camera_urls)} streams in parallel...")
|
||||
connections = {}
|
||||
for stream_id, rtsp_url in camera_urls:
|
||||
connection_threads = []
|
||||
connection_results = {}
|
||||
|
||||
def connect_stream(stream_id, rtsp_url):
|
||||
"""Thread worker to connect a single stream"""
|
||||
try:
|
||||
conn = manager.connect_stream(
|
||||
rtsp_url=rtsp_url,
|
||||
stream_id=stream_id,
|
||||
buffer_size=5
|
||||
rtsp_url=rtsp_url, stream_id=stream_id, buffer_size=3
|
||||
)
|
||||
connections[stream_id] = conn
|
||||
print(f"✓ Connected: {stream_id}")
|
||||
connection_results[stream_id] = ("success", conn)
|
||||
except Exception as e:
|
||||
print(f"✗ Failed {stream_id}: {e}")
|
||||
connection_results[stream_id] = ("error", str(e))
|
||||
|
||||
# Start all connection threads
|
||||
for stream_id, rtsp_url in camera_urls:
|
||||
thread = threading.Thread(
|
||||
target=connect_stream, args=(stream_id, rtsp_url), daemon=True
|
||||
)
|
||||
thread.start()
|
||||
connection_threads.append(thread)
|
||||
|
||||
# Wait for all connections to complete
|
||||
for thread in connection_threads:
|
||||
thread.join()
|
||||
|
||||
# Collect results
|
||||
for stream_id, (status, result) in connection_results.items():
|
||||
if status == "success":
|
||||
connections[stream_id] = result
|
||||
print(f"✓ Connected: {stream_id}")
|
||||
else:
|
||||
print(f"✗ Failed {stream_id}: {result}")
|
||||
|
||||
if not connections:
|
||||
print("No streams connected")
|
||||
|
|
@ -284,10 +145,20 @@ def main_multi_stream():
|
|||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Track stats
|
||||
stream_stats = {sid: {'count': 0, 'start': time.time()} for sid in connections.keys()}
|
||||
stream_stats = {
|
||||
sid: {"count": 0, "start": time.time()} for sid in connections.keys()
|
||||
}
|
||||
total_results = 0
|
||||
start_time = time.time()
|
||||
|
||||
# Create windows for each stream if display enabled
|
||||
if ENABLE_DISPLAY:
|
||||
for stream_id in connections.keys():
|
||||
cv2.namedWindow(stream_id, cv2.WINDOW_NORMAL)
|
||||
cv2.resizeWindow(
|
||||
stream_id, 640, 360
|
||||
) # Smaller windows for multiple streams
|
||||
|
||||
try:
|
||||
# Merge all result queues from all connections
|
||||
import queue as queue_module
|
||||
|
|
@ -306,27 +177,92 @@ def main_multi_stream():
|
|||
stream_id = result.stream_id
|
||||
|
||||
if stream_id in stream_stats:
|
||||
stream_stats[stream_id]['count'] += 1
|
||||
stream_stats[stream_id]["count"] += 1
|
||||
|
||||
# Display visualization if enabled
|
||||
if ENABLE_DISPLAY:
|
||||
# Get latest frame from decoder (already in CPU memory as numpy RGB)
|
||||
frame_rgb = conn.decoder.get_latest_frame_cpu(rgb=True)
|
||||
if frame_rgb is not None:
|
||||
# Convert RGB to BGR for OpenCV
|
||||
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Draw bounding boxes
|
||||
for obj in result.tracked_objects:
|
||||
x1, y1, x2, y2 = map(int, obj.bbox)
|
||||
|
||||
# Draw box
|
||||
cv2.rectangle(
|
||||
frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2
|
||||
)
|
||||
|
||||
# Draw label with ID and class
|
||||
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||
(label_w, label_h), _ = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
cv2.rectangle(
|
||||
frame_bgr,
|
||||
(x1, y1 - label_h - 10),
|
||||
(x1 + label_w, y1),
|
||||
(0, 255, 0),
|
||||
-1,
|
||||
)
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
label,
|
||||
(x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 0, 0),
|
||||
1,
|
||||
)
|
||||
|
||||
# Show FPS on frame
|
||||
s_elapsed = time.time() - stream_stats[stream_id]["start"]
|
||||
s_fps = (
|
||||
stream_stats[stream_id]["count"] / s_elapsed
|
||||
if s_elapsed > 0
|
||||
else 0
|
||||
)
|
||||
fps_text = f"{stream_id}: {s_fps:.1f} FPS | {len(result.tracked_objects)} objects"
|
||||
cv2.putText(
|
||||
frame_bgr,
|
||||
fps_text,
|
||||
(10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.7,
|
||||
(0, 255, 0),
|
||||
2,
|
||||
)
|
||||
|
||||
# Display
|
||||
cv2.imshow(stream_id, frame_bgr)
|
||||
|
||||
# Print stats every 100 results
|
||||
if total_results % 100 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
total_fps = total_results / elapsed if elapsed > 0 else 0
|
||||
|
||||
print(f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS")
|
||||
print(
|
||||
f"\nTotal: {total_results} | {elapsed:.1f}s | {total_fps:.1f} FPS"
|
||||
)
|
||||
for sid, stats in stream_stats.items():
|
||||
s_elapsed = time.time() - stats['start']
|
||||
s_fps = stats['count'] / s_elapsed if s_elapsed > 0 else 0
|
||||
s_elapsed = time.time() - stats["start"]
|
||||
s_fps = stats["count"] / s_elapsed if s_elapsed > 0 else 0
|
||||
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
||||
|
||||
except queue_module.Empty:
|
||||
continue
|
||||
|
||||
# Process OpenCV events to keep windows responsive
|
||||
if ENABLE_DISPLAY:
|
||||
cv2.waitKey(1)
|
||||
|
||||
# Small sleep if no results to avoid busy loop
|
||||
if not got_result:
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n✓ Interrupted")
|
||||
|
||||
|
|
@ -335,6 +271,10 @@ def main_multi_stream():
|
|||
print("Cleanup")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Close OpenCV windows if they were opened
|
||||
if ENABLE_DISPLAY:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
for conn in connections.values():
|
||||
conn.stop()
|
||||
manager.shutdown()
|
||||
|
|
@ -347,8 +287,4 @@ def main_multi_stream():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "single":
|
||||
main_single_stream()
|
||||
else:
|
||||
main_multi_stream()
|
||||
main_multi_stream()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue