Refactor: PHASE 8: Testing & Integration

This commit is contained in:
ziesorx 2025-09-12 18:55:23 +07:00
parent af34f4fd08
commit 9e8c6804a7
32 changed files with 17128 additions and 0 deletions

1
tests/__init__.py Normal file
View file

@ -0,0 +1 @@
"""Test package for detector worker."""

276
tests/conftest.py Normal file
View file

@ -0,0 +1,276 @@
"""
Pytest configuration and shared fixtures for detector worker tests.
"""
import pytest
import tempfile
import os
from unittest.mock import Mock, MagicMock, patch
from typing import Dict, Any, Generator
# Add the project root to the path so we can import detector_worker
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from detector_worker.core.config import get_config_manager, ConfigurationManager
from detector_worker.core.dependency_injection import get_container, DetectorWorkerContainer
from detector_worker.core.singleton_managers import (
ModelStateManager, StreamStateManager, SessionStateManager,
CacheStateManager, CameraStateManager, PipelineStateManager
)
@pytest.fixture
def temp_dir():
"""Create a temporary directory for tests."""
with tempfile.TemporaryDirectory() as tmpdir:
yield tmpdir
@pytest.fixture
def mock_config():
"""Mock configuration data for tests."""
return {
"poll_interval_ms": 100,
"max_streams": 5,
"target_fps": 10,
"reconnect_interval_sec": 5,
"max_retries": 3,
"heartbeat_interval": 2,
"session_timeout": 600,
"models_dir": "models",
"log_level": "INFO",
"database": {
"enabled": False,
"host": "localhost",
"port": 5432,
"database": "test_db",
"username": "test_user",
"password": "test_pass",
"schema": "public"
},
"redis": {
"enabled": False,
"host": "localhost",
"port": 6379,
"password": None,
"db": 0
}
}
@pytest.fixture
def mock_detection_result():
"""Mock detection result for tests."""
return {
"class": "car",
"confidence": 0.85,
"bbox": [100, 200, 300, 400],
"id": 12345,
"branch_results": {}
}
@pytest.fixture
def mock_frame():
"""Mock frame data for tests."""
import numpy as np
return np.zeros((480, 640, 3), dtype=np.uint8)
@pytest.fixture
def mock_model_tree():
"""Mock model tree structure for tests."""
return {
"modelId": "test_model_v1",
"modelFile": "test_model.pt",
"multiClass": True,
"expectedClasses": ["car", "truck"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
@pytest.fixture
def mock_pipeline_context():
"""Mock pipeline context for tests."""
return {
"camera_id": "test_camera_001",
"display_id": "display_001",
"session_id": "session_12345",
"timestamp": 1640995200000,
"subscription_id": "sub_001"
}
@pytest.fixture(autouse=True)
def reset_singletons():
"""Reset singleton managers before each test."""
# Clear all singleton state before each test
yield
# Cleanup after test
try:
ModelStateManager().clear_all()
StreamStateManager().clear_all()
SessionStateManager().clear_all()
CacheStateManager().clear_all()
CameraStateManager().clear_all()
PipelineStateManager().clear_all()
except Exception:
pass # Ignore cleanup errors
@pytest.fixture
def isolated_config_manager(temp_dir, mock_config):
"""Create an isolated configuration manager for testing."""
config_file = os.path.join(temp_dir, "test_config.json")
import json
with open(config_file, 'w') as f:
json.dump(mock_config, f)
# Create a fresh ConfigurationManager for testing
from detector_worker.core.config import JsonFileProvider, EnvironmentProvider
manager = ConfigurationManager()
manager._providers.clear() # Remove default providers
manager.add_provider(JsonFileProvider(config_file))
return manager
@pytest.fixture
def mock_websocket():
"""Mock WebSocket for testing."""
mock_ws = Mock()
mock_ws.accept = Mock()
mock_ws.send_text = Mock()
mock_ws.send_json = Mock()
mock_ws.receive_text = Mock()
mock_ws.receive_json = Mock()
mock_ws.close = Mock()
mock_ws.client_state = Mock()
mock_ws.client_state.DISCONNECTED = False
return mock_ws
@pytest.fixture
def mock_redis_client():
"""Mock Redis client for testing."""
mock_redis = Mock()
mock_redis.get = Mock(return_value=None)
mock_redis.set = Mock(return_value=True)
mock_redis.delete = Mock(return_value=1)
mock_redis.exists = Mock(return_value=0)
mock_redis.expire = Mock(return_value=True)
mock_redis.publish = Mock(return_value=1)
return mock_redis
@pytest.fixture
def mock_database_connection():
"""Mock database connection for testing."""
mock_conn = Mock()
mock_cursor = Mock()
mock_cursor.execute = Mock()
mock_cursor.fetchone = Mock(return_value=None)
mock_cursor.fetchall = Mock(return_value=[])
mock_cursor.fetchmany = Mock(return_value=[])
mock_cursor.rowcount = 1
mock_conn.cursor = Mock(return_value=mock_cursor)
mock_conn.commit = Mock()
mock_conn.rollback = Mock()
mock_conn.close = Mock()
return mock_conn
@pytest.fixture
def mock_yolo_model():
"""Mock YOLO model for testing."""
mock_model = Mock()
# Mock results with boxes
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = Mock()
mock_result.boxes.conf = Mock()
mock_result.boxes.cls = Mock()
mock_result.boxes.id = Mock()
# Mock track method
mock_model.track = Mock(return_value=[mock_result])
mock_model.predict = Mock(return_value=[mock_result])
return mock_model
@pytest.fixture
def sample_detection_data():
"""Sample detection data for testing."""
return [
{
"class": "car",
"confidence": 0.92,
"bbox": [100, 150, 250, 300],
"id": 1001,
"branch_results": {}
},
{
"class": "truck",
"confidence": 0.87,
"bbox": [300, 200, 450, 350],
"id": 1002,
"branch_results": {}
}
]
@pytest.fixture
def sample_session_data():
"""Sample session data for testing."""
return {
"session_id": "session_test_001",
"display_id": "display_test_001",
"camera_id": "camera_test_001",
"created_at": 1640995200.0,
"last_activity": 1640995200.0,
"detection_data": {
"car_brand": "Toyota",
"car_model": "Camry",
"license_plate": "ABC-123"
}
}
# Helper functions for tests
def create_mock_detection_result(class_name: str = "car", confidence: float = 0.85, track_id: int = 1001):
"""Helper function to create mock detection results."""
return {
"class": class_name,
"confidence": confidence,
"bbox": [100, 200, 300, 400],
"id": track_id,
"branch_results": {}
}
def create_mock_regions_dict(detections: list = None):
"""Helper function to create mock regions dictionary."""
if detections is None:
detections = [create_mock_detection_result()]
regions = {}
for det in detections:
regions[det["class"]] = {
"bbox": det["bbox"],
"confidence": det["confidence"],
"detection": det
}
return regions

View file

@ -0,0 +1,19 @@
"""
Integration tests for the detector worker application.
This package contains integration tests that verify the interaction
between multiple components and end-to-end workflows.
"""
# Integration test modules
from . import (
test_complete_detection_workflow,
test_websocket_protocol,
test_pipeline_integration
)
__all__ = [
"test_complete_detection_workflow",
"test_websocket_protocol",
"test_pipeline_integration"
]

View file

@ -0,0 +1,681 @@
"""
Integration tests for complete detection workflow.
This module tests the full end-to-end detection pipeline from stream
to database update, ensuring all components work together correctly.
"""
import pytest
import asyncio
import uuid
import time
import json
import tempfile
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from pathlib import Path
import numpy as np
import cv2
from detector_worker.app import create_app
from detector_worker.core.config import Configuration
from detector_worker.core.dependency_injection import ServiceContainer
from detector_worker.streams.stream_manager import StreamManager, StreamConfig
from detector_worker.models.model_manager import ModelManager
from detector_worker.pipeline.pipeline_executor import PipelineExecutor
from detector_worker.storage.database_manager import DatabaseManager
from detector_worker.storage.redis_client import RedisClient, RedisConfig
from detector_worker.communication.websocket_handler import WebSocketHandler
from detector_worker.communication.message_processor import MessageProcessor
@pytest.fixture
def temp_config_file():
"""Create temporary configuration file."""
config_data = {
"poll_interval_ms": 100,
"max_streams": 5,
"target_fps": 10,
"reconnect_interval_sec": 1,
"max_retries": 2,
"database": {
"enabled": True,
"host": "localhost",
"port": 5432,
"database": "test_gas_station_1",
"user": "test_user",
"password": "test_pass"
},
"redis": {
"enabled": True,
"host": "localhost",
"port": 6379,
"db": 0
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
temp_path = f.name
yield temp_path
# Cleanup
Path(temp_path).unlink(missing_ok=True)
@pytest.fixture
def mock_frame():
"""Create a mock frame for testing."""
return np.ones((480, 640, 3), dtype=np.uint8) * 128
@pytest.fixture
def sample_mpta_file():
"""Create a sample MPTA (pipeline) file."""
pipeline_config = {
"modelId": "car_frontal_detection_v1",
"modelFile": "car_frontal_detection_v1.pt",
"multiClass": True,
"expectedClasses": ["Car", "Frontal"],
"triggerClasses": ["Car", "Frontal"],
"minConfidence": 0.8,
"actions": [
{
"type": "redis_save_image",
"region": "Frontal",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600
},
{
"type": "postgresql_create_record",
"table": "car_frontal_info",
"fields": {
"display_id": "{display_id}",
"captured_timestamp": "{timestamp}",
"session_id": "{session_id}",
"license_character": None,
"license_type": "No model available"
}
}
],
"branches": [
{
"modelId": "car_brand_cls_v1",
"modelFile": "car_brand_cls_v1.pt",
"parallel": True,
"crop": True,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.85
},
{
"modelId": "car_bodytype_cls_v1",
"modelFile": "car_bodytype_cls_v1.pt",
"parallel": True,
"crop": True,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.80
}
],
"parallelActions": [
{
"type": "postgresql_update_combined",
"table": "car_frontal_info",
"key_field": "session_id",
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
"fields": {
"car_brand": "{car_brand_cls_v1.brand}",
"car_body_type": "{car_bodytype_cls_v1.body_type}"
}
}
]
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(pipeline_config, f)
temp_path = f.name
yield temp_path
# Cleanup
Path(temp_path).unlink(missing_ok=True)
class TestCompleteDetectionWorkflow:
"""Test complete detection workflow from stream to database."""
@pytest.mark.asyncio
async def test_complete_rtsp_detection_workflow(self, temp_config_file, sample_mpta_file, mock_frame):
"""Test complete workflow: RTSP stream -> detection -> classification -> database."""
# Initialize configuration
config = Configuration()
config.load_from_file(temp_config_file)
# Create service container
container = ServiceContainer()
# Mock all external dependencies
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load, \
patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis:
# Setup video capture mock
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, mock_frame)
# Setup model loading mock
mock_detection_model = Mock()
mock_brand_model = Mock()
mock_bodytype_model = Mock()
def mock_model_load(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return mock_brand_model
elif "bodytype" in path:
return mock_bodytype_model
return Mock()
mock_torch_load.side_effect = mock_model_load
# Setup detection model predictions
mock_detection_model.return_value = Mock()
mock_detection_model.return_value.boxes = Mock()
mock_detection_model.return_value.boxes.xyxy = Mock()
mock_detection_model.return_value.boxes.conf = Mock()
mock_detection_model.return_value.boxes.cls = Mock()
mock_detection_model.return_value.names = {0: "Car", 1: "Frontal"}
# Mock detection results - Car and Frontal detected
mock_detection_model.return_value.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[100, 200, 300, 400], # Car bbox
[150, 250, 250, 350] # Frontal bbox
])
mock_detection_model.return_value.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_model.return_value.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
# Setup classification model predictions
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 5 # Toyota index
mock_brand_result.probs.top1conf = Mock()
mock_brand_result.probs.top1conf.item.return_value = 0.87
mock_brand_result.names = {5: "Toyota"}
mock_brand_model.return_value = mock_brand_result
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 2 # Sedan index
mock_bodytype_result.probs.top1conf = Mock()
mock_bodytype_result.probs.top1conf.item.return_value = 0.82
mock_bodytype_result.names = {2: "Sedan"}
mock_bodytype_model.return_value = mock_bodytype_result
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
# Setup Redis mock
mock_redis_instance = Mock()
mock_redis.return_value = mock_redis_instance
mock_redis_instance.ping.return_value = True
mock_redis_instance.set.return_value = True
mock_redis_instance.expire.return_value = True
# Initialize managers
stream_manager = StreamManager()
model_manager = ModelManager()
pipeline_executor = PipelineExecutor()
# Register services in container
container.register_singleton(StreamManager, lambda: stream_manager)
container.register_singleton(ModelManager, lambda: model_manager)
container.register_singleton(PipelineExecutor, lambda: pipeline_executor)
try:
# 1. Create RTSP stream
stream_config = StreamConfig(
stream_url="rtsp://example.com/stream",
stream_type="rtsp",
target_fps=10
)
stream_info = await stream_manager.create_stream(
camera_id="camera_001",
config=stream_config,
subscription_id="sub_001"
)
# 2. Load pipeline and models
pipeline_config = json.loads(Path(sample_mpta_file).read_text())
# Mock model file paths exist
with patch('os.path.exists', return_value=True):
await model_manager.load_models_from_config(pipeline_config)
# 3. Get frame from stream
frame = stream_manager.get_latest_frame("camera_001")
assert frame is not None
# 4. Run detection pipeline
detection_context = {
"camera_id": "camera_001",
"display_id": "display_001",
"frame": mock_frame,
"timestamp": int(time.time() * 1000),
"session_id": str(uuid.uuid4())
}
# Mock cv2.imencode for Redis image storage
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
pipeline_result = await pipeline_executor.execute_pipeline(
pipeline_config,
detection_context
)
# 5. Verify pipeline execution
assert pipeline_result is not None
assert pipeline_result.get("status") == "completed"
# 6. Verify database operations
# Should have called create record
create_calls = [call for call in mock_cursor.execute.call_args_list
if "INSERT" in str(call)]
assert len(create_calls) >= 1
# Should have called update with classification results
update_calls = [call for call in mock_cursor.execute.call_args_list
if "UPDATE" in str(call)]
assert len(update_calls) >= 1
# 7. Verify Redis operations
# Should have saved cropped image
assert mock_redis_instance.set.called
assert mock_redis_instance.expire.called
# 8. Verify final results
assert "detections" in pipeline_result
detections = pipeline_result["detections"]
assert len(detections) >= 2 # Car and Frontal
# Check detection classes
detected_classes = [d.get("class") for d in detections]
assert "Car" in detected_classes
assert "Frontal" in detected_classes
# 9. Verify classification results in context
classification_results = pipeline_result.get("classification_results", {})
assert "car_brand_cls_v1" in classification_results
assert "car_bodytype_cls_v1" in classification_results
brand_result = classification_results["car_brand_cls_v1"]
assert brand_result.get("brand") == "Toyota"
assert brand_result.get("confidence") == 0.87
bodytype_result = classification_results["car_bodytype_cls_v1"]
assert bodytype_result.get("body_type") == "Sedan"
assert bodytype_result.get("confidence") == 0.82
finally:
# Cleanup
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_websocket_subscription_workflow(self, temp_config_file, sample_mpta_file):
"""Test complete WebSocket subscription workflow."""
# Mock WebSocket for testing
mock_websocket = Mock()
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock()
# Initialize configuration
config = Configuration()
config.load_from_file(temp_config_file)
# Initialize message processor and WebSocket handler
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load, \
patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis:
# Setup mocks
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
mock_torch_load.return_value = Mock()
mock_db_connect.return_value = Mock()
mock_redis.return_value = Mock()
# Simulate WebSocket message sequence
subscription_message = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://example.com/stream",
"modelUrl": f"file://{sample_mpta_file}",
"modelId": 101,
"modelName": "Vehicle Detection",
"cropX1": 100, "cropY1": 200,
"cropX2": 300, "cropY2": 400
}
}
mock_websocket.receive_json.side_effect = [
subscription_message,
{"type": "requestState"},
{"type": "unsubscribe", "payload": {"subscriptionIdentifier": "display-001;cam-001"}}
]
# Mock file operations
with patch('builtins.open', mock_open(read_data=json.dumps(json.loads(Path(sample_mpta_file).read_text())))):
with patch('os.path.exists', return_value=True):
try:
# Start WebSocket handler (will process messages)
await websocket_handler.handle_websocket(mock_websocket, "client_001")
# Verify WebSocket interactions
mock_websocket.accept.assert_called_once()
# Should have sent subscription acknowledgment
send_calls = mock_websocket.send_json.call_args_list
assert len(send_calls) >= 1
# Check subscription acknowledgment
first_response = send_calls[0][0][0]
assert first_response.get("type") == "subscribeAck"
assert first_response.get("status") == "success"
# Should have sent state report
state_responses = [call[0][0] for call in send_calls
if call[0][0].get("type") == "stateReport"]
assert len(state_responses) >= 1
except Exception as e:
# WebSocket disconnect is expected at end of message sequence
pass
@pytest.mark.asyncio
async def test_http_snapshot_workflow(self, temp_config_file, mock_frame):
"""Test HTTP snapshot workflow."""
config = Configuration()
config.load_from_file(temp_config_file)
stream_manager = StreamManager()
with patch('requests.get') as mock_requests:
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"fake_jpeg_data"
mock_requests.return_value = mock_response
with patch('cv2.imdecode') as mock_imdecode:
mock_imdecode.return_value = mock_frame
try:
# Create HTTP snapshot stream
stream_config = StreamConfig(
stream_url="http://camera.example.com/snapshot.jpg",
stream_type="http_snapshot",
snapshot_interval=1.0
)
stream_info = await stream_manager.create_stream(
camera_id="camera_002",
config=stream_config,
subscription_id="sub_002"
)
# Wait for snapshot capture
await asyncio.sleep(1.2)
# Verify frame was captured
frame = stream_manager.get_latest_frame("camera_002")
assert frame is not None
assert np.array_equal(frame, mock_frame)
# Verify HTTP request was made
mock_requests.assert_called()
mock_imdecode.assert_called()
finally:
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_error_recovery_workflow(self, temp_config_file):
"""Test error recovery and resilience."""
config = Configuration()
config.load_from_file(temp_config_file)
stream_manager = StreamManager()
with patch('cv2.VideoCapture') as mock_video_cap:
# Simulate connection failures then success
mock_cap_instances = []
# First attempt fails
mock_cap_fail = Mock()
mock_cap_fail.isOpened.return_value = False
mock_cap_instances.append(mock_cap_fail)
# Second attempt succeeds
mock_cap_success = Mock()
mock_cap_success.isOpened.return_value = True
mock_cap_success.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
mock_cap_instances.append(mock_cap_success)
mock_video_cap.side_effect = mock_cap_instances
try:
stream_config = StreamConfig(
stream_url="rtsp://unreliable.example.com/stream",
stream_type="rtsp",
max_retries=2
)
# First attempt should fail
try:
await stream_manager.create_stream(
camera_id="camera_003",
config=stream_config,
subscription_id="sub_003"
)
assert False, "Should have failed on first attempt"
except Exception:
pass
# Retry should succeed
stream_info = await stream_manager.create_stream(
camera_id="camera_003",
config=stream_config,
subscription_id="sub_003"
)
assert stream_info is not None
assert mock_video_cap.call_count == 2
finally:
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_concurrent_streams_workflow(self, temp_config_file, mock_frame):
"""Test handling multiple concurrent streams."""
config = Configuration()
config.load_from_file(temp_config_file)
stream_manager = StreamManager()
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('requests.get') as mock_requests:
# Setup RTSP mock
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, mock_frame)
# Setup HTTP mock
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"fake_jpeg_data"
mock_requests.return_value = mock_response
with patch('cv2.imdecode', return_value=mock_frame):
try:
# Create multiple streams concurrently
stream_tasks = []
# RTSP streams
for i in range(3):
config = StreamConfig(
stream_url=f"rtsp://camera{i}.example.com/stream",
stream_type="rtsp"
)
task = stream_manager.create_stream(
camera_id=f"camera_rtsp_{i}",
config=config,
subscription_id=f"sub_rtsp_{i}"
)
stream_tasks.append(task)
# HTTP snapshot streams
for i in range(2):
config = StreamConfig(
stream_url=f"http://camera{i}.example.com/snapshot.jpg",
stream_type="http_snapshot"
)
task = stream_manager.create_stream(
camera_id=f"camera_http_{i}",
config=config,
subscription_id=f"sub_http_{i}"
)
stream_tasks.append(task)
# Wait for all streams to be created
stream_results = await asyncio.gather(*stream_tasks, return_exceptions=True)
# Verify all streams were created successfully
successful_streams = [r for r in stream_results if not isinstance(r, Exception)]
assert len(successful_streams) == 5
# Verify frames can be retrieved from all streams
await asyncio.sleep(0.5) # Allow time for frame capture
for i in range(3):
frame = stream_manager.get_latest_frame(f"camera_rtsp_{i}")
assert frame is not None
for i in range(2):
frame = stream_manager.get_latest_frame(f"camera_http_{i}")
# HTTP snapshots might not have frames immediately
# Verify stream statistics
stats = stream_manager.get_stream_statistics()
assert stats["total_streams"] == 5
assert stats["active_streams"] >= 3 # At least RTSP streams should be active
finally:
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_memory_usage_workflow(self, temp_config_file):
"""Test memory usage tracking and cleanup."""
config = Configuration()
config.load_from_file(temp_config_file)
# Create managers with small limits for testing
stream_manager = StreamManager({"max_streams": 10})
model_manager = ModelManager({"cache_max_size": 5})
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load:
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, np.ones((100, 100, 3), dtype=np.uint8))
# Mock model loading
def create_mock_model():
model = Mock()
# Mock model parameters for memory estimation
param1 = Mock()
param1.numel.return_value = 1000
param1.element_size.return_value = 4
model.parameters.return_value = [param1]
return model
mock_torch_load.side_effect = lambda *args, **kwargs: create_mock_model()
try:
# Create streams up to limit
for i in range(8):
stream_config = StreamConfig(
stream_url=f"rtsp://test{i}.example.com/stream",
stream_type="rtsp"
)
await stream_manager.create_stream(
camera_id=f"test_camera_{i}",
config=stream_config,
subscription_id=f"test_sub_{i}"
)
# Load models up to cache limit
with patch('os.path.exists', return_value=True):
for i in range(7):
config = {
"model_id": f"test_model_{i}",
"model_path": f"/fake/path/model_{i}.pt",
"model_type": "detection",
"device": "cpu"
}
await model_manager.load_model_from_dict(config)
# Check memory usage tracking
stream_stats = stream_manager.get_stream_statistics()
model_stats = model_manager.get_cache_statistics()
assert stream_stats["total_streams"] == 8
assert model_stats["size"] <= 5 # Should be limited by cache size
# Test cleanup
cleaned_models = model_manager.cleanup_unused_models()
assert cleaned_models >= 0
stopped_streams = await stream_manager.stop_all_streams()
assert stopped_streams == 8
# Verify cleanup
final_stream_stats = stream_manager.get_stream_statistics()
assert final_stream_stats["total_streams"] == 0
finally:
await stream_manager.stop_all_streams()
def mock_open(read_data=""):
"""Create a mock file opener."""
from unittest.mock import mock_open as _mock_open
return _mock_open(read_data=read_data)

View file

@ -0,0 +1,738 @@
"""
Integration tests for pipeline execution workflows.
Tests the complete machine learning pipeline execution including
detection, classification, database updates, and Redis actions.
"""
import pytest
import asyncio
import json
import tempfile
import uuid
import time
from pathlib import Path
from unittest.mock import Mock, patch, AsyncMock
import numpy as np
from detector_worker.pipeline.pipeline_executor import PipelineExecutor
from detector_worker.pipeline.action_executor import ActionExecutor
from detector_worker.pipeline.field_mapper import FieldMapper
from detector_worker.models.model_manager import ModelManager
from detector_worker.storage.database_manager import DatabaseManager
from detector_worker.storage.redis_client import RedisClient, RedisConfig
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
@pytest.fixture
def sample_detection_pipeline():
"""Create sample detection pipeline configuration."""
return {
"modelId": "car_frontal_detection_v1",
"modelFile": "car_frontal_detection_v1.pt",
"multiClass": True,
"expectedClasses": ["Car", "Frontal"],
"triggerClasses": ["Car", "Frontal"],
"minConfidence": 0.8,
"actions": [
{
"type": "redis_save_image",
"region": "Frontal",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600
},
{
"type": "postgresql_create_record",
"table": "car_frontal_info",
"fields": {
"display_id": "{display_id}",
"captured_timestamp": "{timestamp}",
"session_id": "{session_id}",
"license_character": None,
"license_type": "No model available"
}
}
],
"branches": [
{
"modelId": "car_brand_cls_v1",
"modelFile": "car_brand_cls_v1.pt",
"parallel": True,
"crop": True,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.85
},
{
"modelId": "car_bodytype_cls_v1",
"modelFile": "car_bodytype_cls_v1.pt",
"parallel": True,
"crop": True,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.80
}
],
"parallelActions": [
{
"type": "postgresql_update_combined",
"table": "car_frontal_info",
"key_field": "session_id",
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
"fields": {
"car_brand": "{car_brand_cls_v1.brand}",
"car_body_type": "{car_bodytype_cls_v1.body_type}"
}
}
]
}
@pytest.fixture
def sample_frame():
"""Create sample frame for testing."""
return np.ones((480, 640, 3), dtype=np.uint8) * 128
@pytest.fixture
def detection_context():
"""Create sample detection context."""
return {
"camera_id": "camera_001",
"display_id": "display_001",
"timestamp": int(time.time() * 1000),
"session_id": str(uuid.uuid4()),
"frame": np.ones((480, 640, 3), dtype=np.uint8) * 128,
"filename": "detection_image.jpg"
}
class TestPipelineIntegration:
"""Test complete pipeline integration workflows."""
@pytest.mark.asyncio
async def test_complete_detection_classification_pipeline(self, sample_detection_pipeline, detection_context):
"""Test complete detection to classification pipeline."""
pipeline_executor = PipelineExecutor()
model_manager = ModelManager()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis:
# Setup detection model mock
mock_detection_model = Mock()
mock_detection_result = Mock()
# Mock successful multi-class detection
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
# Detection results: Car and Frontal detected with high confidence
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], # Car bbox
[150, 200, 300, 400] # Frontal bbox (within Car)
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
# Setup classification models
mock_brand_model = Mock()
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 3 # Toyota index
mock_brand_result.probs.top1conf = Mock()
mock_brand_result.probs.top1conf.item.return_value = 0.87
mock_brand_result.names = {3: "Toyota"}
mock_brand_model.return_value = mock_brand_result
mock_bodytype_model = Mock()
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 1 # Sedan index
mock_bodytype_result.probs.top1conf = Mock()
mock_bodytype_result.probs.top1conf.item.return_value = 0.82
mock_bodytype_result.names = {1: "Sedan"}
mock_bodytype_model.return_value = mock_bodytype_result
# Route model loading to appropriate mocks
def model_loader(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return mock_brand_model
elif "bodytype" in path:
return mock_bodytype_model
return Mock()
mock_torch_load.side_effect = model_loader
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = None
# Setup Redis mock
mock_redis_instance = Mock()
mock_redis.return_value = mock_redis_instance
mock_redis_instance.ping.return_value = True
mock_redis_instance.set.return_value = True
mock_redis_instance.expire.return_value = True
# Mock image encoding for Redis storage
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Execute complete pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Verify pipeline execution
assert result is not None
assert result.get("status") == "completed"
assert "detections" in result
# Verify detection results
detections = result["detections"]
assert len(detections) == 2 # Car and Frontal
detection_classes = [d.get("class") for d in detections]
assert "Car" in detection_classes
assert "Frontal" in detection_classes
# Verify classification results
assert "classification_results" in result
classification_results = result["classification_results"]
assert "car_brand_cls_v1" in classification_results
brand_result = classification_results["car_brand_cls_v1"]
assert brand_result.get("brand") == "Toyota"
assert brand_result.get("confidence") == 0.87
assert "car_bodytype_cls_v1" in classification_results
bodytype_result = classification_results["car_bodytype_cls_v1"]
assert bodytype_result.get("body_type") == "Sedan"
assert bodytype_result.get("confidence") == 0.82
# Verify database operations
db_calls = mock_cursor.execute.call_args_list
# Should have INSERT for initial record creation
insert_calls = [call for call in db_calls if "INSERT" in str(call[0])]
assert len(insert_calls) >= 1
# Should have UPDATE for classification results
update_calls = [call for call in db_calls if "UPDATE" in str(call[0])]
assert len(update_calls) >= 1
# Verify Redis operations
assert mock_redis_instance.set.called
assert mock_redis_instance.expire.called
@pytest.mark.asyncio
async def test_pipeline_with_missing_detections(self, sample_detection_pipeline, detection_context):
"""Test pipeline behavior when expected detections are missing."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True):
# Setup detection model that doesn't find expected classes
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
# Only detect Car, no Frontal
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450] # Only Car bbox
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0])
mock_detection_model.return_value = mock_detection_result
mock_torch_load.return_value = mock_detection_model
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Pipeline should complete but skip classification branches
assert result is not None
assert "detections" in result
detections = result["detections"]
assert len(detections) == 1 # Only Car detected
assert detections[0].get("class") == "Car"
# Classification should not have run (no Frontal detected)
classification_results = result.get("classification_results", {})
assert len(classification_results) == 0 or all(
not res for res in classification_results.values()
)
@pytest.mark.asyncio
async def test_pipeline_with_low_confidence_detections(self, sample_detection_pipeline, detection_context):
"""Test pipeline with detections below confidence threshold."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True):
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
# Detections with low confidence (below 0.8 threshold)
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], # Car bbox
[150, 200, 300, 400] # Frontal bbox
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.75, 0.70]) # Below threshold
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
mock_torch_load.return_value = mock_detection_model
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Should complete but with filtered detections
assert result is not None
# Low confidence detections should be filtered out
detections = result.get("detections", [])
high_conf_detections = [d for d in detections if d.get("confidence", 0) >= 0.8]
assert len(high_conf_detections) == 0
@pytest.mark.asyncio
async def test_pipeline_branch_execution_order(self, sample_detection_pipeline, detection_context):
"""Test that pipeline branches execute in correct order and parallel mode works."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('psycopg2.connect') as mock_db_connect:
# Track execution order
execution_order = []
# Setup detection model
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], [150, 200, 300, 400]
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
def track_detection_execution(*args, **kwargs):
execution_order.append("detection")
return mock_detection_result
mock_detection_model.side_effect = track_detection_execution
# Setup classification models with execution tracking
def create_tracked_model(model_id):
def track_execution(*args, **kwargs):
execution_order.append(model_id)
result = Mock()
result.probs = Mock()
result.probs.top1 = 0
result.probs.top1conf = Mock()
result.probs.top1conf.item.return_value = 0.90
result.names = {0: "TestResult"}
return result
model = Mock()
model.side_effect = track_execution
return model
# Route models with execution tracking
def model_loader(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return create_tracked_model("car_brand_cls_v1")
elif "bodytype" in path:
return create_tracked_model("car_bodytype_cls_v1")
return Mock()
mock_torch_load.side_effect = model_loader
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Verify execution order
assert "detection" in execution_order
assert execution_order[0] == "detection" # Detection should run first
# Classification models should run after detection
brand_index = execution_order.index("car_brand_cls_v1") if "car_brand_cls_v1" in execution_order else -1
bodytype_index = execution_order.index("car_bodytype_cls_v1") if "car_bodytype_cls_v1" in execution_order else -1
detection_index = execution_order.index("detection")
if brand_index >= 0:
assert brand_index > detection_index
if bodytype_index >= 0:
assert bodytype_index > detection_index
# Since branches are parallel, they could run in any order relative to each other
# but both should run after detection
@pytest.mark.asyncio
async def test_pipeline_error_recovery(self, sample_detection_pipeline, detection_context):
"""Test pipeline error handling and recovery."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('psycopg2.connect') as mock_db_connect:
# Setup detection model that works
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], [150, 200, 300, 400]
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
# Setup classification models - one fails, one succeeds
mock_brand_model = Mock()
mock_brand_model.side_effect = RuntimeError("Model inference failed")
mock_bodytype_model = Mock()
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 1
mock_bodytype_result.probs.top1conf = Mock()
mock_bodytype_result.probs.top1conf.item.return_value = 0.85
mock_bodytype_result.names = {1: "SUV"}
mock_bodytype_model.return_value = mock_bodytype_result
def model_loader(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return mock_brand_model
elif "bodytype" in path:
return mock_bodytype_model
return Mock()
mock_torch_load.side_effect = model_loader
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Pipeline should complete despite one branch failing
assert result is not None
# Detection should succeed
assert "detections" in result
detections = result["detections"]
assert len(detections) == 2
# Classification results should be partial
classification_results = result.get("classification_results", {})
# Brand classification should have failed
brand_result = classification_results.get("car_brand_cls_v1")
assert brand_result is None or brand_result.get("error") is not None
# Body type classification should have succeeded
bodytype_result = classification_results.get("car_bodytype_cls_v1")
assert bodytype_result is not None
assert bodytype_result.get("body_type") == "SUV"
assert bodytype_result.get("confidence") == 0.85
@pytest.mark.asyncio
async def test_field_mapping_and_database_update(self, sample_detection_pipeline, detection_context):
"""Test field mapping and database update integration."""
pipeline_executor = PipelineExecutor()
field_mapper = FieldMapper()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('psycopg2.connect') as mock_db_connect:
# Setup successful detection and classification
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], [150, 200, 300, 400]
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
# Setup classification models
mock_brand_model = Mock()
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 2
mock_brand_result.probs.top1conf = Mock()
mock_brand_result.probs.top1conf.item.return_value = 0.88
mock_brand_result.names = {2: "Honda"}
mock_brand_model.return_value = mock_brand_result
mock_bodytype_model = Mock()
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 0
mock_bodytype_result.probs.top1conf = Mock()
mock_bodytype_result.probs.top1conf.item.return_value = 0.91
mock_bodytype_result.names = {0: "Hatchback"}
mock_bodytype_model.return_value = mock_bodytype_result
def model_loader(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return mock_brand_model
elif "bodytype" in path:
return mock_bodytype_model
return Mock()
mock_torch_load.side_effect = model_loader
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Verify pipeline completed successfully
assert result is not None
assert result.get("status") == "completed"
# Check database operations
db_calls = mock_cursor.execute.call_args_list
# Should have INSERT and UPDATE operations
insert_calls = [call for call in db_calls if "INSERT" in str(call[0])]
update_calls = [call for call in db_calls if "UPDATE" in str(call[0])]
assert len(insert_calls) >= 1
assert len(update_calls) >= 1
# Check that UPDATE includes field mapping results
update_sql = str(update_calls[0][0])
assert "car_brand" in update_sql.lower()
assert "car_body_type" in update_sql.lower()
# Check that classification results were properly mapped
classification_results = result.get("classification_results", {})
assert "car_brand_cls_v1" in classification_results
assert "car_bodytype_cls_v1" in classification_results
brand_result = classification_results["car_brand_cls_v1"]
bodytype_result = classification_results["car_bodytype_cls_v1"]
assert brand_result.get("brand") == "Honda"
assert brand_result.get("confidence") == 0.88
assert bodytype_result.get("body_type") == "Hatchback"
assert bodytype_result.get("confidence") == 0.91
@pytest.mark.asyncio
async def test_redis_image_storage_integration(self, sample_detection_pipeline, detection_context):
"""Test Redis image storage integration in pipeline."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('redis.Redis') as mock_redis, \
patch('cv2.imencode') as mock_imencode:
# Setup successful detection
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], [150, 200, 300, 400]
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
mock_torch_load.return_value = mock_detection_model
# Setup Redis mock
mock_redis_instance = Mock()
mock_redis.return_value = mock_redis_instance
mock_redis_instance.ping.return_value = True
mock_redis_instance.set.return_value = True
mock_redis_instance.expire.return_value = True
# Setup image encoding mock
encoded_data = np.array([1, 2, 3, 4, 5], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Verify Redis operations
assert mock_redis_instance.set.called
assert mock_redis_instance.expire.called
# Check that image was encoded
assert mock_imencode.called
# Verify correct key format was used
set_call = mock_redis_instance.set.call_args
redis_key = set_call[0][0]
# Key should contain display_id, timestamp, session_id
assert detection_context["display_id"] in redis_key
assert detection_context["session_id"] in redis_key
assert str(detection_context["timestamp"]) in redis_key
# Should set expiration
expire_call = mock_redis_instance.expire.call_args
expire_key = expire_call[0][0]
expire_seconds = expire_call[0][1]
assert expire_key == redis_key
assert expire_seconds == 600 # As configured in pipeline
@pytest.mark.asyncio
async def test_pipeline_performance_timing(self, sample_detection_pipeline, detection_context):
"""Test pipeline execution timing and performance."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis, \
patch('cv2.imencode') as mock_imencode:
# Setup fast mocks
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], [150, 200, 300, 400]
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
# Setup fast classification models
def create_fast_model():
model = Mock()
result = Mock()
result.probs = Mock()
result.probs.top1 = 0
result.probs.top1conf = Mock()
result.probs.top1conf.item.return_value = 0.90
result.names = {0: "TestClass"}
model.return_value = result
return model
def model_loader(path, **kwargs):
if "detection" in path:
return mock_detection_model
else:
return create_fast_model()
mock_torch_load.side_effect = model_loader
# Setup fast database and Redis
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
mock_redis_instance = Mock()
mock_redis.return_value = mock_redis_instance
mock_redis_instance.ping.return_value = True
mock_redis_instance.set.return_value = True
mock_redis_instance.expire.return_value = True
encoded_data = np.array([1, 2, 3], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Measure execution time
start_time = time.time()
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
end_time = time.time()
execution_time = end_time - start_time
# Pipeline should complete quickly (less than 1 second with mocks)
assert execution_time < 1.0
# Should have timing information in result
assert result is not None
if "execution_time" in result:
assert result["execution_time"] > 0
# Verify pipeline completed successfully
assert result.get("status") == "completed"

View file

@ -0,0 +1,579 @@
"""
Integration tests for WebSocket protocol compliance.
Tests the complete WebSocket communication protocol to ensure
compatibility with existing clients and proper message handling.
"""
import pytest
import asyncio
import json
import uuid
from unittest.mock import Mock, AsyncMock, patch
from fastapi.websockets import WebSocket
from fastapi.testclient import TestClient
from detector_worker.app import create_app
from detector_worker.communication.websocket_handler import WebSocketHandler
from detector_worker.communication.message_processor import MessageProcessor, MessageType
from detector_worker.core.exceptions import MessageProcessingError
@pytest.fixture
def test_app():
"""Create test FastAPI application."""
return create_app()
@pytest.fixture
def mock_websocket():
"""Create mock WebSocket for testing."""
websocket = Mock(spec=WebSocket)
websocket.accept = AsyncMock()
websocket.send_json = AsyncMock()
websocket.send_text = AsyncMock()
websocket.receive_json = AsyncMock()
websocket.receive_text = AsyncMock()
websocket.close = AsyncMock()
websocket.ping = AsyncMock()
return websocket
class TestWebSocketProtocol:
"""Test WebSocket protocol compliance."""
@pytest.mark.asyncio
async def test_subscription_message_protocol(self, mock_websocket):
"""Test subscription message handling protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Mock external dependencies
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load, \
patch('builtins.open') as mock_file_open:
# Setup video capture mock
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
# Setup model loading mock
mock_torch_load.return_value = Mock()
# Setup pipeline file mock
pipeline_config = {
"modelId": "test_detection_model",
"modelFile": "test_model.pt",
"expectedClasses": ["Car"],
"minConfidence": 0.8
}
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
# Test message sequence
subscription_message = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://example.com/stream",
"modelUrl": "http://example.com/model.mpta",
"modelId": 101,
"modelName": "Test Detection",
"cropX1": 0,
"cropY1": 0,
"cropX2": 640,
"cropY2": 480
}
}
request_state_message = {
"type": "requestState"
}
unsubscribe_message = {
"type": "unsubscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001"
}
}
# Mock WebSocket message sequence
mock_websocket.receive_json.side_effect = [
subscription_message,
request_state_message,
unsubscribe_message,
asyncio.CancelledError() # Simulate client disconnect
]
client_id = "test_client_001"
try:
# Handle WebSocket connection
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass # Expected when client disconnects
# Verify protocol compliance
mock_websocket.accept.assert_called_once()
# Check sent messages
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
# Should receive subscription acknowledgment
subscribe_acks = [msg for msg in sent_messages if msg.get("type") == "subscribeAck"]
assert len(subscribe_acks) >= 1
subscribe_ack = subscribe_acks[0]
assert subscribe_ack["status"] in ["success", "error"]
if subscribe_ack["status"] == "success":
assert "subscriptionId" in subscribe_ack
# Should receive state report
state_reports = [msg for msg in sent_messages if msg.get("type") == "stateReport"]
assert len(state_reports) >= 1
state_report = state_reports[0]
assert "payload" in state_report
assert "subscriptions" in state_report["payload"]
assert "system" in state_report["payload"]
# Should receive unsubscribe acknowledgment
unsubscribe_acks = [msg for msg in sent_messages if msg.get("type") == "unsubscribeAck"]
assert len(unsubscribe_acks) >= 1
unsubscribe_ack = unsubscribe_acks[0]
assert unsubscribe_ack["status"] in ["success", "error"]
@pytest.mark.asyncio
async def test_invalid_message_handling(self, mock_websocket):
"""Test handling of invalid messages."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Test invalid message types
invalid_messages = [
{"invalid": "message"}, # Missing type
{"type": "unknown_type", "payload": {}}, # Unknown type
{"type": "subscribe"}, # Missing payload
{"type": "subscribe", "payload": {}}, # Missing required fields
{"type": "subscribe", "payload": {"subscriptionIdentifier": "test"}}, # Missing URL
]
mock_websocket.receive_json.side_effect = invalid_messages + [asyncio.CancelledError()]
client_id = "test_client_error"
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Verify error responses
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
error_messages = [msg for msg in sent_messages if msg.get("type") == "error"]
# Should receive error responses for invalid messages
assert len(error_messages) >= len(invalid_messages)
for error_msg in error_messages:
assert "message" in error_msg
assert error_msg["message"] # Non-empty error message
@pytest.mark.asyncio
async def test_session_management_protocol(self, mock_websocket):
"""Test session management protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
session_id = str(uuid.uuid4())
# Test session messages
set_session_message = {
"type": "setSessionId",
"payload": {
"sessionId": session_id,
"displayId": "display-001"
}
}
patch_session_message = {
"type": "patchSession",
"payload": {
"sessionId": session_id,
"data": {
"car_brand": "Toyota",
"confidence": 0.92
}
}
}
mock_websocket.receive_json.side_effect = [
set_session_message,
patch_session_message,
asyncio.CancelledError()
]
client_id = "test_client_session"
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Verify session responses
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
# Should receive acknowledgments for session operations
set_session_acks = [msg for msg in sent_messages
if msg.get("type") == "setSessionIdAck"]
assert len(set_session_acks) >= 1
patch_session_acks = [msg for msg in sent_messages
if msg.get("type") == "patchSessionAck"]
assert len(patch_session_acks) >= 1
@pytest.mark.asyncio
async def test_heartbeat_protocol(self, mock_websocket):
"""Test heartbeat/ping-pong protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor, {"heartbeat_interval": 0.1})
# Simulate long-running connection
mock_websocket.receive_json.side_effect = [
asyncio.TimeoutError(), # No messages for a while
asyncio.TimeoutError(),
asyncio.CancelledError() # Then disconnect
]
client_id = "test_client_heartbeat"
# Start heartbeat task
heartbeat_task = asyncio.create_task(websocket_handler._heartbeat_loop())
try:
# Let heartbeat run briefly
await asyncio.sleep(0.2)
# Cancel heartbeat
heartbeat_task.cancel()
await heartbeat_task
except asyncio.CancelledError:
pass
# Verify ping was sent
assert mock_websocket.ping.called
@pytest.mark.asyncio
async def test_detection_result_protocol(self, mock_websocket):
"""Test detection result message protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Mock a detection result being sent
detection_result = {
"type": "imageDetection",
"payload": {
"subscriptionId": "display-001;cam-001",
"detections": [
{
"class": "Car",
"confidence": 0.95,
"bbox": [100, 200, 300, 400],
"trackId": 1001
}
],
"timestamp": 1640995200000,
"modelInfo": {
"modelId": 101,
"modelName": "Vehicle Detection"
}
}
}
# Send detection result to client
await websocket_handler.send_to_client("test_client", detection_result)
# Verify message was sent
mock_websocket.send_json.assert_called_with(detection_result)
@pytest.mark.asyncio
async def test_error_recovery_protocol(self, mock_websocket):
"""Test error recovery and graceful degradation."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Simulate WebSocket errors
mock_websocket.send_json.side_effect = [
None, # First message succeeds
ConnectionError("Connection lost"), # Second fails
None # Third succeeds after recovery
]
# Try to send multiple messages
messages = [
{"type": "test", "message": "1"},
{"type": "test", "message": "2"},
{"type": "test", "message": "3"}
]
results = []
for msg in messages:
try:
await websocket_handler.send_to_client("test_client", msg)
results.append("success")
except Exception:
results.append("error")
# Should handle errors gracefully
assert "error" in results
# But should still be able to send other messages
assert "success" in results
@pytest.mark.asyncio
async def test_concurrent_client_protocol(self, mock_websocket):
"""Test handling multiple concurrent clients."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Create multiple mock WebSocket connections
mock_websockets = []
for i in range(3):
ws = Mock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
ws.receive_json = AsyncMock(side_effect=[asyncio.CancelledError()])
mock_websockets.append(ws)
# Handle multiple clients concurrently
client_tasks = []
for i, ws in enumerate(mock_websockets):
task = asyncio.create_task(
websocket_handler.handle_websocket(ws, f"client_{i}")
)
client_tasks.append(task)
# Wait briefly then cancel all
await asyncio.sleep(0.1)
for task in client_tasks:
task.cancel()
try:
await asyncio.gather(*client_tasks)
except asyncio.CancelledError:
pass
# Verify all connections were accepted
for ws in mock_websockets:
ws.accept.assert_called_once()
@pytest.mark.asyncio
async def test_subscription_sharing_protocol(self, mock_websocket):
"""Test shared subscription protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Mock multiple clients subscribing to same camera
with patch('cv2.VideoCapture') as mock_video_cap:
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
# First client subscribes
subscription_msg1 = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://shared.example.com/stream",
"modelUrl": "http://example.com/model.mpta"
}
}
# Second client subscribes to same camera
subscription_msg2 = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-002;cam-001",
"rtspUrl": "rtsp://shared.example.com/stream", # Same URL
"modelUrl": "http://example.com/model.mpta"
}
}
# Mock file operations
with patch('builtins.open') as mock_file_open:
pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]}
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
# Process both subscriptions
response1 = await message_processor.process_message(subscription_msg1, "client_1")
response2 = await message_processor.process_message(subscription_msg2, "client_2")
# Both should succeed and reference same underlying stream
assert response1.get("status") == "success"
assert response2.get("status") == "success"
# Should only create one video capture instance (shared stream)
assert mock_video_cap.call_count == 1
@pytest.mark.asyncio
async def test_message_ordering_protocol(self, mock_websocket):
"""Test message ordering and sequencing."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Test sequence of related messages
messages = [
{"type": "subscribe", "payload": {"subscriptionIdentifier": "test", "rtspUrl": "rtsp://test.com"}},
{"type": "setSessionId", "payload": {"sessionId": "session_123", "displayId": "display_001"}},
{"type": "requestState"},
{"type": "patchSession", "payload": {"sessionId": "session_123", "data": {"test": "data"}}},
{"type": "unsubscribe", "payload": {"subscriptionIdentifier": "test"}}
]
mock_websocket.receive_json.side_effect = messages + [asyncio.CancelledError()]
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('builtins.open') as mock_file_open:
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]}
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
client_id = "test_client_ordering"
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Verify responses were sent in correct order
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
# Should receive responses for each message type
response_types = [msg.get("type") for msg in sent_messages]
expected_types = ["subscribeAck", "setSessionIdAck", "stateReport", "patchSessionAck", "unsubscribeAck"]
# Check that we got appropriate responses (order may vary slightly)
for expected_type in expected_types:
assert any(expected_type in response_types), f"Missing response type: {expected_type}"
class TestWebSocketPerformance:
"""Test WebSocket performance characteristics."""
@pytest.mark.asyncio
async def test_message_throughput(self, mock_websocket):
"""Test message processing throughput."""
message_processor = MessageProcessor()
# Prepare batch of simple messages
state_request = {"type": "requestState"}
import time
start_time = time.time()
# Process many messages quickly
for _ in range(100):
await message_processor.process_message(state_request, "test_client")
end_time = time.time()
processing_time = end_time - start_time
# Should process messages quickly (less than 1 second for 100 messages)
assert processing_time < 1.0
# Calculate throughput
throughput = 100 / processing_time
assert throughput > 100 # Should handle > 100 messages/second
@pytest.mark.asyncio
async def test_concurrent_message_handling(self, mock_websocket):
"""Test concurrent message handling."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Create multiple mock clients
num_clients = 10
mock_websockets = []
for i in range(num_clients):
ws = Mock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
ws.receive_json = AsyncMock(side_effect=[
{"type": "requestState"},
asyncio.CancelledError()
])
mock_websockets.append(ws)
# Handle all clients concurrently
client_tasks = []
for i, ws in enumerate(mock_websockets):
task = asyncio.create_task(
websocket_handler.handle_websocket(ws, f"perf_client_{i}")
)
client_tasks.append(task)
start_time = time.time()
# Wait for all to complete
try:
await asyncio.gather(*client_tasks)
except asyncio.CancelledError:
pass
end_time = time.time()
total_time = end_time - start_time
# Should handle all clients efficiently
assert total_time < 2.0 # Should complete in less than 2 seconds
# All clients should have been accepted
for ws in mock_websockets:
ws.accept.assert_called_once()
@pytest.mark.asyncio
async def test_memory_usage_stability(self, mock_websocket):
"""Test memory usage remains stable."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Simulate many connection/disconnection cycles
for cycle in range(10):
# Create client
client_id = f"memory_test_client_{cycle}"
mock_websocket.receive_json.side_effect = [
{"type": "requestState"},
asyncio.CancelledError()
]
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Reset mock for next cycle
mock_websocket.reset_mock()
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock()
# Connection manager should not accumulate stale connections
stats = websocket_handler.get_connection_stats()
assert stats["total_connections"] == 0 # All should be cleaned up

View file

@ -0,0 +1,19 @@
"""
Performance tests for the detector worker application.
This package contains performance benchmarks and load tests to ensure
the application meets scalability and throughput requirements.
"""
# Performance test modules
from . import (
test_detection_performance,
test_websocket_performance,
test_storage_performance
)
__all__ = [
"test_detection_performance",
"test_websocket_performance",
"test_storage_performance"
]

View file

@ -0,0 +1,672 @@
"""
Performance tests for detection pipeline components.
These tests benchmark the performance of key detection pipeline
components to ensure they meet performance requirements.
"""
import pytest
import time
import asyncio
import statistics
from unittest.mock import Mock, patch
import numpy as np
import psutil
import gc
from detector_worker.detection.yolo_detector import YOLODetector
from detector_worker.detection.tracking_manager import TrackingManager
from detector_worker.detection.stability_validator import StabilityValidator
from detector_worker.pipeline.pipeline_executor import PipelineExecutor
from detector_worker.models.model_manager import ModelManager
from detector_worker.streams.stream_manager import StreamManager
@pytest.fixture
def sample_frame():
"""Create a sample frame for performance testing."""
return np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
@pytest.fixture
def large_frame():
"""Create a large frame for stress testing."""
return np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8)
@pytest.fixture
def performance_config():
"""Configuration for performance tests."""
return {
"target_fps": 30,
"max_detection_time_ms": 100,
"max_tracking_time_ms": 50,
"max_pipeline_time_ms": 500,
"memory_limit_mb": 1024
}
class TestDetectionPerformance:
"""Test detection performance benchmarks."""
def test_yolo_detection_speed(self, sample_frame, performance_config):
"""Benchmark YOLO detection speed."""
detector = YOLODetector()
with patch('torch.load') as mock_torch_load:
# Setup fast mock model
mock_model = Mock()
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.xyxy = Mock()
mock_result.boxes.conf = Mock()
mock_result.boxes.cls = Mock()
mock_result.names = {0: "car", 1: "person"}
# Mock detection results
mock_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[100, 200, 300, 400],
[150, 250, 350, 450]
])
mock_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.9, 0.8])
mock_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_model.return_value = mock_result
mock_torch_load.return_value = mock_model
# Warm up
for _ in range(5):
detector.detect(sample_frame, confidence_threshold=0.5)
# Benchmark detection speed
detection_times = []
num_iterations = 100
for _ in range(num_iterations):
start_time = time.perf_counter()
detections = detector.detect(sample_frame, confidence_threshold=0.5)
end_time = time.perf_counter()
detection_time_ms = (end_time - start_time) * 1000
detection_times.append(detection_time_ms)
# Calculate statistics
avg_detection_time = statistics.mean(detection_times)
median_detection_time = statistics.median(detection_times)
max_detection_time = max(detection_times)
min_detection_time = min(detection_times)
# Performance assertions
assert avg_detection_time < performance_config["max_detection_time_ms"]
assert median_detection_time < performance_config["max_detection_time_ms"]
# Calculate theoretical FPS
theoretical_fps = 1000 / avg_detection_time
assert theoretical_fps >= performance_config["target_fps"]
print(f"\nDetection Performance Metrics:")
print(f"Average detection time: {avg_detection_time:.2f} ms")
print(f"Median detection time: {median_detection_time:.2f} ms")
print(f"Min detection time: {min_detection_time:.2f} ms")
print(f"Max detection time: {max_detection_time:.2f} ms")
print(f"Theoretical FPS: {theoretical_fps:.1f}")
def test_tracking_performance(self, sample_frame, performance_config):
"""Benchmark object tracking performance."""
tracking_manager = TrackingManager()
# Create mock detections
detections = [
{"class": "car", "confidence": 0.9, "bbox": [100, 200, 300, 400]},
{"class": "car", "confidence": 0.8, "bbox": [150, 250, 350, 450]},
{"class": "person", "confidence": 0.7, "bbox": [200, 300, 250, 400]}
]
# Warm up tracking
for i in range(10):
tracking_manager.update_tracks(detections, frame_id=i)
# Benchmark tracking speed
tracking_times = []
num_iterations = 100
for i in range(num_iterations):
# Simulate moving detections
moving_detections = []
for det in detections:
moved_det = det.copy()
# Add small random movement
bbox = moved_det["bbox"]
moved_det["bbox"] = [
bbox[0] + np.random.randint(-5, 5),
bbox[1] + np.random.randint(-5, 5),
bbox[2] + np.random.randint(-5, 5),
bbox[3] + np.random.randint(-5, 5)
]
moving_detections.append(moved_det)
start_time = time.perf_counter()
tracks = tracking_manager.update_tracks(moving_detections, frame_id=i + 10)
end_time = time.perf_counter()
tracking_time_ms = (end_time - start_time) * 1000
tracking_times.append(tracking_time_ms)
# Calculate statistics
avg_tracking_time = statistics.mean(tracking_times)
max_tracking_time = max(tracking_times)
# Performance assertions
assert avg_tracking_time < performance_config["max_tracking_time_ms"]
assert max_tracking_time < performance_config["max_tracking_time_ms"] * 2
print(f"\nTracking Performance Metrics:")
print(f"Average tracking time: {avg_tracking_time:.2f} ms")
print(f"Max tracking time: {max_tracking_time:.2f} ms")
def test_stability_validation_performance(self, performance_config):
"""Benchmark stability validation performance."""
validator = StabilityValidator()
# Create stable detections sequence
base_detection = {
"class": "car",
"confidence": 0.9,
"bbox": [100, 200, 300, 400],
"track_id": 1001
}
# Add sequence of stable detections
for i in range(20):
detection = base_detection.copy()
# Add small variations to simulate real detection noise
detection["confidence"] = 0.9 + np.random.normal(0, 0.02)
bbox = detection["bbox"]
detection["bbox"] = [
bbox[0] + np.random.normal(0, 2),
bbox[1] + np.random.normal(0, 2),
bbox[2] + np.random.normal(0, 2),
bbox[3] + np.random.normal(0, 2)
]
validator.add_detection(detection, frame_id=i)
# Benchmark validation performance
validation_times = []
num_iterations = 1000
for i in range(num_iterations):
test_detection = base_detection.copy()
test_detection["confidence"] = 0.85 + np.random.normal(0, 0.05)
start_time = time.perf_counter()
is_stable = validator.is_detection_stable(
test_detection,
stability_frames=10,
confidence_threshold=0.8
)
end_time = time.perf_counter()
validation_time_ms = (end_time - start_time) * 1000
validation_times.append(validation_time_ms)
avg_validation_time = statistics.mean(validation_times)
max_validation_time = max(validation_times)
# Should be very fast (< 1ms typically)
assert avg_validation_time < 1.0
assert max_validation_time < 5.0
print(f"\nStability Validation Performance Metrics:")
print(f"Average validation time: {avg_validation_time:.3f} ms")
print(f"Max validation time: {max_validation_time:.3f} ms")
@pytest.mark.asyncio
async def test_pipeline_executor_performance(self, sample_frame, performance_config):
"""Benchmark complete pipeline execution performance."""
pipeline_executor = PipelineExecutor()
# Simple pipeline configuration
pipeline_config = {
"modelId": "fast_detection_model",
"modelFile": "fast_model.pt",
"expectedClasses": ["car"],
"minConfidence": 0.5,
"actions": [],
"branches": []
}
detection_context = {
"camera_id": "perf_camera",
"display_id": "perf_display",
"frame": sample_frame,
"timestamp": int(time.time() * 1000),
"session_id": "perf_session"
}
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True):
# Setup fast mock model
mock_model = Mock()
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.xyxy = Mock()
mock_result.boxes.conf = Mock()
mock_result.boxes.cls = Mock()
mock_result.names = {0: "car"}
mock_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([[100, 200, 300, 400]])
mock_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.9])
mock_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0])
mock_model.return_value = mock_result
mock_torch_load.return_value = mock_model
# Warm up
for _ in range(3):
await pipeline_executor.execute_pipeline(pipeline_config, detection_context)
# Benchmark pipeline execution
pipeline_times = []
num_iterations = 50
for _ in range(num_iterations):
start_time = time.perf_counter()
result = await pipeline_executor.execute_pipeline(pipeline_config, detection_context)
end_time = time.perf_counter()
pipeline_time_ms = (end_time - start_time) * 1000
pipeline_times.append(pipeline_time_ms)
# Ensure result is valid
assert result is not None
avg_pipeline_time = statistics.mean(pipeline_times)
max_pipeline_time = max(pipeline_times)
# Performance assertions
assert avg_pipeline_time < performance_config["max_pipeline_time_ms"]
print(f"\nPipeline Execution Performance Metrics:")
print(f"Average pipeline time: {avg_pipeline_time:.2f} ms")
print(f"Max pipeline time: {max_pipeline_time:.2f} ms")
def test_memory_usage_detection(self, sample_frame, performance_config):
"""Test memory usage during detection operations."""
detector = YOLODetector()
with patch('torch.load') as mock_torch_load:
# Setup mock model
mock_model = Mock()
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.xyxy = Mock()
mock_result.boxes.conf = Mock()
mock_result.boxes.cls = Mock()
mock_result.names = {0: "car"}
mock_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([[100, 200, 300, 400]])
mock_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.9])
mock_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0])
mock_model.return_value = mock_result
mock_torch_load.return_value = mock_model
# Measure memory usage
gc.collect() # Clean up before measurement
initial_memory = psutil.Process().memory_info().rss / 1024 / 1024 # MB
# Run detections and monitor memory
memory_measurements = []
for i in range(100):
detections = detector.detect(sample_frame, confidence_threshold=0.5)
if i % 10 == 0: # Measure every 10 iterations
current_memory = psutil.Process().memory_info().rss / 1024 / 1024
memory_measurements.append(current_memory - initial_memory)
# Final memory measurement
gc.collect()
final_memory = psutil.Process().memory_info().rss / 1024 / 1024
memory_increase = final_memory - initial_memory
# Memory should not grow significantly
assert memory_increase < 100 # Less than 100MB increase
# Memory should be relatively stable (not constantly growing)
if len(memory_measurements) > 1:
memory_trend = memory_measurements[-1] - memory_measurements[0]
assert memory_trend < 50 # Less than 50MB trend growth
print(f"\nMemory Usage Metrics:")
print(f"Initial memory: {initial_memory:.1f} MB")
print(f"Final memory: {final_memory:.1f} MB")
print(f"Memory increase: {memory_increase:.1f} MB")
def test_concurrent_detection_performance(self, sample_frame):
"""Test performance with concurrent detection operations."""
with patch('torch.load') as mock_torch_load:
# Setup mock model
mock_model = Mock()
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.xyxy = Mock()
mock_result.boxes.conf = Mock()
mock_result.boxes.cls = Mock()
mock_result.names = {0: "car"}
mock_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([[100, 200, 300, 400]])
mock_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.9])
mock_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0])
mock_model.return_value = mock_result
mock_torch_load.return_value = mock_model
# Create multiple detectors
detectors = [YOLODetector() for _ in range(4)]
import threading
import concurrent.futures
def run_detection(detector, frame, iterations=25):
"""Run detection iterations."""
times = []
for _ in range(iterations):
start_time = time.perf_counter()
detections = detector.detect(frame, confidence_threshold=0.5)
end_time = time.perf_counter()
times.append((end_time - start_time) * 1000)
return times
# Run concurrent detections
start_time = time.perf_counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = [
executor.submit(run_detection, detector, sample_frame)
for detector in detectors
]
results = [future.result() for future in concurrent.futures.as_completed(futures)]
end_time = time.perf_counter()
total_time = end_time - start_time
# Analyze results
all_times = [time_ms for result in results for time_ms in result]
total_detections = len(all_times)
avg_detection_time = statistics.mean(all_times)
# Calculate effective throughput
effective_fps = total_detections / total_time
print(f"\nConcurrent Detection Performance:")
print(f"Total detections: {total_detections}")
print(f"Total time: {total_time:.2f} seconds")
print(f"Average detection time: {avg_detection_time:.2f} ms")
print(f"Effective throughput: {effective_fps:.1f} FPS")
# Should maintain reasonable performance under load
assert avg_detection_time < 200 # Less than 200ms average
assert effective_fps > 20 # More than 20 effective FPS
def test_large_frame_performance(self, large_frame):
"""Test detection performance with large frames."""
detector = YOLODetector()
with patch('torch.load') as mock_torch_load:
# Setup mock model
mock_model = Mock()
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.xyxy = Mock()
mock_result.boxes.conf = Mock()
mock_result.boxes.cls = Mock()
mock_result.names = {0: "car", 1: "person"}
# Larger frame might have more detections
mock_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[100, 200, 300, 400],
[500, 600, 700, 800],
[1000, 200, 1200, 400]
])
mock_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.9, 0.8, 0.7])
mock_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1, 0])
mock_model.return_value = mock_result
mock_torch_load.return_value = mock_model
# Benchmark large frame detection
detection_times = []
num_iterations = 20 # Fewer iterations for large frames
for _ in range(num_iterations):
start_time = time.perf_counter()
detections = detector.detect(large_frame, confidence_threshold=0.5)
end_time = time.perf_counter()
detection_time_ms = (end_time - start_time) * 1000
detection_times.append(detection_time_ms)
avg_detection_time = statistics.mean(detection_times)
max_detection_time = max(detection_times)
print(f"\nLarge Frame Detection Performance:")
print(f"Frame size: {large_frame.shape}")
print(f"Average detection time: {avg_detection_time:.2f} ms")
print(f"Max detection time: {max_detection_time:.2f} ms")
# Large frames should still be processed in reasonable time
assert avg_detection_time < 300 # Less than 300ms for large frames
assert max_detection_time < 500 # Less than 500ms max
class TestStreamPerformance:
"""Test stream management performance."""
@pytest.mark.asyncio
async def test_stream_creation_performance(self):
"""Test performance of stream creation and management."""
stream_manager = StreamManager()
with patch('cv2.VideoCapture') as mock_video_cap:
# Setup fast mock
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
# Benchmark stream creation
creation_times = []
num_streams = 20
try:
for i in range(num_streams):
from detector_worker.streams.stream_manager import StreamConfig
config = StreamConfig(
stream_url=f"rtsp://test{i}.example.com/stream",
stream_type="rtsp"
)
start_time = time.perf_counter()
await stream_manager.create_stream(f"camera_{i}", config, f"sub_{i}")
end_time = time.perf_counter()
creation_time_ms = (end_time - start_time) * 1000
creation_times.append(creation_time_ms)
avg_creation_time = statistics.mean(creation_times)
max_creation_time = max(creation_times)
# Stream creation should be fast
assert avg_creation_time < 100 # Less than 100ms average
assert max_creation_time < 500 # Less than 500ms max
print(f"\nStream Creation Performance:")
print(f"Streams created: {num_streams}")
print(f"Average creation time: {avg_creation_time:.2f} ms")
print(f"Max creation time: {max_creation_time:.2f} ms")
finally:
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_frame_retrieval_performance(self, sample_frame):
"""Test performance of frame retrieval operations."""
stream_manager = StreamManager()
with patch('cv2.VideoCapture') as mock_video_cap:
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, sample_frame)
try:
# Create test stream
from detector_worker.streams.stream_manager import StreamConfig
config = StreamConfig(
stream_url="rtsp://perf.example.com/stream",
stream_type="rtsp"
)
await stream_manager.create_stream("perf_camera", config, "perf_sub")
# Let stream capture some frames
await asyncio.sleep(0.1)
# Benchmark frame retrieval
retrieval_times = []
num_retrievals = 1000
for _ in range(num_retrievals):
start_time = time.perf_counter()
frame = stream_manager.get_latest_frame("perf_camera")
end_time = time.perf_counter()
retrieval_time_ms = (end_time - start_time) * 1000
retrieval_times.append(retrieval_time_ms)
avg_retrieval_time = statistics.mean(retrieval_times)
max_retrieval_time = max(retrieval_times)
# Frame retrieval should be very fast
assert avg_retrieval_time < 1.0 # Less than 1ms average
assert max_retrieval_time < 10.0 # Less than 10ms max
print(f"\nFrame Retrieval Performance:")
print(f"Frame retrievals: {num_retrievals}")
print(f"Average retrieval time: {avg_retrieval_time:.3f} ms")
print(f"Max retrieval time: {max_retrieval_time:.3f} ms")
finally:
await stream_manager.stop_all_streams()
class TestModelPerformance:
"""Test model management performance."""
def test_model_loading_performance(self):
"""Test performance of model loading operations."""
model_manager = ModelManager()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True):
# Setup mock model
def create_mock_model():
model = Mock()
# Mock model parameters for memory estimation
param = Mock()
param.numel.return_value = 1000000 # 1M parameters
param.element_size.return_value = 4 # 4 bytes each
model.parameters.return_value = [param]
return model
mock_torch_load.side_effect = lambda *args, **kwargs: create_mock_model()
# Benchmark model loading
loading_times = []
num_models = 10
for i in range(num_models):
from detector_worker.models.model_manager import ModelConfig
config = ModelConfig(
model_id=f"perf_model_{i}",
model_path=f"/fake/path/model_{i}.pt",
model_type="detection",
device="cpu"
)
start_time = time.perf_counter()
model = model_manager.load_model(config)
end_time = time.perf_counter()
loading_time_ms = (end_time - start_time) * 1000
loading_times.append(loading_time_ms)
avg_loading_time = statistics.mean(loading_times)
max_loading_time = max(loading_times)
print(f"\nModel Loading Performance:")
print(f"Models loaded: {num_models}")
print(f"Average loading time: {avg_loading_time:.2f} ms")
print(f"Max loading time: {max_loading_time:.2f} ms")
# Model loading should be reasonable
assert avg_loading_time < 200 # Less than 200ms average
def test_model_cache_performance(self):
"""Test performance of model cache operations."""
model_manager = ModelManager()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True):
mock_torch_load.return_value = Mock()
# Load model first
from detector_worker.models.model_manager import ModelConfig
config = ModelConfig(
model_id="cache_perf_model",
model_path="/fake/path/model.pt",
model_type="detection",
device="cpu"
)
# Initial load
model_manager.load_model(config)
# Benchmark cache retrieval
cache_times = []
num_retrievals = 10000
for _ in range(num_retrievals):
start_time = time.perf_counter()
model = model_manager.get_model("cache_perf_model")
end_time = time.perf_counter()
cache_time_ms = (end_time - start_time) * 1000
cache_times.append(cache_time_ms)
avg_cache_time = statistics.mean(cache_times)
max_cache_time = max(cache_times)
print(f"\nModel Cache Performance:")
print(f"Cache retrievals: {num_retrievals}")
print(f"Average cache time: {avg_cache_time:.4f} ms")
print(f"Max cache time: {max_cache_time:.4f} ms")
# Cache should be very fast
assert avg_cache_time < 0.1 # Less than 0.1ms average
assert max_cache_time < 1.0 # Less than 1ms max

View file

@ -0,0 +1,828 @@
"""
Performance tests for storage components (database, Redis, session cache).
These tests benchmark storage operations to ensure they meet
performance requirements for high-throughput scenarios.
"""
import pytest
import asyncio
import time
import statistics
import uuid
from unittest.mock import Mock, patch, MagicMock
import psutil
import gc
import numpy as np
from detector_worker.storage.database_manager import DatabaseManager
from detector_worker.storage.redis_client import RedisClient, RedisConfig
from detector_worker.storage.session_cache import SessionCacheManager, SessionCache, CacheConfig
@pytest.fixture
def performance_config():
"""Configuration for performance tests."""
return {
"max_db_query_time_ms": 50,
"max_redis_operation_time_ms": 10,
"max_cache_operation_time_ms": 1,
"min_db_throughput_ops_per_sec": 1000,
"min_redis_throughput_ops_per_sec": 5000,
"min_cache_throughput_ops_per_sec": 10000
}
class TestDatabasePerformance:
"""Test database performance benchmarks."""
def test_database_connection_performance(self, performance_config):
"""Test database connection establishment performance."""
with patch('psycopg2.connect') as mock_connect:
# Setup mock connection
mock_conn = Mock()
mock_cursor = Mock()
mock_conn.cursor.return_value = mock_cursor
mock_connect.return_value = mock_conn
db_manager = DatabaseManager()
# Benchmark connection times
connection_times = []
num_connections = 100
for _ in range(num_connections):
start_time = time.perf_counter()
db_manager.connect()
end_time = time.perf_counter()
connection_time_ms = (end_time - start_time) * 1000
connection_times.append(connection_time_ms)
# Disconnect for next test
db_manager.disconnect()
avg_connection_time = statistics.mean(connection_times)
max_connection_time = max(connection_times)
print(f"\nDatabase Connection Performance:")
print(f"Connections: {num_connections}")
print(f"Average connection time: {avg_connection_time:.2f} ms")
print(f"Max connection time: {max_connection_time:.2f} ms")
# Connection should be fast
assert avg_connection_time < 10.0 # Less than 10ms average
assert max_connection_time < 50.0 # Less than 50ms max
@pytest.mark.asyncio
async def test_database_insert_performance(self, performance_config):
"""Test database insert performance."""
with patch('psycopg2.connect') as mock_connect:
# Setup mock database
mock_conn = Mock()
mock_cursor = Mock()
mock_conn.cursor.return_value = mock_cursor
mock_connect.return_value = mock_conn
db_manager = DatabaseManager()
db_manager.connect()
# Prepare test data
table_name = "car_frontal_info"
test_records = [
{
"display_id": f"display_{i}",
"captured_timestamp": str(int(time.time() * 1000) + i),
"session_id": str(uuid.uuid4()),
"license_character": None,
"license_type": "No model available"
}
for i in range(1000)
]
# Benchmark single inserts
insert_times = []
for record in test_records[:100]: # Test first 100 for individual timing
start_time = time.perf_counter()
await db_manager.create_record(table_name, record)
end_time = time.perf_counter()
insert_time_ms = (end_time - start_time) * 1000
insert_times.append(insert_time_ms)
# Benchmark batch insert
start_time = time.perf_counter()
for record in test_records[100:]:
await db_manager.create_record(table_name, record)
end_time = time.perf_counter()
batch_time = end_time - start_time
batch_throughput = 900 / batch_time # 900 records in batch
avg_insert_time = statistics.mean(insert_times)
max_insert_time = max(insert_times)
print(f"\nDatabase Insert Performance:")
print(f"Average insert time: {avg_insert_time:.2f} ms")
print(f"Max insert time: {max_insert_time:.2f} ms")
print(f"Batch throughput: {batch_throughput:.0f} inserts/second")
assert avg_insert_time < performance_config["max_db_query_time_ms"]
assert batch_throughput > performance_config["min_db_throughput_ops_per_sec"]
@pytest.mark.asyncio
async def test_database_update_performance(self, performance_config):
"""Test database update performance."""
with patch('psycopg2.connect') as mock_connect:
# Setup mock database
mock_conn = Mock()
mock_cursor = Mock()
mock_conn.cursor.return_value = mock_cursor
mock_connect.return_value = mock_conn
db_manager = DatabaseManager()
db_manager.connect()
table_name = "car_frontal_info"
session_ids = [str(uuid.uuid4()) for _ in range(1000)]
# Benchmark updates
update_times = []
for session_id in session_ids[:100]: # Test first 100 for individual timing
update_data = {
"car_brand": "Toyota",
"car_body_type": "Sedan",
"updated_at": "NOW()"
}
start_time = time.perf_counter()
await db_manager.update_record(table_name, session_id, update_data, key_field="session_id")
end_time = time.perf_counter()
update_time_ms = (end_time - start_time) * 1000
update_times.append(update_time_ms)
# Benchmark batch updates
start_time = time.perf_counter()
for session_id in session_ids[100:]:
update_data = {
"car_brand": "Honda",
"car_body_type": "Hatchback"
}
await db_manager.update_record(table_name, session_id, update_data, key_field="session_id")
end_time = time.perf_counter()
batch_time = end_time - start_time
batch_throughput = 900 / batch_time
avg_update_time = statistics.mean(update_times)
max_update_time = max(update_times)
print(f"\nDatabase Update Performance:")
print(f"Average update time: {avg_update_time:.2f} ms")
print(f"Max update time: {max_update_time:.2f} ms")
print(f"Batch throughput: {batch_throughput:.0f} updates/second")
assert avg_update_time < performance_config["max_db_query_time_ms"]
assert batch_throughput > performance_config["min_db_throughput_ops_per_sec"]
@pytest.mark.asyncio
async def test_database_query_performance(self, performance_config):
"""Test database query performance."""
with patch('psycopg2.connect') as mock_connect:
# Setup mock database
mock_conn = Mock()
mock_cursor = Mock()
mock_conn.cursor.return_value = mock_cursor
# Mock query results
mock_cursor.fetchone.return_value = ("display_1", "1640995200", "session_123", None, "No model", "Toyota", "Sedan")
mock_cursor.fetchall.return_value = [
("display_1", "1640995200", "session_123", None, "No model", "Toyota", "Sedan"),
("display_2", "1640995201", "session_124", None, "No model", "Honda", "Hatchback")
]
mock_connect.return_value = mock_conn
db_manager = DatabaseManager()
db_manager.connect()
table_name = "car_frontal_info"
# Benchmark single record queries
query_times = []
num_queries = 1000
for i in range(num_queries):
session_id = f"session_{i}"
start_time = time.perf_counter()
result = await db_manager.get_record(table_name, session_id, key_field="session_id")
end_time = time.perf_counter()
query_time_ms = (end_time - start_time) * 1000
query_times.append(query_time_ms)
avg_query_time = statistics.mean(query_times)
max_query_time = max(query_times)
query_throughput = num_queries / (sum(query_times) / 1000)
print(f"\nDatabase Query Performance:")
print(f"Queries: {num_queries}")
print(f"Average query time: {avg_query_time:.2f} ms")
print(f"Max query time: {max_query_time:.2f} ms")
print(f"Query throughput: {query_throughput:.0f} queries/second")
assert avg_query_time < performance_config["max_db_query_time_ms"]
assert query_throughput > performance_config["min_db_throughput_ops_per_sec"]
class TestRedisPerformance:
"""Test Redis client performance benchmarks."""
@pytest.mark.asyncio
async def test_redis_connection_performance(self):
"""Test Redis connection performance."""
with patch('redis.Redis') as mock_redis_class, \
patch('redis.ConnectionPool') as mock_pool_class:
mock_redis = Mock()
mock_redis.ping.return_value = True
mock_redis_class.return_value = mock_redis
mock_pool = Mock()
mock_pool_class.return_value = mock_pool
config = RedisConfig(host="localhost", port=6379)
# Benchmark connection times
connection_times = []
num_connections = 100
for _ in range(num_connections):
redis_client = RedisClient(config)
start_time = time.perf_counter()
await redis_client.connect()
end_time = time.perf_counter()
connection_time_ms = (end_time - start_time) * 1000
connection_times.append(connection_time_ms)
await redis_client.disconnect()
avg_connection_time = statistics.mean(connection_times)
max_connection_time = max(connection_times)
print(f"\nRedis Connection Performance:")
print(f"Connections: {num_connections}")
print(f"Average connection time: {avg_connection_time:.2f} ms")
print(f"Max connection time: {max_connection_time:.2f} ms")
# Redis connections should be very fast
assert avg_connection_time < 5.0 # Less than 5ms average
assert max_connection_time < 20.0 # Less than 20ms max
@pytest.mark.asyncio
async def test_redis_basic_operations_performance(self, performance_config):
"""Test basic Redis operations performance."""
with patch('redis.Redis') as mock_redis_class:
mock_redis = Mock()
mock_redis.ping.return_value = True
mock_redis.set.return_value = True
mock_redis.get.return_value = "test_value"
mock_redis.delete.return_value = 1
mock_redis.exists.return_value = 1
mock_redis_class.return_value = mock_redis
config = RedisConfig(host="localhost")
redis_client = RedisClient(config)
await redis_client.connect()
# Benchmark SET operations
set_times = []
num_operations = 10000
for i in range(num_operations):
start_time = time.perf_counter()
await redis_client.set(f"key_{i}", f"value_{i}", expire_seconds=300)
end_time = time.perf_counter()
set_time_ms = (end_time - start_time) * 1000
set_times.append(set_time_ms)
# Benchmark GET operations
get_times = []
for i in range(num_operations):
start_time = time.perf_counter()
value = await redis_client.get(f"key_{i}")
end_time = time.perf_counter()
get_time_ms = (end_time - start_time) * 1000
get_times.append(get_time_ms)
# Benchmark DELETE operations
delete_times = []
for i in range(num_operations):
start_time = time.perf_counter()
result = await redis_client.delete(f"key_{i}")
end_time = time.perf_counter()
delete_time_ms = (end_time - start_time) * 1000
delete_times.append(delete_time_ms)
# Calculate statistics
avg_set_time = statistics.mean(set_times)
avg_get_time = statistics.mean(get_times)
avg_delete_time = statistics.mean(delete_times)
set_throughput = num_operations / (sum(set_times) / 1000)
get_throughput = num_operations / (sum(get_times) / 1000)
delete_throughput = num_operations / (sum(delete_times) / 1000)
print(f"\nRedis Basic Operations Performance:")
print(f"Operations per type: {num_operations}")
print(f"Average SET time: {avg_set_time:.3f} ms")
print(f"Average GET time: {avg_get_time:.3f} ms")
print(f"Average DELETE time: {avg_delete_time:.3f} ms")
print(f"SET throughput: {set_throughput:.0f} ops/second")
print(f"GET throughput: {get_throughput:.0f} ops/second")
print(f"DELETE throughput: {delete_throughput:.0f} ops/second")
assert avg_set_time < performance_config["max_redis_operation_time_ms"]
assert avg_get_time < performance_config["max_redis_operation_time_ms"]
assert avg_delete_time < performance_config["max_redis_operation_time_ms"]
assert set_throughput > performance_config["min_redis_throughput_ops_per_sec"]
assert get_throughput > performance_config["min_redis_throughput_ops_per_sec"]
@pytest.mark.asyncio
async def test_redis_image_storage_performance(self):
"""Test Redis image storage performance."""
with patch('redis.Redis') as mock_redis_class, \
patch('cv2.imencode') as mock_imencode:
mock_redis = Mock()
mock_redis.ping.return_value = True
mock_redis.set.return_value = True
mock_redis.expire.return_value = True
mock_redis_class.return_value = mock_redis
# Mock image encoding
encoded_data = np.array([1, 2, 3, 4, 5], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
config = RedisConfig(host="localhost")
redis_client = RedisClient(config)
await redis_client.connect()
# Create test frames
small_frame = np.random.randint(0, 255, (240, 320, 3), dtype=np.uint8)
medium_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
large_frame = np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8)
frames = [
("small", small_frame),
("medium", medium_frame),
("large", large_frame)
]
for frame_type, frame in frames:
storage_times = []
num_images = 100
for i in range(num_images):
key = f"test_image_{frame_type}_{i}"
start_time = time.perf_counter()
await redis_client.image_storage.store_image(key, frame, expire_seconds=300)
end_time = time.perf_counter()
storage_time_ms = (end_time - start_time) * 1000
storage_times.append(storage_time_ms)
avg_storage_time = statistics.mean(storage_times)
max_storage_time = max(storage_times)
throughput = num_images / (sum(storage_times) / 1000)
print(f"\n{frame_type.capitalize()} Frame Storage Performance:")
print(f"Frame size: {frame.shape}")
print(f"Images stored: {num_images}")
print(f"Average storage time: {avg_storage_time:.2f} ms")
print(f"Max storage time: {max_storage_time:.2f} ms")
print(f"Storage throughput: {throughput:.1f} images/second")
# Performance should scale reasonably with image size
expected_max_time = {"small": 50, "medium": 100, "large": 200}
assert avg_storage_time < expected_max_time[frame_type]
@pytest.mark.asyncio
async def test_redis_pipeline_performance(self):
"""Test Redis pipeline performance."""
with patch('redis.Redis') as mock_redis_class:
mock_redis = Mock()
mock_redis.ping.return_value = True
mock_redis_class.return_value = mock_redis
# Mock pipeline
mock_pipeline = Mock()
mock_pipeline.execute.return_value = [True] * 1000
mock_redis.pipeline.return_value = mock_pipeline
config = RedisConfig(host="localhost")
redis_client = RedisClient(config)
await redis_client.connect()
# Benchmark pipeline operations
num_operations = 1000
start_time = time.perf_counter()
async with redis_client.pipeline() as pipe:
for i in range(num_operations):
pipe.set(f"pipeline_key_{i}", f"pipeline_value_{i}")
results = await pipe.execute()
end_time = time.perf_counter()
total_time = end_time - start_time
throughput = num_operations / total_time
print(f"\nRedis Pipeline Performance:")
print(f"Operations: {num_operations}")
print(f"Total time: {total_time:.3f} seconds")
print(f"Throughput: {throughput:.0f} ops/second")
# Pipeline should be much faster than individual operations
assert throughput > 10000 # Should exceed 10k ops/second with pipeline
assert len(results) == num_operations
class TestSessionCachePerformance:
"""Test session cache performance benchmarks."""
def test_cache_basic_operations_performance(self, performance_config):
"""Test basic cache operations performance."""
cache_config = CacheConfig(max_size=10000, ttl_seconds=3600)
cache = SessionCache(cache_config)
# Prepare test data
test_sessions = []
for i in range(10000):
from detector_worker.storage.session_cache import SessionData
session_data = SessionData(
session_id=f"session_{i}",
camera_id=f"camera_{i % 100}", # 100 unique cameras
display_id=f"display_{i % 50}" # 50 unique displays
)
session_data.add_detection_data("main", {"class": "car", "confidence": 0.9})
test_sessions.append((f"session_{i}", session_data))
# Benchmark PUT operations
put_times = []
for session_id, session_data in test_sessions:
start_time = time.perf_counter()
cache.put(session_id, session_data)
end_time = time.perf_counter()
put_time_ms = (end_time - start_time) * 1000
put_times.append(put_time_ms)
# Benchmark GET operations
get_times = []
for session_id, _ in test_sessions:
start_time = time.perf_counter()
retrieved_data = cache.get(session_id)
end_time = time.perf_counter()
get_time_ms = (end_time - start_time) * 1000
get_times.append(get_time_ms)
# Calculate statistics
avg_put_time = statistics.mean(put_times)
avg_get_time = statistics.mean(get_times)
max_put_time = max(put_times)
max_get_time = max(get_times)
put_throughput = len(test_sessions) / (sum(put_times) / 1000)
get_throughput = len(test_sessions) / (sum(get_times) / 1000)
print(f"\nSession Cache Basic Operations Performance:")
print(f"Operations per type: {len(test_sessions)}")
print(f"Average PUT time: {avg_put_time:.3f} ms")
print(f"Average GET time: {avg_get_time:.3f} ms")
print(f"Max PUT time: {max_put_time:.3f} ms")
print(f"Max GET time: {max_get_time:.3f} ms")
print(f"PUT throughput: {put_throughput:.0f} ops/second")
print(f"GET throughput: {get_throughput:.0f} ops/second")
assert avg_put_time < performance_config["max_cache_operation_time_ms"]
assert avg_get_time < performance_config["max_cache_operation_time_ms"]
assert put_throughput > performance_config["min_cache_throughput_ops_per_sec"]
assert get_throughput > performance_config["min_cache_throughput_ops_per_sec"]
def test_cache_manager_performance(self, performance_config):
"""Test session cache manager performance."""
cache_manager = SessionCacheManager()
cache_manager.clear_all()
# Benchmark detection caching
detection_times = []
num_operations = 5000
for i in range(num_operations):
camera_id = f"camera_{i % 50}"
detection_data = {
"class": "car",
"confidence": 0.9,
"bbox": [100, 200, 300, 400],
"track_id": i
}
start_time = time.perf_counter()
cache_manager.cache_detection(camera_id, detection_data)
end_time = time.perf_counter()
detection_time_ms = (end_time - start_time) * 1000
detection_times.append(detection_time_ms)
# Benchmark detection retrieval
retrieval_times = []
for i in range(num_operations):
camera_id = f"camera_{i % 50}"
start_time = time.perf_counter()
cached_detection = cache_manager.get_cached_detection(camera_id)
end_time = time.perf_counter()
retrieval_time_ms = (end_time - start_time) * 1000
retrieval_times.append(retrieval_time_ms)
# Benchmark session operations
session_times = []
for i in range(1000): # Fewer session operations as they're more complex
session_id = str(uuid.uuid4())
camera_id = f"camera_{i % 20}"
start_time = time.perf_counter()
cache_manager.create_session(session_id, camera_id, {"initial": "data"})
cache_manager.update_session_detection(session_id, {"car_brand": "Toyota"})
session_data = cache_manager.get_session_detection(session_id)
end_time = time.perf_counter()
session_time_ms = (end_time - start_time) * 1000
session_times.append(session_time_ms)
# Calculate statistics
avg_detection_time = statistics.mean(detection_times)
avg_retrieval_time = statistics.mean(retrieval_times)
avg_session_time = statistics.mean(session_times)
detection_throughput = num_operations / (sum(detection_times) / 1000)
retrieval_throughput = num_operations / (sum(retrieval_times) / 1000)
session_throughput = 1000 / (sum(session_times) / 1000)
print(f"\nCache Manager Performance:")
print(f"Average detection cache time: {avg_detection_time:.3f} ms")
print(f"Average retrieval time: {avg_retrieval_time:.3f} ms")
print(f"Average session operation time: {avg_session_time:.3f} ms")
print(f"Detection throughput: {detection_throughput:.0f} ops/second")
print(f"Retrieval throughput: {retrieval_throughput:.0f} ops/second")
print(f"Session throughput: {session_throughput:.0f} ops/second")
assert avg_detection_time < performance_config["max_cache_operation_time_ms"] * 2
assert avg_retrieval_time < performance_config["max_cache_operation_time_ms"]
assert detection_throughput > performance_config["min_cache_throughput_ops_per_sec"] / 2
def test_cache_memory_performance(self):
"""Test cache memory usage and performance."""
# Measure initial memory
gc.collect()
initial_memory = psutil.Process().memory_info().rss / 1024 / 1024 # MB
cache_config = CacheConfig(max_size=10000, ttl_seconds=3600)
cache = SessionCache(cache_config)
# Add many sessions to test memory usage
num_sessions = 5000
memory_measurements = []
for i in range(num_sessions):
from detector_worker.storage.session_cache import SessionData
session_data = SessionData(
session_id=f"memory_session_{i}",
camera_id=f"camera_{i % 100}",
display_id=f"display_{i % 50}"
)
# Add some detection data
session_data.add_detection_data("detection", {
"class": "car",
"confidence": 0.9,
"bbox": [100, 200, 300, 400],
"features": [float(j) for j in range(50)] # Add some bulk
})
cache.put(f"memory_session_{i}", session_data)
# Measure memory periodically
if i % 500 == 0 and i > 0:
current_memory = psutil.Process().memory_info().rss / 1024 / 1024
memory_increase = current_memory - initial_memory
memory_measurements.append((i, memory_increase))
# Final memory measurement
gc.collect()
final_memory = psutil.Process().memory_info().rss / 1024 / 1024
total_memory_increase = final_memory - initial_memory
# Calculate memory per session
memory_per_session = total_memory_increase / num_sessions
print(f"\nCache Memory Performance:")
print(f"Sessions cached: {num_sessions}")
print(f"Initial memory: {initial_memory:.1f} MB")
print(f"Final memory: {final_memory:.1f} MB")
print(f"Total memory increase: {total_memory_increase:.1f} MB")
print(f"Memory per session: {memory_per_session * 1024:.1f} KB")
# Memory usage should be reasonable
assert memory_per_session < 0.1 # Less than 100KB per session
assert total_memory_increase < 500 # Total increase less than 500MB
# Test access performance with full cache
access_times = []
for i in range(1000):
session_id = f"memory_session_{i}"
start_time = time.perf_counter()
session_data = cache.get(session_id)
end_time = time.perf_counter()
access_time_ms = (end_time - start_time) * 1000
access_times.append(access_time_ms)
avg_access_time = statistics.mean(access_times)
max_access_time = max(access_times)
print(f"Full cache access performance:")
print(f"Average access time: {avg_access_time:.3f} ms")
print(f"Max access time: {max_access_time:.3f} ms")
# Access should remain fast even with full cache
assert avg_access_time < 1.0 # Less than 1ms average
assert max_access_time < 10.0 # Less than 10ms max
def test_cache_eviction_performance(self):
"""Test cache eviction performance."""
# Create cache with small size to force evictions
cache_config = CacheConfig(max_size=1000, eviction_policy="lru")
cache = SessionCache(cache_config)
# Fill cache beyond capacity
num_sessions = 2000 # Double the capacity
eviction_times = []
for i in range(num_sessions):
from detector_worker.storage.session_cache import SessionData
session_data = SessionData(
session_id=f"eviction_session_{i}",
camera_id=f"camera_{i % 100}",
display_id=f"display_{i % 50}"
)
start_time = time.perf_counter()
cache.put(f"eviction_session_{i}", session_data)
end_time = time.perf_counter()
operation_time_ms = (end_time - start_time) * 1000
eviction_times.append(operation_time_ms)
# Analyze eviction performance
avg_operation_time = statistics.mean(eviction_times)
max_operation_time = max(eviction_times)
# Check that cache size is maintained
assert cache.size() == 1000 # Should not exceed max_size
print(f"\nCache Eviction Performance:")
print(f"Sessions processed: {num_sessions}")
print(f"Final cache size: {cache.size()}")
print(f"Average operation time: {avg_operation_time:.3f} ms")
print(f"Max operation time: {max_operation_time:.3f} ms")
# Eviction should not significantly slow down operations
assert avg_operation_time < 5.0 # Less than 5ms average with eviction
assert max_operation_time < 20.0 # Less than 20ms max
class TestStorageIntegrationPerformance:
"""Test integrated storage performance scenarios."""
@pytest.mark.asyncio
async def test_full_storage_pipeline_performance(self):
"""Test performance of complete storage pipeline."""
with patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis_class:
# Setup mocks
mock_db_conn = Mock()
mock_db_cursor = Mock()
mock_db_conn.cursor.return_value = mock_db_cursor
mock_db_connect.return_value = mock_db_conn
mock_redis = Mock()
mock_redis.ping.return_value = True
mock_redis.set.return_value = True
mock_redis.expire.return_value = True
mock_redis_class.return_value = mock_redis
# Initialize storage components
db_manager = DatabaseManager()
db_manager.connect()
redis_config = RedisConfig(host="localhost")
redis_client = RedisClient(redis_config)
await redis_client.connect()
cache_manager = SessionCacheManager()
cache_manager.clear_all()
# Benchmark complete storage pipeline
pipeline_times = []
num_iterations = 500
for i in range(num_iterations):
session_id = str(uuid.uuid4())
camera_id = f"camera_{i % 20}"
start_time = time.perf_counter()
# 1. Cache detection
detection_data = {
"class": "car",
"confidence": 0.9,
"bbox": [100, 200, 300, 400],
"track_id": i + 1000
}
cache_manager.cache_detection(camera_id, detection_data)
# 2. Create session
cache_manager.create_session(session_id, camera_id, {"initial": "data"})
# 3. Database insert
await db_manager.create_record("car_frontal_info", {
"session_id": session_id,
"display_id": f"display_{i % 10}",
"captured_timestamp": str(int(time.time() * 1000)),
"license_type": "No model available"
})
# 4. Redis store
await redis_client.set(f"detection:{session_id}", "image_data", expire_seconds=600)
# 5. Update session with results
cache_manager.update_session_detection(session_id, {
"car_brand": "Toyota",
"car_body_type": "Sedan"
})
# 6. Database update
await db_manager.update_record("car_frontal_info", session_id, {
"car_brand": "Toyota",
"car_body_type": "Sedan"
}, key_field="session_id")
end_time = time.perf_counter()
pipeline_time_ms = (end_time - start_time) * 1000
pipeline_times.append(pipeline_time_ms)
# Analyze pipeline performance
avg_pipeline_time = statistics.mean(pipeline_times)
max_pipeline_time = max(pipeline_times)
pipeline_throughput = num_iterations / (sum(pipeline_times) / 1000)
print(f"\nFull Storage Pipeline Performance:")
print(f"Pipeline iterations: {num_iterations}")
print(f"Average pipeline time: {avg_pipeline_time:.2f} ms")
print(f"Max pipeline time: {max_pipeline_time:.2f} ms")
print(f"Pipeline throughput: {pipeline_throughput:.1f} pipelines/second")
# Complete pipeline should be efficient
assert avg_pipeline_time < 100 # Less than 100ms per complete pipeline
assert pipeline_throughput > 50 # At least 50 pipelines/second

View file

@ -0,0 +1,596 @@
"""
Performance tests for WebSocket communication and message processing.
These tests benchmark WebSocket throughput, latency, and concurrent
connection handling to ensure scalability requirements are met.
"""
import pytest
import asyncio
import time
import statistics
import json
from unittest.mock import Mock, AsyncMock
from concurrent.futures import ThreadPoolExecutor
import psutil
from detector_worker.communication.websocket_handler import WebSocketHandler
from detector_worker.communication.message_processor import MessageProcessor
from detector_worker.communication.websocket_handler import ConnectionManager
@pytest.fixture
def performance_config():
"""Configuration for performance tests."""
return {
"max_message_latency_ms": 10,
"min_throughput_msgs_per_sec": 1000,
"max_concurrent_connections": 100,
"max_memory_per_connection_kb": 100
}
@pytest.fixture
def mock_websocket():
"""Create mock WebSocket for performance testing."""
websocket = Mock()
websocket.accept = AsyncMock()
websocket.send_json = AsyncMock()
websocket.send_text = AsyncMock()
websocket.receive_json = AsyncMock()
websocket.receive_text = AsyncMock()
websocket.close = AsyncMock()
websocket.ping = AsyncMock()
return websocket
class TestWebSocketMessagePerformance:
"""Test WebSocket message processing performance."""
@pytest.mark.asyncio
async def test_message_processing_throughput(self, performance_config):
"""Test message processing throughput."""
message_processor = MessageProcessor()
# Simple state request message
test_message = {"type": "requestState"}
client_id = "perf_client"
# Warm up
for _ in range(10):
await message_processor.process_message(test_message, client_id)
# Benchmark throughput
num_messages = 10000
start_time = time.perf_counter()
for _ in range(num_messages):
await message_processor.process_message(test_message, client_id)
end_time = time.perf_counter()
total_time = end_time - start_time
throughput = num_messages / total_time
print(f"\nMessage Processing Throughput:")
print(f"Messages processed: {num_messages}")
print(f"Total time: {total_time:.2f} seconds")
print(f"Throughput: {throughput:.0f} messages/second")
assert throughput >= performance_config["min_throughput_msgs_per_sec"]
@pytest.mark.asyncio
async def test_message_processing_latency(self, performance_config):
"""Test individual message processing latency."""
message_processor = MessageProcessor()
test_messages = [
{"type": "requestState"},
{"type": "setSessionId", "payload": {"sessionId": "test", "displayId": "display"}},
{"type": "patchSession", "payload": {"sessionId": "test", "data": {"test": "value"}}}
]
client_id = "latency_client"
# Benchmark individual message latency
all_latencies = []
for message_type, test_message in enumerate(test_messages):
latencies = []
for _ in range(1000):
start_time = time.perf_counter()
await message_processor.process_message(test_message, client_id)
end_time = time.perf_counter()
latency_ms = (end_time - start_time) * 1000
latencies.append(latency_ms)
avg_latency = statistics.mean(latencies)
max_latency = max(latencies)
p95_latency = statistics.quantiles(latencies, n=20)[18] # 95th percentile
all_latencies.extend(latencies)
print(f"\nMessage Type: {test_message['type']}")
print(f"Average latency: {avg_latency:.3f} ms")
print(f"Max latency: {max_latency:.3f} ms")
print(f"95th percentile: {p95_latency:.3f} ms")
assert avg_latency < performance_config["max_message_latency_ms"]
assert p95_latency < performance_config["max_message_latency_ms"] * 2
# Overall statistics
overall_avg = statistics.mean(all_latencies)
overall_p95 = statistics.quantiles(all_latencies, n=20)[18]
print(f"\nOverall Message Latency:")
print(f"Average latency: {overall_avg:.3f} ms")
print(f"95th percentile: {overall_p95:.3f} ms")
@pytest.mark.asyncio
async def test_concurrent_message_processing(self, performance_config):
"""Test concurrent message processing performance."""
message_processor = MessageProcessor()
async def process_messages_batch(client_id, num_messages):
"""Process a batch of messages for one client."""
test_message = {"type": "requestState"}
latencies = []
for _ in range(num_messages):
start_time = time.perf_counter()
await message_processor.process_message(test_message, client_id)
end_time = time.perf_counter()
latency_ms = (end_time - start_time) * 1000
latencies.append(latency_ms)
return latencies
# Run concurrent processing
num_clients = 50
messages_per_client = 100
start_time = time.perf_counter()
tasks = [
process_messages_batch(f"client_{i}", messages_per_client)
for i in range(num_clients)
]
results = await asyncio.gather(*tasks)
end_time = time.perf_counter()
total_time = end_time - start_time
# Analyze results
all_latencies = [latency for client_latencies in results for latency in client_latencies]
total_messages = len(all_latencies)
avg_latency = statistics.mean(all_latencies)
throughput = total_messages / total_time
print(f"\nConcurrent Message Processing:")
print(f"Clients: {num_clients}")
print(f"Total messages: {total_messages}")
print(f"Total time: {total_time:.2f} seconds")
print(f"Throughput: {throughput:.0f} messages/second")
print(f"Average latency: {avg_latency:.3f} ms")
assert throughput >= performance_config["min_throughput_msgs_per_sec"] / 2 # Reduced due to concurrency overhead
assert avg_latency < performance_config["max_message_latency_ms"] * 2
@pytest.mark.asyncio
async def test_large_message_performance(self):
"""Test performance with large messages."""
message_processor = MessageProcessor()
# Create large message (simulating detection results)
large_payload = {
"detections": [
{
"class": f"object_{i}",
"confidence": 0.9,
"bbox": [i*10, i*10, (i+1)*10, (i+1)*10],
"metadata": {
"feature_vector": [float(j) for j in range(100)],
"description": "x" * 500 # Large text field
}
}
for i in range(50) # 50 detections
],
"camera_info": {
"resolution": [1920, 1080],
"settings": {"brightness": 50, "contrast": 75},
"history": [{"timestamp": i, "event": f"event_{i}"} for i in range(100)]
}
}
large_message = {
"type": "imageDetection",
"payload": large_payload
}
client_id = "large_msg_client"
# Measure message size
message_size_bytes = len(json.dumps(large_message))
print(f"\nLarge Message Performance:")
print(f"Message size: {message_size_bytes / 1024:.1f} KB")
# Benchmark large message processing
processing_times = []
num_iterations = 100
for _ in range(num_iterations):
start_time = time.perf_counter()
await message_processor.process_message(large_message, client_id)
end_time = time.perf_counter()
processing_time_ms = (end_time - start_time) * 1000
processing_times.append(processing_time_ms)
avg_processing_time = statistics.mean(processing_times)
max_processing_time = max(processing_times)
print(f"Average processing time: {avg_processing_time:.2f} ms")
print(f"Max processing time: {max_processing_time:.2f} ms")
# Large messages should still be processed reasonably quickly
assert avg_processing_time < 100 # Less than 100ms for large messages
assert max_processing_time < 500 # Less than 500ms max
class TestConnectionManagerPerformance:
"""Test connection manager performance."""
def test_connection_creation_performance(self, performance_config, mock_websocket):
"""Test connection creation and management performance."""
connection_manager = ConnectionManager()
# Benchmark connection creation
creation_times = []
num_connections = 1000
for i in range(num_connections):
start_time = time.perf_counter()
connection_manager._create_connection(mock_websocket, f"client_{i}")
end_time = time.perf_counter()
creation_time_ms = (end_time - start_time) * 1000
creation_times.append(creation_time_ms)
avg_creation_time = statistics.mean(creation_times)
max_creation_time = max(creation_times)
print(f"\nConnection Creation Performance:")
print(f"Connections created: {num_connections}")
print(f"Average creation time: {avg_creation_time:.3f} ms")
print(f"Max creation time: {max_creation_time:.3f} ms")
# Connection creation should be very fast
assert avg_creation_time < 1.0 # Less than 1ms average
assert max_creation_time < 10.0 # Less than 10ms max
@pytest.mark.asyncio
async def test_broadcast_performance(self, mock_websocket):
"""Test broadcast message performance."""
connection_manager = ConnectionManager()
# Create many mock connections
num_connections = 1000
mock_websockets = []
for i in range(num_connections):
ws = Mock()
ws.send_json = AsyncMock()
ws.send_text = AsyncMock()
mock_websockets.append(ws)
# Add to connection manager
connection = connection_manager._create_connection(ws, f"client_{i}")
connection.is_connected = True
connection_manager.connections[f"client_{i}"] = connection
# Test broadcast performance
test_message = {"type": "broadcast", "data": "test message"}
broadcast_times = []
num_broadcasts = 100
for _ in range(num_broadcasts):
start_time = time.perf_counter()
await connection_manager.broadcast(test_message)
end_time = time.perf_counter()
broadcast_time_ms = (end_time - start_time) * 1000
broadcast_times.append(broadcast_time_ms)
avg_broadcast_time = statistics.mean(broadcast_times)
max_broadcast_time = max(broadcast_times)
print(f"\nBroadcast Performance:")
print(f"Connections: {num_connections}")
print(f"Broadcasts: {num_broadcasts}")
print(f"Average broadcast time: {avg_broadcast_time:.2f} ms")
print(f"Max broadcast time: {max_broadcast_time:.2f} ms")
# Broadcast should scale reasonably
assert avg_broadcast_time < 50 # Less than 50ms for 1000 connections
# Verify all connections received the message
for ws in mock_websockets:
assert ws.send_json.call_count == num_broadcasts
def test_subscription_management_performance(self):
"""Test subscription management performance."""
connection_manager = ConnectionManager()
# Test subscription operations performance
num_operations = 10000
# Add subscriptions
add_times = []
for i in range(num_operations):
client_id = f"client_{i % 100}" # 100 unique clients
subscription_id = f"camera_{i % 50}" # 50 unique cameras
start_time = time.perf_counter()
connection_manager.add_subscription(client_id, subscription_id)
end_time = time.perf_counter()
add_time_ms = (end_time - start_time) * 1000
add_times.append(add_time_ms)
# Query subscriptions
query_times = []
for i in range(1000):
client_id = f"client_{i % 100}"
start_time = time.perf_counter()
subscriptions = connection_manager.get_client_subscriptions(client_id)
end_time = time.perf_counter()
query_time_ms = (end_time - start_time) * 1000
query_times.append(query_time_ms)
# Remove subscriptions
remove_times = []
for i in range(num_operations):
client_id = f"client_{i % 100}"
subscription_id = f"camera_{i % 50}"
start_time = time.perf_counter()
connection_manager.remove_subscription(client_id, subscription_id)
end_time = time.perf_counter()
remove_time_ms = (end_time - start_time) * 1000
remove_times.append(remove_time_ms)
# Analyze results
avg_add_time = statistics.mean(add_times)
avg_query_time = statistics.mean(query_times)
avg_remove_time = statistics.mean(remove_times)
print(f"\nSubscription Management Performance:")
print(f"Average add time: {avg_add_time:.4f} ms")
print(f"Average query time: {avg_query_time:.4f} ms")
print(f"Average remove time: {avg_remove_time:.4f} ms")
# Should be very fast operations
assert avg_add_time < 0.1
assert avg_query_time < 0.1
assert avg_remove_time < 0.1
class TestWebSocketHandlerPerformance:
"""Test complete WebSocket handler performance."""
@pytest.mark.asyncio
async def test_concurrent_connections_performance(self, performance_config):
"""Test performance with many concurrent connections."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
async def simulate_client_session(client_id, num_messages=50):
"""Simulate a client WebSocket session."""
mock_ws = Mock()
mock_ws.accept = AsyncMock()
mock_ws.send_json = AsyncMock()
mock_ws.receive_json = AsyncMock()
# Simulate message sequence
messages = [
{"type": "requestState"} for _ in range(num_messages)
]
messages.append(asyncio.CancelledError()) # Disconnect
mock_ws.receive_json.side_effect = messages
processing_times = []
try:
await websocket_handler.handle_websocket(mock_ws, client_id)
except asyncio.CancelledError:
pass # Expected disconnect
return len(messages) - 1 # Exclude the disconnect
# Test concurrent connections
num_concurrent_clients = 100
messages_per_client = 25
start_time = time.perf_counter()
tasks = [
simulate_client_session(f"perf_client_{i}", messages_per_client)
for i in range(num_concurrent_clients)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
end_time = time.perf_counter()
total_time = end_time - start_time
# Analyze results
successful_clients = len([r for r in results if not isinstance(r, Exception)])
total_messages = sum(r for r in results if isinstance(r, int))
print(f"\nConcurrent Connections Performance:")
print(f"Concurrent clients: {num_concurrent_clients}")
print(f"Successful clients: {successful_clients}")
print(f"Total messages: {total_messages}")
print(f"Total time: {total_time:.2f} seconds")
print(f"Messages per second: {total_messages / total_time:.0f}")
assert successful_clients >= num_concurrent_clients * 0.95 # 95% success rate
assert total_messages / total_time >= 1000 # At least 1000 msg/sec throughput
@pytest.mark.asyncio
async def test_memory_usage_under_load(self, performance_config):
"""Test memory usage under high connection load."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Measure initial memory
initial_memory = psutil.Process().memory_info().rss / 1024 / 1024 # MB
# Create many connections
num_connections = 500
connections = []
for i in range(num_connections):
mock_ws = Mock()
mock_ws.accept = AsyncMock()
mock_ws.send_json = AsyncMock()
connection = websocket_handler.connection_manager._create_connection(
mock_ws, f"mem_test_client_{i}"
)
connection.is_connected = True
websocket_handler.connection_manager.connections[f"mem_test_client_{i}"] = connection
connections.append(connection)
# Measure memory after connections
after_connections_memory = psutil.Process().memory_info().rss / 1024 / 1024
memory_per_connection = (after_connections_memory - initial_memory) / num_connections * 1024 # KB
# Simulate some activity
test_message = {"type": "broadcast", "data": "test"}
for _ in range(10):
await websocket_handler.connection_manager.broadcast(test_message)
# Measure memory after activity
after_activity_memory = psutil.Process().memory_info().rss / 1024 / 1024
print(f"\nMemory Usage Under Load:")
print(f"Initial memory: {initial_memory:.1f} MB")
print(f"After {num_connections} connections: {after_connections_memory:.1f} MB")
print(f"After activity: {after_activity_memory:.1f} MB")
print(f"Memory per connection: {memory_per_connection:.1f} KB")
# Memory usage should be reasonable
assert memory_per_connection < performance_config["max_memory_per_connection_kb"]
# Clean up
websocket_handler.connection_manager.connections.clear()
@pytest.mark.asyncio
async def test_heartbeat_performance(self):
"""Test heartbeat mechanism performance."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor, {"heartbeat_interval": 0.1})
# Create connections with mock WebSockets
num_connections = 100
mock_websockets = []
for i in range(num_connections):
mock_ws = Mock()
mock_ws.ping = AsyncMock()
mock_websockets.append(mock_ws)
connection = websocket_handler.connection_manager._create_connection(
mock_ws, f"heartbeat_client_{i}"
)
connection.is_connected = True
websocket_handler.connection_manager.connections[f"heartbeat_client_{i}"] = connection
# Start heartbeat task
heartbeat_task = asyncio.create_task(websocket_handler._heartbeat_loop())
# Let it run for several heartbeat cycles
start_time = time.perf_counter()
await asyncio.sleep(0.5) # 5 heartbeat cycles
end_time = time.perf_counter()
# Cancel heartbeat
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
# Analyze heartbeat performance
elapsed_time = end_time - start_time
expected_pings = int(elapsed_time / 0.1) * num_connections
actual_pings = sum(ws.ping.call_count for ws in mock_websockets)
ping_efficiency = actual_pings / expected_pings if expected_pings > 0 else 0
print(f"\nHeartbeat Performance:")
print(f"Connections: {num_connections}")
print(f"Elapsed time: {elapsed_time:.2f} seconds")
print(f"Expected pings: {expected_pings}")
print(f"Actual pings: {actual_pings}")
print(f"Ping efficiency: {ping_efficiency:.2%}")
# Should achieve reasonable ping efficiency
assert ping_efficiency > 0.8 # At least 80% efficiency
# Clean up
websocket_handler.connection_manager.connections.clear()
@pytest.mark.asyncio
async def test_error_handling_performance(self):
"""Test performance impact of error handling."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Create messages that will cause errors
error_messages = [
{"invalid": "message"}, # Missing type
{"type": "unknown_type"}, # Unknown type
{"type": "subscribe"}, # Missing payload
]
valid_message = {"type": "requestState"}
# Mix error messages with valid ones
test_sequence = (error_messages + [valid_message]) * 250 # 1000 total messages
start_time = time.perf_counter()
for message in test_sequence:
await message_processor.process_message(message, "error_perf_client")
end_time = time.perf_counter()
total_time = end_time - start_time
throughput = len(test_sequence) / total_time
print(f"\nError Handling Performance:")
print(f"Total messages (with errors): {len(test_sequence)}")
print(f"Total time: {total_time:.2f} seconds")
print(f"Throughput: {throughput:.0f} messages/second")
# Error handling shouldn't significantly impact performance
assert throughput > 500 # Should still process > 500 msg/sec with errors

View file

@ -0,0 +1,856 @@
"""
Unit tests for WebSocket handling functionality.
"""
import pytest
import asyncio
import json
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from fastapi.websockets import WebSocket, WebSocketDisconnect
import uuid
from detector_worker.communication.websocket_handler import (
WebSocketHandler,
ConnectionManager,
WebSocketConnection,
MessageHandler,
WebSocketError,
ConnectionError as WSConnectionError
)
from detector_worker.communication.message_processor import MessageType
from detector_worker.core.exceptions import MessageProcessingError
class TestWebSocketConnection:
"""Test WebSocket connection wrapper."""
def test_creation(self, mock_websocket):
"""Test WebSocket connection creation."""
connection = WebSocketConnection(mock_websocket, "client_001")
assert connection.websocket == mock_websocket
assert connection.client_id == "client_001"
assert connection.is_connected is False
assert connection.connected_at is None
assert connection.last_ping is None
assert connection.subscription_id is None
@pytest.mark.asyncio
async def test_accept_connection(self, mock_websocket):
"""Test accepting WebSocket connection."""
connection = WebSocketConnection(mock_websocket, "client_001")
mock_websocket.accept = AsyncMock()
await connection.accept()
assert connection.is_connected is True
assert connection.connected_at is not None
mock_websocket.accept.assert_called_once()
@pytest.mark.asyncio
async def test_send_message_json(self, mock_websocket):
"""Test sending JSON message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_json = AsyncMock()
message = {"type": "test", "data": "hello"}
await connection.send_message(message)
mock_websocket.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_send_message_text(self, mock_websocket):
"""Test sending text message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_text = AsyncMock()
await connection.send_message("hello world")
mock_websocket.send_text.assert_called_once_with("hello world")
@pytest.mark.asyncio
async def test_send_message_not_connected(self, mock_websocket):
"""Test sending message when not connected."""
connection = WebSocketConnection(mock_websocket, "client_001")
# Don't set is_connected = True
with pytest.raises(WebSocketError) as exc_info:
await connection.send_message({"type": "test"})
assert "not connected" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_receive_message_json(self, mock_websocket):
"""Test receiving JSON message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.receive_json = AsyncMock(return_value={"type": "test", "data": "received"})
message = await connection.receive_message()
assert message == {"type": "test", "data": "received"}
mock_websocket.receive_json.assert_called_once()
@pytest.mark.asyncio
async def test_receive_message_text(self, mock_websocket):
"""Test receiving text message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
# Mock receive_json to fail, then receive_text to succeed
mock_websocket.receive_json = AsyncMock(side_effect=json.JSONDecodeError("Invalid JSON", "", 0))
mock_websocket.receive_text = AsyncMock(return_value="plain text message")
message = await connection.receive_message()
assert message == "plain text message"
mock_websocket.receive_text.assert_called_once()
@pytest.mark.asyncio
async def test_ping_pong(self, mock_websocket):
"""Test ping/pong functionality."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.ping = AsyncMock()
await connection.ping()
assert connection.last_ping is not None
mock_websocket.ping.assert_called_once()
@pytest.mark.asyncio
async def test_close_connection(self, mock_websocket):
"""Test closing connection."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.close = AsyncMock()
await connection.close(code=1000, reason="Normal closure")
assert connection.is_connected is False
mock_websocket.close.assert_called_once_with(code=1000, reason="Normal closure")
def test_connection_info(self, mock_websocket):
"""Test getting connection information."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
connection.subscription_id = "sub_123"
info = connection.get_connection_info()
assert info["client_id"] == "client_001"
assert info["is_connected"] is True
assert info["subscription_id"] == "sub_123"
assert "connected_at" in info
assert "last_ping" in info
class TestConnectionManager:
"""Test WebSocket connection management."""
def test_initialization(self):
"""Test connection manager initialization."""
manager = ConnectionManager()
assert len(manager.connections) == 0
assert len(manager.subscriptions) == 0
assert manager.max_connections == 100
@pytest.mark.asyncio
async def test_add_connection(self, mock_websocket):
"""Test adding a connection."""
manager = ConnectionManager()
client_id = "client_001"
connection = await manager.add_connection(mock_websocket, client_id)
assert connection.client_id == client_id
assert client_id in manager.connections
assert manager.get_connection_count() == 1
@pytest.mark.asyncio
async def test_remove_connection(self, mock_websocket):
"""Test removing a connection."""
manager = ConnectionManager()
client_id = "client_001"
await manager.add_connection(mock_websocket, client_id)
assert client_id in manager.connections
removed_connection = await manager.remove_connection(client_id)
assert removed_connection is not None
assert removed_connection.client_id == client_id
assert client_id not in manager.connections
assert manager.get_connection_count() == 0
def test_get_connection(self, mock_websocket):
"""Test getting a connection."""
manager = ConnectionManager()
client_id = "client_001"
# Manually add connection for testing
connection = WebSocketConnection(mock_websocket, client_id)
manager.connections[client_id] = connection
retrieved_connection = manager.get_connection(client_id)
assert retrieved_connection == connection
assert retrieved_connection.client_id == client_id
def test_get_nonexistent_connection(self):
"""Test getting non-existent connection."""
manager = ConnectionManager()
connection = manager.get_connection("nonexistent_client")
assert connection is None
@pytest.mark.asyncio
async def test_broadcast_message(self, mock_websocket):
"""Test broadcasting message to all connections."""
manager = ConnectionManager()
# Add multiple connections
connections = []
for i in range(3):
client_id = f"client_{i}"
ws = Mock()
ws.send_json = AsyncMock()
connection = WebSocketConnection(ws, client_id)
connection.is_connected = True
manager.connections[client_id] = connection
connections.append(connection)
message = {"type": "broadcast", "data": "hello all"}
await manager.broadcast(message)
# All connections should have received the message
for connection in connections:
connection.websocket.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_broadcast_to_subscription(self, mock_websocket):
"""Test broadcasting to specific subscription."""
manager = ConnectionManager()
# Add connections with different subscriptions
subscription_id = "camera_001"
# Connection with target subscription
ws1 = Mock()
ws1.send_json = AsyncMock()
connection1 = WebSocketConnection(ws1, "client_001")
connection1.is_connected = True
connection1.subscription_id = subscription_id
manager.connections["client_001"] = connection1
manager.subscriptions[subscription_id] = {"client_001"}
# Connection with different subscription
ws2 = Mock()
ws2.send_json = AsyncMock()
connection2 = WebSocketConnection(ws2, "client_002")
connection2.is_connected = True
connection2.subscription_id = "camera_002"
manager.connections["client_002"] = connection2
manager.subscriptions["camera_002"] = {"client_002"}
message = {"type": "detection", "data": "camera detection"}
await manager.broadcast_to_subscription(subscription_id, message)
# Only connection1 should have received the message
ws1.send_json.assert_called_once_with(message)
ws2.send_json.assert_not_called()
def test_add_subscription(self):
"""Test adding subscription mapping."""
manager = ConnectionManager()
client_id = "client_001"
subscription_id = "camera_001"
manager.add_subscription(client_id, subscription_id)
assert subscription_id in manager.subscriptions
assert client_id in manager.subscriptions[subscription_id]
def test_remove_subscription(self):
"""Test removing subscription mapping."""
manager = ConnectionManager()
client_id = "client_001"
subscription_id = "camera_001"
# Add subscription first
manager.add_subscription(client_id, subscription_id)
assert client_id in manager.subscriptions[subscription_id]
# Remove subscription
manager.remove_subscription(client_id, subscription_id)
assert client_id not in manager.subscriptions.get(subscription_id, set())
def test_get_subscription_clients(self):
"""Test getting clients for a subscription."""
manager = ConnectionManager()
subscription_id = "camera_001"
clients = ["client_001", "client_002", "client_003"]
for client_id in clients:
manager.add_subscription(client_id, subscription_id)
subscription_clients = manager.get_subscription_clients(subscription_id)
assert subscription_clients == set(clients)
def test_get_client_subscriptions(self):
"""Test getting subscriptions for a client."""
manager = ConnectionManager()
client_id = "client_001"
subscriptions = ["camera_001", "camera_002", "camera_003"]
for subscription_id in subscriptions:
manager.add_subscription(client_id, subscription_id)
client_subscriptions = manager.get_client_subscriptions(client_id)
assert client_subscriptions == set(subscriptions)
@pytest.mark.asyncio
async def test_cleanup_disconnected_connections(self):
"""Test cleanup of disconnected connections."""
manager = ConnectionManager()
# Add connected and disconnected connections
ws1 = Mock()
connection1 = WebSocketConnection(ws1, "client_001")
connection1.is_connected = True
manager.connections["client_001"] = connection1
ws2 = Mock()
connection2 = WebSocketConnection(ws2, "client_002")
connection2.is_connected = False # Disconnected
manager.connections["client_002"] = connection2
# Add subscriptions
manager.add_subscription("client_001", "camera_001")
manager.add_subscription("client_002", "camera_002")
cleaned_count = await manager.cleanup_disconnected()
assert cleaned_count == 1
assert "client_001" in manager.connections # Still connected
assert "client_002" not in manager.connections # Cleaned up
# Subscriptions should also be cleaned up
assert manager.get_client_subscriptions("client_002") == set()
def test_get_connection_stats(self):
"""Test getting connection statistics."""
manager = ConnectionManager()
# Add various connections and subscriptions
for i in range(3):
client_id = f"client_{i}"
ws = Mock()
connection = WebSocketConnection(ws, client_id)
connection.is_connected = i < 2 # First 2 connected, last one disconnected
manager.connections[client_id] = connection
if i < 2: # Add subscriptions for connected clients
manager.add_subscription(client_id, f"camera_{i}")
stats = manager.get_connection_stats()
assert stats["total_connections"] == 3
assert stats["active_connections"] == 2
assert stats["total_subscriptions"] == 2
assert "uptime" in stats
class TestMessageHandler:
"""Test message handling functionality."""
def test_creation(self):
"""Test message handler creation."""
mock_processor = Mock()
handler = MessageHandler(mock_processor)
assert handler.message_processor == mock_processor
assert handler.connection_manager is None
def test_set_connection_manager(self):
"""Test setting connection manager."""
mock_processor = Mock()
mock_manager = Mock()
handler = MessageHandler(mock_processor)
handler.set_connection_manager(mock_manager)
assert handler.connection_manager == mock_manager
@pytest.mark.asyncio
async def test_handle_message_success(self, mock_websocket):
"""Test successful message handling."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(return_value={"type": "response", "status": "success"})
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
message = {"type": "subscribe", "payload": {"camera_id": "camera_001"}}
response = await handler.handle_message(connection, message)
assert response["status"] == "success"
mock_processor.process_message.assert_called_once_with(message, "client_001")
@pytest.mark.asyncio
async def test_handle_message_processing_error(self, mock_websocket):
"""Test message handling with processing error."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(side_effect=MessageProcessingError("Invalid message"))
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
message = {"type": "invalid", "payload": {}}
response = await handler.handle_message(connection, message)
assert response["type"] == "error"
assert "Invalid message" in response["message"]
@pytest.mark.asyncio
async def test_handle_message_unexpected_error(self, mock_websocket):
"""Test message handling with unexpected error."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(side_effect=Exception("Unexpected error"))
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
message = {"type": "test", "payload": {}}
response = await handler.handle_message(connection, message)
assert response["type"] == "error"
assert "internal error" in response["message"].lower()
@pytest.mark.asyncio
async def test_send_response(self, mock_websocket):
"""Test sending response to client."""
mock_processor = Mock()
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_json = AsyncMock()
response = {"type": "response", "data": "test response"}
await handler.send_response(connection, response)
mock_websocket.send_json.assert_called_once_with(response)
@pytest.mark.asyncio
async def test_send_error_response(self, mock_websocket):
"""Test sending error response."""
mock_processor = Mock()
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_json = AsyncMock()
await handler.send_error_response(connection, "Test error message", "TEST_ERROR")
mock_websocket.send_json.assert_called_once()
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == "error"
assert call_args["message"] == "Test error message"
assert call_args["error_code"] == "TEST_ERROR"
class TestWebSocketHandler:
"""Test main WebSocket handler functionality."""
def test_initialization(self):
"""Test WebSocket handler initialization."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
assert isinstance(handler.connection_manager, ConnectionManager)
assert isinstance(handler.message_handler, MessageHandler)
assert handler.message_handler.connection_manager == handler.connection_manager
assert handler.heartbeat_interval == 30.0
assert handler.max_connections == 100
def test_initialization_with_config(self):
"""Test initialization with custom configuration."""
mock_processor = Mock()
config = {
"heartbeat_interval": 60.0,
"max_connections": 200,
"connection_timeout": 300.0
}
handler = WebSocketHandler(mock_processor, config)
assert handler.heartbeat_interval == 60.0
assert handler.max_connections == 200
assert handler.connection_timeout == 300.0
@pytest.mark.asyncio
async def test_handle_websocket_connection(self, mock_websocket):
"""Test handling WebSocket connection."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(return_value={"type": "ack", "status": "success"})
handler = WebSocketHandler(mock_processor)
# Mock WebSocket behavior
mock_websocket.accept = AsyncMock()
mock_websocket.receive_json = AsyncMock(side_effect=[
{"type": "subscribe", "payload": {"camera_id": "camera_001"}},
WebSocketDisconnect() # Simulate disconnection
])
mock_websocket.send_json = AsyncMock()
client_id = "test_client_001"
# Handle connection (should not raise exception)
await handler.handle_websocket(mock_websocket, client_id)
# Verify connection was accepted
mock_websocket.accept.assert_called_once()
# Verify message was processed
mock_processor.process_message.assert_called_once()
@pytest.mark.asyncio
async def test_handle_websocket_max_connections(self, mock_websocket):
"""Test handling max connections limit."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor, {"max_connections": 1})
# Add one connection to reach limit
client1_ws = Mock()
connection1 = WebSocketConnection(client1_ws, "client_001")
handler.connection_manager.connections["client_001"] = connection1
mock_websocket.close = AsyncMock()
# Try to add second connection
await handler.handle_websocket(mock_websocket, "client_002")
# Should close connection due to limit
mock_websocket.close.assert_called_once()
@pytest.mark.asyncio
async def test_broadcast_message(self):
"""Test broadcasting message to all connections."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Mock connection manager
handler.connection_manager.broadcast = AsyncMock()
message = {"type": "system", "data": "Server maintenance in 10 minutes"}
await handler.broadcast_message(message)
handler.connection_manager.broadcast.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_send_to_client(self):
"""Test sending message to specific client."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Create mock connection
mock_websocket = Mock()
mock_websocket.send_json = AsyncMock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
handler.connection_manager.connections["client_001"] = connection
message = {"type": "notification", "data": "Personal message"}
result = await handler.send_to_client("client_001", message)
assert result is True
mock_websocket.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_send_to_nonexistent_client(self):
"""Test sending message to non-existent client."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
message = {"type": "notification", "data": "Message"}
result = await handler.send_to_client("nonexistent_client", message)
assert result is False
@pytest.mark.asyncio
async def test_send_to_subscription(self):
"""Test sending message to subscription."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Mock connection manager
handler.connection_manager.broadcast_to_subscription = AsyncMock()
subscription_id = "camera_001"
message = {"type": "detection", "data": {"class": "car", "confidence": 0.95}}
await handler.send_to_subscription(subscription_id, message)
handler.connection_manager.broadcast_to_subscription.assert_called_once_with(subscription_id, message)
@pytest.mark.asyncio
async def test_start_heartbeat_task(self):
"""Test starting heartbeat task."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor, {"heartbeat_interval": 0.1})
# Mock connection with ping capability
mock_websocket = Mock()
mock_websocket.ping = AsyncMock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
handler.connection_manager.connections["client_001"] = connection
# Start heartbeat task
heartbeat_task = asyncio.create_task(handler._heartbeat_loop())
# Let it run briefly
await asyncio.sleep(0.2)
# Cancel task
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
# Should have sent at least one ping
assert mock_websocket.ping.called
def test_get_connection_stats(self):
"""Test getting connection statistics."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Add some mock connections
for i in range(3):
client_id = f"client_{i}"
ws = Mock()
connection = WebSocketConnection(ws, client_id)
connection.is_connected = True
handler.connection_manager.connections[client_id] = connection
stats = handler.get_connection_stats()
assert stats["total_connections"] == 3
assert stats["active_connections"] == 3
def test_get_client_info(self):
"""Test getting client information."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Add mock connection
mock_websocket = Mock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
connection.subscription_id = "camera_001"
handler.connection_manager.connections["client_001"] = connection
info = handler.get_client_info("client_001")
assert info is not None
assert info["client_id"] == "client_001"
assert info["is_connected"] is True
assert info["subscription_id"] == "camera_001"
@pytest.mark.asyncio
async def test_disconnect_client(self):
"""Test disconnecting specific client."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Add mock connection
mock_websocket = Mock()
mock_websocket.close = AsyncMock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
handler.connection_manager.connections["client_001"] = connection
result = await handler.disconnect_client("client_001", code=1000, reason="Admin disconnect")
assert result is True
mock_websocket.close.assert_called_once_with(code=1000, reason="Admin disconnect")
@pytest.mark.asyncio
async def test_cleanup_connections(self):
"""Test cleanup of disconnected connections."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Mock connection manager cleanup
handler.connection_manager.cleanup_disconnected = AsyncMock(return_value=2)
cleaned_count = await handler.cleanup_connections()
assert cleaned_count == 2
handler.connection_manager.cleanup_disconnected.assert_called_once()
class TestWebSocketHandlerIntegration:
"""Integration tests for WebSocket handler."""
@pytest.mark.asyncio
async def test_complete_subscription_workflow(self, mock_websocket):
"""Test complete subscription workflow."""
mock_processor = Mock()
# Mock processor responses
mock_processor.process_message = AsyncMock(side_effect=[
{"type": "subscribeAck", "status": "success", "subscription_id": "camera_001"},
{"type": "unsubscribeAck", "status": "success"}
])
handler = WebSocketHandler(mock_processor)
# Mock WebSocket behavior
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock(side_effect=[
{"type": "subscribe", "payload": {"camera_id": "camera_001", "rtsp_url": "rtsp://example.com"}},
{"type": "unsubscribe", "payload": {"subscription_id": "camera_001"}},
WebSocketDisconnect()
])
client_id = "test_client"
# Handle complete workflow
await handler.handle_websocket(mock_websocket, client_id)
# Verify both messages were processed
assert mock_processor.process_message.call_count == 2
# Verify responses were sent
assert mock_websocket.send_json.call_count == 2
@pytest.mark.asyncio
async def test_multiple_client_management(self):
"""Test managing multiple concurrent clients."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
clients = []
for i in range(5):
client_id = f"client_{i}"
mock_ws = Mock()
mock_ws.send_json = AsyncMock()
connection = WebSocketConnection(mock_ws, client_id)
connection.is_connected = True
handler.connection_manager.connections[client_id] = connection
clients.append(connection)
# Test broadcasting to all clients
message = {"type": "broadcast", "data": "Hello all clients"}
await handler.broadcast_message(message)
# All clients should receive the message
for connection in clients:
connection.websocket.send_json.assert_called_once_with(message)
# Test subscription-specific messaging
subscription_id = "camera_001"
handler.connection_manager.add_subscription("client_0", subscription_id)
handler.connection_manager.add_subscription("client_2", subscription_id)
subscription_message = {"type": "detection", "camera_id": "camera_001"}
await handler.send_to_subscription(subscription_id, subscription_message)
# Only subscribed clients should receive the message
# Note: This would require additional mocking of broadcast_to_subscription
@pytest.mark.asyncio
async def test_error_handling_and_recovery(self, mock_websocket):
"""Test error handling and recovery scenarios."""
mock_processor = Mock()
# First message causes error, second succeeds
mock_processor.process_message = AsyncMock(side_effect=[
MessageProcessingError("Invalid message format"),
{"type": "ack", "status": "success"}
])
handler = WebSocketHandler(mock_processor)
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock(side_effect=[
{"type": "invalid", "malformed": True},
{"type": "valid", "payload": {"test": True}},
WebSocketDisconnect()
])
client_id = "error_test_client"
# Should handle errors gracefully and continue processing
await handler.handle_websocket(mock_websocket, client_id)
# Both messages should have been processed
assert mock_processor.process_message.call_count == 2
# Should have sent error response and success response
assert mock_websocket.send_json.call_count == 2
# First call should be error response
first_response = mock_websocket.send_json.call_args_list[0][0][0]
assert first_response["type"] == "error"
@pytest.mark.asyncio
async def test_connection_timeout_handling(self):
"""Test connection timeout handling."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor, {"connection_timeout": 0.1})
# Add connection that hasn't been active
mock_websocket = Mock()
connection = WebSocketConnection(mock_websocket, "timeout_client")
connection.is_connected = True
# Don't update last_ping to simulate timeout
handler.connection_manager.connections["timeout_client"] = connection
# Wait longer than timeout
await asyncio.sleep(0.2)
# Manual cleanup (in real implementation this would be automatic)
cleaned = await handler.cleanup_connections()
# Connection should be identified for cleanup
# (Actual timeout logic would need to be implemented in the cleanup method)

View file

@ -0,0 +1,429 @@
"""
Unit tests for configuration management system.
"""
import pytest
import json
import os
import tempfile
from unittest.mock import Mock, patch, MagicMock
from detector_worker.core.config import (
ConfigurationManager,
JsonFileProvider,
EnvironmentProvider,
DatabaseConfig,
RedisConfig,
StreamConfig,
ModelConfig,
LoggingConfig,
get_config_manager,
validate_config
)
from detector_worker.core.exceptions import ConfigurationError
class TestJsonFileProvider:
"""Test JSON file configuration provider."""
def test_get_config_from_valid_file(self, temp_dir):
"""Test loading configuration from a valid JSON file."""
config_data = {"test_key": "test_value", "number": 42}
config_file = os.path.join(temp_dir, "config.json")
with open(config_file, 'w') as f:
json.dump(config_data, f)
provider = JsonFileProvider(config_file)
result = provider.get_config()
assert result == config_data
def test_get_config_file_not_exists(self, temp_dir):
"""Test handling of non-existent config file."""
config_file = os.path.join(temp_dir, "nonexistent.json")
provider = JsonFileProvider(config_file)
result = provider.get_config()
assert result == {}
def test_get_config_invalid_json(self, temp_dir):
"""Test handling of invalid JSON file."""
config_file = os.path.join(temp_dir, "invalid.json")
with open(config_file, 'w') as f:
f.write("invalid json content")
provider = JsonFileProvider(config_file)
result = provider.get_config()
assert result == {}
def test_reload_updates_config(self, temp_dir):
"""Test that reload updates configuration."""
config_file = os.path.join(temp_dir, "config.json")
# Initial config
initial_config = {"version": 1}
with open(config_file, 'w') as f:
json.dump(initial_config, f)
provider = JsonFileProvider(config_file)
assert provider.get_config() == initial_config
# Update config file
updated_config = {"version": 2}
with open(config_file, 'w') as f:
json.dump(updated_config, f)
# Force reload
provider.reload()
assert provider.get_config() == updated_config
class TestEnvironmentProvider:
"""Test environment variable configuration provider."""
def test_get_config_with_env_vars(self):
"""Test loading configuration from environment variables."""
env_vars = {
"DETECTOR_MAX_STREAMS": "10",
"DETECTOR_TARGET_FPS": "15",
"DETECTOR_CONFIG": '{"nested": "value"}',
"OTHER_VAR": "ignored"
}
with patch.dict(os.environ, env_vars, clear=False):
provider = EnvironmentProvider("DETECTOR_")
config = provider.get_config()
assert config["max_streams"] == "10"
assert config["target_fps"] == "15"
assert config["config"] == {"nested": "value"}
assert "other_var" not in config
def test_get_config_no_env_vars(self):
"""Test with no matching environment variables."""
with patch.dict(os.environ, {}, clear=True):
provider = EnvironmentProvider("DETECTOR_")
config = provider.get_config()
assert config == {}
def test_custom_prefix(self):
"""Test with custom prefix."""
env_vars = {"CUSTOM_TEST": "value"}
with patch.dict(os.environ, env_vars, clear=False):
provider = EnvironmentProvider("CUSTOM_")
config = provider.get_config()
assert config["test"] == "value"
class TestConfigDataclasses:
"""Test configuration dataclasses."""
def test_database_config_from_dict(self):
"""Test DatabaseConfig creation from dictionary."""
data = {
"enabled": True,
"host": "db.example.com",
"port": 5432,
"database": "testdb",
"username": "user",
"password": "pass",
"schema": "test_schema",
"unknown_field": "ignored"
}
config = DatabaseConfig.from_dict(data)
assert config.enabled is True
assert config.host == "db.example.com"
assert config.port == 5432
assert config.database == "testdb"
assert config.username == "user"
assert config.password == "pass"
assert config.schema == "test_schema"
# Unknown fields should be ignored
assert not hasattr(config, 'unknown_field')
def test_redis_config_from_dict(self):
"""Test RedisConfig creation from dictionary."""
data = {
"enabled": True,
"host": "redis.example.com",
"port": 6379,
"password": "secret",
"db": 1
}
config = RedisConfig.from_dict(data)
assert config.enabled is True
assert config.host == "redis.example.com"
assert config.port == 6379
assert config.password == "secret"
assert config.db == 1
def test_stream_config_from_dict(self):
"""Test StreamConfig creation from dictionary."""
data = {
"poll_interval_ms": 50,
"max_streams": 10,
"target_fps": 20,
"reconnect_interval_sec": 10,
"max_retries": 5
}
config = StreamConfig.from_dict(data)
assert config.poll_interval_ms == 50
assert config.max_streams == 10
assert config.target_fps == 20
assert config.reconnect_interval_sec == 10
assert config.max_retries == 5
class TestConfigurationManager:
"""Test main configuration manager."""
def test_initialization_with_defaults(self):
"""Test that manager initializes with default values."""
manager = ConfigurationManager()
# Should have default providers
assert len(manager._providers) >= 1
# Should have default configuration values
config = manager.get_all()
assert "poll_interval_ms" in config
assert "max_streams" in config
assert "target_fps" in config
def test_add_provider(self):
"""Test adding configuration providers."""
manager = ConfigurationManager()
initial_count = len(manager._providers)
mock_provider = Mock()
mock_provider.get_config.return_value = {"test": "value"}
manager.add_provider(mock_provider)
assert len(manager._providers) == initial_count + 1
def test_get_configuration_value(self):
"""Test getting specific configuration values."""
manager = ConfigurationManager()
# Test existing key
value = manager.get("poll_interval_ms")
assert value is not None
# Test non-existing key with default
value = manager.get("nonexistent", "default")
assert value == "default"
# Test non-existing key without default
value = manager.get("nonexistent")
assert value is None
def test_get_section(self):
"""Test getting configuration sections."""
manager = ConfigurationManager()
# Test existing section
db_section = manager.get_section("database")
assert isinstance(db_section, dict)
# Test non-existing section
empty_section = manager.get_section("nonexistent")
assert empty_section == {}
def test_typed_config_access(self):
"""Test typed configuration object access."""
manager = ConfigurationManager()
# Test database config
db_config = manager.get_database_config()
assert isinstance(db_config, DatabaseConfig)
# Test Redis config
redis_config = manager.get_redis_config()
assert isinstance(redis_config, RedisConfig)
# Test stream config
stream_config = manager.get_stream_config()
assert isinstance(stream_config, StreamConfig)
# Test model config
model_config = manager.get_model_config()
assert isinstance(model_config, ModelConfig)
# Test logging config
logging_config = manager.get_logging_config()
assert isinstance(logging_config, LoggingConfig)
def test_set_configuration_value(self):
"""Test setting configuration values at runtime."""
manager = ConfigurationManager()
manager.set("test_key", "test_value")
assert manager.get("test_key") == "test_value"
# Should also update typed configs
manager.set("poll_interval_ms", 200)
stream_config = manager.get_stream_config()
assert stream_config.poll_interval_ms == 200
def test_validation_success(self):
"""Test configuration validation with valid config."""
manager = ConfigurationManager()
# Set valid configuration
manager.set("poll_interval_ms", 100)
manager.set("max_streams", 5)
manager.set("target_fps", 10)
errors = manager.validate()
assert errors == []
assert manager.is_valid() is True
def test_validation_errors(self):
"""Test configuration validation with invalid values."""
manager = ConfigurationManager()
# Set invalid configuration
manager.set("poll_interval_ms", 0)
manager.set("max_streams", -1)
manager.set("target_fps", 0)
errors = manager.validate()
assert len(errors) > 0
assert manager.is_valid() is False
# Check specific errors
error_messages = " ".join(errors)
assert "poll_interval_ms must be positive" in error_messages
assert "max_streams must be positive" in error_messages
assert "target_fps must be positive" in error_messages
def test_database_validation(self):
"""Test database-specific validation."""
manager = ConfigurationManager()
# Enable database but don't provide required fields
db_config = {
"enabled": True,
"host": "",
"database": ""
}
manager.set("database", db_config)
errors = manager.validate()
error_messages = " ".join(errors)
assert "database host is required" in error_messages
assert "database name is required" in error_messages
def test_redis_validation(self):
"""Test Redis-specific validation."""
manager = ConfigurationManager()
# Enable Redis but don't provide required fields
redis_config = {
"enabled": True,
"host": ""
}
manager.set("redis", redis_config)
errors = manager.validate()
error_messages = " ".join(errors)
assert "redis host is required" in error_messages
class TestGlobalConfigurationFunctions:
"""Test global configuration functions."""
def test_get_config_manager_singleton(self):
"""Test that get_config_manager returns a singleton."""
manager1 = get_config_manager()
manager2 = get_config_manager()
assert manager1 is manager2
@patch('detector_worker.core.config.get_config_manager')
def test_validate_config_function(self, mock_get_manager):
"""Test global validate_config function."""
mock_manager = Mock()
mock_manager.validate.return_value = ["error1", "error2"]
mock_get_manager.return_value = mock_manager
errors = validate_config()
assert errors == ["error1", "error2"]
mock_manager.validate.assert_called_once()
class TestConfigurationIntegration:
"""Integration tests for configuration system."""
def test_provider_priority(self, temp_dir):
"""Test that later providers override earlier ones."""
# Create JSON file with initial config
config_file = os.path.join(temp_dir, "config.json")
json_config = {"test_value": "from_json", "json_only": "json"}
with open(config_file, 'w') as f:
json.dump(json_config, f)
# Set environment variable that should override
env_vars = {"DETECTOR_TEST_VALUE": "from_env", "DETECTOR_ENV_ONLY": "env"}
with patch.dict(os.environ, env_vars, clear=False):
manager = ConfigurationManager()
manager._providers.clear() # Start fresh
# Add providers in order
manager.add_provider(JsonFileProvider(config_file))
manager.add_provider(EnvironmentProvider("DETECTOR_"))
config = manager.get_all()
# Environment should override JSON
assert config["test_value"] == "from_env"
# Both sources should be present
assert config["json_only"] == "json"
assert config["env_only"] == "env"
def test_hot_reload(self, temp_dir):
"""Test configuration hot reload functionality."""
config_file = os.path.join(temp_dir, "config.json")
# Initial config
initial_config = {"version": 1, "feature_enabled": False}
with open(config_file, 'w') as f:
json.dump(initial_config, f)
manager = ConfigurationManager()
manager._providers.clear()
manager.add_provider(JsonFileProvider(config_file))
assert manager.get("version") == 1
assert manager.get("feature_enabled") is False
# Update config file
updated_config = {"version": 2, "feature_enabled": True}
with open(config_file, 'w') as f:
json.dump(updated_config, f)
# Reload configuration
success = manager.reload()
assert success is True
assert manager.get("version") == 2
assert manager.get("feature_enabled") is True

View file

@ -0,0 +1,566 @@
"""
Unit tests for dependency injection system.
"""
import pytest
import threading
from unittest.mock import Mock, MagicMock
from detector_worker.core.dependency_injection import (
ServiceContainer,
ServiceLifetime,
ServiceDescriptor,
ServiceScope,
DetectorWorkerContainer,
get_container,
resolve_service,
create_service_scope
)
from detector_worker.core.exceptions import DependencyInjectionError
class TestServiceContainer:
"""Test core service container functionality."""
def test_register_singleton(self):
"""Test singleton service registration."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
# Register singleton
container.register_singleton(TestService)
# Resolve twice - should get same instance
instance1 = container.resolve(TestService)
instance2 = container.resolve(TestService)
assert instance1 is instance2
assert instance1.value == 42
def test_register_singleton_with_instance(self):
"""Test singleton registration with pre-created instance."""
container = ServiceContainer()
class TestService:
def __init__(self, value):
self.value = value
# Create instance and register
instance = TestService(99)
container.register_singleton(TestService, instance=instance)
# Resolve should return the pre-created instance
resolved = container.resolve(TestService)
assert resolved is instance
assert resolved.value == 99
def test_register_transient(self):
"""Test transient service registration."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
# Register transient
container.register_transient(TestService)
# Resolve twice - should get different instances
instance1 = container.resolve(TestService)
instance2 = container.resolve(TestService)
assert instance1 is not instance2
assert instance1.value == instance2.value == 42
def test_register_scoped(self):
"""Test scoped service registration."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
# Register scoped
container.register_scoped(TestService)
# Resolve in same scope - should get same instance
instance1 = container.resolve(TestService, scope_id="scope1")
instance2 = container.resolve(TestService, scope_id="scope1")
assert instance1 is instance2
# Resolve in different scope - should get different instance
instance3 = container.resolve(TestService, scope_id="scope2")
assert instance3 is not instance1
def test_register_with_factory(self):
"""Test service registration with factory function."""
container = ServiceContainer()
class TestService:
def __init__(self, value):
self.value = value
# Register with factory
def factory():
return TestService(100)
container.register_singleton(TestService, factory=factory)
instance = container.resolve(TestService)
assert instance.value == 100
def test_register_with_implementation_type(self):
"""Test service registration with implementation type."""
container = ServiceContainer()
class ITestService:
pass
class TestService(ITestService):
def __init__(self):
self.value = 42
# Register interface with implementation
container.register_singleton(ITestService, implementation_type=TestService)
instance = container.resolve(ITestService)
assert isinstance(instance, TestService)
assert instance.value == 42
def test_dependency_injection(self):
"""Test automatic dependency injection."""
container = ServiceContainer()
class DatabaseService:
def __init__(self):
self.connected = True
class UserService:
def __init__(self, database: DatabaseService):
self.database = database
# Register services
container.register_singleton(DatabaseService)
container.register_transient(UserService)
# Resolve should inject dependencies
user_service = container.resolve(UserService)
assert isinstance(user_service.database, DatabaseService)
assert user_service.database.connected is True
def test_circular_dependency_detection(self):
"""Test circular dependency detection."""
container = ServiceContainer()
class ServiceA:
def __init__(self, service_b: 'ServiceB'):
self.service_b = service_b
class ServiceB:
def __init__(self, service_a: ServiceA):
self.service_a = service_a
# Register circular dependencies
container.register_singleton(ServiceA)
container.register_singleton(ServiceB)
# Should raise circular dependency error
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(ServiceA)
assert "Circular dependency detected" in str(exc_info.value)
def test_unregistered_service_error(self):
"""Test error when resolving unregistered service."""
container = ServiceContainer()
class UnregisteredService:
pass
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(UnregisteredService)
assert "is not registered" in str(exc_info.value)
def test_scoped_service_without_scope_id(self):
"""Test error when resolving scoped service without scope ID."""
container = ServiceContainer()
class TestService:
pass
container.register_scoped(TestService)
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(TestService)
assert "Scope ID required" in str(exc_info.value)
def test_factory_error_handling(self):
"""Test factory error handling."""
container = ServiceContainer()
class TestService:
pass
def failing_factory():
raise ValueError("Factory failed")
container.register_singleton(TestService, factory=failing_factory)
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(TestService)
assert "Failed to create service using factory" in str(exc_info.value)
def test_constructor_dependency_with_default(self):
"""Test dependency with default value."""
container = ServiceContainer()
class TestService:
def __init__(self, value: int = 42):
self.value = value
container.register_singleton(TestService)
instance = container.resolve(TestService)
assert instance.value == 42
def test_unresolvable_dependency_with_default(self):
"""Test unresolvable dependency that has a default value."""
container = ServiceContainer()
class UnregisteredService:
pass
class TestService:
def __init__(self, dep: UnregisteredService = None):
self.dep = dep
container.register_singleton(TestService)
instance = container.resolve(TestService)
assert instance.dep is None
def test_unresolvable_dependency_without_default(self):
"""Test unresolvable dependency without default value."""
container = ServiceContainer()
class UnregisteredService:
pass
class TestService:
def __init__(self, dep: UnregisteredService):
self.dep = dep
container.register_singleton(TestService)
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(TestService)
assert "Cannot resolve dependency" in str(exc_info.value)
class TestServiceScope:
"""Test service scope functionality."""
def test_create_scope(self):
"""Test scope creation."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
container.register_scoped(TestService)
scope = container.create_scope("test_scope")
assert isinstance(scope, ServiceScope)
assert scope.scope_id == "test_scope"
def test_scope_context_manager(self):
"""Test scope as context manager."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.disposed = False
def dispose(self):
self.disposed = True
container.register_scoped(TestService)
instance = None
with container.create_scope("test_scope") as scope:
instance = scope.resolve(TestService)
assert not instance.disposed
# Instance should be disposed after scope exit
assert instance.disposed
def test_dispose_scope(self):
"""Test manual scope disposal."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.disposed = False
def dispose(self):
self.disposed = True
container.register_scoped(TestService)
instance = container.resolve(TestService, scope_id="test_scope")
assert not instance.disposed
container.dispose_scope("test_scope")
assert instance.disposed
def test_dispose_error_handling(self):
"""Test error handling during scope disposal."""
container = ServiceContainer()
class TestService:
def dispose(self):
raise ValueError("Dispose failed")
container.register_scoped(TestService)
container.resolve(TestService, scope_id="test_scope")
# Should not raise error, just log it
container.dispose_scope("test_scope")
class TestContainerIntrospection:
"""Test container introspection capabilities."""
def test_is_registered(self):
"""Test checking if service is registered."""
container = ServiceContainer()
class RegisteredService:
pass
class UnregisteredService:
pass
container.register_singleton(RegisteredService)
assert container.is_registered(RegisteredService) is True
assert container.is_registered(UnregisteredService) is False
def test_get_registration_info(self):
"""Test getting service registration information."""
container = ServiceContainer()
class TestService:
pass
container.register_singleton(TestService)
info = container.get_registration_info(TestService)
assert isinstance(info, ServiceDescriptor)
assert info.service_type == TestService
assert info.lifetime == ServiceLifetime.SINGLETON
def test_get_registered_services(self):
"""Test getting all registered services."""
container = ServiceContainer()
class Service1:
pass
class Service2:
pass
container.register_singleton(Service1)
container.register_transient(Service2)
services = container.get_registered_services()
assert len(services) == 2
assert Service1 in services
assert Service2 in services
def test_clear_singletons(self):
"""Test clearing singleton instances."""
container = ServiceContainer()
class TestService:
pass
container.register_singleton(TestService)
# Create singleton instance
instance1 = container.resolve(TestService)
# Clear singletons
container.clear_singletons()
# Next resolve should create new instance
instance2 = container.resolve(TestService)
assert instance2 is not instance1
def test_get_stats(self):
"""Test getting container statistics."""
container = ServiceContainer()
class Service1:
pass
class Service2:
pass
class Service3:
pass
container.register_singleton(Service1)
container.register_transient(Service2)
container.register_scoped(Service3)
# Create some instances
container.resolve(Service1)
container.resolve(Service3, scope_id="scope1")
stats = container.get_stats()
assert stats["registered_services"] == 3
assert stats["active_singletons"] == 1
assert stats["active_scopes"] == 1
assert stats["lifetime_breakdown"]["singleton"] == 1
assert stats["lifetime_breakdown"]["transient"] == 1
assert stats["lifetime_breakdown"]["scoped"] == 1
class TestDetectorWorkerContainer:
"""Test pre-configured detector worker container."""
def test_initialization(self):
"""Test detector worker container initialization."""
container = DetectorWorkerContainer()
assert isinstance(container.container, ServiceContainer)
# Should have core services registered
stats = container.container.get_stats()
assert stats["registered_services"] > 0
def test_resolve_convenience_method(self):
"""Test resolve convenience method."""
container = DetectorWorkerContainer()
# Should be able to resolve through convenience method
from detector_worker.core.singleton_managers import ModelStateManager
manager = container.resolve(ModelStateManager)
assert isinstance(manager, ModelStateManager)
def test_create_scope_convenience_method(self):
"""Test create scope convenience method."""
container = DetectorWorkerContainer()
scope = container.create_scope("test_scope")
assert isinstance(scope, ServiceScope)
assert scope.scope_id == "test_scope"
class TestGlobalContainerFunctions:
"""Test global container functions."""
def test_get_container_singleton(self):
"""Test that get_container returns a singleton."""
container1 = get_container()
container2 = get_container()
assert container1 is container2
assert isinstance(container1, DetectorWorkerContainer)
def test_resolve_service_convenience(self):
"""Test resolve_service convenience function."""
from detector_worker.core.singleton_managers import ModelStateManager
manager = resolve_service(ModelStateManager)
assert isinstance(manager, ModelStateManager)
def test_create_service_scope_convenience(self):
"""Test create_service_scope convenience function."""
scope = create_service_scope("test_scope")
assert isinstance(scope, ServiceScope)
assert scope.scope_id == "test_scope"
class TestThreadSafety:
"""Test thread safety of dependency injection system."""
def test_container_thread_safety(self):
"""Test that container is thread-safe."""
container = ServiceContainer()
class TestService:
def __init__(self):
import threading
self.thread_id = threading.current_thread().ident
container.register_singleton(TestService)
instances = {}
def resolve_service(thread_id):
instances[thread_id] = container.resolve(TestService)
# Create multiple threads
threads = []
for i in range(10):
thread = threading.Thread(target=resolve_service, args=(i,))
threads.append(thread)
thread.start()
# Wait for all threads
for thread in threads:
thread.join()
# All should get the same singleton instance
first_instance = list(instances.values())[0]
for instance in instances.values():
assert instance is first_instance
def test_scope_thread_safety(self):
"""Test that scoped services are thread-safe."""
container = ServiceContainer()
class TestService:
def __init__(self):
import threading
self.thread_id = threading.current_thread().ident
container.register_scoped(TestService)
results = {}
def resolve_in_scope(thread_id):
# Each thread uses its own scope
instance1 = container.resolve(TestService, scope_id=f"scope_{thread_id}")
instance2 = container.resolve(TestService, scope_id=f"scope_{thread_id}")
results[thread_id] = {
"same_instance": instance1 is instance2,
"thread_id": instance1.thread_id
}
threads = []
for i in range(5):
thread = threading.Thread(target=resolve_in_scope, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Each thread should get same instance within its scope
for thread_id, result in results.items():
assert result["same_instance"] is True

View file

@ -0,0 +1,560 @@
"""
Unit tests for singleton state managers.
"""
import pytest
import time
import threading
from unittest.mock import Mock, patch, MagicMock
from detector_worker.core.singleton_managers import (
SingletonMeta,
ModelStateManager,
StreamStateManager,
SessionStateManager,
CacheStateManager,
CameraStateManager,
PipelineStateManager,
ModelInfo,
StreamInfo,
SessionInfo
)
class TestSingletonMeta:
"""Test singleton metaclass."""
def test_singleton_behavior(self):
"""Test that singleton metaclass creates only one instance."""
class TestSingleton(metaclass=SingletonMeta):
def __init__(self):
self.value = 42
instance1 = TestSingleton()
instance2 = TestSingleton()
assert instance1 is instance2
assert instance1.value == instance2.value
def test_singleton_thread_safety(self):
"""Test that singleton is thread-safe."""
class TestSingleton(metaclass=SingletonMeta):
def __init__(self):
self.created_by = threading.current_thread().name
instances = {}
def create_instance(thread_id):
instances[thread_id] = TestSingleton()
threads = []
for i in range(10):
thread = threading.Thread(target=create_instance, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# All instances should be the same object
first_instance = instances[0]
for instance in instances.values():
assert instance is first_instance
class TestModelStateManager:
"""Test model state management."""
def test_singleton_behavior(self):
"""Test that ModelStateManager is a singleton."""
manager1 = ModelStateManager()
manager2 = ModelStateManager()
assert manager1 is manager2
def test_load_model(self):
"""Test loading a model."""
manager = ModelStateManager()
manager.clear_all() # Start fresh
mock_model = Mock()
manager.load_model("camera1", "model1", mock_model)
retrieved_model = manager.get_model("camera1", "model1")
assert retrieved_model is mock_model
def test_load_same_model_increments_reference_count(self):
"""Test that loading the same model increments reference count."""
manager = ModelStateManager()
manager.clear_all()
mock_model = Mock()
# Load same model twice
manager.load_model("camera1", "model1", mock_model)
manager.load_model("camera1", "model1", mock_model)
# Should still be accessible
assert manager.get_model("camera1", "model1") is mock_model
def test_get_camera_models(self):
"""Test getting all models for a camera."""
manager = ModelStateManager()
manager.clear_all()
mock_model1 = Mock()
mock_model2 = Mock()
manager.load_model("camera1", "model1", mock_model1)
manager.load_model("camera1", "model2", mock_model2)
models = manager.get_camera_models("camera1")
assert len(models) == 2
assert models["model1"] is mock_model1
assert models["model2"] is mock_model2
def test_unload_model_with_multiple_references(self):
"""Test unloading model with multiple references."""
manager = ModelStateManager()
manager.clear_all()
mock_model = Mock()
# Load model twice (reference count = 2)
manager.load_model("camera1", "model1", mock_model)
manager.load_model("camera1", "model1", mock_model)
# First unload should not remove model
result = manager.unload_model("camera1", "model1")
assert result is False # Still referenced
assert manager.get_model("camera1", "model1") is mock_model
# Second unload should remove model
result = manager.unload_model("camera1", "model1")
assert result is True # Completely removed
assert manager.get_model("camera1", "model1") is None
def test_unload_camera_models(self):
"""Test unloading all models for a camera."""
manager = ModelStateManager()
manager.clear_all()
mock_model1 = Mock()
mock_model2 = Mock()
manager.load_model("camera1", "model1", mock_model1)
manager.load_model("camera1", "model2", mock_model2)
manager.unload_camera_models("camera1")
assert manager.get_model("camera1", "model1") is None
assert manager.get_model("camera1", "model2") is None
def test_get_stats(self):
"""Test getting model statistics."""
manager = ModelStateManager()
manager.clear_all()
mock_model = Mock()
manager.load_model("camera1", "model1", mock_model)
manager.load_model("camera2", "model2", mock_model)
stats = manager.get_stats()
assert stats["total_models"] == 2
assert stats["total_cameras"] == 2
assert "camera1" in stats["cameras"]
assert "camera2" in stats["cameras"]
class TestStreamStateManager:
"""Test stream state management."""
def test_add_stream(self):
"""Test adding a stream."""
manager = StreamStateManager()
manager.clear_all()
config = {"rtsp_url": "rtsp://example.com", "model_id": "test"}
manager.add_stream("camera1", "sub1", config)
stream = manager.get_stream("camera1")
assert stream is not None
assert stream.camera_id == "camera1"
assert stream.subscription_id == "sub1"
assert stream.config == config
def test_subscription_mapping(self):
"""Test subscription to camera mapping."""
manager = StreamStateManager()
manager.clear_all()
config = {"rtsp_url": "rtsp://example.com"}
manager.add_stream("camera1", "sub1", config)
camera_id = manager.get_camera_by_subscription("sub1")
assert camera_id == "camera1"
def test_remove_stream(self):
"""Test removing a stream."""
manager = StreamStateManager()
manager.clear_all()
config = {"rtsp_url": "rtsp://example.com"}
manager.add_stream("camera1", "sub1", config)
removed_stream = manager.remove_stream("camera1")
assert removed_stream is not None
assert removed_stream.camera_id == "camera1"
assert manager.get_stream("camera1") is None
assert manager.get_camera_by_subscription("sub1") is None
def test_shared_stream_management(self):
"""Test shared stream management."""
manager = StreamStateManager()
manager.clear_all()
stream_data = {"reader": Mock(), "reference_count": 1}
manager.add_shared_stream("rtsp://example.com", stream_data)
retrieved_data = manager.get_shared_stream("rtsp://example.com")
assert retrieved_data == stream_data
removed_data = manager.remove_shared_stream("rtsp://example.com")
assert removed_data == stream_data
assert manager.get_shared_stream("rtsp://example.com") is None
class TestSessionStateManager:
"""Test session state management."""
def test_session_id_management(self):
"""Test session ID assignment."""
manager = SessionStateManager()
manager.clear_all()
manager.set_session_id("display1", "session123")
session_id = manager.get_session_id("display1")
assert session_id == "session123"
def test_create_session(self):
"""Test session creation with detection data."""
manager = SessionStateManager()
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", detection_data)
retrieved_data = manager.get_session_detection("session123")
assert retrieved_data == detection_data
camera_id = manager.get_camera_by_session("session123")
assert camera_id == "camera1"
def test_update_session_detection(self):
"""Test updating session detection data."""
manager = SessionStateManager()
manager.clear_all()
initial_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", initial_data)
update_data = {"brand": "Toyota"}
manager.update_session_detection("session123", update_data)
final_data = manager.get_session_detection("session123")
assert final_data["class"] == "car"
assert final_data["brand"] == "Toyota"
def test_session_expiration(self):
"""Test session expiration based on TTL."""
# Use a very short TTL for testing
manager = SessionStateManager(session_ttl=0.1)
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", detection_data)
# Session should exist initially
assert manager.get_session_detection("session123") is not None
# Wait for expiration
time.sleep(0.2)
# Clean up expired sessions
expired_count = manager.cleanup_expired_sessions()
assert expired_count == 1
assert manager.get_session_detection("session123") is None
def test_remove_session(self):
"""Test manual session removal."""
manager = SessionStateManager()
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", detection_data)
result = manager.remove_session("session123")
assert result is True
assert manager.get_session_detection("session123") is None
assert manager.get_camera_by_session("session123") is None
class TestCacheStateManager:
"""Test cache state management."""
def test_cache_detection(self):
"""Test caching detection results."""
manager = CacheStateManager()
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85, "bbox": [100, 200, 300, 400]}
manager.cache_detection("camera1", detection_data)
cached_data = manager.get_cached_detection("camera1")
assert cached_data == detection_data
def test_cache_pipeline_result(self):
"""Test caching pipeline results."""
manager = CacheStateManager()
manager.clear_all()
pipeline_result = {"status": "success", "detections": []}
manager.cache_pipeline_result("camera1", pipeline_result)
cached_result = manager.get_cached_pipeline_result("camera1")
assert cached_result == pipeline_result
def test_latest_frame_management(self):
"""Test latest frame storage."""
manager = CacheStateManager()
manager.clear_all()
frame_data = b"fake_frame_data"
manager.set_latest_frame("camera1", frame_data)
retrieved_frame = manager.get_latest_frame("camera1")
assert retrieved_frame == frame_data
def test_frame_skip_flag(self):
"""Test frame skip flag management."""
manager = CacheStateManager()
manager.clear_all()
# Initially should be False
assert manager.get_frame_skip_flag("camera1") is False
manager.set_frame_skip_flag("camera1", True)
assert manager.get_frame_skip_flag("camera1") is True
manager.set_frame_skip_flag("camera1", False)
assert manager.get_frame_skip_flag("camera1") is False
def test_clear_camera_cache(self):
"""Test clearing all cache data for a camera."""
manager = CacheStateManager()
manager.clear_all()
# Set up cache data
detection_data = {"class": "car"}
pipeline_result = {"status": "success"}
frame_data = b"frame"
manager.cache_detection("camera1", detection_data)
manager.cache_pipeline_result("camera1", pipeline_result)
manager.set_latest_frame("camera1", frame_data)
manager.set_frame_skip_flag("camera1", True)
# Clear cache
manager.clear_camera_cache("camera1")
# All data should be gone
assert manager.get_cached_detection("camera1") is None
assert manager.get_cached_pipeline_result("camera1") is None
assert manager.get_latest_frame("camera1") is None
assert manager.get_frame_skip_flag("camera1") is False
class TestCameraStateManager:
"""Test camera state management."""
def test_camera_connection_state(self):
"""Test camera connection state management."""
manager = CameraStateManager()
manager.clear_all()
# Initially connected (default)
assert manager.is_camera_connected("camera1") is True
# Set disconnected
manager.set_camera_connected("camera1", False)
assert manager.is_camera_connected("camera1") is False
# Set connected again
manager.set_camera_connected("camera1", True)
assert manager.is_camera_connected("camera1") is True
def test_notification_flags(self):
"""Test disconnection/reconnection notification flags."""
manager = CameraStateManager()
manager.clear_all()
# Set disconnected
manager.set_camera_connected("camera1", False)
# Should notify disconnection once
assert manager.should_notify_disconnection("camera1") is True
manager.mark_disconnection_notified("camera1")
assert manager.should_notify_disconnection("camera1") is False
# Reconnect
manager.set_camera_connected("camera1", True)
# Should notify reconnection
assert manager.should_notify_reconnection("camera1") is True
manager.mark_reconnection_notified("camera1")
assert manager.should_notify_reconnection("camera1") is False
def test_get_camera_state(self):
"""Test getting full camera state."""
manager = CameraStateManager()
manager.clear_all()
manager.set_camera_connected("camera1", False)
state = manager.get_camera_state("camera1")
assert state["connected"] is False
assert "last_update" in state
assert "disconnection_notified" in state
assert "reconnection_notified" in state
def test_get_stats(self):
"""Test getting camera state statistics."""
manager = CameraStateManager()
manager.clear_all()
manager.set_camera_connected("camera1", True)
manager.set_camera_connected("camera2", False)
stats = manager.get_stats()
assert stats["total_cameras"] == 2
assert stats["connected_cameras"] == 1
assert stats["disconnected_cameras"] == 1
class TestPipelineStateManager:
"""Test pipeline state management."""
def test_get_or_init_state(self):
"""Test getting or initializing pipeline state."""
manager = PipelineStateManager()
manager.clear_all()
state = manager.get_or_init_state("camera1")
assert state["mode"] == "validation_detecting"
assert state["backend_session_id"] is None
assert state["yolo_inference_enabled"] is True
assert "created_at" in state
def test_update_mode(self):
"""Test updating pipeline mode."""
manager = PipelineStateManager()
manager.clear_all()
manager.update_mode("camera1", "classification", "session123")
state = manager.get_state("camera1")
assert state["mode"] == "classification"
assert state["backend_session_id"] == "session123"
def test_set_yolo_inference_enabled(self):
"""Test setting YOLO inference state."""
manager = PipelineStateManager()
manager.clear_all()
manager.set_yolo_inference_enabled("camera1", False)
state = manager.get_state("camera1")
assert state["yolo_inference_enabled"] is False
def test_set_progression_stage(self):
"""Test setting progression stage."""
manager = PipelineStateManager()
manager.clear_all()
manager.set_progression_stage("camera1", "brand_classification")
state = manager.get_state("camera1")
assert state["progression_stage"] == "brand_classification"
def test_set_validated_detection(self):
"""Test setting validated detection."""
manager = PipelineStateManager()
manager.clear_all()
detection = {"class": "car", "confidence": 0.85}
manager.set_validated_detection("camera1", detection)
state = manager.get_state("camera1")
assert state["validated_detection"] == detection
def test_get_stats(self):
"""Test getting pipeline state statistics."""
manager = PipelineStateManager()
manager.clear_all()
manager.update_mode("camera1", "validation_detecting")
manager.update_mode("camera2", "classification")
manager.update_mode("camera3", "classification")
stats = manager.get_stats()
assert stats["total_pipeline_states"] == 3
assert stats["mode_breakdown"]["validation_detecting"] == 1
assert stats["mode_breakdown"]["classification"] == 2
class TestThreadSafety:
"""Test thread safety of singleton managers."""
def test_model_manager_thread_safety(self):
"""Test ModelStateManager thread safety."""
manager = ModelStateManager()
manager.clear_all()
results = {}
def load_models(thread_id):
for i in range(10):
model = Mock()
model.thread_id = thread_id
model.model_id = i
manager.load_model(f"camera{thread_id}", f"model{i}", model)
# Verify models
models = manager.get_camera_models(f"camera{thread_id}")
results[thread_id] = len(models)
threads = []
for i in range(5):
thread = threading.Thread(target=load_models, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Each thread should have loaded 10 models
for thread_id, model_count in results.items():
assert model_count == 10
# Total should be 50 models
stats = manager.get_stats()
assert stats["total_models"] == 50

View file

@ -0,0 +1,479 @@
"""
Unit tests for detection result data structures.
"""
import pytest
from dataclasses import asdict
import numpy as np
from detector_worker.detection.detection_result import (
BoundingBox,
DetectionResult,
LightweightDetectionResult,
DetectionSession,
TrackValidationResult
)
class TestBoundingBox:
"""Test BoundingBox data structure."""
def test_creation_from_coordinates(self):
"""Test creating bounding box from coordinates."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
assert bbox.x1 == 100
assert bbox.y1 == 200
assert bbox.x2 == 300
assert bbox.y2 == 400
def test_creation_from_list(self):
"""Test creating bounding box from list."""
coords = [100, 200, 300, 400]
bbox = BoundingBox.from_list(coords)
assert bbox.x1 == 100
assert bbox.y1 == 200
assert bbox.x2 == 300
assert bbox.y2 == 400
def test_creation_from_invalid_list(self):
"""Test error handling for invalid list."""
with pytest.raises(ValueError):
BoundingBox.from_list([100, 200, 300]) # Too few elements
def test_to_list(self):
"""Test converting bounding box to list."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
coords = bbox.to_list()
assert coords == [100, 200, 300, 400]
def test_area_calculation(self):
"""Test area calculation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
area = bbox.area()
expected_area = (300 - 100) * (400 - 200) # 200 * 200 = 40000
assert area == expected_area
def test_area_zero_for_invalid_bbox(self):
"""Test area is zero for invalid bounding box."""
# x2 <= x1
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
assert bbox.area() == 0
# y2 <= y1
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
assert bbox.area() == 0
def test_width_height(self):
"""Test width and height properties."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
assert bbox.width() == 200
assert bbox.height() == 200
def test_center_point(self):
"""Test center point calculation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
center = bbox.center()
assert center == (200, 300) # (x1+x2)/2, (y1+y2)/2
def test_is_valid(self):
"""Test bounding box validation."""
# Valid bbox
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
assert bbox.is_valid() is True
# Invalid bbox (x2 <= x1)
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
assert bbox.is_valid() is False
# Invalid bbox (y2 <= y1)
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
assert bbox.is_valid() is False
def test_intersection(self):
"""Test bounding box intersection."""
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
intersection = bbox1.intersection(bbox2)
assert intersection.x1 == 200
assert intersection.y1 == 200
assert intersection.x2 == 300
assert intersection.y2 == 300
def test_no_intersection(self):
"""Test no intersection between bounding boxes."""
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
intersection = bbox1.intersection(bbox2)
assert intersection.is_valid() is False
def test_union(self):
"""Test bounding box union."""
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
union = bbox1.union(bbox2)
assert union.x1 == 100
assert union.y1 == 100
assert union.x2 == 400
assert union.y2 == 400
def test_iou_calculation(self):
"""Test IoU (Intersection over Union) calculation."""
# Perfect overlap
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
assert bbox1.iou(bbox2) == 1.0
# No overlap
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
assert bbox1.iou(bbox2) == 0.0
# Partial overlap
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
# Intersection area: 100x100 = 10000
# Union area: 200x200 + 200x200 - 10000 = 30000
# IoU = 10000/30000 = 1/3
expected_iou = 1.0 / 3.0
assert abs(bbox1.iou(bbox2) - expected_iou) < 1e-6
class TestDetectionResult:
"""Test DetectionResult data structure."""
def test_creation_with_required_fields(self):
"""Test creating detection result with required fields."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
assert detection.class_name == "car"
assert detection.confidence == 0.85
assert detection.bbox == bbox
assert detection.track_id == 12345
def test_creation_with_all_fields(self):
"""Test creating detection result with all fields."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345,
model_id="yolo_v8",
timestamp=1640995200000,
branch_results={"brand": "Toyota"}
)
assert detection.model_id == "yolo_v8"
assert detection.timestamp == 1640995200000
assert detection.branch_results == {"brand": "Toyota"}
def test_creation_from_dict(self):
"""Test creating detection result from dictionary."""
data = {
"class": "car",
"confidence": 0.85,
"bbox": [100, 200, 300, 400],
"id": 12345,
"model_id": "yolo_v8",
"timestamp": 1640995200000
}
detection = DetectionResult.from_dict(data)
assert detection.class_name == "car"
assert detection.confidence == 0.85
assert detection.bbox.to_list() == [100, 200, 300, 400]
assert detection.track_id == 12345
def test_to_dict(self):
"""Test converting detection result to dictionary."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
data = detection.to_dict()
assert data["class"] == "car"
assert data["confidence"] == 0.85
assert data["bbox"] == [100, 200, 300, 400]
assert data["id"] == 12345
def test_is_valid_detection(self):
"""Test detection validation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
# Valid detection
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
assert detection.is_valid() is True
# Invalid confidence (too low)
detection = DetectionResult(
class_name="car",
confidence=-0.1,
bbox=bbox,
track_id=12345
)
assert detection.is_valid() is False
# Invalid confidence (too high)
detection = DetectionResult(
class_name="car",
confidence=1.5,
bbox=bbox,
track_id=12345
)
assert detection.is_valid() is False
# Invalid bounding box
invalid_bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=invalid_bbox,
track_id=12345
)
assert detection.is_valid() is False
class TestLightweightDetectionResult:
"""Test LightweightDetectionResult data structure."""
def test_creation(self):
"""Test creating lightweight detection result."""
detection = LightweightDetectionResult(
class_name="car",
confidence=0.85,
bbox_area=40000,
frame_width=1920,
frame_height=1080
)
assert detection.class_name == "car"
assert detection.confidence == 0.85
assert detection.bbox_area == 40000
assert detection.frame_width == 1920
assert detection.frame_height == 1080
def test_area_ratio_calculation(self):
"""Test bounding box area ratio calculation."""
detection = LightweightDetectionResult(
class_name="car",
confidence=0.85,
bbox_area=40000,
frame_width=1920,
frame_height=1080
)
expected_ratio = 40000 / (1920 * 1080)
assert abs(detection.area_ratio() - expected_ratio) < 1e-6
def test_meets_threshold(self):
"""Test threshold checking."""
detection = LightweightDetectionResult(
class_name="car",
confidence=0.85,
bbox_area=40000,
frame_width=1920,
frame_height=1080
)
assert detection.meets_threshold(confidence=0.8, area_ratio=0.01) is True
assert detection.meets_threshold(confidence=0.9, area_ratio=0.01) is False
assert detection.meets_threshold(confidence=0.8, area_ratio=0.1) is False
class TestDetectionSession:
"""Test DetectionSession data structure."""
def test_creation(self):
"""Test creating detection session."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
assert session.session_id == "session_123"
assert session.camera_id == "camera_001"
assert session.display_id == "display_001"
assert session.detections == []
assert session.metadata == {}
def test_add_detection(self):
"""Test adding detection to session."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
session.add_detection(detection)
assert len(session.detections) == 1
assert session.detections[0] == detection
def test_get_latest_detection(self):
"""Test getting latest detection."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
# Add multiple detections
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=12345,
timestamp=1640995200000
)
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
detection2 = DetectionResult(
class_name="car",
confidence=0.90,
bbox=bbox2,
track_id=12345,
timestamp=1640995300000
)
session.add_detection(detection1)
session.add_detection(detection2)
latest = session.get_latest_detection()
assert latest == detection2 # Should be the one with later timestamp
def test_get_detections_by_class(self):
"""Test filtering detections by class."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
car_detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
truck_detection = DetectionResult(
class_name="truck",
confidence=0.80,
bbox=bbox,
track_id=54321
)
session.add_detection(car_detection)
session.add_detection(truck_detection)
car_detections = session.get_detections_by_class("car")
assert len(car_detections) == 1
assert car_detections[0] == car_detection
truck_detections = session.get_detections_by_class("truck")
assert len(truck_detections) == 1
assert truck_detections[0] == truck_detection
class TestTrackValidationResult:
"""Test TrackValidationResult data structure."""
def test_creation(self):
"""Test creating track validation result."""
result = TrackValidationResult(
stable_tracks=[101, 102, 103],
current_tracks=[101, 102, 104, 105],
newly_stable=[103],
lost_tracks=[106]
)
assert result.stable_tracks == [101, 102, 103]
assert result.current_tracks == [101, 102, 104, 105]
assert result.newly_stable == [103]
assert result.lost_tracks == [106]
def test_has_stable_tracks(self):
"""Test checking for stable tracks."""
result = TrackValidationResult(
stable_tracks=[101, 102],
current_tracks=[101, 102, 103]
)
assert result.has_stable_tracks() is True
result_empty = TrackValidationResult(
stable_tracks=[],
current_tracks=[101, 102, 103]
)
assert result_empty.has_stable_tracks() is False
def test_get_stats(self):
"""Test getting validation statistics."""
result = TrackValidationResult(
stable_tracks=[101, 102, 103],
current_tracks=[101, 102, 104, 105],
newly_stable=[103],
lost_tracks=[106]
)
stats = result.get_stats()
assert stats["stable_count"] == 3
assert stats["current_count"] == 4
assert stats["newly_stable_count"] == 1
assert stats["lost_count"] == 1
assert stats["stability_ratio"] == 3/4 # stable/current
def test_is_track_stable(self):
"""Test checking if specific track is stable."""
result = TrackValidationResult(
stable_tracks=[101, 102, 103],
current_tracks=[101, 102, 104, 105]
)
assert result.is_track_stable(101) is True
assert result.is_track_stable(102) is True
assert result.is_track_stable(104) is False
assert result.is_track_stable(999) is False

View file

@ -0,0 +1,701 @@
"""
Unit tests for track stability validation.
"""
import pytest
import time
from unittest.mock import Mock, patch
from collections import defaultdict
from detector_worker.detection.stability_validator import (
StabilityValidator,
StabilityConfig,
ValidationResult,
TrackStabilityMetrics
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox, TrackValidationResult
from detector_worker.core.exceptions import ValidationError
class TestStabilityConfig:
"""Test stability configuration data structure."""
def test_default_config(self):
"""Test default stability configuration."""
config = StabilityConfig()
assert config.min_detection_frames == 10
assert config.max_absence_frames == 30
assert config.confidence_threshold == 0.5
assert config.stability_window == 60.0
assert config.iou_threshold == 0.3
assert config.movement_threshold == 50.0
def test_custom_config(self):
"""Test custom stability configuration."""
config = StabilityConfig(
min_detection_frames=5,
max_absence_frames=15,
confidence_threshold=0.8,
stability_window=30.0,
iou_threshold=0.5,
movement_threshold=25.0
)
assert config.min_detection_frames == 5
assert config.max_absence_frames == 15
assert config.confidence_threshold == 0.8
assert config.stability_window == 30.0
assert config.iou_threshold == 0.5
assert config.movement_threshold == 25.0
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"min_detection_frames": 8,
"max_absence_frames": 25,
"confidence_threshold": 0.75,
"unknown_field": "ignored"
}
config = StabilityConfig.from_dict(config_dict)
assert config.min_detection_frames == 8
assert config.max_absence_frames == 25
assert config.confidence_threshold == 0.75
# Unknown fields should use defaults
assert config.stability_window == 60.0
class TestTrackStabilityMetrics:
"""Test track stability metrics."""
def test_initialization(self):
"""Test metrics initialization."""
metrics = TrackStabilityMetrics(track_id=1001)
assert metrics.track_id == 1001
assert metrics.detection_count == 0
assert metrics.absence_count == 0
assert metrics.total_confidence == 0.0
assert metrics.first_detection_time is None
assert metrics.last_detection_time is None
assert metrics.bounding_boxes == []
assert metrics.confidence_scores == []
def test_add_detection(self):
"""Test adding detection to metrics."""
metrics = TrackStabilityMetrics(track_id=1001)
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
metrics.add_detection(detection, current_time=1640995200.0)
assert metrics.detection_count == 1
assert metrics.absence_count == 0
assert metrics.total_confidence == 0.85
assert metrics.first_detection_time == 1640995200.0
assert metrics.last_detection_time == 1640995200.0
assert len(metrics.bounding_boxes) == 1
assert len(metrics.confidence_scores) == 1
def test_increment_absence(self):
"""Test incrementing absence count."""
metrics = TrackStabilityMetrics(track_id=1001)
metrics.increment_absence()
assert metrics.absence_count == 1
metrics.increment_absence()
assert metrics.absence_count == 2
def test_reset_absence(self):
"""Test resetting absence count."""
metrics = TrackStabilityMetrics(track_id=1001)
metrics.increment_absence()
metrics.increment_absence()
assert metrics.absence_count == 2
metrics.reset_absence()
assert metrics.absence_count == 0
def test_average_confidence(self):
"""Test average confidence calculation."""
metrics = TrackStabilityMetrics(track_id=1001)
# No detections
assert metrics.average_confidence() == 0.0
# Add detections
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
detection2 = DetectionResult(
class_name="car",
confidence=0.9,
bbox=bbox,
track_id=1001,
timestamp=1640995300000
)
metrics.add_detection(detection1, current_time=1640995200.0)
metrics.add_detection(detection2, current_time=1640995300.0)
assert metrics.average_confidence() == 0.85 # (0.8 + 0.9) / 2
def test_tracking_duration(self):
"""Test tracking duration calculation."""
metrics = TrackStabilityMetrics(track_id=1001)
# No detections
assert metrics.tracking_duration() == 0.0
# Add detections
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
detection2 = DetectionResult(
class_name="car",
confidence=0.9,
bbox=bbox,
track_id=1001,
timestamp=1640995300000
)
metrics.add_detection(detection1, current_time=1640995200.0)
metrics.add_detection(detection2, current_time=1640995300.0)
assert metrics.tracking_duration() == 100.0 # 1640995300 - 1640995200
def test_movement_distance(self):
"""Test movement distance calculation."""
metrics = TrackStabilityMetrics(track_id=1001)
# No movement with single detection
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
metrics.add_detection(detection1, current_time=1640995200.0)
assert metrics.total_movement_distance() == 0.0
# Add second detection with movement
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
detection2 = DetectionResult(
class_name="car",
confidence=0.9,
bbox=bbox2,
track_id=1001,
timestamp=1640995300000
)
metrics.add_detection(detection2, current_time=1640995300.0)
# Distance between centers: (200,300) to (210,310) = sqrt(100+100) ≈ 14.14
movement = metrics.total_movement_distance()
assert movement == pytest.approx(14.14, rel=1e-2)
class TestValidationResult:
"""Test validation result data structure."""
def test_initialization(self):
"""Test validation result initialization."""
result = ValidationResult(
track_id=1001,
is_stable=True,
detection_count=15,
absence_count=2,
average_confidence=0.85,
tracking_duration=120.0
)
assert result.track_id == 1001
assert result.is_stable is True
assert result.detection_count == 15
assert result.absence_count == 2
assert result.average_confidence == 0.85
assert result.tracking_duration == 120.0
assert result.reasons == []
def test_with_reasons(self):
"""Test validation result with failure reasons."""
result = ValidationResult(
track_id=1001,
is_stable=False,
detection_count=5,
absence_count=35,
average_confidence=0.4,
tracking_duration=30.0,
reasons=["Insufficient detection frames", "Too many absences", "Low confidence"]
)
assert result.is_stable is False
assert len(result.reasons) == 3
assert "Insufficient detection frames" in result.reasons
class TestStabilityValidator:
"""Test stability validation functionality."""
def test_initialization_default(self):
"""Test validator initialization with default config."""
validator = StabilityValidator()
assert isinstance(validator.config, StabilityConfig)
assert validator.config.min_detection_frames == 10
assert len(validator.track_metrics) == 0
def test_initialization_custom_config(self):
"""Test validator initialization with custom config."""
config = StabilityConfig(min_detection_frames=5, confidence_threshold=0.8)
validator = StabilityValidator(config)
assert validator.config.min_detection_frames == 5
assert validator.config.confidence_threshold == 0.8
def test_update_detections_new_track(self):
"""Test updating with new track."""
validator = StabilityValidator()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
validator.update_detections([detection], current_time=1640995200.0)
assert 1001 in validator.track_metrics
metrics = validator.track_metrics[1001]
assert metrics.detection_count == 1
assert metrics.absence_count == 0
def test_update_detections_existing_track(self):
"""Test updating existing track."""
validator = StabilityValidator()
# First detection
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
validator.update_detections([detection1], current_time=1640995200.0)
# Second detection
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
detection2 = DetectionResult(
class_name="car",
confidence=0.9,
bbox=bbox2,
track_id=1001,
timestamp=1640995300000
)
validator.update_detections([detection2], current_time=1640995300.0)
metrics = validator.track_metrics[1001]
assert metrics.detection_count == 2
assert metrics.absence_count == 0
assert metrics.average_confidence() == 0.85
def test_update_detections_missing_track(self):
"""Test updating when track is missing (increment absence)."""
validator = StabilityValidator()
# Add track
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
validator.update_detections([detection], current_time=1640995200.0)
# Update with empty detections
validator.update_detections([], current_time=1640995300.0)
metrics = validator.track_metrics[1001]
assert metrics.detection_count == 1
assert metrics.absence_count == 1
def test_validate_track_stable(self):
"""Test validating a stable track."""
config = StabilityConfig(min_detection_frames=3, max_absence_frames=5)
validator = StabilityValidator(config)
# Create track with sufficient detections
track_id = 1001
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
# Add sufficient detections
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(5):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
result = validator.validate_track(track_id)
assert result.is_stable is True
assert result.detection_count == 5
assert result.absence_count == 0
assert len(result.reasons) == 0
def test_validate_track_insufficient_detections(self):
"""Test validating track with insufficient detections."""
config = StabilityConfig(min_detection_frames=10, max_absence_frames=5)
validator = StabilityValidator(config)
# Create track with insufficient detections
track_id = 1001
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
# Add only few detections
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(3):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
result = validator.validate_track(track_id)
assert result.is_stable is False
assert "Insufficient detection frames" in result.reasons
def test_validate_track_too_many_absences(self):
"""Test validating track with too many absences."""
config = StabilityConfig(min_detection_frames=3, max_absence_frames=2)
validator = StabilityValidator(config)
# Create track with too many absences
track_id = 1001
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
# Add detections and absences
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(5):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
# Add too many absences
for _ in range(5):
metrics.increment_absence()
result = validator.validate_track(track_id)
assert result.is_stable is False
assert "Too many absence frames" in result.reasons
def test_validate_track_low_confidence(self):
"""Test validating track with low confidence."""
config = StabilityConfig(
min_detection_frames=3,
max_absence_frames=5,
confidence_threshold=0.8
)
validator = StabilityValidator(config)
# Create track with low confidence
track_id = 1001
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
# Add detections with low confidence
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(5):
detection = DetectionResult(
class_name="car",
confidence=0.5, # Below threshold
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
result = validator.validate_track(track_id)
assert result.is_stable is False
assert "Low average confidence" in result.reasons
def test_validate_all_tracks(self):
"""Test validating all tracks."""
config = StabilityConfig(min_detection_frames=3)
validator = StabilityValidator(config)
# Add multiple tracks
for track_id in [1001, 1002, 1003]:
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
# Make some tracks stable, others not
detection_count = 5 if track_id == 1001 else 2
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(detection_count):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
results = validator.validate_all_tracks()
assert len(results) == 3
assert results[1001].is_stable is True # 5 detections
assert results[1002].is_stable is False # 2 detections
assert results[1003].is_stable is False # 2 detections
def test_get_stable_tracks(self):
"""Test getting stable track IDs."""
config = StabilityConfig(min_detection_frames=3)
validator = StabilityValidator(config)
# Add tracks with different stability
for track_id, detection_count in [(1001, 5), (1002, 2), (1003, 4)]:
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(detection_count):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
stable_tracks = validator.get_stable_tracks()
assert stable_tracks == [1001, 1003] # 5 and 4 detections respectively
def test_cleanup_expired_tracks(self):
"""Test cleanup of expired tracks."""
config = StabilityConfig(stability_window=10.0)
validator = StabilityValidator(config)
# Add tracks with different last detection times
current_time = 1640995300.0
for track_id, last_detection_time in [(1001, current_time - 5), (1002, current_time - 15)]:
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=int(last_detection_time * 1000)
)
metrics.add_detection(detection, current_time=last_detection_time)
removed_count = validator.cleanup_expired_tracks(current_time)
assert removed_count == 1 # 1002 should be removed (15 > 10 seconds)
assert 1001 in validator.track_metrics
assert 1002 not in validator.track_metrics
def test_clear_all_tracks(self):
"""Test clearing all track metrics."""
validator = StabilityValidator()
# Add some tracks
for track_id in [1001, 1002]:
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
assert len(validator.track_metrics) == 2
validator.clear_all_tracks()
assert len(validator.track_metrics) == 0
def test_get_validation_summary(self):
"""Test getting validation summary statistics."""
config = StabilityConfig(min_detection_frames=3)
validator = StabilityValidator(config)
# Add tracks with different characteristics
track_data = [
(1001, 5, True), # Stable
(1002, 2, False), # Unstable
(1003, 4, True), # Stable
(1004, 1, False) # Unstable
]
for track_id, detection_count, _ in track_data:
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(detection_count):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
summary = validator.get_validation_summary()
assert summary["total_tracks"] == 4
assert summary["stable_tracks"] == 2
assert summary["unstable_tracks"] == 2
assert summary["stability_rate"] == 0.5
class TestStabilityValidatorIntegration:
"""Integration tests for stability validator."""
def test_full_tracking_lifecycle(self):
"""Test complete tracking lifecycle with stability validation."""
config = StabilityConfig(
min_detection_frames=3,
max_absence_frames=2,
confidence_threshold=0.7
)
validator = StabilityValidator(config)
track_id = 1001
# Phase 1: Initial detections (building up)
for i in range(5):
bbox = BoundingBox(x1=100+i*2, y1=200+i*2, x2=300+i*2, y2=400+i*2)
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
validator.update_detections([detection], current_time=1640995200.0 + i)
# Should be stable now
result = validator.validate_track(track_id)
assert result.is_stable is True
# Phase 2: Some absences
for i in range(2):
validator.update_detections([], current_time=1640995205.0 + i)
# Still stable (within absence threshold)
result = validator.validate_track(track_id)
assert result.is_stable is True
# Phase 3: Track reappears
bbox = BoundingBox(x1=120, y1=220, x2=320, y2=420)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=track_id,
timestamp=1640995207000
)
validator.update_detections([detection], current_time=1640995207.0)
# Should reset absence count and remain stable
result = validator.validate_track(track_id)
assert result.is_stable is True
assert validator.track_metrics[track_id].absence_count == 0
def test_multi_track_validation(self):
"""Test validation with multiple tracks."""
validator = StabilityValidator()
# Simulate multi-track scenario
frame_detections = [
# Frame 1
[
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000),
DetectionResult("truck", 0.8, BoundingBox(400, 200, 600, 400), 1002, 1640995200000)
],
# Frame 2
[
DetectionResult("car", 0.85, BoundingBox(105, 205, 305, 405), 1001, 1640995201000),
DetectionResult("truck", 0.82, BoundingBox(405, 205, 605, 405), 1002, 1640995201000),
DetectionResult("car", 0.75, BoundingBox(200, 300, 400, 500), 1003, 1640995201000)
],
# Frame 3 - track 1002 disappears
[
DetectionResult("car", 0.88, BoundingBox(110, 210, 310, 410), 1001, 1640995202000),
DetectionResult("car", 0.78, BoundingBox(205, 305, 405, 505), 1003, 1640995202000)
]
]
# Process frames
for i, detections in enumerate(frame_detections):
validator.update_detections(detections, current_time=1640995200.0 + i)
# Get validation results
validation_results = validator.validate_all_tracks()
assert len(validation_results) == 3
# All tracks should be unstable (insufficient frames)
for result in validation_results.values():
assert result.is_stable is False
assert "Insufficient detection frames" in result.reasons

View file

@ -0,0 +1,606 @@
"""
Unit tests for BoT-SORT tracking management.
"""
import pytest
import numpy as np
from unittest.mock import Mock, MagicMock, patch
from collections import defaultdict
from detector_worker.detection.tracking_manager import TrackingManager, TrackInfo
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import TrackingError
class TestTrackInfo:
"""Test TrackInfo data structure."""
def test_creation(self):
"""Test TrackInfo creation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
assert track.track_id == 1001
assert track.bbox == bbox
assert track.confidence == 0.85
assert track.class_name == "car"
assert track.first_seen == 1640995200.0
assert track.last_seen == 1640995300.0
assert track.frame_count == 1
assert track.absence_count == 0
def test_update_track(self):
"""Test updating track information."""
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox1,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995200.0
)
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
track.update(bbox2, 0.90, 1640995300.0)
assert track.bbox == bbox2
assert track.confidence == 0.90
assert track.last_seen == 1640995300.0
assert track.frame_count == 2
assert track.absence_count == 0
def test_increment_absence(self):
"""Test incrementing absence count."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995200.0
)
track.increment_absence()
assert track.absence_count == 1
track.increment_absence()
assert track.absence_count == 2
def test_age_calculation(self):
"""Test track age calculation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
age = track.age(current_time=1640995400.0)
assert age == 200.0 # 1640995400 - 1640995200
def test_time_since_last_seen(self):
"""Test time since last seen calculation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
time_since = track.time_since_last_seen(current_time=1640995450.0)
assert time_since == 150.0 # 1640995450 - 1640995300
def test_is_stable(self):
"""Test track stability checking."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
# Not stable initially
assert track.is_stable(min_frames=5, max_absence=3) is False
# Make it stable
track.frame_count = 10
track.absence_count = 1
assert track.is_stable(min_frames=5, max_absence=3) is True
# Too many absences
track.absence_count = 5
assert track.is_stable(min_frames=5, max_absence=3) is False
class TestTrackingManager:
"""Test tracking management functionality."""
def test_initialization(self):
"""Test tracking manager initialization."""
manager = TrackingManager()
assert manager.max_absence_frames == 30
assert manager.min_stable_frames == 10
assert manager.track_timeout == 60.0
assert len(manager.active_tracks) == 0
assert len(manager.stable_tracks) == 0
def test_initialization_with_config(self):
"""Test initialization with custom configuration."""
config = {
"max_absence_frames": 20,
"min_stable_frames": 5,
"track_timeout": 30.0
}
manager = TrackingManager(config)
assert manager.max_absence_frames == 20
assert manager.min_stable_frames == 5
assert manager.track_timeout == 30.0
def test_update_tracks_new_detections(self):
"""Test updating with new detections."""
manager = TrackingManager()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
assert len(manager.active_tracks) == 1
assert 1001 in manager.active_tracks
track = manager.active_tracks[1001]
assert track.track_id == 1001
assert track.class_name == "car"
assert track.confidence == 0.85
assert track.frame_count == 1
def test_update_tracks_existing_detection(self):
"""Test updating existing track."""
manager = TrackingManager()
# First detection
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection1], current_time=1640995200.0)
# Second detection (same track, different position)
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
detection2 = DetectionResult(
class_name="car",
confidence=0.90,
bbox=bbox2,
track_id=1001,
timestamp=1640995300000
)
manager.update_tracks([detection2], current_time=1640995300.0)
assert len(manager.active_tracks) == 1
track = manager.active_tracks[1001]
assert track.frame_count == 2
assert track.confidence == 0.90
assert track.bbox == bbox2
assert track.absence_count == 0
def test_update_tracks_no_detections(self):
"""Test updating with no detections (increment absence)."""
manager = TrackingManager()
# Add initial track
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
# Update with no detections
manager.update_tracks([], current_time=1640995300.0)
track = manager.active_tracks[1001]
assert track.absence_count == 1
def test_cleanup_expired_tracks(self):
"""Test cleanup of expired tracks."""
manager = TrackingManager({"track_timeout": 10.0})
# Add track
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
assert len(manager.active_tracks) == 1
# Cleanup after timeout
removed_count = manager.cleanup_expired_tracks(current_time=1640995220.0) # 20 seconds later
assert removed_count == 1
assert len(manager.active_tracks) == 0
def test_cleanup_absent_tracks(self):
"""Test cleanup of tracks with too many absences."""
manager = TrackingManager({"max_absence_frames": 3})
# Add track
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
# Increment absence count beyond threshold
for i in range(5):
manager.update_tracks([], current_time=1640995200.0 + i)
track = manager.active_tracks[1001]
assert track.absence_count == 5
# Cleanup absent tracks
removed_count = manager.cleanup_absent_tracks()
assert removed_count == 1
assert len(manager.active_tracks) == 0
def test_get_stable_tracks(self):
"""Test getting stable tracks."""
manager = TrackingManager({"min_stable_frames": 3})
# Add track and make it stable
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track_info = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
track_info.frame_count = 5 # Make it stable
manager.active_tracks[1001] = track_info
stable_tracks = manager.get_stable_tracks()
assert len(stable_tracks) == 1
assert 1001 in stable_tracks
assert 1001 in manager.stable_tracks # Should be cached
def test_get_track_by_id(self):
"""Test getting track by ID."""
manager = TrackingManager()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
track = manager.get_track_by_id(1001)
assert track is not None
assert track.track_id == 1001
non_existent = manager.get_track_by_id(9999)
assert non_existent is None
def test_get_tracks_by_class(self):
"""Test getting tracks by class name."""
manager = TrackingManager()
# Add different classes
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
detection2 = DetectionResult(
class_name="truck",
confidence=0.80,
bbox=bbox2,
track_id=1002,
timestamp=1640995200000
)
bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500)
detection3 = DetectionResult(
class_name="car",
confidence=0.90,
bbox=bbox3,
track_id=1003,
timestamp=1640995200000
)
manager.update_tracks([detection1, detection2, detection3], current_time=1640995200.0)
car_tracks = manager.get_tracks_by_class("car")
assert len(car_tracks) == 2
assert 1001 in car_tracks
assert 1003 in car_tracks
truck_tracks = manager.get_tracks_by_class("truck")
assert len(truck_tracks) == 1
assert 1002 in truck_tracks
def test_get_track_count(self):
"""Test getting track counts."""
manager = TrackingManager()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
assert manager.get_active_track_count() == 1
assert manager.get_track_count_by_class("car") == 1
assert manager.get_track_count_by_class("truck") == 0
def test_clear_all_tracks(self):
"""Test clearing all tracks."""
manager = TrackingManager()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
assert len(manager.active_tracks) == 1
manager.clear_all_tracks()
assert len(manager.active_tracks) == 0
assert len(manager.stable_tracks) == 0
def test_get_track_statistics(self):
"""Test getting track statistics."""
manager = TrackingManager({"min_stable_frames": 2})
# Add multiple tracks
detections = []
for i in range(3):
bbox = BoundingBox(x1=100+i*50, y1=200, x2=300+i*50, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001+i,
timestamp=1640995200000
)
detections.append(detection)
manager.update_tracks(detections, current_time=1640995200.0)
# Make some tracks stable
manager.active_tracks[1001].frame_count = 5
manager.active_tracks[1002].frame_count = 3
# 1003 remains unstable with frame_count=1
stats = manager.get_track_statistics()
assert stats["active_tracks"] == 3
assert stats["stable_tracks"] == 2
assert stats["unstable_tracks"] == 1
assert "average_track_age" in stats
assert "average_confidence" in stats
def test_validate_tracks(self):
"""Test track validation."""
manager = TrackingManager({"min_stable_frames": 3, "max_absence_frames": 2})
# Add tracks with different stability
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track1 = TrackInfo(
track_id=1001,
bbox=bbox1,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
track1.frame_count = 5 # Stable
track1.absence_count = 1 # Present
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
track2 = TrackInfo(
track_id=1002,
bbox=bbox2,
confidence=0.80,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995250.0
)
track2.frame_count = 2 # Not stable
track2.absence_count = 1
bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500)
track3 = TrackInfo(
track_id=1003,
bbox=bbox3,
confidence=0.90,
class_name="car",
first_seen=1640995100.0,
last_seen=1640995150.0
)
track3.frame_count = 8 # Was stable but now absent
track3.absence_count = 5 # Too many absences
manager.active_tracks = {1001: track1, 1002: track2, 1003: track3}
manager.stable_tracks = {1001, 1003} # 1003 was previously stable
validation_result = manager.validate_tracks()
assert validation_result.stable_tracks == [1001]
assert validation_result.current_tracks == [1001, 1002, 1003]
assert validation_result.newly_stable == []
assert validation_result.lost_tracks == [1003]
def test_track_persistence_across_frames(self):
"""Test track persistence across multiple frames."""
manager = TrackingManager()
# Frame 1
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection1], current_time=1640995200.0)
# Frame 2 - track moves
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
detection2 = DetectionResult(
class_name="car",
confidence=0.88,
bbox=bbox2,
track_id=1001,
timestamp=1640995300000
)
manager.update_tracks([detection2], current_time=1640995300.0)
# Frame 3 - track disappears
manager.update_tracks([], current_time=1640995400.0)
# Frame 4 - track reappears
bbox4 = BoundingBox(x1=120, y1=220, x2=320, y2=420)
detection4 = DetectionResult(
class_name="car",
confidence=0.82,
bbox=bbox4,
track_id=1001,
timestamp=1640995500000
)
manager.update_tracks([detection4], current_time=1640995500.0)
track = manager.active_tracks[1001]
assert track.frame_count == 3 # Seen in 3 frames
assert track.absence_count == 0 # Reset when reappeared
assert track.bbox == bbox4 # Latest position
class TestTrackingManagerErrorHandling:
"""Test error handling in tracking manager."""
def test_invalid_detection_input(self):
"""Test handling of invalid detection input."""
manager = TrackingManager()
# None detection should be handled gracefully
with pytest.raises(TrackingError):
manager.update_tracks([None], current_time=1640995200.0)
def test_negative_track_id(self):
"""Test handling of negative track ID."""
manager = TrackingManager()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=-1, # Invalid track ID
timestamp=1640995200000
)
with pytest.raises(TrackingError):
manager.update_tracks([detection], current_time=1640995200.0)
def test_duplicate_track_ids_different_classes(self):
"""Test handling of duplicate track IDs with different classes."""
manager = TrackingManager()
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
detection2 = DetectionResult(
class_name="truck", # Different class, same ID
confidence=0.80,
bbox=bbox2,
track_id=1001,
timestamp=1640995200000
)
# Should log warning but handle gracefully
manager.update_tracks([detection1, detection2], current_time=1640995200.0)
# The later detection should update the track
track = manager.active_tracks[1001]
assert track.class_name == "truck" # Last update wins

View file

@ -0,0 +1,386 @@
"""
Unit tests for YOLO detector with tracking functionality.
"""
import pytest
import numpy as np
from unittest.mock import Mock, MagicMock, patch
import torch
from detector_worker.detection.yolo_detector import YOLODetector
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import DetectionError
class TestYOLODetector:
"""Test YOLO detection and tracking functionality."""
def test_initialization_with_valid_model(self, mock_yolo_model):
"""Test detector initialization with valid model."""
detector = YOLODetector(mock_yolo_model)
assert detector.model is mock_yolo_model
assert detector.class_names == {}
assert detector.is_tracking_enabled is True
def test_initialization_with_class_names(self, mock_yolo_model):
"""Test detector initialization with class names."""
class_names = {0: "car", 1: "truck", 2: "bus"}
detector = YOLODetector(mock_yolo_model, class_names=class_names)
assert detector.class_names == class_names
def test_initialization_tracking_disabled(self, mock_yolo_model):
"""Test detector initialization with tracking disabled."""
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
assert detector.is_tracking_enabled is False
def test_detect_with_tracking(self, mock_yolo_model, mock_frame):
"""Test detection with tracking enabled."""
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0], # x1, y1, x2, y2, conf, class
[150, 250, 350, 450, 0.85, 1]
])
mock_result.boxes.id = torch.tensor([1001, 1002])
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame)
assert len(detections) == 2
assert detections[0].confidence == 0.9
assert detections[0].track_id == 1001
assert detections[0].bbox.x1 == 100
mock_yolo_model.track.assert_called_once_with(mock_frame, persist=True, verbose=False)
def test_detect_without_tracking(self, mock_yolo_model, mock_frame):
"""Test detection with tracking disabled."""
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = None # No tracking IDs
mock_yolo_model.predict.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
detections = detector.detect(mock_frame)
assert len(detections) == 1
assert detections[0].track_id is None # No tracking ID
mock_yolo_model.predict.assert_called_once_with(mock_frame, verbose=False)
def test_detect_with_class_names(self, mock_yolo_model, mock_frame):
"""Test detection with class name mapping."""
class_names = {0: "car", 1: "truck"}
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0], # car
[150, 250, 350, 450, 0.85, 1] # truck
])
mock_result.boxes.id = torch.tensor([1001, 1002])
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model, class_names=class_names)
detections = detector.detect(mock_frame)
assert detections[0].class_name == "car"
assert detections[1].class_name == "truck"
def test_detect_no_boxes(self, mock_yolo_model, mock_frame):
"""Test detection when no objects are detected."""
mock_result = Mock()
mock_result.boxes = None
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame)
assert detections == []
def test_detect_empty_boxes(self, mock_yolo_model, mock_frame):
"""Test detection with empty boxes tensor."""
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([]).reshape(0, 6)
mock_result.boxes.id = None
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame)
assert detections == []
def test_detect_with_confidence_threshold(self, mock_yolo_model, mock_frame):
"""Test detection with confidence threshold filtering."""
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0], # Above threshold
[150, 250, 350, 450, 0.3, 1] # Below threshold
])
mock_result.boxes.id = torch.tensor([1001, 1002])
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame, confidence_threshold=0.5)
assert len(detections) == 1 # Only one above threshold
assert detections[0].confidence == 0.9
def test_detect_model_error_handling(self, mock_yolo_model, mock_frame):
"""Test error handling when model fails."""
mock_yolo_model.track.side_effect = Exception("Model inference failed")
detector = YOLODetector(mock_yolo_model)
with pytest.raises(DetectionError) as exc_info:
detector.detect(mock_frame)
assert "Model inference failed" in str(exc_info.value)
def test_detect_invalid_frame(self, mock_yolo_model):
"""Test detection with invalid frame input."""
detector = YOLODetector(mock_yolo_model)
with pytest.raises(DetectionError) as exc_info:
detector.detect(None)
assert "Invalid frame" in str(exc_info.value)
def test_detect_result_validation(self, mock_yolo_model, mock_frame):
"""Test detection result validation."""
# Mock result with invalid bounding box (x2 <= x1)
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[300, 200, 100, 400, 0.9, 0] # Invalid: x2 < x1
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame)
# Invalid detections should be filtered out
assert detections == []
def test_get_model_info(self, mock_yolo_model):
"""Test getting model information."""
mock_yolo_model.device = "cuda:0"
mock_yolo_model.names = {0: "car", 1: "truck"}
detector = YOLODetector(mock_yolo_model)
info = detector.get_model_info()
assert info["device"] == "cuda:0"
assert info["class_names"] == {0: "car", 1: "truck"}
assert info["tracking_enabled"] is True
def test_set_tracking_enabled(self, mock_yolo_model):
"""Test enabling/disabling tracking at runtime."""
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
assert detector.is_tracking_enabled is False
detector.set_tracking_enabled(True)
assert detector.is_tracking_enabled is True
detector.set_tracking_enabled(False)
assert detector.is_tracking_enabled is False
def test_update_class_names(self, mock_yolo_model):
"""Test updating class names at runtime."""
detector = YOLODetector(mock_yolo_model)
new_class_names = {0: "vehicle", 1: "person"}
detector.update_class_names(new_class_names)
assert detector.class_names == new_class_names
def test_reset_tracker(self, mock_yolo_model):
"""Test resetting the tracking state."""
detector = YOLODetector(mock_yolo_model)
# This should not raise an error
detector.reset_tracker()
def test_detect_with_crop_region(self, mock_yolo_model, mock_frame):
"""Test detection with crop region specified."""
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[50, 75, 150, 175, 0.9, 0] # Relative to cropped region
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
crop_region = (100, 200, 300, 400) # x1, y1, x2, y2
detections = detector.detect(mock_frame, crop_region=crop_region)
# Bounding box should be adjusted to global coordinates
assert detections[0].bbox.x1 == 150 # 100 + 50
assert detections[0].bbox.y1 == 275 # 200 + 75
assert detections[0].bbox.x2 == 250 # 100 + 150
assert detections[0].bbox.y2 == 375 # 200 + 175
def test_detect_batch_processing(self, mock_yolo_model):
"""Test batch detection processing."""
frames = [
np.zeros((480, 640, 3), dtype=np.uint8),
np.ones((480, 640, 3), dtype=np.uint8) * 255
]
mock_results = []
for i in range(2):
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100 + i*10, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001 + i])
mock_results.append(mock_result)
mock_yolo_model.track.side_effect = [[result] for result in mock_results]
detector = YOLODetector(mock_yolo_model)
batch_detections = detector.detect_batch(frames)
assert len(batch_detections) == 2
assert len(batch_detections[0]) == 1
assert len(batch_detections[1]) == 1
assert batch_detections[0][0].bbox.x1 == 100
assert batch_detections[1][0].bbox.x1 == 110
def test_detect_batch_empty_frames(self, mock_yolo_model):
"""Test batch detection with empty frame list."""
detector = YOLODetector(mock_yolo_model)
batch_detections = detector.detect_batch([])
assert batch_detections == []
def test_detect_performance_metrics(self, mock_yolo_model, mock_frame):
"""Test detection performance metrics collection."""
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_result.speed = {"preprocess": 2.1, "inference": 15.3, "postprocess": 1.2}
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame, return_metrics=True)
# Check if performance metrics are available
assert hasattr(detector, '_last_inference_time')
@pytest.mark.parametrize("device", ["cpu", "cuda:0", "mps"])
def test_detect_different_devices(self, device, mock_frame):
"""Test detection on different devices."""
mock_model = Mock()
mock_model.device = device
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_model.track.return_value = [mock_result]
detector = YOLODetector(mock_model)
detections = detector.detect(mock_frame)
assert len(detections) == 1
assert detections[0].confidence == 0.9
class TestYOLODetectorIntegration:
"""Integration tests for YOLO detector."""
def test_detect_with_real_tensor_operations(self, mock_yolo_model, mock_frame):
"""Test detection with realistic tensor operations."""
# Create more realistic box data
boxes_data = torch.tensor([
[100.5, 200.3, 299.7, 399.8, 0.95, 0],
[150.2, 250.1, 349.9, 449.6, 0.87, 1],
[200.0, 300.0, 400.0, 500.0, 0.45, 0] # Low confidence
])
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = boxes_data
mock_result.boxes.id = torch.tensor([2001, 2002, 2003])
mock_yolo_model.track.return_value = [mock_result]
class_names = {0: "car", 1: "truck"}
detector = YOLODetector(mock_yolo_model, class_names=class_names)
detections = detector.detect(mock_frame, confidence_threshold=0.5)
# Should filter out low confidence detection
assert len(detections) == 2
# Check first detection
det1 = detections[0]
assert det1.class_name == "car"
assert det1.confidence == pytest.approx(0.95)
assert det1.track_id == 2001
assert det1.bbox.x1 == pytest.approx(100.5)
assert det1.bbox.y1 == pytest.approx(200.3)
# Check second detection
det2 = detections[1]
assert det2.class_name == "truck"
assert det2.confidence == pytest.approx(0.87)
assert det2.track_id == 2002
def test_multi_frame_tracking_consistency(self, mock_yolo_model, mock_frame):
"""Test that tracking IDs remain consistent across frames."""
detector = YOLODetector(mock_yolo_model)
# Frame 1
mock_result1 = Mock()
mock_result1.boxes = Mock()
mock_result1.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result1.boxes.id = torch.tensor([5001])
mock_yolo_model.track.return_value = [mock_result1]
detections1 = detector.detect(mock_frame)
# Frame 2 - same object, slightly moved
mock_result2 = Mock()
mock_result2.boxes = Mock()
mock_result2.boxes.data = torch.tensor([
[105, 205, 305, 405, 0.88, 0]
])
mock_result2.boxes.id = torch.tensor([5001]) # Same ID
mock_yolo_model.track.return_value = [mock_result2]
detections2 = detector.detect(mock_frame)
# Should maintain same track ID
assert detections1[0].track_id == detections2[0].track_id == 5001

View file

@ -0,0 +1,882 @@
"""
Unit tests for model management functionality.
"""
import pytest
import os
import tempfile
import threading
import time
from unittest.mock import Mock, patch, MagicMock
import torch
import numpy as np
from detector_worker.models.model_manager import (
ModelManager,
ModelInfo,
ModelConfig,
ModelCache,
ModelLoader,
ModelError,
ModelLoadError,
ModelCacheError
)
from detector_worker.core.exceptions import ConfigurationError
class TestModelConfig:
"""Test model configuration."""
def test_creation(self):
"""Test model config creation."""
config = ModelConfig(
model_id="yolo_v8_car",
model_path="/models/yolo_v8_car.pt",
model_type="detection",
device="cuda:0"
)
assert config.model_id == "yolo_v8_car"
assert config.model_path == "/models/yolo_v8_car.pt"
assert config.model_type == "detection"
assert config.device == "cuda:0"
assert config.confidence_threshold == 0.5
assert config.max_memory_mb == 1024
def test_creation_with_optional_params(self):
"""Test config creation with optional parameters."""
config = ModelConfig(
model_id="classifier_v1",
model_path="/models/classifier.pt",
model_type="classification",
device="cpu",
confidence_threshold=0.8,
max_memory_mb=512,
class_names={0: "car", 1: "truck", 2: "bus"},
preprocessing_config={"resize": (224, 224), "normalize": True}
)
assert config.confidence_threshold == 0.8
assert config.max_memory_mb == 512
assert config.class_names[0] == "car"
assert config.preprocessing_config["resize"] == (224, 224)
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"model_id": "detection_model",
"model_path": "/path/to/model.pt",
"model_type": "detection",
"device": "cuda:0",
"confidence_threshold": 0.75,
"class_names": {0: "person", 1: "vehicle"},
"unknown_field": "ignored"
}
config = ModelConfig.from_dict(config_dict)
assert config.model_id == "detection_model"
assert config.confidence_threshold == 0.75
assert config.class_names[1] == "vehicle"
def test_validation(self):
"""Test config validation."""
# Valid config
valid_config = ModelConfig(
model_id="test_model",
model_path="/valid/path/model.pt",
model_type="detection",
device="cpu"
)
assert valid_config.is_valid() is True
# Invalid config (empty model_id)
invalid_config = ModelConfig(
model_id="",
model_path="/path/model.pt",
model_type="detection",
device="cpu"
)
assert invalid_config.is_valid() is False
def test_get_memory_limit_bytes(self):
"""Test getting memory limit in bytes."""
config = ModelConfig(
model_id="test",
model_path="/path",
model_type="detection",
device="cpu",
max_memory_mb=256
)
assert config.get_memory_limit_bytes() == 256 * 1024 * 1024
class TestModelInfo:
"""Test model information."""
def test_creation(self):
"""Test model info creation."""
config = ModelConfig(
model_id="test_model",
model_path="/path/model.pt",
model_type="detection",
device="cuda:0"
)
mock_model = Mock()
info = ModelInfo(
config=config,
model_instance=mock_model,
load_time=1.5
)
assert info.config == config
assert info.model_instance == mock_model
assert info.load_time == 1.5
assert info.reference_count == 0
assert info.last_used <= time.time()
assert info.memory_usage == 0
def test_increment_reference(self):
"""Test incrementing reference count."""
config = ModelConfig("test", "/path", "detection", "cpu")
info = ModelInfo(config, Mock(), 1.0)
assert info.reference_count == 0
info.increment_reference()
assert info.reference_count == 1
info.increment_reference()
assert info.reference_count == 2
def test_decrement_reference(self):
"""Test decrementing reference count."""
config = ModelConfig("test", "/path", "detection", "cpu")
info = ModelInfo(config, Mock(), 1.0)
info.reference_count = 3
assert info.decrement_reference() == 2
assert info.reference_count == 2
assert info.decrement_reference() == 1
assert info.decrement_reference() == 0
# Should not go below 0
assert info.decrement_reference() == 0
def test_update_usage(self):
"""Test updating usage statistics."""
config = ModelConfig("test", "/path", "detection", "cpu")
info = ModelInfo(config, Mock(), 1.0)
original_time = info.last_used
original_count = info.usage_count
time.sleep(0.01) # Small delay
info.update_usage(memory_usage=512*1024*1024) # 512MB
assert info.last_used > original_time
assert info.usage_count == original_count + 1
assert info.memory_usage == 512*1024*1024
def test_age_calculation(self):
"""Test age calculation."""
config = ModelConfig("test", "/path", "detection", "cpu")
info = ModelInfo(config, Mock(), 1.0)
time.sleep(0.01)
age = info.age()
assert age > 0
assert age < 1 # Should be less than 1 second
def test_get_stats(self):
"""Test getting model statistics."""
config = ModelConfig("test_model", "/path", "detection", "cuda:0")
info = ModelInfo(config, Mock(), 2.5)
info.reference_count = 3
info.usage_count = 100
info.memory_usage = 1024*1024*1024 # 1GB
stats = info.get_stats()
assert stats["model_id"] == "test_model"
assert stats["device"] == "cuda:0"
assert stats["load_time"] == 2.5
assert stats["reference_count"] == 3
assert stats["usage_count"] == 100
assert stats["memory_usage_mb"] == 1024
assert "age_seconds" in stats
class TestModelLoader:
"""Test model loading functionality."""
def test_creation(self):
"""Test model loader creation."""
loader = ModelLoader()
assert loader.supported_formats == [".pt", ".pth", ".onnx", ".trt"]
assert loader.default_device == "cpu"
def test_detect_device_cuda_available(self):
"""Test device detection when CUDA is available."""
loader = ModelLoader()
with patch('torch.cuda.is_available', return_value=True):
with patch('torch.cuda.device_count', return_value=2):
device = loader.detect_optimal_device()
assert device == "cuda:0"
def test_detect_device_cuda_unavailable(self):
"""Test device detection when CUDA is not available."""
loader = ModelLoader()
with patch('torch.cuda.is_available', return_value=False):
device = loader.detect_optimal_device()
assert device == "cpu"
def test_load_pytorch_model(self):
"""Test loading PyTorch model."""
loader = ModelLoader()
with patch('torch.load') as mock_torch_load:
with patch('os.path.exists', return_value=True):
mock_model = Mock()
mock_torch_load.return_value = mock_model
config = ModelConfig(
model_id="test_model",
model_path="/path/to/model.pt",
model_type="detection",
device="cpu"
)
loaded_model = loader.load_model(config)
assert loaded_model == mock_model
mock_torch_load.assert_called_once_with("/path/to/model.pt", map_location="cpu")
def test_load_model_file_not_exists(self):
"""Test loading model when file doesn't exist."""
loader = ModelLoader()
with patch('os.path.exists', return_value=False):
config = ModelConfig(
model_id="missing_model",
model_path="/nonexistent/model.pt",
model_type="detection",
device="cpu"
)
with pytest.raises(ModelLoadError) as exc_info:
loader.load_model(config)
assert "does not exist" in str(exc_info.value)
def test_load_model_invalid_format(self):
"""Test loading model with invalid format."""
loader = ModelLoader()
with patch('os.path.exists', return_value=True):
config = ModelConfig(
model_id="invalid_model",
model_path="/path/to/model.invalid",
model_type="detection",
device="cpu"
)
with pytest.raises(ModelLoadError) as exc_info:
loader.load_model(config)
assert "unsupported format" in str(exc_info.value).lower()
def test_load_model_torch_error(self):
"""Test loading model with torch loading error."""
loader = ModelLoader()
with patch('os.path.exists', return_value=True):
with patch('torch.load', side_effect=RuntimeError("CUDA out of memory")):
config = ModelConfig(
model_id="error_model",
model_path="/path/to/model.pt",
model_type="detection",
device="cuda:0"
)
with pytest.raises(ModelLoadError) as exc_info:
loader.load_model(config)
assert "CUDA out of memory" in str(exc_info.value)
def test_validate_model_pytorch(self):
"""Test validating PyTorch model."""
loader = ModelLoader()
mock_model = Mock()
mock_model.__class__.__module__ = "torch.nn"
config = ModelConfig("test", "/path", "detection", "cpu")
is_valid = loader.validate_model(mock_model, config)
assert is_valid is True
def test_validate_model_invalid(self):
"""Test validating invalid model."""
loader = ModelLoader()
invalid_model = "not_a_model"
config = ModelConfig("test", "/path", "detection", "cpu")
is_valid = loader.validate_model(invalid_model, config)
assert is_valid is False
def test_estimate_model_memory(self):
"""Test estimating model memory usage."""
loader = ModelLoader()
mock_model = Mock()
mock_param1 = Mock()
mock_param1.numel.return_value = 1000000 # 1M parameters
mock_param1.element_size.return_value = 4 # 4 bytes per parameter
mock_param2 = Mock()
mock_param2.numel.return_value = 500000 # 0.5M parameters
mock_param2.element_size.return_value = 4
mock_model.parameters.return_value = [mock_param1, mock_param2]
memory_bytes = loader.estimate_memory_usage(mock_model)
expected_bytes = (1000000 + 500000) * 4 # 6MB
assert memory_bytes == expected_bytes
class TestModelCache:
"""Test model caching functionality."""
def test_creation(self):
"""Test model cache creation."""
cache = ModelCache(max_size=5, max_memory_mb=2048)
assert cache.max_size == 5
assert cache.max_memory_mb == 2048
assert len(cache.models) == 0
assert len(cache.access_order) == 0
def test_put_and_get_model(self):
"""Test putting and getting model from cache."""
cache = ModelCache(max_size=3)
config = ModelConfig("test_model", "/path", "detection", "cpu")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 1.5)
cache.put("test_model", model_info)
retrieved_info = cache.get("test_model")
assert retrieved_info == model_info
assert retrieved_info.reference_count == 1 # Should be incremented on get
def test_get_nonexistent_model(self):
"""Test getting non-existent model."""
cache = ModelCache(max_size=3)
result = cache.get("nonexistent_model")
assert result is None
def test_contains_check(self):
"""Test checking if model exists in cache."""
cache = ModelCache(max_size=3)
config = ModelConfig("test_model", "/path", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
cache.put("test_model", model_info)
assert cache.contains("test_model") is True
assert cache.contains("nonexistent_model") is False
def test_remove_model(self):
"""Test removing model from cache."""
cache = ModelCache(max_size=3)
config = ModelConfig("test_model", "/path", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
cache.put("test_model", model_info)
assert cache.contains("test_model") is True
removed_info = cache.remove("test_model")
assert removed_info == model_info
assert cache.contains("test_model") is False
def test_lru_eviction(self):
"""Test LRU eviction policy."""
cache = ModelCache(max_size=2)
# Add models to fill cache
for i in range(2):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
cache.put(f"model_{i}", model_info)
# Access model_0 to make it recently used
cache.get("model_0")
# Add another model (should evict model_1, the least recently used)
config = ModelConfig("model_2", "/path_2", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
cache.put("model_2", model_info)
assert cache.size() == 2
assert cache.contains("model_0") is True # Recently accessed
assert cache.contains("model_1") is False # Evicted
assert cache.contains("model_2") is True # Newly added
def test_memory_based_eviction(self):
"""Test memory-based eviction."""
cache = ModelCache(max_size=10, max_memory_mb=1) # 1MB limit
# Add model that uses 0.8MB
config1 = ModelConfig("model_1", "/path_1", "detection", "cpu")
model1 = Mock()
info1 = ModelInfo(config1, model1, 1.0)
info1.memory_usage = 0.8 * 1024 * 1024 # 0.8MB
cache.put("model_1", info1)
# Add model that would exceed memory limit
config2 = ModelConfig("model_2", "/path_2", "detection", "cpu")
model2 = Mock()
info2 = ModelInfo(config2, model2, 1.0)
info2.memory_usage = 0.5 * 1024 * 1024 # 0.5MB
cache.put("model_2", info2)
# First model should be evicted due to memory constraint
assert cache.contains("model_1") is False
assert cache.contains("model_2") is True
def test_get_stats(self):
"""Test getting cache statistics."""
cache = ModelCache(max_size=5)
# Add some models
for i in range(3):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
model_info.memory_usage = 100 * 1024 * 1024 # 100MB each
cache.put(f"model_{i}", model_info)
# Access some models
cache.get("model_0")
cache.get("model_1")
cache.get("nonexistent") # Miss
stats = cache.get_stats()
assert stats["size"] == 3
assert stats["max_size"] == 5
assert stats["hits"] == 2
assert stats["misses"] == 1
assert stats["hit_rate"] == 2/3
assert stats["memory_usage_mb"] == 300
def test_clear_cache(self):
"""Test clearing entire cache."""
cache = ModelCache(max_size=5)
# Add models
for i in range(3):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
cache.put(f"model_{i}", model_info)
assert cache.size() == 3
cache.clear()
assert cache.size() == 0
assert len(cache.models) == 0
assert len(cache.access_order) == 0
class TestModelManager:
"""Test main model manager functionality."""
def test_initialization(self):
"""Test model manager initialization."""
manager = ModelManager()
assert isinstance(manager.cache, ModelCache)
assert isinstance(manager.loader, ModelLoader)
assert manager.models_directory == "models"
assert manager.default_device == "cpu"
def test_initialization_with_config(self):
"""Test initialization with custom configuration."""
config = {
"models_directory": "/custom/models",
"default_device": "cuda:0",
"cache_max_size": 20,
"cache_max_memory_mb": 4096
}
manager = ModelManager(config)
assert manager.models_directory == "/custom/models"
assert manager.default_device == "cuda:0"
assert manager.cache.max_size == 20
assert manager.cache.max_memory_mb == 4096
def test_load_model_new(self):
"""Test loading new model."""
manager = ModelManager()
config = ModelConfig(
model_id="test_model",
model_path="/path/to/model.pt",
model_type="detection",
device="cpu"
)
with patch.object(manager.loader, 'load_model') as mock_load:
with patch.object(manager.loader, 'estimate_memory_usage', return_value=512*1024*1024):
mock_model = Mock()
mock_load.return_value = mock_model
loaded_model = manager.load_model(config)
assert loaded_model == mock_model
assert manager.cache.contains("test_model") is True
mock_load.assert_called_once_with(config)
def test_load_model_from_cache(self):
"""Test loading model from cache."""
manager = ModelManager()
# Pre-populate cache
config = ModelConfig("cached_model", "/path", "detection", "cpu")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 1.0)
manager.cache.put("cached_model", model_info)
with patch.object(manager.loader, 'load_model') as mock_load:
loaded_model = manager.load_model(config)
assert loaded_model == mock_model
mock_load.assert_not_called() # Should not load from disk
def test_get_model_by_id(self):
"""Test getting model by ID."""
manager = ModelManager()
config = ModelConfig("test_model", "/path", "detection", "cpu")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 1.0)
manager.cache.put("test_model", model_info)
retrieved_model = manager.get_model("test_model")
assert retrieved_model == mock_model
def test_get_nonexistent_model(self):
"""Test getting non-existent model."""
manager = ModelManager()
model = manager.get_model("nonexistent_model")
assert model is None
def test_unload_model_with_references(self):
"""Test unloading model with active references."""
manager = ModelManager()
config = ModelConfig("ref_model", "/path", "detection", "cpu")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 1.0)
model_info.reference_count = 2 # Active references
manager.cache.put("ref_model", model_info)
result = manager.unload_model("ref_model")
assert result is False # Should not unload with active references
assert manager.cache.contains("ref_model") is True
def test_unload_model_no_references(self):
"""Test unloading model without references."""
manager = ModelManager()
config = ModelConfig("no_ref_model", "/path", "detection", "cpu")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 1.0)
model_info.reference_count = 0 # No references
manager.cache.put("no_ref_model", model_info)
result = manager.unload_model("no_ref_model")
assert result is True
assert manager.cache.contains("no_ref_model") is False
def test_list_loaded_models(self):
"""Test listing loaded models."""
manager = ModelManager()
# Add models to cache
for i in range(3):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
manager.cache.put(f"model_{i}", model_info)
loaded_models = manager.list_loaded_models()
assert len(loaded_models) == 3
assert all(info["model_id"].startswith("model_") for info in loaded_models)
def test_get_model_info(self):
"""Test getting model information."""
manager = ModelManager()
config = ModelConfig("info_model", "/path", "detection", "cuda:0")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 2.5)
model_info.usage_count = 10
manager.cache.put("info_model", model_info)
info = manager.get_model_info("info_model")
assert info is not None
assert info["model_id"] == "info_model"
assert info["device"] == "cuda:0"
assert info["load_time"] == 2.5
assert info["usage_count"] == 10
def test_cleanup_unused_models(self):
"""Test cleaning up unused models."""
manager = ModelManager()
# Add models with different reference counts
models_data = [
("used_model", 2), # Has references
("unused_model_1", 0), # No references
("unused_model_2", 0) # No references
]
for model_id, ref_count in models_data:
config = ModelConfig(model_id, f"/path/{model_id}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
model_info.reference_count = ref_count
manager.cache.put(model_id, model_info)
cleaned_count = manager.cleanup_unused_models()
assert cleaned_count == 2 # Two unused models cleaned
assert manager.cache.contains("used_model") is True
assert manager.cache.contains("unused_model_1") is False
assert manager.cache.contains("unused_model_2") is False
def test_get_memory_usage(self):
"""Test getting total memory usage."""
manager = ModelManager()
# Add models with different memory usage
memory_sizes = [256, 512, 1024] # MB
for i, memory_mb in enumerate(memory_sizes):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
model_info.memory_usage = memory_mb * 1024 * 1024 # Convert to bytes
manager.cache.put(f"model_{i}", model_info)
total_usage = manager.get_memory_usage()
expected_bytes = sum(memory_sizes) * 1024 * 1024
assert total_usage == expected_bytes
def test_health_check(self):
"""Test model manager health check."""
manager = ModelManager()
# Add models
for i in range(3):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
model_info.memory_usage = 100 * 1024 * 1024 # 100MB each
manager.cache.put(f"model_{i}", model_info)
health_report = manager.health_check()
assert health_report["status"] == "healthy"
assert health_report["loaded_models"] == 3
assert health_report["total_memory_mb"] == 300
assert health_report["cache_hit_rate"] >= 0
class TestModelManagerIntegration:
"""Integration tests for model manager."""
def test_concurrent_model_loading(self):
"""Test concurrent model loading."""
manager = ModelManager()
# Mock loader to simulate loading time
def slow_load(config):
time.sleep(0.1) # Simulate loading time
mock_model = Mock()
mock_model.model_id = config.model_id
return mock_model
with patch.object(manager.loader, 'load_model', side_effect=slow_load):
with patch.object(manager.loader, 'estimate_memory_usage', return_value=100*1024*1024):
# Create multiple threads loading different models
results = {}
errors = []
def load_model_thread(model_id):
try:
config = ModelConfig(
model_id=model_id,
model_path=f"/path/{model_id}.pt",
model_type="detection",
device="cpu"
)
model = manager.load_model(config)
results[model_id] = model
except Exception as e:
errors.append((model_id, str(e)))
threads = []
for i in range(5):
thread = threading.Thread(target=load_model_thread, args=(f"model_{i}",))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# All models should be loaded successfully
assert len(errors) == 0
assert len(results) == 5
assert len(manager.cache.models) == 5
def test_memory_pressure_handling(self):
"""Test handling memory pressure."""
# Create manager with small memory limit
manager = ModelManager({
"cache_max_memory_mb": 200 # 200MB limit
})
with patch.object(manager.loader, 'load_model') as mock_load:
with patch.object(manager.loader, 'estimate_memory_usage', return_value=100*1024*1024): # 100MB per model
def create_mock_model(config):
mock_model = Mock()
mock_model.model_id = config.model_id
return mock_model
mock_load.side_effect = create_mock_model
# Load models that exceed memory limit
for i in range(4): # 4 * 100MB = 400MB > 200MB limit
config = ModelConfig(
model_id=f"large_model_{i}",
model_path=f"/path/large_model_{i}.pt",
model_type="detection",
device="cpu"
)
manager.load_model(config)
# Should not exceed memory limit due to eviction
total_memory = manager.get_memory_usage()
memory_limit = 200 * 1024 * 1024
assert total_memory <= memory_limit
def test_model_lifecycle_management(self):
"""Test complete model lifecycle."""
manager = ModelManager()
with patch.object(manager.loader, 'load_model') as mock_load:
with patch.object(manager.loader, 'estimate_memory_usage', return_value=50*1024*1024):
mock_model = Mock()
mock_load.return_value = mock_model
config = ModelConfig(
model_id="lifecycle_model",
model_path="/path/lifecycle_model.pt",
model_type="detection",
device="cpu"
)
# 1. Load model
loaded_model = manager.load_model(config)
assert loaded_model == mock_model
assert manager.cache.contains("lifecycle_model") is True
# 2. Get model multiple times (increase usage)
for _ in range(5):
model = manager.get_model("lifecycle_model")
assert model == mock_model
# 3. Check model info
info = manager.get_model_info("lifecycle_model")
assert info["usage_count"] >= 5
# 4. Simulate model still in use
model_info = manager.cache.get("lifecycle_model")
model_info.reference_count = 1
# Should not unload while in use
unloaded = manager.unload_model("lifecycle_model")
assert unloaded is False
assert manager.cache.contains("lifecycle_model") is True
# 5. Release reference and unload
model_info.reference_count = 0
unloaded = manager.unload_model("lifecycle_model")
assert unloaded is True
assert manager.cache.contains("lifecycle_model") is False
def test_error_recovery(self):
"""Test error recovery scenarios."""
manager = ModelManager()
# Test loading model that fails initially then succeeds
call_count = 0
def failing_then_success_load(config):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ModelLoadError("First attempt failed")
return Mock()
with patch.object(manager.loader, 'load_model', side_effect=failing_then_success_load):
with patch.object(manager.loader, 'estimate_memory_usage', return_value=50*1024*1024):
config = ModelConfig(
model_id="retry_model",
model_path="/path/retry_model.pt",
model_type="detection",
device="cpu"
)
# First attempt should fail
with pytest.raises(ModelLoadError):
manager.load_model(config)
# Model should not be in cache
assert manager.cache.contains("retry_model") is False
# Second attempt should succeed
model = manager.load_model(config)
assert model is not None
assert manager.cache.contains("retry_model") is True

View file

@ -0,0 +1,959 @@
"""
Unit tests for action execution functionality.
"""
import pytest
import asyncio
import json
import base64
import numpy as np
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
from detector_worker.pipeline.action_executor import (
ActionExecutor,
ActionResult,
ActionType,
RedisAction,
PostgreSQLAction,
FileAction
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import ActionError, RedisError, DatabaseError
class TestActionResult:
"""Test action execution result."""
def test_creation_success(self):
"""Test successful action result creation."""
result = ActionResult(
action_type=ActionType.REDIS_SAVE,
success=True,
execution_time=0.05,
metadata={"key": "saved_image_key", "expiry": 600}
)
assert result.action_type == ActionType.REDIS_SAVE
assert result.success is True
assert result.execution_time == 0.05
assert result.metadata["key"] == "saved_image_key"
assert result.error is None
def test_creation_failure(self):
"""Test failed action result creation."""
result = ActionResult(
action_type=ActionType.POSTGRESQL_INSERT,
success=False,
error="Database connection failed",
execution_time=0.02
)
assert result.action_type == ActionType.POSTGRESQL_INSERT
assert result.success is False
assert result.error == "Database connection failed"
assert result.metadata == {}
class TestRedisAction:
"""Test Redis action implementations."""
def test_creation(self):
"""Test Redis action creation."""
action_config = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{timestamp}:{session_id}",
"expire_seconds": 600
}
action = RedisAction(action_config)
assert action.action_type == ActionType.REDIS_SAVE
assert action.region == "car"
assert action.key_template == "inference:{display_id}:{timestamp}:{session_id}"
assert action.expire_seconds == 600
def test_resolve_key_template(self):
"""Test key template resolution."""
action_config = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600
}
action = RedisAction(action_config)
context = {
"display_id": "display_001",
"timestamp": "1640995200000",
"session_id": "session_123",
"filename": "detection.jpg"
}
resolved_key = action.resolve_key(context)
expected_key = "inference:display_001:1640995200000:session_123:detection.jpg"
assert resolved_key == expected_key
def test_resolve_key_missing_variable(self):
"""Test key resolution with missing variable."""
action_config = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{missing_var}",
"expire_seconds": 600
}
action = RedisAction(action_config)
context = {"display_id": "display_001"}
with pytest.raises(ActionError):
action.resolve_key(context)
class TestPostgreSQLAction:
"""Test PostgreSQL action implementations."""
def test_creation_insert(self):
"""Test PostgreSQL insert action creation."""
action_config = {
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"bbox_x1": "{bbox.x1}",
"created_at": "NOW()"
}
}
action = PostgreSQLAction(action_config)
assert action.action_type == ActionType.POSTGRESQL_INSERT
assert action.table == "detections"
assert len(action.fields) == 6
assert action.key_field is None
def test_creation_update(self):
"""Test PostgreSQL update action creation."""
action_config = {
"type": "postgresql_update_combined",
"table": "car_info",
"key_field": "session_id",
"fields": {
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"updated_at": "NOW()"
},
"waitForBranches": ["car_brand_cls", "car_bodytype_cls"]
}
action = PostgreSQLAction(action_config)
assert action.action_type == ActionType.POSTGRESQL_UPDATE
assert action.table == "car_info"
assert action.key_field == "session_id"
assert action.wait_for_branches == ["car_brand_cls", "car_bodytype_cls"]
def test_resolve_field_values(self):
"""Test field value resolution."""
action_config = {
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"brand": "{car_brand_cls.brand}"
}
}
action = PostgreSQLAction(action_config)
context = {
"camera_id": "camera_001",
"class": "car",
"confidence": 0.85
}
branch_results = {
"car_brand_cls": {"brand": "Toyota", "confidence": 0.78}
}
resolved_fields = action.resolve_field_values(context, branch_results)
assert resolved_fields["camera_id"] == "camera_001"
assert resolved_fields["detection_class"] == "car"
assert resolved_fields["confidence"] == 0.85
assert resolved_fields["brand"] == "Toyota"
class TestFileAction:
"""Test file action implementations."""
def test_creation(self):
"""Test file action creation."""
action_config = {
"type": "save_image",
"path": "/tmp/detections/{camera_id}_{timestamp}.jpg",
"region": "car",
"format": "jpeg",
"quality": 85
}
action = FileAction(action_config)
assert action.action_type == ActionType.SAVE_IMAGE
assert action.path_template == "/tmp/detections/{camera_id}_{timestamp}.jpg"
assert action.region == "car"
assert action.format == "jpeg"
assert action.quality == 85
def test_resolve_path_template(self):
"""Test path template resolution."""
action_config = {
"type": "save_image",
"path": "/tmp/detections/{camera_id}/{date}/{timestamp}.jpg"
}
action = FileAction(action_config)
context = {
"camera_id": "camera_001",
"timestamp": "1640995200000",
"date": "2022-01-01"
}
resolved_path = action.resolve_path(context)
expected_path = "/tmp/detections/camera_001/2022-01-01/1640995200000.jpg"
assert resolved_path == expected_path
class TestActionExecutor:
"""Test action execution functionality."""
def test_initialization(self):
"""Test action executor initialization."""
executor = ActionExecutor()
assert executor.redis_client is None
assert executor.db_manager is None
assert executor.max_concurrent_actions == 10
assert executor.action_timeout == 30.0
def test_initialization_with_clients(self, mock_redis_client, mock_database_connection):
"""Test initialization with client instances."""
executor = ActionExecutor(
redis_client=mock_redis_client,
db_manager=mock_database_connection
)
assert executor.redis_client is mock_redis_client
assert executor.db_manager is mock_database_connection
@pytest.mark.asyncio
async def test_execute_actions_empty_list(self):
"""Test executing empty action list."""
executor = ActionExecutor()
context = {
"camera_id": "camera_001",
"session_id": "session_123"
}
results = await executor.execute_actions([], {}, context)
assert results == []
@pytest.mark.asyncio
async def test_execute_redis_save_action(self, mock_redis_client, mock_frame):
"""Test executing Redis save image action."""
executor = ActionExecutor(redis_client=mock_redis_client)
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{camera_id}:{session_id}",
"expire_seconds": 600
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"frame_data": mock_frame
}
# Mock successful Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.REDIS_SAVE
# Verify Redis calls
mock_redis_client.set.assert_called_once()
mock_redis_client.expire.assert_called_once()
@pytest.mark.asyncio
async def test_execute_postgresql_insert_action(self, mock_database_connection):
"""Test executing PostgreSQL insert action."""
# Mock database manager
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(db_manager=mock_db_manager)
actions = [
{
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"created_at": "NOW()"
}
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"class": "car",
"confidence": 0.9
}
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.POSTGRESQL_INSERT
# Verify database call
mock_db_manager.execute_query.assert_called_once()
call_args = mock_db_manager.execute_query.call_args[0]
assert "INSERT INTO detections" in call_args[0]
@pytest.mark.asyncio
async def test_execute_postgresql_update_action(self, mock_database_connection):
"""Test executing PostgreSQL update action."""
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(db_manager=mock_db_manager)
actions = [
{
"type": "postgresql_update_combined",
"table": "car_info",
"key_field": "session_id",
"fields": {
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"updated_at": "NOW()"
},
"waitForBranches": ["car_brand_cls", "car_bodytype_cls"]
}
]
regions = {}
context = {
"session_id": "session_123"
}
branch_results = {
"car_brand_cls": {"brand": "Toyota"},
"car_bodytype_cls": {"body_type": "Sedan"}
}
results = await executor.execute_actions(actions, regions, context, branch_results)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.POSTGRESQL_UPDATE
# Verify database call
mock_db_manager.execute_query.assert_called_once()
call_args = mock_db_manager.execute_query.call_args[0]
assert "UPDATE car_info SET" in call_args[0]
assert "WHERE session_id" in call_args[0]
@pytest.mark.asyncio
async def test_execute_file_save_action(self, mock_frame):
"""Test executing file save action."""
executor = ActionExecutor()
actions = [
{
"type": "save_image",
"path": "/tmp/test_{camera_id}_{timestamp}.jpg",
"region": "car",
"format": "jpeg",
"quality": 85
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"timestamp": "1640995200000",
"frame_data": mock_frame
}
with patch('cv2.imwrite') as mock_imwrite:
mock_imwrite.return_value = True
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.SAVE_IMAGE
# Verify file save call
mock_imwrite.assert_called_once()
call_args = mock_imwrite.call_args
assert "/tmp/test_camera_001_1640995200000.jpg" in call_args[0][0]
@pytest.mark.asyncio
async def test_execute_actions_parallel(self, mock_redis_client):
"""Test parallel execution of multiple actions."""
executor = ActionExecutor(redis_client=mock_redis_client)
# Multiple Redis actions
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:car:{session_id}",
"expire_seconds": 600
},
{
"type": "redis_publish",
"channel": "detections",
"message": "{camera_id}:car_detected"
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
}
# Mock Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
mock_redis_client.publish.return_value = 1
import time
start_time = time.time()
results = await executor.execute_actions(actions, regions, context)
execution_time = time.time() - start_time
assert len(results) == 2
assert all(result.success for result in results)
# Should execute in parallel (faster than sequential)
assert execution_time < 0.1 # Allow some overhead
@pytest.mark.asyncio
async def test_execute_actions_error_handling(self, mock_redis_client):
"""Test error handling in action execution."""
executor = ActionExecutor(redis_client=mock_redis_client)
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{session_id}",
"expire_seconds": 600
},
{
"type": "redis_save_image", # This one will fail
"region": "truck", # Region not detected
"key": "inference:truck:{session_id}",
"expire_seconds": 600
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
# No truck region
}
context = {
"session_id": "session_123",
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
}
# Mock Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 2
assert results[0].success is True # Car action succeeds
assert results[1].success is False # Truck action fails
assert "Region 'truck' not found" in results[1].error
@pytest.mark.asyncio
async def test_execute_actions_timeout(self, mock_redis_client):
"""Test action execution timeout."""
config = {"action_timeout": 0.001} # Very short timeout
executor = ActionExecutor(redis_client=mock_redis_client, config=config)
def slow_redis_operation(*args, **kwargs):
import time
time.sleep(1) # Longer than timeout
return True
mock_redis_client.set.side_effect = slow_redis_operation
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{session_id}",
"expire_seconds": 600
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"session_id": "session_123",
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
}
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is False
assert "timeout" in results[0].error.lower()
@pytest.mark.asyncio
async def test_execute_redis_publish_action(self, mock_redis_client):
"""Test executing Redis publish action."""
executor = ActionExecutor(redis_client=mock_redis_client)
actions = [
{
"type": "redis_publish",
"channel": "detections:{camera_id}",
"message": {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"timestamp": "{timestamp}"
}
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"class": "car",
"confidence": 0.9,
"timestamp": "1640995200000"
}
mock_redis_client.publish.return_value = 1
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.REDIS_PUBLISH
# Verify publish call
mock_redis_client.publish.assert_called_once()
call_args = mock_redis_client.publish.call_args
assert call_args[0][0] == "detections:camera_001" # Channel
# Message should be JSON
message = call_args[0][1]
parsed_message = json.loads(message)
assert parsed_message["camera_id"] == "camera_001"
assert parsed_message["detection_class"] == "car"
@pytest.mark.asyncio
async def test_execute_conditional_action(self):
"""Test executing conditional actions."""
executor = ActionExecutor()
actions = [
{
"type": "conditional",
"condition": "{confidence} > 0.8",
"actions": [
{
"type": "log",
"message": "High confidence detection: {class} ({confidence})"
}
]
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.95, # High confidence
"detection": DetectionResult("car", 0.95, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"class": "car",
"confidence": 0.95
}
with patch('logging.info') as mock_log:
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
# Should have logged the message
mock_log.assert_called_once()
log_message = mock_log.call_args[0][0]
assert "High confidence detection: car (0.95)" in log_message
def test_crop_region_from_frame(self, mock_frame):
"""Test cropping region from frame."""
executor = ActionExecutor()
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
cropped = executor._crop_region_from_frame(mock_frame, detection.bbox)
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
def test_encode_image_base64(self, mock_frame):
"""Test encoding image to base64."""
executor = ActionExecutor()
# Crop a small region
cropped_frame = mock_frame[200:400, 100:300] # 200x200 region
with patch('cv2.imencode') as mock_imencode:
# Mock successful encoding
mock_imencode.return_value = (True, np.array([1, 2, 3, 4], dtype=np.uint8))
encoded = executor._encode_image_base64(cropped_frame, format="jpeg")
# Should return base64 string
assert isinstance(encoded, str)
assert len(encoded) > 0
# Verify encoding call
mock_imencode.assert_called_once()
assert mock_imencode.call_args[0][0] == '.jpg'
def test_build_insert_query(self):
"""Test building INSERT SQL query."""
executor = ActionExecutor()
table = "detections"
fields = {
"camera_id": "camera_001",
"detection_class": "car",
"confidence": 0.9,
"created_at": "NOW()"
}
query, values = executor._build_insert_query(table, fields)
assert "INSERT INTO detections" in query
assert "camera_id, detection_class, confidence, created_at" in query
assert "VALUES (%s, %s, %s, NOW())" in query
assert values == ["camera_001", "car", 0.9]
def test_build_update_query(self):
"""Test building UPDATE SQL query."""
executor = ActionExecutor()
table = "car_info"
fields = {
"car_brand": "Toyota",
"car_body_type": "Sedan",
"updated_at": "NOW()"
}
key_field = "session_id"
key_value = "session_123"
query, values = executor._build_update_query(table, fields, key_field, key_value)
assert "UPDATE car_info SET" in query
assert "car_brand = %s" in query
assert "car_body_type = %s" in query
assert "updated_at = NOW()" in query
assert "WHERE session_id = %s" in query
assert values == ["Toyota", "Sedan", "session_123"]
def test_evaluate_condition(self):
"""Test evaluating conditional expressions."""
executor = ActionExecutor()
context = {
"confidence": 0.85,
"class": "car",
"area": 40000
}
# Simple comparisons
assert executor._evaluate_condition("{confidence} > 0.8", context) is True
assert executor._evaluate_condition("{confidence} < 0.8", context) is False
assert executor._evaluate_condition("{confidence} >= 0.85", context) is True
assert executor._evaluate_condition("{confidence} == 0.85", context) is True
# String comparisons
assert executor._evaluate_condition("{class} == 'car'", context) is True
assert executor._evaluate_condition("{class} != 'truck'", context) is True
# Complex conditions
assert executor._evaluate_condition("{confidence} > 0.8 and {area} > 30000", context) is True
assert executor._evaluate_condition("{confidence} > 0.9 or {area} > 30000", context) is True
assert executor._evaluate_condition("{confidence} > 0.9 and {area} < 30000", context) is False
def test_validate_action_config(self):
"""Test action configuration validation."""
executor = ActionExecutor()
# Valid Redis action
valid_redis = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{session_id}",
"expire_seconds": 600
}
assert executor._validate_action_config(valid_redis) is True
# Invalid action (missing required fields)
invalid_action = {
"type": "redis_save_image"
# Missing region and key
}
with pytest.raises(ActionError):
executor._validate_action_config(invalid_action)
# Unknown action type
unknown_action = {
"type": "unknown_action_type",
"some_field": "value"
}
with pytest.raises(ActionError):
executor._validate_action_config(unknown_action)
class TestActionExecutorIntegration:
"""Integration tests for action execution."""
@pytest.mark.asyncio
async def test_complete_detection_workflow(self, mock_redis_client, mock_frame):
"""Test complete detection workflow with multiple actions."""
# Mock database manager
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(
redis_client=mock_redis_client,
db_manager=mock_db_manager
)
# Complete action workflow
actions = [
# Save cropped image to Redis
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{camera_id}:{timestamp}:{session_id}:car",
"expire_seconds": 600
},
# Insert initial detection record
{
"type": "postgresql_insert",
"table": "car_detections",
"fields": {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"bbox_x1": "{bbox.x1}",
"bbox_y1": "{bbox.y1}",
"bbox_x2": "{bbox.x2}",
"bbox_y2": "{bbox.y2}",
"created_at": "NOW()"
}
},
# Publish detection event
{
"type": "redis_publish",
"channel": "detections:{camera_id}",
"message": {
"event": "car_detected",
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"timestamp": "{timestamp}"
}
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.92,
"detection": DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"timestamp": "1640995200000",
"class": "car",
"confidence": 0.92,
"bbox": {"x1": 100, "y1": 200, "x2": 300, "y2": 400},
"frame_data": mock_frame
}
# Mock all Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
mock_redis_client.publish.return_value = 1
results = await executor.execute_actions(actions, regions, context)
# All actions should succeed
assert len(results) == 3
assert all(result.success for result in results)
# Verify all operations were called
mock_redis_client.set.assert_called_once() # Image save
mock_redis_client.expire.assert_called_once() # Set expiry
mock_redis_client.publish.assert_called_once() # Publish event
mock_db_manager.execute_query.assert_called_once() # Database insert
@pytest.mark.asyncio
async def test_branch_dependent_actions(self, mock_database_connection):
"""Test actions that depend on branch results."""
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(db_manager=mock_db_manager)
# Action that waits for classification branches
actions = [
{
"type": "postgresql_update_combined",
"table": "car_info",
"key_field": "session_id",
"fields": {
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"car_color": "{car_color_cls.color}",
"confidence_brand": "{car_brand_cls.confidence}",
"confidence_bodytype": "{car_bodytype_cls.confidence}",
"updated_at": "NOW()"
},
"waitForBranches": ["car_brand_cls", "car_bodytype_cls", "car_color_cls"]
}
]
regions = {}
context = {
"session_id": "session_123"
}
# Simulated branch results
branch_results = {
"car_brand_cls": {"brand": "Toyota", "confidence": 0.87},
"car_bodytype_cls": {"body_type": "Sedan", "confidence": 0.82},
"car_color_cls": {"color": "Red", "confidence": 0.79}
}
results = await executor.execute_actions(actions, regions, context, branch_results)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.POSTGRESQL_UPDATE
# Verify database call with all branch data
mock_db_manager.execute_query.assert_called_once()
call_args = mock_db_manager.execute_query.call_args
query = call_args[0][0]
values = call_args[0][1]
assert "UPDATE car_info SET" in query
assert "car_brand = %s" in query
assert "car_body_type = %s" in query
assert "car_color = %s" in query
assert "WHERE session_id = %s" in query
assert "Toyota" in values
assert "Sedan" in values
assert "Red" in values
assert "session_123" in values

View file

@ -0,0 +1,786 @@
"""
Unit tests for field mapping and template resolution.
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime
import json
from detector_worker.pipeline.field_mapper import (
FieldMapper,
MappingContext,
TemplateResolver,
FieldMappingError,
NestedFieldAccessor
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
class TestNestedFieldAccessor:
"""Test nested field access functionality."""
def test_get_nested_value_simple(self):
"""Test getting simple nested values."""
data = {
"user": {
"name": "John",
"age": 30,
"address": {
"city": "New York",
"zip": "10001"
}
}
}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "user.name") == "John"
assert accessor.get_nested_value(data, "user.age") == 30
assert accessor.get_nested_value(data, "user.address.city") == "New York"
assert accessor.get_nested_value(data, "user.address.zip") == "10001"
def test_get_nested_value_array_access(self):
"""Test accessing array elements."""
data = {
"results": [
{"score": 0.9, "label": "car"},
{"score": 0.8, "label": "truck"}
],
"bbox": [100, 200, 300, 400]
}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "results[0].score") == 0.9
assert accessor.get_nested_value(data, "results[0].label") == "car"
assert accessor.get_nested_value(data, "results[1].score") == 0.8
assert accessor.get_nested_value(data, "bbox[0]") == 100
assert accessor.get_nested_value(data, "bbox[3]") == 400
def test_get_nested_value_nonexistent_path(self):
"""Test accessing non-existent paths."""
data = {"user": {"name": "John"}}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "user.nonexistent") is None
assert accessor.get_nested_value(data, "nonexistent.field") is None
assert accessor.get_nested_value(data, "user.address.city") is None
def test_get_nested_value_with_default(self):
"""Test getting nested values with default fallback."""
data = {"user": {"name": "John"}}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "user.age", default=25) == 25
assert accessor.get_nested_value(data, "user.name", default="Unknown") == "John"
def test_set_nested_value(self):
"""Test setting nested values."""
data = {}
accessor = NestedFieldAccessor()
accessor.set_nested_value(data, "user.name", "John")
assert data["user"]["name"] == "John"
accessor.set_nested_value(data, "user.address.city", "New York")
assert data["user"]["address"]["city"] == "New York"
accessor.set_nested_value(data, "scores[0]", 0.95)
assert data["scores"][0] == 0.95
def test_set_nested_value_overwrite(self):
"""Test overwriting existing nested values."""
data = {"user": {"name": "John", "age": 30}}
accessor = NestedFieldAccessor()
accessor.set_nested_value(data, "user.name", "Jane")
assert data["user"]["name"] == "Jane"
assert data["user"]["age"] == 30 # Should not affect other fields
class TestTemplateResolver:
"""Test template string resolution."""
def test_resolve_simple_template(self):
"""Test resolving simple template variables."""
resolver = TemplateResolver()
template = "Hello {name}, you are {age} years old"
context = {"name": "John", "age": 30}
result = resolver.resolve(template, context)
assert result == "Hello John, you are 30 years old"
def test_resolve_nested_template(self):
"""Test resolving nested field templates."""
resolver = TemplateResolver()
template = "User: {user.name} from {user.address.city}"
context = {
"user": {
"name": "John",
"address": {"city": "New York", "zip": "10001"}
}
}
result = resolver.resolve(template, context)
assert result == "User: John from New York"
def test_resolve_array_template(self):
"""Test resolving array element templates."""
resolver = TemplateResolver()
template = "First result: {results[0].label} ({results[0].score})"
context = {
"results": [
{"label": "car", "score": 0.95},
{"label": "truck", "score": 0.87}
]
}
result = resolver.resolve(template, context)
assert result == "First result: car (0.95)"
def test_resolve_missing_variables(self):
"""Test resolving templates with missing variables."""
resolver = TemplateResolver()
template = "Hello {name}, you are {age} years old"
context = {"name": "John"} # Missing age
with pytest.raises(FieldMappingError) as exc_info:
resolver.resolve(template, context)
assert "Variable 'age' not found" in str(exc_info.value)
def test_resolve_with_defaults(self):
"""Test resolving templates with default values."""
resolver = TemplateResolver(allow_missing=True)
template = "Hello {name}, you are {age|25} years old"
context = {"name": "John"} # Missing age, should use default
result = resolver.resolve(template, context)
assert result == "Hello John, you are 25 years old"
def test_resolve_complex_template(self):
"""Test resolving complex templates with multiple variable types."""
resolver = TemplateResolver()
template = "{camera_id}:{timestamp}:{session_id}:{results[0].class}_{bbox[0]}_{bbox[1]}"
context = {
"camera_id": "cam001",
"timestamp": 1640995200000,
"session_id": "sess123",
"results": [{"class": "car", "confidence": 0.95}],
"bbox": [100, 200, 300, 400]
}
result = resolver.resolve(template, context)
assert result == "cam001:1640995200000:sess123:car_100_200"
def test_resolve_conditional_template(self):
"""Test resolving conditional templates."""
resolver = TemplateResolver()
# Simple conditional
template = "{name} is {age > 18 ? 'adult' : 'minor'}"
context_adult = {"name": "John", "age": 25}
result_adult = resolver.resolve(template, context_adult)
assert result_adult == "John is adult"
context_minor = {"name": "Jane", "age": 16}
result_minor = resolver.resolve(template, context_minor)
assert result_minor == "Jane is minor"
def test_escape_braces(self):
"""Test escaping braces in templates."""
resolver = TemplateResolver()
template = "Literal {{braces}} and variable {name}"
context = {"name": "John"}
result = resolver.resolve(template, context)
assert result == "Literal {braces} and variable John"
class TestMappingContext:
"""Test mapping context data structure."""
def test_creation(self):
"""Test mapping context creation."""
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection,
timestamp=1640995200000
)
assert context.camera_id == "camera_001"
assert context.display_id == "display_001"
assert context.session_id == "session_123"
assert context.detection == detection
assert context.timestamp == 1640995200000
assert context.branch_results == {}
assert context.metadata == {}
def test_add_branch_result(self):
"""Test adding branch results to context."""
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
context.add_branch_result("car_brand_cls", {"brand": "Toyota", "confidence": 0.87})
context.add_branch_result("car_bodytype_cls", {"body_type": "Sedan", "confidence": 0.82})
assert len(context.branch_results) == 2
assert context.branch_results["car_brand_cls"]["brand"] == "Toyota"
assert context.branch_results["car_bodytype_cls"]["body_type"] == "Sedan"
def test_to_dict(self):
"""Test converting context to dictionary."""
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection,
timestamp=1640995200000
)
context.add_branch_result("car_brand_cls", {"brand": "Toyota"})
context.add_metadata("model_id", "yolo_v8")
context_dict = context.to_dict()
assert context_dict["camera_id"] == "camera_001"
assert context_dict["display_id"] == "display_001"
assert context_dict["session_id"] == "session_123"
assert context_dict["timestamp"] == 1640995200000
assert context_dict["class"] == "car"
assert context_dict["confidence"] == 0.9
assert context_dict["track_id"] == 1001
assert context_dict["bbox"]["x1"] == 100
assert context_dict["car_brand_cls"]["brand"] == "Toyota"
assert context_dict["model_id"] == "yolo_v8"
def test_add_metadata(self):
"""Test adding metadata to context."""
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
context.add_metadata("model_version", "v2.1")
context.add_metadata("inference_time", 0.15)
assert context.metadata["model_version"] == "v2.1"
assert context.metadata["inference_time"] == 0.15
class TestFieldMapper:
"""Test field mapping functionality."""
def test_initialization(self):
"""Test field mapper initialization."""
mapper = FieldMapper()
assert isinstance(mapper.template_resolver, TemplateResolver)
assert isinstance(mapper.field_accessor, NestedFieldAccessor)
def test_map_fields_simple(self):
"""Test simple field mapping."""
mapper = FieldMapper()
field_mappings = {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence_score": "{confidence}",
"track_identifier": "{track_id}"
}
detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection,
timestamp=1640995200000
)
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["camera_id"] == "camera_001"
assert mapped_fields["detection_class"] == "car"
assert mapped_fields["confidence_score"] == 0.92
assert mapped_fields["track_identifier"] == 1001
def test_map_fields_with_branch_results(self):
"""Test field mapping with branch results."""
mapper = FieldMapper()
field_mappings = {
"car_brand": "{car_brand_cls.brand}",
"car_model": "{car_brand_cls.model}",
"body_type": "{car_bodytype_cls.body_type}",
"brand_confidence": "{car_brand_cls.confidence}",
"combined_info": "{car_brand_cls.brand} {car_bodytype_cls.body_type}"
}
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
context.add_branch_result("car_brand_cls", {
"brand": "Toyota",
"model": "Camry",
"confidence": 0.87
})
context.add_branch_result("car_bodytype_cls", {
"body_type": "Sedan",
"confidence": 0.82
})
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["car_brand"] == "Toyota"
assert mapped_fields["car_model"] == "Camry"
assert mapped_fields["body_type"] == "Sedan"
assert mapped_fields["brand_confidence"] == 0.87
assert mapped_fields["combined_info"] == "Toyota Sedan"
def test_map_fields_bbox_access(self):
"""Test field mapping with bounding box access."""
mapper = FieldMapper()
field_mappings = {
"bbox_x1": "{bbox.x1}",
"bbox_y1": "{bbox.y1}",
"bbox_x2": "{bbox.x2}",
"bbox_y2": "{bbox.y2}",
"bbox_width": "{bbox.width}",
"bbox_height": "{bbox.height}",
"bbox_area": "{bbox.area}",
"bbox_center_x": "{bbox.center_x}",
"bbox_center_y": "{bbox.center_y}"
}
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection
)
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["bbox_x1"] == 100
assert mapped_fields["bbox_y1"] == 200
assert mapped_fields["bbox_x2"] == 300
assert mapped_fields["bbox_y2"] == 400
assert mapped_fields["bbox_width"] == 200 # 300 - 100
assert mapped_fields["bbox_height"] == 200 # 400 - 200
assert mapped_fields["bbox_area"] == 40000 # 200 * 200
assert mapped_fields["bbox_center_x"] == 200 # (100 + 300) / 2
assert mapped_fields["bbox_center_y"] == 300 # (200 + 400) / 2
def test_map_fields_with_sql_functions(self):
"""Test field mapping with SQL function templates."""
mapper = FieldMapper()
field_mappings = {
"created_at": "NOW()",
"updated_at": "CURRENT_TIMESTAMP",
"uuid_field": "UUID()",
"json_data": "JSON_OBJECT('class', '{class}', 'confidence', {confidence})"
}
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection
)
mapped_fields = mapper.map_fields(field_mappings, context)
# SQL functions should pass through unchanged
assert mapped_fields["created_at"] == "NOW()"
assert mapped_fields["updated_at"] == "CURRENT_TIMESTAMP"
assert mapped_fields["uuid_field"] == "UUID()"
assert mapped_fields["json_data"] == "JSON_OBJECT('class', 'car', 'confidence', 0.9)"
def test_map_fields_missing_branch_data(self):
"""Test field mapping with missing branch data."""
mapper = FieldMapper()
field_mappings = {
"car_brand": "{car_brand_cls.brand}",
"car_model": "{nonexistent_branch.model}"
}
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
# Only add one branch result
context.add_branch_result("car_brand_cls", {"brand": "Toyota"})
with pytest.raises(FieldMappingError) as exc_info:
mapper.map_fields(field_mappings, context)
assert "nonexistent_branch.model" in str(exc_info.value)
def test_map_fields_with_defaults(self):
"""Test field mapping with default values."""
mapper = FieldMapper(allow_missing=True)
field_mappings = {
"car_brand": "{car_brand_cls.brand|Unknown}",
"car_model": "{car_brand_cls.model|N/A}",
"confidence": "{confidence|0.0}"
}
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
# Don't add any branch results
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["car_brand"] == "Unknown"
assert mapped_fields["car_model"] == "N/A"
assert mapped_fields["confidence"] == "0.0"
def test_map_database_fields(self):
"""Test mapping fields for database operations."""
mapper = FieldMapper()
# Database field mapping
db_field_mappings = {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_timestamp": "{timestamp}",
"object_class": "{class}",
"detection_confidence": "{confidence}",
"track_id": "{track_id}",
"bbox_json": "JSON_OBJECT('x1', {bbox.x1}, 'y1', {bbox.y1}, 'x2', {bbox.x2}, 'y2', {bbox.y2})",
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"license_plate": "{license_ocr.text}",
"created_at": "NOW()",
"updated_at": "NOW()"
}
detection = DetectionResult("car", 0.93, BoundingBox(150, 250, 350, 450), 2001, 1640995300000)
context = MappingContext(
camera_id="camera_002",
display_id="display_002",
session_id="session_456",
detection=detection,
timestamp=1640995300000
)
# Add branch results
context.add_branch_result("car_brand_cls", {"brand": "Honda", "confidence": 0.89})
context.add_branch_result("car_bodytype_cls", {"body_type": "SUV", "confidence": 0.85})
context.add_branch_result("license_ocr", {"text": "ABC-123", "confidence": 0.76})
mapped_fields = mapper.map_fields(db_field_mappings, context)
assert mapped_fields["camera_id"] == "camera_002"
assert mapped_fields["session_id"] == "session_456"
assert mapped_fields["detection_timestamp"] == 1640995300000
assert mapped_fields["object_class"] == "car"
assert mapped_fields["detection_confidence"] == 0.93
assert mapped_fields["track_id"] == 2001
assert mapped_fields["bbox_json"] == "JSON_OBJECT('x1', 150, 'y1', 250, 'x2', 350, 'y2', 450)"
assert mapped_fields["car_brand"] == "Honda"
assert mapped_fields["car_body_type"] == "SUV"
assert mapped_fields["license_plate"] == "ABC-123"
assert mapped_fields["created_at"] == "NOW()"
assert mapped_fields["updated_at"] == "NOW()"
def test_map_redis_keys(self):
"""Test mapping Redis key templates."""
mapper = FieldMapper()
key_templates = [
"inference:{camera_id}:{timestamp}:{session_id}:car",
"detection:{display_id}:{track_id}",
"cropped_image:{camera_id}:{session_id}:{class}",
"metadata:{session_id}:brands:{car_brand_cls.brand}",
"tracking:{camera_id}:active_tracks"
]
detection = DetectionResult("car", 0.88, BoundingBox(200, 300, 400, 500), 3001, 1640995400000)
context = MappingContext(
camera_id="camera_003",
display_id="display_003",
session_id="session_789",
detection=detection,
timestamp=1640995400000
)
context.add_branch_result("car_brand_cls", {"brand": "Ford"})
mapped_keys = [mapper.map_template(template, context) for template in key_templates]
expected_keys = [
"inference:camera_003:1640995400000:session_789:car",
"detection:display_003:3001",
"cropped_image:camera_003:session_789:car",
"metadata:session_789:brands:Ford",
"tracking:camera_003:active_tracks"
]
assert mapped_keys == expected_keys
def test_map_template(self):
"""Test single template mapping."""
mapper = FieldMapper()
template = "Camera {camera_id} detected {class} with {confidence:.2f} confidence at {timestamp}"
detection = DetectionResult("truck", 0.876, BoundingBox(100, 150, 300, 350), 4001, 1640995500000)
context = MappingContext(
camera_id="camera_004",
display_id="display_004",
session_id="session_101",
detection=detection,
timestamp=1640995500000
)
result = mapper.map_template(template, context)
expected = "Camera camera_004 detected truck with 0.88 confidence at 1640995500000"
assert result == expected
def test_validate_field_mappings(self):
"""Test field mapping validation."""
mapper = FieldMapper()
# Valid mappings
valid_mappings = {
"camera_id": "{camera_id}",
"class": "{class}",
"confidence": "{confidence}",
"created_at": "NOW()"
}
assert mapper.validate_field_mappings(valid_mappings) is True
# Invalid mappings (malformed templates)
invalid_mappings = {
"camera_id": "{camera_id", # Missing closing brace
"class": "class}", # Missing opening brace
"confidence": "{nonexistent_field}" # This might be valid depending on context
}
with pytest.raises(FieldMappingError):
mapper.validate_field_mappings(invalid_mappings)
def test_create_context_from_detection(self):
"""Test creating mapping context from detection result."""
mapper = FieldMapper()
detection = DetectionResult("car", 0.95, BoundingBox(50, 100, 250, 300), 5001, 1640995600000)
context = mapper.create_context_from_detection(
detection,
camera_id="camera_005",
display_id="display_005",
session_id="session_202"
)
assert context.camera_id == "camera_005"
assert context.display_id == "display_005"
assert context.session_id == "session_202"
assert context.detection == detection
assert context.timestamp == 1640995600000
def test_format_sql_value(self):
"""Test SQL value formatting."""
mapper = FieldMapper()
# String values should be quoted
assert mapper.format_sql_value("test_string") == "'test_string'"
assert mapper.format_sql_value("John's car") == "'John''s car'" # Escape quotes
# Numeric values should not be quoted
assert mapper.format_sql_value(42) == "42"
assert mapper.format_sql_value(3.14) == "3.14"
assert mapper.format_sql_value(0.95) == "0.95"
# Boolean values
assert mapper.format_sql_value(True) == "TRUE"
assert mapper.format_sql_value(False) == "FALSE"
# None/NULL values
assert mapper.format_sql_value(None) == "NULL"
# SQL functions should pass through
assert mapper.format_sql_value("NOW()") == "NOW()"
assert mapper.format_sql_value("CURRENT_TIMESTAMP") == "CURRENT_TIMESTAMP"
class TestFieldMapperIntegration:
"""Integration tests for field mapping."""
def test_complete_mapping_workflow(self):
"""Test complete field mapping workflow."""
mapper = FieldMapper()
# Simulate complete detection workflow
detection = DetectionResult("car", 0.91, BoundingBox(120, 180, 320, 380), 6001, 1640995700000)
context = MappingContext(
camera_id="camera_006",
display_id="display_006",
session_id="session_303",
detection=detection,
timestamp=1640995700000
)
# Add comprehensive branch results
context.add_branch_result("car_brand_cls", {
"brand": "BMW",
"model": "X5",
"confidence": 0.84,
"top3_brands": ["BMW", "Audi", "Mercedes"]
})
context.add_branch_result("car_bodytype_cls", {
"body_type": "SUV",
"confidence": 0.79,
"features": ["tall", "4_doors", "roof_rails"]
})
context.add_branch_result("car_color_cls", {
"color": "Black",
"confidence": 0.73,
"rgb_values": [20, 25, 30]
})
context.add_branch_result("license_ocr", {
"text": "XYZ-789",
"confidence": 0.68,
"region_bbox": [150, 320, 290, 360]
})
# Database field mapping
db_mappings = {
"camera_id": "{camera_id}",
"display_id": "{display_id}",
"session_id": "{session_id}",
"detection_timestamp": "{timestamp}",
"object_class": "{class}",
"detection_confidence": "{confidence}",
"track_id": "{track_id}",
"bbox_x1": "{bbox.x1}",
"bbox_y1": "{bbox.y1}",
"bbox_x2": "{bbox.x2}",
"bbox_y2": "{bbox.y2}",
"bbox_area": "{bbox.area}",
"car_brand": "{car_brand_cls.brand}",
"car_model": "{car_brand_cls.model}",
"car_body_type": "{car_bodytype_cls.body_type}",
"car_color": "{car_color_cls.color}",
"license_plate": "{license_ocr.text}",
"brand_confidence": "{car_brand_cls.confidence}",
"bodytype_confidence": "{car_bodytype_cls.confidence}",
"color_confidence": "{car_color_cls.confidence}",
"license_confidence": "{license_ocr.confidence}",
"detection_summary": "{car_brand_cls.brand} {car_bodytype_cls.body_type} ({car_color_cls.color})",
"created_at": "NOW()",
"updated_at": "NOW()"
}
mapped_db_fields = mapper.map_fields(db_mappings, context)
# Verify all mappings
assert mapped_db_fields["camera_id"] == "camera_006"
assert mapped_db_fields["session_id"] == "session_303"
assert mapped_db_fields["object_class"] == "car"
assert mapped_db_fields["detection_confidence"] == 0.91
assert mapped_db_fields["track_id"] == 6001
assert mapped_db_fields["bbox_area"] == 40000 # 200 * 200
assert mapped_db_fields["car_brand"] == "BMW"
assert mapped_db_fields["car_model"] == "X5"
assert mapped_db_fields["car_body_type"] == "SUV"
assert mapped_db_fields["car_color"] == "Black"
assert mapped_db_fields["license_plate"] == "XYZ-789"
assert mapped_db_fields["detection_summary"] == "BMW SUV (Black)"
# Redis key mapping
redis_key_templates = [
"detection:{camera_id}:{session_id}:main",
"cropped:{camera_id}:{session_id}:car_image",
"metadata:{session_id}:brand:{car_brand_cls.brand}",
"tracking:{camera_id}:track_{track_id}",
"classification:{session_id}:results"
]
mapped_redis_keys = [
mapper.map_template(template, context)
for template in redis_key_templates
]
expected_redis_keys = [
"detection:camera_006:session_303:main",
"cropped:camera_006:session_303:car_image",
"metadata:session_303:brand:BMW",
"tracking:camera_006:track_6001",
"classification:session_303:results"
]
assert mapped_redis_keys == expected_redis_keys
def test_error_handling_and_recovery(self):
"""Test error handling and recovery in field mapping."""
mapper = FieldMapper(allow_missing=True)
# Context with missing detection
context = MappingContext(
camera_id="camera_007",
display_id="display_007",
session_id="session_404"
)
# Partial branch results
context.add_branch_result("car_brand_cls", {"brand": "Unknown"})
# Missing car_bodytype_cls branch
# Field mappings with some missing data
mappings = {
"camera_id": "{camera_id}",
"detection_class": "{class|Unknown}",
"confidence": "{confidence|0.0}",
"car_brand": "{car_brand_cls.brand|N/A}",
"car_body_type": "{car_bodytype_cls.body_type|Unknown}",
"car_model": "{car_brand_cls.model|N/A}"
}
mapped_fields = mapper.map_fields(mappings, context)
assert mapped_fields["camera_id"] == "camera_007"
assert mapped_fields["detection_class"] == "Unknown"
assert mapped_fields["confidence"] == "0.0"
assert mapped_fields["car_brand"] == "Unknown"
assert mapped_fields["car_body_type"] == "Unknown"
assert mapped_fields["car_model"] == "N/A"

View file

@ -0,0 +1,921 @@
"""
Unit tests for pipeline execution functionality.
"""
import pytest
import asyncio
import numpy as np
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from concurrent.futures import ThreadPoolExecutor
import json
from detector_worker.pipeline.pipeline_executor import (
PipelineExecutor,
PipelineContext,
PipelineResult,
BranchResult,
ExecutionMode
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import PipelineError, ModelError, ActionError
class TestPipelineContext:
"""Test pipeline context data structure."""
def test_creation(self):
"""Test pipeline context creation."""
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
)
assert context.camera_id == "camera_001"
assert context.display_id == "display_001"
assert context.session_id == "session_123"
assert context.timestamp == 1640995200000
assert context.frame_data.shape == (480, 640, 3)
assert context.metadata == {}
assert context.crop_region is None
def test_creation_with_crop_region(self):
"""Test context creation with crop region."""
crop_region = (100, 200, 300, 400)
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8),
crop_region=crop_region
)
assert context.crop_region == crop_region
def test_add_metadata(self):
"""Test adding metadata to context."""
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
)
context.add_metadata("model_id", "yolo_v8")
context.add_metadata("confidence_threshold", 0.8)
assert context.metadata["model_id"] == "yolo_v8"
assert context.metadata["confidence_threshold"] == 0.8
def test_get_cropped_frame(self):
"""Test getting cropped frame."""
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=frame,
crop_region=(100, 200, 300, 400)
)
cropped = context.get_cropped_frame()
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
assert np.all(cropped == 255)
def test_get_cropped_frame_no_crop(self):
"""Test getting frame when no crop region specified."""
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=frame
)
cropped = context.get_cropped_frame()
assert np.array_equal(cropped, frame)
class TestBranchResult:
"""Test branch execution result."""
def test_creation_success(self):
"""Test successful branch result creation."""
detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
result = BranchResult(
branch_id="car_brand_cls",
success=True,
detections=detections,
metadata={"brand": "Toyota"},
execution_time=0.15
)
assert result.branch_id == "car_brand_cls"
assert result.success is True
assert len(result.detections) == 1
assert result.metadata["brand"] == "Toyota"
assert result.execution_time == 0.15
assert result.error is None
def test_creation_failure(self):
"""Test failed branch result creation."""
result = BranchResult(
branch_id="car_brand_cls",
success=False,
error="Model inference failed",
execution_time=0.05
)
assert result.branch_id == "car_brand_cls"
assert result.success is False
assert result.detections == []
assert result.metadata == {}
assert result.error == "Model inference failed"
class TestPipelineResult:
"""Test pipeline execution result."""
def test_creation_success(self):
"""Test successful pipeline result creation."""
main_detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
}
result = PipelineResult(
success=True,
detections=main_detections,
branch_results=branch_results,
total_execution_time=0.5
)
assert result.success is True
assert len(result.detections) == 1
assert len(result.branch_results) == 2
assert result.total_execution_time == 0.5
assert result.error is None
def test_creation_failure(self):
"""Test failed pipeline result creation."""
result = PipelineResult(
success=False,
error="Pipeline execution failed",
total_execution_time=0.1
)
assert result.success is False
assert result.detections == []
assert result.branch_results == {}
assert result.error == "Pipeline execution failed"
def test_get_combined_results(self):
"""Test getting combined results from all branches."""
main_detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
}
result = PipelineResult(
success=True,
detections=main_detections,
branch_results=branch_results,
total_execution_time=0.5
)
combined = result.get_combined_results()
assert "brand" in combined
assert "body_type" in combined
assert combined["brand"] == "Toyota"
assert combined["body_type"] == "Sedan"
class TestPipelineExecutor:
"""Test pipeline execution functionality."""
def test_initialization(self):
"""Test pipeline executor initialization."""
executor = PipelineExecutor()
assert isinstance(executor.thread_pool, ThreadPoolExecutor)
assert executor.max_workers == 4
assert executor.execution_mode == ExecutionMode.PARALLEL
assert executor.timeout == 30.0
def test_initialization_custom_config(self):
"""Test initialization with custom configuration."""
config = {
"max_workers": 8,
"execution_mode": "sequential",
"timeout": 60.0
}
executor = PipelineExecutor(config)
assert executor.max_workers == 8
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
assert executor.timeout == 60.0
@pytest.mark.asyncio
async def test_execute_pipeline_simple(self, mock_yolo_model, mock_frame):
"""Test simple pipeline execution."""
# Mock pipeline configuration
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert len(result.detections) == 1
assert result.detections[0].class_name == "0" # Default class name
assert result.detections[0].confidence == 0.9
@pytest.mark.asyncio
async def test_execute_pipeline_with_branches(self, mock_yolo_model, mock_frame):
"""Test pipeline execution with classification branches."""
import torch
# Mock main detection
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0] # car detection
])
mock_detection_result.boxes.id = torch.tensor([1001])
# Mock classification results
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 2 # Toyota
mock_brand_result.probs.top1conf = 0.85
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 1 # Sedan
mock_bodytype_result.probs.top1conf = 0.78
mock_yolo_model.track.return_value = [mock_detection_result]
mock_yolo_model.predict.return_value = [mock_brand_result]
mock_brand_model = Mock()
mock_brand_model.predict.return_value = [mock_brand_result]
mock_brand_model.names = {0: "Honda", 1: "Ford", 2: "Toyota"}
mock_bodytype_model = Mock()
mock_bodytype_model.predict.return_value = [mock_bodytype_result]
mock_bodytype_model.names = {0: "SUV", 1: "Sedan", 2: "Hatchback"}
# Pipeline configuration with branches
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [
{
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
},
{
"modelId": "car_bodytype_cls",
"modelFile": "car_bodytype.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
}
],
"actions": []
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
def get_model_side_effect(model_id, camera_id):
if model_id == "car_detection_v1":
return mock_yolo_model
elif model_id == "car_brand_cls":
return mock_brand_model
elif model_id == "car_bodytype_cls":
return mock_bodytype_model
return None
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert len(result.detections) == 1
assert len(result.branch_results) == 2
# Check branch results
assert "car_brand_cls" in result.branch_results
assert "car_bodytype_cls" in result.branch_results
brand_result = result.branch_results["car_brand_cls"]
assert brand_result.success is True
assert brand_result.metadata.get("brand") == "Toyota"
bodytype_result = result.branch_results["car_bodytype_cls"]
assert bodytype_result.success is True
assert bodytype_result.metadata.get("body_type") == "Sedan"
@pytest.mark.asyncio
async def test_execute_pipeline_sequential_mode(self, mock_yolo_model, mock_frame):
"""Test pipeline execution in sequential mode."""
import torch
config = {"execution_mode": "sequential"}
executor = PipelineExecutor(config)
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [
{
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": False # Sequential execution
}
],
"actions": []
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
@pytest.mark.asyncio
async def test_execute_pipeline_with_actions(self, mock_yolo_model, mock_frame):
"""Test pipeline execution with actions."""
import torch
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
# Pipeline configuration with actions
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{timestamp}:{session_id}",
"expire_seconds": 600
},
{
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence": "{confidence}"
}
}
]
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager, \
patch('detector_worker.pipeline.action_executor.ActionExecutor') as mock_action_executor:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
mock_action_executor.return_value.execute_actions = AsyncMock(return_value=True)
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
# Actions should be executed
mock_action_executor.return_value.execute_actions.assert_called_once()
@pytest.mark.asyncio
async def test_execute_pipeline_model_error(self, mock_frame):
"""Test pipeline execution with model error."""
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
# Model manager raises error
mock_model_manager.return_value.get_model.side_effect = ModelError("Model not found")
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is False
assert "Model not found" in result.error
@pytest.mark.asyncio
async def test_execute_pipeline_timeout(self, mock_yolo_model, mock_frame):
"""Test pipeline execution timeout."""
import torch
# Configure short timeout
config = {"timeout": 0.001} # Very short timeout
executor = PipelineExecutor(config)
# Mock slow model inference
def slow_inference(*args, **kwargs):
import time
time.sleep(1) # Longer than timeout
mock_result = Mock()
mock_result.boxes = None
return [mock_result]
mock_yolo_model.track.side_effect = slow_inference
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is False
assert "timeout" in result.error.lower()
@pytest.mark.asyncio
async def test_execute_branch_parallel(self, mock_frame):
"""Test parallel branch execution."""
import torch
# Mock classification model
mock_brand_model = Mock()
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
mock_brand_model.predict.return_value = [mock_result]
mock_brand_model.names = {0: "Honda", 1: "Toyota", 2: "Ford"}
executor = PipelineExecutor()
# Branch configuration
branch_config = {
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
}
# Mock detected regions
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_brand_model
result = await executor._execute_branch(branch_config, regions, context)
assert result.success is True
assert result.branch_id == "car_brand_cls"
assert result.metadata.get("brand") == "Toyota"
assert result.execution_time > 0
@pytest.mark.asyncio
async def test_execute_branch_no_trigger_class(self, mock_frame):
"""Test branch execution when trigger class not detected."""
executor = PipelineExecutor()
branch_config = {
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7
}
# No car detected
regions = {
"truck": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("truck", 0.9, BoundingBox(100, 200, 300, 400), 1002)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
result = await executor._execute_branch(branch_config, regions, context)
assert result.success is False
assert "trigger class not detected" in result.error.lower()
def test_wait_for_branches(self):
"""Test waiting for specific branches to complete."""
executor = PipelineExecutor()
# Mock completed branch results
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12),
"license_ocr": BranchResult("license_ocr", True, [], {"license": "ABC123"}, 0.2)
}
# Wait for specific branches
wait_for = ["car_brand_cls", "car_bodytype_cls"]
completed = executor._wait_for_branches(branch_results, wait_for, timeout=1.0)
assert completed is True
# Wait for non-existent branch (should timeout)
wait_for_missing = ["car_brand_cls", "nonexistent_branch"]
completed = executor._wait_for_branches(branch_results, wait_for_missing, timeout=0.1)
assert completed is False
def test_validate_pipeline_config(self):
"""Test pipeline configuration validation."""
executor = PipelineExecutor()
# Valid configuration
valid_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8
}
assert executor._validate_pipeline_config(valid_config) is True
# Invalid configuration (missing required fields)
invalid_config = {
"modelFile": "car_detection.pt"
# Missing modelId
}
with pytest.raises(PipelineError):
executor._validate_pipeline_config(invalid_config)
def test_crop_frame_for_detection(self, mock_frame):
"""Test frame cropping for detection."""
executor = PipelineExecutor()
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
cropped = executor._crop_frame_for_detection(mock_frame, detection)
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
def test_crop_frame_invalid_bounds(self, mock_frame):
"""Test frame cropping with invalid bounds."""
executor = PipelineExecutor()
# Detection outside frame bounds
detection = DetectionResult("car", 0.9, BoundingBox(-100, -200, 50, 100), 1001)
cropped = executor._crop_frame_for_detection(mock_frame, detection)
# Should handle bounds gracefully
assert cropped.shape[0] > 0
assert cropped.shape[1] > 0
class TestPipelineExecutorPerformance:
"""Test pipeline executor performance and optimization."""
@pytest.mark.asyncio
async def test_parallel_branch_execution_performance(self, mock_frame):
"""Test that parallel execution is faster than sequential."""
import time
import torch
def slow_inference(*args, **kwargs):
time.sleep(0.1) # Simulate slow inference
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
return [mock_result]
mock_model = Mock()
mock_model.predict.side_effect = slow_inference
mock_model.names = {0: "Class0", 1: "Class1"}
# Test parallel execution
parallel_executor = PipelineExecutor({"execution_mode": "parallel", "max_workers": 2})
branch_configs = [
{
"modelId": f"branch_{i}",
"modelFile": f"branch_{i}.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
}
for i in range(3) # 3 branches
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_model
start_time = time.time()
# Execute branches in parallel
tasks = [
parallel_executor._execute_branch(config, regions, context)
for config in branch_configs
]
results = await asyncio.gather(*tasks)
parallel_time = time.time() - start_time
# Parallel execution should be faster than 3 * 0.1 seconds
assert parallel_time < 0.25 # Allow some overhead
assert len(results) == 3
assert all(result.success for result in results)
def test_thread_pool_management(self):
"""Test thread pool creation and management."""
# Test different worker counts
for workers in [1, 2, 4, 8]:
executor = PipelineExecutor({"max_workers": workers})
assert executor.max_workers == workers
assert executor.thread_pool._max_workers == workers
def test_memory_management_large_frames(self):
"""Test memory management with large frames."""
executor = PipelineExecutor()
# Create large frame
large_frame = np.ones((1080, 1920, 3), dtype=np.uint8) * 128
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=large_frame,
crop_region=(500, 400, 1000, 800)
)
# Get cropped frame
cropped = context.get_cropped_frame()
# Should reduce memory usage
assert cropped.shape == (400, 500, 3) # Much smaller than original
assert cropped.nbytes < large_frame.nbytes
class TestPipelineExecutorErrorHandling:
"""Test comprehensive error handling."""
@pytest.mark.asyncio
async def test_branch_execution_error_isolation(self, mock_frame):
"""Test that errors in one branch don't affect others."""
executor = PipelineExecutor()
# Mock models - one fails, one succeeds
failing_model = Mock()
failing_model.predict.side_effect = Exception("Model crashed")
success_model = Mock()
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
success_model.predict.return_value = [mock_result]
success_model.names = {0: "Class0", 1: "Class1"}
branch_configs = [
{
"modelId": "failing_branch",
"modelFile": "failing.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
},
{
"modelId": "success_branch",
"modelFile": "success.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
def get_model_side_effect(model_id, camera_id):
if model_id == "failing_branch":
return failing_model
elif model_id == "success_branch":
return success_model
return None
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
# Execute branches
tasks = [
executor._execute_branch(config, regions, context)
for config in branch_configs
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# One should fail, one should succeed
failing_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "failing_branch")
success_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "success_branch")
assert failing_result.success is False
assert "Model crashed" in failing_result.error
assert success_result.success is True
assert success_result.error is None

View file

@ -0,0 +1,976 @@
"""
Unit tests for database management functionality.
"""
import pytest
import asyncio
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
import psycopg2
import uuid
from detector_worker.storage.database_manager import (
DatabaseManager,
DatabaseConfig,
DatabaseConnection,
QueryBuilder,
TransactionManager,
DatabaseError,
ConnectionPoolError
)
from detector_worker.core.exceptions import ConfigurationError
class TestDatabaseConfig:
"""Test database configuration."""
def test_creation_minimal(self):
"""Test creating database config with minimal parameters."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
assert config.host == "localhost"
assert config.port == 5432 # Default port
assert config.database == "test_db"
assert config.username == "test_user"
assert config.password == "test_pass"
assert config.schema == "public" # Default schema
assert config.enabled is True
def test_creation_full(self):
"""Test creating database config with all parameters."""
config = DatabaseConfig(
host="db.example.com",
port=5433,
database="production_db",
username="prod_user",
password="secure_pass",
schema="gas_station_1",
enabled=True,
pool_min_conn=2,
pool_max_conn=20,
pool_timeout=30.0,
connection_timeout=10.0,
ssl_mode="require"
)
assert config.host == "db.example.com"
assert config.port == 5433
assert config.database == "production_db"
assert config.schema == "gas_station_1"
assert config.pool_min_conn == 2
assert config.pool_max_conn == 20
assert config.ssl_mode == "require"
def test_get_connection_string(self):
"""Test generating connection string."""
config = DatabaseConfig(
host="localhost",
port=5432,
database="test_db",
username="test_user",
password="test_pass"
)
conn_string = config.get_connection_string()
expected = "host=localhost port=5432 database=test_db user=test_user password=test_pass"
assert conn_string == expected
def test_get_connection_string_with_ssl(self):
"""Test generating connection string with SSL."""
config = DatabaseConfig(
host="db.example.com",
database="secure_db",
username="user",
password="pass",
ssl_mode="require"
)
conn_string = config.get_connection_string()
assert "sslmode=require" in conn_string
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"host": "test-host",
"port": 5433,
"database": "test-db",
"username": "test-user",
"password": "test-pass",
"schema": "test_schema",
"pool_max_conn": 15,
"unknown_field": "ignored"
}
config = DatabaseConfig.from_dict(config_dict)
assert config.host == "test-host"
assert config.port == 5433
assert config.database == "test-db"
assert config.schema == "test_schema"
assert config.pool_max_conn == 15
class TestQueryBuilder:
"""Test SQL query building functionality."""
def test_build_select_query(self):
"""Test building SELECT queries."""
builder = QueryBuilder("test_schema")
query, params = builder.build_select_query(
table="users",
columns=["id", "name", "email"],
where={"status": "active", "age": 25},
order_by="created_at DESC",
limit=10
)
expected_query = (
"SELECT id, name, email FROM test_schema.users "
"WHERE status = %s AND age = %s "
"ORDER BY created_at DESC LIMIT 10"
)
assert query == expected_query
assert params == ["active", 25]
def test_build_select_all_columns(self):
"""Test building SELECT * query."""
builder = QueryBuilder("public")
query, params = builder.build_select_query("products")
expected_query = "SELECT * FROM public.products"
assert query == expected_query
assert params == []
def test_build_insert_query(self):
"""Test building INSERT queries."""
builder = QueryBuilder("inventory")
data = {
"product_name": "Widget",
"price": 19.99,
"quantity": 100,
"created_at": "NOW()"
}
query, params = builder.build_insert_query("products", data)
expected_query = (
"INSERT INTO inventory.products (product_name, price, quantity, created_at) "
"VALUES (%s, %s, %s, NOW()) RETURNING id"
)
assert query == expected_query
assert params == ["Widget", 19.99, 100]
def test_build_update_query(self):
"""Test building UPDATE queries."""
builder = QueryBuilder("sales")
data = {
"status": "shipped",
"shipped_date": "NOW()",
"tracking_number": "ABC123"
}
where_conditions = {"order_id": 12345}
query, params = builder.build_update_query("orders", data, where_conditions)
expected_query = (
"UPDATE sales.orders SET status = %s, shipped_date = NOW(), tracking_number = %s "
"WHERE order_id = %s"
)
assert query == expected_query
assert params == ["shipped", "ABC123", 12345]
def test_build_delete_query(self):
"""Test building DELETE queries."""
builder = QueryBuilder("logs")
where_conditions = {
"level": "DEBUG",
"created_at": "< NOW() - INTERVAL '7 days'"
}
query, params = builder.build_delete_query("application_logs", where_conditions)
expected_query = (
"DELETE FROM logs.application_logs "
"WHERE level = %s AND created_at < NOW() - INTERVAL '7 days'"
)
assert query == expected_query
assert params == ["DEBUG"]
def test_build_create_table_query(self):
"""Test building CREATE TABLE queries."""
builder = QueryBuilder("gas_station_1")
columns = {
"id": "SERIAL PRIMARY KEY",
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
"camera_id": "VARCHAR(255) NOT NULL",
"detection_class": "VARCHAR(100)",
"confidence": "DECIMAL(4,3)",
"bbox_data": "JSON",
"created_at": "TIMESTAMP DEFAULT NOW()",
"updated_at": "TIMESTAMP DEFAULT NOW()"
}
query = builder.build_create_table_query("detections", columns)
expected_parts = [
"CREATE TABLE IF NOT EXISTS gas_station_1.detections",
"id SERIAL PRIMARY KEY",
"session_id VARCHAR(255) UNIQUE NOT NULL",
"camera_id VARCHAR(255) NOT NULL",
"bbox_data JSON",
"created_at TIMESTAMP DEFAULT NOW()"
]
for part in expected_parts:
assert part in query
def test_escape_identifier(self):
"""Test SQL identifier escaping."""
builder = QueryBuilder("test")
assert builder.escape_identifier("table") == '"table"'
assert builder.escape_identifier("column_name") == '"column_name"'
assert builder.escape_identifier("user-table") == '"user-table"'
def test_format_value_for_sql(self):
"""Test SQL value formatting."""
builder = QueryBuilder("test")
# Regular values should use placeholder
assert builder.format_value_for_sql("string") == ("%s", "string")
assert builder.format_value_for_sql(42) == ("%s", 42)
assert builder.format_value_for_sql(3.14) == ("%s", 3.14)
# SQL functions should be literal
assert builder.format_value_for_sql("NOW()") == ("NOW()", None)
assert builder.format_value_for_sql("CURRENT_TIMESTAMP") == ("CURRENT_TIMESTAMP", None)
assert builder.format_value_for_sql("UUID()") == ("UUID()", None)
class TestDatabaseConnection:
"""Test database connection management."""
def test_creation(self, mock_database_connection):
"""Test connection creation."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
assert conn.config == config
assert conn.connection == mock_database_connection
assert conn.is_connected is True
def test_execute_query(self, mock_database_connection):
"""Test query execution."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [
(1, "John", "john@example.com"),
(2, "Jane", "jane@example.com")
]
mock_cursor.rowcount = 2
conn = DatabaseConnection(config, mock_database_connection)
query = "SELECT id, name, email FROM users WHERE status = %s"
params = ["active"]
result = conn.execute_query(query, params)
assert result == [
(1, "John", "john@example.com"),
(2, "Jane", "jane@example.com")
]
mock_cursor.execute.assert_called_once_with(query, params)
mock_cursor.fetchall.assert_called_once()
def test_execute_query_single_result(self, mock_database_connection):
"""Test query execution with single result."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (1, "John", "john@example.com")
conn = DatabaseConnection(config, mock_database_connection)
result = conn.execute_query("SELECT * FROM users WHERE id = %s", [1], fetch_one=True)
assert result == (1, "John", "john@example.com")
mock_cursor.fetchone.assert_called_once()
def test_execute_query_no_fetch(self, mock_database_connection):
"""Test query execution without fetching results."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 1
conn = DatabaseConnection(config, mock_database_connection)
result = conn.execute_query(
"INSERT INTO users (name) VALUES (%s)",
["John"],
fetch_results=False
)
assert result == 1 # Row count
mock_cursor.execute.assert_called_once()
mock_cursor.fetchall.assert_not_called()
mock_cursor.fetchone.assert_not_called()
def test_execute_query_error(self, mock_database_connection):
"""Test query execution error handling."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.execute.side_effect = psycopg2.Error("Database error")
conn = DatabaseConnection(config, mock_database_connection)
with pytest.raises(DatabaseError) as exc_info:
conn.execute_query("SELECT * FROM invalid_table")
assert "Database error" in str(exc_info.value)
def test_commit_transaction(self, mock_database_connection):
"""Test transaction commit."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.commit()
mock_database_connection.commit.assert_called_once()
def test_rollback_transaction(self, mock_database_connection):
"""Test transaction rollback."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.rollback()
mock_database_connection.rollback.assert_called_once()
def test_close_connection(self, mock_database_connection):
"""Test connection closing."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.close()
assert conn.is_connected is False
mock_database_connection.close.assert_called_once()
class TestTransactionManager:
"""Test transaction management."""
def test_transaction_context_success(self, mock_database_connection):
"""Test successful transaction context."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
tx_manager = TransactionManager(conn)
with tx_manager:
# Simulate some database operations
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should commit on successful exit
mock_database_connection.commit.assert_called_once()
mock_database_connection.rollback.assert_not_called()
def test_transaction_context_error(self, mock_database_connection):
"""Test transaction context with error."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
tx_manager = TransactionManager(conn)
with pytest.raises(DatabaseError):
with tx_manager:
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
# Simulate an error
raise DatabaseError("Something went wrong")
# Should rollback on error
mock_database_connection.rollback.assert_called_once()
mock_database_connection.commit.assert_not_called()
class TestDatabaseManager:
"""Test main database manager functionality."""
def test_initialization(self):
"""Test database manager initialization."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
assert manager.config == config
assert isinstance(manager.query_builder, QueryBuilder)
assert manager.query_builder.schema == "gas_station_1"
assert manager.connection is None
@pytest.mark.asyncio
async def test_connect_success(self):
"""Test successful database connection."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
mock_connection = Mock()
mock_connect.return_value = mock_connection
await manager.connect()
assert manager.connection is not None
assert manager.is_connected is True
mock_connect.assert_called_once()
@pytest.mark.asyncio
async def test_connect_failure(self):
"""Test database connection failure."""
config = DatabaseConfig(
host="nonexistent-host",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
mock_connect.side_effect = psycopg2.Error("Connection failed")
with pytest.raises(DatabaseError) as exc_info:
await manager.connect()
assert "Connection failed" in str(exc_info.value)
assert manager.is_connected is False
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test database disconnection."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
# Mock connection
mock_connection = Mock()
manager.connection = DatabaseConnection(config, mock_connection)
await manager.disconnect()
assert manager.connection is None
mock_connection.close.assert_called_once()
@pytest.mark.asyncio
async def test_execute_query(self, mock_database_connection):
"""Test query execution through manager."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [(1, "Test"), (2, "Data")]
result = await manager.execute_query("SELECT * FROM test_table")
assert result == [(1, "Test"), (2, "Data")]
mock_cursor.execute.assert_called_once()
@pytest.mark.asyncio
async def test_execute_query_not_connected(self):
"""Test query execution when not connected."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with pytest.raises(DatabaseError) as exc_info:
await manager.execute_query("SELECT * FROM test_table")
assert "not connected" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_insert_record(self, mock_database_connection):
"""Test inserting a record."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (123,) # Returned ID
data = {
"session_id": "session_123",
"camera_id": "camera_001",
"detection_class": "car",
"confidence": 0.95,
"created_at": "NOW()"
}
record_id = await manager.insert_record("car_detections", data)
assert record_id == 123
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_record(self, mock_database_connection):
"""Test updating a record."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 1
data = {
"car_brand": "Toyota",
"car_body_type": "Sedan",
"updated_at": "NOW()"
}
where_conditions = {"session_id": "session_123"}
rows_affected = await manager.update_record("car_info", data, where_conditions)
assert rows_affected == 1
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_records(self, mock_database_connection):
"""Test deleting records."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 3
where_conditions = {
"created_at": "< NOW() - INTERVAL '30 days'",
"processed": True
}
rows_deleted = await manager.delete_records("old_detections", where_conditions)
assert rows_deleted == 3
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_create_table(self, mock_database_connection):
"""Test creating a table."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
columns = {
"id": "SERIAL PRIMARY KEY",
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
"camera_id": "VARCHAR(255) NOT NULL",
"detection_data": "JSON",
"created_at": "TIMESTAMP DEFAULT NOW()"
}
await manager.create_table("test_detections", columns)
mock_database_connection.cursor.return_value.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_table_exists(self, mock_database_connection):
"""Test checking if table exists."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior - table exists
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (1,)
exists = await manager.table_exists("car_detections")
assert exists is True
mock_cursor.execute.assert_called_once()
# Mock cursor behavior - table doesn't exist
mock_cursor.fetchone.return_value = None
exists = await manager.table_exists("nonexistent_table")
assert exists is False
@pytest.mark.asyncio
async def test_transaction_context(self, mock_database_connection):
"""Test transaction context manager."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
async with manager.transaction():
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should commit on successful completion
mock_database_connection.commit.assert_called()
@pytest.mark.asyncio
async def test_get_table_schema(self, mock_database_connection):
"""Test getting table schema information."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [
("id", "integer", "NOT NULL"),
("session_id", "character varying", "NOT NULL"),
("created_at", "timestamp without time zone", "DEFAULT now()")
]
schema = await manager.get_table_schema("car_detections")
assert len(schema) == 3
assert schema[0] == ("id", "integer", "NOT NULL")
assert schema[1] == ("session_id", "character varying", "NOT NULL")
@pytest.mark.asyncio
async def test_bulk_insert(self, mock_database_connection):
"""Test bulk insert operation."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
records = [
{"name": "John", "email": "john@example.com"},
{"name": "Jane", "email": "jane@example.com"},
{"name": "Bob", "email": "bob@example.com"}
]
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 3
rows_inserted = await manager.bulk_insert("users", records)
assert rows_inserted == 3
mock_cursor.executemany.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_get_connection_stats(self, mock_database_connection):
"""Test getting connection statistics."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
stats = manager.get_connection_stats()
assert "connected" in stats
assert "host" in stats
assert "database" in stats
assert "schema" in stats
assert stats["connected"] is True
assert stats["host"] == "localhost"
assert stats["database"] == "test_db"
class TestDatabaseManagerIntegration:
"""Integration tests for database manager."""
@pytest.mark.asyncio
async def test_complete_car_detection_workflow(self, mock_database_connection):
"""Test complete car detection database workflow."""
config = DatabaseConfig(
host="localhost",
database="gas_station_db",
username="detector_user",
password="detector_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behaviors for different operations
mock_cursor = mock_database_connection.cursor.return_value
# 1. Create initial detection record
mock_cursor.fetchone.return_value = (456,) # Returned ID
detection_data = {
"session_id": str(uuid.uuid4()),
"camera_id": "camera_001",
"display_id": "display_001",
"detection_class": "car",
"confidence": 0.92,
"bbox_x1": 100,
"bbox_y1": 200,
"bbox_x2": 300,
"bbox_y2": 400,
"track_id": 1001,
"created_at": "NOW()"
}
detection_id = await manager.insert_record("car_detections", detection_data)
assert detection_id == 456
# 2. Update with classification results
mock_cursor.rowcount = 1
classification_data = {
"car_brand": "Toyota",
"car_model": "Camry",
"car_body_type": "Sedan",
"car_color": "Blue",
"brand_confidence": 0.87,
"bodytype_confidence": 0.82,
"color_confidence": 0.79,
"updated_at": "NOW()"
}
where_conditions = {"session_id": detection_data["session_id"]}
rows_updated = await manager.update_record("car_detections", classification_data, where_conditions)
assert rows_updated == 1
# 3. Query final results
mock_cursor.fetchall.return_value = [
(456, detection_data["session_id"], "camera_001", "car", 0.92, "Toyota", "Sedan")
]
results = await manager.execute_query(
"SELECT id, session_id, camera_id, detection_class, confidence, car_brand, car_body_type "
"FROM gas_station_1.car_detections WHERE session_id = %s",
[detection_data["session_id"]]
)
assert len(results) == 1
assert results[0][0] == 456 # ID
assert results[0][3] == "car" # detection_class
assert results[0][5] == "Toyota" # car_brand
# Verify all database operations were called
assert mock_cursor.execute.call_count == 3
assert mock_database_connection.commit.call_count == 2
@pytest.mark.asyncio
async def test_error_handling_and_recovery(self, mock_database_connection):
"""Test error handling and recovery scenarios."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Test transaction rollback on error
mock_cursor = mock_database_connection.cursor.return_value
with pytest.raises(DatabaseError):
async with manager.transaction():
# First operation succeeds
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
# Second operation fails
mock_cursor.execute.side_effect = psycopg2.Error("Constraint violation")
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should have rolled back
mock_database_connection.rollback.assert_called_once()
mock_database_connection.commit.assert_not_called()
@pytest.mark.asyncio
async def test_connection_recovery(self):
"""Test automatic connection recovery."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
# First connection attempt fails
mock_connect.side_effect = [
psycopg2.Error("Connection refused"),
Mock() # Second attempt succeeds
]
# First attempt should fail
with pytest.raises(DatabaseError):
await manager.connect()
# Second attempt should succeed
await manager.connect()
assert manager.is_connected is True

View file

@ -0,0 +1,964 @@
"""
Unit tests for Redis client functionality.
"""
import pytest
import asyncio
import json
import base64
import time
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
import redis
import numpy as np
from detector_worker.storage.redis_client import (
RedisClient,
RedisConfig,
RedisConnectionPool,
RedisPublisher,
RedisSubscriber,
RedisImageStorage,
RedisError,
ConnectionPoolError
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import ConfigurationError
class TestRedisConfig:
"""Test Redis configuration."""
def test_creation_minimal(self):
"""Test creating Redis config with minimal parameters."""
config = RedisConfig(
host="localhost"
)
assert config.host == "localhost"
assert config.port == 6379 # Default port
assert config.password is None
assert config.db == 0 # Default database
assert config.enabled is True
def test_creation_full(self):
"""Test creating Redis config with all parameters."""
config = RedisConfig(
host="redis.example.com",
port=6380,
password="secure_pass",
db=2,
enabled=True,
connection_timeout=5.0,
socket_timeout=3.0,
socket_connect_timeout=2.0,
max_connections=50,
retry_on_timeout=True,
health_check_interval=30
)
assert config.host == "redis.example.com"
assert config.port == 6380
assert config.password == "secure_pass"
assert config.db == 2
assert config.connection_timeout == 5.0
assert config.max_connections == 50
assert config.retry_on_timeout is True
def test_get_connection_params(self):
"""Test getting Redis connection parameters."""
config = RedisConfig(
host="localhost",
port=6379,
password="test_pass",
db=1,
connection_timeout=10.0
)
params = config.get_connection_params()
assert params["host"] == "localhost"
assert params["port"] == 6379
assert params["password"] == "test_pass"
assert params["db"] == 1
assert params["socket_timeout"] == 10.0
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"host": "redis-server",
"port": 6380,
"password": "secret",
"db": 3,
"max_connections": 100,
"unknown_field": "ignored"
}
config = RedisConfig.from_dict(config_dict)
assert config.host == "redis-server"
assert config.port == 6380
assert config.password == "secret"
assert config.db == 3
assert config.max_connections == 100
class TestRedisConnectionPool:
"""Test Redis connection pool management."""
def test_creation(self):
"""Test connection pool creation."""
config = RedisConfig(
host="localhost",
max_connections=20
)
pool = RedisConnectionPool(config)
assert pool.config == config
assert pool.pool is None
assert pool.is_connected is False
@pytest.mark.asyncio
async def test_connect_success(self):
"""Test successful connection to Redis."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
with patch('redis.ConnectionPool') as mock_pool_class:
mock_pool = Mock()
mock_pool_class.return_value = mock_pool
with patch('redis.Redis') as mock_redis_class:
mock_redis = Mock()
mock_redis.ping.return_value = True
mock_redis_class.return_value = mock_redis
await pool.connect()
assert pool.is_connected is True
assert pool.pool is not None
mock_pool_class.assert_called_once()
@pytest.mark.asyncio
async def test_connect_failure(self):
"""Test Redis connection failure."""
config = RedisConfig(host="nonexistent-redis")
pool = RedisConnectionPool(config)
with patch('redis.ConnectionPool') as mock_pool_class:
mock_pool_class.side_effect = redis.ConnectionError("Connection failed")
with pytest.raises(RedisError) as exc_info:
await pool.connect()
assert "Connection failed" in str(exc_info.value)
assert pool.is_connected is False
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test Redis disconnection."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
# Mock connected state
mock_pool = Mock()
mock_redis = Mock()
pool.pool = mock_pool
pool._redis_client = mock_redis
pool.is_connected = True
await pool.disconnect()
assert pool.is_connected is False
assert pool.pool is None
mock_pool.disconnect.assert_called_once()
def test_get_client_connected(self):
"""Test getting Redis client when connected."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
mock_pool = Mock()
mock_redis = Mock()
pool.pool = mock_pool
pool._redis_client = mock_redis
pool.is_connected = True
client = pool.get_client()
assert client == mock_redis
def test_get_client_not_connected(self):
"""Test getting Redis client when not connected."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
with pytest.raises(RedisError) as exc_info:
pool.get_client()
assert "not connected" in str(exc_info.value).lower()
def test_health_check(self):
"""Test Redis health check."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
mock_redis = Mock()
mock_redis.ping.return_value = True
pool._redis_client = mock_redis
pool.is_connected = True
is_healthy = pool.health_check()
assert is_healthy is True
mock_redis.ping.assert_called_once()
def test_health_check_failure(self):
"""Test Redis health check failure."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
mock_redis = Mock()
mock_redis.ping.side_effect = redis.ConnectionError("Connection lost")
pool._redis_client = mock_redis
pool.is_connected = True
is_healthy = pool.health_check()
assert is_healthy is False
class TestRedisImageStorage:
"""Test Redis image storage functionality."""
def test_creation(self, mock_redis_client):
"""Test Redis image storage creation."""
storage = RedisImageStorage(mock_redis_client)
assert storage.redis_client == mock_redis_client
assert storage.default_expiry == 3600 # 1 hour
assert storage.compression_enabled is True
@pytest.mark.asyncio
async def test_store_image_success(self, mock_redis_client, mock_frame):
"""Test successful image storage."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
with patch('cv2.imencode') as mock_imencode:
# Mock successful encoding
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
result = await storage.store_image("test_key", mock_frame, expire_seconds=600)
assert result is True
mock_redis_client.set.assert_called_once()
mock_redis_client.expire.assert_called_once_with("test_key", 600)
mock_imencode.assert_called_once()
@pytest.mark.asyncio
async def test_store_image_cropped(self, mock_redis_client, mock_frame):
"""Test storing cropped image."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
result = await storage.store_image("cropped_key", mock_frame, crop_bbox=bbox)
assert result is True
mock_redis_client.set.assert_called_once()
@pytest.mark.asyncio
async def test_store_image_encoding_failure(self, mock_redis_client, mock_frame):
"""Test image storage with encoding failure."""
storage = RedisImageStorage(mock_redis_client)
with patch('cv2.imencode') as mock_imencode:
# Mock encoding failure
mock_imencode.return_value = (False, None)
with pytest.raises(RedisError) as exc_info:
await storage.store_image("test_key", mock_frame)
assert "Failed to encode image" in str(exc_info.value)
mock_redis_client.set.assert_not_called()
@pytest.mark.asyncio
async def test_store_image_redis_failure(self, mock_redis_client, mock_frame):
"""Test image storage with Redis failure."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.set.side_effect = redis.RedisError("Redis error")
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
with pytest.raises(RedisError) as exc_info:
await storage.store_image("test_key", mock_frame)
assert "Redis error" in str(exc_info.value)
@pytest.mark.asyncio
async def test_retrieve_image_success(self, mock_redis_client):
"""Test successful image retrieval."""
storage = RedisImageStorage(mock_redis_client)
# Mock encoded image data
original_image = np.ones((100, 100, 3), dtype=np.uint8) * 128
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Mock Redis returning base64 encoded data
base64_data = base64.b64encode(encoded_data.tobytes()).decode('utf-8')
mock_redis_client.get.return_value = base64_data
with patch('cv2.imdecode') as mock_imdecode:
mock_imdecode.return_value = original_image
retrieved_image = await storage.retrieve_image("test_key")
assert retrieved_image is not None
assert retrieved_image.shape == (100, 100, 3)
mock_redis_client.get.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_retrieve_image_not_found(self, mock_redis_client):
"""Test image retrieval when key not found."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.get.return_value = None
retrieved_image = await storage.retrieve_image("nonexistent_key")
assert retrieved_image is None
mock_redis_client.get.assert_called_once_with("nonexistent_key")
@pytest.mark.asyncio
async def test_delete_image(self, mock_redis_client):
"""Test image deletion."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.delete.return_value = 1
result = await storage.delete_image("test_key")
assert result is True
mock_redis_client.delete.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_delete_image_not_found(self, mock_redis_client):
"""Test deleting non-existent image."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.delete.return_value = 0
result = await storage.delete_image("nonexistent_key")
assert result is False
mock_redis_client.delete.assert_called_once_with("nonexistent_key")
@pytest.mark.asyncio
async def test_bulk_delete_images(self, mock_redis_client):
"""Test bulk image deletion."""
storage = RedisImageStorage(mock_redis_client)
keys = ["key1", "key2", "key3"]
mock_redis_client.delete.return_value = 3
deleted_count = await storage.bulk_delete_images(keys)
assert deleted_count == 3
mock_redis_client.delete.assert_called_once_with(*keys)
@pytest.mark.asyncio
async def test_cleanup_expired_images(self, mock_redis_client):
"""Test cleanup of expired images."""
storage = RedisImageStorage(mock_redis_client)
# Mock scan to return image keys
mock_redis_client.scan_iter.return_value = [
b"inference:camera1:image1",
b"inference:camera2:image2",
b"inference:camera1:image3"
]
# Mock ttl to return different expiry times
mock_redis_client.ttl.side_effect = [-1, 100, -2] # No expiry, valid, expired
mock_redis_client.delete.return_value = 1
deleted_count = await storage.cleanup_expired_images("inference:*")
assert deleted_count == 1 # Only expired images deleted
mock_redis_client.delete.assert_called_once()
def test_get_image_info(self, mock_redis_client):
"""Test getting image metadata."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.exists.return_value = 1
mock_redis_client.ttl.return_value = 1800 # 30 minutes
mock_redis_client.memory_usage.return_value = 4096 # 4KB
info = storage.get_image_info("test_key")
assert info["exists"] is True
assert info["ttl"] == 1800
assert info["size_bytes"] == 4096
mock_redis_client.exists.assert_called_once_with("test_key")
mock_redis_client.ttl.assert_called_once_with("test_key")
class TestRedisPublisher:
"""Test Redis publisher functionality."""
def test_creation(self, mock_redis_client):
"""Test Redis publisher creation."""
publisher = RedisPublisher(mock_redis_client)
assert publisher.redis_client == mock_redis_client
@pytest.mark.asyncio
async def test_publish_message_string(self, mock_redis_client):
"""Test publishing string message."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.return_value = 2 # 2 subscribers
result = await publisher.publish("test_channel", "Hello, Redis!")
assert result == 2
mock_redis_client.publish.assert_called_once_with("test_channel", "Hello, Redis!")
@pytest.mark.asyncio
async def test_publish_message_json(self, mock_redis_client):
"""Test publishing JSON message."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.return_value = 1
message_data = {
"camera_id": "camera_001",
"detection_class": "car",
"confidence": 0.95,
"timestamp": 1640995200000
}
result = await publisher.publish("detections", message_data)
assert result == 1
# Should have been JSON serialized
expected_json = json.dumps(message_data)
mock_redis_client.publish.assert_called_once_with("detections", expected_json)
@pytest.mark.asyncio
async def test_publish_detection_event(self, mock_redis_client):
"""Test publishing detection event."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.return_value = 3
detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
result = await publisher.publish_detection_event(
"camera_detections",
detection,
camera_id="camera_001",
session_id="session_123"
)
assert result == 3
# Verify the published message structure
call_args = mock_redis_client.publish.call_args
channel = call_args[0][0]
message_str = call_args[0][1]
message_data = json.loads(message_str)
assert channel == "camera_detections"
assert message_data["event_type"] == "detection"
assert message_data["camera_id"] == "camera_001"
assert message_data["session_id"] == "session_123"
assert message_data["detection"]["class"] == "car"
assert message_data["detection"]["confidence"] == 0.92
@pytest.mark.asyncio
async def test_publish_batch_messages(self, mock_redis_client):
"""Test publishing multiple messages in batch."""
publisher = RedisPublisher(mock_redis_client)
mock_pipeline = Mock()
mock_redis_client.pipeline.return_value = mock_pipeline
mock_pipeline.execute.return_value = [1, 2, 1] # Subscriber counts
messages = [
("channel1", "message1"),
("channel2", {"data": "message2"}),
("channel1", "message3")
]
results = await publisher.publish_batch(messages)
assert results == [1, 2, 1]
mock_redis_client.pipeline.assert_called_once()
assert mock_pipeline.publish.call_count == 3
mock_pipeline.execute.assert_called_once()
@pytest.mark.asyncio
async def test_publish_error_handling(self, mock_redis_client):
"""Test error handling in publishing."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.side_effect = redis.RedisError("Publish failed")
with pytest.raises(RedisError) as exc_info:
await publisher.publish("test_channel", "test_message")
assert "Publish failed" in str(exc_info.value)
class TestRedisSubscriber:
"""Test Redis subscriber functionality."""
def test_creation(self, mock_redis_client):
"""Test Redis subscriber creation."""
subscriber = RedisSubscriber(mock_redis_client)
assert subscriber.redis_client == mock_redis_client
assert subscriber.pubsub is None
assert subscriber.subscriptions == set()
@pytest.mark.asyncio
async def test_subscribe_to_channel(self, mock_redis_client):
"""Test subscribing to a channel."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
await subscriber.subscribe("test_channel")
assert "test_channel" in subscriber.subscriptions
mock_pubsub.subscribe.assert_called_once_with("test_channel")
@pytest.mark.asyncio
async def test_subscribe_to_pattern(self, mock_redis_client):
"""Test subscribing to a pattern."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
await subscriber.subscribe_pattern("detection:*")
assert "detection:*" in subscriber.subscriptions
mock_pubsub.psubscribe.assert_called_once_with("detection:*")
@pytest.mark.asyncio
async def test_unsubscribe_from_channel(self, mock_redis_client):
"""Test unsubscribing from a channel."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
subscriber.pubsub = mock_pubsub
subscriber.subscriptions.add("test_channel")
await subscriber.unsubscribe("test_channel")
assert "test_channel" not in subscriber.subscriptions
mock_pubsub.unsubscribe.assert_called_once_with("test_channel")
@pytest.mark.asyncio
async def test_listen_for_messages(self, mock_redis_client):
"""Test listening for messages."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
# Mock message stream
messages = [
{"type": "subscribe", "channel": "test", "data": 1},
{"type": "message", "channel": "test", "data": "Hello"},
{"type": "message", "channel": "test", "data": '{"key": "value"}'}
]
mock_pubsub.listen.return_value = iter(messages)
received_messages = []
message_count = 0
async for message in subscriber.listen():
received_messages.append(message)
message_count += 1
if message_count >= 2: # Only process actual messages
break
# Should receive 2 actual messages (excluding subscribe confirmation)
assert len(received_messages) == 2
assert received_messages[0]["data"] == "Hello"
assert received_messages[1]["data"] == {"key": "value"} # Should be parsed as JSON
@pytest.mark.asyncio
async def test_close_subscription(self, mock_redis_client):
"""Test closing subscription."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
subscriber.pubsub = mock_pubsub
subscriber.subscriptions = {"channel1", "pattern:*"}
await subscriber.close()
assert len(subscriber.subscriptions) == 0
mock_pubsub.close.assert_called_once()
assert subscriber.pubsub is None
class TestRedisClient:
"""Test main Redis client functionality."""
def test_initialization(self):
"""Test Redis client initialization."""
config = RedisConfig(host="localhost", port=6379)
client = RedisClient(config)
assert client.config == config
assert isinstance(client.connection_pool, RedisConnectionPool)
assert client.image_storage is None
assert client.publisher is None
assert client.subscriber is None
@pytest.mark.asyncio
async def test_connect_and_initialize_components(self):
"""Test connecting and initializing all components."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect:
mock_redis_client = Mock()
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
await client.connect()
assert client.image_storage is not None
assert client.publisher is not None
assert client.subscriber is not None
assert isinstance(client.image_storage, RedisImageStorage)
assert isinstance(client.publisher, RedisPublisher)
assert isinstance(client.subscriber, RedisSubscriber)
mock_connect.assert_called_once()
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test disconnection."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.is_connected = True
client.subscriber = Mock()
client.subscriber.close = AsyncMock()
with patch.object(client.connection_pool, 'disconnect', new_callable=AsyncMock) as mock_disconnect:
await client.disconnect()
client.subscriber.close.assert_called_once()
mock_disconnect.assert_called_once()
assert client.image_storage is None
assert client.publisher is None
assert client.subscriber is None
@pytest.mark.asyncio
async def test_store_and_retrieve_data(self, mock_redis_client):
"""Test storing and retrieving data."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
# Test storing data
mock_redis_client.set.return_value = True
result = await client.set("test_key", "test_value", expire_seconds=300)
assert result is True
mock_redis_client.set.assert_called_once_with("test_key", "test_value")
mock_redis_client.expire.assert_called_once_with("test_key", 300)
# Test retrieving data
mock_redis_client.get.return_value = "test_value"
value = await client.get("test_key")
assert value == "test_value"
mock_redis_client.get.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_delete_keys(self, mock_redis_client):
"""Test deleting keys."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.delete.return_value = 2
result = await client.delete("key1", "key2")
assert result == 2
mock_redis_client.delete.assert_called_once_with("key1", "key2")
@pytest.mark.asyncio
async def test_exists_check(self, mock_redis_client):
"""Test checking key existence."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.exists.return_value = 1
exists = await client.exists("test_key")
assert exists is True
mock_redis_client.exists.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_expire_key(self, mock_redis_client):
"""Test setting key expiration."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.expire.return_value = True
result = await client.expire("test_key", 600)
assert result is True
mock_redis_client.expire.assert_called_once_with("test_key", 600)
@pytest.mark.asyncio
async def test_get_ttl(self, mock_redis_client):
"""Test getting key TTL."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.ttl.return_value = 300
ttl = await client.ttl("test_key")
assert ttl == 300
mock_redis_client.ttl.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_scan_keys(self, mock_redis_client):
"""Test scanning for keys."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.scan_iter.return_value = [b"key1", b"key2", b"key3"]
keys = await client.scan_keys("test:*")
assert keys == ["key1", "key2", "key3"]
mock_redis_client.scan_iter.assert_called_once_with(match="test:*")
@pytest.mark.asyncio
async def test_flush_database(self, mock_redis_client):
"""Test flushing database."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.flushdb.return_value = True
result = await client.flush_db()
assert result is True
mock_redis_client.flushdb.assert_called_once()
def test_get_connection_info(self):
"""Test getting connection information."""
config = RedisConfig(
host="redis.example.com",
port=6380,
db=2
)
client = RedisClient(config)
client.connection_pool.is_connected = True
info = client.get_connection_info()
assert info["connected"] is True
assert info["host"] == "redis.example.com"
assert info["port"] == 6380
assert info["database"] == 2
@pytest.mark.asyncio
async def test_pipeline_operations(self, mock_redis_client):
"""Test Redis pipeline operations."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_pipeline = Mock()
mock_redis_client.pipeline.return_value = mock_pipeline
mock_pipeline.execute.return_value = [True, True, 1]
async with client.pipeline() as pipe:
pipe.set("key1", "value1")
pipe.set("key2", "value2")
pipe.delete("key3")
results = await pipe.execute()
assert results == [True, True, 1]
mock_redis_client.pipeline.assert_called_once()
mock_pipeline.execute.assert_called_once()
class TestRedisClientIntegration:
"""Integration tests for Redis client."""
@pytest.mark.asyncio
async def test_complete_image_workflow(self, mock_redis_client, mock_frame):
"""Test complete image storage workflow."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state and components
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
client.image_storage = RedisImageStorage(mock_redis_client)
client.publisher = RedisPublisher(mock_redis_client)
# Mock Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
mock_redis_client.publish.return_value = 2
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Store image
store_result = await client.image_storage.store_image(
"detection:camera001:1640995200:session123",
mock_frame,
expire_seconds=600
)
# Publish detection event
detection_event = {
"camera_id": "camera001",
"session_id": "session123",
"detection_class": "car",
"confidence": 0.95,
"timestamp": 1640995200000
}
publish_result = await client.publisher.publish("detections:camera001", detection_event)
assert store_result is True
assert publish_result == 2
# Verify Redis operations
mock_redis_client.set.assert_called_once()
mock_redis_client.expire.assert_called_once()
mock_redis_client.publish.assert_called_once()
@pytest.mark.asyncio
async def test_error_recovery_and_reconnection(self):
"""Test error recovery and reconnection."""
config = RedisConfig(host="localhost", retry_on_timeout=True)
client = RedisClient(config)
with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect:
with patch.object(client.connection_pool, 'health_check') as mock_health_check:
# First health check fails, second succeeds
mock_health_check.side_effect = [False, True]
# First connection attempt fails, second succeeds
mock_connect.side_effect = [RedisError("Connection failed"), None]
# Simulate connection recovery
try:
await client.connect()
except RedisError:
# Retry connection
await client.connect()
assert mock_connect.call_count == 2
@pytest.mark.asyncio
async def test_bulk_operations_performance(self, mock_redis_client):
"""Test bulk operations for performance."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
client.publisher = RedisPublisher(mock_redis_client)
# Mock pipeline operations
mock_pipeline = Mock()
mock_redis_client.pipeline.return_value = mock_pipeline
mock_pipeline.execute.return_value = [1] * 100 # 100 successful operations
# Prepare bulk messages
messages = [
(f"channel_{i}", f"message_{i}")
for i in range(100)
]
start_time = time.time()
results = await client.publisher.publish_batch(messages)
execution_time = time.time() - start_time
assert len(results) == 100
assert all(result == 1 for result in results)
# Should be faster than individual operations
assert execution_time < 1.0 # Should complete in less than 1 second
# Pipeline should be used for efficiency
mock_redis_client.pipeline.assert_called_once()
assert mock_pipeline.publish.call_count == 100
mock_pipeline.execute.assert_called_once()

View file

@ -0,0 +1,883 @@
"""
Unit tests for session cache management.
"""
import pytest
import time
import uuid
from unittest.mock import Mock, patch
from datetime import datetime, timedelta
from collections import defaultdict
from detector_worker.storage.session_cache import (
SessionCache,
SessionCacheManager,
SessionData,
CacheConfig,
CacheEntry,
CacheStats,
SessionError,
CacheError
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
class TestCacheConfig:
"""Test cache configuration."""
def test_creation_default(self):
"""Test creating cache config with default values."""
config = CacheConfig()
assert config.max_size == 1000
assert config.ttl_seconds == 3600 # 1 hour
assert config.cleanup_interval == 300 # 5 minutes
assert config.eviction_policy == "lru"
assert config.enable_persistence is False
def test_creation_custom(self):
"""Test creating cache config with custom values."""
config = CacheConfig(
max_size=5000,
ttl_seconds=7200,
cleanup_interval=600,
eviction_policy="lfu",
enable_persistence=True,
persistence_path="/tmp/cache"
)
assert config.max_size == 5000
assert config.ttl_seconds == 7200
assert config.cleanup_interval == 600
assert config.eviction_policy == "lfu"
assert config.enable_persistence is True
assert config.persistence_path == "/tmp/cache"
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"max_size": 2000,
"ttl_seconds": 1800,
"eviction_policy": "fifo",
"enable_persistence": True,
"unknown_field": "ignored"
}
config = CacheConfig.from_dict(config_dict)
assert config.max_size == 2000
assert config.ttl_seconds == 1800
assert config.eviction_policy == "fifo"
assert config.enable_persistence is True
class TestCacheEntry:
"""Test cache entry data structure."""
def test_creation(self):
"""Test cache entry creation."""
data = {"key": "value", "number": 42}
entry = CacheEntry(data, ttl_seconds=600)
assert entry.data == data
assert entry.ttl_seconds == 600
assert entry.created_at <= time.time()
assert entry.last_accessed <= time.time()
assert entry.access_count == 1
assert entry.size > 0
def test_is_expired(self):
"""Test expiration checking."""
# Non-expired entry
entry = CacheEntry({"data": "test"}, ttl_seconds=600)
assert entry.is_expired() is False
# Expired entry (simulate by setting old creation time)
entry.created_at = time.time() - 700 # Created 700 seconds ago
assert entry.is_expired() is True
# Entry without expiration
entry_no_ttl = CacheEntry({"data": "test"})
assert entry_no_ttl.is_expired() is False
def test_touch(self):
"""Test updating access time and count."""
entry = CacheEntry({"data": "test"})
original_access_time = entry.last_accessed
original_access_count = entry.access_count
time.sleep(0.01) # Small delay
entry.touch()
assert entry.last_accessed > original_access_time
assert entry.access_count == original_access_count + 1
def test_age(self):
"""Test age calculation."""
entry = CacheEntry({"data": "test"})
time.sleep(0.01) # Small delay
age = entry.age()
assert age > 0
assert age < 1 # Should be less than 1 second
def test_size_estimation(self):
"""Test size estimation."""
small_entry = CacheEntry({"key": "value"})
large_entry = CacheEntry({"key": "x" * 1000, "data": list(range(100))})
assert large_entry.size > small_entry.size
class TestSessionData:
"""Test session data structure."""
def test_creation(self):
"""Test session data creation."""
session_data = SessionData(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
assert session_data.session_id == "session_123"
assert session_data.camera_id == "camera_001"
assert session_data.display_id == "display_001"
assert session_data.created_at <= time.time()
assert session_data.last_activity <= time.time()
assert session_data.detection_data == {}
assert session_data.metadata == {}
def test_update_activity(self):
"""Test updating last activity."""
session_data = SessionData("session_123", "camera_001", "display_001")
original_activity = session_data.last_activity
time.sleep(0.01)
session_data.update_activity()
assert session_data.last_activity > original_activity
def test_add_detection_data(self):
"""Test adding detection data."""
session_data = SessionData("session_123", "camera_001", "display_001")
detection_data = {
"class": "car",
"confidence": 0.95,
"bbox": [100, 200, 300, 400]
}
session_data.add_detection_data("main_detection", detection_data)
assert "main_detection" in session_data.detection_data
assert session_data.detection_data["main_detection"] == detection_data
def test_add_metadata(self):
"""Test adding metadata."""
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_metadata("model_version", "v2.1")
session_data.add_metadata("inference_time", 0.15)
assert session_data.metadata["model_version"] == "v2.1"
assert session_data.metadata["inference_time"] == 0.15
def test_is_expired(self):
"""Test session expiration."""
session_data = SessionData("session_123", "camera_001", "display_001")
# Not expired with default timeout
assert session_data.is_expired() is False
# Expired with short timeout
assert session_data.is_expired(timeout_seconds=0.001) is True
def test_to_dict(self):
"""Test converting session to dictionary."""
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("detection", {"class": "car", "confidence": 0.9})
session_data.add_metadata("model_id", "yolo_v8")
data_dict = session_data.to_dict()
assert data_dict["session_id"] == "session_123"
assert data_dict["camera_id"] == "camera_001"
assert data_dict["detection_data"]["detection"]["class"] == "car"
assert data_dict["metadata"]["model_id"] == "yolo_v8"
assert "created_at" in data_dict
assert "last_activity" in data_dict
class TestCacheStats:
"""Test cache statistics."""
def test_creation(self):
"""Test cache stats creation."""
stats = CacheStats()
assert stats.hits == 0
assert stats.misses == 0
assert stats.evictions == 0
assert stats.size == 0
assert stats.memory_usage == 0
def test_hit_rate_calculation(self):
"""Test hit rate calculation."""
stats = CacheStats()
# No requests yet
assert stats.hit_rate() == 0.0
# Some hits and misses
stats.hits = 8
stats.misses = 2
assert stats.hit_rate() == 0.8 # 8 / (8 + 2)
def test_total_requests(self):
"""Test total requests calculation."""
stats = CacheStats()
stats.hits = 15
stats.misses = 5
assert stats.total_requests() == 20
class TestSessionCache:
"""Test session cache functionality."""
def test_creation(self):
"""Test session cache creation."""
config = CacheConfig(max_size=100, ttl_seconds=300)
cache = SessionCache(config)
assert cache.config == config
assert cache.max_size == 100
assert cache.ttl_seconds == 300
assert len(cache._cache) == 0
assert len(cache._access_order) == 0
def test_put_and_get_session(self):
"""Test putting and getting session data."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("main", {"class": "car", "confidence": 0.9})
# Put session
cache.put("session_123", session_data)
# Get session
retrieved_data = cache.get("session_123")
assert retrieved_data is not None
assert retrieved_data.session_id == "session_123"
assert retrieved_data.camera_id == "camera_001"
assert "main" in retrieved_data.detection_data
def test_get_nonexistent_session(self):
"""Test getting non-existent session."""
cache = SessionCache(CacheConfig(max_size=10))
result = cache.get("nonexistent_session")
assert result is None
def test_contains_check(self):
"""Test checking if session exists."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
assert cache.contains("session_123") is True
assert cache.contains("nonexistent_session") is False
def test_remove_session(self):
"""Test removing session."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
assert cache.contains("session_123") is True
removed_data = cache.remove("session_123")
assert removed_data is not None
assert removed_data.session_id == "session_123"
assert cache.contains("session_123") is False
def test_size_tracking(self):
"""Test cache size tracking."""
cache = SessionCache(CacheConfig(max_size=10))
assert cache.size() == 0
assert cache.is_empty() is True
# Add sessions
for i in range(3):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
assert cache.size() == 3
assert cache.is_empty() is False
def test_lru_eviction(self):
"""Test LRU eviction policy."""
cache = SessionCache(CacheConfig(max_size=3, eviction_policy="lru"))
# Fill cache to capacity
for i in range(3):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
# Access session_1 to make it recently used
cache.get("session_1")
# Add another session (should evict session_0, the least recently used)
new_session = SessionData("session_3", "camera_001", "display_001")
cache.put("session_3", new_session)
assert cache.size() == 3
assert cache.contains("session_0") is False # Evicted
assert cache.contains("session_1") is True # Recently accessed
assert cache.contains("session_2") is True
assert cache.contains("session_3") is True # Newly added
def test_ttl_expiration(self):
"""Test TTL-based expiration."""
cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1)) # 100ms TTL
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
# Should exist immediately
assert cache.contains("session_123") is True
# Wait for expiration
time.sleep(0.2)
# Should be expired (but might still be in cache until cleanup)
entry = cache._cache.get("session_123")
if entry:
assert entry.is_expired() is True
# Getting expired entry should return None and clean it up
retrieved = cache.get("session_123")
assert retrieved is None
assert cache.contains("session_123") is False
def test_cleanup_expired_entries(self):
"""Test cleanup of expired entries."""
cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1))
# Add multiple sessions
for i in range(3):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
assert cache.size() == 3
# Wait for expiration
time.sleep(0.2)
# Cleanup expired entries
cleaned_count = cache.cleanup_expired()
assert cleaned_count == 3
assert cache.size() == 0
def test_clear_cache(self):
"""Test clearing entire cache."""
cache = SessionCache(CacheConfig(max_size=10))
# Add sessions
for i in range(5):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
assert cache.size() == 5
cache.clear()
assert cache.size() == 0
assert cache.is_empty() is True
def test_get_all_sessions(self):
"""Test getting all sessions."""
cache = SessionCache(CacheConfig(max_size=10))
sessions = []
for i in range(3):
session_data = SessionData(f"session_{i}", f"camera_{i}", "display_001")
cache.put(f"session_{i}", session_data)
sessions.append(session_data)
all_sessions = cache.get_all()
assert len(all_sessions) == 3
for session_id, session_data in all_sessions.items():
assert session_id.startswith("session_")
assert session_data.session_id == session_id
def test_get_sessions_by_camera(self):
"""Test getting sessions by camera ID."""
cache = SessionCache(CacheConfig(max_size=10))
# Add sessions for different cameras
for i in range(2):
session_data1 = SessionData(f"session_cam1_{i}", "camera_001", "display_001")
session_data2 = SessionData(f"session_cam2_{i}", "camera_002", "display_001")
cache.put(f"session_cam1_{i}", session_data1)
cache.put(f"session_cam2_{i}", session_data2)
camera1_sessions = cache.get_by_camera("camera_001")
camera2_sessions = cache.get_by_camera("camera_002")
assert len(camera1_sessions) == 2
assert len(camera2_sessions) == 2
for session_data in camera1_sessions:
assert session_data.camera_id == "camera_001"
for session_data in camera2_sessions:
assert session_data.camera_id == "camera_002"
def test_statistics_tracking(self):
"""Test cache statistics tracking."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
# Cache miss
cache.get("nonexistent_session")
# Cache hit
cache.get("session_123")
cache.get("session_123") # Another hit
stats = cache.get_stats()
assert stats.hits == 2
assert stats.misses == 1
assert stats.size == 1
assert stats.hit_rate() == 2/3 # 2 hits out of 3 total requests
def test_memory_usage_estimation(self):
"""Test memory usage estimation."""
cache = SessionCache(CacheConfig(max_size=10))
initial_memory = cache.get_memory_usage()
# Add large session
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("large_data", {"data": "x" * 1000})
cache.put("session_123", session_data)
after_memory = cache.get_memory_usage()
assert after_memory > initial_memory
class TestSessionCacheManager:
"""Test session cache manager."""
def test_singleton_behavior(self):
"""Test that SessionCacheManager is a singleton."""
manager1 = SessionCacheManager()
manager2 = SessionCacheManager()
assert manager1 is manager2
def test_initialization(self):
"""Test session cache manager initialization."""
manager = SessionCacheManager()
assert manager.detection_cache is not None
assert manager.pipeline_cache is not None
assert manager.session_cache is not None
assert isinstance(manager.detection_cache, SessionCache)
assert isinstance(manager.pipeline_cache, SessionCache)
assert isinstance(manager.session_cache, SessionCache)
def test_cache_detection_result(self):
"""Test caching detection results."""
manager = SessionCacheManager()
manager.clear_all() # Start fresh
detection_data = {
"class": "car",
"confidence": 0.95,
"bbox": [100, 200, 300, 400],
"track_id": 1001
}
manager.cache_detection("camera_001", detection_data)
cached_detection = manager.get_cached_detection("camera_001")
assert cached_detection is not None
assert cached_detection["class"] == "car"
assert cached_detection["confidence"] == 0.95
assert cached_detection["track_id"] == 1001
def test_cache_pipeline_result(self):
"""Test caching pipeline results."""
manager = SessionCacheManager()
manager.clear_all()
pipeline_result = {
"status": "success",
"detections": [{"class": "car", "confidence": 0.9}],
"execution_time": 0.15,
"model_id": "yolo_v8"
}
manager.cache_pipeline_result("camera_001", pipeline_result)
cached_result = manager.get_cached_pipeline_result("camera_001")
assert cached_result is not None
assert cached_result["status"] == "success"
assert cached_result["execution_time"] == 0.15
assert len(cached_result["detections"]) == 1
def test_manage_session_data(self):
"""Test session data management."""
manager = SessionCacheManager()
manager.clear_all()
session_id = str(uuid.uuid4())
# Create session
manager.create_session(session_id, "camera_001", {"initial": "data"})
# Update session
manager.update_session_detection(session_id, {"car_brand": "Toyota"})
# Get session
session_data = manager.get_session_detection(session_id)
assert session_data is not None
assert "initial" in session_data
assert session_data["car_brand"] == "Toyota"
def test_set_latest_frame(self):
"""Test setting and getting latest frame."""
manager = SessionCacheManager()
manager.clear_all()
frame_data = b"fake_frame_data"
manager.set_latest_frame("camera_001", frame_data)
retrieved_frame = manager.get_latest_frame("camera_001")
assert retrieved_frame == frame_data
def test_frame_skip_flag_management(self):
"""Test frame skip flag management."""
manager = SessionCacheManager()
manager.clear_all()
# Initially should be False
assert manager.get_frame_skip_flag("camera_001") is False
# Set to True
manager.set_frame_skip_flag("camera_001", True)
assert manager.get_frame_skip_flag("camera_001") is True
# Set back to False
manager.set_frame_skip_flag("camera_001", False)
assert manager.get_frame_skip_flag("camera_001") is False
def test_cleanup_expired_sessions(self):
"""Test cleanup of expired sessions."""
manager = SessionCacheManager()
manager.clear_all()
# Create sessions with short TTL
manager.session_cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1))
# Add sessions
for i in range(3):
session_id = f"session_{i}"
manager.create_session(session_id, "camera_001", {"test": "data"})
assert manager.session_cache.size() == 3
# Wait for expiration
time.sleep(0.2)
# Cleanup
expired_count = manager.cleanup_expired_sessions()
assert expired_count == 3
assert manager.session_cache.size() == 0
def test_clear_camera_cache(self):
"""Test clearing cache for specific camera."""
manager = SessionCacheManager()
manager.clear_all()
# Add data for multiple cameras
manager.cache_detection("camera_001", {"class": "car"})
manager.cache_detection("camera_002", {"class": "truck"})
manager.cache_pipeline_result("camera_001", {"status": "success"})
manager.set_latest_frame("camera_001", b"frame1")
manager.set_latest_frame("camera_002", b"frame2")
# Clear camera_001 cache
manager.clear_camera_cache("camera_001")
# camera_001 data should be gone
assert manager.get_cached_detection("camera_001") is None
assert manager.get_cached_pipeline_result("camera_001") is None
assert manager.get_latest_frame("camera_001") is None
# camera_002 data should remain
assert manager.get_cached_detection("camera_002") is not None
assert manager.get_latest_frame("camera_002") is not None
def test_get_cache_statistics(self):
"""Test getting cache statistics."""
manager = SessionCacheManager()
manager.clear_all()
# Add some data to generate statistics
manager.cache_detection("camera_001", {"class": "car"})
manager.cache_pipeline_result("camera_001", {"status": "success"})
manager.create_session("session_123", "camera_001", {"initial": "data"})
# Access data to generate hits/misses
manager.get_cached_detection("camera_001") # Hit
manager.get_cached_detection("camera_002") # Miss
stats = manager.get_cache_statistics()
assert "detection_cache" in stats
assert "pipeline_cache" in stats
assert "session_cache" in stats
assert "total_memory_usage" in stats
detection_stats = stats["detection_cache"]
assert detection_stats["size"] >= 1
assert detection_stats["hits"] >= 1
assert detection_stats["misses"] >= 1
def test_memory_pressure_handling(self):
"""Test handling memory pressure."""
# Create manager with small cache sizes
config = CacheConfig(max_size=3)
manager = SessionCacheManager()
manager.detection_cache = SessionCache(config)
manager.pipeline_cache = SessionCache(config)
manager.session_cache = SessionCache(config)
# Fill caches beyond capacity
for i in range(5):
manager.cache_detection(f"camera_{i}", {"class": "car", "data": "x" * 100})
manager.cache_pipeline_result(f"camera_{i}", {"status": "success", "data": "y" * 100})
manager.create_session(f"session_{i}", f"camera_{i}", {"data": "z" * 100})
# Caches should not exceed max size due to eviction
assert manager.detection_cache.size() <= 3
assert manager.pipeline_cache.size() <= 3
assert manager.session_cache.size() <= 3
def test_concurrent_access_thread_safety(self):
"""Test thread safety of concurrent cache access."""
import threading
import concurrent.futures
manager = SessionCacheManager()
manager.clear_all()
results = []
errors = []
def cache_operation(thread_id):
try:
# Each thread performs multiple cache operations
for i in range(10):
session_id = f"thread_{thread_id}_session_{i}"
# Create session
manager.create_session(session_id, f"camera_{thread_id}", {"thread": thread_id, "index": i})
# Update session
manager.update_session_detection(session_id, {"updated": True})
# Read session
data = manager.get_session_detection(session_id)
if data and data.get("thread") == thread_id:
results.append((thread_id, i))
except Exception as e:
errors.append((thread_id, str(e)))
# Run operations concurrently
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(cache_operation, i) for i in range(5)]
concurrent.futures.wait(futures)
# Should have no errors and successful operations
assert len(errors) == 0
assert len(results) >= 25 # At least some operations should succeed
class TestSessionCacheIntegration:
"""Integration tests for session cache."""
def test_complete_detection_workflow(self):
"""Test complete detection workflow with caching."""
manager = SessionCacheManager()
manager.clear_all()
camera_id = "camera_001"
session_id = str(uuid.uuid4())
# 1. Cache initial detection
detection_data = {
"class": "car",
"confidence": 0.92,
"bbox": [100, 200, 300, 400],
"track_id": 1001,
"timestamp": int(time.time() * 1000)
}
manager.cache_detection(camera_id, detection_data)
# 2. Create session for tracking
initial_session_data = {
"detection_class": detection_data["class"],
"confidence": detection_data["confidence"],
"track_id": detection_data["track_id"]
}
manager.create_session(session_id, camera_id, initial_session_data)
# 3. Cache pipeline processing result
pipeline_result = {
"status": "processing",
"stage": "classification",
"detections": [detection_data],
"branches_completed": [],
"branches_pending": ["car_brand_cls", "car_bodytype_cls"]
}
manager.cache_pipeline_result(camera_id, pipeline_result)
# 4. Update session with classification results
classification_updates = [
{"car_brand": "Toyota", "brand_confidence": 0.87},
{"car_body_type": "Sedan", "bodytype_confidence": 0.82}
]
for update in classification_updates:
manager.update_session_detection(session_id, update)
# 5. Update pipeline result to completed
final_pipeline_result = {
"status": "completed",
"stage": "finished",
"detections": [detection_data],
"branches_completed": ["car_brand_cls", "car_bodytype_cls"],
"branches_pending": [],
"execution_time": 0.25
}
manager.cache_pipeline_result(camera_id, final_pipeline_result)
# 6. Verify all cached data
cached_detection = manager.get_cached_detection(camera_id)
cached_pipeline = manager.get_cached_pipeline_result(camera_id)
cached_session = manager.get_session_detection(session_id)
# Assertions
assert cached_detection["class"] == "car"
assert cached_detection["track_id"] == 1001
assert cached_pipeline["status"] == "completed"
assert len(cached_pipeline["branches_completed"]) == 2
assert cached_session["detection_class"] == "car"
assert cached_session["car_brand"] == "Toyota"
assert cached_session["car_body_type"] == "Sedan"
assert cached_session["brand_confidence"] == 0.87
def test_cache_performance_under_load(self):
"""Test cache performance under load."""
manager = SessionCacheManager()
manager.clear_all()
import time
# Measure performance of cache operations
start_time = time.time()
# Perform many cache operations
for i in range(1000):
camera_id = f"camera_{i % 10}" # 10 different cameras
session_id = f"session_{i}"
# Cache detection
detection_data = {
"class": "car",
"confidence": 0.9 + (i % 10) * 0.01,
"track_id": i,
"bbox": [i % 100, i % 100, (i % 100) + 200, (i % 100) + 200]
}
manager.cache_detection(camera_id, detection_data)
# Create session
manager.create_session(session_id, camera_id, {"index": i})
# Read back (every 10th operation)
if i % 10 == 0:
manager.get_cached_detection(camera_id)
manager.get_session_detection(session_id)
end_time = time.time()
total_time = end_time - start_time
# Should complete in reasonable time (less than 1 second)
assert total_time < 1.0
# Verify cache statistics
stats = manager.get_cache_statistics()
assert stats["detection_cache"]["size"] > 0
assert stats["session_cache"]["size"] > 0
assert stats["detection_cache"]["hits"] > 0
def test_cache_persistence_and_recovery(self):
"""Test cache persistence and recovery (if enabled)."""
# This test would be more meaningful with actual persistence
# For now, test the configuration and structure
persistence_config = CacheConfig(
max_size=100,
enable_persistence=True,
persistence_path="/tmp/detector_cache_test"
)
cache = SessionCache(persistence_config)
# Add some data
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("main", {"class": "car", "confidence": 0.95})
cache.put("session_123", session_data)
# Verify data exists
assert cache.contains("session_123") is True
# In a real implementation, this would test:
# 1. Saving cache to disk
# 2. Loading cache from disk
# 3. Verifying data integrity after reload

View file

@ -0,0 +1,818 @@
"""
Unit tests for stream management functionality.
"""
import pytest
import asyncio
import threading
import time
from unittest.mock import Mock, AsyncMock, patch, MagicMock
import numpy as np
import cv2
from detector_worker.streams.stream_manager import (
StreamManager,
StreamInfo,
StreamConfig,
StreamReader,
StreamError,
ConnectionError as StreamConnectionError
)
from detector_worker.streams.frame_reader import FrameReader
from detector_worker.core.exceptions import ConfigurationError
class TestStreamConfig:
"""Test stream configuration."""
def test_creation_rtsp(self):
"""Test creating RTSP stream config."""
config = StreamConfig(
stream_url="rtsp://example.com/stream1",
stream_type="rtsp",
target_fps=15,
reconnect_interval=5.0,
max_retries=3
)
assert config.stream_url == "rtsp://example.com/stream1"
assert config.stream_type == "rtsp"
assert config.target_fps == 15
assert config.reconnect_interval == 5.0
assert config.max_retries == 3
def test_creation_http_snapshot(self):
"""Test creating HTTP snapshot config."""
config = StreamConfig(
stream_url="http://example.com/snapshot.jpg",
stream_type="http_snapshot",
snapshot_interval=1.0,
timeout=10.0
)
assert config.stream_url == "http://example.com/snapshot.jpg"
assert config.stream_type == "http_snapshot"
assert config.snapshot_interval == 1.0
assert config.timeout == 10.0
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"stream_url": "rtsp://camera.example.com/live",
"stream_type": "rtsp",
"target_fps": 20,
"reconnect_interval": 3.0,
"max_retries": 5,
"crop_region": [100, 200, 300, 400],
"unknown_field": "ignored"
}
config = StreamConfig.from_dict(config_dict)
assert config.stream_url == "rtsp://camera.example.com/live"
assert config.target_fps == 20
assert config.crop_region == [100, 200, 300, 400]
def test_validation(self):
"""Test config validation."""
# Valid config
valid_config = StreamConfig(
stream_url="rtsp://example.com/stream",
stream_type="rtsp"
)
assert valid_config.is_valid() is True
# Invalid config (empty URL)
invalid_config = StreamConfig(
stream_url="",
stream_type="rtsp"
)
assert invalid_config.is_valid() is False
class TestStreamInfo:
"""Test stream information."""
def test_creation(self):
"""Test stream info creation."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
info = StreamInfo(
stream_id="stream_001",
config=config,
camera_id="camera_001"
)
assert info.stream_id == "stream_001"
assert info.config == config
assert info.camera_id == "camera_001"
assert info.status == "inactive"
assert info.reference_count == 0
assert info.created_at <= time.time()
def test_increment_reference(self):
"""Test incrementing reference count."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
info = StreamInfo("stream_001", config, "camera_001")
assert info.reference_count == 0
info.increment_reference()
assert info.reference_count == 1
info.increment_reference()
assert info.reference_count == 2
def test_decrement_reference(self):
"""Test decrementing reference count."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
info = StreamInfo("stream_001", config, "camera_001")
info.reference_count = 3
assert info.decrement_reference() == 2
assert info.reference_count == 2
assert info.decrement_reference() == 1
assert info.decrement_reference() == 0
# Should not go below 0
assert info.decrement_reference() == 0
def test_update_status(self):
"""Test updating stream status."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
info = StreamInfo("stream_001", config, "camera_001")
info.update_status("connecting")
assert info.status == "connecting"
assert info.last_update <= time.time()
info.update_status("active", frame_count=100)
assert info.status == "active"
assert info.frame_count == 100
def test_get_stats(self):
"""Test getting stream statistics."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
info = StreamInfo("stream_001", config, "camera_001")
info.frame_count = 1000
info.error_count = 5
info.reference_count = 2
stats = info.get_stats()
assert stats["stream_id"] == "stream_001"
assert stats["status"] == "inactive"
assert stats["frame_count"] == 1000
assert stats["error_count"] == 5
assert stats["reference_count"] == 2
assert "uptime" in stats
class TestStreamReader:
"""Test stream reader functionality."""
def test_creation(self):
"""Test stream reader creation."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
assert reader.stream_id == "stream_001"
assert reader.config == config
assert reader.is_running is False
assert reader.latest_frame is None
assert reader.frame_queue.qsize() == 0
@pytest.mark.asyncio
async def test_start_rtsp_stream(self):
"""Test starting RTSP stream."""
config = StreamConfig("rtsp://example.com/stream", "rtsp", target_fps=10)
reader = StreamReader("stream_001", config)
# Mock cv2.VideoCapture
with patch('cv2.VideoCapture') as mock_cap:
mock_cap_instance = Mock()
mock_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, np.zeros((480, 640, 3), dtype=np.uint8))
await reader.start()
assert reader.is_running is True
assert reader.capture is not None
mock_cap.assert_called_once_with("rtsp://example.com/stream")
@pytest.mark.asyncio
async def test_start_rtsp_connection_failure(self):
"""Test RTSP connection failure."""
config = StreamConfig("rtsp://invalid.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
with patch('cv2.VideoCapture') as mock_cap:
mock_cap_instance = Mock()
mock_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = False
with pytest.raises(StreamConnectionError):
await reader.start()
@pytest.mark.asyncio
async def test_start_http_snapshot(self):
"""Test starting HTTP snapshot stream."""
config = StreamConfig("http://example.com/snapshot.jpg", "http_snapshot", snapshot_interval=1.0)
reader = StreamReader("stream_001", config)
with patch('requests.get') as mock_get:
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"fake_image_data"
mock_get.return_value = mock_response
with patch('cv2.imdecode') as mock_decode:
mock_decode.return_value = np.zeros((480, 640, 3), dtype=np.uint8)
await reader.start()
assert reader.is_running is True
mock_get.assert_called_once()
@pytest.mark.asyncio
async def test_stop_stream(self):
"""Test stopping stream."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
# Simulate running state
reader.is_running = True
reader.capture = Mock()
reader.capture.release = Mock()
reader._reader_task = Mock()
reader._reader_task.cancel = Mock()
await reader.stop()
assert reader.is_running is False
reader.capture.release.assert_called_once()
reader._reader_task.cancel.assert_called_once()
def test_get_latest_frame(self):
"""Test getting latest frame."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
reader.latest_frame = test_frame
frame = reader.get_latest_frame()
assert np.array_equal(frame, test_frame)
def test_get_frame_from_queue(self):
"""Test getting frame from queue."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
reader.frame_queue.put(test_frame)
frame = reader.get_frame(timeout=0.1)
assert np.array_equal(frame, test_frame)
def test_get_frame_timeout(self):
"""Test getting frame with timeout."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
# Queue is empty, should timeout
frame = reader.get_frame(timeout=0.1)
assert frame is None
def test_get_stats(self):
"""Test getting reader statistics."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
reader.frame_count = 500
reader.error_count = 2
stats = reader.get_stats()
assert stats["stream_id"] == "stream_001"
assert stats["frame_count"] == 500
assert stats["error_count"] == 2
assert stats["is_running"] is False
class TestStreamManager:
"""Test stream manager functionality."""
def test_initialization(self):
"""Test stream manager initialization."""
manager = StreamManager()
assert len(manager.streams) == 0
assert len(manager.readers) == 0
assert manager.max_streams == 10
assert manager.default_timeout == 30.0
def test_initialization_with_config(self):
"""Test initialization with custom configuration."""
config = {
"max_streams": 20,
"default_timeout": 60.0,
"frame_buffer_size": 5
}
manager = StreamManager(config)
assert manager.max_streams == 20
assert manager.default_timeout == 60.0
assert manager.frame_buffer_size == 5
@pytest.mark.asyncio
async def test_create_stream_new(self):
"""Test creating new stream."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
stream_info = await manager.create_stream("camera_001", config, "sub_001")
assert "camera_001" in manager.streams
assert manager.streams["camera_001"].reference_count == 1
assert manager.streams["camera_001"].camera_id == "camera_001"
@pytest.mark.asyncio
async def test_create_stream_shared(self):
"""Test creating shared stream (same URL)."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
# Create first stream
stream_info1 = await manager.create_stream("camera_001", config, "sub_001")
# Create second stream with same URL
stream_info2 = await manager.create_stream("camera_001", config, "sub_002")
assert stream_info1 == stream_info2 # Should be same stream
assert manager.streams["camera_001"].reference_count == 2
@pytest.mark.asyncio
async def test_create_stream_max_limit(self):
"""Test creating stream when at max limit."""
manager = StreamManager({"max_streams": 1})
config1 = StreamConfig("rtsp://example.com/stream1", "rtsp")
config2 = StreamConfig("rtsp://example.com/stream2", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
# Create first stream (should succeed)
await manager.create_stream("camera_001", config1, "sub_001")
# Try to create second stream (should fail)
with pytest.raises(StreamError) as exc_info:
await manager.create_stream("camera_002", config2, "sub_002")
assert "maximum number of streams" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_remove_stream_single_reference(self):
"""Test removing stream with single reference."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
with patch.object(StreamReader, 'stop', new_callable=AsyncMock):
# Create stream
await manager.create_stream("camera_001", config, "sub_001")
# Remove stream
removed = await manager.remove_stream("camera_001", "sub_001")
assert removed is True
assert "camera_001" not in manager.streams
@pytest.mark.asyncio
async def test_remove_stream_multiple_references(self):
"""Test removing stream with multiple references."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
# Create shared stream
await manager.create_stream("camera_001", config, "sub_001")
await manager.create_stream("camera_001", config, "sub_002")
assert manager.streams["camera_001"].reference_count == 2
# Remove one reference
removed = await manager.remove_stream("camera_001", "sub_001")
assert removed is True
assert "camera_001" in manager.streams # Still exists
assert manager.streams["camera_001"].reference_count == 1
def test_get_stream_info(self):
"""Test getting stream information."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
stream_info = StreamInfo("camera_001", config, "camera_001")
manager.streams["camera_001"] = stream_info
retrieved_info = manager.get_stream_info("camera_001")
assert retrieved_info == stream_info
def test_get_nonexistent_stream_info(self):
"""Test getting info for non-existent stream."""
manager = StreamManager()
info = manager.get_stream_info("nonexistent_camera")
assert info is None
def test_get_latest_frame(self):
"""Test getting latest frame from stream."""
manager = StreamManager()
# Create mock reader
mock_reader = Mock()
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
mock_reader.get_latest_frame.return_value = test_frame
manager.readers["camera_001"] = mock_reader
frame = manager.get_latest_frame("camera_001")
assert np.array_equal(frame, test_frame)
mock_reader.get_latest_frame.assert_called_once()
def test_get_frame_from_nonexistent_stream(self):
"""Test getting frame from non-existent stream."""
manager = StreamManager()
frame = manager.get_latest_frame("nonexistent_camera")
assert frame is None
def test_list_active_streams(self):
"""Test listing active streams."""
manager = StreamManager()
# Add streams
config1 = StreamConfig("rtsp://example.com/stream1", "rtsp")
config2 = StreamConfig("rtsp://example.com/stream2", "rtsp")
stream1 = StreamInfo("camera_001", config1, "camera_001")
stream1.update_status("active")
stream2 = StreamInfo("camera_002", config2, "camera_002")
stream2.update_status("inactive")
manager.streams["camera_001"] = stream1
manager.streams["camera_002"] = stream2
active_streams = manager.list_active_streams()
assert len(active_streams) == 1
assert active_streams[0]["camera_id"] == "camera_001"
assert active_streams[0]["status"] == "active"
@pytest.mark.asyncio
async def test_stop_all_streams(self):
"""Test stopping all streams."""
manager = StreamManager()
# Add mock streams
mock_reader1 = Mock()
mock_reader1.stop = AsyncMock()
mock_reader2 = Mock()
mock_reader2.stop = AsyncMock()
manager.readers["camera_001"] = mock_reader1
manager.readers["camera_002"] = mock_reader2
stopped_count = await manager.stop_all_streams()
assert stopped_count == 2
mock_reader1.stop.assert_called_once()
mock_reader2.stop.assert_called_once()
assert len(manager.readers) == 0
assert len(manager.streams) == 0
def test_get_stream_statistics(self):
"""Test getting stream statistics."""
manager = StreamManager()
# Add streams
config = StreamConfig("rtsp://example.com/stream", "rtsp")
stream1 = StreamInfo("camera_001", config, "camera_001")
stream1.update_status("active")
stream1.frame_count = 1000
stream1.reference_count = 2
stream2 = StreamInfo("camera_002", config, "camera_002")
stream2.update_status("error")
stream2.error_count = 5
manager.streams["camera_001"] = stream1
manager.streams["camera_002"] = stream2
stats = manager.get_stream_statistics()
assert stats["total_streams"] == 2
assert stats["active_streams"] == 1
assert stats["error_streams"] == 1
assert stats["total_references"] == 2
assert "status_breakdown" in stats
@pytest.mark.asyncio
async def test_reconnect_stream(self):
"""Test reconnecting failed stream."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
stream_info = StreamInfo("camera_001", config, "camera_001")
stream_info.update_status("error")
manager.streams["camera_001"] = stream_info
# Mock reader
mock_reader = Mock()
mock_reader.start = AsyncMock()
mock_reader.stop = AsyncMock()
manager.readers["camera_001"] = mock_reader
result = await manager.reconnect_stream("camera_001")
assert result is True
mock_reader.stop.assert_called_once()
mock_reader.start.assert_called_once()
assert stream_info.status != "error"
@pytest.mark.asyncio
async def test_health_check_streams(self):
"""Test health check of all streams."""
manager = StreamManager()
# Add streams with different states
config = StreamConfig("rtsp://example.com/stream", "rtsp")
stream1 = StreamInfo("camera_001", config, "camera_001")
stream1.update_status("active")
stream2 = StreamInfo("camera_002", config, "camera_002")
stream2.update_status("error")
manager.streams["camera_001"] = stream1
manager.streams["camera_002"] = stream2
# Mock readers
mock_reader1 = Mock()
mock_reader1.is_running = True
mock_reader2 = Mock()
mock_reader2.is_running = False
manager.readers["camera_001"] = mock_reader1
manager.readers["camera_002"] = mock_reader2
health_report = await manager.health_check()
assert health_report["total_streams"] == 2
assert health_report["healthy_streams"] == 1
assert health_report["unhealthy_streams"] == 1
assert len(health_report["unhealthy_stream_ids"]) == 1
class TestStreamManagerIntegration:
"""Integration tests for stream manager."""
@pytest.mark.asyncio
async def test_multiple_subscribers_same_stream(self):
"""Test multiple subscribers to same stream."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/shared_stream", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
# Multiple subscribers to same stream
stream1 = await manager.create_stream("camera_001", config, "sub_001")
stream2 = await manager.create_stream("camera_001", config, "sub_002")
stream3 = await manager.create_stream("camera_001", config, "sub_003")
# All should reference same stream
assert stream1 == stream2 == stream3
assert manager.streams["camera_001"].reference_count == 3
assert len(manager.readers) == 1 # Only one actual reader
# Remove subscribers one by one
with patch.object(StreamReader, 'stop', new_callable=AsyncMock) as mock_stop:
await manager.remove_stream("camera_001", "sub_001") # ref_count = 2
await manager.remove_stream("camera_001", "sub_002") # ref_count = 1
# Stream should still exist
assert "camera_001" in manager.streams
mock_stop.assert_not_called()
await manager.remove_stream("camera_001", "sub_003") # ref_count = 0
# Now stream should be stopped and removed
assert "camera_001" not in manager.streams
mock_stop.assert_called_once()
@pytest.mark.asyncio
async def test_stream_failure_and_recovery(self):
"""Test stream failure and recovery workflow."""
manager = StreamManager()
config = StreamConfig("rtsp://unreliable.com/stream", "rtsp", max_retries=2)
# Mock reader that fails initially then succeeds
with patch.object(StreamReader, 'start', new_callable=AsyncMock) as mock_start:
mock_start.side_effect = [
StreamConnectionError("Connection failed"), # First attempt fails
None # Second attempt succeeds
]
# First attempt should fail
with pytest.raises(StreamConnectionError):
await manager.create_stream("camera_001", config, "sub_001")
# Retry should succeed
stream_info = await manager.create_stream("camera_001", config, "sub_001")
assert stream_info is not None
assert mock_start.call_count == 2
@pytest.mark.asyncio
async def test_concurrent_stream_operations(self):
"""Test concurrent stream operations."""
manager = StreamManager()
configs = [
StreamConfig(f"rtsp://example.com/stream{i}", "rtsp")
for i in range(5)
]
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
with patch.object(StreamReader, 'stop', new_callable=AsyncMock):
# Create streams concurrently
create_tasks = [
manager.create_stream(f"camera_{i}", configs[i], f"sub_{i}")
for i in range(5)
]
results = await asyncio.gather(*create_tasks)
assert len(results) == 5
assert len(manager.streams) == 5
# Remove streams concurrently
remove_tasks = [
manager.remove_stream(f"camera_{i}", f"sub_{i}")
for i in range(5)
]
remove_results = await asyncio.gather(*remove_tasks)
assert all(remove_results)
assert len(manager.streams) == 0
@pytest.mark.asyncio
async def test_memory_management_large_scale(self):
"""Test memory management with many streams."""
manager = StreamManager({"max_streams": 50})
# Create many streams
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
for i in range(30):
config = StreamConfig(f"rtsp://example.com/stream{i}", "rtsp")
await manager.create_stream(f"camera_{i}", config, f"sub_{i}")
# Verify memory usage is reasonable
stats = manager.get_stream_statistics()
assert stats["total_streams"] == 30
assert stats["active_streams"] <= 30
# Test bulk cleanup
with patch.object(StreamReader, 'stop', new_callable=AsyncMock):
stopped_count = await manager.stop_all_streams()
assert stopped_count == 30
assert len(manager.streams) == 0
assert len(manager.readers) == 0
class TestFrameReaderIntegration:
"""Integration tests for frame reader."""
@pytest.mark.asyncio
async def test_rtsp_frame_processing(self):
"""Test RTSP frame processing pipeline."""
config = StreamConfig(
stream_url="rtsp://example.com/stream",
stream_type="rtsp",
target_fps=10,
crop_region=[100, 100, 400, 300]
)
reader = StreamReader("test_stream", config)
# Mock cv2.VideoCapture
with patch('cv2.VideoCapture') as mock_cap:
mock_cap_instance = Mock()
mock_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
# Mock frame sequence
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
mock_cap_instance.read.side_effect = [
(True, test_frame), # First frame
(True, test_frame * 0.8), # Second frame
(False, None), # Connection lost
(True, test_frame * 1.2), # Reconnected
]
await reader.start()
# Let reader process some frames
await asyncio.sleep(0.1)
# Verify frame processing
latest_frame = reader.get_latest_frame()
assert latest_frame is not None
assert latest_frame.shape == (480, 640, 3)
await reader.stop()
@pytest.mark.asyncio
async def test_http_snapshot_processing(self):
"""Test HTTP snapshot processing."""
config = StreamConfig(
stream_url="http://camera.example.com/snapshot.jpg",
stream_type="http_snapshot",
snapshot_interval=0.5,
timeout=5.0
)
reader = StreamReader("snapshot_stream", config)
with patch('requests.get') as mock_get:
# Mock HTTP responses
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"fake_jpeg_data"
mock_get.return_value = mock_response
with patch('cv2.imdecode') as mock_decode:
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 200
mock_decode.return_value = test_frame
await reader.start()
# Wait for snapshot capture
await asyncio.sleep(0.6)
# Verify snapshot processing
latest_frame = reader.get_latest_frame()
assert latest_frame is not None
assert np.array_equal(latest_frame, test_frame)
await reader.stop()
def test_frame_queue_management(self):
"""Test frame queue management and buffering."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("queue_test", config, frame_buffer_size=3)
# Add frames to queue
frames = [
np.ones((100, 100, 3), dtype=np.uint8) * i
for i in range(50, 250, 50) # 4 different frames
]
for frame in frames[:3]: # Fill buffer
reader._add_frame_to_queue(frame)
assert reader.frame_queue.qsize() == 3
# Add one more (should drop oldest)
reader._add_frame_to_queue(frames[3])
assert reader.frame_queue.qsize() == 3
# Verify frame order (oldest should be dropped)
retrieved_frames = []
while not reader.frame_queue.empty():
retrieved_frames.append(reader.get_frame(timeout=0.1))
assert len(retrieved_frames) == 3
# First frame should have been dropped, so we should have frames 1,2,3
assert not np.array_equal(retrieved_frames[0], frames[0])