386 lines
No EOL
15 KiB
Python
386 lines
No EOL
15 KiB
Python
"""
|
|
Unit tests for YOLO detector with tracking functionality.
|
|
"""
|
|
import pytest
|
|
import numpy as np
|
|
from unittest.mock import Mock, MagicMock, patch
|
|
import torch
|
|
|
|
from detector_worker.detection.yolo_detector import YOLODetector
|
|
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
|
from detector_worker.core.exceptions import DetectionError
|
|
|
|
|
|
class TestYOLODetector:
|
|
"""Test YOLO detection and tracking functionality."""
|
|
|
|
def test_initialization_with_valid_model(self, mock_yolo_model):
|
|
"""Test detector initialization with valid model."""
|
|
detector = YOLODetector(mock_yolo_model)
|
|
|
|
assert detector.model is mock_yolo_model
|
|
assert detector.class_names == {}
|
|
assert detector.is_tracking_enabled is True
|
|
|
|
def test_initialization_with_class_names(self, mock_yolo_model):
|
|
"""Test detector initialization with class names."""
|
|
class_names = {0: "car", 1: "truck", 2: "bus"}
|
|
detector = YOLODetector(mock_yolo_model, class_names=class_names)
|
|
|
|
assert detector.class_names == class_names
|
|
|
|
def test_initialization_tracking_disabled(self, mock_yolo_model):
|
|
"""Test detector initialization with tracking disabled."""
|
|
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
|
|
|
|
assert detector.is_tracking_enabled is False
|
|
|
|
def test_detect_with_tracking(self, mock_yolo_model, mock_frame):
|
|
"""Test detection with tracking enabled."""
|
|
# Mock detection result
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = torch.tensor([
|
|
[100, 200, 300, 400, 0.9, 0], # x1, y1, x2, y2, conf, class
|
|
[150, 250, 350, 450, 0.85, 1]
|
|
])
|
|
mock_result.boxes.id = torch.tensor([1001, 1002])
|
|
|
|
mock_yolo_model.track.return_value = [mock_result]
|
|
|
|
detector = YOLODetector(mock_yolo_model)
|
|
detections = detector.detect(mock_frame)
|
|
|
|
assert len(detections) == 2
|
|
assert detections[0].confidence == 0.9
|
|
assert detections[0].track_id == 1001
|
|
assert detections[0].bbox.x1 == 100
|
|
|
|
mock_yolo_model.track.assert_called_once_with(mock_frame, persist=True, verbose=False)
|
|
|
|
def test_detect_without_tracking(self, mock_yolo_model, mock_frame):
|
|
"""Test detection with tracking disabled."""
|
|
# Mock detection result
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = torch.tensor([
|
|
[100, 200, 300, 400, 0.9, 0]
|
|
])
|
|
mock_result.boxes.id = None # No tracking IDs
|
|
|
|
mock_yolo_model.predict.return_value = [mock_result]
|
|
|
|
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
|
|
detections = detector.detect(mock_frame)
|
|
|
|
assert len(detections) == 1
|
|
assert detections[0].track_id is None # No tracking ID
|
|
|
|
mock_yolo_model.predict.assert_called_once_with(mock_frame, verbose=False)
|
|
|
|
def test_detect_with_class_names(self, mock_yolo_model, mock_frame):
|
|
"""Test detection with class name mapping."""
|
|
class_names = {0: "car", 1: "truck"}
|
|
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = torch.tensor([
|
|
[100, 200, 300, 400, 0.9, 0], # car
|
|
[150, 250, 350, 450, 0.85, 1] # truck
|
|
])
|
|
mock_result.boxes.id = torch.tensor([1001, 1002])
|
|
|
|
mock_yolo_model.track.return_value = [mock_result]
|
|
|
|
detector = YOLODetector(mock_yolo_model, class_names=class_names)
|
|
detections = detector.detect(mock_frame)
|
|
|
|
assert detections[0].class_name == "car"
|
|
assert detections[1].class_name == "truck"
|
|
|
|
def test_detect_no_boxes(self, mock_yolo_model, mock_frame):
|
|
"""Test detection when no objects are detected."""
|
|
mock_result = Mock()
|
|
mock_result.boxes = None
|
|
|
|
mock_yolo_model.track.return_value = [mock_result]
|
|
|
|
detector = YOLODetector(mock_yolo_model)
|
|
detections = detector.detect(mock_frame)
|
|
|
|
assert detections == []
|
|
|
|
def test_detect_empty_boxes(self, mock_yolo_model, mock_frame):
|
|
"""Test detection with empty boxes tensor."""
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = torch.tensor([]).reshape(0, 6)
|
|
mock_result.boxes.id = None
|
|
|
|
mock_yolo_model.track.return_value = [mock_result]
|
|
|
|
detector = YOLODetector(mock_yolo_model)
|
|
detections = detector.detect(mock_frame)
|
|
|
|
assert detections == []
|
|
|
|
def test_detect_with_confidence_threshold(self, mock_yolo_model, mock_frame):
|
|
"""Test detection with confidence threshold filtering."""
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = torch.tensor([
|
|
[100, 200, 300, 400, 0.9, 0], # Above threshold
|
|
[150, 250, 350, 450, 0.3, 1] # Below threshold
|
|
])
|
|
mock_result.boxes.id = torch.tensor([1001, 1002])
|
|
|
|
mock_yolo_model.track.return_value = [mock_result]
|
|
|
|
detector = YOLODetector(mock_yolo_model)
|
|
detections = detector.detect(mock_frame, confidence_threshold=0.5)
|
|
|
|
assert len(detections) == 1 # Only one above threshold
|
|
assert detections[0].confidence == 0.9
|
|
|
|
def test_detect_model_error_handling(self, mock_yolo_model, mock_frame):
|
|
"""Test error handling when model fails."""
|
|
mock_yolo_model.track.side_effect = Exception("Model inference failed")
|
|
|
|
detector = YOLODetector(mock_yolo_model)
|
|
|
|
with pytest.raises(DetectionError) as exc_info:
|
|
detector.detect(mock_frame)
|
|
|
|
assert "Model inference failed" in str(exc_info.value)
|
|
|
|
def test_detect_invalid_frame(self, mock_yolo_model):
|
|
"""Test detection with invalid frame input."""
|
|
detector = YOLODetector(mock_yolo_model)
|
|
|
|
with pytest.raises(DetectionError) as exc_info:
|
|
detector.detect(None)
|
|
|
|
assert "Invalid frame" in str(exc_info.value)
|
|
|
|
def test_detect_result_validation(self, mock_yolo_model, mock_frame):
|
|
"""Test detection result validation."""
|
|
# Mock result with invalid bounding box (x2 <= x1)
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = torch.tensor([
|
|
[300, 200, 100, 400, 0.9, 0] # Invalid: x2 < x1
|
|
])
|
|
mock_result.boxes.id = torch.tensor([1001])
|
|
|
|
mock_yolo_model.track.return_value = [mock_result]
|
|
|
|
detector = YOLODetector(mock_yolo_model)
|
|
detections = detector.detect(mock_frame)
|
|
|
|
# Invalid detections should be filtered out
|
|
assert detections == []
|
|
|
|
def test_get_model_info(self, mock_yolo_model):
|
|
"""Test getting model information."""
|
|
mock_yolo_model.device = "cuda:0"
|
|
mock_yolo_model.names = {0: "car", 1: "truck"}
|
|
|
|
detector = YOLODetector(mock_yolo_model)
|
|
info = detector.get_model_info()
|
|
|
|
assert info["device"] == "cuda:0"
|
|
assert info["class_names"] == {0: "car", 1: "truck"}
|
|
assert info["tracking_enabled"] is True
|
|
|
|
def test_set_tracking_enabled(self, mock_yolo_model):
|
|
"""Test enabling/disabling tracking at runtime."""
|
|
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
|
|
assert detector.is_tracking_enabled is False
|
|
|
|
detector.set_tracking_enabled(True)
|
|
assert detector.is_tracking_enabled is True
|
|
|
|
detector.set_tracking_enabled(False)
|
|
assert detector.is_tracking_enabled is False
|
|
|
|
def test_update_class_names(self, mock_yolo_model):
|
|
"""Test updating class names at runtime."""
|
|
detector = YOLODetector(mock_yolo_model)
|
|
|
|
new_class_names = {0: "vehicle", 1: "person"}
|
|
detector.update_class_names(new_class_names)
|
|
|
|
assert detector.class_names == new_class_names
|
|
|
|
def test_reset_tracker(self, mock_yolo_model):
|
|
"""Test resetting the tracking state."""
|
|
detector = YOLODetector(mock_yolo_model)
|
|
|
|
# This should not raise an error
|
|
detector.reset_tracker()
|
|
|
|
def test_detect_with_crop_region(self, mock_yolo_model, mock_frame):
|
|
"""Test detection with crop region specified."""
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = torch.tensor([
|
|
[50, 75, 150, 175, 0.9, 0] # Relative to cropped region
|
|
])
|
|
mock_result.boxes.id = torch.tensor([1001])
|
|
|
|
mock_yolo_model.track.return_value = [mock_result]
|
|
|
|
detector = YOLODetector(mock_yolo_model)
|
|
crop_region = (100, 200, 300, 400) # x1, y1, x2, y2
|
|
detections = detector.detect(mock_frame, crop_region=crop_region)
|
|
|
|
# Bounding box should be adjusted to global coordinates
|
|
assert detections[0].bbox.x1 == 150 # 100 + 50
|
|
assert detections[0].bbox.y1 == 275 # 200 + 75
|
|
assert detections[0].bbox.x2 == 250 # 100 + 150
|
|
assert detections[0].bbox.y2 == 375 # 200 + 175
|
|
|
|
def test_detect_batch_processing(self, mock_yolo_model):
|
|
"""Test batch detection processing."""
|
|
frames = [
|
|
np.zeros((480, 640, 3), dtype=np.uint8),
|
|
np.ones((480, 640, 3), dtype=np.uint8) * 255
|
|
]
|
|
|
|
mock_results = []
|
|
for i in range(2):
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = torch.tensor([
|
|
[100 + i*10, 200, 300, 400, 0.9, 0]
|
|
])
|
|
mock_result.boxes.id = torch.tensor([1001 + i])
|
|
mock_results.append(mock_result)
|
|
|
|
mock_yolo_model.track.side_effect = [[result] for result in mock_results]
|
|
|
|
detector = YOLODetector(mock_yolo_model)
|
|
batch_detections = detector.detect_batch(frames)
|
|
|
|
assert len(batch_detections) == 2
|
|
assert len(batch_detections[0]) == 1
|
|
assert len(batch_detections[1]) == 1
|
|
assert batch_detections[0][0].bbox.x1 == 100
|
|
assert batch_detections[1][0].bbox.x1 == 110
|
|
|
|
def test_detect_batch_empty_frames(self, mock_yolo_model):
|
|
"""Test batch detection with empty frame list."""
|
|
detector = YOLODetector(mock_yolo_model)
|
|
batch_detections = detector.detect_batch([])
|
|
|
|
assert batch_detections == []
|
|
|
|
def test_detect_performance_metrics(self, mock_yolo_model, mock_frame):
|
|
"""Test detection performance metrics collection."""
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = torch.tensor([
|
|
[100, 200, 300, 400, 0.9, 0]
|
|
])
|
|
mock_result.boxes.id = torch.tensor([1001])
|
|
mock_result.speed = {"preprocess": 2.1, "inference": 15.3, "postprocess": 1.2}
|
|
|
|
mock_yolo_model.track.return_value = [mock_result]
|
|
|
|
detector = YOLODetector(mock_yolo_model)
|
|
detections = detector.detect(mock_frame, return_metrics=True)
|
|
|
|
# Check if performance metrics are available
|
|
assert hasattr(detector, '_last_inference_time')
|
|
|
|
@pytest.mark.parametrize("device", ["cpu", "cuda:0", "mps"])
|
|
def test_detect_different_devices(self, device, mock_frame):
|
|
"""Test detection on different devices."""
|
|
mock_model = Mock()
|
|
mock_model.device = device
|
|
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = torch.tensor([
|
|
[100, 200, 300, 400, 0.9, 0]
|
|
])
|
|
mock_result.boxes.id = torch.tensor([1001])
|
|
|
|
mock_model.track.return_value = [mock_result]
|
|
|
|
detector = YOLODetector(mock_model)
|
|
detections = detector.detect(mock_frame)
|
|
|
|
assert len(detections) == 1
|
|
assert detections[0].confidence == 0.9
|
|
|
|
|
|
class TestYOLODetectorIntegration:
|
|
"""Integration tests for YOLO detector."""
|
|
|
|
def test_detect_with_real_tensor_operations(self, mock_yolo_model, mock_frame):
|
|
"""Test detection with realistic tensor operations."""
|
|
# Create more realistic box data
|
|
boxes_data = torch.tensor([
|
|
[100.5, 200.3, 299.7, 399.8, 0.95, 0],
|
|
[150.2, 250.1, 349.9, 449.6, 0.87, 1],
|
|
[200.0, 300.0, 400.0, 500.0, 0.45, 0] # Low confidence
|
|
])
|
|
|
|
mock_result = Mock()
|
|
mock_result.boxes = Mock()
|
|
mock_result.boxes.data = boxes_data
|
|
mock_result.boxes.id = torch.tensor([2001, 2002, 2003])
|
|
|
|
mock_yolo_model.track.return_value = [mock_result]
|
|
|
|
class_names = {0: "car", 1: "truck"}
|
|
detector = YOLODetector(mock_yolo_model, class_names=class_names)
|
|
|
|
detections = detector.detect(mock_frame, confidence_threshold=0.5)
|
|
|
|
# Should filter out low confidence detection
|
|
assert len(detections) == 2
|
|
|
|
# Check first detection
|
|
det1 = detections[0]
|
|
assert det1.class_name == "car"
|
|
assert det1.confidence == pytest.approx(0.95)
|
|
assert det1.track_id == 2001
|
|
assert det1.bbox.x1 == pytest.approx(100.5)
|
|
assert det1.bbox.y1 == pytest.approx(200.3)
|
|
|
|
# Check second detection
|
|
det2 = detections[1]
|
|
assert det2.class_name == "truck"
|
|
assert det2.confidence == pytest.approx(0.87)
|
|
assert det2.track_id == 2002
|
|
|
|
def test_multi_frame_tracking_consistency(self, mock_yolo_model, mock_frame):
|
|
"""Test that tracking IDs remain consistent across frames."""
|
|
detector = YOLODetector(mock_yolo_model)
|
|
|
|
# Frame 1
|
|
mock_result1 = Mock()
|
|
mock_result1.boxes = Mock()
|
|
mock_result1.boxes.data = torch.tensor([
|
|
[100, 200, 300, 400, 0.9, 0]
|
|
])
|
|
mock_result1.boxes.id = torch.tensor([5001])
|
|
|
|
mock_yolo_model.track.return_value = [mock_result1]
|
|
detections1 = detector.detect(mock_frame)
|
|
|
|
# Frame 2 - same object, slightly moved
|
|
mock_result2 = Mock()
|
|
mock_result2.boxes = Mock()
|
|
mock_result2.boxes.data = torch.tensor([
|
|
[105, 205, 305, 405, 0.88, 0]
|
|
])
|
|
mock_result2.boxes.id = torch.tensor([5001]) # Same ID
|
|
|
|
mock_yolo_model.track.return_value = [mock_result2]
|
|
detections2 = detector.detect(mock_frame)
|
|
|
|
# Should maintain same track ID
|
|
assert detections1[0].track_id == detections2[0].track_id == 5001 |