""" Unit tests for track stability validation. """ import pytest import time from unittest.mock import Mock, patch from collections import defaultdict from detector_worker.detection.stability_validator import ( StabilityValidator, StabilityConfig, ValidationResult, TrackStabilityMetrics ) from detector_worker.detection.detection_result import DetectionResult, BoundingBox, TrackValidationResult from detector_worker.core.exceptions import ValidationError class TestStabilityConfig: """Test stability configuration data structure.""" def test_default_config(self): """Test default stability configuration.""" config = StabilityConfig() assert config.min_detection_frames == 10 assert config.max_absence_frames == 30 assert config.confidence_threshold == 0.5 assert config.stability_window == 60.0 assert config.iou_threshold == 0.3 assert config.movement_threshold == 50.0 def test_custom_config(self): """Test custom stability configuration.""" config = StabilityConfig( min_detection_frames=5, max_absence_frames=15, confidence_threshold=0.8, stability_window=30.0, iou_threshold=0.5, movement_threshold=25.0 ) assert config.min_detection_frames == 5 assert config.max_absence_frames == 15 assert config.confidence_threshold == 0.8 assert config.stability_window == 30.0 assert config.iou_threshold == 0.5 assert config.movement_threshold == 25.0 def test_from_dict(self): """Test creating config from dictionary.""" config_dict = { "min_detection_frames": 8, "max_absence_frames": 25, "confidence_threshold": 0.75, "unknown_field": "ignored" } config = StabilityConfig.from_dict(config_dict) assert config.min_detection_frames == 8 assert config.max_absence_frames == 25 assert config.confidence_threshold == 0.75 # Unknown fields should use defaults assert config.stability_window == 60.0 class TestTrackStabilityMetrics: """Test track stability metrics.""" def test_initialization(self): """Test metrics initialization.""" metrics = TrackStabilityMetrics(track_id=1001) assert metrics.track_id == 1001 assert metrics.detection_count == 0 assert metrics.absence_count == 0 assert metrics.total_confidence == 0.0 assert metrics.first_detection_time is None assert metrics.last_detection_time is None assert metrics.bounding_boxes == [] assert metrics.confidence_scores == [] def test_add_detection(self): """Test adding detection to metrics.""" metrics = TrackStabilityMetrics(track_id=1001) bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001, timestamp=1640995200000 ) metrics.add_detection(detection, current_time=1640995200.0) assert metrics.detection_count == 1 assert metrics.absence_count == 0 assert metrics.total_confidence == 0.85 assert metrics.first_detection_time == 1640995200.0 assert metrics.last_detection_time == 1640995200.0 assert len(metrics.bounding_boxes) == 1 assert len(metrics.confidence_scores) == 1 def test_increment_absence(self): """Test incrementing absence count.""" metrics = TrackStabilityMetrics(track_id=1001) metrics.increment_absence() assert metrics.absence_count == 1 metrics.increment_absence() assert metrics.absence_count == 2 def test_reset_absence(self): """Test resetting absence count.""" metrics = TrackStabilityMetrics(track_id=1001) metrics.increment_absence() metrics.increment_absence() assert metrics.absence_count == 2 metrics.reset_absence() assert metrics.absence_count == 0 def test_average_confidence(self): """Test average confidence calculation.""" metrics = TrackStabilityMetrics(track_id=1001) # No detections assert metrics.average_confidence() == 0.0 # Add detections bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection1 = DetectionResult( class_name="car", confidence=0.8, bbox=bbox, track_id=1001, timestamp=1640995200000 ) detection2 = DetectionResult( class_name="car", confidence=0.9, bbox=bbox, track_id=1001, timestamp=1640995300000 ) metrics.add_detection(detection1, current_time=1640995200.0) metrics.add_detection(detection2, current_time=1640995300.0) assert metrics.average_confidence() == 0.85 # (0.8 + 0.9) / 2 def test_tracking_duration(self): """Test tracking duration calculation.""" metrics = TrackStabilityMetrics(track_id=1001) # No detections assert metrics.tracking_duration() == 0.0 # Add detections bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection1 = DetectionResult( class_name="car", confidence=0.8, bbox=bbox, track_id=1001, timestamp=1640995200000 ) detection2 = DetectionResult( class_name="car", confidence=0.9, bbox=bbox, track_id=1001, timestamp=1640995300000 ) metrics.add_detection(detection1, current_time=1640995200.0) metrics.add_detection(detection2, current_time=1640995300.0) assert metrics.tracking_duration() == 100.0 # 1640995300 - 1640995200 def test_movement_distance(self): """Test movement distance calculation.""" metrics = TrackStabilityMetrics(track_id=1001) # No movement with single detection bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection1 = DetectionResult( class_name="car", confidence=0.8, bbox=bbox1, track_id=1001, timestamp=1640995200000 ) metrics.add_detection(detection1, current_time=1640995200.0) assert metrics.total_movement_distance() == 0.0 # Add second detection with movement bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410) detection2 = DetectionResult( class_name="car", confidence=0.9, bbox=bbox2, track_id=1001, timestamp=1640995300000 ) metrics.add_detection(detection2, current_time=1640995300.0) # Distance between centers: (200,300) to (210,310) = sqrt(100+100) ≈ 14.14 movement = metrics.total_movement_distance() assert movement == pytest.approx(14.14, rel=1e-2) class TestValidationResult: """Test validation result data structure.""" def test_initialization(self): """Test validation result initialization.""" result = ValidationResult( track_id=1001, is_stable=True, detection_count=15, absence_count=2, average_confidence=0.85, tracking_duration=120.0 ) assert result.track_id == 1001 assert result.is_stable is True assert result.detection_count == 15 assert result.absence_count == 2 assert result.average_confidence == 0.85 assert result.tracking_duration == 120.0 assert result.reasons == [] def test_with_reasons(self): """Test validation result with failure reasons.""" result = ValidationResult( track_id=1001, is_stable=False, detection_count=5, absence_count=35, average_confidence=0.4, tracking_duration=30.0, reasons=["Insufficient detection frames", "Too many absences", "Low confidence"] ) assert result.is_stable is False assert len(result.reasons) == 3 assert "Insufficient detection frames" in result.reasons class TestStabilityValidator: """Test stability validation functionality.""" def test_initialization_default(self): """Test validator initialization with default config.""" validator = StabilityValidator() assert isinstance(validator.config, StabilityConfig) assert validator.config.min_detection_frames == 10 assert len(validator.track_metrics) == 0 def test_initialization_custom_config(self): """Test validator initialization with custom config.""" config = StabilityConfig(min_detection_frames=5, confidence_threshold=0.8) validator = StabilityValidator(config) assert validator.config.min_detection_frames == 5 assert validator.config.confidence_threshold == 0.8 def test_update_detections_new_track(self): """Test updating with new track.""" validator = StabilityValidator() bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001, timestamp=1640995200000 ) validator.update_detections([detection], current_time=1640995200.0) assert 1001 in validator.track_metrics metrics = validator.track_metrics[1001] assert metrics.detection_count == 1 assert metrics.absence_count == 0 def test_update_detections_existing_track(self): """Test updating existing track.""" validator = StabilityValidator() # First detection bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection1 = DetectionResult( class_name="car", confidence=0.8, bbox=bbox1, track_id=1001, timestamp=1640995200000 ) validator.update_detections([detection1], current_time=1640995200.0) # Second detection bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410) detection2 = DetectionResult( class_name="car", confidence=0.9, bbox=bbox2, track_id=1001, timestamp=1640995300000 ) validator.update_detections([detection2], current_time=1640995300.0) metrics = validator.track_metrics[1001] assert metrics.detection_count == 2 assert metrics.absence_count == 0 assert metrics.average_confidence() == 0.85 def test_update_detections_missing_track(self): """Test updating when track is missing (increment absence).""" validator = StabilityValidator() # Add track bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001, timestamp=1640995200000 ) validator.update_detections([detection], current_time=1640995200.0) # Update with empty detections validator.update_detections([], current_time=1640995300.0) metrics = validator.track_metrics[1001] assert metrics.detection_count == 1 assert metrics.absence_count == 1 def test_validate_track_stable(self): """Test validating a stable track.""" config = StabilityConfig(min_detection_frames=3, max_absence_frames=5) validator = StabilityValidator(config) # Create track with sufficient detections track_id = 1001 validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) metrics = validator.track_metrics[track_id] # Add sufficient detections bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) for i in range(5): detection = DetectionResult( class_name="car", confidence=0.8, bbox=bbox, track_id=track_id, timestamp=1640995200000 + i * 1000 ) metrics.add_detection(detection, current_time=1640995200.0 + i) result = validator.validate_track(track_id) assert result.is_stable is True assert result.detection_count == 5 assert result.absence_count == 0 assert len(result.reasons) == 0 def test_validate_track_insufficient_detections(self): """Test validating track with insufficient detections.""" config = StabilityConfig(min_detection_frames=10, max_absence_frames=5) validator = StabilityValidator(config) # Create track with insufficient detections track_id = 1001 validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) metrics = validator.track_metrics[track_id] # Add only few detections bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) for i in range(3): detection = DetectionResult( class_name="car", confidence=0.8, bbox=bbox, track_id=track_id, timestamp=1640995200000 + i * 1000 ) metrics.add_detection(detection, current_time=1640995200.0 + i) result = validator.validate_track(track_id) assert result.is_stable is False assert "Insufficient detection frames" in result.reasons def test_validate_track_too_many_absences(self): """Test validating track with too many absences.""" config = StabilityConfig(min_detection_frames=3, max_absence_frames=2) validator = StabilityValidator(config) # Create track with too many absences track_id = 1001 validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) metrics = validator.track_metrics[track_id] # Add detections and absences bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) for i in range(5): detection = DetectionResult( class_name="car", confidence=0.8, bbox=bbox, track_id=track_id, timestamp=1640995200000 + i * 1000 ) metrics.add_detection(detection, current_time=1640995200.0 + i) # Add too many absences for _ in range(5): metrics.increment_absence() result = validator.validate_track(track_id) assert result.is_stable is False assert "Too many absence frames" in result.reasons def test_validate_track_low_confidence(self): """Test validating track with low confidence.""" config = StabilityConfig( min_detection_frames=3, max_absence_frames=5, confidence_threshold=0.8 ) validator = StabilityValidator(config) # Create track with low confidence track_id = 1001 validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) metrics = validator.track_metrics[track_id] # Add detections with low confidence bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) for i in range(5): detection = DetectionResult( class_name="car", confidence=0.5, # Below threshold bbox=bbox, track_id=track_id, timestamp=1640995200000 + i * 1000 ) metrics.add_detection(detection, current_time=1640995200.0 + i) result = validator.validate_track(track_id) assert result.is_stable is False assert "Low average confidence" in result.reasons def test_validate_all_tracks(self): """Test validating all tracks.""" config = StabilityConfig(min_detection_frames=3) validator = StabilityValidator(config) # Add multiple tracks for track_id in [1001, 1002, 1003]: validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) metrics = validator.track_metrics[track_id] # Make some tracks stable, others not detection_count = 5 if track_id == 1001 else 2 bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) for i in range(detection_count): detection = DetectionResult( class_name="car", confidence=0.8, bbox=bbox, track_id=track_id, timestamp=1640995200000 + i * 1000 ) metrics.add_detection(detection, current_time=1640995200.0 + i) results = validator.validate_all_tracks() assert len(results) == 3 assert results[1001].is_stable is True # 5 detections assert results[1002].is_stable is False # 2 detections assert results[1003].is_stable is False # 2 detections def test_get_stable_tracks(self): """Test getting stable track IDs.""" config = StabilityConfig(min_detection_frames=3) validator = StabilityValidator(config) # Add tracks with different stability for track_id, detection_count in [(1001, 5), (1002, 2), (1003, 4)]: validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) metrics = validator.track_metrics[track_id] bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) for i in range(detection_count): detection = DetectionResult( class_name="car", confidence=0.8, bbox=bbox, track_id=track_id, timestamp=1640995200000 + i * 1000 ) metrics.add_detection(detection, current_time=1640995200.0 + i) stable_tracks = validator.get_stable_tracks() assert stable_tracks == [1001, 1003] # 5 and 4 detections respectively def test_cleanup_expired_tracks(self): """Test cleanup of expired tracks.""" config = StabilityConfig(stability_window=10.0) validator = StabilityValidator(config) # Add tracks with different last detection times current_time = 1640995300.0 for track_id, last_detection_time in [(1001, current_time - 5), (1002, current_time - 15)]: validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) metrics = validator.track_metrics[track_id] bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.8, bbox=bbox, track_id=track_id, timestamp=int(last_detection_time * 1000) ) metrics.add_detection(detection, current_time=last_detection_time) removed_count = validator.cleanup_expired_tracks(current_time) assert removed_count == 1 # 1002 should be removed (15 > 10 seconds) assert 1001 in validator.track_metrics assert 1002 not in validator.track_metrics def test_clear_all_tracks(self): """Test clearing all track metrics.""" validator = StabilityValidator() # Add some tracks for track_id in [1001, 1002]: validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) assert len(validator.track_metrics) == 2 validator.clear_all_tracks() assert len(validator.track_metrics) == 0 def test_get_validation_summary(self): """Test getting validation summary statistics.""" config = StabilityConfig(min_detection_frames=3) validator = StabilityValidator(config) # Add tracks with different characteristics track_data = [ (1001, 5, True), # Stable (1002, 2, False), # Unstable (1003, 4, True), # Stable (1004, 1, False) # Unstable ] for track_id, detection_count, _ in track_data: validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) metrics = validator.track_metrics[track_id] bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) for i in range(detection_count): detection = DetectionResult( class_name="car", confidence=0.8, bbox=bbox, track_id=track_id, timestamp=1640995200000 + i * 1000 ) metrics.add_detection(detection, current_time=1640995200.0 + i) summary = validator.get_validation_summary() assert summary["total_tracks"] == 4 assert summary["stable_tracks"] == 2 assert summary["unstable_tracks"] == 2 assert summary["stability_rate"] == 0.5 class TestStabilityValidatorIntegration: """Integration tests for stability validator.""" def test_full_tracking_lifecycle(self): """Test complete tracking lifecycle with stability validation.""" config = StabilityConfig( min_detection_frames=3, max_absence_frames=2, confidence_threshold=0.7 ) validator = StabilityValidator(config) track_id = 1001 # Phase 1: Initial detections (building up) for i in range(5): bbox = BoundingBox(x1=100+i*2, y1=200+i*2, x2=300+i*2, y2=400+i*2) detection = DetectionResult( class_name="car", confidence=0.8, bbox=bbox, track_id=track_id, timestamp=1640995200000 + i * 1000 ) validator.update_detections([detection], current_time=1640995200.0 + i) # Should be stable now result = validator.validate_track(track_id) assert result.is_stable is True # Phase 2: Some absences for i in range(2): validator.update_detections([], current_time=1640995205.0 + i) # Still stable (within absence threshold) result = validator.validate_track(track_id) assert result.is_stable is True # Phase 3: Track reappears bbox = BoundingBox(x1=120, y1=220, x2=320, y2=420) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=track_id, timestamp=1640995207000 ) validator.update_detections([detection], current_time=1640995207.0) # Should reset absence count and remain stable result = validator.validate_track(track_id) assert result.is_stable is True assert validator.track_metrics[track_id].absence_count == 0 def test_multi_track_validation(self): """Test validation with multiple tracks.""" validator = StabilityValidator() # Simulate multi-track scenario frame_detections = [ # Frame 1 [ DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000), DetectionResult("truck", 0.8, BoundingBox(400, 200, 600, 400), 1002, 1640995200000) ], # Frame 2 [ DetectionResult("car", 0.85, BoundingBox(105, 205, 305, 405), 1001, 1640995201000), DetectionResult("truck", 0.82, BoundingBox(405, 205, 605, 405), 1002, 1640995201000), DetectionResult("car", 0.75, BoundingBox(200, 300, 400, 500), 1003, 1640995201000) ], # Frame 3 - track 1002 disappears [ DetectionResult("car", 0.88, BoundingBox(110, 210, 310, 410), 1001, 1640995202000), DetectionResult("car", 0.78, BoundingBox(205, 305, 405, 505), 1003, 1640995202000) ] ] # Process frames for i, detections in enumerate(frame_detections): validator.update_detections(detections, current_time=1640995200.0 + i) # Get validation results validation_results = validator.validate_all_tracks() assert len(validation_results) == 3 # All tracks should be unstable (insufficient frames) for result in validation_results.values(): assert result.is_stable is False assert "Insufficient detection frames" in result.reasons