feat: custom bot-sort based tracker

This commit is contained in:
ziesorx 2025-09-26 14:22:38 +07:00
parent bd201acac1
commit 791f611f7d
8 changed files with 649 additions and 282 deletions

View file

@ -60,6 +60,8 @@ class YOLOWrapper:
self.model = None
self._class_names = []
self._load_model()
logger.info(f"Initialized YOLO wrapper for {model_id} on {self.device}")
@ -115,6 +117,7 @@ class YOLOWrapper:
logger.error(f"Failed to extract class names: {str(e)}")
self._class_names = {}
def infer(
self,
image: np.ndarray,
@ -222,55 +225,30 @@ class YOLOWrapper:
return detections
def track(
self,
image: np.ndarray,
confidence_threshold: float = 0.5,
trigger_classes: Optional[List[str]] = None,
persist: bool = True
persist: bool = True,
camera_id: Optional[str] = None
) -> InferenceResult:
"""
Run tracking on an image
Run detection (tracking will be handled by external tracker)
Args:
image: Input image as numpy array (BGR format)
confidence_threshold: Minimum confidence for detections
trigger_classes: List of class names to filter
persist: Whether to persist tracks across frames
persist: Ignored - tracking handled externally
camera_id: Ignored - tracking handled externally
Returns:
InferenceResult containing detections with track IDs
InferenceResult containing detections (no track IDs from YOLO)
"""
if self.model is None:
raise RuntimeError(f"Model {self.model_id} not loaded")
try:
import time
start_time = time.time()
# Run tracking
results = self.model.track(
image,
conf=confidence_threshold,
persist=persist,
verbose=False
)
inference_time = time.time() - start_time
# Parse results
detections = self._parse_results(results[0], trigger_classes)
return InferenceResult(
detections=detections,
image_shape=(image.shape[0], image.shape[1]),
inference_time=inference_time,
model_id=self.model_id
)
except Exception as e:
logger.error(f"Tracking failed for model {self.model_id}: {str(e)}", exc_info=True)
raise
# Just do detection - no YOLO tracking
return self.infer(image, confidence_threshold, trigger_classes)
def predict_classification(
self,
@ -350,6 +328,7 @@ class YOLOWrapper:
"""Get the number of classes the model can detect"""
return len(self._class_names)
def clear_cache(self) -> None:
"""Clear the model cache"""
with self._cache_lock: