479 lines
No EOL
15 KiB
Python
479 lines
No EOL
15 KiB
Python
"""
|
|
Unit tests for detection result data structures.
|
|
"""
|
|
import pytest
|
|
from dataclasses import asdict
|
|
import numpy as np
|
|
|
|
from detector_worker.detection.detection_result import (
|
|
BoundingBox,
|
|
DetectionResult,
|
|
LightweightDetectionResult,
|
|
DetectionSession,
|
|
TrackValidationResult
|
|
)
|
|
|
|
|
|
class TestBoundingBox:
|
|
"""Test BoundingBox data structure."""
|
|
|
|
def test_creation_from_coordinates(self):
|
|
"""Test creating bounding box from coordinates."""
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
|
|
assert bbox.x1 == 100
|
|
assert bbox.y1 == 200
|
|
assert bbox.x2 == 300
|
|
assert bbox.y2 == 400
|
|
|
|
def test_creation_from_list(self):
|
|
"""Test creating bounding box from list."""
|
|
coords = [100, 200, 300, 400]
|
|
bbox = BoundingBox.from_list(coords)
|
|
|
|
assert bbox.x1 == 100
|
|
assert bbox.y1 == 200
|
|
assert bbox.x2 == 300
|
|
assert bbox.y2 == 400
|
|
|
|
def test_creation_from_invalid_list(self):
|
|
"""Test error handling for invalid list."""
|
|
with pytest.raises(ValueError):
|
|
BoundingBox.from_list([100, 200, 300]) # Too few elements
|
|
|
|
def test_to_list(self):
|
|
"""Test converting bounding box to list."""
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
coords = bbox.to_list()
|
|
|
|
assert coords == [100, 200, 300, 400]
|
|
|
|
def test_area_calculation(self):
|
|
"""Test area calculation."""
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
area = bbox.area()
|
|
|
|
expected_area = (300 - 100) * (400 - 200) # 200 * 200 = 40000
|
|
assert area == expected_area
|
|
|
|
def test_area_zero_for_invalid_bbox(self):
|
|
"""Test area is zero for invalid bounding box."""
|
|
# x2 <= x1
|
|
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
|
assert bbox.area() == 0
|
|
|
|
# y2 <= y1
|
|
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
|
|
assert bbox.area() == 0
|
|
|
|
def test_width_height(self):
|
|
"""Test width and height properties."""
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
|
|
assert bbox.width() == 200
|
|
assert bbox.height() == 200
|
|
|
|
def test_center_point(self):
|
|
"""Test center point calculation."""
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
center = bbox.center()
|
|
|
|
assert center == (200, 300) # (x1+x2)/2, (y1+y2)/2
|
|
|
|
def test_is_valid(self):
|
|
"""Test bounding box validation."""
|
|
# Valid bbox
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
assert bbox.is_valid() is True
|
|
|
|
# Invalid bbox (x2 <= x1)
|
|
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
|
assert bbox.is_valid() is False
|
|
|
|
# Invalid bbox (y2 <= y1)
|
|
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
|
|
assert bbox.is_valid() is False
|
|
|
|
def test_intersection(self):
|
|
"""Test bounding box intersection."""
|
|
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
|
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
|
|
|
intersection = bbox1.intersection(bbox2)
|
|
|
|
assert intersection.x1 == 200
|
|
assert intersection.y1 == 200
|
|
assert intersection.x2 == 300
|
|
assert intersection.y2 == 300
|
|
|
|
def test_no_intersection(self):
|
|
"""Test no intersection between bounding boxes."""
|
|
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
|
|
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
|
|
|
|
intersection = bbox1.intersection(bbox2)
|
|
|
|
assert intersection.is_valid() is False
|
|
|
|
def test_union(self):
|
|
"""Test bounding box union."""
|
|
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
|
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
|
|
|
union = bbox1.union(bbox2)
|
|
|
|
assert union.x1 == 100
|
|
assert union.y1 == 100
|
|
assert union.x2 == 400
|
|
assert union.y2 == 400
|
|
|
|
def test_iou_calculation(self):
|
|
"""Test IoU (Intersection over Union) calculation."""
|
|
# Perfect overlap
|
|
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
|
bbox2 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
|
assert bbox1.iou(bbox2) == 1.0
|
|
|
|
# No overlap
|
|
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
|
|
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
|
|
assert bbox1.iou(bbox2) == 0.0
|
|
|
|
# Partial overlap
|
|
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
|
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
|
|
|
# Intersection area: 100x100 = 10000
|
|
# Union area: 200x200 + 200x200 - 10000 = 30000
|
|
# IoU = 10000/30000 = 1/3
|
|
expected_iou = 1.0 / 3.0
|
|
assert abs(bbox1.iou(bbox2) - expected_iou) < 1e-6
|
|
|
|
|
|
class TestDetectionResult:
|
|
"""Test DetectionResult data structure."""
|
|
|
|
def test_creation_with_required_fields(self):
|
|
"""Test creating detection result with required fields."""
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
detection = DetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox=bbox,
|
|
track_id=12345
|
|
)
|
|
|
|
assert detection.class_name == "car"
|
|
assert detection.confidence == 0.85
|
|
assert detection.bbox == bbox
|
|
assert detection.track_id == 12345
|
|
|
|
def test_creation_with_all_fields(self):
|
|
"""Test creating detection result with all fields."""
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
detection = DetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox=bbox,
|
|
track_id=12345,
|
|
model_id="yolo_v8",
|
|
timestamp=1640995200000,
|
|
branch_results={"brand": "Toyota"}
|
|
)
|
|
|
|
assert detection.model_id == "yolo_v8"
|
|
assert detection.timestamp == 1640995200000
|
|
assert detection.branch_results == {"brand": "Toyota"}
|
|
|
|
def test_creation_from_dict(self):
|
|
"""Test creating detection result from dictionary."""
|
|
data = {
|
|
"class": "car",
|
|
"confidence": 0.85,
|
|
"bbox": [100, 200, 300, 400],
|
|
"id": 12345,
|
|
"model_id": "yolo_v8",
|
|
"timestamp": 1640995200000
|
|
}
|
|
|
|
detection = DetectionResult.from_dict(data)
|
|
|
|
assert detection.class_name == "car"
|
|
assert detection.confidence == 0.85
|
|
assert detection.bbox.to_list() == [100, 200, 300, 400]
|
|
assert detection.track_id == 12345
|
|
|
|
def test_to_dict(self):
|
|
"""Test converting detection result to dictionary."""
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
detection = DetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox=bbox,
|
|
track_id=12345
|
|
)
|
|
|
|
data = detection.to_dict()
|
|
|
|
assert data["class"] == "car"
|
|
assert data["confidence"] == 0.85
|
|
assert data["bbox"] == [100, 200, 300, 400]
|
|
assert data["id"] == 12345
|
|
|
|
def test_is_valid_detection(self):
|
|
"""Test detection validation."""
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
|
|
# Valid detection
|
|
detection = DetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox=bbox,
|
|
track_id=12345
|
|
)
|
|
assert detection.is_valid() is True
|
|
|
|
# Invalid confidence (too low)
|
|
detection = DetectionResult(
|
|
class_name="car",
|
|
confidence=-0.1,
|
|
bbox=bbox,
|
|
track_id=12345
|
|
)
|
|
assert detection.is_valid() is False
|
|
|
|
# Invalid confidence (too high)
|
|
detection = DetectionResult(
|
|
class_name="car",
|
|
confidence=1.5,
|
|
bbox=bbox,
|
|
track_id=12345
|
|
)
|
|
assert detection.is_valid() is False
|
|
|
|
# Invalid bounding box
|
|
invalid_bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
|
detection = DetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox=invalid_bbox,
|
|
track_id=12345
|
|
)
|
|
assert detection.is_valid() is False
|
|
|
|
|
|
class TestLightweightDetectionResult:
|
|
"""Test LightweightDetectionResult data structure."""
|
|
|
|
def test_creation(self):
|
|
"""Test creating lightweight detection result."""
|
|
detection = LightweightDetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox_area=40000,
|
|
frame_width=1920,
|
|
frame_height=1080
|
|
)
|
|
|
|
assert detection.class_name == "car"
|
|
assert detection.confidence == 0.85
|
|
assert detection.bbox_area == 40000
|
|
assert detection.frame_width == 1920
|
|
assert detection.frame_height == 1080
|
|
|
|
def test_area_ratio_calculation(self):
|
|
"""Test bounding box area ratio calculation."""
|
|
detection = LightweightDetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox_area=40000,
|
|
frame_width=1920,
|
|
frame_height=1080
|
|
)
|
|
|
|
expected_ratio = 40000 / (1920 * 1080)
|
|
assert abs(detection.area_ratio() - expected_ratio) < 1e-6
|
|
|
|
def test_meets_threshold(self):
|
|
"""Test threshold checking."""
|
|
detection = LightweightDetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox_area=40000,
|
|
frame_width=1920,
|
|
frame_height=1080
|
|
)
|
|
|
|
assert detection.meets_threshold(confidence=0.8, area_ratio=0.01) is True
|
|
assert detection.meets_threshold(confidence=0.9, area_ratio=0.01) is False
|
|
assert detection.meets_threshold(confidence=0.8, area_ratio=0.1) is False
|
|
|
|
|
|
class TestDetectionSession:
|
|
"""Test DetectionSession data structure."""
|
|
|
|
def test_creation(self):
|
|
"""Test creating detection session."""
|
|
session = DetectionSession(
|
|
session_id="session_123",
|
|
camera_id="camera_001",
|
|
display_id="display_001"
|
|
)
|
|
|
|
assert session.session_id == "session_123"
|
|
assert session.camera_id == "camera_001"
|
|
assert session.display_id == "display_001"
|
|
assert session.detections == []
|
|
assert session.metadata == {}
|
|
|
|
def test_add_detection(self):
|
|
"""Test adding detection to session."""
|
|
session = DetectionSession(
|
|
session_id="session_123",
|
|
camera_id="camera_001",
|
|
display_id="display_001"
|
|
)
|
|
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
detection = DetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox=bbox,
|
|
track_id=12345
|
|
)
|
|
|
|
session.add_detection(detection)
|
|
|
|
assert len(session.detections) == 1
|
|
assert session.detections[0] == detection
|
|
|
|
def test_get_latest_detection(self):
|
|
"""Test getting latest detection."""
|
|
session = DetectionSession(
|
|
session_id="session_123",
|
|
camera_id="camera_001",
|
|
display_id="display_001"
|
|
)
|
|
|
|
# Add multiple detections
|
|
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
detection1 = DetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox=bbox1,
|
|
track_id=12345,
|
|
timestamp=1640995200000
|
|
)
|
|
|
|
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
|
|
detection2 = DetectionResult(
|
|
class_name="car",
|
|
confidence=0.90,
|
|
bbox=bbox2,
|
|
track_id=12345,
|
|
timestamp=1640995300000
|
|
)
|
|
|
|
session.add_detection(detection1)
|
|
session.add_detection(detection2)
|
|
|
|
latest = session.get_latest_detection()
|
|
assert latest == detection2 # Should be the one with later timestamp
|
|
|
|
def test_get_detections_by_class(self):
|
|
"""Test filtering detections by class."""
|
|
session = DetectionSession(
|
|
session_id="session_123",
|
|
camera_id="camera_001",
|
|
display_id="display_001"
|
|
)
|
|
|
|
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
|
|
|
car_detection = DetectionResult(
|
|
class_name="car",
|
|
confidence=0.85,
|
|
bbox=bbox,
|
|
track_id=12345
|
|
)
|
|
|
|
truck_detection = DetectionResult(
|
|
class_name="truck",
|
|
confidence=0.80,
|
|
bbox=bbox,
|
|
track_id=54321
|
|
)
|
|
|
|
session.add_detection(car_detection)
|
|
session.add_detection(truck_detection)
|
|
|
|
car_detections = session.get_detections_by_class("car")
|
|
assert len(car_detections) == 1
|
|
assert car_detections[0] == car_detection
|
|
|
|
truck_detections = session.get_detections_by_class("truck")
|
|
assert len(truck_detections) == 1
|
|
assert truck_detections[0] == truck_detection
|
|
|
|
|
|
class TestTrackValidationResult:
|
|
"""Test TrackValidationResult data structure."""
|
|
|
|
def test_creation(self):
|
|
"""Test creating track validation result."""
|
|
result = TrackValidationResult(
|
|
stable_tracks=[101, 102, 103],
|
|
current_tracks=[101, 102, 104, 105],
|
|
newly_stable=[103],
|
|
lost_tracks=[106]
|
|
)
|
|
|
|
assert result.stable_tracks == [101, 102, 103]
|
|
assert result.current_tracks == [101, 102, 104, 105]
|
|
assert result.newly_stable == [103]
|
|
assert result.lost_tracks == [106]
|
|
|
|
def test_has_stable_tracks(self):
|
|
"""Test checking for stable tracks."""
|
|
result = TrackValidationResult(
|
|
stable_tracks=[101, 102],
|
|
current_tracks=[101, 102, 103]
|
|
)
|
|
|
|
assert result.has_stable_tracks() is True
|
|
|
|
result_empty = TrackValidationResult(
|
|
stable_tracks=[],
|
|
current_tracks=[101, 102, 103]
|
|
)
|
|
|
|
assert result_empty.has_stable_tracks() is False
|
|
|
|
def test_get_stats(self):
|
|
"""Test getting validation statistics."""
|
|
result = TrackValidationResult(
|
|
stable_tracks=[101, 102, 103],
|
|
current_tracks=[101, 102, 104, 105],
|
|
newly_stable=[103],
|
|
lost_tracks=[106]
|
|
)
|
|
|
|
stats = result.get_stats()
|
|
|
|
assert stats["stable_count"] == 3
|
|
assert stats["current_count"] == 4
|
|
assert stats["newly_stable_count"] == 1
|
|
assert stats["lost_count"] == 1
|
|
assert stats["stability_ratio"] == 3/4 # stable/current
|
|
|
|
def test_is_track_stable(self):
|
|
"""Test checking if specific track is stable."""
|
|
result = TrackValidationResult(
|
|
stable_tracks=[101, 102, 103],
|
|
current_tracks=[101, 102, 104, 105]
|
|
)
|
|
|
|
assert result.is_track_stable(101) is True
|
|
assert result.is_track_stable(102) is True
|
|
assert result.is_track_stable(104) is False
|
|
assert result.is_track_stable(999) is False |