Refactor: PHASE 8: Testing & Integration
This commit is contained in:
parent
af34f4fd08
commit
9e8c6804a7
32 changed files with 17128 additions and 0 deletions
479
tests/unit/detection/test_detection_result.py
Normal file
479
tests/unit/detection/test_detection_result.py
Normal file
|
@ -0,0 +1,479 @@
|
|||
"""
|
||||
Unit tests for detection result data structures.
|
||||
"""
|
||||
import pytest
|
||||
from dataclasses import asdict
|
||||
import numpy as np
|
||||
|
||||
from detector_worker.detection.detection_result import (
|
||||
BoundingBox,
|
||||
DetectionResult,
|
||||
LightweightDetectionResult,
|
||||
DetectionSession,
|
||||
TrackValidationResult
|
||||
)
|
||||
|
||||
|
||||
class TestBoundingBox:
|
||||
"""Test BoundingBox data structure."""
|
||||
|
||||
def test_creation_from_coordinates(self):
|
||||
"""Test creating bounding box from coordinates."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
assert bbox.x1 == 100
|
||||
assert bbox.y1 == 200
|
||||
assert bbox.x2 == 300
|
||||
assert bbox.y2 == 400
|
||||
|
||||
def test_creation_from_list(self):
|
||||
"""Test creating bounding box from list."""
|
||||
coords = [100, 200, 300, 400]
|
||||
bbox = BoundingBox.from_list(coords)
|
||||
|
||||
assert bbox.x1 == 100
|
||||
assert bbox.y1 == 200
|
||||
assert bbox.x2 == 300
|
||||
assert bbox.y2 == 400
|
||||
|
||||
def test_creation_from_invalid_list(self):
|
||||
"""Test error handling for invalid list."""
|
||||
with pytest.raises(ValueError):
|
||||
BoundingBox.from_list([100, 200, 300]) # Too few elements
|
||||
|
||||
def test_to_list(self):
|
||||
"""Test converting bounding box to list."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
coords = bbox.to_list()
|
||||
|
||||
assert coords == [100, 200, 300, 400]
|
||||
|
||||
def test_area_calculation(self):
|
||||
"""Test area calculation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
area = bbox.area()
|
||||
|
||||
expected_area = (300 - 100) * (400 - 200) # 200 * 200 = 40000
|
||||
assert area == expected_area
|
||||
|
||||
def test_area_zero_for_invalid_bbox(self):
|
||||
"""Test area is zero for invalid bounding box."""
|
||||
# x2 <= x1
|
||||
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
||||
assert bbox.area() == 0
|
||||
|
||||
# y2 <= y1
|
||||
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
|
||||
assert bbox.area() == 0
|
||||
|
||||
def test_width_height(self):
|
||||
"""Test width and height properties."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
assert bbox.width() == 200
|
||||
assert bbox.height() == 200
|
||||
|
||||
def test_center_point(self):
|
||||
"""Test center point calculation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
center = bbox.center()
|
||||
|
||||
assert center == (200, 300) # (x1+x2)/2, (y1+y2)/2
|
||||
|
||||
def test_is_valid(self):
|
||||
"""Test bounding box validation."""
|
||||
# Valid bbox
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
assert bbox.is_valid() is True
|
||||
|
||||
# Invalid bbox (x2 <= x1)
|
||||
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
||||
assert bbox.is_valid() is False
|
||||
|
||||
# Invalid bbox (y2 <= y1)
|
||||
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
|
||||
assert bbox.is_valid() is False
|
||||
|
||||
def test_intersection(self):
|
||||
"""Test bounding box intersection."""
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
||||
|
||||
intersection = bbox1.intersection(bbox2)
|
||||
|
||||
assert intersection.x1 == 200
|
||||
assert intersection.y1 == 200
|
||||
assert intersection.x2 == 300
|
||||
assert intersection.y2 == 300
|
||||
|
||||
def test_no_intersection(self):
|
||||
"""Test no intersection between bounding boxes."""
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
|
||||
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
|
||||
|
||||
intersection = bbox1.intersection(bbox2)
|
||||
|
||||
assert intersection.is_valid() is False
|
||||
|
||||
def test_union(self):
|
||||
"""Test bounding box union."""
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
||||
|
||||
union = bbox1.union(bbox2)
|
||||
|
||||
assert union.x1 == 100
|
||||
assert union.y1 == 100
|
||||
assert union.x2 == 400
|
||||
assert union.y2 == 400
|
||||
|
||||
def test_iou_calculation(self):
|
||||
"""Test IoU (Intersection over Union) calculation."""
|
||||
# Perfect overlap
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
assert bbox1.iou(bbox2) == 1.0
|
||||
|
||||
# No overlap
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
|
||||
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
|
||||
assert bbox1.iou(bbox2) == 0.0
|
||||
|
||||
# Partial overlap
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
||||
|
||||
# Intersection area: 100x100 = 10000
|
||||
# Union area: 200x200 + 200x200 - 10000 = 30000
|
||||
# IoU = 10000/30000 = 1/3
|
||||
expected_iou = 1.0 / 3.0
|
||||
assert abs(bbox1.iou(bbox2) - expected_iou) < 1e-6
|
||||
|
||||
|
||||
class TestDetectionResult:
|
||||
"""Test DetectionResult data structure."""
|
||||
|
||||
def test_creation_with_required_fields(self):
|
||||
"""Test creating detection result with required fields."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
assert detection.class_name == "car"
|
||||
assert detection.confidence == 0.85
|
||||
assert detection.bbox == bbox
|
||||
assert detection.track_id == 12345
|
||||
|
||||
def test_creation_with_all_fields(self):
|
||||
"""Test creating detection result with all fields."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345,
|
||||
model_id="yolo_v8",
|
||||
timestamp=1640995200000,
|
||||
branch_results={"brand": "Toyota"}
|
||||
)
|
||||
|
||||
assert detection.model_id == "yolo_v8"
|
||||
assert detection.timestamp == 1640995200000
|
||||
assert detection.branch_results == {"brand": "Toyota"}
|
||||
|
||||
def test_creation_from_dict(self):
|
||||
"""Test creating detection result from dictionary."""
|
||||
data = {
|
||||
"class": "car",
|
||||
"confidence": 0.85,
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"id": 12345,
|
||||
"model_id": "yolo_v8",
|
||||
"timestamp": 1640995200000
|
||||
}
|
||||
|
||||
detection = DetectionResult.from_dict(data)
|
||||
|
||||
assert detection.class_name == "car"
|
||||
assert detection.confidence == 0.85
|
||||
assert detection.bbox.to_list() == [100, 200, 300, 400]
|
||||
assert detection.track_id == 12345
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting detection result to dictionary."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
data = detection.to_dict()
|
||||
|
||||
assert data["class"] == "car"
|
||||
assert data["confidence"] == 0.85
|
||||
assert data["bbox"] == [100, 200, 300, 400]
|
||||
assert data["id"] == 12345
|
||||
|
||||
def test_is_valid_detection(self):
|
||||
"""Test detection validation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
# Valid detection
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is True
|
||||
|
||||
# Invalid confidence (too low)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=-0.1,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is False
|
||||
|
||||
# Invalid confidence (too high)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=1.5,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is False
|
||||
|
||||
# Invalid bounding box
|
||||
invalid_bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=invalid_bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is False
|
||||
|
||||
|
||||
class TestLightweightDetectionResult:
|
||||
"""Test LightweightDetectionResult data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating lightweight detection result."""
|
||||
detection = LightweightDetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox_area=40000,
|
||||
frame_width=1920,
|
||||
frame_height=1080
|
||||
)
|
||||
|
||||
assert detection.class_name == "car"
|
||||
assert detection.confidence == 0.85
|
||||
assert detection.bbox_area == 40000
|
||||
assert detection.frame_width == 1920
|
||||
assert detection.frame_height == 1080
|
||||
|
||||
def test_area_ratio_calculation(self):
|
||||
"""Test bounding box area ratio calculation."""
|
||||
detection = LightweightDetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox_area=40000,
|
||||
frame_width=1920,
|
||||
frame_height=1080
|
||||
)
|
||||
|
||||
expected_ratio = 40000 / (1920 * 1080)
|
||||
assert abs(detection.area_ratio() - expected_ratio) < 1e-6
|
||||
|
||||
def test_meets_threshold(self):
|
||||
"""Test threshold checking."""
|
||||
detection = LightweightDetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox_area=40000,
|
||||
frame_width=1920,
|
||||
frame_height=1080
|
||||
)
|
||||
|
||||
assert detection.meets_threshold(confidence=0.8, area_ratio=0.01) is True
|
||||
assert detection.meets_threshold(confidence=0.9, area_ratio=0.01) is False
|
||||
assert detection.meets_threshold(confidence=0.8, area_ratio=0.1) is False
|
||||
|
||||
|
||||
class TestDetectionSession:
|
||||
"""Test DetectionSession data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating detection session."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
assert session.session_id == "session_123"
|
||||
assert session.camera_id == "camera_001"
|
||||
assert session.display_id == "display_001"
|
||||
assert session.detections == []
|
||||
assert session.metadata == {}
|
||||
|
||||
def test_add_detection(self):
|
||||
"""Test adding detection to session."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
session.add_detection(detection)
|
||||
|
||||
assert len(session.detections) == 1
|
||||
assert session.detections[0] == detection
|
||||
|
||||
def test_get_latest_detection(self):
|
||||
"""Test getting latest detection."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
# Add multiple detections
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=12345,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.90,
|
||||
bbox=bbox2,
|
||||
track_id=12345,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
session.add_detection(detection1)
|
||||
session.add_detection(detection2)
|
||||
|
||||
latest = session.get_latest_detection()
|
||||
assert latest == detection2 # Should be the one with later timestamp
|
||||
|
||||
def test_get_detections_by_class(self):
|
||||
"""Test filtering detections by class."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
car_detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
truck_detection = DetectionResult(
|
||||
class_name="truck",
|
||||
confidence=0.80,
|
||||
bbox=bbox,
|
||||
track_id=54321
|
||||
)
|
||||
|
||||
session.add_detection(car_detection)
|
||||
session.add_detection(truck_detection)
|
||||
|
||||
car_detections = session.get_detections_by_class("car")
|
||||
assert len(car_detections) == 1
|
||||
assert car_detections[0] == car_detection
|
||||
|
||||
truck_detections = session.get_detections_by_class("truck")
|
||||
assert len(truck_detections) == 1
|
||||
assert truck_detections[0] == truck_detection
|
||||
|
||||
|
||||
class TestTrackValidationResult:
|
||||
"""Test TrackValidationResult data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating track validation result."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102, 103],
|
||||
current_tracks=[101, 102, 104, 105],
|
||||
newly_stable=[103],
|
||||
lost_tracks=[106]
|
||||
)
|
||||
|
||||
assert result.stable_tracks == [101, 102, 103]
|
||||
assert result.current_tracks == [101, 102, 104, 105]
|
||||
assert result.newly_stable == [103]
|
||||
assert result.lost_tracks == [106]
|
||||
|
||||
def test_has_stable_tracks(self):
|
||||
"""Test checking for stable tracks."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102],
|
||||
current_tracks=[101, 102, 103]
|
||||
)
|
||||
|
||||
assert result.has_stable_tracks() is True
|
||||
|
||||
result_empty = TrackValidationResult(
|
||||
stable_tracks=[],
|
||||
current_tracks=[101, 102, 103]
|
||||
)
|
||||
|
||||
assert result_empty.has_stable_tracks() is False
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting validation statistics."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102, 103],
|
||||
current_tracks=[101, 102, 104, 105],
|
||||
newly_stable=[103],
|
||||
lost_tracks=[106]
|
||||
)
|
||||
|
||||
stats = result.get_stats()
|
||||
|
||||
assert stats["stable_count"] == 3
|
||||
assert stats["current_count"] == 4
|
||||
assert stats["newly_stable_count"] == 1
|
||||
assert stats["lost_count"] == 1
|
||||
assert stats["stability_ratio"] == 3/4 # stable/current
|
||||
|
||||
def test_is_track_stable(self):
|
||||
"""Test checking if specific track is stable."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102, 103],
|
||||
current_tracks=[101, 102, 104, 105]
|
||||
)
|
||||
|
||||
assert result.is_track_stable(101) is True
|
||||
assert result.is_track_stable(102) is True
|
||||
assert result.is_track_stable(104) is False
|
||||
assert result.is_track_stable(999) is False
|
701
tests/unit/detection/test_stability_validator.py
Normal file
701
tests/unit/detection/test_stability_validator.py
Normal file
|
@ -0,0 +1,701 @@
|
|||
"""
|
||||
Unit tests for track stability validation.
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import Mock, patch
|
||||
from collections import defaultdict
|
||||
|
||||
from detector_worker.detection.stability_validator import (
|
||||
StabilityValidator,
|
||||
StabilityConfig,
|
||||
ValidationResult,
|
||||
TrackStabilityMetrics
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox, TrackValidationResult
|
||||
from detector_worker.core.exceptions import ValidationError
|
||||
|
||||
|
||||
class TestStabilityConfig:
|
||||
"""Test stability configuration data structure."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default stability configuration."""
|
||||
config = StabilityConfig()
|
||||
|
||||
assert config.min_detection_frames == 10
|
||||
assert config.max_absence_frames == 30
|
||||
assert config.confidence_threshold == 0.5
|
||||
assert config.stability_window == 60.0
|
||||
assert config.iou_threshold == 0.3
|
||||
assert config.movement_threshold == 50.0
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom stability configuration."""
|
||||
config = StabilityConfig(
|
||||
min_detection_frames=5,
|
||||
max_absence_frames=15,
|
||||
confidence_threshold=0.8,
|
||||
stability_window=30.0,
|
||||
iou_threshold=0.5,
|
||||
movement_threshold=25.0
|
||||
)
|
||||
|
||||
assert config.min_detection_frames == 5
|
||||
assert config.max_absence_frames == 15
|
||||
assert config.confidence_threshold == 0.8
|
||||
assert config.stability_window == 30.0
|
||||
assert config.iou_threshold == 0.5
|
||||
assert config.movement_threshold == 25.0
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"min_detection_frames": 8,
|
||||
"max_absence_frames": 25,
|
||||
"confidence_threshold": 0.75,
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = StabilityConfig.from_dict(config_dict)
|
||||
|
||||
assert config.min_detection_frames == 8
|
||||
assert config.max_absence_frames == 25
|
||||
assert config.confidence_threshold == 0.75
|
||||
# Unknown fields should use defaults
|
||||
assert config.stability_window == 60.0
|
||||
|
||||
|
||||
class TestTrackStabilityMetrics:
|
||||
"""Test track stability metrics."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test metrics initialization."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
assert metrics.track_id == 1001
|
||||
assert metrics.detection_count == 0
|
||||
assert metrics.absence_count == 0
|
||||
assert metrics.total_confidence == 0.0
|
||||
assert metrics.first_detection_time is None
|
||||
assert metrics.last_detection_time is None
|
||||
assert metrics.bounding_boxes == []
|
||||
assert metrics.confidence_scores == []
|
||||
|
||||
def test_add_detection(self):
|
||||
"""Test adding detection to metrics."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
metrics.add_detection(detection, current_time=1640995200.0)
|
||||
|
||||
assert metrics.detection_count == 1
|
||||
assert metrics.absence_count == 0
|
||||
assert metrics.total_confidence == 0.85
|
||||
assert metrics.first_detection_time == 1640995200.0
|
||||
assert metrics.last_detection_time == 1640995200.0
|
||||
assert len(metrics.bounding_boxes) == 1
|
||||
assert len(metrics.confidence_scores) == 1
|
||||
|
||||
def test_increment_absence(self):
|
||||
"""Test incrementing absence count."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
metrics.increment_absence()
|
||||
assert metrics.absence_count == 1
|
||||
|
||||
metrics.increment_absence()
|
||||
assert metrics.absence_count == 2
|
||||
|
||||
def test_reset_absence(self):
|
||||
"""Test resetting absence count."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
metrics.increment_absence()
|
||||
metrics.increment_absence()
|
||||
assert metrics.absence_count == 2
|
||||
|
||||
metrics.reset_absence()
|
||||
assert metrics.absence_count == 0
|
||||
|
||||
def test_average_confidence(self):
|
||||
"""Test average confidence calculation."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
# No detections
|
||||
assert metrics.average_confidence() == 0.0
|
||||
|
||||
# Add detections
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.9,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
metrics.add_detection(detection1, current_time=1640995200.0)
|
||||
metrics.add_detection(detection2, current_time=1640995300.0)
|
||||
|
||||
assert metrics.average_confidence() == 0.85 # (0.8 + 0.9) / 2
|
||||
|
||||
def test_tracking_duration(self):
|
||||
"""Test tracking duration calculation."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
# No detections
|
||||
assert metrics.tracking_duration() == 0.0
|
||||
|
||||
# Add detections
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.9,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
metrics.add_detection(detection1, current_time=1640995200.0)
|
||||
metrics.add_detection(detection2, current_time=1640995300.0)
|
||||
|
||||
assert metrics.tracking_duration() == 100.0 # 1640995300 - 1640995200
|
||||
|
||||
def test_movement_distance(self):
|
||||
"""Test movement distance calculation."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
# No movement with single detection
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
metrics.add_detection(detection1, current_time=1640995200.0)
|
||||
assert metrics.total_movement_distance() == 0.0
|
||||
|
||||
# Add second detection with movement
|
||||
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.9,
|
||||
bbox=bbox2,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
metrics.add_detection(detection2, current_time=1640995300.0)
|
||||
|
||||
# Distance between centers: (200,300) to (210,310) = sqrt(100+100) ≈ 14.14
|
||||
movement = metrics.total_movement_distance()
|
||||
assert movement == pytest.approx(14.14, rel=1e-2)
|
||||
|
||||
|
||||
class TestValidationResult:
|
||||
"""Test validation result data structure."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test validation result initialization."""
|
||||
result = ValidationResult(
|
||||
track_id=1001,
|
||||
is_stable=True,
|
||||
detection_count=15,
|
||||
absence_count=2,
|
||||
average_confidence=0.85,
|
||||
tracking_duration=120.0
|
||||
)
|
||||
|
||||
assert result.track_id == 1001
|
||||
assert result.is_stable is True
|
||||
assert result.detection_count == 15
|
||||
assert result.absence_count == 2
|
||||
assert result.average_confidence == 0.85
|
||||
assert result.tracking_duration == 120.0
|
||||
assert result.reasons == []
|
||||
|
||||
def test_with_reasons(self):
|
||||
"""Test validation result with failure reasons."""
|
||||
result = ValidationResult(
|
||||
track_id=1001,
|
||||
is_stable=False,
|
||||
detection_count=5,
|
||||
absence_count=35,
|
||||
average_confidence=0.4,
|
||||
tracking_duration=30.0,
|
||||
reasons=["Insufficient detection frames", "Too many absences", "Low confidence"]
|
||||
)
|
||||
|
||||
assert result.is_stable is False
|
||||
assert len(result.reasons) == 3
|
||||
assert "Insufficient detection frames" in result.reasons
|
||||
|
||||
|
||||
class TestStabilityValidator:
|
||||
"""Test stability validation functionality."""
|
||||
|
||||
def test_initialization_default(self):
|
||||
"""Test validator initialization with default config."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
assert isinstance(validator.config, StabilityConfig)
|
||||
assert validator.config.min_detection_frames == 10
|
||||
assert len(validator.track_metrics) == 0
|
||||
|
||||
def test_initialization_custom_config(self):
|
||||
"""Test validator initialization with custom config."""
|
||||
config = StabilityConfig(min_detection_frames=5, confidence_threshold=0.8)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
assert validator.config.min_detection_frames == 5
|
||||
assert validator.config.confidence_threshold == 0.8
|
||||
|
||||
def test_update_detections_new_track(self):
|
||||
"""Test updating with new track."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
validator.update_detections([detection], current_time=1640995200.0)
|
||||
|
||||
assert 1001 in validator.track_metrics
|
||||
metrics = validator.track_metrics[1001]
|
||||
assert metrics.detection_count == 1
|
||||
assert metrics.absence_count == 0
|
||||
|
||||
def test_update_detections_existing_track(self):
|
||||
"""Test updating existing track."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
# First detection
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
validator.update_detections([detection1], current_time=1640995200.0)
|
||||
|
||||
# Second detection
|
||||
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.9,
|
||||
bbox=bbox2,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
validator.update_detections([detection2], current_time=1640995300.0)
|
||||
|
||||
metrics = validator.track_metrics[1001]
|
||||
assert metrics.detection_count == 2
|
||||
assert metrics.absence_count == 0
|
||||
assert metrics.average_confidence() == 0.85
|
||||
|
||||
def test_update_detections_missing_track(self):
|
||||
"""Test updating when track is missing (increment absence)."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
# Add track
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
validator.update_detections([detection], current_time=1640995200.0)
|
||||
|
||||
# Update with empty detections
|
||||
validator.update_detections([], current_time=1640995300.0)
|
||||
|
||||
metrics = validator.track_metrics[1001]
|
||||
assert metrics.detection_count == 1
|
||||
assert metrics.absence_count == 1
|
||||
|
||||
def test_validate_track_stable(self):
|
||||
"""Test validating a stable track."""
|
||||
config = StabilityConfig(min_detection_frames=3, max_absence_frames=5)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Create track with sufficient detections
|
||||
track_id = 1001
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
# Add sufficient detections
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(5):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
result = validator.validate_track(track_id)
|
||||
|
||||
assert result.is_stable is True
|
||||
assert result.detection_count == 5
|
||||
assert result.absence_count == 0
|
||||
assert len(result.reasons) == 0
|
||||
|
||||
def test_validate_track_insufficient_detections(self):
|
||||
"""Test validating track with insufficient detections."""
|
||||
config = StabilityConfig(min_detection_frames=10, max_absence_frames=5)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Create track with insufficient detections
|
||||
track_id = 1001
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
# Add only few detections
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(3):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
result = validator.validate_track(track_id)
|
||||
|
||||
assert result.is_stable is False
|
||||
assert "Insufficient detection frames" in result.reasons
|
||||
|
||||
def test_validate_track_too_many_absences(self):
|
||||
"""Test validating track with too many absences."""
|
||||
config = StabilityConfig(min_detection_frames=3, max_absence_frames=2)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Create track with too many absences
|
||||
track_id = 1001
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
# Add detections and absences
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(5):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
# Add too many absences
|
||||
for _ in range(5):
|
||||
metrics.increment_absence()
|
||||
|
||||
result = validator.validate_track(track_id)
|
||||
|
||||
assert result.is_stable is False
|
||||
assert "Too many absence frames" in result.reasons
|
||||
|
||||
def test_validate_track_low_confidence(self):
|
||||
"""Test validating track with low confidence."""
|
||||
config = StabilityConfig(
|
||||
min_detection_frames=3,
|
||||
max_absence_frames=5,
|
||||
confidence_threshold=0.8
|
||||
)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Create track with low confidence
|
||||
track_id = 1001
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
# Add detections with low confidence
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(5):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.5, # Below threshold
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
result = validator.validate_track(track_id)
|
||||
|
||||
assert result.is_stable is False
|
||||
assert "Low average confidence" in result.reasons
|
||||
|
||||
def test_validate_all_tracks(self):
|
||||
"""Test validating all tracks."""
|
||||
config = StabilityConfig(min_detection_frames=3)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Add multiple tracks
|
||||
for track_id in [1001, 1002, 1003]:
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
# Make some tracks stable, others not
|
||||
detection_count = 5 if track_id == 1001 else 2
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
for i in range(detection_count):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
results = validator.validate_all_tracks()
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[1001].is_stable is True # 5 detections
|
||||
assert results[1002].is_stable is False # 2 detections
|
||||
assert results[1003].is_stable is False # 2 detections
|
||||
|
||||
def test_get_stable_tracks(self):
|
||||
"""Test getting stable track IDs."""
|
||||
config = StabilityConfig(min_detection_frames=3)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Add tracks with different stability
|
||||
for track_id, detection_count in [(1001, 5), (1002, 2), (1003, 4)]:
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(detection_count):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
stable_tracks = validator.get_stable_tracks()
|
||||
|
||||
assert stable_tracks == [1001, 1003] # 5 and 4 detections respectively
|
||||
|
||||
def test_cleanup_expired_tracks(self):
|
||||
"""Test cleanup of expired tracks."""
|
||||
config = StabilityConfig(stability_window=10.0)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Add tracks with different last detection times
|
||||
current_time = 1640995300.0
|
||||
|
||||
for track_id, last_detection_time in [(1001, current_time - 5), (1002, current_time - 15)]:
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=int(last_detection_time * 1000)
|
||||
)
|
||||
metrics.add_detection(detection, current_time=last_detection_time)
|
||||
|
||||
removed_count = validator.cleanup_expired_tracks(current_time)
|
||||
|
||||
assert removed_count == 1 # 1002 should be removed (15 > 10 seconds)
|
||||
assert 1001 in validator.track_metrics
|
||||
assert 1002 not in validator.track_metrics
|
||||
|
||||
def test_clear_all_tracks(self):
|
||||
"""Test clearing all track metrics."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
# Add some tracks
|
||||
for track_id in [1001, 1002]:
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
|
||||
assert len(validator.track_metrics) == 2
|
||||
|
||||
validator.clear_all_tracks()
|
||||
|
||||
assert len(validator.track_metrics) == 0
|
||||
|
||||
def test_get_validation_summary(self):
|
||||
"""Test getting validation summary statistics."""
|
||||
config = StabilityConfig(min_detection_frames=3)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Add tracks with different characteristics
|
||||
track_data = [
|
||||
(1001, 5, True), # Stable
|
||||
(1002, 2, False), # Unstable
|
||||
(1003, 4, True), # Stable
|
||||
(1004, 1, False) # Unstable
|
||||
]
|
||||
|
||||
for track_id, detection_count, _ in track_data:
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(detection_count):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
summary = validator.get_validation_summary()
|
||||
|
||||
assert summary["total_tracks"] == 4
|
||||
assert summary["stable_tracks"] == 2
|
||||
assert summary["unstable_tracks"] == 2
|
||||
assert summary["stability_rate"] == 0.5
|
||||
|
||||
|
||||
class TestStabilityValidatorIntegration:
|
||||
"""Integration tests for stability validator."""
|
||||
|
||||
def test_full_tracking_lifecycle(self):
|
||||
"""Test complete tracking lifecycle with stability validation."""
|
||||
config = StabilityConfig(
|
||||
min_detection_frames=3,
|
||||
max_absence_frames=2,
|
||||
confidence_threshold=0.7
|
||||
)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
track_id = 1001
|
||||
|
||||
# Phase 1: Initial detections (building up)
|
||||
for i in range(5):
|
||||
bbox = BoundingBox(x1=100+i*2, y1=200+i*2, x2=300+i*2, y2=400+i*2)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
validator.update_detections([detection], current_time=1640995200.0 + i)
|
||||
|
||||
# Should be stable now
|
||||
result = validator.validate_track(track_id)
|
||||
assert result.is_stable is True
|
||||
|
||||
# Phase 2: Some absences
|
||||
for i in range(2):
|
||||
validator.update_detections([], current_time=1640995205.0 + i)
|
||||
|
||||
# Still stable (within absence threshold)
|
||||
result = validator.validate_track(track_id)
|
||||
assert result.is_stable is True
|
||||
|
||||
# Phase 3: Track reappears
|
||||
bbox = BoundingBox(x1=120, y1=220, x2=320, y2=420)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995207000
|
||||
)
|
||||
validator.update_detections([detection], current_time=1640995207.0)
|
||||
|
||||
# Should reset absence count and remain stable
|
||||
result = validator.validate_track(track_id)
|
||||
assert result.is_stable is True
|
||||
assert validator.track_metrics[track_id].absence_count == 0
|
||||
|
||||
def test_multi_track_validation(self):
|
||||
"""Test validation with multiple tracks."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
# Simulate multi-track scenario
|
||||
frame_detections = [
|
||||
# Frame 1
|
||||
[
|
||||
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000),
|
||||
DetectionResult("truck", 0.8, BoundingBox(400, 200, 600, 400), 1002, 1640995200000)
|
||||
],
|
||||
# Frame 2
|
||||
[
|
||||
DetectionResult("car", 0.85, BoundingBox(105, 205, 305, 405), 1001, 1640995201000),
|
||||
DetectionResult("truck", 0.82, BoundingBox(405, 205, 605, 405), 1002, 1640995201000),
|
||||
DetectionResult("car", 0.75, BoundingBox(200, 300, 400, 500), 1003, 1640995201000)
|
||||
],
|
||||
# Frame 3 - track 1002 disappears
|
||||
[
|
||||
DetectionResult("car", 0.88, BoundingBox(110, 210, 310, 410), 1001, 1640995202000),
|
||||
DetectionResult("car", 0.78, BoundingBox(205, 305, 405, 505), 1003, 1640995202000)
|
||||
]
|
||||
]
|
||||
|
||||
# Process frames
|
||||
for i, detections in enumerate(frame_detections):
|
||||
validator.update_detections(detections, current_time=1640995200.0 + i)
|
||||
|
||||
# Get validation results
|
||||
validation_results = validator.validate_all_tracks()
|
||||
|
||||
assert len(validation_results) == 3
|
||||
|
||||
# All tracks should be unstable (insufficient frames)
|
||||
for result in validation_results.values():
|
||||
assert result.is_stable is False
|
||||
assert "Insufficient detection frames" in result.reasons
|
606
tests/unit/detection/test_tracking_manager.py
Normal file
606
tests/unit/detection/test_tracking_manager.py
Normal file
|
@ -0,0 +1,606 @@
|
|||
"""
|
||||
Unit tests for BoT-SORT tracking management.
|
||||
"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from collections import defaultdict
|
||||
|
||||
from detector_worker.detection.tracking_manager import TrackingManager, TrackInfo
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
from detector_worker.core.exceptions import TrackingError
|
||||
|
||||
|
||||
class TestTrackInfo:
|
||||
"""Test TrackInfo data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test TrackInfo creation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
|
||||
assert track.track_id == 1001
|
||||
assert track.bbox == bbox
|
||||
assert track.confidence == 0.85
|
||||
assert track.class_name == "car"
|
||||
assert track.first_seen == 1640995200.0
|
||||
assert track.last_seen == 1640995300.0
|
||||
assert track.frame_count == 1
|
||||
assert track.absence_count == 0
|
||||
|
||||
def test_update_track(self):
|
||||
"""Test updating track information."""
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox1,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995200.0
|
||||
)
|
||||
|
||||
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
|
||||
track.update(bbox2, 0.90, 1640995300.0)
|
||||
|
||||
assert track.bbox == bbox2
|
||||
assert track.confidence == 0.90
|
||||
assert track.last_seen == 1640995300.0
|
||||
assert track.frame_count == 2
|
||||
assert track.absence_count == 0
|
||||
|
||||
def test_increment_absence(self):
|
||||
"""Test incrementing absence count."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995200.0
|
||||
)
|
||||
|
||||
track.increment_absence()
|
||||
assert track.absence_count == 1
|
||||
|
||||
track.increment_absence()
|
||||
assert track.absence_count == 2
|
||||
|
||||
def test_age_calculation(self):
|
||||
"""Test track age calculation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
|
||||
age = track.age(current_time=1640995400.0)
|
||||
assert age == 200.0 # 1640995400 - 1640995200
|
||||
|
||||
def test_time_since_last_seen(self):
|
||||
"""Test time since last seen calculation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
|
||||
time_since = track.time_since_last_seen(current_time=1640995450.0)
|
||||
assert time_since == 150.0 # 1640995450 - 1640995300
|
||||
|
||||
def test_is_stable(self):
|
||||
"""Test track stability checking."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
|
||||
# Not stable initially
|
||||
assert track.is_stable(min_frames=5, max_absence=3) is False
|
||||
|
||||
# Make it stable
|
||||
track.frame_count = 10
|
||||
track.absence_count = 1
|
||||
assert track.is_stable(min_frames=5, max_absence=3) is True
|
||||
|
||||
# Too many absences
|
||||
track.absence_count = 5
|
||||
assert track.is_stable(min_frames=5, max_absence=3) is False
|
||||
|
||||
|
||||
class TestTrackingManager:
|
||||
"""Test tracking management functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test tracking manager initialization."""
|
||||
manager = TrackingManager()
|
||||
|
||||
assert manager.max_absence_frames == 30
|
||||
assert manager.min_stable_frames == 10
|
||||
assert manager.track_timeout == 60.0
|
||||
assert len(manager.active_tracks) == 0
|
||||
assert len(manager.stable_tracks) == 0
|
||||
|
||||
def test_initialization_with_config(self):
|
||||
"""Test initialization with custom configuration."""
|
||||
config = {
|
||||
"max_absence_frames": 20,
|
||||
"min_stable_frames": 5,
|
||||
"track_timeout": 30.0
|
||||
}
|
||||
manager = TrackingManager(config)
|
||||
|
||||
assert manager.max_absence_frames == 20
|
||||
assert manager.min_stable_frames == 5
|
||||
assert manager.track_timeout == 30.0
|
||||
|
||||
def test_update_tracks_new_detections(self):
|
||||
"""Test updating with new detections."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
assert len(manager.active_tracks) == 1
|
||||
assert 1001 in manager.active_tracks
|
||||
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.track_id == 1001
|
||||
assert track.class_name == "car"
|
||||
assert track.confidence == 0.85
|
||||
assert track.frame_count == 1
|
||||
|
||||
def test_update_tracks_existing_detection(self):
|
||||
"""Test updating existing track."""
|
||||
manager = TrackingManager()
|
||||
|
||||
# First detection
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection1], current_time=1640995200.0)
|
||||
|
||||
# Second detection (same track, different position)
|
||||
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.90,
|
||||
bbox=bbox2,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection2], current_time=1640995300.0)
|
||||
|
||||
assert len(manager.active_tracks) == 1
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.frame_count == 2
|
||||
assert track.confidence == 0.90
|
||||
assert track.bbox == bbox2
|
||||
assert track.absence_count == 0
|
||||
|
||||
def test_update_tracks_no_detections(self):
|
||||
"""Test updating with no detections (increment absence)."""
|
||||
manager = TrackingManager()
|
||||
|
||||
# Add initial track
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
# Update with no detections
|
||||
manager.update_tracks([], current_time=1640995300.0)
|
||||
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.absence_count == 1
|
||||
|
||||
def test_cleanup_expired_tracks(self):
|
||||
"""Test cleanup of expired tracks."""
|
||||
manager = TrackingManager({"track_timeout": 10.0})
|
||||
|
||||
# Add track
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
assert len(manager.active_tracks) == 1
|
||||
|
||||
# Cleanup after timeout
|
||||
removed_count = manager.cleanup_expired_tracks(current_time=1640995220.0) # 20 seconds later
|
||||
|
||||
assert removed_count == 1
|
||||
assert len(manager.active_tracks) == 0
|
||||
|
||||
def test_cleanup_absent_tracks(self):
|
||||
"""Test cleanup of tracks with too many absences."""
|
||||
manager = TrackingManager({"max_absence_frames": 3})
|
||||
|
||||
# Add track
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
# Increment absence count beyond threshold
|
||||
for i in range(5):
|
||||
manager.update_tracks([], current_time=1640995200.0 + i)
|
||||
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.absence_count == 5
|
||||
|
||||
# Cleanup absent tracks
|
||||
removed_count = manager.cleanup_absent_tracks()
|
||||
|
||||
assert removed_count == 1
|
||||
assert len(manager.active_tracks) == 0
|
||||
|
||||
def test_get_stable_tracks(self):
|
||||
"""Test getting stable tracks."""
|
||||
manager = TrackingManager({"min_stable_frames": 3})
|
||||
|
||||
# Add track and make it stable
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track_info = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
track_info.frame_count = 5 # Make it stable
|
||||
|
||||
manager.active_tracks[1001] = track_info
|
||||
|
||||
stable_tracks = manager.get_stable_tracks()
|
||||
|
||||
assert len(stable_tracks) == 1
|
||||
assert 1001 in stable_tracks
|
||||
assert 1001 in manager.stable_tracks # Should be cached
|
||||
|
||||
def test_get_track_by_id(self):
|
||||
"""Test getting track by ID."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
track = manager.get_track_by_id(1001)
|
||||
assert track is not None
|
||||
assert track.track_id == 1001
|
||||
|
||||
non_existent = manager.get_track_by_id(9999)
|
||||
assert non_existent is None
|
||||
|
||||
def test_get_tracks_by_class(self):
|
||||
"""Test getting tracks by class name."""
|
||||
manager = TrackingManager()
|
||||
|
||||
# Add different classes
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
|
||||
detection2 = DetectionResult(
|
||||
class_name="truck",
|
||||
confidence=0.80,
|
||||
bbox=bbox2,
|
||||
track_id=1002,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500)
|
||||
detection3 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.90,
|
||||
bbox=bbox3,
|
||||
track_id=1003,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection1, detection2, detection3], current_time=1640995200.0)
|
||||
|
||||
car_tracks = manager.get_tracks_by_class("car")
|
||||
assert len(car_tracks) == 2
|
||||
assert 1001 in car_tracks
|
||||
assert 1003 in car_tracks
|
||||
|
||||
truck_tracks = manager.get_tracks_by_class("truck")
|
||||
assert len(truck_tracks) == 1
|
||||
assert 1002 in truck_tracks
|
||||
|
||||
def test_get_track_count(self):
|
||||
"""Test getting track counts."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
assert manager.get_active_track_count() == 1
|
||||
assert manager.get_track_count_by_class("car") == 1
|
||||
assert manager.get_track_count_by_class("truck") == 0
|
||||
|
||||
def test_clear_all_tracks(self):
|
||||
"""Test clearing all tracks."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
assert len(manager.active_tracks) == 1
|
||||
|
||||
manager.clear_all_tracks()
|
||||
|
||||
assert len(manager.active_tracks) == 0
|
||||
assert len(manager.stable_tracks) == 0
|
||||
|
||||
def test_get_track_statistics(self):
|
||||
"""Test getting track statistics."""
|
||||
manager = TrackingManager({"min_stable_frames": 2})
|
||||
|
||||
# Add multiple tracks
|
||||
detections = []
|
||||
for i in range(3):
|
||||
bbox = BoundingBox(x1=100+i*50, y1=200, x2=300+i*50, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001+i,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
manager.update_tracks(detections, current_time=1640995200.0)
|
||||
|
||||
# Make some tracks stable
|
||||
manager.active_tracks[1001].frame_count = 5
|
||||
manager.active_tracks[1002].frame_count = 3
|
||||
# 1003 remains unstable with frame_count=1
|
||||
|
||||
stats = manager.get_track_statistics()
|
||||
|
||||
assert stats["active_tracks"] == 3
|
||||
assert stats["stable_tracks"] == 2
|
||||
assert stats["unstable_tracks"] == 1
|
||||
assert "average_track_age" in stats
|
||||
assert "average_confidence" in stats
|
||||
|
||||
def test_validate_tracks(self):
|
||||
"""Test track validation."""
|
||||
manager = TrackingManager({"min_stable_frames": 3, "max_absence_frames": 2})
|
||||
|
||||
# Add tracks with different stability
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track1 = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox1,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
track1.frame_count = 5 # Stable
|
||||
track1.absence_count = 1 # Present
|
||||
|
||||
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
|
||||
track2 = TrackInfo(
|
||||
track_id=1002,
|
||||
bbox=bbox2,
|
||||
confidence=0.80,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995250.0
|
||||
)
|
||||
track2.frame_count = 2 # Not stable
|
||||
track2.absence_count = 1
|
||||
|
||||
bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500)
|
||||
track3 = TrackInfo(
|
||||
track_id=1003,
|
||||
bbox=bbox3,
|
||||
confidence=0.90,
|
||||
class_name="car",
|
||||
first_seen=1640995100.0,
|
||||
last_seen=1640995150.0
|
||||
)
|
||||
track3.frame_count = 8 # Was stable but now absent
|
||||
track3.absence_count = 5 # Too many absences
|
||||
|
||||
manager.active_tracks = {1001: track1, 1002: track2, 1003: track3}
|
||||
manager.stable_tracks = {1001, 1003} # 1003 was previously stable
|
||||
|
||||
validation_result = manager.validate_tracks()
|
||||
|
||||
assert validation_result.stable_tracks == [1001]
|
||||
assert validation_result.current_tracks == [1001, 1002, 1003]
|
||||
assert validation_result.newly_stable == []
|
||||
assert validation_result.lost_tracks == [1003]
|
||||
|
||||
def test_track_persistence_across_frames(self):
|
||||
"""Test track persistence across multiple frames."""
|
||||
manager = TrackingManager()
|
||||
|
||||
# Frame 1
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection1], current_time=1640995200.0)
|
||||
|
||||
# Frame 2 - track moves
|
||||
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.88,
|
||||
bbox=bbox2,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection2], current_time=1640995300.0)
|
||||
|
||||
# Frame 3 - track disappears
|
||||
manager.update_tracks([], current_time=1640995400.0)
|
||||
|
||||
# Frame 4 - track reappears
|
||||
bbox4 = BoundingBox(x1=120, y1=220, x2=320, y2=420)
|
||||
detection4 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.82,
|
||||
bbox=bbox4,
|
||||
track_id=1001,
|
||||
timestamp=1640995500000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection4], current_time=1640995500.0)
|
||||
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.frame_count == 3 # Seen in 3 frames
|
||||
assert track.absence_count == 0 # Reset when reappeared
|
||||
assert track.bbox == bbox4 # Latest position
|
||||
|
||||
|
||||
class TestTrackingManagerErrorHandling:
|
||||
"""Test error handling in tracking manager."""
|
||||
|
||||
def test_invalid_detection_input(self):
|
||||
"""Test handling of invalid detection input."""
|
||||
manager = TrackingManager()
|
||||
|
||||
# None detection should be handled gracefully
|
||||
with pytest.raises(TrackingError):
|
||||
manager.update_tracks([None], current_time=1640995200.0)
|
||||
|
||||
def test_negative_track_id(self):
|
||||
"""Test handling of negative track ID."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=-1, # Invalid track ID
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
with pytest.raises(TrackingError):
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
def test_duplicate_track_ids_different_classes(self):
|
||||
"""Test handling of duplicate track IDs with different classes."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
|
||||
detection2 = DetectionResult(
|
||||
class_name="truck", # Different class, same ID
|
||||
confidence=0.80,
|
||||
bbox=bbox2,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
# Should log warning but handle gracefully
|
||||
manager.update_tracks([detection1, detection2], current_time=1640995200.0)
|
||||
|
||||
# The later detection should update the track
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.class_name == "truck" # Last update wins
|
386
tests/unit/detection/test_yolo_detector.py
Normal file
386
tests/unit/detection/test_yolo_detector.py
Normal file
|
@ -0,0 +1,386 @@
|
|||
"""
|
||||
Unit tests for YOLO detector with tracking functionality.
|
||||
"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
import torch
|
||||
|
||||
from detector_worker.detection.yolo_detector import YOLODetector
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
from detector_worker.core.exceptions import DetectionError
|
||||
|
||||
|
||||
class TestYOLODetector:
|
||||
"""Test YOLO detection and tracking functionality."""
|
||||
|
||||
def test_initialization_with_valid_model(self, mock_yolo_model):
|
||||
"""Test detector initialization with valid model."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
assert detector.model is mock_yolo_model
|
||||
assert detector.class_names == {}
|
||||
assert detector.is_tracking_enabled is True
|
||||
|
||||
def test_initialization_with_class_names(self, mock_yolo_model):
|
||||
"""Test detector initialization with class names."""
|
||||
class_names = {0: "car", 1: "truck", 2: "bus"}
|
||||
detector = YOLODetector(mock_yolo_model, class_names=class_names)
|
||||
|
||||
assert detector.class_names == class_names
|
||||
|
||||
def test_initialization_tracking_disabled(self, mock_yolo_model):
|
||||
"""Test detector initialization with tracking disabled."""
|
||||
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
|
||||
|
||||
assert detector.is_tracking_enabled is False
|
||||
|
||||
def test_detect_with_tracking(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with tracking enabled."""
|
||||
# Mock detection result
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0], # x1, y1, x2, y2, conf, class
|
||||
[150, 250, 350, 450, 0.85, 1]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001, 1002])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert len(detections) == 2
|
||||
assert detections[0].confidence == 0.9
|
||||
assert detections[0].track_id == 1001
|
||||
assert detections[0].bbox.x1 == 100
|
||||
|
||||
mock_yolo_model.track.assert_called_once_with(mock_frame, persist=True, verbose=False)
|
||||
|
||||
def test_detect_without_tracking(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with tracking disabled."""
|
||||
# Mock detection result
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = None # No tracking IDs
|
||||
|
||||
mock_yolo_model.predict.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert len(detections) == 1
|
||||
assert detections[0].track_id is None # No tracking ID
|
||||
|
||||
mock_yolo_model.predict.assert_called_once_with(mock_frame, verbose=False)
|
||||
|
||||
def test_detect_with_class_names(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with class name mapping."""
|
||||
class_names = {0: "car", 1: "truck"}
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0], # car
|
||||
[150, 250, 350, 450, 0.85, 1] # truck
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001, 1002])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model, class_names=class_names)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert detections[0].class_name == "car"
|
||||
assert detections[1].class_name == "truck"
|
||||
|
||||
def test_detect_no_boxes(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection when no objects are detected."""
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = None
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert detections == []
|
||||
|
||||
def test_detect_empty_boxes(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with empty boxes tensor."""
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([]).reshape(0, 6)
|
||||
mock_result.boxes.id = None
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert detections == []
|
||||
|
||||
def test_detect_with_confidence_threshold(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with confidence threshold filtering."""
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0], # Above threshold
|
||||
[150, 250, 350, 450, 0.3, 1] # Below threshold
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001, 1002])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame, confidence_threshold=0.5)
|
||||
|
||||
assert len(detections) == 1 # Only one above threshold
|
||||
assert detections[0].confidence == 0.9
|
||||
|
||||
def test_detect_model_error_handling(self, mock_yolo_model, mock_frame):
|
||||
"""Test error handling when model fails."""
|
||||
mock_yolo_model.track.side_effect = Exception("Model inference failed")
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
with pytest.raises(DetectionError) as exc_info:
|
||||
detector.detect(mock_frame)
|
||||
|
||||
assert "Model inference failed" in str(exc_info.value)
|
||||
|
||||
def test_detect_invalid_frame(self, mock_yolo_model):
|
||||
"""Test detection with invalid frame input."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
with pytest.raises(DetectionError) as exc_info:
|
||||
detector.detect(None)
|
||||
|
||||
assert "Invalid frame" in str(exc_info.value)
|
||||
|
||||
def test_detect_result_validation(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection result validation."""
|
||||
# Mock result with invalid bounding box (x2 <= x1)
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[300, 200, 100, 400, 0.9, 0] # Invalid: x2 < x1
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
# Invalid detections should be filtered out
|
||||
assert detections == []
|
||||
|
||||
def test_get_model_info(self, mock_yolo_model):
|
||||
"""Test getting model information."""
|
||||
mock_yolo_model.device = "cuda:0"
|
||||
mock_yolo_model.names = {0: "car", 1: "truck"}
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
info = detector.get_model_info()
|
||||
|
||||
assert info["device"] == "cuda:0"
|
||||
assert info["class_names"] == {0: "car", 1: "truck"}
|
||||
assert info["tracking_enabled"] is True
|
||||
|
||||
def test_set_tracking_enabled(self, mock_yolo_model):
|
||||
"""Test enabling/disabling tracking at runtime."""
|
||||
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
|
||||
assert detector.is_tracking_enabled is False
|
||||
|
||||
detector.set_tracking_enabled(True)
|
||||
assert detector.is_tracking_enabled is True
|
||||
|
||||
detector.set_tracking_enabled(False)
|
||||
assert detector.is_tracking_enabled is False
|
||||
|
||||
def test_update_class_names(self, mock_yolo_model):
|
||||
"""Test updating class names at runtime."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
new_class_names = {0: "vehicle", 1: "person"}
|
||||
detector.update_class_names(new_class_names)
|
||||
|
||||
assert detector.class_names == new_class_names
|
||||
|
||||
def test_reset_tracker(self, mock_yolo_model):
|
||||
"""Test resetting the tracking state."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
# This should not raise an error
|
||||
detector.reset_tracker()
|
||||
|
||||
def test_detect_with_crop_region(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with crop region specified."""
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[50, 75, 150, 175, 0.9, 0] # Relative to cropped region
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
crop_region = (100, 200, 300, 400) # x1, y1, x2, y2
|
||||
detections = detector.detect(mock_frame, crop_region=crop_region)
|
||||
|
||||
# Bounding box should be adjusted to global coordinates
|
||||
assert detections[0].bbox.x1 == 150 # 100 + 50
|
||||
assert detections[0].bbox.y1 == 275 # 200 + 75
|
||||
assert detections[0].bbox.x2 == 250 # 100 + 150
|
||||
assert detections[0].bbox.y2 == 375 # 200 + 175
|
||||
|
||||
def test_detect_batch_processing(self, mock_yolo_model):
|
||||
"""Test batch detection processing."""
|
||||
frames = [
|
||||
np.zeros((480, 640, 3), dtype=np.uint8),
|
||||
np.ones((480, 640, 3), dtype=np.uint8) * 255
|
||||
]
|
||||
|
||||
mock_results = []
|
||||
for i in range(2):
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100 + i*10, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001 + i])
|
||||
mock_results.append(mock_result)
|
||||
|
||||
mock_yolo_model.track.side_effect = [[result] for result in mock_results]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
batch_detections = detector.detect_batch(frames)
|
||||
|
||||
assert len(batch_detections) == 2
|
||||
assert len(batch_detections[0]) == 1
|
||||
assert len(batch_detections[1]) == 1
|
||||
assert batch_detections[0][0].bbox.x1 == 100
|
||||
assert batch_detections[1][0].bbox.x1 == 110
|
||||
|
||||
def test_detect_batch_empty_frames(self, mock_yolo_model):
|
||||
"""Test batch detection with empty frame list."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
batch_detections = detector.detect_batch([])
|
||||
|
||||
assert batch_detections == []
|
||||
|
||||
def test_detect_performance_metrics(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection performance metrics collection."""
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
mock_result.speed = {"preprocess": 2.1, "inference": 15.3, "postprocess": 1.2}
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame, return_metrics=True)
|
||||
|
||||
# Check if performance metrics are available
|
||||
assert hasattr(detector, '_last_inference_time')
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda:0", "mps"])
|
||||
def test_detect_different_devices(self, device, mock_frame):
|
||||
"""Test detection on different devices."""
|
||||
mock_model = Mock()
|
||||
mock_model.device = device
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_model)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert len(detections) == 1
|
||||
assert detections[0].confidence == 0.9
|
||||
|
||||
|
||||
class TestYOLODetectorIntegration:
|
||||
"""Integration tests for YOLO detector."""
|
||||
|
||||
def test_detect_with_real_tensor_operations(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with realistic tensor operations."""
|
||||
# Create more realistic box data
|
||||
boxes_data = torch.tensor([
|
||||
[100.5, 200.3, 299.7, 399.8, 0.95, 0],
|
||||
[150.2, 250.1, 349.9, 449.6, 0.87, 1],
|
||||
[200.0, 300.0, 400.0, 500.0, 0.45, 0] # Low confidence
|
||||
])
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = boxes_data
|
||||
mock_result.boxes.id = torch.tensor([2001, 2002, 2003])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
class_names = {0: "car", 1: "truck"}
|
||||
detector = YOLODetector(mock_yolo_model, class_names=class_names)
|
||||
|
||||
detections = detector.detect(mock_frame, confidence_threshold=0.5)
|
||||
|
||||
# Should filter out low confidence detection
|
||||
assert len(detections) == 2
|
||||
|
||||
# Check first detection
|
||||
det1 = detections[0]
|
||||
assert det1.class_name == "car"
|
||||
assert det1.confidence == pytest.approx(0.95)
|
||||
assert det1.track_id == 2001
|
||||
assert det1.bbox.x1 == pytest.approx(100.5)
|
||||
assert det1.bbox.y1 == pytest.approx(200.3)
|
||||
|
||||
# Check second detection
|
||||
det2 = detections[1]
|
||||
assert det2.class_name == "truck"
|
||||
assert det2.confidence == pytest.approx(0.87)
|
||||
assert det2.track_id == 2002
|
||||
|
||||
def test_multi_frame_tracking_consistency(self, mock_yolo_model, mock_frame):
|
||||
"""Test that tracking IDs remain consistent across frames."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
# Frame 1
|
||||
mock_result1 = Mock()
|
||||
mock_result1.boxes = Mock()
|
||||
mock_result1.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result1.boxes.id = torch.tensor([5001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result1]
|
||||
detections1 = detector.detect(mock_frame)
|
||||
|
||||
# Frame 2 - same object, slightly moved
|
||||
mock_result2 = Mock()
|
||||
mock_result2.boxes = Mock()
|
||||
mock_result2.boxes.data = torch.tensor([
|
||||
[105, 205, 305, 405, 0.88, 0]
|
||||
])
|
||||
mock_result2.boxes.id = torch.tensor([5001]) # Same ID
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result2]
|
||||
detections2 = detector.detect(mock_frame)
|
||||
|
||||
# Should maintain same track ID
|
||||
assert detections1[0].track_id == detections2[0].track_id == 5001
|
Loading…
Add table
Add a link
Reference in a new issue