profiling
This commit is contained in:
parent
7044b1e588
commit
c0ffa3967b
9 changed files with 354 additions and 1298 deletions
File diff suppressed because it is too large
Load diff
|
|
@ -2,4 +2,8 @@
|
||||||
|
|
||||||
- It doesn't really care what pt file is included and it always use YOLO's model id, for example if id 1 is apple, it still say person. maybe extract class list from yolo's .pt somehow?
|
- It doesn't really care what pt file is included and it always use YOLO's model id, for example if id 1 is apple, it still say person. maybe extract class list from yolo's .pt somehow?
|
||||||
|
|
||||||
|
- It read frame a bit too fast. it say it's infering at 20-ish fps but the actual camera is only 5 fps or so
|
||||||
|
|
||||||
- Potential race condition issue when multiple camera try to init with the same unconverted model.
|
- Potential race condition issue when multiple camera try to init with the same unconverted model.
|
||||||
|
|
||||||
|
- Blurry asyncio archtecture, require documentations
|
||||||
165
scripts/profiling.py
Normal file
165
scripts/profiling.py
Normal file
|
|
@ -0,0 +1,165 @@
|
||||||
|
"""
|
||||||
|
Profiling script for the real-time object tracking pipeline.
|
||||||
|
|
||||||
|
This script runs the single-stream example from test_tracking_realtime.py
|
||||||
|
under the Python profiler (cProfile) to identify performance bottlenecks.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/profiling.py
|
||||||
|
|
||||||
|
The script will print a summary of the most time-consuming functions
|
||||||
|
at the end of the run.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import cProfile
|
||||||
|
import pstats
|
||||||
|
import io
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Add project root to path to allow imports from services
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
from services import (
|
||||||
|
StreamConnectionManager,
|
||||||
|
YOLOv8Utils,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
async def profiled_main():
|
||||||
|
"""
|
||||||
|
Single stream example with event-driven architecture, adapted for profiling.
|
||||||
|
This function is a modified version of main_single_stream from test_tracking_realtime.py
|
||||||
|
"""
|
||||||
|
print("=" * 80)
|
||||||
|
print("Profiling: Event-Driven GPU-Accelerated Object Tracking")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
GPU_ID = 0
|
||||||
|
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
|
||||||
|
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
FORCE_TIMEOUT = 0.05
|
||||||
|
# NOTE: Display is disabled for profiling to isolate pipeline performance
|
||||||
|
ENABLE_DISPLAY = False
|
||||||
|
# Run for a limited number of frames to get a representative profile
|
||||||
|
MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300'))
|
||||||
|
|
||||||
|
print(f"\nConfiguration:")
|
||||||
|
print(f" GPU: {GPU_ID}")
|
||||||
|
print(f" Model: {MODEL_PATH}")
|
||||||
|
print(f" Stream: {STREAM_URL}")
|
||||||
|
print(f" Batch size: {BATCH_SIZE}")
|
||||||
|
print(f" Force timeout: {FORCE_TIMEOUT}s")
|
||||||
|
print(f" Display: Disabled for profiling")
|
||||||
|
print(f" Max frames: {MAX_FRAMES}\n")
|
||||||
|
|
||||||
|
# Create StreamConnectionManager
|
||||||
|
print("[1/3] Creating StreamConnectionManager...")
|
||||||
|
manager = StreamConnectionManager(
|
||||||
|
gpu_id=GPU_ID,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
force_timeout=FORCE_TIMEOUT,
|
||||||
|
enable_pt_conversion=True
|
||||||
|
)
|
||||||
|
print("✓ Manager created")
|
||||||
|
|
||||||
|
# Initialize with PT model
|
||||||
|
print("\n[2/3] Initializing with PT model...")
|
||||||
|
try:
|
||||||
|
await manager.initialize(
|
||||||
|
model_path=MODEL_PATH,
|
||||||
|
model_id="detector",
|
||||||
|
preprocess_fn=YOLOv8Utils.preprocess,
|
||||||
|
postprocess_fn=YOLOv8Utils.postprocess,
|
||||||
|
num_contexts=4,
|
||||||
|
pt_input_shapes={"images": (1, 3, 640, 640)},
|
||||||
|
pt_precision=torch.float16
|
||||||
|
)
|
||||||
|
print("✓ Manager initialized")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Failed to initialize: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Connect stream
|
||||||
|
print("\n[3/3] Connecting to stream...")
|
||||||
|
try:
|
||||||
|
connection = await manager.connect_stream(
|
||||||
|
rtsp_url=STREAM_URL,
|
||||||
|
stream_id="camera_1",
|
||||||
|
buffer_size=30
|
||||||
|
)
|
||||||
|
print(f"✓ Stream connected: camera_1")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Failed to connect stream: {e}")
|
||||||
|
await manager.shutdown()
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Profiling is running for {MAX_FRAMES} frames...")
|
||||||
|
print(f"{ '=' * 80}\n")
|
||||||
|
|
||||||
|
result_count = 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for result in connection.tracking_results():
|
||||||
|
result_count += 1
|
||||||
|
if result_count >= MAX_FRAMES:
|
||||||
|
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
|
||||||
|
break
|
||||||
|
|
||||||
|
if result_count % 50 == 0:
|
||||||
|
print(f" Processed {result_count}/{MAX_FRAMES} frames...")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print(f"\n✓ Interrupted by user")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("Cleanup")
|
||||||
|
print(f"{ '=' * 80}")
|
||||||
|
|
||||||
|
await connection.stop()
|
||||||
|
await manager.shutdown()
|
||||||
|
print("✓ Stopped")
|
||||||
|
|
||||||
|
# Final stats
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
avg_fps = result_count / elapsed if elapsed > 0 else 0
|
||||||
|
print(f"\nFinal: {result_count} results in {elapsed:.1f}s ({avg_fps:.1f} FPS)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Create a profiler object
|
||||||
|
profiler = cProfile.Profile()
|
||||||
|
|
||||||
|
# Run the async main function under the profiler
|
||||||
|
print("Starting profiler...")
|
||||||
|
profiler.enable()
|
||||||
|
|
||||||
|
asyncio.run(profiled_main())
|
||||||
|
|
||||||
|
profiler.disable()
|
||||||
|
print("Profiling complete.")
|
||||||
|
|
||||||
|
# Print the stats
|
||||||
|
s = io.StringIO()
|
||||||
|
# Sort stats by cumulative time
|
||||||
|
sortby = pstats.SortKey.CUMULATIVE
|
||||||
|
ps = pstats.Stats(profiler, stream=s).sort_stats(sortby)
|
||||||
|
ps.print_stats(30) # Print top 30 functions
|
||||||
|
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("PROFILING RESULTS (Top 30, sorted by cumulative time)")
|
||||||
|
print("="*80)
|
||||||
|
print(s.getvalue())
|
||||||
|
|
@ -5,8 +5,7 @@ Services package for RTSP stream processing with GPU acceleration.
|
||||||
from .stream_decoder import StreamDecoderFactory, StreamDecoder, ConnectionStatus
|
from .stream_decoder import StreamDecoderFactory, StreamDecoder, ConnectionStatus
|
||||||
from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg
|
from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg
|
||||||
from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionContext, SharedEngine
|
from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionContext, SharedEngine
|
||||||
from .tracking_controller import TrackingController, TrackedObject
|
from .tracking_controller import ObjectTracker, TrackedObject, Detection
|
||||||
from .tracking_factory import TrackingFactory
|
|
||||||
from .yolo import YOLOv8Utils, COCO_CLASSES
|
from .yolo import YOLOv8Utils, COCO_CLASSES
|
||||||
from .model_controller import ModelController, BatchFrame, BufferState
|
from .model_controller import ModelController, BatchFrame, BufferState
|
||||||
from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult
|
from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult
|
||||||
|
|
@ -23,9 +22,9 @@ __all__ = [
|
||||||
'ModelMetadata',
|
'ModelMetadata',
|
||||||
'ExecutionContext',
|
'ExecutionContext',
|
||||||
'SharedEngine',
|
'SharedEngine',
|
||||||
'TrackingController',
|
'ObjectTracker',
|
||||||
'TrackedObject',
|
'TrackedObject',
|
||||||
'TrackingFactory',
|
'Detection',
|
||||||
'YOLOv8Utils',
|
'YOLOv8Utils',
|
||||||
'COCO_CLASSES',
|
'COCO_CLASSES',
|
||||||
'ModelController',
|
'ModelController',
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ import torch
|
||||||
|
|
||||||
from .model_controller import ModelController
|
from .model_controller import ModelController
|
||||||
from .stream_decoder import StreamDecoderFactory
|
from .stream_decoder import StreamDecoderFactory
|
||||||
from .tracking_factory import TrackingFactory
|
|
||||||
from .model_repository import TensorRTModelRepository
|
from .model_repository import TensorRTModelRepository
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -133,16 +132,20 @@ class StreamConnection:
|
||||||
|
|
||||||
async def _frame_poller(self):
|
async def _frame_poller(self):
|
||||||
"""Poll frames from threaded decoder and submit to model controller"""
|
"""Poll frames from threaded decoder and submit to model controller"""
|
||||||
last_frame_ptr = None
|
last_decoder_frame_count = -1
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
# Poll frame from decoder (runs in thread)
|
# Get current decoder frame count (no data transfer, just counter)
|
||||||
|
decoder_frame_count = self.decoder.get_frame_count()
|
||||||
|
|
||||||
|
# Check if decoder has a new frame (avoid reprocessing same frame)
|
||||||
|
if decoder_frame_count > last_decoder_frame_count:
|
||||||
|
# Poll frame from decoder (zero-copy - stays in VRAM)
|
||||||
frame = self.decoder.get_latest_frame(rgb=True)
|
frame = self.decoder.get_latest_frame(rgb=True)
|
||||||
|
|
||||||
# Check if we got a new frame (avoid reprocessing same frame)
|
if frame is not None:
|
||||||
if frame is not None and frame.data_ptr() != last_frame_ptr:
|
last_decoder_frame_count = decoder_frame_count
|
||||||
last_frame_ptr = frame.data_ptr()
|
|
||||||
self.last_frame_time = time.time()
|
self.last_frame_time = time.time()
|
||||||
self.frame_count += 1
|
self.frame_count += 1
|
||||||
|
|
||||||
|
|
@ -211,53 +214,37 @@ class StreamConnection:
|
||||||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||||
await self.error_queue.put(e)
|
await self.error_queue.put(e)
|
||||||
|
|
||||||
def _run_tracking_sync(self, detections):
|
def _run_tracking_sync(self, detections, min_confidence=0.7):
|
||||||
"""
|
"""
|
||||||
Run tracking synchronously (called from executor).
|
Run tracking synchronously (called from executor).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
|
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
|
||||||
|
min_confidence: Minimum confidence threshold for detections
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of TrackedObject instances
|
List of TrackedObject instances
|
||||||
"""
|
"""
|
||||||
# Use the TrackingController's internal tracking with detections
|
# Convert tensor detections to Detection objects, filtering by confidence
|
||||||
# We need to manually update tracks since we already have detections
|
from .tracking_controller import Detection
|
||||||
import torch
|
|
||||||
|
|
||||||
with self.tracking_controller._lock:
|
detection_list = []
|
||||||
self.tracking_controller._frame_count += 1
|
for det in detections:
|
||||||
|
confidence = float(det[4])
|
||||||
|
|
||||||
# If no detections, just cleanup and return current tracks
|
# Filter by confidence threshold (prevents track accumulation)
|
||||||
if len(detections) == 0:
|
if confidence < min_confidence:
|
||||||
self.tracking_controller._cleanup_stale_tracks()
|
continue
|
||||||
return list(self.tracking_controller._tracks.values())
|
|
||||||
|
|
||||||
# Run IoU tracking to associate detections with existing tracks
|
detection_list.append(Detection(
|
||||||
associations = self.tracking_controller._iou_tracking(detections)
|
bbox=det[:4].cpu().tolist(),
|
||||||
|
confidence=confidence,
|
||||||
|
class_id=int(det[5]) if det.shape[0] > 5 else 0,
|
||||||
|
class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown"
|
||||||
|
))
|
||||||
|
|
||||||
# Update or create tracks
|
# Update tracker with detections (lightweight, no model dependency!)
|
||||||
for (det_idx, track_id), detection in zip(associations, detections):
|
return self.tracking_controller.update(detection_list)
|
||||||
bbox = detection[:4].cpu().tolist()
|
|
||||||
confidence = float(detection[4])
|
|
||||||
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
|
|
||||||
|
|
||||||
if track_id == -1:
|
|
||||||
# Create new track
|
|
||||||
new_track = self.tracking_controller._create_track(
|
|
||||||
bbox, confidence, class_id, self.tracking_controller._frame_count
|
|
||||||
)
|
|
||||||
self.tracking_controller._tracks[new_track.track_id] = new_track
|
|
||||||
else:
|
|
||||||
# Update existing track
|
|
||||||
self.tracking_controller._tracks[track_id].update(
|
|
||||||
bbox, confidence, self.tracking_controller._frame_count
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cleanup stale tracks
|
|
||||||
self.tracking_controller._cleanup_stale_tracks()
|
|
||||||
|
|
||||||
return list(self.tracking_controller._tracks.values())
|
|
||||||
|
|
||||||
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
|
async def tracking_results(self) -> AsyncIterator[TrackingResult]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -341,7 +328,6 @@ class StreamConnectionManager:
|
||||||
|
|
||||||
# Factories
|
# Factories
|
||||||
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
|
||||||
self.tracking_factory = TrackingFactory(gpu_id=gpu_id)
|
|
||||||
self.model_repository = TensorRTModelRepository(
|
self.model_repository = TensorRTModelRepository(
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
enable_pt_conversion=enable_pt_conversion
|
enable_pt_conversion=enable_pt_conversion
|
||||||
|
|
@ -349,7 +335,6 @@ class StreamConnectionManager:
|
||||||
|
|
||||||
# Controllers
|
# Controllers
|
||||||
self.model_controller: Optional[ModelController] = None
|
self.model_controller: Optional[ModelController] = None
|
||||||
self.tracking_controller = None
|
|
||||||
|
|
||||||
# Connections
|
# Connections
|
||||||
self.connections: Dict[str, StreamConnection] = {}
|
self.connections: Dict[str, StreamConnection] = {}
|
||||||
|
|
@ -454,17 +439,16 @@ class StreamConnectionManager:
|
||||||
# Create decoder
|
# Create decoder
|
||||||
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
|
decoder = self.decoder_factory.create_decoder(rtsp_url, buffer_size=buffer_size)
|
||||||
|
|
||||||
# Create dedicated tracking controller for THIS stream
|
# Create lightweight tracker (NO model_repository dependency!)
|
||||||
# This prevents track accumulation across multiple streams
|
from .tracking_controller import ObjectTracker
|
||||||
tracking_controller = self.tracking_factory.create_controller(
|
tracking_controller = ObjectTracker(
|
||||||
model_repository=self.model_repository,
|
gpu_id=self.gpu_id,
|
||||||
model_id=self.model_id_for_tracking,
|
|
||||||
tracker_type="iou",
|
tracker_type="iou",
|
||||||
max_age=30,
|
max_age=30,
|
||||||
min_confidence=0.5,
|
|
||||||
iou_threshold=0.3,
|
iou_threshold=0.3,
|
||||||
|
class_names=None # TODO: pass class names if available
|
||||||
)
|
)
|
||||||
logger.info(f"Created dedicated TrackingController for stream {stream_id}")
|
logger.info(f"Created lightweight ObjectTracker for stream {stream_id}")
|
||||||
|
|
||||||
# Create connection
|
# Create connection
|
||||||
connection = StreamConnection(
|
connection = StreamConnection(
|
||||||
|
|
|
||||||
|
|
@ -448,6 +448,10 @@ class StreamDecoder:
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
return len(self.frame_buffer)
|
return len(self.frame_buffer)
|
||||||
|
|
||||||
|
def get_frame_count(self) -> int:
|
||||||
|
"""Get total number of frames decoded since start"""
|
||||||
|
return self.frame_count
|
||||||
|
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
"""Check if stream is actively connected"""
|
"""Check if stream is actively connected"""
|
||||||
return self.get_status() == ConnectionStatus.CONNECTED
|
return self.get_status() == ConnectionStatus.CONNECTED
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from collections import defaultdict, deque
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .model_repository import TensorRTModelRepository
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -61,78 +60,81 @@ class TrackedObject:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TrackingController:
|
@dataclass
|
||||||
|
class Detection:
|
||||||
"""
|
"""
|
||||||
GPU-accelerated object tracking controller that wraps TensorRTModelRepository.
|
Represents a single detection from object detection model.
|
||||||
|
|
||||||
Architecture:
|
Attributes:
|
||||||
- Wraps model repository for dependency injection
|
bbox: Bounding box [x1, y1, x2, y2]
|
||||||
- Maintains CUDA state for bbox tracking operations
|
confidence: Detection confidence (0-1)
|
||||||
- Stores persistent tracking data (track IDs, histories, states)
|
class_id: Object class ID
|
||||||
- Processes GPU tensor frames directly (zero-copy pipeline)
|
class_name: Object class name (optional)
|
||||||
- Thread-safe for concurrent tracking operations
|
"""
|
||||||
|
bbox: List[float]
|
||||||
|
confidence: float
|
||||||
|
class_id: int
|
||||||
|
class_name: str = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectTracker:
|
||||||
|
"""
|
||||||
|
Lightweight GPU-accelerated object tracker (decoupled from inference).
|
||||||
|
|
||||||
|
This class only handles tracking logic - associating detections with existing tracks,
|
||||||
|
maintaining track IDs, and managing track lifecycle. It does NOT perform inference.
|
||||||
|
|
||||||
|
Architecture (Event-Driven Mode):
|
||||||
|
- Receives pre-computed detections (from ModelController)
|
||||||
|
- Maintains persistent tracking state (track IDs, histories)
|
||||||
|
- GPU-accelerated IoU computation for track association
|
||||||
|
- Thread-safe for concurrent operations
|
||||||
|
|
||||||
Tracking Flow:
|
Tracking Flow:
|
||||||
GPU Frame → Model Inference (GPU) → Detections (GPU)
|
Detections → Track Association (GPU IoU) → Update Tracks → Return Tracked Objects
|
||||||
↓
|
|
||||||
Tracking Algorithm (GPU/CPU) → Track Assignment
|
|
||||||
↓
|
|
||||||
Update Persistent Tracks → Return Tracked Objects
|
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
- GPU-first: All tensor operations stay on GPU until final results
|
- Lightweight: No model_repository dependency (zero VRAM overhead)
|
||||||
|
- GPU-accelerated: IoU computation on GPU for performance
|
||||||
- Persistent IDs: Tracks maintain consistent IDs across frames
|
- Persistent IDs: Tracks maintain consistent IDs across frames
|
||||||
- Track History: Maintains trajectory history for each object
|
- Track History: Maintains trajectory history for each object
|
||||||
- Configurable: Supports custom tracking algorithms via callbacks
|
|
||||||
- Thread-safe: Mutex-based locking for concurrent access
|
- Thread-safe: Mutex-based locking for concurrent access
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
# Initialize with DI
|
# Event-driven mode (no model dependency)
|
||||||
repo = TensorRTModelRepository(gpu_id=0)
|
tracker = ObjectTracker(
|
||||||
factory = TrackingFactory(gpu_id=0)
|
gpu_id=0,
|
||||||
controller = factory.create_controller(
|
tracker_type="iou",
|
||||||
model_repository=repo,
|
max_age=30,
|
||||||
model_id="yolov8_detector",
|
iou_threshold=0.3,
|
||||||
tracker_type="iou"
|
class_names=COCO_CLASSES
|
||||||
)
|
)
|
||||||
|
|
||||||
# Track objects in frame
|
# Update with pre-computed detections
|
||||||
rgb_frame = decoder.get_latest_frame() # GPU tensor
|
detections = [Detection(bbox=[x1,y1,x2,y2], confidence=0.9, class_id=0)]
|
||||||
tracked_objects = controller.track(rgb_frame)
|
tracked_objects = tracker.update(detections)
|
||||||
|
|
||||||
# Get all tracked objects
|
|
||||||
all_tracks = controller.get_all_tracks()
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_repository: TensorRTModelRepository,
|
|
||||||
model_id: str,
|
|
||||||
gpu_id: int = 0,
|
gpu_id: int = 0,
|
||||||
tracker_type: str = "iou",
|
tracker_type: str = "iou",
|
||||||
max_age: int = 30,
|
max_age: int = 30,
|
||||||
min_confidence: float = 0.5,
|
|
||||||
iou_threshold: float = 0.3,
|
iou_threshold: float = 0.3,
|
||||||
class_names: Optional[Dict[int, str]] = None):
|
class_names: Optional[Dict[int, str]] = None):
|
||||||
"""
|
"""
|
||||||
Initialize TrackingController.
|
Initialize ObjectTracker (no model dependency).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_repository: TensorRT model repository (dependency injection)
|
gpu_id: GPU device ID for IoU computation
|
||||||
model_id: Model ID in repository to use for detection
|
tracker_type: Tracking algorithm type ("iou")
|
||||||
gpu_id: GPU device ID
|
|
||||||
tracker_type: Tracking algorithm type ("iou", "sort", "deepsort", "bytetrack")
|
|
||||||
max_age: Maximum frames to keep track without detection
|
max_age: Maximum frames to keep track without detection
|
||||||
min_confidence: Minimum confidence threshold for detections
|
|
||||||
iou_threshold: IoU threshold for track association
|
iou_threshold: IoU threshold for track association
|
||||||
class_names: Optional mapping of class IDs to names
|
class_names: Optional mapping of class IDs to names
|
||||||
"""
|
"""
|
||||||
self.model_repository = model_repository
|
|
||||||
self.model_id = model_id
|
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.device = torch.device(f'cuda:{gpu_id}')
|
self.device = torch.device(f'cuda:{gpu_id}')
|
||||||
self.tracker_type = tracker_type
|
self.tracker_type = tracker_type
|
||||||
self.max_age = max_age
|
self.max_age = max_age
|
||||||
self.min_confidence = min_confidence
|
|
||||||
self.iou_threshold = iou_threshold
|
self.iou_threshold = iou_threshold
|
||||||
self.class_names = class_names or {}
|
self.class_names = class_names or {}
|
||||||
|
|
||||||
|
|
@ -146,19 +148,6 @@ class TrackingController:
|
||||||
self._total_detections = 0
|
self._total_detections = 0
|
||||||
self._total_tracks_created = 0
|
self._total_tracks_created = 0
|
||||||
|
|
||||||
# Verify model exists in repository
|
|
||||||
metadata = self.model_repository.get_metadata(model_id)
|
|
||||||
if metadata is None:
|
|
||||||
raise ValueError(f"Model '{model_id}' not found in repository")
|
|
||||||
|
|
||||||
print(f"TrackingController initialized:")
|
|
||||||
print(f" Model ID: {model_id}")
|
|
||||||
print(f" GPU: {gpu_id}")
|
|
||||||
print(f" Tracker: {tracker_type}")
|
|
||||||
print(f" Max age: {max_age} frames")
|
|
||||||
print(f" Min confidence: {min_confidence}")
|
|
||||||
print(f" IoU threshold: {iou_threshold}")
|
|
||||||
|
|
||||||
def _compute_iou_gpu(self, boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
|
def _compute_iou_gpu(self, boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute IoU between two sets of boxes on GPU.
|
Compute IoU between two sets of boxes on GPU.
|
||||||
|
|
@ -283,97 +272,51 @@ class TrackingController:
|
||||||
for tid in stale_track_ids:
|
for tid in stale_track_ids:
|
||||||
del self._tracks[tid]
|
del self._tracks[tid]
|
||||||
|
|
||||||
def track(self, frame: torch.Tensor,
|
def update(self, detections: List[Detection]) -> List[TrackedObject]:
|
||||||
preprocess_fn: Optional[callable] = None,
|
|
||||||
postprocess_fn: Optional[callable] = None) -> List[TrackedObject]:
|
|
||||||
"""
|
"""
|
||||||
Track objects in a GPU tensor frame.
|
Update tracker with new detections (decoupled from inference).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frame: RGB frame as GPU tensor, shape (3, H, W) or (1, 3, H, W)
|
detections: List of Detection objects from model inference
|
||||||
preprocess_fn: Optional preprocessing function (frame -> model_input)
|
|
||||||
postprocess_fn: Optional postprocessing function (model_output -> detections)
|
|
||||||
Should return tensor of shape (N, 6): [x1, y1, x2, y2, conf, class_id]
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of currently tracked objects
|
List of currently tracked objects
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._frame_count += 1
|
self._frame_count += 1
|
||||||
|
|
||||||
# Ensure frame is on correct device
|
|
||||||
if not frame.is_cuda:
|
|
||||||
frame = frame.to(self.device)
|
|
||||||
elif frame.device != self.device:
|
|
||||||
frame = frame.to(self.device)
|
|
||||||
|
|
||||||
# Preprocess frame for model
|
|
||||||
if preprocess_fn is not None:
|
|
||||||
model_input = preprocess_fn(frame)
|
|
||||||
else:
|
|
||||||
# Default: add batch dimension if needed
|
|
||||||
if frame.dim() == 3:
|
|
||||||
model_input = frame.unsqueeze(0) # (1, 3, H, W)
|
|
||||||
else:
|
|
||||||
model_input = frame
|
|
||||||
|
|
||||||
# Run inference (GPU-to-GPU)
|
|
||||||
# Assuming model expects input named "images" or "input"
|
|
||||||
metadata = self.model_repository.get_metadata(self.model_id)
|
|
||||||
input_name = metadata.input_names[0] if metadata else "images"
|
|
||||||
|
|
||||||
outputs = self.model_repository.infer(
|
|
||||||
model_id=self.model_id,
|
|
||||||
inputs={input_name: model_input},
|
|
||||||
synchronize=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Postprocess model output to get detections
|
|
||||||
if postprocess_fn is not None:
|
|
||||||
detections = postprocess_fn(outputs)
|
|
||||||
else:
|
|
||||||
# Default: assume output is already in correct format
|
|
||||||
# Get first output tensor
|
|
||||||
output_name = list(outputs.keys())[0]
|
|
||||||
detections = outputs[output_name]
|
|
||||||
|
|
||||||
# Reshape if needed: (1, N, 6) -> (N, 6)
|
|
||||||
if detections.dim() == 3:
|
|
||||||
detections = detections.squeeze(0)
|
|
||||||
|
|
||||||
# Filter by confidence
|
|
||||||
if detections.dim() == 2 and detections.shape[1] >= 5:
|
|
||||||
conf_mask = detections[:, 4] >= self.min_confidence
|
|
||||||
detections = detections[conf_mask]
|
|
||||||
|
|
||||||
self._total_detections += len(detections)
|
self._total_detections += len(detections)
|
||||||
|
|
||||||
# Track objects
|
|
||||||
if len(detections) == 0:
|
|
||||||
# No detections, just cleanup stale tracks
|
# No detections, just cleanup stale tracks
|
||||||
|
if len(detections) == 0:
|
||||||
self._cleanup_stale_tracks()
|
self._cleanup_stale_tracks()
|
||||||
return list(self._tracks.values())
|
return list(self._tracks.values())
|
||||||
|
|
||||||
|
# Convert detections to tensor for GPU processing
|
||||||
|
det_tensor = torch.tensor(
|
||||||
|
[[*det.bbox, det.confidence, det.class_id] for det in detections],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
# Run tracking algorithm
|
# Run tracking algorithm
|
||||||
if self.tracker_type == "iou":
|
if self.tracker_type == "iou":
|
||||||
associations = self._iou_tracking(detections)
|
associations = self._iou_tracking(det_tensor)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Tracker type '{self.tracker_type}' not implemented")
|
raise NotImplementedError(f"Tracker type '{self.tracker_type}' not implemented")
|
||||||
|
|
||||||
# Update tracks based on associations
|
# Update tracks based on associations
|
||||||
for det_idx, track_id in associations:
|
for det_idx, track_id in associations:
|
||||||
detection = detections[det_idx]
|
det = detections[det_idx]
|
||||||
bbox = detection[:4].cpu().tolist()
|
|
||||||
confidence = float(detection[4])
|
|
||||||
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
|
|
||||||
|
|
||||||
if track_id == -1:
|
if track_id == -1:
|
||||||
# Create new track
|
# Create new track
|
||||||
new_track = self._create_track(bbox, confidence, class_id, self._frame_count)
|
new_track = self._create_track(
|
||||||
|
det.bbox, det.confidence, det.class_id, self._frame_count
|
||||||
|
)
|
||||||
self._tracks[new_track.track_id] = new_track
|
self._tracks[new_track.track_id] = new_track
|
||||||
else:
|
else:
|
||||||
# Update existing track
|
# Update existing track
|
||||||
self._tracks[track_id].update(bbox, confidence, self._frame_count)
|
self._tracks[track_id].update(det.bbox, det.confidence, self._frame_count)
|
||||||
|
|
||||||
# Cleanup stale tracks
|
# Cleanup stale tracks
|
||||||
self._cleanup_stale_tracks()
|
self._cleanup_stale_tracks()
|
||||||
|
|
@ -476,7 +419,6 @@ class TrackingController:
|
||||||
'total_tracks_created': self._total_tracks_created,
|
'total_tracks_created': self._total_tracks_created,
|
||||||
'total_detections': self._total_detections,
|
'total_detections': self._total_detections,
|
||||||
'avg_detections_per_frame': self._total_detections / max(self._frame_count, 1),
|
'avg_detections_per_frame': self._total_detections / max(self._frame_count, 1),
|
||||||
'model_id': self.model_id,
|
|
||||||
'tracker_type': self.tracker_type,
|
'tracker_type': self.tracker_type,
|
||||||
'class_counts': self.get_class_counts(active_only=True)
|
'class_counts': self.get_class_counts(active_only=True)
|
||||||
}
|
}
|
||||||
|
|
@ -518,7 +460,6 @@ class TrackingController:
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return (f"TrackingController(model={self.model_id}, "
|
return (f"ObjectTracker(tracker={self.tracker_type}, "
|
||||||
f"tracker={self.tracker_type}, "
|
|
||||||
f"frame={self._frame_count}, "
|
f"frame={self._frame_count}, "
|
||||||
f"tracks={len(self._tracks)})")
|
f"tracks={len(self._tracks)})")
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
from .tracking_controller import TrackingController
|
from .tracking_controller import ObjectTracker
|
||||||
from .model_repository import TensorRTModelRepository
|
from .model_repository import TensorRTModelRepository
|
||||||
|
|
||||||
|
# Backward compatibility alias (TrackingFactory is deprecated in event-driven mode)
|
||||||
|
TrackingController = ObjectTracker
|
||||||
|
|
||||||
|
|
||||||
class TrackingFactory:
|
class TrackingFactory:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@ import asyncio
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from services import (
|
from services import (
|
||||||
StreamConnectionManager,
|
StreamConnectionManager,
|
||||||
|
|
@ -32,17 +34,21 @@ async def main_single_stream():
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
GPU_ID = 0
|
GPU_ID = 0
|
||||||
MODEL_PATH = "models/yolov8n.pt" # PT file will be auto-converted
|
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # PT file will be auto-converted
|
||||||
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
FORCE_TIMEOUT = 0.05
|
FORCE_TIMEOUT = 0.05
|
||||||
|
ENABLE_DISPLAY = os.getenv('ENABLE_DISPLAY', 'false').lower() == 'true' # Set to 'true' to enable OpenCV display
|
||||||
|
MAX_FRAMES = int(os.getenv('MAX_FRAMES', '300')) # Stop after N frames (0 = unlimited)
|
||||||
|
|
||||||
print(f"\nConfiguration:")
|
print(f"\nConfiguration:")
|
||||||
print(f" GPU: {GPU_ID}")
|
print(f" GPU: {GPU_ID}")
|
||||||
print(f" Model: {MODEL_PATH}")
|
print(f" Model: {MODEL_PATH}")
|
||||||
print(f" Stream: {STREAM_URL}")
|
print(f" Stream: {STREAM_URL}")
|
||||||
print(f" Batch size: {BATCH_SIZE}")
|
print(f" Batch size: {BATCH_SIZE}")
|
||||||
print(f" Force timeout: {FORCE_TIMEOUT}s\n")
|
print(f" Force timeout: {FORCE_TIMEOUT}s")
|
||||||
|
print(f" Display: {'Enabled' if ENABLE_DISPLAY else 'Disabled (inference only)'}")
|
||||||
|
print(f" Max frames: {MAX_FRAMES if MAX_FRAMES > 0 else 'Unlimited'}\n")
|
||||||
|
|
||||||
# Create StreamConnectionManager with PT conversion enabled
|
# Create StreamConnectionManager with PT conversion enabled
|
||||||
print("[1/3] Creating StreamConnectionManager...")
|
print("[1/3] Creating StreamConnectionManager...")
|
||||||
|
|
@ -94,14 +100,68 @@ async def main_single_stream():
|
||||||
print("Press Ctrl+C to stop")
|
print("Press Ctrl+C to stop")
|
||||||
print(f"{'=' * 80}\n")
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
# Stream results
|
# Stream results with optional OpenCV visualization
|
||||||
result_count = 0
|
result_count = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Create window only if display is enabled
|
||||||
|
if ENABLE_DISPLAY:
|
||||||
|
cv2.namedWindow("Object Tracking", cv2.WINDOW_NORMAL)
|
||||||
|
cv2.resizeWindow("Object Tracking", 1280, 720)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for result in connection.tracking_results():
|
async for result in connection.tracking_results():
|
||||||
result_count += 1
|
result_count += 1
|
||||||
|
|
||||||
|
# Check if we've reached max frames
|
||||||
|
if MAX_FRAMES > 0 and result_count >= MAX_FRAMES:
|
||||||
|
print(f"\n✓ Reached max frames limit ({MAX_FRAMES})")
|
||||||
|
break
|
||||||
|
|
||||||
|
# OpenCV visualization (only if enabled)
|
||||||
|
if ENABLE_DISPLAY:
|
||||||
|
# Get latest frame from decoder (as CPU numpy array)
|
||||||
|
frame = connection.decoder.get_latest_frame_cpu(rgb=True)
|
||||||
|
|
||||||
|
if frame is not None:
|
||||||
|
# Convert to BGR for OpenCV
|
||||||
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# Draw tracked objects
|
||||||
|
for obj in result.tracked_objects:
|
||||||
|
# Get bbox coordinates
|
||||||
|
x1, y1, x2, y2 = map(int, obj.bbox)
|
||||||
|
|
||||||
|
# Draw bounding box
|
||||||
|
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# Draw track ID and class name
|
||||||
|
label = f"ID:{obj.track_id} {obj.class_name} {obj.confidence:.2f}"
|
||||||
|
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||||
|
|
||||||
|
# Draw label background
|
||||||
|
cv2.rectangle(frame_bgr, (x1, y1 - label_size[1] - 10),
|
||||||
|
(x1 + label_size[0], y1), (0, 255, 0), -1)
|
||||||
|
|
||||||
|
# Draw label text
|
||||||
|
cv2.putText(frame_bgr, label, (x1, y1 - 5),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
||||||
|
|
||||||
|
# Draw FPS and object count
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
fps = result_count / elapsed if elapsed > 0 else 0
|
||||||
|
info_text = f"FPS: {fps:.1f} | Objects: {len(result.tracked_objects)} | Frame: {result_count}"
|
||||||
|
cv2.putText(frame_bgr, info_text, (10, 30),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# Display frame
|
||||||
|
cv2.imshow("Object Tracking", frame_bgr)
|
||||||
|
|
||||||
|
# Check for 'q' key to quit
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
print(f"\n✓ Quit by user (pressed 'q')")
|
||||||
|
break
|
||||||
|
|
||||||
# Print stats every 30 results
|
# Print stats every 30 results
|
||||||
if result_count % 30 == 0:
|
if result_count % 30 == 0:
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
@ -125,6 +185,10 @@ async def main_single_stream():
|
||||||
print("Cleanup")
|
print("Cleanup")
|
||||||
print(f"{'=' * 80}")
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# Close OpenCV window if it was opened
|
||||||
|
if ENABLE_DISPLAY:
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
await connection.stop()
|
await connection.stop()
|
||||||
await manager.shutdown()
|
await manager.shutdown()
|
||||||
print("✓ Stopped")
|
print("✓ Stopped")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue