feat: trac king
This commit is contained in:
parent
cf24a172a2
commit
bea895d3d8
4 changed files with 1054 additions and 0 deletions
|
|
@ -4,6 +4,9 @@ Services package for RTSP stream processing with GPU acceleration.
|
|||
|
||||
from .stream_decoder import StreamDecoderFactory, StreamDecoder, ConnectionStatus
|
||||
from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg
|
||||
from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionContext, SharedEngine
|
||||
from .tracking_controller import TrackingController, TrackedObject
|
||||
from .tracking_factory import TrackingFactory
|
||||
|
||||
__all__ = [
|
||||
'StreamDecoderFactory',
|
||||
|
|
@ -11,4 +14,11 @@ __all__ = [
|
|||
'ConnectionStatus',
|
||||
'JPEGEncoderFactory',
|
||||
'encode_frame_to_jpeg',
|
||||
'TensorRTModelRepository',
|
||||
'ModelMetadata',
|
||||
'ExecutionContext',
|
||||
'SharedEngine',
|
||||
'TrackingController',
|
||||
'TrackedObject',
|
||||
'TrackingFactory',
|
||||
]
|
||||
|
|
|
|||
524
services/tracking_controller.py
Normal file
524
services/tracking_controller.py
Normal file
|
|
@ -0,0 +1,524 @@
|
|||
import threading
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict, deque
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from .model_repository import TensorRTModelRepository
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackedObject:
|
||||
"""
|
||||
Represents a tracked object with persistent ID and metadata.
|
||||
|
||||
Attributes:
|
||||
track_id: Unique persistent tracking ID
|
||||
class_id: Object class ID from detection model
|
||||
class_name: Object class name (if available)
|
||||
confidence: Detection confidence score (0-1)
|
||||
bbox: Bounding box in format [x1, y1, x2, y2] (normalized or absolute)
|
||||
last_seen_frame: Frame number when object was last detected
|
||||
first_seen_frame: Frame number when object was first detected
|
||||
track_history: Deque of historical bboxes for trajectory tracking
|
||||
state: Custom state dict for additional tracking data
|
||||
"""
|
||||
track_id: int
|
||||
class_id: int
|
||||
class_name: str = "unknown"
|
||||
confidence: float = 0.0
|
||||
bbox: List[float] = field(default_factory=list)
|
||||
last_seen_frame: int = 0
|
||||
first_seen_frame: int = 0
|
||||
track_history: deque = field(default_factory=lambda: deque(maxlen=30))
|
||||
state: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def update(self, bbox: List[float], confidence: float, frame_num: int):
|
||||
"""Update tracked object with new detection"""
|
||||
self.bbox = bbox
|
||||
self.confidence = confidence
|
||||
self.last_seen_frame = frame_num
|
||||
self.track_history.append((frame_num, bbox, confidence))
|
||||
|
||||
def age(self, current_frame: int) -> int:
|
||||
"""Get age of track in frames since last seen"""
|
||||
return current_frame - self.last_seen_frame
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization"""
|
||||
return {
|
||||
'track_id': self.track_id,
|
||||
'class_id': self.class_id,
|
||||
'class_name': self.class_name,
|
||||
'confidence': self.confidence,
|
||||
'bbox': self.bbox,
|
||||
'last_seen_frame': self.last_seen_frame,
|
||||
'first_seen_frame': self.first_seen_frame,
|
||||
'age': self.last_seen_frame - self.first_seen_frame,
|
||||
'history_length': len(self.track_history),
|
||||
'state': self.state
|
||||
}
|
||||
|
||||
|
||||
class TrackingController:
|
||||
"""
|
||||
GPU-accelerated object tracking controller that wraps TensorRTModelRepository.
|
||||
|
||||
Architecture:
|
||||
- Wraps model repository for dependency injection
|
||||
- Maintains CUDA state for bbox tracking operations
|
||||
- Stores persistent tracking data (track IDs, histories, states)
|
||||
- Processes GPU tensor frames directly (zero-copy pipeline)
|
||||
- Thread-safe for concurrent tracking operations
|
||||
|
||||
Tracking Flow:
|
||||
GPU Frame → Model Inference (GPU) → Detections (GPU)
|
||||
↓
|
||||
Tracking Algorithm (GPU/CPU) → Track Assignment
|
||||
↓
|
||||
Update Persistent Tracks → Return Tracked Objects
|
||||
|
||||
Features:
|
||||
- GPU-first: All tensor operations stay on GPU until final results
|
||||
- Persistent IDs: Tracks maintain consistent IDs across frames
|
||||
- Track History: Maintains trajectory history for each object
|
||||
- Configurable: Supports custom tracking algorithms via callbacks
|
||||
- Thread-safe: Mutex-based locking for concurrent access
|
||||
|
||||
Example:
|
||||
# Initialize with DI
|
||||
repo = TensorRTModelRepository(gpu_id=0)
|
||||
factory = TrackingFactory(gpu_id=0)
|
||||
controller = factory.create_controller(
|
||||
model_repository=repo,
|
||||
model_id="yolov8_detector",
|
||||
tracker_type="iou"
|
||||
)
|
||||
|
||||
# Track objects in frame
|
||||
rgb_frame = decoder.get_latest_frame() # GPU tensor
|
||||
tracked_objects = controller.track(rgb_frame)
|
||||
|
||||
# Get all tracked objects
|
||||
all_tracks = controller.get_all_tracks()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_repository: TensorRTModelRepository,
|
||||
model_id: str,
|
||||
gpu_id: int = 0,
|
||||
tracker_type: str = "iou",
|
||||
max_age: int = 30,
|
||||
min_confidence: float = 0.5,
|
||||
iou_threshold: float = 0.3,
|
||||
class_names: Optional[Dict[int, str]] = None):
|
||||
"""
|
||||
Initialize TrackingController.
|
||||
|
||||
Args:
|
||||
model_repository: TensorRT model repository (dependency injection)
|
||||
model_id: Model ID in repository to use for detection
|
||||
gpu_id: GPU device ID
|
||||
tracker_type: Tracking algorithm type ("iou", "sort", "deepsort", "bytetrack")
|
||||
max_age: Maximum frames to keep track without detection
|
||||
min_confidence: Minimum confidence threshold for detections
|
||||
iou_threshold: IoU threshold for track association
|
||||
class_names: Optional mapping of class IDs to names
|
||||
"""
|
||||
self.model_repository = model_repository
|
||||
self.model_id = model_id
|
||||
self.gpu_id = gpu_id
|
||||
self.device = torch.device(f'cuda:{gpu_id}')
|
||||
self.tracker_type = tracker_type
|
||||
self.max_age = max_age
|
||||
self.min_confidence = min_confidence
|
||||
self.iou_threshold = iou_threshold
|
||||
self.class_names = class_names or {}
|
||||
|
||||
# Tracking state
|
||||
self._tracks: Dict[int, TrackedObject] = {}
|
||||
self._next_track_id: int = 0
|
||||
self._frame_count: int = 0
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Statistics
|
||||
self._total_detections = 0
|
||||
self._total_tracks_created = 0
|
||||
|
||||
# Verify model exists in repository
|
||||
metadata = self.model_repository.get_metadata(model_id)
|
||||
if metadata is None:
|
||||
raise ValueError(f"Model '{model_id}' not found in repository")
|
||||
|
||||
print(f"TrackingController initialized:")
|
||||
print(f" Model ID: {model_id}")
|
||||
print(f" GPU: {gpu_id}")
|
||||
print(f" Tracker: {tracker_type}")
|
||||
print(f" Max age: {max_age} frames")
|
||||
print(f" Min confidence: {min_confidence}")
|
||||
print(f" IoU threshold: {iou_threshold}")
|
||||
|
||||
def _compute_iou_gpu(self, boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute IoU between two sets of boxes on GPU.
|
||||
|
||||
Args:
|
||||
boxes1: Tensor of shape (N, 4) in format [x1, y1, x2, y2]
|
||||
boxes2: Tensor of shape (M, 4) in format [x1, y1, x2, y2]
|
||||
|
||||
Returns:
|
||||
IoU matrix of shape (N, M)
|
||||
"""
|
||||
# Ensure on GPU
|
||||
boxes1 = boxes1.to(self.device)
|
||||
boxes2 = boxes2.to(self.device)
|
||||
|
||||
# Compute intersection
|
||||
x1_max = torch.max(boxes1[:, None, 0], boxes2[:, 0]) # (N, M)
|
||||
y1_max = torch.max(boxes1[:, None, 1], boxes2[:, 1]) # (N, M)
|
||||
x2_min = torch.min(boxes1[:, None, 2], boxes2[:, 2]) # (N, M)
|
||||
y2_min = torch.min(boxes1[:, None, 3], boxes2[:, 3]) # (N, M)
|
||||
|
||||
intersection_width = torch.clamp(x2_min - x1_max, min=0)
|
||||
intersection_height = torch.clamp(y2_min - y1_max, min=0)
|
||||
intersection_area = intersection_width * intersection_height
|
||||
|
||||
# Compute areas
|
||||
boxes1_area = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
|
||||
boxes2_area = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
|
||||
|
||||
# Compute union
|
||||
union_area = boxes1_area[:, None] + boxes2_area - intersection_area
|
||||
|
||||
# Compute IoU
|
||||
iou = intersection_area / (union_area + 1e-6)
|
||||
|
||||
return iou
|
||||
|
||||
def _iou_tracking(self, detections: torch.Tensor) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Simple IoU-based tracking algorithm (on GPU).
|
||||
|
||||
Args:
|
||||
detections: Tensor of shape (N, 6) with [x1, y1, x2, y2, conf, class_id]
|
||||
|
||||
Returns:
|
||||
List of (detection_idx, track_id) associations
|
||||
"""
|
||||
if len(self._tracks) == 0:
|
||||
# No existing tracks, create new ones for all detections
|
||||
return [(-1, -1) for _ in range(len(detections))]
|
||||
|
||||
# Get existing track bboxes
|
||||
track_ids = list(self._tracks.keys())
|
||||
track_bboxes = torch.tensor(
|
||||
[self._tracks[tid].bbox for tid in track_ids],
|
||||
dtype=torch.float32,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Extract detection bboxes
|
||||
det_bboxes = detections[:, :4] # (N, 4)
|
||||
|
||||
# Compute IoU matrix (GPU)
|
||||
iou_matrix = self._compute_iou_gpu(det_bboxes, track_bboxes) # (N, M)
|
||||
|
||||
# Greedy matching: assign each detection to best matching track
|
||||
associations = []
|
||||
matched_tracks = set()
|
||||
|
||||
# Convert to CPU for matching logic (small matrix)
|
||||
iou_cpu = iou_matrix.cpu().numpy()
|
||||
|
||||
for det_idx in range(len(detections)):
|
||||
best_iou = self.iou_threshold
|
||||
best_track_idx = -1
|
||||
|
||||
for track_idx, track_id in enumerate(track_ids):
|
||||
if track_idx in matched_tracks:
|
||||
continue
|
||||
|
||||
if iou_cpu[det_idx, track_idx] > best_iou:
|
||||
best_iou = iou_cpu[det_idx, track_idx]
|
||||
best_track_idx = track_idx
|
||||
|
||||
if best_track_idx >= 0:
|
||||
associations.append((det_idx, track_ids[best_track_idx]))
|
||||
matched_tracks.add(best_track_idx)
|
||||
else:
|
||||
associations.append((det_idx, -1)) # New track
|
||||
|
||||
return associations
|
||||
|
||||
def _create_track(self, bbox: List[float], confidence: float,
|
||||
class_id: int, frame_num: int) -> TrackedObject:
|
||||
"""Create a new tracked object"""
|
||||
track_id = self._next_track_id
|
||||
self._next_track_id += 1
|
||||
self._total_tracks_created += 1
|
||||
|
||||
class_name = self.class_names.get(class_id, f"class_{class_id}")
|
||||
|
||||
track = TrackedObject(
|
||||
track_id=track_id,
|
||||
class_id=class_id,
|
||||
class_name=class_name,
|
||||
confidence=confidence,
|
||||
bbox=bbox,
|
||||
last_seen_frame=frame_num,
|
||||
first_seen_frame=frame_num
|
||||
)
|
||||
track.track_history.append((frame_num, bbox, confidence))
|
||||
|
||||
return track
|
||||
|
||||
def _cleanup_stale_tracks(self):
|
||||
"""Remove tracks that haven't been seen for max_age frames"""
|
||||
stale_track_ids = [
|
||||
tid for tid, track in self._tracks.items()
|
||||
if track.age(self._frame_count) > self.max_age
|
||||
]
|
||||
|
||||
for tid in stale_track_ids:
|
||||
del self._tracks[tid]
|
||||
|
||||
def track(self, frame: torch.Tensor,
|
||||
preprocess_fn: Optional[callable] = None,
|
||||
postprocess_fn: Optional[callable] = None) -> List[TrackedObject]:
|
||||
"""
|
||||
Track objects in a GPU tensor frame.
|
||||
|
||||
Args:
|
||||
frame: RGB frame as GPU tensor, shape (3, H, W) or (1, 3, H, W)
|
||||
preprocess_fn: Optional preprocessing function (frame -> model_input)
|
||||
postprocess_fn: Optional postprocessing function (model_output -> detections)
|
||||
Should return tensor of shape (N, 6): [x1, y1, x2, y2, conf, class_id]
|
||||
|
||||
Returns:
|
||||
List of currently tracked objects
|
||||
"""
|
||||
with self._lock:
|
||||
self._frame_count += 1
|
||||
|
||||
# Ensure frame is on correct device
|
||||
if not frame.is_cuda:
|
||||
frame = frame.to(self.device)
|
||||
elif frame.device != self.device:
|
||||
frame = frame.to(self.device)
|
||||
|
||||
# Preprocess frame for model
|
||||
if preprocess_fn is not None:
|
||||
model_input = preprocess_fn(frame)
|
||||
else:
|
||||
# Default: add batch dimension if needed
|
||||
if frame.dim() == 3:
|
||||
model_input = frame.unsqueeze(0) # (1, 3, H, W)
|
||||
else:
|
||||
model_input = frame
|
||||
|
||||
# Run inference (GPU-to-GPU)
|
||||
# Assuming model expects input named "images" or "input"
|
||||
metadata = self.model_repository.get_metadata(self.model_id)
|
||||
input_name = metadata.input_names[0] if metadata else "images"
|
||||
|
||||
outputs = self.model_repository.infer(
|
||||
model_id=self.model_id,
|
||||
inputs={input_name: model_input},
|
||||
synchronize=True
|
||||
)
|
||||
|
||||
# Postprocess model output to get detections
|
||||
if postprocess_fn is not None:
|
||||
detections = postprocess_fn(outputs)
|
||||
else:
|
||||
# Default: assume output is already in correct format
|
||||
# Get first output tensor
|
||||
output_name = list(outputs.keys())[0]
|
||||
detections = outputs[output_name]
|
||||
|
||||
# Reshape if needed: (1, N, 6) -> (N, 6)
|
||||
if detections.dim() == 3:
|
||||
detections = detections.squeeze(0)
|
||||
|
||||
# Filter by confidence
|
||||
if detections.dim() == 2 and detections.shape[1] >= 5:
|
||||
conf_mask = detections[:, 4] >= self.min_confidence
|
||||
detections = detections[conf_mask]
|
||||
|
||||
self._total_detections += len(detections)
|
||||
|
||||
# Track objects
|
||||
if len(detections) == 0:
|
||||
# No detections, just cleanup stale tracks
|
||||
self._cleanup_stale_tracks()
|
||||
return list(self._tracks.values())
|
||||
|
||||
# Run tracking algorithm
|
||||
if self.tracker_type == "iou":
|
||||
associations = self._iou_tracking(detections)
|
||||
else:
|
||||
raise NotImplementedError(f"Tracker type '{self.tracker_type}' not implemented")
|
||||
|
||||
# Update tracks based on associations
|
||||
for det_idx, track_id in associations:
|
||||
detection = detections[det_idx]
|
||||
bbox = detection[:4].cpu().tolist()
|
||||
confidence = float(detection[4])
|
||||
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
|
||||
|
||||
if track_id == -1:
|
||||
# Create new track
|
||||
new_track = self._create_track(bbox, confidence, class_id, self._frame_count)
|
||||
self._tracks[new_track.track_id] = new_track
|
||||
else:
|
||||
# Update existing track
|
||||
self._tracks[track_id].update(bbox, confidence, self._frame_count)
|
||||
|
||||
# Cleanup stale tracks
|
||||
self._cleanup_stale_tracks()
|
||||
|
||||
return list(self._tracks.values())
|
||||
|
||||
def get_all_tracks(self, active_only: bool = True) -> List[TrackedObject]:
|
||||
"""
|
||||
Get all tracked objects.
|
||||
|
||||
Args:
|
||||
active_only: If True, only return tracks seen in recent frames
|
||||
|
||||
Returns:
|
||||
List of tracked objects
|
||||
"""
|
||||
with self._lock:
|
||||
if active_only:
|
||||
return [
|
||||
track for track in self._tracks.values()
|
||||
if track.age(self._frame_count) <= self.max_age
|
||||
]
|
||||
else:
|
||||
return list(self._tracks.values())
|
||||
|
||||
def get_track_by_id(self, track_id: int) -> Optional[TrackedObject]:
|
||||
"""
|
||||
Get a specific track by ID.
|
||||
|
||||
Args:
|
||||
track_id: Track ID to retrieve
|
||||
|
||||
Returns:
|
||||
TrackedObject or None if not found
|
||||
"""
|
||||
with self._lock:
|
||||
return self._tracks.get(track_id)
|
||||
|
||||
def get_tracks_by_class(self, class_id: int, active_only: bool = True) -> List[TrackedObject]:
|
||||
"""
|
||||
Get all tracks of a specific class.
|
||||
|
||||
Args:
|
||||
class_id: Class ID to filter by
|
||||
active_only: If True, only return active tracks
|
||||
|
||||
Returns:
|
||||
List of tracked objects
|
||||
"""
|
||||
all_tracks = self.get_all_tracks(active_only=active_only)
|
||||
return [track for track in all_tracks if track.class_id == class_id]
|
||||
|
||||
def get_track_count(self, active_only: bool = True) -> int:
|
||||
"""
|
||||
Get number of tracked objects.
|
||||
|
||||
Args:
|
||||
active_only: If True, only count active tracks
|
||||
|
||||
Returns:
|
||||
Number of tracks
|
||||
"""
|
||||
return len(self.get_all_tracks(active_only=active_only))
|
||||
|
||||
def get_class_counts(self, active_only: bool = True) -> Dict[int, int]:
|
||||
"""
|
||||
Get count of tracked objects per class.
|
||||
|
||||
Args:
|
||||
active_only: If True, only count active tracks
|
||||
|
||||
Returns:
|
||||
Dictionary mapping class_id to count
|
||||
"""
|
||||
tracks = self.get_all_tracks(active_only=active_only)
|
||||
counts = defaultdict(int)
|
||||
for track in tracks:
|
||||
counts[track.class_id] += 1
|
||||
return dict(counts)
|
||||
|
||||
def reset_tracks(self):
|
||||
"""Clear all tracking state"""
|
||||
with self._lock:
|
||||
self._tracks.clear()
|
||||
self._next_track_id = 0
|
||||
self._frame_count = 0
|
||||
print("Tracking state reset")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get tracking statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with tracking stats
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
'frame_count': self._frame_count,
|
||||
'active_tracks': len(self._tracks),
|
||||
'total_tracks_created': self._total_tracks_created,
|
||||
'total_detections': self._total_detections,
|
||||
'avg_detections_per_frame': self._total_detections / max(self._frame_count, 1),
|
||||
'model_id': self.model_id,
|
||||
'tracker_type': self.tracker_type,
|
||||
'class_counts': self.get_class_counts(active_only=True)
|
||||
}
|
||||
|
||||
def export_tracks(self, format: str = "dict") -> Any:
|
||||
"""
|
||||
Export all tracks in specified format.
|
||||
|
||||
Args:
|
||||
format: Export format ("dict", "json", "numpy")
|
||||
|
||||
Returns:
|
||||
Tracks in specified format
|
||||
"""
|
||||
with self._lock:
|
||||
tracks = self.get_all_tracks(active_only=False)
|
||||
|
||||
if format == "dict":
|
||||
return {track.track_id: track.to_dict() for track in tracks}
|
||||
elif format == "json":
|
||||
import json
|
||||
return json.dumps(
|
||||
{track.track_id: track.to_dict() for track in tracks},
|
||||
indent=2
|
||||
)
|
||||
elif format == "numpy":
|
||||
# Export as numpy array: [track_id, class_id, x1, y1, x2, y2, conf]
|
||||
data = []
|
||||
for track in tracks:
|
||||
data.append([
|
||||
track.track_id,
|
||||
track.class_id,
|
||||
*track.bbox,
|
||||
track.confidence
|
||||
])
|
||||
return np.array(data) if data else np.array([])
|
||||
else:
|
||||
raise ValueError(f"Unknown export format: {format}")
|
||||
|
||||
def __repr__(self):
|
||||
with self._lock:
|
||||
return (f"TrackingController(model={self.model_id}, "
|
||||
f"tracker={self.tracker_type}, "
|
||||
f"frame={self._frame_count}, "
|
||||
f"tracks={len(self._tracks)})")
|
||||
202
services/tracking_factory.py
Normal file
202
services/tracking_factory.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
import threading
|
||||
from typing import Optional, Dict
|
||||
from .tracking_controller import TrackingController
|
||||
from .model_repository import TensorRTModelRepository
|
||||
|
||||
|
||||
class TrackingFactory:
|
||||
"""
|
||||
Factory for creating TrackingController instances with shared GPU resources.
|
||||
|
||||
This factory follows the same pattern as StreamDecoderFactory for consistency:
|
||||
- Singleton pattern per GPU
|
||||
- Manages shared CUDA state if needed
|
||||
- Provides centralized controller creation
|
||||
- Thread-safe controller fabrication
|
||||
|
||||
The factory doesn't need to manage CUDA context directly since the
|
||||
TensorRTModelRepository already handles GPU resource management.
|
||||
Instead, it provides a clean interface for creating controllers with
|
||||
proper configuration and dependency injection.
|
||||
|
||||
Example:
|
||||
# Get factory instance for GPU 0
|
||||
factory = TrackingFactory(gpu_id=0)
|
||||
|
||||
# Create model repository
|
||||
repo = TensorRTModelRepository(gpu_id=0)
|
||||
repo.load_model("detector", "yolov8.trt")
|
||||
|
||||
# Create tracking controller
|
||||
controller = factory.create_controller(
|
||||
model_repository=repo,
|
||||
model_id="detector",
|
||||
tracker_type="iou"
|
||||
)
|
||||
|
||||
# Multiple controllers share the same model repository
|
||||
controller2 = factory.create_controller(
|
||||
model_repository=repo,
|
||||
model_id="detector",
|
||||
tracker_type="iou"
|
||||
)
|
||||
"""
|
||||
|
||||
_instances: Dict[int, 'TrackingFactory'] = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, gpu_id: int = 0):
|
||||
"""
|
||||
Singleton pattern per GPU.
|
||||
Each GPU gets its own factory instance.
|
||||
"""
|
||||
if gpu_id not in cls._instances:
|
||||
with cls._lock:
|
||||
if gpu_id not in cls._instances:
|
||||
instance = super(TrackingFactory, cls).__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instances[gpu_id] = instance
|
||||
return cls._instances[gpu_id]
|
||||
|
||||
def __init__(self, gpu_id: int = 0):
|
||||
"""
|
||||
Initialize the tracking factory.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID to use
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self._controller_count = 0
|
||||
self._controller_lock = threading.Lock()
|
||||
|
||||
self._initialized = True
|
||||
print(f"TrackingFactory initialized for GPU {gpu_id}")
|
||||
|
||||
def create_controller(self,
|
||||
model_repository: TensorRTModelRepository,
|
||||
model_id: str,
|
||||
tracker_type: str = "iou",
|
||||
max_age: int = 30,
|
||||
min_confidence: float = 0.5,
|
||||
iou_threshold: float = 0.3,
|
||||
class_names: Optional[Dict[int, str]] = None) -> TrackingController:
|
||||
"""
|
||||
Create a new TrackingController instance.
|
||||
|
||||
Args:
|
||||
model_repository: TensorRT model repository (dependency injection)
|
||||
model_id: Model ID in repository to use for detection
|
||||
tracker_type: Tracking algorithm type ("iou", "sort", "deepsort", "bytetrack")
|
||||
max_age: Maximum frames to keep track without detection
|
||||
min_confidence: Minimum confidence threshold for detections
|
||||
iou_threshold: IoU threshold for track association
|
||||
class_names: Optional mapping of class IDs to names
|
||||
|
||||
Returns:
|
||||
TrackingController instance
|
||||
|
||||
Raises:
|
||||
ValueError: If model_repository GPU doesn't match factory GPU
|
||||
ValueError: If model_id not found in repository
|
||||
"""
|
||||
# Validate GPU ID matches
|
||||
if model_repository.gpu_id != self.gpu_id:
|
||||
raise ValueError(
|
||||
f"Model repository GPU ({model_repository.gpu_id}) doesn't match "
|
||||
f"factory GPU ({self.gpu_id})"
|
||||
)
|
||||
|
||||
# Verify model exists
|
||||
if model_repository.get_metadata(model_id) is None:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' not found in repository. "
|
||||
f"Available models: {list(model_repository._model_to_hash.keys())}"
|
||||
)
|
||||
|
||||
with self._controller_lock:
|
||||
self._controller_count += 1
|
||||
|
||||
controller = TrackingController(
|
||||
model_repository=model_repository,
|
||||
model_id=model_id,
|
||||
gpu_id=self.gpu_id,
|
||||
tracker_type=tracker_type,
|
||||
max_age=max_age,
|
||||
min_confidence=min_confidence,
|
||||
iou_threshold=iou_threshold,
|
||||
class_names=class_names
|
||||
)
|
||||
|
||||
print(f"Created TrackingController #{self._controller_count} (model: {model_id})")
|
||||
|
||||
return controller
|
||||
|
||||
def create_multi_model_controller(self,
|
||||
model_repository: TensorRTModelRepository,
|
||||
model_configs: Dict[str, Dict],
|
||||
ensemble_strategy: str = "nms") -> 'MultiModelTrackingController':
|
||||
"""
|
||||
Create a multi-model tracking controller that combines multiple detectors.
|
||||
|
||||
Args:
|
||||
model_repository: TensorRT model repository
|
||||
model_configs: Dict mapping model_id to config dict with keys:
|
||||
- tracker_type, max_age, min_confidence, iou_threshold, class_names
|
||||
ensemble_strategy: How to combine detections ("nms", "vote", "union")
|
||||
|
||||
Returns:
|
||||
MultiModelTrackingController instance
|
||||
|
||||
Note:
|
||||
This is a placeholder for future multi-model support.
|
||||
Currently raises NotImplementedError.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Multi-model tracking controller not yet implemented. "
|
||||
"Use create_controller for single-model tracking."
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""
|
||||
Get factory statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with factory stats
|
||||
"""
|
||||
with self._controller_lock:
|
||||
return {
|
||||
'gpu_id': self.gpu_id,
|
||||
'controllers_created': self._controller_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_factory(cls, gpu_id: int = 0) -> 'TrackingFactory':
|
||||
"""
|
||||
Get or create factory instance for specified GPU.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID
|
||||
|
||||
Returns:
|
||||
TrackingFactory instance for the GPU
|
||||
"""
|
||||
return cls(gpu_id=gpu_id)
|
||||
|
||||
@classmethod
|
||||
def list_factories(cls) -> Dict[int, 'TrackingFactory']:
|
||||
"""
|
||||
List all factory instances.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping gpu_id to factory instance
|
||||
"""
|
||||
with cls._lock:
|
||||
return cls._instances.copy()
|
||||
|
||||
def __repr__(self):
|
||||
with self._controller_lock:
|
||||
return (f"TrackingFactory(gpu={self.gpu_id}, "
|
||||
f"controllers_created={self._controller_count})")
|
||||
318
test_tracking.py
Normal file
318
test_tracking.py
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
"""
|
||||
Test script for TrackingController and TrackingFactory.
|
||||
|
||||
This script demonstrates how to use the tracking system with:
|
||||
- TensorRT model repository (dependency injection)
|
||||
- TrackingFactory for controller creation
|
||||
- GPU-accelerated object tracking on RTSP streams
|
||||
- Persistent track IDs and history management
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from services import (
|
||||
StreamDecoderFactory,
|
||||
TensorRTModelRepository,
|
||||
TrackingFactory,
|
||||
TrackedObject
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main test function demonstrating tracking workflow.
|
||||
"""
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.trt" # Update with your model path
|
||||
RTSP_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
BUFFER_SIZE = 30
|
||||
|
||||
# COCO class names (example for YOLOv8)
|
||||
COCO_CLASSES = {
|
||||
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
|
||||
5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
|
||||
# Add more as needed...
|
||||
}
|
||||
|
||||
print("=" * 80)
|
||||
print("GPU-Accelerated Object Tracking Test")
|
||||
print("=" * 80)
|
||||
|
||||
# Step 1: Create model repository
|
||||
print("\n[1/5] Initializing TensorRT Model Repository...")
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4)
|
||||
|
||||
# Load detection model (if file exists)
|
||||
model_id = "yolov8_detector"
|
||||
if os.path.exists(MODEL_PATH):
|
||||
try:
|
||||
metadata = model_repo.load_model(
|
||||
model_id=model_id,
|
||||
file_path=MODEL_PATH,
|
||||
num_contexts=4
|
||||
)
|
||||
print(f"✓ Model loaded successfully")
|
||||
print(f" Input shape: {metadata.input_shapes}")
|
||||
print(f" Output shape: {metadata.output_shapes}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
print(f" Please ensure {MODEL_PATH} exists")
|
||||
print(f" Continuing with demo (will use mock detections)...")
|
||||
model_id = None
|
||||
else:
|
||||
print(f"✗ Model file not found: {MODEL_PATH}")
|
||||
print(f" Continuing with demo (will use mock detections)...")
|
||||
model_id = None
|
||||
|
||||
# Step 2: Create tracking factory
|
||||
print("\n[2/5] Creating TrackingFactory...")
|
||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||
print(f"✓ Factory created: {tracking_factory}")
|
||||
|
||||
# Step 3: Create tracking controller (only if model loaded)
|
||||
tracking_controller = None
|
||||
if model_id is not None:
|
||||
print("\n[3/5] Creating TrackingController...")
|
||||
try:
|
||||
tracking_controller = tracking_factory.create_controller(
|
||||
model_repository=model_repo,
|
||||
model_id=model_id,
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5,
|
||||
iou_threshold=0.3,
|
||||
class_names=COCO_CLASSES
|
||||
)
|
||||
print(f"✓ Controller created: {tracking_controller}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to create controller: {e}")
|
||||
tracking_controller = None
|
||||
else:
|
||||
print("\n[3/5] Skipping TrackingController creation (no model loaded)")
|
||||
|
||||
# Step 4: Create stream decoder
|
||||
print("\n[4/5] Creating RTSP Stream Decoder...")
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoder = stream_factory.create_decoder(
|
||||
rtsp_url=RTSP_URL,
|
||||
buffer_size=BUFFER_SIZE
|
||||
)
|
||||
decoder.start()
|
||||
print(f"✓ Decoder started for: {RTSP_URL}")
|
||||
print(f" Waiting for connection...")
|
||||
|
||||
# Wait for stream connection
|
||||
time.sleep(5)
|
||||
|
||||
if decoder.is_connected():
|
||||
print(f"✓ Stream connected!")
|
||||
else:
|
||||
print(f"✗ Stream not connected (status: {decoder.get_status().value})")
|
||||
print(f" Note: This is expected if RTSP URL is not available")
|
||||
print(f" The tracking system will still work with valid streams")
|
||||
|
||||
# Step 5: Run tracking loop (demo)
|
||||
print("\n[5/5] Running Tracking Loop...")
|
||||
print(f" Processing frames for 30 seconds...")
|
||||
print(f" Press Ctrl+C to stop early\n")
|
||||
|
||||
try:
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < 30:
|
||||
# Get latest frame from decoder (GPU tensor)
|
||||
frame = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
if frame is None:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# Run tracking (if controller available)
|
||||
if tracking_controller is not None:
|
||||
try:
|
||||
# Track objects in frame
|
||||
tracked_objects = tracking_controller.track(frame)
|
||||
|
||||
# Display tracking results every 10 frames
|
||||
if frame_count % 10 == 0:
|
||||
print(f"\n--- Frame {frame_count} ---")
|
||||
print(f"Active tracks: {len(tracked_objects)}")
|
||||
|
||||
for obj in tracked_objects:
|
||||
print(f" Track #{obj.track_id}: {obj.class_name} "
|
||||
f"(conf={obj.confidence:.2f}, "
|
||||
f"bbox={[f'{x:.1f}' for x in obj.bbox]}, "
|
||||
f"age={obj.age(tracking_controller._frame_count)} frames)")
|
||||
|
||||
# Print statistics
|
||||
stats = tracking_controller.get_statistics()
|
||||
print(f"\nStatistics:")
|
||||
print(f" Total frames processed: {stats['frame_count']}")
|
||||
print(f" Total tracks created: {stats['total_tracks_created']}")
|
||||
print(f" Total detections: {stats['total_detections']}")
|
||||
print(f" Avg detections/frame: {stats['avg_detections_per_frame']:.2f}")
|
||||
print(f" Class counts: {stats['class_counts']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Tracking error on frame {frame_count}: {e}")
|
||||
|
||||
# Small delay to avoid overwhelming output
|
||||
time.sleep(0.1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n✓ Interrupted by user")
|
||||
|
||||
# Cleanup
|
||||
print("\n" + "=" * 80)
|
||||
print("Cleanup")
|
||||
print("=" * 80)
|
||||
|
||||
if tracking_controller is not None:
|
||||
print("\nTracking final statistics:")
|
||||
stats = tracking_controller.get_statistics()
|
||||
for key, value in stats.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\nExporting tracks to JSON...")
|
||||
try:
|
||||
tracks_json = tracking_controller.export_tracks(format="json")
|
||||
with open("tracked_objects.json", "w") as f:
|
||||
f.write(tracks_json)
|
||||
print(f"✓ Tracks exported to tracked_objects.json")
|
||||
except Exception as e:
|
||||
print(f"✗ Export failed: {e}")
|
||||
|
||||
print("\nStopping decoder...")
|
||||
decoder.stop()
|
||||
print("✓ Decoder stopped")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Test completed successfully!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def test_multi_camera_tracking():
|
||||
"""
|
||||
Example: Track objects across multiple camera streams.
|
||||
|
||||
This demonstrates:
|
||||
- Shared model repository across multiple streams
|
||||
- Multiple tracking controllers (one per camera)
|
||||
- Efficient GPU resource usage
|
||||
"""
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.trt"
|
||||
|
||||
# Load multiple camera URLs
|
||||
camera_urls = []
|
||||
i = 1
|
||||
while True:
|
||||
url = os.getenv(f'CAMERA_URL_{i}')
|
||||
if url:
|
||||
camera_urls.append(url)
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if not camera_urls:
|
||||
print("No camera URLs found in .env file")
|
||||
return
|
||||
|
||||
print(f"Testing multi-camera tracking with {len(camera_urls)} cameras")
|
||||
|
||||
# Create shared model repository
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=8)
|
||||
|
||||
if os.path.exists(MODEL_PATH):
|
||||
model_repo.load_model("detector", MODEL_PATH, num_contexts=8)
|
||||
else:
|
||||
print(f"Model not found: {MODEL_PATH}")
|
||||
return
|
||||
|
||||
# Create tracking factory
|
||||
tracking_factory = TrackingFactory(gpu_id=GPU_ID)
|
||||
|
||||
# Create stream decoders and tracking controllers
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoders = []
|
||||
controllers = []
|
||||
|
||||
for i, url in enumerate(camera_urls):
|
||||
# Create decoder
|
||||
decoder = stream_factory.create_decoder(url, buffer_size=30)
|
||||
decoder.start()
|
||||
decoders.append(decoder)
|
||||
|
||||
# Create tracking controller
|
||||
controller = tracking_factory.create_controller(
|
||||
model_repository=model_repo,
|
||||
model_id="detector",
|
||||
tracker_type="iou",
|
||||
max_age=30,
|
||||
min_confidence=0.5
|
||||
)
|
||||
controllers.append(controller)
|
||||
|
||||
print(f"Camera {i+1}: {url}")
|
||||
|
||||
print(f"\nWaiting for streams to connect...")
|
||||
time.sleep(10)
|
||||
|
||||
# Track objects for 30 seconds
|
||||
print(f"\nTracking objects across {len(camera_urls)} cameras...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while time.time() - start_time < 30:
|
||||
for i, (decoder, controller) in enumerate(zip(decoders, controllers)):
|
||||
frame = decoder.get_latest_frame(rgb=True)
|
||||
|
||||
if frame is not None:
|
||||
tracked_objects = controller.track(frame)
|
||||
|
||||
# Print stats every 10 seconds
|
||||
if int(time.time() - start_time) % 10 == 0:
|
||||
stats = controller.get_statistics()
|
||||
print(f"Camera {i+1}: {stats['active_tracks']} tracks, "
|
||||
f"{stats['frame_count']} frames")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
|
||||
# Cleanup
|
||||
print("\nCleaning up...")
|
||||
for decoder in decoders:
|
||||
decoder.stop()
|
||||
|
||||
# Print final stats
|
||||
print("\nFinal Statistics:")
|
||||
for i, controller in enumerate(controllers):
|
||||
stats = controller.get_statistics()
|
||||
print(f"\nCamera {i+1}:")
|
||||
print(f" Frames: {stats['frame_count']}")
|
||||
print(f" Tracks created: {stats['total_tracks_created']}")
|
||||
print(f" Active tracks: {stats['active_tracks']}")
|
||||
|
||||
# Print model repository stats
|
||||
print("\nModel Repository Stats:")
|
||||
repo_stats = model_repo.get_stats()
|
||||
for key, value in repo_stats.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run single camera test
|
||||
main()
|
||||
|
||||
# Uncomment to test multi-camera tracking
|
||||
# test_multi_camera_tracking()
|
||||
Loading…
Add table
Add a link
Reference in a new issue