feat: trac king

This commit is contained in:
Siwat Sirichai 2025-11-09 01:49:52 +07:00
parent cf24a172a2
commit bea895d3d8
4 changed files with 1054 additions and 0 deletions

View file

@ -0,0 +1,524 @@
import threading
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, field
from collections import defaultdict, deque
import time
import torch
import numpy as np
from .model_repository import TensorRTModelRepository
@dataclass
class TrackedObject:
"""
Represents a tracked object with persistent ID and metadata.
Attributes:
track_id: Unique persistent tracking ID
class_id: Object class ID from detection model
class_name: Object class name (if available)
confidence: Detection confidence score (0-1)
bbox: Bounding box in format [x1, y1, x2, y2] (normalized or absolute)
last_seen_frame: Frame number when object was last detected
first_seen_frame: Frame number when object was first detected
track_history: Deque of historical bboxes for trajectory tracking
state: Custom state dict for additional tracking data
"""
track_id: int
class_id: int
class_name: str = "unknown"
confidence: float = 0.0
bbox: List[float] = field(default_factory=list)
last_seen_frame: int = 0
first_seen_frame: int = 0
track_history: deque = field(default_factory=lambda: deque(maxlen=30))
state: Dict[str, Any] = field(default_factory=dict)
def update(self, bbox: List[float], confidence: float, frame_num: int):
"""Update tracked object with new detection"""
self.bbox = bbox
self.confidence = confidence
self.last_seen_frame = frame_num
self.track_history.append((frame_num, bbox, confidence))
def age(self, current_frame: int) -> int:
"""Get age of track in frames since last seen"""
return current_frame - self.last_seen_frame
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization"""
return {
'track_id': self.track_id,
'class_id': self.class_id,
'class_name': self.class_name,
'confidence': self.confidence,
'bbox': self.bbox,
'last_seen_frame': self.last_seen_frame,
'first_seen_frame': self.first_seen_frame,
'age': self.last_seen_frame - self.first_seen_frame,
'history_length': len(self.track_history),
'state': self.state
}
class TrackingController:
"""
GPU-accelerated object tracking controller that wraps TensorRTModelRepository.
Architecture:
- Wraps model repository for dependency injection
- Maintains CUDA state for bbox tracking operations
- Stores persistent tracking data (track IDs, histories, states)
- Processes GPU tensor frames directly (zero-copy pipeline)
- Thread-safe for concurrent tracking operations
Tracking Flow:
GPU Frame Model Inference (GPU) Detections (GPU)
Tracking Algorithm (GPU/CPU) Track Assignment
Update Persistent Tracks Return Tracked Objects
Features:
- GPU-first: All tensor operations stay on GPU until final results
- Persistent IDs: Tracks maintain consistent IDs across frames
- Track History: Maintains trajectory history for each object
- Configurable: Supports custom tracking algorithms via callbacks
- Thread-safe: Mutex-based locking for concurrent access
Example:
# Initialize with DI
repo = TensorRTModelRepository(gpu_id=0)
factory = TrackingFactory(gpu_id=0)
controller = factory.create_controller(
model_repository=repo,
model_id="yolov8_detector",
tracker_type="iou"
)
# Track objects in frame
rgb_frame = decoder.get_latest_frame() # GPU tensor
tracked_objects = controller.track(rgb_frame)
# Get all tracked objects
all_tracks = controller.get_all_tracks()
"""
def __init__(self,
model_repository: TensorRTModelRepository,
model_id: str,
gpu_id: int = 0,
tracker_type: str = "iou",
max_age: int = 30,
min_confidence: float = 0.5,
iou_threshold: float = 0.3,
class_names: Optional[Dict[int, str]] = None):
"""
Initialize TrackingController.
Args:
model_repository: TensorRT model repository (dependency injection)
model_id: Model ID in repository to use for detection
gpu_id: GPU device ID
tracker_type: Tracking algorithm type ("iou", "sort", "deepsort", "bytetrack")
max_age: Maximum frames to keep track without detection
min_confidence: Minimum confidence threshold for detections
iou_threshold: IoU threshold for track association
class_names: Optional mapping of class IDs to names
"""
self.model_repository = model_repository
self.model_id = model_id
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.tracker_type = tracker_type
self.max_age = max_age
self.min_confidence = min_confidence
self.iou_threshold = iou_threshold
self.class_names = class_names or {}
# Tracking state
self._tracks: Dict[int, TrackedObject] = {}
self._next_track_id: int = 0
self._frame_count: int = 0
self._lock = threading.RLock()
# Statistics
self._total_detections = 0
self._total_tracks_created = 0
# Verify model exists in repository
metadata = self.model_repository.get_metadata(model_id)
if metadata is None:
raise ValueError(f"Model '{model_id}' not found in repository")
print(f"TrackingController initialized:")
print(f" Model ID: {model_id}")
print(f" GPU: {gpu_id}")
print(f" Tracker: {tracker_type}")
print(f" Max age: {max_age} frames")
print(f" Min confidence: {min_confidence}")
print(f" IoU threshold: {iou_threshold}")
def _compute_iou_gpu(self, boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
"""
Compute IoU between two sets of boxes on GPU.
Args:
boxes1: Tensor of shape (N, 4) in format [x1, y1, x2, y2]
boxes2: Tensor of shape (M, 4) in format [x1, y1, x2, y2]
Returns:
IoU matrix of shape (N, M)
"""
# Ensure on GPU
boxes1 = boxes1.to(self.device)
boxes2 = boxes2.to(self.device)
# Compute intersection
x1_max = torch.max(boxes1[:, None, 0], boxes2[:, 0]) # (N, M)
y1_max = torch.max(boxes1[:, None, 1], boxes2[:, 1]) # (N, M)
x2_min = torch.min(boxes1[:, None, 2], boxes2[:, 2]) # (N, M)
y2_min = torch.min(boxes1[:, None, 3], boxes2[:, 3]) # (N, M)
intersection_width = torch.clamp(x2_min - x1_max, min=0)
intersection_height = torch.clamp(y2_min - y1_max, min=0)
intersection_area = intersection_width * intersection_height
# Compute areas
boxes1_area = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
boxes2_area = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
# Compute union
union_area = boxes1_area[:, None] + boxes2_area - intersection_area
# Compute IoU
iou = intersection_area / (union_area + 1e-6)
return iou
def _iou_tracking(self, detections: torch.Tensor) -> List[Tuple[int, int]]:
"""
Simple IoU-based tracking algorithm (on GPU).
Args:
detections: Tensor of shape (N, 6) with [x1, y1, x2, y2, conf, class_id]
Returns:
List of (detection_idx, track_id) associations
"""
if len(self._tracks) == 0:
# No existing tracks, create new ones for all detections
return [(-1, -1) for _ in range(len(detections))]
# Get existing track bboxes
track_ids = list(self._tracks.keys())
track_bboxes = torch.tensor(
[self._tracks[tid].bbox for tid in track_ids],
dtype=torch.float32,
device=self.device
)
# Extract detection bboxes
det_bboxes = detections[:, :4] # (N, 4)
# Compute IoU matrix (GPU)
iou_matrix = self._compute_iou_gpu(det_bboxes, track_bboxes) # (N, M)
# Greedy matching: assign each detection to best matching track
associations = []
matched_tracks = set()
# Convert to CPU for matching logic (small matrix)
iou_cpu = iou_matrix.cpu().numpy()
for det_idx in range(len(detections)):
best_iou = self.iou_threshold
best_track_idx = -1
for track_idx, track_id in enumerate(track_ids):
if track_idx in matched_tracks:
continue
if iou_cpu[det_idx, track_idx] > best_iou:
best_iou = iou_cpu[det_idx, track_idx]
best_track_idx = track_idx
if best_track_idx >= 0:
associations.append((det_idx, track_ids[best_track_idx]))
matched_tracks.add(best_track_idx)
else:
associations.append((det_idx, -1)) # New track
return associations
def _create_track(self, bbox: List[float], confidence: float,
class_id: int, frame_num: int) -> TrackedObject:
"""Create a new tracked object"""
track_id = self._next_track_id
self._next_track_id += 1
self._total_tracks_created += 1
class_name = self.class_names.get(class_id, f"class_{class_id}")
track = TrackedObject(
track_id=track_id,
class_id=class_id,
class_name=class_name,
confidence=confidence,
bbox=bbox,
last_seen_frame=frame_num,
first_seen_frame=frame_num
)
track.track_history.append((frame_num, bbox, confidence))
return track
def _cleanup_stale_tracks(self):
"""Remove tracks that haven't been seen for max_age frames"""
stale_track_ids = [
tid for tid, track in self._tracks.items()
if track.age(self._frame_count) > self.max_age
]
for tid in stale_track_ids:
del self._tracks[tid]
def track(self, frame: torch.Tensor,
preprocess_fn: Optional[callable] = None,
postprocess_fn: Optional[callable] = None) -> List[TrackedObject]:
"""
Track objects in a GPU tensor frame.
Args:
frame: RGB frame as GPU tensor, shape (3, H, W) or (1, 3, H, W)
preprocess_fn: Optional preprocessing function (frame -> model_input)
postprocess_fn: Optional postprocessing function (model_output -> detections)
Should return tensor of shape (N, 6): [x1, y1, x2, y2, conf, class_id]
Returns:
List of currently tracked objects
"""
with self._lock:
self._frame_count += 1
# Ensure frame is on correct device
if not frame.is_cuda:
frame = frame.to(self.device)
elif frame.device != self.device:
frame = frame.to(self.device)
# Preprocess frame for model
if preprocess_fn is not None:
model_input = preprocess_fn(frame)
else:
# Default: add batch dimension if needed
if frame.dim() == 3:
model_input = frame.unsqueeze(0) # (1, 3, H, W)
else:
model_input = frame
# Run inference (GPU-to-GPU)
# Assuming model expects input named "images" or "input"
metadata = self.model_repository.get_metadata(self.model_id)
input_name = metadata.input_names[0] if metadata else "images"
outputs = self.model_repository.infer(
model_id=self.model_id,
inputs={input_name: model_input},
synchronize=True
)
# Postprocess model output to get detections
if postprocess_fn is not None:
detections = postprocess_fn(outputs)
else:
# Default: assume output is already in correct format
# Get first output tensor
output_name = list(outputs.keys())[0]
detections = outputs[output_name]
# Reshape if needed: (1, N, 6) -> (N, 6)
if detections.dim() == 3:
detections = detections.squeeze(0)
# Filter by confidence
if detections.dim() == 2 and detections.shape[1] >= 5:
conf_mask = detections[:, 4] >= self.min_confidence
detections = detections[conf_mask]
self._total_detections += len(detections)
# Track objects
if len(detections) == 0:
# No detections, just cleanup stale tracks
self._cleanup_stale_tracks()
return list(self._tracks.values())
# Run tracking algorithm
if self.tracker_type == "iou":
associations = self._iou_tracking(detections)
else:
raise NotImplementedError(f"Tracker type '{self.tracker_type}' not implemented")
# Update tracks based on associations
for det_idx, track_id in associations:
detection = detections[det_idx]
bbox = detection[:4].cpu().tolist()
confidence = float(detection[4])
class_id = int(detection[5]) if detection.shape[0] > 5 else 0
if track_id == -1:
# Create new track
new_track = self._create_track(bbox, confidence, class_id, self._frame_count)
self._tracks[new_track.track_id] = new_track
else:
# Update existing track
self._tracks[track_id].update(bbox, confidence, self._frame_count)
# Cleanup stale tracks
self._cleanup_stale_tracks()
return list(self._tracks.values())
def get_all_tracks(self, active_only: bool = True) -> List[TrackedObject]:
"""
Get all tracked objects.
Args:
active_only: If True, only return tracks seen in recent frames
Returns:
List of tracked objects
"""
with self._lock:
if active_only:
return [
track for track in self._tracks.values()
if track.age(self._frame_count) <= self.max_age
]
else:
return list(self._tracks.values())
def get_track_by_id(self, track_id: int) -> Optional[TrackedObject]:
"""
Get a specific track by ID.
Args:
track_id: Track ID to retrieve
Returns:
TrackedObject or None if not found
"""
with self._lock:
return self._tracks.get(track_id)
def get_tracks_by_class(self, class_id: int, active_only: bool = True) -> List[TrackedObject]:
"""
Get all tracks of a specific class.
Args:
class_id: Class ID to filter by
active_only: If True, only return active tracks
Returns:
List of tracked objects
"""
all_tracks = self.get_all_tracks(active_only=active_only)
return [track for track in all_tracks if track.class_id == class_id]
def get_track_count(self, active_only: bool = True) -> int:
"""
Get number of tracked objects.
Args:
active_only: If True, only count active tracks
Returns:
Number of tracks
"""
return len(self.get_all_tracks(active_only=active_only))
def get_class_counts(self, active_only: bool = True) -> Dict[int, int]:
"""
Get count of tracked objects per class.
Args:
active_only: If True, only count active tracks
Returns:
Dictionary mapping class_id to count
"""
tracks = self.get_all_tracks(active_only=active_only)
counts = defaultdict(int)
for track in tracks:
counts[track.class_id] += 1
return dict(counts)
def reset_tracks(self):
"""Clear all tracking state"""
with self._lock:
self._tracks.clear()
self._next_track_id = 0
self._frame_count = 0
print("Tracking state reset")
def get_statistics(self) -> Dict[str, Any]:
"""
Get tracking statistics.
Returns:
Dictionary with tracking stats
"""
with self._lock:
return {
'frame_count': self._frame_count,
'active_tracks': len(self._tracks),
'total_tracks_created': self._total_tracks_created,
'total_detections': self._total_detections,
'avg_detections_per_frame': self._total_detections / max(self._frame_count, 1),
'model_id': self.model_id,
'tracker_type': self.tracker_type,
'class_counts': self.get_class_counts(active_only=True)
}
def export_tracks(self, format: str = "dict") -> Any:
"""
Export all tracks in specified format.
Args:
format: Export format ("dict", "json", "numpy")
Returns:
Tracks in specified format
"""
with self._lock:
tracks = self.get_all_tracks(active_only=False)
if format == "dict":
return {track.track_id: track.to_dict() for track in tracks}
elif format == "json":
import json
return json.dumps(
{track.track_id: track.to_dict() for track in tracks},
indent=2
)
elif format == "numpy":
# Export as numpy array: [track_id, class_id, x1, y1, x2, y2, conf]
data = []
for track in tracks:
data.append([
track.track_id,
track.class_id,
*track.bbox,
track.confidence
])
return np.array(data) if data else np.array([])
else:
raise ValueError(f"Unknown export format: {format}")
def __repr__(self):
with self._lock:
return (f"TrackingController(model={self.model_id}, "
f"tracker={self.tracker_type}, "
f"frame={self._frame_count}, "
f"tracks={len(self._tracks)})")