feat: trac king
This commit is contained in:
parent
cf24a172a2
commit
bea895d3d8
4 changed files with 1054 additions and 0 deletions
202
services/tracking_factory.py
Normal file
202
services/tracking_factory.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
import threading
|
||||
from typing import Optional, Dict
|
||||
from .tracking_controller import TrackingController
|
||||
from .model_repository import TensorRTModelRepository
|
||||
|
||||
|
||||
class TrackingFactory:
|
||||
"""
|
||||
Factory for creating TrackingController instances with shared GPU resources.
|
||||
|
||||
This factory follows the same pattern as StreamDecoderFactory for consistency:
|
||||
- Singleton pattern per GPU
|
||||
- Manages shared CUDA state if needed
|
||||
- Provides centralized controller creation
|
||||
- Thread-safe controller fabrication
|
||||
|
||||
The factory doesn't need to manage CUDA context directly since the
|
||||
TensorRTModelRepository already handles GPU resource management.
|
||||
Instead, it provides a clean interface for creating controllers with
|
||||
proper configuration and dependency injection.
|
||||
|
||||
Example:
|
||||
# Get factory instance for GPU 0
|
||||
factory = TrackingFactory(gpu_id=0)
|
||||
|
||||
# Create model repository
|
||||
repo = TensorRTModelRepository(gpu_id=0)
|
||||
repo.load_model("detector", "yolov8.trt")
|
||||
|
||||
# Create tracking controller
|
||||
controller = factory.create_controller(
|
||||
model_repository=repo,
|
||||
model_id="detector",
|
||||
tracker_type="iou"
|
||||
)
|
||||
|
||||
# Multiple controllers share the same model repository
|
||||
controller2 = factory.create_controller(
|
||||
model_repository=repo,
|
||||
model_id="detector",
|
||||
tracker_type="iou"
|
||||
)
|
||||
"""
|
||||
|
||||
_instances: Dict[int, 'TrackingFactory'] = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, gpu_id: int = 0):
|
||||
"""
|
||||
Singleton pattern per GPU.
|
||||
Each GPU gets its own factory instance.
|
||||
"""
|
||||
if gpu_id not in cls._instances:
|
||||
with cls._lock:
|
||||
if gpu_id not in cls._instances:
|
||||
instance = super(TrackingFactory, cls).__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instances[gpu_id] = instance
|
||||
return cls._instances[gpu_id]
|
||||
|
||||
def __init__(self, gpu_id: int = 0):
|
||||
"""
|
||||
Initialize the tracking factory.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID to use
|
||||
"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
self._controller_count = 0
|
||||
self._controller_lock = threading.Lock()
|
||||
|
||||
self._initialized = True
|
||||
print(f"TrackingFactory initialized for GPU {gpu_id}")
|
||||
|
||||
def create_controller(self,
|
||||
model_repository: TensorRTModelRepository,
|
||||
model_id: str,
|
||||
tracker_type: str = "iou",
|
||||
max_age: int = 30,
|
||||
min_confidence: float = 0.5,
|
||||
iou_threshold: float = 0.3,
|
||||
class_names: Optional[Dict[int, str]] = None) -> TrackingController:
|
||||
"""
|
||||
Create a new TrackingController instance.
|
||||
|
||||
Args:
|
||||
model_repository: TensorRT model repository (dependency injection)
|
||||
model_id: Model ID in repository to use for detection
|
||||
tracker_type: Tracking algorithm type ("iou", "sort", "deepsort", "bytetrack")
|
||||
max_age: Maximum frames to keep track without detection
|
||||
min_confidence: Minimum confidence threshold for detections
|
||||
iou_threshold: IoU threshold for track association
|
||||
class_names: Optional mapping of class IDs to names
|
||||
|
||||
Returns:
|
||||
TrackingController instance
|
||||
|
||||
Raises:
|
||||
ValueError: If model_repository GPU doesn't match factory GPU
|
||||
ValueError: If model_id not found in repository
|
||||
"""
|
||||
# Validate GPU ID matches
|
||||
if model_repository.gpu_id != self.gpu_id:
|
||||
raise ValueError(
|
||||
f"Model repository GPU ({model_repository.gpu_id}) doesn't match "
|
||||
f"factory GPU ({self.gpu_id})"
|
||||
)
|
||||
|
||||
# Verify model exists
|
||||
if model_repository.get_metadata(model_id) is None:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' not found in repository. "
|
||||
f"Available models: {list(model_repository._model_to_hash.keys())}"
|
||||
)
|
||||
|
||||
with self._controller_lock:
|
||||
self._controller_count += 1
|
||||
|
||||
controller = TrackingController(
|
||||
model_repository=model_repository,
|
||||
model_id=model_id,
|
||||
gpu_id=self.gpu_id,
|
||||
tracker_type=tracker_type,
|
||||
max_age=max_age,
|
||||
min_confidence=min_confidence,
|
||||
iou_threshold=iou_threshold,
|
||||
class_names=class_names
|
||||
)
|
||||
|
||||
print(f"Created TrackingController #{self._controller_count} (model: {model_id})")
|
||||
|
||||
return controller
|
||||
|
||||
def create_multi_model_controller(self,
|
||||
model_repository: TensorRTModelRepository,
|
||||
model_configs: Dict[str, Dict],
|
||||
ensemble_strategy: str = "nms") -> 'MultiModelTrackingController':
|
||||
"""
|
||||
Create a multi-model tracking controller that combines multiple detectors.
|
||||
|
||||
Args:
|
||||
model_repository: TensorRT model repository
|
||||
model_configs: Dict mapping model_id to config dict with keys:
|
||||
- tracker_type, max_age, min_confidence, iou_threshold, class_names
|
||||
ensemble_strategy: How to combine detections ("nms", "vote", "union")
|
||||
|
||||
Returns:
|
||||
MultiModelTrackingController instance
|
||||
|
||||
Note:
|
||||
This is a placeholder for future multi-model support.
|
||||
Currently raises NotImplementedError.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Multi-model tracking controller not yet implemented. "
|
||||
"Use create_controller for single-model tracking."
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""
|
||||
Get factory statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with factory stats
|
||||
"""
|
||||
with self._controller_lock:
|
||||
return {
|
||||
'gpu_id': self.gpu_id,
|
||||
'controllers_created': self._controller_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_factory(cls, gpu_id: int = 0) -> 'TrackingFactory':
|
||||
"""
|
||||
Get or create factory instance for specified GPU.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID
|
||||
|
||||
Returns:
|
||||
TrackingFactory instance for the GPU
|
||||
"""
|
||||
return cls(gpu_id=gpu_id)
|
||||
|
||||
@classmethod
|
||||
def list_factories(cls) -> Dict[int, 'TrackingFactory']:
|
||||
"""
|
||||
List all factory instances.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping gpu_id to factory instance
|
||||
"""
|
||||
with cls._lock:
|
||||
return cls._instances.copy()
|
||||
|
||||
def __repr__(self):
|
||||
with self._controller_lock:
|
||||
return (f"TrackingFactory(gpu={self.gpu_id}, "
|
||||
f"controllers_created={self._controller_count})")
|
||||
Loading…
Add table
Add a link
Reference in a new issue