Refactor: PHASE 8: Testing & Integration
This commit is contained in:
parent
af34f4fd08
commit
9e8c6804a7
32 changed files with 17128 additions and 0 deletions
479
tests/unit/detection/test_detection_result.py
Normal file
479
tests/unit/detection/test_detection_result.py
Normal file
|
@ -0,0 +1,479 @@
|
|||
"""
|
||||
Unit tests for detection result data structures.
|
||||
"""
|
||||
import pytest
|
||||
from dataclasses import asdict
|
||||
import numpy as np
|
||||
|
||||
from detector_worker.detection.detection_result import (
|
||||
BoundingBox,
|
||||
DetectionResult,
|
||||
LightweightDetectionResult,
|
||||
DetectionSession,
|
||||
TrackValidationResult
|
||||
)
|
||||
|
||||
|
||||
class TestBoundingBox:
|
||||
"""Test BoundingBox data structure."""
|
||||
|
||||
def test_creation_from_coordinates(self):
|
||||
"""Test creating bounding box from coordinates."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
assert bbox.x1 == 100
|
||||
assert bbox.y1 == 200
|
||||
assert bbox.x2 == 300
|
||||
assert bbox.y2 == 400
|
||||
|
||||
def test_creation_from_list(self):
|
||||
"""Test creating bounding box from list."""
|
||||
coords = [100, 200, 300, 400]
|
||||
bbox = BoundingBox.from_list(coords)
|
||||
|
||||
assert bbox.x1 == 100
|
||||
assert bbox.y1 == 200
|
||||
assert bbox.x2 == 300
|
||||
assert bbox.y2 == 400
|
||||
|
||||
def test_creation_from_invalid_list(self):
|
||||
"""Test error handling for invalid list."""
|
||||
with pytest.raises(ValueError):
|
||||
BoundingBox.from_list([100, 200, 300]) # Too few elements
|
||||
|
||||
def test_to_list(self):
|
||||
"""Test converting bounding box to list."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
coords = bbox.to_list()
|
||||
|
||||
assert coords == [100, 200, 300, 400]
|
||||
|
||||
def test_area_calculation(self):
|
||||
"""Test area calculation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
area = bbox.area()
|
||||
|
||||
expected_area = (300 - 100) * (400 - 200) # 200 * 200 = 40000
|
||||
assert area == expected_area
|
||||
|
||||
def test_area_zero_for_invalid_bbox(self):
|
||||
"""Test area is zero for invalid bounding box."""
|
||||
# x2 <= x1
|
||||
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
||||
assert bbox.area() == 0
|
||||
|
||||
# y2 <= y1
|
||||
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
|
||||
assert bbox.area() == 0
|
||||
|
||||
def test_width_height(self):
|
||||
"""Test width and height properties."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
assert bbox.width() == 200
|
||||
assert bbox.height() == 200
|
||||
|
||||
def test_center_point(self):
|
||||
"""Test center point calculation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
center = bbox.center()
|
||||
|
||||
assert center == (200, 300) # (x1+x2)/2, (y1+y2)/2
|
||||
|
||||
def test_is_valid(self):
|
||||
"""Test bounding box validation."""
|
||||
# Valid bbox
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
assert bbox.is_valid() is True
|
||||
|
||||
# Invalid bbox (x2 <= x1)
|
||||
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
||||
assert bbox.is_valid() is False
|
||||
|
||||
# Invalid bbox (y2 <= y1)
|
||||
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
|
||||
assert bbox.is_valid() is False
|
||||
|
||||
def test_intersection(self):
|
||||
"""Test bounding box intersection."""
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
||||
|
||||
intersection = bbox1.intersection(bbox2)
|
||||
|
||||
assert intersection.x1 == 200
|
||||
assert intersection.y1 == 200
|
||||
assert intersection.x2 == 300
|
||||
assert intersection.y2 == 300
|
||||
|
||||
def test_no_intersection(self):
|
||||
"""Test no intersection between bounding boxes."""
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
|
||||
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
|
||||
|
||||
intersection = bbox1.intersection(bbox2)
|
||||
|
||||
assert intersection.is_valid() is False
|
||||
|
||||
def test_union(self):
|
||||
"""Test bounding box union."""
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
||||
|
||||
union = bbox1.union(bbox2)
|
||||
|
||||
assert union.x1 == 100
|
||||
assert union.y1 == 100
|
||||
assert union.x2 == 400
|
||||
assert union.y2 == 400
|
||||
|
||||
def test_iou_calculation(self):
|
||||
"""Test IoU (Intersection over Union) calculation."""
|
||||
# Perfect overlap
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
assert bbox1.iou(bbox2) == 1.0
|
||||
|
||||
# No overlap
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
|
||||
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
|
||||
assert bbox1.iou(bbox2) == 0.0
|
||||
|
||||
# Partial overlap
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
||||
|
||||
# Intersection area: 100x100 = 10000
|
||||
# Union area: 200x200 + 200x200 - 10000 = 30000
|
||||
# IoU = 10000/30000 = 1/3
|
||||
expected_iou = 1.0 / 3.0
|
||||
assert abs(bbox1.iou(bbox2) - expected_iou) < 1e-6
|
||||
|
||||
|
||||
class TestDetectionResult:
|
||||
"""Test DetectionResult data structure."""
|
||||
|
||||
def test_creation_with_required_fields(self):
|
||||
"""Test creating detection result with required fields."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
assert detection.class_name == "car"
|
||||
assert detection.confidence == 0.85
|
||||
assert detection.bbox == bbox
|
||||
assert detection.track_id == 12345
|
||||
|
||||
def test_creation_with_all_fields(self):
|
||||
"""Test creating detection result with all fields."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345,
|
||||
model_id="yolo_v8",
|
||||
timestamp=1640995200000,
|
||||
branch_results={"brand": "Toyota"}
|
||||
)
|
||||
|
||||
assert detection.model_id == "yolo_v8"
|
||||
assert detection.timestamp == 1640995200000
|
||||
assert detection.branch_results == {"brand": "Toyota"}
|
||||
|
||||
def test_creation_from_dict(self):
|
||||
"""Test creating detection result from dictionary."""
|
||||
data = {
|
||||
"class": "car",
|
||||
"confidence": 0.85,
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"id": 12345,
|
||||
"model_id": "yolo_v8",
|
||||
"timestamp": 1640995200000
|
||||
}
|
||||
|
||||
detection = DetectionResult.from_dict(data)
|
||||
|
||||
assert detection.class_name == "car"
|
||||
assert detection.confidence == 0.85
|
||||
assert detection.bbox.to_list() == [100, 200, 300, 400]
|
||||
assert detection.track_id == 12345
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting detection result to dictionary."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
data = detection.to_dict()
|
||||
|
||||
assert data["class"] == "car"
|
||||
assert data["confidence"] == 0.85
|
||||
assert data["bbox"] == [100, 200, 300, 400]
|
||||
assert data["id"] == 12345
|
||||
|
||||
def test_is_valid_detection(self):
|
||||
"""Test detection validation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
# Valid detection
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is True
|
||||
|
||||
# Invalid confidence (too low)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=-0.1,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is False
|
||||
|
||||
# Invalid confidence (too high)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=1.5,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is False
|
||||
|
||||
# Invalid bounding box
|
||||
invalid_bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=invalid_bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is False
|
||||
|
||||
|
||||
class TestLightweightDetectionResult:
|
||||
"""Test LightweightDetectionResult data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating lightweight detection result."""
|
||||
detection = LightweightDetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox_area=40000,
|
||||
frame_width=1920,
|
||||
frame_height=1080
|
||||
)
|
||||
|
||||
assert detection.class_name == "car"
|
||||
assert detection.confidence == 0.85
|
||||
assert detection.bbox_area == 40000
|
||||
assert detection.frame_width == 1920
|
||||
assert detection.frame_height == 1080
|
||||
|
||||
def test_area_ratio_calculation(self):
|
||||
"""Test bounding box area ratio calculation."""
|
||||
detection = LightweightDetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox_area=40000,
|
||||
frame_width=1920,
|
||||
frame_height=1080
|
||||
)
|
||||
|
||||
expected_ratio = 40000 / (1920 * 1080)
|
||||
assert abs(detection.area_ratio() - expected_ratio) < 1e-6
|
||||
|
||||
def test_meets_threshold(self):
|
||||
"""Test threshold checking."""
|
||||
detection = LightweightDetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox_area=40000,
|
||||
frame_width=1920,
|
||||
frame_height=1080
|
||||
)
|
||||
|
||||
assert detection.meets_threshold(confidence=0.8, area_ratio=0.01) is True
|
||||
assert detection.meets_threshold(confidence=0.9, area_ratio=0.01) is False
|
||||
assert detection.meets_threshold(confidence=0.8, area_ratio=0.1) is False
|
||||
|
||||
|
||||
class TestDetectionSession:
|
||||
"""Test DetectionSession data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating detection session."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
assert session.session_id == "session_123"
|
||||
assert session.camera_id == "camera_001"
|
||||
assert session.display_id == "display_001"
|
||||
assert session.detections == []
|
||||
assert session.metadata == {}
|
||||
|
||||
def test_add_detection(self):
|
||||
"""Test adding detection to session."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
session.add_detection(detection)
|
||||
|
||||
assert len(session.detections) == 1
|
||||
assert session.detections[0] == detection
|
||||
|
||||
def test_get_latest_detection(self):
|
||||
"""Test getting latest detection."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
# Add multiple detections
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=12345,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.90,
|
||||
bbox=bbox2,
|
||||
track_id=12345,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
session.add_detection(detection1)
|
||||
session.add_detection(detection2)
|
||||
|
||||
latest = session.get_latest_detection()
|
||||
assert latest == detection2 # Should be the one with later timestamp
|
||||
|
||||
def test_get_detections_by_class(self):
|
||||
"""Test filtering detections by class."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
car_detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
truck_detection = DetectionResult(
|
||||
class_name="truck",
|
||||
confidence=0.80,
|
||||
bbox=bbox,
|
||||
track_id=54321
|
||||
)
|
||||
|
||||
session.add_detection(car_detection)
|
||||
session.add_detection(truck_detection)
|
||||
|
||||
car_detections = session.get_detections_by_class("car")
|
||||
assert len(car_detections) == 1
|
||||
assert car_detections[0] == car_detection
|
||||
|
||||
truck_detections = session.get_detections_by_class("truck")
|
||||
assert len(truck_detections) == 1
|
||||
assert truck_detections[0] == truck_detection
|
||||
|
||||
|
||||
class TestTrackValidationResult:
|
||||
"""Test TrackValidationResult data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating track validation result."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102, 103],
|
||||
current_tracks=[101, 102, 104, 105],
|
||||
newly_stable=[103],
|
||||
lost_tracks=[106]
|
||||
)
|
||||
|
||||
assert result.stable_tracks == [101, 102, 103]
|
||||
assert result.current_tracks == [101, 102, 104, 105]
|
||||
assert result.newly_stable == [103]
|
||||
assert result.lost_tracks == [106]
|
||||
|
||||
def test_has_stable_tracks(self):
|
||||
"""Test checking for stable tracks."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102],
|
||||
current_tracks=[101, 102, 103]
|
||||
)
|
||||
|
||||
assert result.has_stable_tracks() is True
|
||||
|
||||
result_empty = TrackValidationResult(
|
||||
stable_tracks=[],
|
||||
current_tracks=[101, 102, 103]
|
||||
)
|
||||
|
||||
assert result_empty.has_stable_tracks() is False
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting validation statistics."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102, 103],
|
||||
current_tracks=[101, 102, 104, 105],
|
||||
newly_stable=[103],
|
||||
lost_tracks=[106]
|
||||
)
|
||||
|
||||
stats = result.get_stats()
|
||||
|
||||
assert stats["stable_count"] == 3
|
||||
assert stats["current_count"] == 4
|
||||
assert stats["newly_stable_count"] == 1
|
||||
assert stats["lost_count"] == 1
|
||||
assert stats["stability_ratio"] == 3/4 # stable/current
|
||||
|
||||
def test_is_track_stable(self):
|
||||
"""Test checking if specific track is stable."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102, 103],
|
||||
current_tracks=[101, 102, 104, 105]
|
||||
)
|
||||
|
||||
assert result.is_track_stable(101) is True
|
||||
assert result.is_track_stable(102) is True
|
||||
assert result.is_track_stable(104) is False
|
||||
assert result.is_track_stable(999) is False
|
Loading…
Add table
Add a link
Reference in a new issue