""" Unit tests for BoT-SORT tracking management. """ import pytest import numpy as np from unittest.mock import Mock, MagicMock, patch from collections import defaultdict from detector_worker.detection.tracking_manager import TrackingManager, TrackInfo from detector_worker.detection.detection_result import DetectionResult, BoundingBox from detector_worker.core.exceptions import TrackingError class TestTrackInfo: """Test TrackInfo data structure.""" def test_creation(self): """Test TrackInfo creation.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) track = TrackInfo( track_id=1001, bbox=bbox, confidence=0.85, class_name="car", first_seen=1640995200.0, last_seen=1640995300.0 ) assert track.track_id == 1001 assert track.bbox == bbox assert track.confidence == 0.85 assert track.class_name == "car" assert track.first_seen == 1640995200.0 assert track.last_seen == 1640995300.0 assert track.frame_count == 1 assert track.absence_count == 0 def test_update_track(self): """Test updating track information.""" bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) track = TrackInfo( track_id=1001, bbox=bbox1, confidence=0.85, class_name="car", first_seen=1640995200.0, last_seen=1640995200.0 ) bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410) track.update(bbox2, 0.90, 1640995300.0) assert track.bbox == bbox2 assert track.confidence == 0.90 assert track.last_seen == 1640995300.0 assert track.frame_count == 2 assert track.absence_count == 0 def test_increment_absence(self): """Test incrementing absence count.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) track = TrackInfo( track_id=1001, bbox=bbox, confidence=0.85, class_name="car", first_seen=1640995200.0, last_seen=1640995200.0 ) track.increment_absence() assert track.absence_count == 1 track.increment_absence() assert track.absence_count == 2 def test_age_calculation(self): """Test track age calculation.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) track = TrackInfo( track_id=1001, bbox=bbox, confidence=0.85, class_name="car", first_seen=1640995200.0, last_seen=1640995300.0 ) age = track.age(current_time=1640995400.0) assert age == 200.0 # 1640995400 - 1640995200 def test_time_since_last_seen(self): """Test time since last seen calculation.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) track = TrackInfo( track_id=1001, bbox=bbox, confidence=0.85, class_name="car", first_seen=1640995200.0, last_seen=1640995300.0 ) time_since = track.time_since_last_seen(current_time=1640995450.0) assert time_since == 150.0 # 1640995450 - 1640995300 def test_is_stable(self): """Test track stability checking.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) track = TrackInfo( track_id=1001, bbox=bbox, confidence=0.85, class_name="car", first_seen=1640995200.0, last_seen=1640995300.0 ) # Not stable initially assert track.is_stable(min_frames=5, max_absence=3) is False # Make it stable track.frame_count = 10 track.absence_count = 1 assert track.is_stable(min_frames=5, max_absence=3) is True # Too many absences track.absence_count = 5 assert track.is_stable(min_frames=5, max_absence=3) is False class TestTrackingManager: """Test tracking management functionality.""" def test_initialization(self): """Test tracking manager initialization.""" manager = TrackingManager() assert manager.max_absence_frames == 30 assert manager.min_stable_frames == 10 assert manager.track_timeout == 60.0 assert len(manager.active_tracks) == 0 assert len(manager.stable_tracks) == 0 def test_initialization_with_config(self): """Test initialization with custom configuration.""" config = { "max_absence_frames": 20, "min_stable_frames": 5, "track_timeout": 30.0 } manager = TrackingManager(config) assert manager.max_absence_frames == 20 assert manager.min_stable_frames == 5 assert manager.track_timeout == 30.0 def test_update_tracks_new_detections(self): """Test updating with new detections.""" manager = TrackingManager() bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001, timestamp=1640995200000 ) manager.update_tracks([detection], current_time=1640995200.0) assert len(manager.active_tracks) == 1 assert 1001 in manager.active_tracks track = manager.active_tracks[1001] assert track.track_id == 1001 assert track.class_name == "car" assert track.confidence == 0.85 assert track.frame_count == 1 def test_update_tracks_existing_detection(self): """Test updating existing track.""" manager = TrackingManager() # First detection bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection1 = DetectionResult( class_name="car", confidence=0.85, bbox=bbox1, track_id=1001, timestamp=1640995200000 ) manager.update_tracks([detection1], current_time=1640995200.0) # Second detection (same track, different position) bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410) detection2 = DetectionResult( class_name="car", confidence=0.90, bbox=bbox2, track_id=1001, timestamp=1640995300000 ) manager.update_tracks([detection2], current_time=1640995300.0) assert len(manager.active_tracks) == 1 track = manager.active_tracks[1001] assert track.frame_count == 2 assert track.confidence == 0.90 assert track.bbox == bbox2 assert track.absence_count == 0 def test_update_tracks_no_detections(self): """Test updating with no detections (increment absence).""" manager = TrackingManager() # Add initial track bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001, timestamp=1640995200000 ) manager.update_tracks([detection], current_time=1640995200.0) # Update with no detections manager.update_tracks([], current_time=1640995300.0) track = manager.active_tracks[1001] assert track.absence_count == 1 def test_cleanup_expired_tracks(self): """Test cleanup of expired tracks.""" manager = TrackingManager({"track_timeout": 10.0}) # Add track bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001, timestamp=1640995200000 ) manager.update_tracks([detection], current_time=1640995200.0) assert len(manager.active_tracks) == 1 # Cleanup after timeout removed_count = manager.cleanup_expired_tracks(current_time=1640995220.0) # 20 seconds later assert removed_count == 1 assert len(manager.active_tracks) == 0 def test_cleanup_absent_tracks(self): """Test cleanup of tracks with too many absences.""" manager = TrackingManager({"max_absence_frames": 3}) # Add track bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001, timestamp=1640995200000 ) manager.update_tracks([detection], current_time=1640995200.0) # Increment absence count beyond threshold for i in range(5): manager.update_tracks([], current_time=1640995200.0 + i) track = manager.active_tracks[1001] assert track.absence_count == 5 # Cleanup absent tracks removed_count = manager.cleanup_absent_tracks() assert removed_count == 1 assert len(manager.active_tracks) == 0 def test_get_stable_tracks(self): """Test getting stable tracks.""" manager = TrackingManager({"min_stable_frames": 3}) # Add track and make it stable bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) track_info = TrackInfo( track_id=1001, bbox=bbox, confidence=0.85, class_name="car", first_seen=1640995200.0, last_seen=1640995300.0 ) track_info.frame_count = 5 # Make it stable manager.active_tracks[1001] = track_info stable_tracks = manager.get_stable_tracks() assert len(stable_tracks) == 1 assert 1001 in stable_tracks assert 1001 in manager.stable_tracks # Should be cached def test_get_track_by_id(self): """Test getting track by ID.""" manager = TrackingManager() bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001, timestamp=1640995200000 ) manager.update_tracks([detection], current_time=1640995200.0) track = manager.get_track_by_id(1001) assert track is not None assert track.track_id == 1001 non_existent = manager.get_track_by_id(9999) assert non_existent is None def test_get_tracks_by_class(self): """Test getting tracks by class name.""" manager = TrackingManager() # Add different classes bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection1 = DetectionResult( class_name="car", confidence=0.85, bbox=bbox1, track_id=1001, timestamp=1640995200000 ) bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450) detection2 = DetectionResult( class_name="truck", confidence=0.80, bbox=bbox2, track_id=1002, timestamp=1640995200000 ) bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500) detection3 = DetectionResult( class_name="car", confidence=0.90, bbox=bbox3, track_id=1003, timestamp=1640995200000 ) manager.update_tracks([detection1, detection2, detection3], current_time=1640995200.0) car_tracks = manager.get_tracks_by_class("car") assert len(car_tracks) == 2 assert 1001 in car_tracks assert 1003 in car_tracks truck_tracks = manager.get_tracks_by_class("truck") assert len(truck_tracks) == 1 assert 1002 in truck_tracks def test_get_track_count(self): """Test getting track counts.""" manager = TrackingManager() bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001, timestamp=1640995200000 ) manager.update_tracks([detection], current_time=1640995200.0) assert manager.get_active_track_count() == 1 assert manager.get_track_count_by_class("car") == 1 assert manager.get_track_count_by_class("truck") == 0 def test_clear_all_tracks(self): """Test clearing all tracks.""" manager = TrackingManager() bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001, timestamp=1640995200000 ) manager.update_tracks([detection], current_time=1640995200.0) assert len(manager.active_tracks) == 1 manager.clear_all_tracks() assert len(manager.active_tracks) == 0 assert len(manager.stable_tracks) == 0 def test_get_track_statistics(self): """Test getting track statistics.""" manager = TrackingManager({"min_stable_frames": 2}) # Add multiple tracks detections = [] for i in range(3): bbox = BoundingBox(x1=100+i*50, y1=200, x2=300+i*50, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=1001+i, timestamp=1640995200000 ) detections.append(detection) manager.update_tracks(detections, current_time=1640995200.0) # Make some tracks stable manager.active_tracks[1001].frame_count = 5 manager.active_tracks[1002].frame_count = 3 # 1003 remains unstable with frame_count=1 stats = manager.get_track_statistics() assert stats["active_tracks"] == 3 assert stats["stable_tracks"] == 2 assert stats["unstable_tracks"] == 1 assert "average_track_age" in stats assert "average_confidence" in stats def test_validate_tracks(self): """Test track validation.""" manager = TrackingManager({"min_stable_frames": 3, "max_absence_frames": 2}) # Add tracks with different stability bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) track1 = TrackInfo( track_id=1001, bbox=bbox1, confidence=0.85, class_name="car", first_seen=1640995200.0, last_seen=1640995300.0 ) track1.frame_count = 5 # Stable track1.absence_count = 1 # Present bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450) track2 = TrackInfo( track_id=1002, bbox=bbox2, confidence=0.80, class_name="car", first_seen=1640995200.0, last_seen=1640995250.0 ) track2.frame_count = 2 # Not stable track2.absence_count = 1 bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500) track3 = TrackInfo( track_id=1003, bbox=bbox3, confidence=0.90, class_name="car", first_seen=1640995100.0, last_seen=1640995150.0 ) track3.frame_count = 8 # Was stable but now absent track3.absence_count = 5 # Too many absences manager.active_tracks = {1001: track1, 1002: track2, 1003: track3} manager.stable_tracks = {1001, 1003} # 1003 was previously stable validation_result = manager.validate_tracks() assert validation_result.stable_tracks == [1001] assert validation_result.current_tracks == [1001, 1002, 1003] assert validation_result.newly_stable == [] assert validation_result.lost_tracks == [1003] def test_track_persistence_across_frames(self): """Test track persistence across multiple frames.""" manager = TrackingManager() # Frame 1 bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection1 = DetectionResult( class_name="car", confidence=0.85, bbox=bbox1, track_id=1001, timestamp=1640995200000 ) manager.update_tracks([detection1], current_time=1640995200.0) # Frame 2 - track moves bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410) detection2 = DetectionResult( class_name="car", confidence=0.88, bbox=bbox2, track_id=1001, timestamp=1640995300000 ) manager.update_tracks([detection2], current_time=1640995300.0) # Frame 3 - track disappears manager.update_tracks([], current_time=1640995400.0) # Frame 4 - track reappears bbox4 = BoundingBox(x1=120, y1=220, x2=320, y2=420) detection4 = DetectionResult( class_name="car", confidence=0.82, bbox=bbox4, track_id=1001, timestamp=1640995500000 ) manager.update_tracks([detection4], current_time=1640995500.0) track = manager.active_tracks[1001] assert track.frame_count == 3 # Seen in 3 frames assert track.absence_count == 0 # Reset when reappeared assert track.bbox == bbox4 # Latest position class TestTrackingManagerErrorHandling: """Test error handling in tracking manager.""" def test_invalid_detection_input(self): """Test handling of invalid detection input.""" manager = TrackingManager() # None detection should be handled gracefully with pytest.raises(TrackingError): manager.update_tracks([None], current_time=1640995200.0) def test_negative_track_id(self): """Test handling of negative track ID.""" manager = TrackingManager() bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=-1, # Invalid track ID timestamp=1640995200000 ) with pytest.raises(TrackingError): manager.update_tracks([detection], current_time=1640995200.0) def test_duplicate_track_ids_different_classes(self): """Test handling of duplicate track IDs with different classes.""" manager = TrackingManager() bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection1 = DetectionResult( class_name="car", confidence=0.85, bbox=bbox1, track_id=1001, timestamp=1640995200000 ) bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450) detection2 = DetectionResult( class_name="truck", # Different class, same ID confidence=0.80, bbox=bbox2, track_id=1001, timestamp=1640995200000 ) # Should log warning but handle gracefully manager.update_tracks([detection1, detection2], current_time=1640995200.0) # The later detection should update the track track = manager.active_tracks[1001] assert track.class_name == "truck" # Last update wins