Refactor: PHASE 8: Testing & Integration

This commit is contained in:
ziesorx 2025-09-12 18:55:23 +07:00
parent af34f4fd08
commit 9e8c6804a7
32 changed files with 17128 additions and 0 deletions

View file

@ -0,0 +1,479 @@
"""
Unit tests for detection result data structures.
"""
import pytest
from dataclasses import asdict
import numpy as np
from detector_worker.detection.detection_result import (
BoundingBox,
DetectionResult,
LightweightDetectionResult,
DetectionSession,
TrackValidationResult
)
class TestBoundingBox:
"""Test BoundingBox data structure."""
def test_creation_from_coordinates(self):
"""Test creating bounding box from coordinates."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
assert bbox.x1 == 100
assert bbox.y1 == 200
assert bbox.x2 == 300
assert bbox.y2 == 400
def test_creation_from_list(self):
"""Test creating bounding box from list."""
coords = [100, 200, 300, 400]
bbox = BoundingBox.from_list(coords)
assert bbox.x1 == 100
assert bbox.y1 == 200
assert bbox.x2 == 300
assert bbox.y2 == 400
def test_creation_from_invalid_list(self):
"""Test error handling for invalid list."""
with pytest.raises(ValueError):
BoundingBox.from_list([100, 200, 300]) # Too few elements
def test_to_list(self):
"""Test converting bounding box to list."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
coords = bbox.to_list()
assert coords == [100, 200, 300, 400]
def test_area_calculation(self):
"""Test area calculation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
area = bbox.area()
expected_area = (300 - 100) * (400 - 200) # 200 * 200 = 40000
assert area == expected_area
def test_area_zero_for_invalid_bbox(self):
"""Test area is zero for invalid bounding box."""
# x2 <= x1
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
assert bbox.area() == 0
# y2 <= y1
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
assert bbox.area() == 0
def test_width_height(self):
"""Test width and height properties."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
assert bbox.width() == 200
assert bbox.height() == 200
def test_center_point(self):
"""Test center point calculation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
center = bbox.center()
assert center == (200, 300) # (x1+x2)/2, (y1+y2)/2
def test_is_valid(self):
"""Test bounding box validation."""
# Valid bbox
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
assert bbox.is_valid() is True
# Invalid bbox (x2 <= x1)
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
assert bbox.is_valid() is False
# Invalid bbox (y2 <= y1)
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
assert bbox.is_valid() is False
def test_intersection(self):
"""Test bounding box intersection."""
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
intersection = bbox1.intersection(bbox2)
assert intersection.x1 == 200
assert intersection.y1 == 200
assert intersection.x2 == 300
assert intersection.y2 == 300
def test_no_intersection(self):
"""Test no intersection between bounding boxes."""
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
intersection = bbox1.intersection(bbox2)
assert intersection.is_valid() is False
def test_union(self):
"""Test bounding box union."""
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
union = bbox1.union(bbox2)
assert union.x1 == 100
assert union.y1 == 100
assert union.x2 == 400
assert union.y2 == 400
def test_iou_calculation(self):
"""Test IoU (Intersection over Union) calculation."""
# Perfect overlap
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
assert bbox1.iou(bbox2) == 1.0
# No overlap
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
assert bbox1.iou(bbox2) == 0.0
# Partial overlap
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
# Intersection area: 100x100 = 10000
# Union area: 200x200 + 200x200 - 10000 = 30000
# IoU = 10000/30000 = 1/3
expected_iou = 1.0 / 3.0
assert abs(bbox1.iou(bbox2) - expected_iou) < 1e-6
class TestDetectionResult:
"""Test DetectionResult data structure."""
def test_creation_with_required_fields(self):
"""Test creating detection result with required fields."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
assert detection.class_name == "car"
assert detection.confidence == 0.85
assert detection.bbox == bbox
assert detection.track_id == 12345
def test_creation_with_all_fields(self):
"""Test creating detection result with all fields."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345,
model_id="yolo_v8",
timestamp=1640995200000,
branch_results={"brand": "Toyota"}
)
assert detection.model_id == "yolo_v8"
assert detection.timestamp == 1640995200000
assert detection.branch_results == {"brand": "Toyota"}
def test_creation_from_dict(self):
"""Test creating detection result from dictionary."""
data = {
"class": "car",
"confidence": 0.85,
"bbox": [100, 200, 300, 400],
"id": 12345,
"model_id": "yolo_v8",
"timestamp": 1640995200000
}
detection = DetectionResult.from_dict(data)
assert detection.class_name == "car"
assert detection.confidence == 0.85
assert detection.bbox.to_list() == [100, 200, 300, 400]
assert detection.track_id == 12345
def test_to_dict(self):
"""Test converting detection result to dictionary."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
data = detection.to_dict()
assert data["class"] == "car"
assert data["confidence"] == 0.85
assert data["bbox"] == [100, 200, 300, 400]
assert data["id"] == 12345
def test_is_valid_detection(self):
"""Test detection validation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
# Valid detection
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
assert detection.is_valid() is True
# Invalid confidence (too low)
detection = DetectionResult(
class_name="car",
confidence=-0.1,
bbox=bbox,
track_id=12345
)
assert detection.is_valid() is False
# Invalid confidence (too high)
detection = DetectionResult(
class_name="car",
confidence=1.5,
bbox=bbox,
track_id=12345
)
assert detection.is_valid() is False
# Invalid bounding box
invalid_bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=invalid_bbox,
track_id=12345
)
assert detection.is_valid() is False
class TestLightweightDetectionResult:
"""Test LightweightDetectionResult data structure."""
def test_creation(self):
"""Test creating lightweight detection result."""
detection = LightweightDetectionResult(
class_name="car",
confidence=0.85,
bbox_area=40000,
frame_width=1920,
frame_height=1080
)
assert detection.class_name == "car"
assert detection.confidence == 0.85
assert detection.bbox_area == 40000
assert detection.frame_width == 1920
assert detection.frame_height == 1080
def test_area_ratio_calculation(self):
"""Test bounding box area ratio calculation."""
detection = LightweightDetectionResult(
class_name="car",
confidence=0.85,
bbox_area=40000,
frame_width=1920,
frame_height=1080
)
expected_ratio = 40000 / (1920 * 1080)
assert abs(detection.area_ratio() - expected_ratio) < 1e-6
def test_meets_threshold(self):
"""Test threshold checking."""
detection = LightweightDetectionResult(
class_name="car",
confidence=0.85,
bbox_area=40000,
frame_width=1920,
frame_height=1080
)
assert detection.meets_threshold(confidence=0.8, area_ratio=0.01) is True
assert detection.meets_threshold(confidence=0.9, area_ratio=0.01) is False
assert detection.meets_threshold(confidence=0.8, area_ratio=0.1) is False
class TestDetectionSession:
"""Test DetectionSession data structure."""
def test_creation(self):
"""Test creating detection session."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
assert session.session_id == "session_123"
assert session.camera_id == "camera_001"
assert session.display_id == "display_001"
assert session.detections == []
assert session.metadata == {}
def test_add_detection(self):
"""Test adding detection to session."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
session.add_detection(detection)
assert len(session.detections) == 1
assert session.detections[0] == detection
def test_get_latest_detection(self):
"""Test getting latest detection."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
# Add multiple detections
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=12345,
timestamp=1640995200000
)
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
detection2 = DetectionResult(
class_name="car",
confidence=0.90,
bbox=bbox2,
track_id=12345,
timestamp=1640995300000
)
session.add_detection(detection1)
session.add_detection(detection2)
latest = session.get_latest_detection()
assert latest == detection2 # Should be the one with later timestamp
def test_get_detections_by_class(self):
"""Test filtering detections by class."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
car_detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
truck_detection = DetectionResult(
class_name="truck",
confidence=0.80,
bbox=bbox,
track_id=54321
)
session.add_detection(car_detection)
session.add_detection(truck_detection)
car_detections = session.get_detections_by_class("car")
assert len(car_detections) == 1
assert car_detections[0] == car_detection
truck_detections = session.get_detections_by_class("truck")
assert len(truck_detections) == 1
assert truck_detections[0] == truck_detection
class TestTrackValidationResult:
"""Test TrackValidationResult data structure."""
def test_creation(self):
"""Test creating track validation result."""
result = TrackValidationResult(
stable_tracks=[101, 102, 103],
current_tracks=[101, 102, 104, 105],
newly_stable=[103],
lost_tracks=[106]
)
assert result.stable_tracks == [101, 102, 103]
assert result.current_tracks == [101, 102, 104, 105]
assert result.newly_stable == [103]
assert result.lost_tracks == [106]
def test_has_stable_tracks(self):
"""Test checking for stable tracks."""
result = TrackValidationResult(
stable_tracks=[101, 102],
current_tracks=[101, 102, 103]
)
assert result.has_stable_tracks() is True
result_empty = TrackValidationResult(
stable_tracks=[],
current_tracks=[101, 102, 103]
)
assert result_empty.has_stable_tracks() is False
def test_get_stats(self):
"""Test getting validation statistics."""
result = TrackValidationResult(
stable_tracks=[101, 102, 103],
current_tracks=[101, 102, 104, 105],
newly_stable=[103],
lost_tracks=[106]
)
stats = result.get_stats()
assert stats["stable_count"] == 3
assert stats["current_count"] == 4
assert stats["newly_stable_count"] == 1
assert stats["lost_count"] == 1
assert stats["stability_ratio"] == 3/4 # stable/current
def test_is_track_stable(self):
"""Test checking if specific track is stable."""
result = TrackValidationResult(
stable_tracks=[101, 102, 103],
current_tracks=[101, 102, 104, 105]
)
assert result.is_track_stable(101) is True
assert result.is_track_stable(102) is True
assert result.is_track_stable(104) is False
assert result.is_track_stable(999) is False