""" Unit tests for detection result data structures. """ import pytest from dataclasses import asdict import numpy as np from detector_worker.detection.detection_result import ( BoundingBox, DetectionResult, LightweightDetectionResult, DetectionSession, TrackValidationResult ) class TestBoundingBox: """Test BoundingBox data structure.""" def test_creation_from_coordinates(self): """Test creating bounding box from coordinates.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) assert bbox.x1 == 100 assert bbox.y1 == 200 assert bbox.x2 == 300 assert bbox.y2 == 400 def test_creation_from_list(self): """Test creating bounding box from list.""" coords = [100, 200, 300, 400] bbox = BoundingBox.from_list(coords) assert bbox.x1 == 100 assert bbox.y1 == 200 assert bbox.x2 == 300 assert bbox.y2 == 400 def test_creation_from_invalid_list(self): """Test error handling for invalid list.""" with pytest.raises(ValueError): BoundingBox.from_list([100, 200, 300]) # Too few elements def test_to_list(self): """Test converting bounding box to list.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) coords = bbox.to_list() assert coords == [100, 200, 300, 400] def test_area_calculation(self): """Test area calculation.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) area = bbox.area() expected_area = (300 - 100) * (400 - 200) # 200 * 200 = 40000 assert area == expected_area def test_area_zero_for_invalid_bbox(self): """Test area is zero for invalid bounding box.""" # x2 <= x1 bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400) assert bbox.area() == 0 # y2 <= y1 bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200) assert bbox.area() == 0 def test_width_height(self): """Test width and height properties.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) assert bbox.width() == 200 assert bbox.height() == 200 def test_center_point(self): """Test center point calculation.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) center = bbox.center() assert center == (200, 300) # (x1+x2)/2, (y1+y2)/2 def test_is_valid(self): """Test bounding box validation.""" # Valid bbox bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) assert bbox.is_valid() is True # Invalid bbox (x2 <= x1) bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400) assert bbox.is_valid() is False # Invalid bbox (y2 <= y1) bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200) assert bbox.is_valid() is False def test_intersection(self): """Test bounding box intersection.""" bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300) bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400) intersection = bbox1.intersection(bbox2) assert intersection.x1 == 200 assert intersection.y1 == 200 assert intersection.x2 == 300 assert intersection.y2 == 300 def test_no_intersection(self): """Test no intersection between bounding boxes.""" bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200) bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400) intersection = bbox1.intersection(bbox2) assert intersection.is_valid() is False def test_union(self): """Test bounding box union.""" bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300) bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400) union = bbox1.union(bbox2) assert union.x1 == 100 assert union.y1 == 100 assert union.x2 == 400 assert union.y2 == 400 def test_iou_calculation(self): """Test IoU (Intersection over Union) calculation.""" # Perfect overlap bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300) bbox2 = BoundingBox(x1=100, y1=100, x2=300, y2=300) assert bbox1.iou(bbox2) == 1.0 # No overlap bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200) bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400) assert bbox1.iou(bbox2) == 0.0 # Partial overlap bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300) bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400) # Intersection area: 100x100 = 10000 # Union area: 200x200 + 200x200 - 10000 = 30000 # IoU = 10000/30000 = 1/3 expected_iou = 1.0 / 3.0 assert abs(bbox1.iou(bbox2) - expected_iou) < 1e-6 class TestDetectionResult: """Test DetectionResult data structure.""" def test_creation_with_required_fields(self): """Test creating detection result with required fields.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=12345 ) assert detection.class_name == "car" assert detection.confidence == 0.85 assert detection.bbox == bbox assert detection.track_id == 12345 def test_creation_with_all_fields(self): """Test creating detection result with all fields.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=12345, model_id="yolo_v8", timestamp=1640995200000, branch_results={"brand": "Toyota"} ) assert detection.model_id == "yolo_v8" assert detection.timestamp == 1640995200000 assert detection.branch_results == {"brand": "Toyota"} def test_creation_from_dict(self): """Test creating detection result from dictionary.""" data = { "class": "car", "confidence": 0.85, "bbox": [100, 200, 300, 400], "id": 12345, "model_id": "yolo_v8", "timestamp": 1640995200000 } detection = DetectionResult.from_dict(data) assert detection.class_name == "car" assert detection.confidence == 0.85 assert detection.bbox.to_list() == [100, 200, 300, 400] assert detection.track_id == 12345 def test_to_dict(self): """Test converting detection result to dictionary.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=12345 ) data = detection.to_dict() assert data["class"] == "car" assert data["confidence"] == 0.85 assert data["bbox"] == [100, 200, 300, 400] assert data["id"] == 12345 def test_is_valid_detection(self): """Test detection validation.""" bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) # Valid detection detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=12345 ) assert detection.is_valid() is True # Invalid confidence (too low) detection = DetectionResult( class_name="car", confidence=-0.1, bbox=bbox, track_id=12345 ) assert detection.is_valid() is False # Invalid confidence (too high) detection = DetectionResult( class_name="car", confidence=1.5, bbox=bbox, track_id=12345 ) assert detection.is_valid() is False # Invalid bounding box invalid_bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=invalid_bbox, track_id=12345 ) assert detection.is_valid() is False class TestLightweightDetectionResult: """Test LightweightDetectionResult data structure.""" def test_creation(self): """Test creating lightweight detection result.""" detection = LightweightDetectionResult( class_name="car", confidence=0.85, bbox_area=40000, frame_width=1920, frame_height=1080 ) assert detection.class_name == "car" assert detection.confidence == 0.85 assert detection.bbox_area == 40000 assert detection.frame_width == 1920 assert detection.frame_height == 1080 def test_area_ratio_calculation(self): """Test bounding box area ratio calculation.""" detection = LightweightDetectionResult( class_name="car", confidence=0.85, bbox_area=40000, frame_width=1920, frame_height=1080 ) expected_ratio = 40000 / (1920 * 1080) assert abs(detection.area_ratio() - expected_ratio) < 1e-6 def test_meets_threshold(self): """Test threshold checking.""" detection = LightweightDetectionResult( class_name="car", confidence=0.85, bbox_area=40000, frame_width=1920, frame_height=1080 ) assert detection.meets_threshold(confidence=0.8, area_ratio=0.01) is True assert detection.meets_threshold(confidence=0.9, area_ratio=0.01) is False assert detection.meets_threshold(confidence=0.8, area_ratio=0.1) is False class TestDetectionSession: """Test DetectionSession data structure.""" def test_creation(self): """Test creating detection session.""" session = DetectionSession( session_id="session_123", camera_id="camera_001", display_id="display_001" ) assert session.session_id == "session_123" assert session.camera_id == "camera_001" assert session.display_id == "display_001" assert session.detections == [] assert session.metadata == {} def test_add_detection(self): """Test adding detection to session.""" session = DetectionSession( session_id="session_123", camera_id="camera_001", display_id="display_001" ) bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=12345 ) session.add_detection(detection) assert len(session.detections) == 1 assert session.detections[0] == detection def test_get_latest_detection(self): """Test getting latest detection.""" session = DetectionSession( session_id="session_123", camera_id="camera_001", display_id="display_001" ) # Add multiple detections bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) detection1 = DetectionResult( class_name="car", confidence=0.85, bbox=bbox1, track_id=12345, timestamp=1640995200000 ) bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450) detection2 = DetectionResult( class_name="car", confidence=0.90, bbox=bbox2, track_id=12345, timestamp=1640995300000 ) session.add_detection(detection1) session.add_detection(detection2) latest = session.get_latest_detection() assert latest == detection2 # Should be the one with later timestamp def test_get_detections_by_class(self): """Test filtering detections by class.""" session = DetectionSession( session_id="session_123", camera_id="camera_001", display_id="display_001" ) bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) car_detection = DetectionResult( class_name="car", confidence=0.85, bbox=bbox, track_id=12345 ) truck_detection = DetectionResult( class_name="truck", confidence=0.80, bbox=bbox, track_id=54321 ) session.add_detection(car_detection) session.add_detection(truck_detection) car_detections = session.get_detections_by_class("car") assert len(car_detections) == 1 assert car_detections[0] == car_detection truck_detections = session.get_detections_by_class("truck") assert len(truck_detections) == 1 assert truck_detections[0] == truck_detection class TestTrackValidationResult: """Test TrackValidationResult data structure.""" def test_creation(self): """Test creating track validation result.""" result = TrackValidationResult( stable_tracks=[101, 102, 103], current_tracks=[101, 102, 104, 105], newly_stable=[103], lost_tracks=[106] ) assert result.stable_tracks == [101, 102, 103] assert result.current_tracks == [101, 102, 104, 105] assert result.newly_stable == [103] assert result.lost_tracks == [106] def test_has_stable_tracks(self): """Test checking for stable tracks.""" result = TrackValidationResult( stable_tracks=[101, 102], current_tracks=[101, 102, 103] ) assert result.has_stable_tracks() is True result_empty = TrackValidationResult( stable_tracks=[], current_tracks=[101, 102, 103] ) assert result_empty.has_stable_tracks() is False def test_get_stats(self): """Test getting validation statistics.""" result = TrackValidationResult( stable_tracks=[101, 102, 103], current_tracks=[101, 102, 104, 105], newly_stable=[103], lost_tracks=[106] ) stats = result.get_stats() assert stats["stable_count"] == 3 assert stats["current_count"] == 4 assert stats["newly_stable_count"] == 1 assert stats["lost_count"] == 1 assert stats["stability_ratio"] == 3/4 # stable/current def test_is_track_stable(self): """Test checking if specific track is stable.""" result = TrackValidationResult( stable_tracks=[101, 102, 103], current_tracks=[101, 102, 104, 105] ) assert result.is_track_stable(101) is True assert result.is_track_stable(102) is True assert result.is_track_stable(104) is False assert result.is_track_stable(999) is False