import threading from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass, field from collections import defaultdict, deque import time import torch import numpy as np from .model_repository import TensorRTModelRepository @dataclass class TrackedObject: """ Represents a tracked object with persistent ID and metadata. Attributes: track_id: Unique persistent tracking ID class_id: Object class ID from detection model class_name: Object class name (if available) confidence: Detection confidence score (0-1) bbox: Bounding box in format [x1, y1, x2, y2] (normalized or absolute) last_seen_frame: Frame number when object was last detected first_seen_frame: Frame number when object was first detected track_history: Deque of historical bboxes for trajectory tracking state: Custom state dict for additional tracking data """ track_id: int class_id: int class_name: str = "unknown" confidence: float = 0.0 bbox: List[float] = field(default_factory=list) last_seen_frame: int = 0 first_seen_frame: int = 0 track_history: deque = field(default_factory=lambda: deque(maxlen=30)) state: Dict[str, Any] = field(default_factory=dict) def update(self, bbox: List[float], confidence: float, frame_num: int): """Update tracked object with new detection""" self.bbox = bbox self.confidence = confidence self.last_seen_frame = frame_num self.track_history.append((frame_num, bbox, confidence)) def age(self, current_frame: int) -> int: """Get age of track in frames since last seen""" return current_frame - self.last_seen_frame def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization""" return { 'track_id': self.track_id, 'class_id': self.class_id, 'class_name': self.class_name, 'confidence': self.confidence, 'bbox': self.bbox, 'last_seen_frame': self.last_seen_frame, 'first_seen_frame': self.first_seen_frame, 'age': self.last_seen_frame - self.first_seen_frame, 'history_length': len(self.track_history), 'state': self.state } class TrackingController: """ GPU-accelerated object tracking controller that wraps TensorRTModelRepository. Architecture: - Wraps model repository for dependency injection - Maintains CUDA state for bbox tracking operations - Stores persistent tracking data (track IDs, histories, states) - Processes GPU tensor frames directly (zero-copy pipeline) - Thread-safe for concurrent tracking operations Tracking Flow: GPU Frame → Model Inference (GPU) → Detections (GPU) ↓ Tracking Algorithm (GPU/CPU) → Track Assignment ↓ Update Persistent Tracks → Return Tracked Objects Features: - GPU-first: All tensor operations stay on GPU until final results - Persistent IDs: Tracks maintain consistent IDs across frames - Track History: Maintains trajectory history for each object - Configurable: Supports custom tracking algorithms via callbacks - Thread-safe: Mutex-based locking for concurrent access Example: # Initialize with DI repo = TensorRTModelRepository(gpu_id=0) factory = TrackingFactory(gpu_id=0) controller = factory.create_controller( model_repository=repo, model_id="yolov8_detector", tracker_type="iou" ) # Track objects in frame rgb_frame = decoder.get_latest_frame() # GPU tensor tracked_objects = controller.track(rgb_frame) # Get all tracked objects all_tracks = controller.get_all_tracks() """ def __init__(self, model_repository: TensorRTModelRepository, model_id: str, gpu_id: int = 0, tracker_type: str = "iou", max_age: int = 30, min_confidence: float = 0.5, iou_threshold: float = 0.3, class_names: Optional[Dict[int, str]] = None): """ Initialize TrackingController. Args: model_repository: TensorRT model repository (dependency injection) model_id: Model ID in repository to use for detection gpu_id: GPU device ID tracker_type: Tracking algorithm type ("iou", "sort", "deepsort", "bytetrack") max_age: Maximum frames to keep track without detection min_confidence: Minimum confidence threshold for detections iou_threshold: IoU threshold for track association class_names: Optional mapping of class IDs to names """ self.model_repository = model_repository self.model_id = model_id self.gpu_id = gpu_id self.device = torch.device(f'cuda:{gpu_id}') self.tracker_type = tracker_type self.max_age = max_age self.min_confidence = min_confidence self.iou_threshold = iou_threshold self.class_names = class_names or {} # Tracking state self._tracks: Dict[int, TrackedObject] = {} self._next_track_id: int = 0 self._frame_count: int = 0 self._lock = threading.RLock() # Statistics self._total_detections = 0 self._total_tracks_created = 0 # Verify model exists in repository metadata = self.model_repository.get_metadata(model_id) if metadata is None: raise ValueError(f"Model '{model_id}' not found in repository") print(f"TrackingController initialized:") print(f" Model ID: {model_id}") print(f" GPU: {gpu_id}") print(f" Tracker: {tracker_type}") print(f" Max age: {max_age} frames") print(f" Min confidence: {min_confidence}") print(f" IoU threshold: {iou_threshold}") def _compute_iou_gpu(self, boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: """ Compute IoU between two sets of boxes on GPU. Args: boxes1: Tensor of shape (N, 4) in format [x1, y1, x2, y2] boxes2: Tensor of shape (M, 4) in format [x1, y1, x2, y2] Returns: IoU matrix of shape (N, M) """ # Ensure on GPU boxes1 = boxes1.to(self.device) boxes2 = boxes2.to(self.device) # Compute intersection x1_max = torch.max(boxes1[:, None, 0], boxes2[:, 0]) # (N, M) y1_max = torch.max(boxes1[:, None, 1], boxes2[:, 1]) # (N, M) x2_min = torch.min(boxes1[:, None, 2], boxes2[:, 2]) # (N, M) y2_min = torch.min(boxes1[:, None, 3], boxes2[:, 3]) # (N, M) intersection_width = torch.clamp(x2_min - x1_max, min=0) intersection_height = torch.clamp(y2_min - y1_max, min=0) intersection_area = intersection_width * intersection_height # Compute areas boxes1_area = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) boxes2_area = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # Compute union union_area = boxes1_area[:, None] + boxes2_area - intersection_area # Compute IoU iou = intersection_area / (union_area + 1e-6) return iou def _iou_tracking(self, detections: torch.Tensor) -> List[Tuple[int, int]]: """ Simple IoU-based tracking algorithm (on GPU). Args: detections: Tensor of shape (N, 6) with [x1, y1, x2, y2, conf, class_id] Returns: List of (detection_idx, track_id) associations """ if len(self._tracks) == 0: # No existing tracks, create new ones for all detections return [(-1, -1) for _ in range(len(detections))] # Get existing track bboxes track_ids = list(self._tracks.keys()) track_bboxes = torch.tensor( [self._tracks[tid].bbox for tid in track_ids], dtype=torch.float32, device=self.device ) # Extract detection bboxes det_bboxes = detections[:, :4] # (N, 4) # Compute IoU matrix (GPU) iou_matrix = self._compute_iou_gpu(det_bboxes, track_bboxes) # (N, M) # Greedy matching: assign each detection to best matching track associations = [] matched_tracks = set() # Convert to CPU for matching logic (small matrix) iou_cpu = iou_matrix.cpu().numpy() for det_idx in range(len(detections)): best_iou = self.iou_threshold best_track_idx = -1 for track_idx, track_id in enumerate(track_ids): if track_idx in matched_tracks: continue if iou_cpu[det_idx, track_idx] > best_iou: best_iou = iou_cpu[det_idx, track_idx] best_track_idx = track_idx if best_track_idx >= 0: associations.append((det_idx, track_ids[best_track_idx])) matched_tracks.add(best_track_idx) else: associations.append((det_idx, -1)) # New track return associations def _create_track(self, bbox: List[float], confidence: float, class_id: int, frame_num: int) -> TrackedObject: """Create a new tracked object""" track_id = self._next_track_id self._next_track_id += 1 self._total_tracks_created += 1 class_name = self.class_names.get(class_id, f"class_{class_id}") track = TrackedObject( track_id=track_id, class_id=class_id, class_name=class_name, confidence=confidence, bbox=bbox, last_seen_frame=frame_num, first_seen_frame=frame_num ) track.track_history.append((frame_num, bbox, confidence)) return track def _cleanup_stale_tracks(self): """Remove tracks that haven't been seen for max_age frames""" stale_track_ids = [ tid for tid, track in self._tracks.items() if track.age(self._frame_count) > self.max_age ] for tid in stale_track_ids: del self._tracks[tid] def track(self, frame: torch.Tensor, preprocess_fn: Optional[callable] = None, postprocess_fn: Optional[callable] = None) -> List[TrackedObject]: """ Track objects in a GPU tensor frame. Args: frame: RGB frame as GPU tensor, shape (3, H, W) or (1, 3, H, W) preprocess_fn: Optional preprocessing function (frame -> model_input) postprocess_fn: Optional postprocessing function (model_output -> detections) Should return tensor of shape (N, 6): [x1, y1, x2, y2, conf, class_id] Returns: List of currently tracked objects """ with self._lock: self._frame_count += 1 # Ensure frame is on correct device if not frame.is_cuda: frame = frame.to(self.device) elif frame.device != self.device: frame = frame.to(self.device) # Preprocess frame for model if preprocess_fn is not None: model_input = preprocess_fn(frame) else: # Default: add batch dimension if needed if frame.dim() == 3: model_input = frame.unsqueeze(0) # (1, 3, H, W) else: model_input = frame # Run inference (GPU-to-GPU) # Assuming model expects input named "images" or "input" metadata = self.model_repository.get_metadata(self.model_id) input_name = metadata.input_names[0] if metadata else "images" outputs = self.model_repository.infer( model_id=self.model_id, inputs={input_name: model_input}, synchronize=True ) # Postprocess model output to get detections if postprocess_fn is not None: detections = postprocess_fn(outputs) else: # Default: assume output is already in correct format # Get first output tensor output_name = list(outputs.keys())[0] detections = outputs[output_name] # Reshape if needed: (1, N, 6) -> (N, 6) if detections.dim() == 3: detections = detections.squeeze(0) # Filter by confidence if detections.dim() == 2 and detections.shape[1] >= 5: conf_mask = detections[:, 4] >= self.min_confidence detections = detections[conf_mask] self._total_detections += len(detections) # Track objects if len(detections) == 0: # No detections, just cleanup stale tracks self._cleanup_stale_tracks() return list(self._tracks.values()) # Run tracking algorithm if self.tracker_type == "iou": associations = self._iou_tracking(detections) else: raise NotImplementedError(f"Tracker type '{self.tracker_type}' not implemented") # Update tracks based on associations for det_idx, track_id in associations: detection = detections[det_idx] bbox = detection[:4].cpu().tolist() confidence = float(detection[4]) class_id = int(detection[5]) if detection.shape[0] > 5 else 0 if track_id == -1: # Create new track new_track = self._create_track(bbox, confidence, class_id, self._frame_count) self._tracks[new_track.track_id] = new_track else: # Update existing track self._tracks[track_id].update(bbox, confidence, self._frame_count) # Cleanup stale tracks self._cleanup_stale_tracks() return list(self._tracks.values()) def get_all_tracks(self, active_only: bool = True) -> List[TrackedObject]: """ Get all tracked objects. Args: active_only: If True, only return tracks seen in recent frames Returns: List of tracked objects """ with self._lock: if active_only: return [ track for track in self._tracks.values() if track.age(self._frame_count) <= self.max_age ] else: return list(self._tracks.values()) def get_track_by_id(self, track_id: int) -> Optional[TrackedObject]: """ Get a specific track by ID. Args: track_id: Track ID to retrieve Returns: TrackedObject or None if not found """ with self._lock: return self._tracks.get(track_id) def get_tracks_by_class(self, class_id: int, active_only: bool = True) -> List[TrackedObject]: """ Get all tracks of a specific class. Args: class_id: Class ID to filter by active_only: If True, only return active tracks Returns: List of tracked objects """ all_tracks = self.get_all_tracks(active_only=active_only) return [track for track in all_tracks if track.class_id == class_id] def get_track_count(self, active_only: bool = True) -> int: """ Get number of tracked objects. Args: active_only: If True, only count active tracks Returns: Number of tracks """ return len(self.get_all_tracks(active_only=active_only)) def get_class_counts(self, active_only: bool = True) -> Dict[int, int]: """ Get count of tracked objects per class. Args: active_only: If True, only count active tracks Returns: Dictionary mapping class_id to count """ tracks = self.get_all_tracks(active_only=active_only) counts = defaultdict(int) for track in tracks: counts[track.class_id] += 1 return dict(counts) def reset_tracks(self): """Clear all tracking state""" with self._lock: self._tracks.clear() self._next_track_id = 0 self._frame_count = 0 print("Tracking state reset") def get_statistics(self) -> Dict[str, Any]: """ Get tracking statistics. Returns: Dictionary with tracking stats """ with self._lock: return { 'frame_count': self._frame_count, 'active_tracks': len(self._tracks), 'total_tracks_created': self._total_tracks_created, 'total_detections': self._total_detections, 'avg_detections_per_frame': self._total_detections / max(self._frame_count, 1), 'model_id': self.model_id, 'tracker_type': self.tracker_type, 'class_counts': self.get_class_counts(active_only=True) } def export_tracks(self, format: str = "dict") -> Any: """ Export all tracks in specified format. Args: format: Export format ("dict", "json", "numpy") Returns: Tracks in specified format """ with self._lock: tracks = self.get_all_tracks(active_only=False) if format == "dict": return {track.track_id: track.to_dict() for track in tracks} elif format == "json": import json return json.dumps( {track.track_id: track.to_dict() for track in tracks}, indent=2 ) elif format == "numpy": # Export as numpy array: [track_id, class_id, x1, y1, x2, y2, conf] data = [] for track in tracks: data.append([ track.track_id, track.class_id, *track.bbox, track.confidence ]) return np.array(data) if data else np.array([]) else: raise ValueError(f"Unknown export format: {format}") def __repr__(self): with self._lock: return (f"TrackingController(model={self.model_id}, " f"tracker={self.tracker_type}, " f"frame={self._frame_count}, " f"tracks={len(self._tracks)})")