205 lines
7 KiB
Python
205 lines
7 KiB
Python
import threading
|
|
from typing import Optional, Dict
|
|
from .tracking_controller import ObjectTracker
|
|
from .model_repository import TensorRTModelRepository
|
|
|
|
# Backward compatibility alias (TrackingFactory is deprecated in event-driven mode)
|
|
TrackingController = ObjectTracker
|
|
|
|
|
|
class TrackingFactory:
|
|
"""
|
|
Factory for creating TrackingController instances with shared GPU resources.
|
|
|
|
This factory follows the same pattern as StreamDecoderFactory for consistency:
|
|
- Singleton pattern per GPU
|
|
- Manages shared CUDA state if needed
|
|
- Provides centralized controller creation
|
|
- Thread-safe controller fabrication
|
|
|
|
The factory doesn't need to manage CUDA context directly since the
|
|
TensorRTModelRepository already handles GPU resource management.
|
|
Instead, it provides a clean interface for creating controllers with
|
|
proper configuration and dependency injection.
|
|
|
|
Example:
|
|
# Get factory instance for GPU 0
|
|
factory = TrackingFactory(gpu_id=0)
|
|
|
|
# Create model repository
|
|
repo = TensorRTModelRepository(gpu_id=0)
|
|
repo.load_model("detector", "yolov8.trt")
|
|
|
|
# Create tracking controller
|
|
controller = factory.create_controller(
|
|
model_repository=repo,
|
|
model_id="detector",
|
|
tracker_type="iou"
|
|
)
|
|
|
|
# Multiple controllers share the same model repository
|
|
controller2 = factory.create_controller(
|
|
model_repository=repo,
|
|
model_id="detector",
|
|
tracker_type="iou"
|
|
)
|
|
"""
|
|
|
|
_instances: Dict[int, 'TrackingFactory'] = {}
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls, gpu_id: int = 0):
|
|
"""
|
|
Singleton pattern per GPU.
|
|
Each GPU gets its own factory instance.
|
|
"""
|
|
if gpu_id not in cls._instances:
|
|
with cls._lock:
|
|
if gpu_id not in cls._instances:
|
|
instance = super(TrackingFactory, cls).__new__(cls)
|
|
instance._initialized = False
|
|
cls._instances[gpu_id] = instance
|
|
return cls._instances[gpu_id]
|
|
|
|
def __init__(self, gpu_id: int = 0):
|
|
"""
|
|
Initialize the tracking factory.
|
|
|
|
Args:
|
|
gpu_id: GPU device ID to use
|
|
"""
|
|
if self._initialized:
|
|
return
|
|
|
|
self.gpu_id = gpu_id
|
|
self._controller_count = 0
|
|
self._controller_lock = threading.Lock()
|
|
|
|
self._initialized = True
|
|
print(f"TrackingFactory initialized for GPU {gpu_id}")
|
|
|
|
def create_controller(self,
|
|
model_repository: TensorRTModelRepository,
|
|
model_id: str,
|
|
tracker_type: str = "iou",
|
|
max_age: int = 30,
|
|
min_confidence: float = 0.5,
|
|
iou_threshold: float = 0.3,
|
|
class_names: Optional[Dict[int, str]] = None) -> TrackingController:
|
|
"""
|
|
Create a new TrackingController instance.
|
|
|
|
Args:
|
|
model_repository: TensorRT model repository (dependency injection)
|
|
model_id: Model ID in repository to use for detection
|
|
tracker_type: Tracking algorithm type ("iou", "sort", "deepsort", "bytetrack")
|
|
max_age: Maximum frames to keep track without detection
|
|
min_confidence: Minimum confidence threshold for detections
|
|
iou_threshold: IoU threshold for track association
|
|
class_names: Optional mapping of class IDs to names
|
|
|
|
Returns:
|
|
TrackingController instance
|
|
|
|
Raises:
|
|
ValueError: If model_repository GPU doesn't match factory GPU
|
|
ValueError: If model_id not found in repository
|
|
"""
|
|
# Validate GPU ID matches
|
|
if model_repository.gpu_id != self.gpu_id:
|
|
raise ValueError(
|
|
f"Model repository GPU ({model_repository.gpu_id}) doesn't match "
|
|
f"factory GPU ({self.gpu_id})"
|
|
)
|
|
|
|
# Verify model exists
|
|
if model_repository.get_metadata(model_id) is None:
|
|
raise ValueError(
|
|
f"Model '{model_id}' not found in repository. "
|
|
f"Available models: {list(model_repository._model_to_hash.keys())}"
|
|
)
|
|
|
|
with self._controller_lock:
|
|
self._controller_count += 1
|
|
|
|
controller = TrackingController(
|
|
model_repository=model_repository,
|
|
model_id=model_id,
|
|
gpu_id=self.gpu_id,
|
|
tracker_type=tracker_type,
|
|
max_age=max_age,
|
|
min_confidence=min_confidence,
|
|
iou_threshold=iou_threshold,
|
|
class_names=class_names
|
|
)
|
|
|
|
print(f"Created TrackingController #{self._controller_count} (model: {model_id})")
|
|
|
|
return controller
|
|
|
|
def create_multi_model_controller(self,
|
|
model_repository: TensorRTModelRepository,
|
|
model_configs: Dict[str, Dict],
|
|
ensemble_strategy: str = "nms") -> 'MultiModelTrackingController':
|
|
"""
|
|
Create a multi-model tracking controller that combines multiple detectors.
|
|
|
|
Args:
|
|
model_repository: TensorRT model repository
|
|
model_configs: Dict mapping model_id to config dict with keys:
|
|
- tracker_type, max_age, min_confidence, iou_threshold, class_names
|
|
ensemble_strategy: How to combine detections ("nms", "vote", "union")
|
|
|
|
Returns:
|
|
MultiModelTrackingController instance
|
|
|
|
Note:
|
|
This is a placeholder for future multi-model support.
|
|
Currently raises NotImplementedError.
|
|
"""
|
|
raise NotImplementedError(
|
|
"Multi-model tracking controller not yet implemented. "
|
|
"Use create_controller for single-model tracking."
|
|
)
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""
|
|
Get factory statistics.
|
|
|
|
Returns:
|
|
Dictionary with factory stats
|
|
"""
|
|
with self._controller_lock:
|
|
return {
|
|
'gpu_id': self.gpu_id,
|
|
'controllers_created': self._controller_count,
|
|
}
|
|
|
|
@classmethod
|
|
def get_factory(cls, gpu_id: int = 0) -> 'TrackingFactory':
|
|
"""
|
|
Get or create factory instance for specified GPU.
|
|
|
|
Args:
|
|
gpu_id: GPU device ID
|
|
|
|
Returns:
|
|
TrackingFactory instance for the GPU
|
|
"""
|
|
return cls(gpu_id=gpu_id)
|
|
|
|
@classmethod
|
|
def list_factories(cls) -> Dict[int, 'TrackingFactory']:
|
|
"""
|
|
List all factory instances.
|
|
|
|
Returns:
|
|
Dictionary mapping gpu_id to factory instance
|
|
"""
|
|
with cls._lock:
|
|
return cls._instances.copy()
|
|
|
|
def __repr__(self):
|
|
with self._controller_lock:
|
|
return (f"TrackingFactory(gpu={self.gpu_id}, "
|
|
f"controllers_created={self._controller_count})")
|