feat: custom bot-sort based tracker
This commit is contained in:
parent
bd201acac1
commit
791f611f7d
8 changed files with 649 additions and 282 deletions
|
@ -60,6 +60,8 @@ class YOLOWrapper:
|
|||
|
||||
self.model = None
|
||||
self._class_names = []
|
||||
|
||||
|
||||
self._load_model()
|
||||
|
||||
logger.info(f"Initialized YOLO wrapper for {model_id} on {self.device}")
|
||||
|
@ -115,6 +117,7 @@ class YOLOWrapper:
|
|||
logger.error(f"Failed to extract class names: {str(e)}")
|
||||
self._class_names = {}
|
||||
|
||||
|
||||
def infer(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
|
@ -222,55 +225,30 @@ class YOLOWrapper:
|
|||
|
||||
return detections
|
||||
|
||||
|
||||
def track(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
confidence_threshold: float = 0.5,
|
||||
trigger_classes: Optional[List[str]] = None,
|
||||
persist: bool = True
|
||||
persist: bool = True,
|
||||
camera_id: Optional[str] = None
|
||||
) -> InferenceResult:
|
||||
"""
|
||||
Run tracking on an image
|
||||
Run detection (tracking will be handled by external tracker)
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array (BGR format)
|
||||
confidence_threshold: Minimum confidence for detections
|
||||
trigger_classes: List of class names to filter
|
||||
persist: Whether to persist tracks across frames
|
||||
persist: Ignored - tracking handled externally
|
||||
camera_id: Ignored - tracking handled externally
|
||||
|
||||
Returns:
|
||||
InferenceResult containing detections with track IDs
|
||||
InferenceResult containing detections (no track IDs from YOLO)
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError(f"Model {self.model_id} not loaded")
|
||||
|
||||
try:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Run tracking
|
||||
results = self.model.track(
|
||||
image,
|
||||
conf=confidence_threshold,
|
||||
persist=persist,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# Parse results
|
||||
detections = self._parse_results(results[0], trigger_classes)
|
||||
|
||||
return InferenceResult(
|
||||
detections=detections,
|
||||
image_shape=(image.shape[0], image.shape[1]),
|
||||
inference_time=inference_time,
|
||||
model_id=self.model_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tracking failed for model {self.model_id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
# Just do detection - no YOLO tracking
|
||||
return self.infer(image, confidence_threshold, trigger_classes)
|
||||
|
||||
def predict_classification(
|
||||
self,
|
||||
|
@ -350,6 +328,7 @@ class YOLOWrapper:
|
|||
"""Get the number of classes the model can detect"""
|
||||
return len(self._class_names)
|
||||
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the model cache"""
|
||||
with self._cache_lock:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue