python-rtsp-worker/services/tracking_controller.py
2025-11-09 01:49:52 +07:00

524 lines
18 KiB
Python

import threading
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field
from collections import defaultdict, deque
import time
import torch
import numpy as np
from .model_repository import TensorRTModelRepository
@dataclass
class TrackedObject:
"""
Represents a tracked object with persistent ID and metadata.
Attributes:
track_id: Unique persistent tracking ID
class_id: Object class ID from detection model
class_name: Object class name (if available)
confidence: Detection confidence score (0-1)
bbox: Bounding box in format [x1, y1, x2, y2] (normalized or absolute)
last_seen_frame: Frame number when object was last detected
first_seen_frame: Frame number when object was first detected
track_history: Deque of historical bboxes for trajectory tracking
state: Custom state dict for additional tracking data
"""
track_id: int
class_id: int
class_name: str = "unknown"
confidence: float = 0.0
bbox: List[float] = field(default_factory=list)
last_seen_frame: int = 0
first_seen_frame: int = 0
track_history: deque = field(default_factory=lambda: deque(maxlen=30))
state: Dict[str, Any] = field(default_factory=dict)
def update(self, bbox: List[float], confidence: float, frame_num: int):
"""Update tracked object with new detection"""
self.bbox = bbox
self.confidence = confidence
self.last_seen_frame = frame_num
self.track_history.append((frame_num, bbox, confidence))
def age(self, current_frame: int) -> int:
"""Get age of track in frames since last seen"""
return current_frame - self.last_seen_frame
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization"""
return {
'track_id': self.track_id,
'class_id': self.class_id,
'class_name': self.class_name,
'confidence': self.confidence,
'bbox': self.bbox,
'last_seen_frame': self.last_seen_frame,
'first_seen_frame': self.first_seen_frame,
'age': self.last_seen_frame - self.first_seen_frame,
'history_length': len(self.track_history),
'state': self.state
}
class TrackingController:
"""
GPU-accelerated object tracking controller that wraps TensorRTModelRepository.
Architecture:
- Wraps model repository for dependency injection
- Maintains CUDA state for bbox tracking operations
- Stores persistent tracking data (track IDs, histories, states)
- Processes GPU tensor frames directly (zero-copy pipeline)
- Thread-safe for concurrent tracking operations
Tracking Flow:
GPU Frame → Model Inference (GPU) → Detections (GPU)
Tracking Algorithm (GPU/CPU) → Track Assignment
Update Persistent Tracks → Return Tracked Objects
Features:
- GPU-first: All tensor operations stay on GPU until final results
- Persistent IDs: Tracks maintain consistent IDs across frames
- Track History: Maintains trajectory history for each object
- Configurable: Supports custom tracking algorithms via callbacks
- Thread-safe: Mutex-based locking for concurrent access
Example:
# Initialize with DI
repo = TensorRTModelRepository(gpu_id=0)
factory = TrackingFactory(gpu_id=0)
controller = factory.create_controller(
model_repository=repo,
model_id="yolov8_detector",
tracker_type="iou"
)
# Track objects in frame
rgb_frame = decoder.get_latest_frame() # GPU tensor
tracked_objects = controller.track(rgb_frame)
# Get all tracked objects
all_tracks = controller.get_all_tracks()
"""
def __init__(self,
model_repository: TensorRTModelRepository,
model_id: str,
gpu_id: int = 0,
tracker_type: str = "iou",
max_age: int = 30,
min_confidence: float = 0.5,
iou_threshold: float = 0.3,
class_names: Optional[Dict[int, str]] = None):
"""
Initialize TrackingController.
Args:
model_repository: TensorRT model repository (dependency injection)
model_id: Model ID in repository to use for detection
gpu_id: GPU device ID
tracker_type: Tracking algorithm type ("iou", "sort", "deepsort", "bytetrack")
max_age: Maximum frames to keep track without detection
min_confidence: Minimum confidence threshold for detections
iou_threshold: IoU threshold for track association
class_names: Optional mapping of class IDs to names
"""
self.model_repository = model_repository
self.model_id = model_id
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.tracker_type = tracker_type
self.max_age = max_age
self.min_confidence = min_confidence
self.iou_threshold = iou_threshold
self.class_names = class_names or {}
# Tracking state
self._tracks: Dict[int, TrackedObject] = {}
self._next_track_id: int = 0
self._frame_count: int = 0
self._lock = threading.RLock()
# Statistics
self._total_detections = 0
self._total_tracks_created = 0
# Verify model exists in repository
metadata = self.model_repository.get_metadata(model_id)
if metadata is None:
raise ValueError(f"Model '{model_id}' not found in repository")
print(f"TrackingController initialized:")
print(f" Model ID: {model_id}")
print(f" GPU: {gpu_id}")
print(f" Tracker: {tracker_type}")
print(f" Max age: {max_age} frames")
print(f" Min confidence: {min_confidence}")
print(f" IoU threshold: {iou_threshold}")
def _compute_iou_gpu(self, boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
"""
Compute IoU between two sets of boxes on GPU.
Args:
boxes1: Tensor of shape (N, 4) in format [x1, y1, x2, y2]
boxes2: Tensor of shape (M, 4) in format [x1, y1, x2, y2]
Returns:
IoU matrix of shape (N, M)
"""
# Ensure on GPU
boxes1 = boxes1.to(self.device)
boxes2 = boxes2.to(self.device)
# Compute intersection
x1_max = torch.max(boxes1[:, None, 0], boxes2[:, 0]) # (N, M)
y1_max = torch.max(boxes1[:, None, 1], boxes2[:, 1]) # (N, M)
x2_min = torch.min(boxes1[:, None, 2], boxes2[:, 2]) # (N, M)
y2_min = torch.min(boxes1[:, None, 3], boxes2[:, 3]) # (N, M)
intersection_width = torch.clamp(x2_min - x1_max, min=0)
intersection_height = torch.clamp(y2_min - y1_max, min=0)
intersection_area = intersection_width * intersection_height
# Compute areas
boxes1_area = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
boxes2_area = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
# Compute union
union_area = boxes1_area[:, None] + boxes2_area - intersection_area
# Compute IoU
iou = intersection_area / (union_area + 1e-6)
return iou
def _iou_tracking(self, detections: torch.Tensor) -> List[Tuple[int, int]]:
"""
Simple IoU-based tracking algorithm (on GPU).
Args:
detections: Tensor of shape (N, 6) with [x1, y1, x2, y2, conf, class_id]
Returns:
List of (detection_idx, track_id) associations
"""
if len(self._tracks) == 0:
# No existing tracks, create new ones for all detections
return [(-1, -1) for _ in range(len(detections))]
# Get existing track bboxes
track_ids = list(self._tracks.keys())
track_bboxes = torch.tensor(
[self._tracks[tid].bbox for tid in track_ids],
dtype=torch.float32,
device=self.device
)
# Extract detection bboxes
det_bboxes = detections[:, :4] # (N, 4)
# Compute IoU matrix (GPU)
iou_matrix = self._compute_iou_gpu(det_bboxes, track_bboxes) # (N, M)
# Greedy matching: assign each detection to best matching track
associations = []
matched_tracks = set()
# Convert to CPU for matching logic (small matrix)
iou_cpu = iou_matrix.cpu().numpy()
for det_idx in range(len(detections)):
best_iou = self.iou_threshold
best_track_idx = -1
for track_idx, track_id in enumerate(track_ids):
if track_idx in matched_tracks:
continue
if iou_cpu[det_idx, track_idx] > best_iou:
best_iou = iou_cpu[det_idx, track_idx]
best_track_idx = track_idx
if best_track_idx >= 0:
associations.append((det_idx, track_ids[best_track_idx]))
matched_tracks.add(best_track_idx)
else:
associations.append((det_idx, -1)) # New track
return associations
def _create_track(self, bbox: List[float], confidence: float,
class_id: int, frame_num: int) -> TrackedObject:
"""Create a new tracked object"""
track_id = self._next_track_id
self._next_track_id += 1
self._total_tracks_created += 1
class_name = self.class_names.get(class_id, f"class_{class_id}")
track = TrackedObject(
track_id=track_id,
class_id=class_id,
class_name=class_name,
confidence=confidence,
bbox=bbox,
last_seen_frame=frame_num,
first_seen_frame=frame_num
)
track.track_history.append((frame_num, bbox, confidence))
return track
def _cleanup_stale_tracks(self):
"""Remove tracks that haven't been seen for max_age frames"""
stale_track_ids = [
tid for tid, track in self._tracks.items()
if track.age(self._frame_count) > self.max_age
]
for tid in stale_track_ids:
del self._tracks[tid]
def track(self, frame: torch.Tensor,
preprocess_fn: Optional[callable] = None,
postprocess_fn: Optional[callable] = None) -> List[TrackedObject]:
"""
Track objects in a GPU tensor frame.
Args:
frame: RGB frame as GPU tensor, shape (3, H, W) or (1, 3, H, W)
preprocess_fn: Optional preprocessing function (frame -> model_input)
postprocess_fn: Optional postprocessing function (model_output -> detections)
Should return tensor of shape (N, 6): [x1, y1, x2, y2, conf, class_id]
Returns:
List of currently tracked objects
"""
with self._lock:
self._frame_count += 1
# Ensure frame is on correct device
if not frame.is_cuda:
frame = frame.to(self.device)
elif frame.device != self.device:
frame = frame.to(self.device)
# Preprocess frame for model
if preprocess_fn is not None:
model_input = preprocess_fn(frame)
else:
# Default: add batch dimension if needed
if frame.dim() == 3:
model_input = frame.unsqueeze(0) # (1, 3, H, W)
else:
model_input = frame
# Run inference (GPU-to-GPU)
# Assuming model expects input named "images" or "input"
metadata = self.model_repository.get_metadata(self.model_id)
input_name = metadata.input_names[0] if metadata else "images"
outputs = self.model_repository.infer(
model_id=self.model_id,
inputs={input_name: model_input},
synchronize=True
)
# Postprocess model output to get detections
if postprocess_fn is not None:
detections = postprocess_fn(outputs)
else:
# Default: assume output is already in correct format
# Get first output tensor
output_name = list(outputs.keys())[0]
detections = outputs[output_name]
# Reshape if needed: (1, N, 6) -> (N, 6)
if detections.dim() == 3:
detections = detections.squeeze(0)
# Filter by confidence
if detections.dim() == 2 and detections.shape[1] >= 5:
conf_mask = detections[:, 4] >= self.min_confidence
detections = detections[conf_mask]
self._total_detections += len(detections)
# Track objects
if len(detections) == 0:
# No detections, just cleanup stale tracks
self._cleanup_stale_tracks()
return list(self._tracks.values())
# Run tracking algorithm
if self.tracker_type == "iou":
associations = self._iou_tracking(detections)
else:
raise NotImplementedError(f"Tracker type '{self.tracker_type}' not implemented")
# Update tracks based on associations
for det_idx, track_id in associations:
detection = detections[det_idx]
bbox = detection[:4].cpu().tolist()
confidence = float(detection[4])
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
if track_id == -1:
# Create new track
new_track = self._create_track(bbox, confidence, class_id, self._frame_count)
self._tracks[new_track.track_id] = new_track
else:
# Update existing track
self._tracks[track_id].update(bbox, confidence, self._frame_count)
# Cleanup stale tracks
self._cleanup_stale_tracks()
return list(self._tracks.values())
def get_all_tracks(self, active_only: bool = True) -> List[TrackedObject]:
"""
Get all tracked objects.
Args:
active_only: If True, only return tracks seen in recent frames
Returns:
List of tracked objects
"""
with self._lock:
if active_only:
return [
track for track in self._tracks.values()
if track.age(self._frame_count) <= self.max_age
]
else:
return list(self._tracks.values())
def get_track_by_id(self, track_id: int) -> Optional[TrackedObject]:
"""
Get a specific track by ID.
Args:
track_id: Track ID to retrieve
Returns:
TrackedObject or None if not found
"""
with self._lock:
return self._tracks.get(track_id)
def get_tracks_by_class(self, class_id: int, active_only: bool = True) -> List[TrackedObject]:
"""
Get all tracks of a specific class.
Args:
class_id: Class ID to filter by
active_only: If True, only return active tracks
Returns:
List of tracked objects
"""
all_tracks = self.get_all_tracks(active_only=active_only)
return [track for track in all_tracks if track.class_id == class_id]
def get_track_count(self, active_only: bool = True) -> int:
"""
Get number of tracked objects.
Args:
active_only: If True, only count active tracks
Returns:
Number of tracks
"""
return len(self.get_all_tracks(active_only=active_only))
def get_class_counts(self, active_only: bool = True) -> Dict[int, int]:
"""
Get count of tracked objects per class.
Args:
active_only: If True, only count active tracks
Returns:
Dictionary mapping class_id to count
"""
tracks = self.get_all_tracks(active_only=active_only)
counts = defaultdict(int)
for track in tracks:
counts[track.class_id] += 1
return dict(counts)
def reset_tracks(self):
"""Clear all tracking state"""
with self._lock:
self._tracks.clear()
self._next_track_id = 0
self._frame_count = 0
print("Tracking state reset")
def get_statistics(self) -> Dict[str, Any]:
"""
Get tracking statistics.
Returns:
Dictionary with tracking stats
"""
with self._lock:
return {
'frame_count': self._frame_count,
'active_tracks': len(self._tracks),
'total_tracks_created': self._total_tracks_created,
'total_detections': self._total_detections,
'avg_detections_per_frame': self._total_detections / max(self._frame_count, 1),
'model_id': self.model_id,
'tracker_type': self.tracker_type,
'class_counts': self.get_class_counts(active_only=True)
}
def export_tracks(self, format: str = "dict") -> Any:
"""
Export all tracks in specified format.
Args:
format: Export format ("dict", "json", "numpy")
Returns:
Tracks in specified format
"""
with self._lock:
tracks = self.get_all_tracks(active_only=False)
if format == "dict":
return {track.track_id: track.to_dict() for track in tracks}
elif format == "json":
import json
return json.dumps(
{track.track_id: track.to_dict() for track in tracks},
indent=2
)
elif format == "numpy":
# Export as numpy array: [track_id, class_id, x1, y1, x2, y2, conf]
data = []
for track in tracks:
data.append([
track.track_id,
track.class_id,
*track.bbox,
track.confidence
])
return np.array(data) if data else np.array([])
else:
raise ValueError(f"Unknown export format: {format}")
def __repr__(self):
with self._lock:
return (f"TrackingController(model={self.model_id}, "
f"tracker={self.tracker_type}, "
f"frame={self._frame_count}, "
f"tracks={len(self._tracks)})")