Refactor: PHASE 8: Testing & Integration

This commit is contained in:
ziesorx 2025-09-12 18:55:23 +07:00
parent af34f4fd08
commit 9e8c6804a7
32 changed files with 17128 additions and 0 deletions

View file

@ -0,0 +1,856 @@
"""
Unit tests for WebSocket handling functionality.
"""
import pytest
import asyncio
import json
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from fastapi.websockets import WebSocket, WebSocketDisconnect
import uuid
from detector_worker.communication.websocket_handler import (
WebSocketHandler,
ConnectionManager,
WebSocketConnection,
MessageHandler,
WebSocketError,
ConnectionError as WSConnectionError
)
from detector_worker.communication.message_processor import MessageType
from detector_worker.core.exceptions import MessageProcessingError
class TestWebSocketConnection:
"""Test WebSocket connection wrapper."""
def test_creation(self, mock_websocket):
"""Test WebSocket connection creation."""
connection = WebSocketConnection(mock_websocket, "client_001")
assert connection.websocket == mock_websocket
assert connection.client_id == "client_001"
assert connection.is_connected is False
assert connection.connected_at is None
assert connection.last_ping is None
assert connection.subscription_id is None
@pytest.mark.asyncio
async def test_accept_connection(self, mock_websocket):
"""Test accepting WebSocket connection."""
connection = WebSocketConnection(mock_websocket, "client_001")
mock_websocket.accept = AsyncMock()
await connection.accept()
assert connection.is_connected is True
assert connection.connected_at is not None
mock_websocket.accept.assert_called_once()
@pytest.mark.asyncio
async def test_send_message_json(self, mock_websocket):
"""Test sending JSON message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_json = AsyncMock()
message = {"type": "test", "data": "hello"}
await connection.send_message(message)
mock_websocket.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_send_message_text(self, mock_websocket):
"""Test sending text message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_text = AsyncMock()
await connection.send_message("hello world")
mock_websocket.send_text.assert_called_once_with("hello world")
@pytest.mark.asyncio
async def test_send_message_not_connected(self, mock_websocket):
"""Test sending message when not connected."""
connection = WebSocketConnection(mock_websocket, "client_001")
# Don't set is_connected = True
with pytest.raises(WebSocketError) as exc_info:
await connection.send_message({"type": "test"})
assert "not connected" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_receive_message_json(self, mock_websocket):
"""Test receiving JSON message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.receive_json = AsyncMock(return_value={"type": "test", "data": "received"})
message = await connection.receive_message()
assert message == {"type": "test", "data": "received"}
mock_websocket.receive_json.assert_called_once()
@pytest.mark.asyncio
async def test_receive_message_text(self, mock_websocket):
"""Test receiving text message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
# Mock receive_json to fail, then receive_text to succeed
mock_websocket.receive_json = AsyncMock(side_effect=json.JSONDecodeError("Invalid JSON", "", 0))
mock_websocket.receive_text = AsyncMock(return_value="plain text message")
message = await connection.receive_message()
assert message == "plain text message"
mock_websocket.receive_text.assert_called_once()
@pytest.mark.asyncio
async def test_ping_pong(self, mock_websocket):
"""Test ping/pong functionality."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.ping = AsyncMock()
await connection.ping()
assert connection.last_ping is not None
mock_websocket.ping.assert_called_once()
@pytest.mark.asyncio
async def test_close_connection(self, mock_websocket):
"""Test closing connection."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.close = AsyncMock()
await connection.close(code=1000, reason="Normal closure")
assert connection.is_connected is False
mock_websocket.close.assert_called_once_with(code=1000, reason="Normal closure")
def test_connection_info(self, mock_websocket):
"""Test getting connection information."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
connection.subscription_id = "sub_123"
info = connection.get_connection_info()
assert info["client_id"] == "client_001"
assert info["is_connected"] is True
assert info["subscription_id"] == "sub_123"
assert "connected_at" in info
assert "last_ping" in info
class TestConnectionManager:
"""Test WebSocket connection management."""
def test_initialization(self):
"""Test connection manager initialization."""
manager = ConnectionManager()
assert len(manager.connections) == 0
assert len(manager.subscriptions) == 0
assert manager.max_connections == 100
@pytest.mark.asyncio
async def test_add_connection(self, mock_websocket):
"""Test adding a connection."""
manager = ConnectionManager()
client_id = "client_001"
connection = await manager.add_connection(mock_websocket, client_id)
assert connection.client_id == client_id
assert client_id in manager.connections
assert manager.get_connection_count() == 1
@pytest.mark.asyncio
async def test_remove_connection(self, mock_websocket):
"""Test removing a connection."""
manager = ConnectionManager()
client_id = "client_001"
await manager.add_connection(mock_websocket, client_id)
assert client_id in manager.connections
removed_connection = await manager.remove_connection(client_id)
assert removed_connection is not None
assert removed_connection.client_id == client_id
assert client_id not in manager.connections
assert manager.get_connection_count() == 0
def test_get_connection(self, mock_websocket):
"""Test getting a connection."""
manager = ConnectionManager()
client_id = "client_001"
# Manually add connection for testing
connection = WebSocketConnection(mock_websocket, client_id)
manager.connections[client_id] = connection
retrieved_connection = manager.get_connection(client_id)
assert retrieved_connection == connection
assert retrieved_connection.client_id == client_id
def test_get_nonexistent_connection(self):
"""Test getting non-existent connection."""
manager = ConnectionManager()
connection = manager.get_connection("nonexistent_client")
assert connection is None
@pytest.mark.asyncio
async def test_broadcast_message(self, mock_websocket):
"""Test broadcasting message to all connections."""
manager = ConnectionManager()
# Add multiple connections
connections = []
for i in range(3):
client_id = f"client_{i}"
ws = Mock()
ws.send_json = AsyncMock()
connection = WebSocketConnection(ws, client_id)
connection.is_connected = True
manager.connections[client_id] = connection
connections.append(connection)
message = {"type": "broadcast", "data": "hello all"}
await manager.broadcast(message)
# All connections should have received the message
for connection in connections:
connection.websocket.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_broadcast_to_subscription(self, mock_websocket):
"""Test broadcasting to specific subscription."""
manager = ConnectionManager()
# Add connections with different subscriptions
subscription_id = "camera_001"
# Connection with target subscription
ws1 = Mock()
ws1.send_json = AsyncMock()
connection1 = WebSocketConnection(ws1, "client_001")
connection1.is_connected = True
connection1.subscription_id = subscription_id
manager.connections["client_001"] = connection1
manager.subscriptions[subscription_id] = {"client_001"}
# Connection with different subscription
ws2 = Mock()
ws2.send_json = AsyncMock()
connection2 = WebSocketConnection(ws2, "client_002")
connection2.is_connected = True
connection2.subscription_id = "camera_002"
manager.connections["client_002"] = connection2
manager.subscriptions["camera_002"] = {"client_002"}
message = {"type": "detection", "data": "camera detection"}
await manager.broadcast_to_subscription(subscription_id, message)
# Only connection1 should have received the message
ws1.send_json.assert_called_once_with(message)
ws2.send_json.assert_not_called()
def test_add_subscription(self):
"""Test adding subscription mapping."""
manager = ConnectionManager()
client_id = "client_001"
subscription_id = "camera_001"
manager.add_subscription(client_id, subscription_id)
assert subscription_id in manager.subscriptions
assert client_id in manager.subscriptions[subscription_id]
def test_remove_subscription(self):
"""Test removing subscription mapping."""
manager = ConnectionManager()
client_id = "client_001"
subscription_id = "camera_001"
# Add subscription first
manager.add_subscription(client_id, subscription_id)
assert client_id in manager.subscriptions[subscription_id]
# Remove subscription
manager.remove_subscription(client_id, subscription_id)
assert client_id not in manager.subscriptions.get(subscription_id, set())
def test_get_subscription_clients(self):
"""Test getting clients for a subscription."""
manager = ConnectionManager()
subscription_id = "camera_001"
clients = ["client_001", "client_002", "client_003"]
for client_id in clients:
manager.add_subscription(client_id, subscription_id)
subscription_clients = manager.get_subscription_clients(subscription_id)
assert subscription_clients == set(clients)
def test_get_client_subscriptions(self):
"""Test getting subscriptions for a client."""
manager = ConnectionManager()
client_id = "client_001"
subscriptions = ["camera_001", "camera_002", "camera_003"]
for subscription_id in subscriptions:
manager.add_subscription(client_id, subscription_id)
client_subscriptions = manager.get_client_subscriptions(client_id)
assert client_subscriptions == set(subscriptions)
@pytest.mark.asyncio
async def test_cleanup_disconnected_connections(self):
"""Test cleanup of disconnected connections."""
manager = ConnectionManager()
# Add connected and disconnected connections
ws1 = Mock()
connection1 = WebSocketConnection(ws1, "client_001")
connection1.is_connected = True
manager.connections["client_001"] = connection1
ws2 = Mock()
connection2 = WebSocketConnection(ws2, "client_002")
connection2.is_connected = False # Disconnected
manager.connections["client_002"] = connection2
# Add subscriptions
manager.add_subscription("client_001", "camera_001")
manager.add_subscription("client_002", "camera_002")
cleaned_count = await manager.cleanup_disconnected()
assert cleaned_count == 1
assert "client_001" in manager.connections # Still connected
assert "client_002" not in manager.connections # Cleaned up
# Subscriptions should also be cleaned up
assert manager.get_client_subscriptions("client_002") == set()
def test_get_connection_stats(self):
"""Test getting connection statistics."""
manager = ConnectionManager()
# Add various connections and subscriptions
for i in range(3):
client_id = f"client_{i}"
ws = Mock()
connection = WebSocketConnection(ws, client_id)
connection.is_connected = i < 2 # First 2 connected, last one disconnected
manager.connections[client_id] = connection
if i < 2: # Add subscriptions for connected clients
manager.add_subscription(client_id, f"camera_{i}")
stats = manager.get_connection_stats()
assert stats["total_connections"] == 3
assert stats["active_connections"] == 2
assert stats["total_subscriptions"] == 2
assert "uptime" in stats
class TestMessageHandler:
"""Test message handling functionality."""
def test_creation(self):
"""Test message handler creation."""
mock_processor = Mock()
handler = MessageHandler(mock_processor)
assert handler.message_processor == mock_processor
assert handler.connection_manager is None
def test_set_connection_manager(self):
"""Test setting connection manager."""
mock_processor = Mock()
mock_manager = Mock()
handler = MessageHandler(mock_processor)
handler.set_connection_manager(mock_manager)
assert handler.connection_manager == mock_manager
@pytest.mark.asyncio
async def test_handle_message_success(self, mock_websocket):
"""Test successful message handling."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(return_value={"type": "response", "status": "success"})
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
message = {"type": "subscribe", "payload": {"camera_id": "camera_001"}}
response = await handler.handle_message(connection, message)
assert response["status"] == "success"
mock_processor.process_message.assert_called_once_with(message, "client_001")
@pytest.mark.asyncio
async def test_handle_message_processing_error(self, mock_websocket):
"""Test message handling with processing error."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(side_effect=MessageProcessingError("Invalid message"))
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
message = {"type": "invalid", "payload": {}}
response = await handler.handle_message(connection, message)
assert response["type"] == "error"
assert "Invalid message" in response["message"]
@pytest.mark.asyncio
async def test_handle_message_unexpected_error(self, mock_websocket):
"""Test message handling with unexpected error."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(side_effect=Exception("Unexpected error"))
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
message = {"type": "test", "payload": {}}
response = await handler.handle_message(connection, message)
assert response["type"] == "error"
assert "internal error" in response["message"].lower()
@pytest.mark.asyncio
async def test_send_response(self, mock_websocket):
"""Test sending response to client."""
mock_processor = Mock()
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_json = AsyncMock()
response = {"type": "response", "data": "test response"}
await handler.send_response(connection, response)
mock_websocket.send_json.assert_called_once_with(response)
@pytest.mark.asyncio
async def test_send_error_response(self, mock_websocket):
"""Test sending error response."""
mock_processor = Mock()
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_json = AsyncMock()
await handler.send_error_response(connection, "Test error message", "TEST_ERROR")
mock_websocket.send_json.assert_called_once()
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == "error"
assert call_args["message"] == "Test error message"
assert call_args["error_code"] == "TEST_ERROR"
class TestWebSocketHandler:
"""Test main WebSocket handler functionality."""
def test_initialization(self):
"""Test WebSocket handler initialization."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
assert isinstance(handler.connection_manager, ConnectionManager)
assert isinstance(handler.message_handler, MessageHandler)
assert handler.message_handler.connection_manager == handler.connection_manager
assert handler.heartbeat_interval == 30.0
assert handler.max_connections == 100
def test_initialization_with_config(self):
"""Test initialization with custom configuration."""
mock_processor = Mock()
config = {
"heartbeat_interval": 60.0,
"max_connections": 200,
"connection_timeout": 300.0
}
handler = WebSocketHandler(mock_processor, config)
assert handler.heartbeat_interval == 60.0
assert handler.max_connections == 200
assert handler.connection_timeout == 300.0
@pytest.mark.asyncio
async def test_handle_websocket_connection(self, mock_websocket):
"""Test handling WebSocket connection."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(return_value={"type": "ack", "status": "success"})
handler = WebSocketHandler(mock_processor)
# Mock WebSocket behavior
mock_websocket.accept = AsyncMock()
mock_websocket.receive_json = AsyncMock(side_effect=[
{"type": "subscribe", "payload": {"camera_id": "camera_001"}},
WebSocketDisconnect() # Simulate disconnection
])
mock_websocket.send_json = AsyncMock()
client_id = "test_client_001"
# Handle connection (should not raise exception)
await handler.handle_websocket(mock_websocket, client_id)
# Verify connection was accepted
mock_websocket.accept.assert_called_once()
# Verify message was processed
mock_processor.process_message.assert_called_once()
@pytest.mark.asyncio
async def test_handle_websocket_max_connections(self, mock_websocket):
"""Test handling max connections limit."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor, {"max_connections": 1})
# Add one connection to reach limit
client1_ws = Mock()
connection1 = WebSocketConnection(client1_ws, "client_001")
handler.connection_manager.connections["client_001"] = connection1
mock_websocket.close = AsyncMock()
# Try to add second connection
await handler.handle_websocket(mock_websocket, "client_002")
# Should close connection due to limit
mock_websocket.close.assert_called_once()
@pytest.mark.asyncio
async def test_broadcast_message(self):
"""Test broadcasting message to all connections."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Mock connection manager
handler.connection_manager.broadcast = AsyncMock()
message = {"type": "system", "data": "Server maintenance in 10 minutes"}
await handler.broadcast_message(message)
handler.connection_manager.broadcast.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_send_to_client(self):
"""Test sending message to specific client."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Create mock connection
mock_websocket = Mock()
mock_websocket.send_json = AsyncMock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
handler.connection_manager.connections["client_001"] = connection
message = {"type": "notification", "data": "Personal message"}
result = await handler.send_to_client("client_001", message)
assert result is True
mock_websocket.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_send_to_nonexistent_client(self):
"""Test sending message to non-existent client."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
message = {"type": "notification", "data": "Message"}
result = await handler.send_to_client("nonexistent_client", message)
assert result is False
@pytest.mark.asyncio
async def test_send_to_subscription(self):
"""Test sending message to subscription."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Mock connection manager
handler.connection_manager.broadcast_to_subscription = AsyncMock()
subscription_id = "camera_001"
message = {"type": "detection", "data": {"class": "car", "confidence": 0.95}}
await handler.send_to_subscription(subscription_id, message)
handler.connection_manager.broadcast_to_subscription.assert_called_once_with(subscription_id, message)
@pytest.mark.asyncio
async def test_start_heartbeat_task(self):
"""Test starting heartbeat task."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor, {"heartbeat_interval": 0.1})
# Mock connection with ping capability
mock_websocket = Mock()
mock_websocket.ping = AsyncMock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
handler.connection_manager.connections["client_001"] = connection
# Start heartbeat task
heartbeat_task = asyncio.create_task(handler._heartbeat_loop())
# Let it run briefly
await asyncio.sleep(0.2)
# Cancel task
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
# Should have sent at least one ping
assert mock_websocket.ping.called
def test_get_connection_stats(self):
"""Test getting connection statistics."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Add some mock connections
for i in range(3):
client_id = f"client_{i}"
ws = Mock()
connection = WebSocketConnection(ws, client_id)
connection.is_connected = True
handler.connection_manager.connections[client_id] = connection
stats = handler.get_connection_stats()
assert stats["total_connections"] == 3
assert stats["active_connections"] == 3
def test_get_client_info(self):
"""Test getting client information."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Add mock connection
mock_websocket = Mock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
connection.subscription_id = "camera_001"
handler.connection_manager.connections["client_001"] = connection
info = handler.get_client_info("client_001")
assert info is not None
assert info["client_id"] == "client_001"
assert info["is_connected"] is True
assert info["subscription_id"] == "camera_001"
@pytest.mark.asyncio
async def test_disconnect_client(self):
"""Test disconnecting specific client."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Add mock connection
mock_websocket = Mock()
mock_websocket.close = AsyncMock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
handler.connection_manager.connections["client_001"] = connection
result = await handler.disconnect_client("client_001", code=1000, reason="Admin disconnect")
assert result is True
mock_websocket.close.assert_called_once_with(code=1000, reason="Admin disconnect")
@pytest.mark.asyncio
async def test_cleanup_connections(self):
"""Test cleanup of disconnected connections."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Mock connection manager cleanup
handler.connection_manager.cleanup_disconnected = AsyncMock(return_value=2)
cleaned_count = await handler.cleanup_connections()
assert cleaned_count == 2
handler.connection_manager.cleanup_disconnected.assert_called_once()
class TestWebSocketHandlerIntegration:
"""Integration tests for WebSocket handler."""
@pytest.mark.asyncio
async def test_complete_subscription_workflow(self, mock_websocket):
"""Test complete subscription workflow."""
mock_processor = Mock()
# Mock processor responses
mock_processor.process_message = AsyncMock(side_effect=[
{"type": "subscribeAck", "status": "success", "subscription_id": "camera_001"},
{"type": "unsubscribeAck", "status": "success"}
])
handler = WebSocketHandler(mock_processor)
# Mock WebSocket behavior
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock(side_effect=[
{"type": "subscribe", "payload": {"camera_id": "camera_001", "rtsp_url": "rtsp://example.com"}},
{"type": "unsubscribe", "payload": {"subscription_id": "camera_001"}},
WebSocketDisconnect()
])
client_id = "test_client"
# Handle complete workflow
await handler.handle_websocket(mock_websocket, client_id)
# Verify both messages were processed
assert mock_processor.process_message.call_count == 2
# Verify responses were sent
assert mock_websocket.send_json.call_count == 2
@pytest.mark.asyncio
async def test_multiple_client_management(self):
"""Test managing multiple concurrent clients."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
clients = []
for i in range(5):
client_id = f"client_{i}"
mock_ws = Mock()
mock_ws.send_json = AsyncMock()
connection = WebSocketConnection(mock_ws, client_id)
connection.is_connected = True
handler.connection_manager.connections[client_id] = connection
clients.append(connection)
# Test broadcasting to all clients
message = {"type": "broadcast", "data": "Hello all clients"}
await handler.broadcast_message(message)
# All clients should receive the message
for connection in clients:
connection.websocket.send_json.assert_called_once_with(message)
# Test subscription-specific messaging
subscription_id = "camera_001"
handler.connection_manager.add_subscription("client_0", subscription_id)
handler.connection_manager.add_subscription("client_2", subscription_id)
subscription_message = {"type": "detection", "camera_id": "camera_001"}
await handler.send_to_subscription(subscription_id, subscription_message)
# Only subscribed clients should receive the message
# Note: This would require additional mocking of broadcast_to_subscription
@pytest.mark.asyncio
async def test_error_handling_and_recovery(self, mock_websocket):
"""Test error handling and recovery scenarios."""
mock_processor = Mock()
# First message causes error, second succeeds
mock_processor.process_message = AsyncMock(side_effect=[
MessageProcessingError("Invalid message format"),
{"type": "ack", "status": "success"}
])
handler = WebSocketHandler(mock_processor)
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock(side_effect=[
{"type": "invalid", "malformed": True},
{"type": "valid", "payload": {"test": True}},
WebSocketDisconnect()
])
client_id = "error_test_client"
# Should handle errors gracefully and continue processing
await handler.handle_websocket(mock_websocket, client_id)
# Both messages should have been processed
assert mock_processor.process_message.call_count == 2
# Should have sent error response and success response
assert mock_websocket.send_json.call_count == 2
# First call should be error response
first_response = mock_websocket.send_json.call_args_list[0][0][0]
assert first_response["type"] == "error"
@pytest.mark.asyncio
async def test_connection_timeout_handling(self):
"""Test connection timeout handling."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor, {"connection_timeout": 0.1})
# Add connection that hasn't been active
mock_websocket = Mock()
connection = WebSocketConnection(mock_websocket, "timeout_client")
connection.is_connected = True
# Don't update last_ping to simulate timeout
handler.connection_manager.connections["timeout_client"] = connection
# Wait longer than timeout
await asyncio.sleep(0.2)
# Manual cleanup (in real implementation this would be automatic)
cleaned = await handler.cleanup_connections()
# Connection should be identified for cleanup
# (Actual timeout logic would need to be implemented in the cleanup method)

View file

@ -0,0 +1,429 @@
"""
Unit tests for configuration management system.
"""
import pytest
import json
import os
import tempfile
from unittest.mock import Mock, patch, MagicMock
from detector_worker.core.config import (
ConfigurationManager,
JsonFileProvider,
EnvironmentProvider,
DatabaseConfig,
RedisConfig,
StreamConfig,
ModelConfig,
LoggingConfig,
get_config_manager,
validate_config
)
from detector_worker.core.exceptions import ConfigurationError
class TestJsonFileProvider:
"""Test JSON file configuration provider."""
def test_get_config_from_valid_file(self, temp_dir):
"""Test loading configuration from a valid JSON file."""
config_data = {"test_key": "test_value", "number": 42}
config_file = os.path.join(temp_dir, "config.json")
with open(config_file, 'w') as f:
json.dump(config_data, f)
provider = JsonFileProvider(config_file)
result = provider.get_config()
assert result == config_data
def test_get_config_file_not_exists(self, temp_dir):
"""Test handling of non-existent config file."""
config_file = os.path.join(temp_dir, "nonexistent.json")
provider = JsonFileProvider(config_file)
result = provider.get_config()
assert result == {}
def test_get_config_invalid_json(self, temp_dir):
"""Test handling of invalid JSON file."""
config_file = os.path.join(temp_dir, "invalid.json")
with open(config_file, 'w') as f:
f.write("invalid json content")
provider = JsonFileProvider(config_file)
result = provider.get_config()
assert result == {}
def test_reload_updates_config(self, temp_dir):
"""Test that reload updates configuration."""
config_file = os.path.join(temp_dir, "config.json")
# Initial config
initial_config = {"version": 1}
with open(config_file, 'w') as f:
json.dump(initial_config, f)
provider = JsonFileProvider(config_file)
assert provider.get_config() == initial_config
# Update config file
updated_config = {"version": 2}
with open(config_file, 'w') as f:
json.dump(updated_config, f)
# Force reload
provider.reload()
assert provider.get_config() == updated_config
class TestEnvironmentProvider:
"""Test environment variable configuration provider."""
def test_get_config_with_env_vars(self):
"""Test loading configuration from environment variables."""
env_vars = {
"DETECTOR_MAX_STREAMS": "10",
"DETECTOR_TARGET_FPS": "15",
"DETECTOR_CONFIG": '{"nested": "value"}',
"OTHER_VAR": "ignored"
}
with patch.dict(os.environ, env_vars, clear=False):
provider = EnvironmentProvider("DETECTOR_")
config = provider.get_config()
assert config["max_streams"] == "10"
assert config["target_fps"] == "15"
assert config["config"] == {"nested": "value"}
assert "other_var" not in config
def test_get_config_no_env_vars(self):
"""Test with no matching environment variables."""
with patch.dict(os.environ, {}, clear=True):
provider = EnvironmentProvider("DETECTOR_")
config = provider.get_config()
assert config == {}
def test_custom_prefix(self):
"""Test with custom prefix."""
env_vars = {"CUSTOM_TEST": "value"}
with patch.dict(os.environ, env_vars, clear=False):
provider = EnvironmentProvider("CUSTOM_")
config = provider.get_config()
assert config["test"] == "value"
class TestConfigDataclasses:
"""Test configuration dataclasses."""
def test_database_config_from_dict(self):
"""Test DatabaseConfig creation from dictionary."""
data = {
"enabled": True,
"host": "db.example.com",
"port": 5432,
"database": "testdb",
"username": "user",
"password": "pass",
"schema": "test_schema",
"unknown_field": "ignored"
}
config = DatabaseConfig.from_dict(data)
assert config.enabled is True
assert config.host == "db.example.com"
assert config.port == 5432
assert config.database == "testdb"
assert config.username == "user"
assert config.password == "pass"
assert config.schema == "test_schema"
# Unknown fields should be ignored
assert not hasattr(config, 'unknown_field')
def test_redis_config_from_dict(self):
"""Test RedisConfig creation from dictionary."""
data = {
"enabled": True,
"host": "redis.example.com",
"port": 6379,
"password": "secret",
"db": 1
}
config = RedisConfig.from_dict(data)
assert config.enabled is True
assert config.host == "redis.example.com"
assert config.port == 6379
assert config.password == "secret"
assert config.db == 1
def test_stream_config_from_dict(self):
"""Test StreamConfig creation from dictionary."""
data = {
"poll_interval_ms": 50,
"max_streams": 10,
"target_fps": 20,
"reconnect_interval_sec": 10,
"max_retries": 5
}
config = StreamConfig.from_dict(data)
assert config.poll_interval_ms == 50
assert config.max_streams == 10
assert config.target_fps == 20
assert config.reconnect_interval_sec == 10
assert config.max_retries == 5
class TestConfigurationManager:
"""Test main configuration manager."""
def test_initialization_with_defaults(self):
"""Test that manager initializes with default values."""
manager = ConfigurationManager()
# Should have default providers
assert len(manager._providers) >= 1
# Should have default configuration values
config = manager.get_all()
assert "poll_interval_ms" in config
assert "max_streams" in config
assert "target_fps" in config
def test_add_provider(self):
"""Test adding configuration providers."""
manager = ConfigurationManager()
initial_count = len(manager._providers)
mock_provider = Mock()
mock_provider.get_config.return_value = {"test": "value"}
manager.add_provider(mock_provider)
assert len(manager._providers) == initial_count + 1
def test_get_configuration_value(self):
"""Test getting specific configuration values."""
manager = ConfigurationManager()
# Test existing key
value = manager.get("poll_interval_ms")
assert value is not None
# Test non-existing key with default
value = manager.get("nonexistent", "default")
assert value == "default"
# Test non-existing key without default
value = manager.get("nonexistent")
assert value is None
def test_get_section(self):
"""Test getting configuration sections."""
manager = ConfigurationManager()
# Test existing section
db_section = manager.get_section("database")
assert isinstance(db_section, dict)
# Test non-existing section
empty_section = manager.get_section("nonexistent")
assert empty_section == {}
def test_typed_config_access(self):
"""Test typed configuration object access."""
manager = ConfigurationManager()
# Test database config
db_config = manager.get_database_config()
assert isinstance(db_config, DatabaseConfig)
# Test Redis config
redis_config = manager.get_redis_config()
assert isinstance(redis_config, RedisConfig)
# Test stream config
stream_config = manager.get_stream_config()
assert isinstance(stream_config, StreamConfig)
# Test model config
model_config = manager.get_model_config()
assert isinstance(model_config, ModelConfig)
# Test logging config
logging_config = manager.get_logging_config()
assert isinstance(logging_config, LoggingConfig)
def test_set_configuration_value(self):
"""Test setting configuration values at runtime."""
manager = ConfigurationManager()
manager.set("test_key", "test_value")
assert manager.get("test_key") == "test_value"
# Should also update typed configs
manager.set("poll_interval_ms", 200)
stream_config = manager.get_stream_config()
assert stream_config.poll_interval_ms == 200
def test_validation_success(self):
"""Test configuration validation with valid config."""
manager = ConfigurationManager()
# Set valid configuration
manager.set("poll_interval_ms", 100)
manager.set("max_streams", 5)
manager.set("target_fps", 10)
errors = manager.validate()
assert errors == []
assert manager.is_valid() is True
def test_validation_errors(self):
"""Test configuration validation with invalid values."""
manager = ConfigurationManager()
# Set invalid configuration
manager.set("poll_interval_ms", 0)
manager.set("max_streams", -1)
manager.set("target_fps", 0)
errors = manager.validate()
assert len(errors) > 0
assert manager.is_valid() is False
# Check specific errors
error_messages = " ".join(errors)
assert "poll_interval_ms must be positive" in error_messages
assert "max_streams must be positive" in error_messages
assert "target_fps must be positive" in error_messages
def test_database_validation(self):
"""Test database-specific validation."""
manager = ConfigurationManager()
# Enable database but don't provide required fields
db_config = {
"enabled": True,
"host": "",
"database": ""
}
manager.set("database", db_config)
errors = manager.validate()
error_messages = " ".join(errors)
assert "database host is required" in error_messages
assert "database name is required" in error_messages
def test_redis_validation(self):
"""Test Redis-specific validation."""
manager = ConfigurationManager()
# Enable Redis but don't provide required fields
redis_config = {
"enabled": True,
"host": ""
}
manager.set("redis", redis_config)
errors = manager.validate()
error_messages = " ".join(errors)
assert "redis host is required" in error_messages
class TestGlobalConfigurationFunctions:
"""Test global configuration functions."""
def test_get_config_manager_singleton(self):
"""Test that get_config_manager returns a singleton."""
manager1 = get_config_manager()
manager2 = get_config_manager()
assert manager1 is manager2
@patch('detector_worker.core.config.get_config_manager')
def test_validate_config_function(self, mock_get_manager):
"""Test global validate_config function."""
mock_manager = Mock()
mock_manager.validate.return_value = ["error1", "error2"]
mock_get_manager.return_value = mock_manager
errors = validate_config()
assert errors == ["error1", "error2"]
mock_manager.validate.assert_called_once()
class TestConfigurationIntegration:
"""Integration tests for configuration system."""
def test_provider_priority(self, temp_dir):
"""Test that later providers override earlier ones."""
# Create JSON file with initial config
config_file = os.path.join(temp_dir, "config.json")
json_config = {"test_value": "from_json", "json_only": "json"}
with open(config_file, 'w') as f:
json.dump(json_config, f)
# Set environment variable that should override
env_vars = {"DETECTOR_TEST_VALUE": "from_env", "DETECTOR_ENV_ONLY": "env"}
with patch.dict(os.environ, env_vars, clear=False):
manager = ConfigurationManager()
manager._providers.clear() # Start fresh
# Add providers in order
manager.add_provider(JsonFileProvider(config_file))
manager.add_provider(EnvironmentProvider("DETECTOR_"))
config = manager.get_all()
# Environment should override JSON
assert config["test_value"] == "from_env"
# Both sources should be present
assert config["json_only"] == "json"
assert config["env_only"] == "env"
def test_hot_reload(self, temp_dir):
"""Test configuration hot reload functionality."""
config_file = os.path.join(temp_dir, "config.json")
# Initial config
initial_config = {"version": 1, "feature_enabled": False}
with open(config_file, 'w') as f:
json.dump(initial_config, f)
manager = ConfigurationManager()
manager._providers.clear()
manager.add_provider(JsonFileProvider(config_file))
assert manager.get("version") == 1
assert manager.get("feature_enabled") is False
# Update config file
updated_config = {"version": 2, "feature_enabled": True}
with open(config_file, 'w') as f:
json.dump(updated_config, f)
# Reload configuration
success = manager.reload()
assert success is True
assert manager.get("version") == 2
assert manager.get("feature_enabled") is True

View file

@ -0,0 +1,566 @@
"""
Unit tests for dependency injection system.
"""
import pytest
import threading
from unittest.mock import Mock, MagicMock
from detector_worker.core.dependency_injection import (
ServiceContainer,
ServiceLifetime,
ServiceDescriptor,
ServiceScope,
DetectorWorkerContainer,
get_container,
resolve_service,
create_service_scope
)
from detector_worker.core.exceptions import DependencyInjectionError
class TestServiceContainer:
"""Test core service container functionality."""
def test_register_singleton(self):
"""Test singleton service registration."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
# Register singleton
container.register_singleton(TestService)
# Resolve twice - should get same instance
instance1 = container.resolve(TestService)
instance2 = container.resolve(TestService)
assert instance1 is instance2
assert instance1.value == 42
def test_register_singleton_with_instance(self):
"""Test singleton registration with pre-created instance."""
container = ServiceContainer()
class TestService:
def __init__(self, value):
self.value = value
# Create instance and register
instance = TestService(99)
container.register_singleton(TestService, instance=instance)
# Resolve should return the pre-created instance
resolved = container.resolve(TestService)
assert resolved is instance
assert resolved.value == 99
def test_register_transient(self):
"""Test transient service registration."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
# Register transient
container.register_transient(TestService)
# Resolve twice - should get different instances
instance1 = container.resolve(TestService)
instance2 = container.resolve(TestService)
assert instance1 is not instance2
assert instance1.value == instance2.value == 42
def test_register_scoped(self):
"""Test scoped service registration."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
# Register scoped
container.register_scoped(TestService)
# Resolve in same scope - should get same instance
instance1 = container.resolve(TestService, scope_id="scope1")
instance2 = container.resolve(TestService, scope_id="scope1")
assert instance1 is instance2
# Resolve in different scope - should get different instance
instance3 = container.resolve(TestService, scope_id="scope2")
assert instance3 is not instance1
def test_register_with_factory(self):
"""Test service registration with factory function."""
container = ServiceContainer()
class TestService:
def __init__(self, value):
self.value = value
# Register with factory
def factory():
return TestService(100)
container.register_singleton(TestService, factory=factory)
instance = container.resolve(TestService)
assert instance.value == 100
def test_register_with_implementation_type(self):
"""Test service registration with implementation type."""
container = ServiceContainer()
class ITestService:
pass
class TestService(ITestService):
def __init__(self):
self.value = 42
# Register interface with implementation
container.register_singleton(ITestService, implementation_type=TestService)
instance = container.resolve(ITestService)
assert isinstance(instance, TestService)
assert instance.value == 42
def test_dependency_injection(self):
"""Test automatic dependency injection."""
container = ServiceContainer()
class DatabaseService:
def __init__(self):
self.connected = True
class UserService:
def __init__(self, database: DatabaseService):
self.database = database
# Register services
container.register_singleton(DatabaseService)
container.register_transient(UserService)
# Resolve should inject dependencies
user_service = container.resolve(UserService)
assert isinstance(user_service.database, DatabaseService)
assert user_service.database.connected is True
def test_circular_dependency_detection(self):
"""Test circular dependency detection."""
container = ServiceContainer()
class ServiceA:
def __init__(self, service_b: 'ServiceB'):
self.service_b = service_b
class ServiceB:
def __init__(self, service_a: ServiceA):
self.service_a = service_a
# Register circular dependencies
container.register_singleton(ServiceA)
container.register_singleton(ServiceB)
# Should raise circular dependency error
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(ServiceA)
assert "Circular dependency detected" in str(exc_info.value)
def test_unregistered_service_error(self):
"""Test error when resolving unregistered service."""
container = ServiceContainer()
class UnregisteredService:
pass
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(UnregisteredService)
assert "is not registered" in str(exc_info.value)
def test_scoped_service_without_scope_id(self):
"""Test error when resolving scoped service without scope ID."""
container = ServiceContainer()
class TestService:
pass
container.register_scoped(TestService)
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(TestService)
assert "Scope ID required" in str(exc_info.value)
def test_factory_error_handling(self):
"""Test factory error handling."""
container = ServiceContainer()
class TestService:
pass
def failing_factory():
raise ValueError("Factory failed")
container.register_singleton(TestService, factory=failing_factory)
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(TestService)
assert "Failed to create service using factory" in str(exc_info.value)
def test_constructor_dependency_with_default(self):
"""Test dependency with default value."""
container = ServiceContainer()
class TestService:
def __init__(self, value: int = 42):
self.value = value
container.register_singleton(TestService)
instance = container.resolve(TestService)
assert instance.value == 42
def test_unresolvable_dependency_with_default(self):
"""Test unresolvable dependency that has a default value."""
container = ServiceContainer()
class UnregisteredService:
pass
class TestService:
def __init__(self, dep: UnregisteredService = None):
self.dep = dep
container.register_singleton(TestService)
instance = container.resolve(TestService)
assert instance.dep is None
def test_unresolvable_dependency_without_default(self):
"""Test unresolvable dependency without default value."""
container = ServiceContainer()
class UnregisteredService:
pass
class TestService:
def __init__(self, dep: UnregisteredService):
self.dep = dep
container.register_singleton(TestService)
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(TestService)
assert "Cannot resolve dependency" in str(exc_info.value)
class TestServiceScope:
"""Test service scope functionality."""
def test_create_scope(self):
"""Test scope creation."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
container.register_scoped(TestService)
scope = container.create_scope("test_scope")
assert isinstance(scope, ServiceScope)
assert scope.scope_id == "test_scope"
def test_scope_context_manager(self):
"""Test scope as context manager."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.disposed = False
def dispose(self):
self.disposed = True
container.register_scoped(TestService)
instance = None
with container.create_scope("test_scope") as scope:
instance = scope.resolve(TestService)
assert not instance.disposed
# Instance should be disposed after scope exit
assert instance.disposed
def test_dispose_scope(self):
"""Test manual scope disposal."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.disposed = False
def dispose(self):
self.disposed = True
container.register_scoped(TestService)
instance = container.resolve(TestService, scope_id="test_scope")
assert not instance.disposed
container.dispose_scope("test_scope")
assert instance.disposed
def test_dispose_error_handling(self):
"""Test error handling during scope disposal."""
container = ServiceContainer()
class TestService:
def dispose(self):
raise ValueError("Dispose failed")
container.register_scoped(TestService)
container.resolve(TestService, scope_id="test_scope")
# Should not raise error, just log it
container.dispose_scope("test_scope")
class TestContainerIntrospection:
"""Test container introspection capabilities."""
def test_is_registered(self):
"""Test checking if service is registered."""
container = ServiceContainer()
class RegisteredService:
pass
class UnregisteredService:
pass
container.register_singleton(RegisteredService)
assert container.is_registered(RegisteredService) is True
assert container.is_registered(UnregisteredService) is False
def test_get_registration_info(self):
"""Test getting service registration information."""
container = ServiceContainer()
class TestService:
pass
container.register_singleton(TestService)
info = container.get_registration_info(TestService)
assert isinstance(info, ServiceDescriptor)
assert info.service_type == TestService
assert info.lifetime == ServiceLifetime.SINGLETON
def test_get_registered_services(self):
"""Test getting all registered services."""
container = ServiceContainer()
class Service1:
pass
class Service2:
pass
container.register_singleton(Service1)
container.register_transient(Service2)
services = container.get_registered_services()
assert len(services) == 2
assert Service1 in services
assert Service2 in services
def test_clear_singletons(self):
"""Test clearing singleton instances."""
container = ServiceContainer()
class TestService:
pass
container.register_singleton(TestService)
# Create singleton instance
instance1 = container.resolve(TestService)
# Clear singletons
container.clear_singletons()
# Next resolve should create new instance
instance2 = container.resolve(TestService)
assert instance2 is not instance1
def test_get_stats(self):
"""Test getting container statistics."""
container = ServiceContainer()
class Service1:
pass
class Service2:
pass
class Service3:
pass
container.register_singleton(Service1)
container.register_transient(Service2)
container.register_scoped(Service3)
# Create some instances
container.resolve(Service1)
container.resolve(Service3, scope_id="scope1")
stats = container.get_stats()
assert stats["registered_services"] == 3
assert stats["active_singletons"] == 1
assert stats["active_scopes"] == 1
assert stats["lifetime_breakdown"]["singleton"] == 1
assert stats["lifetime_breakdown"]["transient"] == 1
assert stats["lifetime_breakdown"]["scoped"] == 1
class TestDetectorWorkerContainer:
"""Test pre-configured detector worker container."""
def test_initialization(self):
"""Test detector worker container initialization."""
container = DetectorWorkerContainer()
assert isinstance(container.container, ServiceContainer)
# Should have core services registered
stats = container.container.get_stats()
assert stats["registered_services"] > 0
def test_resolve_convenience_method(self):
"""Test resolve convenience method."""
container = DetectorWorkerContainer()
# Should be able to resolve through convenience method
from detector_worker.core.singleton_managers import ModelStateManager
manager = container.resolve(ModelStateManager)
assert isinstance(manager, ModelStateManager)
def test_create_scope_convenience_method(self):
"""Test create scope convenience method."""
container = DetectorWorkerContainer()
scope = container.create_scope("test_scope")
assert isinstance(scope, ServiceScope)
assert scope.scope_id == "test_scope"
class TestGlobalContainerFunctions:
"""Test global container functions."""
def test_get_container_singleton(self):
"""Test that get_container returns a singleton."""
container1 = get_container()
container2 = get_container()
assert container1 is container2
assert isinstance(container1, DetectorWorkerContainer)
def test_resolve_service_convenience(self):
"""Test resolve_service convenience function."""
from detector_worker.core.singleton_managers import ModelStateManager
manager = resolve_service(ModelStateManager)
assert isinstance(manager, ModelStateManager)
def test_create_service_scope_convenience(self):
"""Test create_service_scope convenience function."""
scope = create_service_scope("test_scope")
assert isinstance(scope, ServiceScope)
assert scope.scope_id == "test_scope"
class TestThreadSafety:
"""Test thread safety of dependency injection system."""
def test_container_thread_safety(self):
"""Test that container is thread-safe."""
container = ServiceContainer()
class TestService:
def __init__(self):
import threading
self.thread_id = threading.current_thread().ident
container.register_singleton(TestService)
instances = {}
def resolve_service(thread_id):
instances[thread_id] = container.resolve(TestService)
# Create multiple threads
threads = []
for i in range(10):
thread = threading.Thread(target=resolve_service, args=(i,))
threads.append(thread)
thread.start()
# Wait for all threads
for thread in threads:
thread.join()
# All should get the same singleton instance
first_instance = list(instances.values())[0]
for instance in instances.values():
assert instance is first_instance
def test_scope_thread_safety(self):
"""Test that scoped services are thread-safe."""
container = ServiceContainer()
class TestService:
def __init__(self):
import threading
self.thread_id = threading.current_thread().ident
container.register_scoped(TestService)
results = {}
def resolve_in_scope(thread_id):
# Each thread uses its own scope
instance1 = container.resolve(TestService, scope_id=f"scope_{thread_id}")
instance2 = container.resolve(TestService, scope_id=f"scope_{thread_id}")
results[thread_id] = {
"same_instance": instance1 is instance2,
"thread_id": instance1.thread_id
}
threads = []
for i in range(5):
thread = threading.Thread(target=resolve_in_scope, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Each thread should get same instance within its scope
for thread_id, result in results.items():
assert result["same_instance"] is True

View file

@ -0,0 +1,560 @@
"""
Unit tests for singleton state managers.
"""
import pytest
import time
import threading
from unittest.mock import Mock, patch, MagicMock
from detector_worker.core.singleton_managers import (
SingletonMeta,
ModelStateManager,
StreamStateManager,
SessionStateManager,
CacheStateManager,
CameraStateManager,
PipelineStateManager,
ModelInfo,
StreamInfo,
SessionInfo
)
class TestSingletonMeta:
"""Test singleton metaclass."""
def test_singleton_behavior(self):
"""Test that singleton metaclass creates only one instance."""
class TestSingleton(metaclass=SingletonMeta):
def __init__(self):
self.value = 42
instance1 = TestSingleton()
instance2 = TestSingleton()
assert instance1 is instance2
assert instance1.value == instance2.value
def test_singleton_thread_safety(self):
"""Test that singleton is thread-safe."""
class TestSingleton(metaclass=SingletonMeta):
def __init__(self):
self.created_by = threading.current_thread().name
instances = {}
def create_instance(thread_id):
instances[thread_id] = TestSingleton()
threads = []
for i in range(10):
thread = threading.Thread(target=create_instance, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# All instances should be the same object
first_instance = instances[0]
for instance in instances.values():
assert instance is first_instance
class TestModelStateManager:
"""Test model state management."""
def test_singleton_behavior(self):
"""Test that ModelStateManager is a singleton."""
manager1 = ModelStateManager()
manager2 = ModelStateManager()
assert manager1 is manager2
def test_load_model(self):
"""Test loading a model."""
manager = ModelStateManager()
manager.clear_all() # Start fresh
mock_model = Mock()
manager.load_model("camera1", "model1", mock_model)
retrieved_model = manager.get_model("camera1", "model1")
assert retrieved_model is mock_model
def test_load_same_model_increments_reference_count(self):
"""Test that loading the same model increments reference count."""
manager = ModelStateManager()
manager.clear_all()
mock_model = Mock()
# Load same model twice
manager.load_model("camera1", "model1", mock_model)
manager.load_model("camera1", "model1", mock_model)
# Should still be accessible
assert manager.get_model("camera1", "model1") is mock_model
def test_get_camera_models(self):
"""Test getting all models for a camera."""
manager = ModelStateManager()
manager.clear_all()
mock_model1 = Mock()
mock_model2 = Mock()
manager.load_model("camera1", "model1", mock_model1)
manager.load_model("camera1", "model2", mock_model2)
models = manager.get_camera_models("camera1")
assert len(models) == 2
assert models["model1"] is mock_model1
assert models["model2"] is mock_model2
def test_unload_model_with_multiple_references(self):
"""Test unloading model with multiple references."""
manager = ModelStateManager()
manager.clear_all()
mock_model = Mock()
# Load model twice (reference count = 2)
manager.load_model("camera1", "model1", mock_model)
manager.load_model("camera1", "model1", mock_model)
# First unload should not remove model
result = manager.unload_model("camera1", "model1")
assert result is False # Still referenced
assert manager.get_model("camera1", "model1") is mock_model
# Second unload should remove model
result = manager.unload_model("camera1", "model1")
assert result is True # Completely removed
assert manager.get_model("camera1", "model1") is None
def test_unload_camera_models(self):
"""Test unloading all models for a camera."""
manager = ModelStateManager()
manager.clear_all()
mock_model1 = Mock()
mock_model2 = Mock()
manager.load_model("camera1", "model1", mock_model1)
manager.load_model("camera1", "model2", mock_model2)
manager.unload_camera_models("camera1")
assert manager.get_model("camera1", "model1") is None
assert manager.get_model("camera1", "model2") is None
def test_get_stats(self):
"""Test getting model statistics."""
manager = ModelStateManager()
manager.clear_all()
mock_model = Mock()
manager.load_model("camera1", "model1", mock_model)
manager.load_model("camera2", "model2", mock_model)
stats = manager.get_stats()
assert stats["total_models"] == 2
assert stats["total_cameras"] == 2
assert "camera1" in stats["cameras"]
assert "camera2" in stats["cameras"]
class TestStreamStateManager:
"""Test stream state management."""
def test_add_stream(self):
"""Test adding a stream."""
manager = StreamStateManager()
manager.clear_all()
config = {"rtsp_url": "rtsp://example.com", "model_id": "test"}
manager.add_stream("camera1", "sub1", config)
stream = manager.get_stream("camera1")
assert stream is not None
assert stream.camera_id == "camera1"
assert stream.subscription_id == "sub1"
assert stream.config == config
def test_subscription_mapping(self):
"""Test subscription to camera mapping."""
manager = StreamStateManager()
manager.clear_all()
config = {"rtsp_url": "rtsp://example.com"}
manager.add_stream("camera1", "sub1", config)
camera_id = manager.get_camera_by_subscription("sub1")
assert camera_id == "camera1"
def test_remove_stream(self):
"""Test removing a stream."""
manager = StreamStateManager()
manager.clear_all()
config = {"rtsp_url": "rtsp://example.com"}
manager.add_stream("camera1", "sub1", config)
removed_stream = manager.remove_stream("camera1")
assert removed_stream is not None
assert removed_stream.camera_id == "camera1"
assert manager.get_stream("camera1") is None
assert manager.get_camera_by_subscription("sub1") is None
def test_shared_stream_management(self):
"""Test shared stream management."""
manager = StreamStateManager()
manager.clear_all()
stream_data = {"reader": Mock(), "reference_count": 1}
manager.add_shared_stream("rtsp://example.com", stream_data)
retrieved_data = manager.get_shared_stream("rtsp://example.com")
assert retrieved_data == stream_data
removed_data = manager.remove_shared_stream("rtsp://example.com")
assert removed_data == stream_data
assert manager.get_shared_stream("rtsp://example.com") is None
class TestSessionStateManager:
"""Test session state management."""
def test_session_id_management(self):
"""Test session ID assignment."""
manager = SessionStateManager()
manager.clear_all()
manager.set_session_id("display1", "session123")
session_id = manager.get_session_id("display1")
assert session_id == "session123"
def test_create_session(self):
"""Test session creation with detection data."""
manager = SessionStateManager()
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", detection_data)
retrieved_data = manager.get_session_detection("session123")
assert retrieved_data == detection_data
camera_id = manager.get_camera_by_session("session123")
assert camera_id == "camera1"
def test_update_session_detection(self):
"""Test updating session detection data."""
manager = SessionStateManager()
manager.clear_all()
initial_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", initial_data)
update_data = {"brand": "Toyota"}
manager.update_session_detection("session123", update_data)
final_data = manager.get_session_detection("session123")
assert final_data["class"] == "car"
assert final_data["brand"] == "Toyota"
def test_session_expiration(self):
"""Test session expiration based on TTL."""
# Use a very short TTL for testing
manager = SessionStateManager(session_ttl=0.1)
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", detection_data)
# Session should exist initially
assert manager.get_session_detection("session123") is not None
# Wait for expiration
time.sleep(0.2)
# Clean up expired sessions
expired_count = manager.cleanup_expired_sessions()
assert expired_count == 1
assert manager.get_session_detection("session123") is None
def test_remove_session(self):
"""Test manual session removal."""
manager = SessionStateManager()
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", detection_data)
result = manager.remove_session("session123")
assert result is True
assert manager.get_session_detection("session123") is None
assert manager.get_camera_by_session("session123") is None
class TestCacheStateManager:
"""Test cache state management."""
def test_cache_detection(self):
"""Test caching detection results."""
manager = CacheStateManager()
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85, "bbox": [100, 200, 300, 400]}
manager.cache_detection("camera1", detection_data)
cached_data = manager.get_cached_detection("camera1")
assert cached_data == detection_data
def test_cache_pipeline_result(self):
"""Test caching pipeline results."""
manager = CacheStateManager()
manager.clear_all()
pipeline_result = {"status": "success", "detections": []}
manager.cache_pipeline_result("camera1", pipeline_result)
cached_result = manager.get_cached_pipeline_result("camera1")
assert cached_result == pipeline_result
def test_latest_frame_management(self):
"""Test latest frame storage."""
manager = CacheStateManager()
manager.clear_all()
frame_data = b"fake_frame_data"
manager.set_latest_frame("camera1", frame_data)
retrieved_frame = manager.get_latest_frame("camera1")
assert retrieved_frame == frame_data
def test_frame_skip_flag(self):
"""Test frame skip flag management."""
manager = CacheStateManager()
manager.clear_all()
# Initially should be False
assert manager.get_frame_skip_flag("camera1") is False
manager.set_frame_skip_flag("camera1", True)
assert manager.get_frame_skip_flag("camera1") is True
manager.set_frame_skip_flag("camera1", False)
assert manager.get_frame_skip_flag("camera1") is False
def test_clear_camera_cache(self):
"""Test clearing all cache data for a camera."""
manager = CacheStateManager()
manager.clear_all()
# Set up cache data
detection_data = {"class": "car"}
pipeline_result = {"status": "success"}
frame_data = b"frame"
manager.cache_detection("camera1", detection_data)
manager.cache_pipeline_result("camera1", pipeline_result)
manager.set_latest_frame("camera1", frame_data)
manager.set_frame_skip_flag("camera1", True)
# Clear cache
manager.clear_camera_cache("camera1")
# All data should be gone
assert manager.get_cached_detection("camera1") is None
assert manager.get_cached_pipeline_result("camera1") is None
assert manager.get_latest_frame("camera1") is None
assert manager.get_frame_skip_flag("camera1") is False
class TestCameraStateManager:
"""Test camera state management."""
def test_camera_connection_state(self):
"""Test camera connection state management."""
manager = CameraStateManager()
manager.clear_all()
# Initially connected (default)
assert manager.is_camera_connected("camera1") is True
# Set disconnected
manager.set_camera_connected("camera1", False)
assert manager.is_camera_connected("camera1") is False
# Set connected again
manager.set_camera_connected("camera1", True)
assert manager.is_camera_connected("camera1") is True
def test_notification_flags(self):
"""Test disconnection/reconnection notification flags."""
manager = CameraStateManager()
manager.clear_all()
# Set disconnected
manager.set_camera_connected("camera1", False)
# Should notify disconnection once
assert manager.should_notify_disconnection("camera1") is True
manager.mark_disconnection_notified("camera1")
assert manager.should_notify_disconnection("camera1") is False
# Reconnect
manager.set_camera_connected("camera1", True)
# Should notify reconnection
assert manager.should_notify_reconnection("camera1") is True
manager.mark_reconnection_notified("camera1")
assert manager.should_notify_reconnection("camera1") is False
def test_get_camera_state(self):
"""Test getting full camera state."""
manager = CameraStateManager()
manager.clear_all()
manager.set_camera_connected("camera1", False)
state = manager.get_camera_state("camera1")
assert state["connected"] is False
assert "last_update" in state
assert "disconnection_notified" in state
assert "reconnection_notified" in state
def test_get_stats(self):
"""Test getting camera state statistics."""
manager = CameraStateManager()
manager.clear_all()
manager.set_camera_connected("camera1", True)
manager.set_camera_connected("camera2", False)
stats = manager.get_stats()
assert stats["total_cameras"] == 2
assert stats["connected_cameras"] == 1
assert stats["disconnected_cameras"] == 1
class TestPipelineStateManager:
"""Test pipeline state management."""
def test_get_or_init_state(self):
"""Test getting or initializing pipeline state."""
manager = PipelineStateManager()
manager.clear_all()
state = manager.get_or_init_state("camera1")
assert state["mode"] == "validation_detecting"
assert state["backend_session_id"] is None
assert state["yolo_inference_enabled"] is True
assert "created_at" in state
def test_update_mode(self):
"""Test updating pipeline mode."""
manager = PipelineStateManager()
manager.clear_all()
manager.update_mode("camera1", "classification", "session123")
state = manager.get_state("camera1")
assert state["mode"] == "classification"
assert state["backend_session_id"] == "session123"
def test_set_yolo_inference_enabled(self):
"""Test setting YOLO inference state."""
manager = PipelineStateManager()
manager.clear_all()
manager.set_yolo_inference_enabled("camera1", False)
state = manager.get_state("camera1")
assert state["yolo_inference_enabled"] is False
def test_set_progression_stage(self):
"""Test setting progression stage."""
manager = PipelineStateManager()
manager.clear_all()
manager.set_progression_stage("camera1", "brand_classification")
state = manager.get_state("camera1")
assert state["progression_stage"] == "brand_classification"
def test_set_validated_detection(self):
"""Test setting validated detection."""
manager = PipelineStateManager()
manager.clear_all()
detection = {"class": "car", "confidence": 0.85}
manager.set_validated_detection("camera1", detection)
state = manager.get_state("camera1")
assert state["validated_detection"] == detection
def test_get_stats(self):
"""Test getting pipeline state statistics."""
manager = PipelineStateManager()
manager.clear_all()
manager.update_mode("camera1", "validation_detecting")
manager.update_mode("camera2", "classification")
manager.update_mode("camera3", "classification")
stats = manager.get_stats()
assert stats["total_pipeline_states"] == 3
assert stats["mode_breakdown"]["validation_detecting"] == 1
assert stats["mode_breakdown"]["classification"] == 2
class TestThreadSafety:
"""Test thread safety of singleton managers."""
def test_model_manager_thread_safety(self):
"""Test ModelStateManager thread safety."""
manager = ModelStateManager()
manager.clear_all()
results = {}
def load_models(thread_id):
for i in range(10):
model = Mock()
model.thread_id = thread_id
model.model_id = i
manager.load_model(f"camera{thread_id}", f"model{i}", model)
# Verify models
models = manager.get_camera_models(f"camera{thread_id}")
results[thread_id] = len(models)
threads = []
for i in range(5):
thread = threading.Thread(target=load_models, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Each thread should have loaded 10 models
for thread_id, model_count in results.items():
assert model_count == 10
# Total should be 50 models
stats = manager.get_stats()
assert stats["total_models"] == 50

View file

@ -0,0 +1,479 @@
"""
Unit tests for detection result data structures.
"""
import pytest
from dataclasses import asdict
import numpy as np
from detector_worker.detection.detection_result import (
BoundingBox,
DetectionResult,
LightweightDetectionResult,
DetectionSession,
TrackValidationResult
)
class TestBoundingBox:
"""Test BoundingBox data structure."""
def test_creation_from_coordinates(self):
"""Test creating bounding box from coordinates."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
assert bbox.x1 == 100
assert bbox.y1 == 200
assert bbox.x2 == 300
assert bbox.y2 == 400
def test_creation_from_list(self):
"""Test creating bounding box from list."""
coords = [100, 200, 300, 400]
bbox = BoundingBox.from_list(coords)
assert bbox.x1 == 100
assert bbox.y1 == 200
assert bbox.x2 == 300
assert bbox.y2 == 400
def test_creation_from_invalid_list(self):
"""Test error handling for invalid list."""
with pytest.raises(ValueError):
BoundingBox.from_list([100, 200, 300]) # Too few elements
def test_to_list(self):
"""Test converting bounding box to list."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
coords = bbox.to_list()
assert coords == [100, 200, 300, 400]
def test_area_calculation(self):
"""Test area calculation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
area = bbox.area()
expected_area = (300 - 100) * (400 - 200) # 200 * 200 = 40000
assert area == expected_area
def test_area_zero_for_invalid_bbox(self):
"""Test area is zero for invalid bounding box."""
# x2 <= x1
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
assert bbox.area() == 0
# y2 <= y1
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
assert bbox.area() == 0
def test_width_height(self):
"""Test width and height properties."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
assert bbox.width() == 200
assert bbox.height() == 200
def test_center_point(self):
"""Test center point calculation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
center = bbox.center()
assert center == (200, 300) # (x1+x2)/2, (y1+y2)/2
def test_is_valid(self):
"""Test bounding box validation."""
# Valid bbox
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
assert bbox.is_valid() is True
# Invalid bbox (x2 <= x1)
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
assert bbox.is_valid() is False
# Invalid bbox (y2 <= y1)
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
assert bbox.is_valid() is False
def test_intersection(self):
"""Test bounding box intersection."""
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
intersection = bbox1.intersection(bbox2)
assert intersection.x1 == 200
assert intersection.y1 == 200
assert intersection.x2 == 300
assert intersection.y2 == 300
def test_no_intersection(self):
"""Test no intersection between bounding boxes."""
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
intersection = bbox1.intersection(bbox2)
assert intersection.is_valid() is False
def test_union(self):
"""Test bounding box union."""
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
union = bbox1.union(bbox2)
assert union.x1 == 100
assert union.y1 == 100
assert union.x2 == 400
assert union.y2 == 400
def test_iou_calculation(self):
"""Test IoU (Intersection over Union) calculation."""
# Perfect overlap
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
assert bbox1.iou(bbox2) == 1.0
# No overlap
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
assert bbox1.iou(bbox2) == 0.0
# Partial overlap
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
# Intersection area: 100x100 = 10000
# Union area: 200x200 + 200x200 - 10000 = 30000
# IoU = 10000/30000 = 1/3
expected_iou = 1.0 / 3.0
assert abs(bbox1.iou(bbox2) - expected_iou) < 1e-6
class TestDetectionResult:
"""Test DetectionResult data structure."""
def test_creation_with_required_fields(self):
"""Test creating detection result with required fields."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
assert detection.class_name == "car"
assert detection.confidence == 0.85
assert detection.bbox == bbox
assert detection.track_id == 12345
def test_creation_with_all_fields(self):
"""Test creating detection result with all fields."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345,
model_id="yolo_v8",
timestamp=1640995200000,
branch_results={"brand": "Toyota"}
)
assert detection.model_id == "yolo_v8"
assert detection.timestamp == 1640995200000
assert detection.branch_results == {"brand": "Toyota"}
def test_creation_from_dict(self):
"""Test creating detection result from dictionary."""
data = {
"class": "car",
"confidence": 0.85,
"bbox": [100, 200, 300, 400],
"id": 12345,
"model_id": "yolo_v8",
"timestamp": 1640995200000
}
detection = DetectionResult.from_dict(data)
assert detection.class_name == "car"
assert detection.confidence == 0.85
assert detection.bbox.to_list() == [100, 200, 300, 400]
assert detection.track_id == 12345
def test_to_dict(self):
"""Test converting detection result to dictionary."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
data = detection.to_dict()
assert data["class"] == "car"
assert data["confidence"] == 0.85
assert data["bbox"] == [100, 200, 300, 400]
assert data["id"] == 12345
def test_is_valid_detection(self):
"""Test detection validation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
# Valid detection
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
assert detection.is_valid() is True
# Invalid confidence (too low)
detection = DetectionResult(
class_name="car",
confidence=-0.1,
bbox=bbox,
track_id=12345
)
assert detection.is_valid() is False
# Invalid confidence (too high)
detection = DetectionResult(
class_name="car",
confidence=1.5,
bbox=bbox,
track_id=12345
)
assert detection.is_valid() is False
# Invalid bounding box
invalid_bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=invalid_bbox,
track_id=12345
)
assert detection.is_valid() is False
class TestLightweightDetectionResult:
"""Test LightweightDetectionResult data structure."""
def test_creation(self):
"""Test creating lightweight detection result."""
detection = LightweightDetectionResult(
class_name="car",
confidence=0.85,
bbox_area=40000,
frame_width=1920,
frame_height=1080
)
assert detection.class_name == "car"
assert detection.confidence == 0.85
assert detection.bbox_area == 40000
assert detection.frame_width == 1920
assert detection.frame_height == 1080
def test_area_ratio_calculation(self):
"""Test bounding box area ratio calculation."""
detection = LightweightDetectionResult(
class_name="car",
confidence=0.85,
bbox_area=40000,
frame_width=1920,
frame_height=1080
)
expected_ratio = 40000 / (1920 * 1080)
assert abs(detection.area_ratio() - expected_ratio) < 1e-6
def test_meets_threshold(self):
"""Test threshold checking."""
detection = LightweightDetectionResult(
class_name="car",
confidence=0.85,
bbox_area=40000,
frame_width=1920,
frame_height=1080
)
assert detection.meets_threshold(confidence=0.8, area_ratio=0.01) is True
assert detection.meets_threshold(confidence=0.9, area_ratio=0.01) is False
assert detection.meets_threshold(confidence=0.8, area_ratio=0.1) is False
class TestDetectionSession:
"""Test DetectionSession data structure."""
def test_creation(self):
"""Test creating detection session."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
assert session.session_id == "session_123"
assert session.camera_id == "camera_001"
assert session.display_id == "display_001"
assert session.detections == []
assert session.metadata == {}
def test_add_detection(self):
"""Test adding detection to session."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
session.add_detection(detection)
assert len(session.detections) == 1
assert session.detections[0] == detection
def test_get_latest_detection(self):
"""Test getting latest detection."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
# Add multiple detections
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=12345,
timestamp=1640995200000
)
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
detection2 = DetectionResult(
class_name="car",
confidence=0.90,
bbox=bbox2,
track_id=12345,
timestamp=1640995300000
)
session.add_detection(detection1)
session.add_detection(detection2)
latest = session.get_latest_detection()
assert latest == detection2 # Should be the one with later timestamp
def test_get_detections_by_class(self):
"""Test filtering detections by class."""
session = DetectionSession(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
car_detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=12345
)
truck_detection = DetectionResult(
class_name="truck",
confidence=0.80,
bbox=bbox,
track_id=54321
)
session.add_detection(car_detection)
session.add_detection(truck_detection)
car_detections = session.get_detections_by_class("car")
assert len(car_detections) == 1
assert car_detections[0] == car_detection
truck_detections = session.get_detections_by_class("truck")
assert len(truck_detections) == 1
assert truck_detections[0] == truck_detection
class TestTrackValidationResult:
"""Test TrackValidationResult data structure."""
def test_creation(self):
"""Test creating track validation result."""
result = TrackValidationResult(
stable_tracks=[101, 102, 103],
current_tracks=[101, 102, 104, 105],
newly_stable=[103],
lost_tracks=[106]
)
assert result.stable_tracks == [101, 102, 103]
assert result.current_tracks == [101, 102, 104, 105]
assert result.newly_stable == [103]
assert result.lost_tracks == [106]
def test_has_stable_tracks(self):
"""Test checking for stable tracks."""
result = TrackValidationResult(
stable_tracks=[101, 102],
current_tracks=[101, 102, 103]
)
assert result.has_stable_tracks() is True
result_empty = TrackValidationResult(
stable_tracks=[],
current_tracks=[101, 102, 103]
)
assert result_empty.has_stable_tracks() is False
def test_get_stats(self):
"""Test getting validation statistics."""
result = TrackValidationResult(
stable_tracks=[101, 102, 103],
current_tracks=[101, 102, 104, 105],
newly_stable=[103],
lost_tracks=[106]
)
stats = result.get_stats()
assert stats["stable_count"] == 3
assert stats["current_count"] == 4
assert stats["newly_stable_count"] == 1
assert stats["lost_count"] == 1
assert stats["stability_ratio"] == 3/4 # stable/current
def test_is_track_stable(self):
"""Test checking if specific track is stable."""
result = TrackValidationResult(
stable_tracks=[101, 102, 103],
current_tracks=[101, 102, 104, 105]
)
assert result.is_track_stable(101) is True
assert result.is_track_stable(102) is True
assert result.is_track_stable(104) is False
assert result.is_track_stable(999) is False

View file

@ -0,0 +1,701 @@
"""
Unit tests for track stability validation.
"""
import pytest
import time
from unittest.mock import Mock, patch
from collections import defaultdict
from detector_worker.detection.stability_validator import (
StabilityValidator,
StabilityConfig,
ValidationResult,
TrackStabilityMetrics
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox, TrackValidationResult
from detector_worker.core.exceptions import ValidationError
class TestStabilityConfig:
"""Test stability configuration data structure."""
def test_default_config(self):
"""Test default stability configuration."""
config = StabilityConfig()
assert config.min_detection_frames == 10
assert config.max_absence_frames == 30
assert config.confidence_threshold == 0.5
assert config.stability_window == 60.0
assert config.iou_threshold == 0.3
assert config.movement_threshold == 50.0
def test_custom_config(self):
"""Test custom stability configuration."""
config = StabilityConfig(
min_detection_frames=5,
max_absence_frames=15,
confidence_threshold=0.8,
stability_window=30.0,
iou_threshold=0.5,
movement_threshold=25.0
)
assert config.min_detection_frames == 5
assert config.max_absence_frames == 15
assert config.confidence_threshold == 0.8
assert config.stability_window == 30.0
assert config.iou_threshold == 0.5
assert config.movement_threshold == 25.0
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"min_detection_frames": 8,
"max_absence_frames": 25,
"confidence_threshold": 0.75,
"unknown_field": "ignored"
}
config = StabilityConfig.from_dict(config_dict)
assert config.min_detection_frames == 8
assert config.max_absence_frames == 25
assert config.confidence_threshold == 0.75
# Unknown fields should use defaults
assert config.stability_window == 60.0
class TestTrackStabilityMetrics:
"""Test track stability metrics."""
def test_initialization(self):
"""Test metrics initialization."""
metrics = TrackStabilityMetrics(track_id=1001)
assert metrics.track_id == 1001
assert metrics.detection_count == 0
assert metrics.absence_count == 0
assert metrics.total_confidence == 0.0
assert metrics.first_detection_time is None
assert metrics.last_detection_time is None
assert metrics.bounding_boxes == []
assert metrics.confidence_scores == []
def test_add_detection(self):
"""Test adding detection to metrics."""
metrics = TrackStabilityMetrics(track_id=1001)
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
metrics.add_detection(detection, current_time=1640995200.0)
assert metrics.detection_count == 1
assert metrics.absence_count == 0
assert metrics.total_confidence == 0.85
assert metrics.first_detection_time == 1640995200.0
assert metrics.last_detection_time == 1640995200.0
assert len(metrics.bounding_boxes) == 1
assert len(metrics.confidence_scores) == 1
def test_increment_absence(self):
"""Test incrementing absence count."""
metrics = TrackStabilityMetrics(track_id=1001)
metrics.increment_absence()
assert metrics.absence_count == 1
metrics.increment_absence()
assert metrics.absence_count == 2
def test_reset_absence(self):
"""Test resetting absence count."""
metrics = TrackStabilityMetrics(track_id=1001)
metrics.increment_absence()
metrics.increment_absence()
assert metrics.absence_count == 2
metrics.reset_absence()
assert metrics.absence_count == 0
def test_average_confidence(self):
"""Test average confidence calculation."""
metrics = TrackStabilityMetrics(track_id=1001)
# No detections
assert metrics.average_confidence() == 0.0
# Add detections
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
detection2 = DetectionResult(
class_name="car",
confidence=0.9,
bbox=bbox,
track_id=1001,
timestamp=1640995300000
)
metrics.add_detection(detection1, current_time=1640995200.0)
metrics.add_detection(detection2, current_time=1640995300.0)
assert metrics.average_confidence() == 0.85 # (0.8 + 0.9) / 2
def test_tracking_duration(self):
"""Test tracking duration calculation."""
metrics = TrackStabilityMetrics(track_id=1001)
# No detections
assert metrics.tracking_duration() == 0.0
# Add detections
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
detection2 = DetectionResult(
class_name="car",
confidence=0.9,
bbox=bbox,
track_id=1001,
timestamp=1640995300000
)
metrics.add_detection(detection1, current_time=1640995200.0)
metrics.add_detection(detection2, current_time=1640995300.0)
assert metrics.tracking_duration() == 100.0 # 1640995300 - 1640995200
def test_movement_distance(self):
"""Test movement distance calculation."""
metrics = TrackStabilityMetrics(track_id=1001)
# No movement with single detection
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
metrics.add_detection(detection1, current_time=1640995200.0)
assert metrics.total_movement_distance() == 0.0
# Add second detection with movement
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
detection2 = DetectionResult(
class_name="car",
confidence=0.9,
bbox=bbox2,
track_id=1001,
timestamp=1640995300000
)
metrics.add_detection(detection2, current_time=1640995300.0)
# Distance between centers: (200,300) to (210,310) = sqrt(100+100) ≈ 14.14
movement = metrics.total_movement_distance()
assert movement == pytest.approx(14.14, rel=1e-2)
class TestValidationResult:
"""Test validation result data structure."""
def test_initialization(self):
"""Test validation result initialization."""
result = ValidationResult(
track_id=1001,
is_stable=True,
detection_count=15,
absence_count=2,
average_confidence=0.85,
tracking_duration=120.0
)
assert result.track_id == 1001
assert result.is_stable is True
assert result.detection_count == 15
assert result.absence_count == 2
assert result.average_confidence == 0.85
assert result.tracking_duration == 120.0
assert result.reasons == []
def test_with_reasons(self):
"""Test validation result with failure reasons."""
result = ValidationResult(
track_id=1001,
is_stable=False,
detection_count=5,
absence_count=35,
average_confidence=0.4,
tracking_duration=30.0,
reasons=["Insufficient detection frames", "Too many absences", "Low confidence"]
)
assert result.is_stable is False
assert len(result.reasons) == 3
assert "Insufficient detection frames" in result.reasons
class TestStabilityValidator:
"""Test stability validation functionality."""
def test_initialization_default(self):
"""Test validator initialization with default config."""
validator = StabilityValidator()
assert isinstance(validator.config, StabilityConfig)
assert validator.config.min_detection_frames == 10
assert len(validator.track_metrics) == 0
def test_initialization_custom_config(self):
"""Test validator initialization with custom config."""
config = StabilityConfig(min_detection_frames=5, confidence_threshold=0.8)
validator = StabilityValidator(config)
assert validator.config.min_detection_frames == 5
assert validator.config.confidence_threshold == 0.8
def test_update_detections_new_track(self):
"""Test updating with new track."""
validator = StabilityValidator()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
validator.update_detections([detection], current_time=1640995200.0)
assert 1001 in validator.track_metrics
metrics = validator.track_metrics[1001]
assert metrics.detection_count == 1
assert metrics.absence_count == 0
def test_update_detections_existing_track(self):
"""Test updating existing track."""
validator = StabilityValidator()
# First detection
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
validator.update_detections([detection1], current_time=1640995200.0)
# Second detection
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
detection2 = DetectionResult(
class_name="car",
confidence=0.9,
bbox=bbox2,
track_id=1001,
timestamp=1640995300000
)
validator.update_detections([detection2], current_time=1640995300.0)
metrics = validator.track_metrics[1001]
assert metrics.detection_count == 2
assert metrics.absence_count == 0
assert metrics.average_confidence() == 0.85
def test_update_detections_missing_track(self):
"""Test updating when track is missing (increment absence)."""
validator = StabilityValidator()
# Add track
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
validator.update_detections([detection], current_time=1640995200.0)
# Update with empty detections
validator.update_detections([], current_time=1640995300.0)
metrics = validator.track_metrics[1001]
assert metrics.detection_count == 1
assert metrics.absence_count == 1
def test_validate_track_stable(self):
"""Test validating a stable track."""
config = StabilityConfig(min_detection_frames=3, max_absence_frames=5)
validator = StabilityValidator(config)
# Create track with sufficient detections
track_id = 1001
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
# Add sufficient detections
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(5):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
result = validator.validate_track(track_id)
assert result.is_stable is True
assert result.detection_count == 5
assert result.absence_count == 0
assert len(result.reasons) == 0
def test_validate_track_insufficient_detections(self):
"""Test validating track with insufficient detections."""
config = StabilityConfig(min_detection_frames=10, max_absence_frames=5)
validator = StabilityValidator(config)
# Create track with insufficient detections
track_id = 1001
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
# Add only few detections
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(3):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
result = validator.validate_track(track_id)
assert result.is_stable is False
assert "Insufficient detection frames" in result.reasons
def test_validate_track_too_many_absences(self):
"""Test validating track with too many absences."""
config = StabilityConfig(min_detection_frames=3, max_absence_frames=2)
validator = StabilityValidator(config)
# Create track with too many absences
track_id = 1001
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
# Add detections and absences
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(5):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
# Add too many absences
for _ in range(5):
metrics.increment_absence()
result = validator.validate_track(track_id)
assert result.is_stable is False
assert "Too many absence frames" in result.reasons
def test_validate_track_low_confidence(self):
"""Test validating track with low confidence."""
config = StabilityConfig(
min_detection_frames=3,
max_absence_frames=5,
confidence_threshold=0.8
)
validator = StabilityValidator(config)
# Create track with low confidence
track_id = 1001
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
# Add detections with low confidence
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(5):
detection = DetectionResult(
class_name="car",
confidence=0.5, # Below threshold
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
result = validator.validate_track(track_id)
assert result.is_stable is False
assert "Low average confidence" in result.reasons
def test_validate_all_tracks(self):
"""Test validating all tracks."""
config = StabilityConfig(min_detection_frames=3)
validator = StabilityValidator(config)
# Add multiple tracks
for track_id in [1001, 1002, 1003]:
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
# Make some tracks stable, others not
detection_count = 5 if track_id == 1001 else 2
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(detection_count):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
results = validator.validate_all_tracks()
assert len(results) == 3
assert results[1001].is_stable is True # 5 detections
assert results[1002].is_stable is False # 2 detections
assert results[1003].is_stable is False # 2 detections
def test_get_stable_tracks(self):
"""Test getting stable track IDs."""
config = StabilityConfig(min_detection_frames=3)
validator = StabilityValidator(config)
# Add tracks with different stability
for track_id, detection_count in [(1001, 5), (1002, 2), (1003, 4)]:
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(detection_count):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
stable_tracks = validator.get_stable_tracks()
assert stable_tracks == [1001, 1003] # 5 and 4 detections respectively
def test_cleanup_expired_tracks(self):
"""Test cleanup of expired tracks."""
config = StabilityConfig(stability_window=10.0)
validator = StabilityValidator(config)
# Add tracks with different last detection times
current_time = 1640995300.0
for track_id, last_detection_time in [(1001, current_time - 5), (1002, current_time - 15)]:
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=int(last_detection_time * 1000)
)
metrics.add_detection(detection, current_time=last_detection_time)
removed_count = validator.cleanup_expired_tracks(current_time)
assert removed_count == 1 # 1002 should be removed (15 > 10 seconds)
assert 1001 in validator.track_metrics
assert 1002 not in validator.track_metrics
def test_clear_all_tracks(self):
"""Test clearing all track metrics."""
validator = StabilityValidator()
# Add some tracks
for track_id in [1001, 1002]:
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
assert len(validator.track_metrics) == 2
validator.clear_all_tracks()
assert len(validator.track_metrics) == 0
def test_get_validation_summary(self):
"""Test getting validation summary statistics."""
config = StabilityConfig(min_detection_frames=3)
validator = StabilityValidator(config)
# Add tracks with different characteristics
track_data = [
(1001, 5, True), # Stable
(1002, 2, False), # Unstable
(1003, 4, True), # Stable
(1004, 1, False) # Unstable
]
for track_id, detection_count, _ in track_data:
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
metrics = validator.track_metrics[track_id]
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
for i in range(detection_count):
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
metrics.add_detection(detection, current_time=1640995200.0 + i)
summary = validator.get_validation_summary()
assert summary["total_tracks"] == 4
assert summary["stable_tracks"] == 2
assert summary["unstable_tracks"] == 2
assert summary["stability_rate"] == 0.5
class TestStabilityValidatorIntegration:
"""Integration tests for stability validator."""
def test_full_tracking_lifecycle(self):
"""Test complete tracking lifecycle with stability validation."""
config = StabilityConfig(
min_detection_frames=3,
max_absence_frames=2,
confidence_threshold=0.7
)
validator = StabilityValidator(config)
track_id = 1001
# Phase 1: Initial detections (building up)
for i in range(5):
bbox = BoundingBox(x1=100+i*2, y1=200+i*2, x2=300+i*2, y2=400+i*2)
detection = DetectionResult(
class_name="car",
confidence=0.8,
bbox=bbox,
track_id=track_id,
timestamp=1640995200000 + i * 1000
)
validator.update_detections([detection], current_time=1640995200.0 + i)
# Should be stable now
result = validator.validate_track(track_id)
assert result.is_stable is True
# Phase 2: Some absences
for i in range(2):
validator.update_detections([], current_time=1640995205.0 + i)
# Still stable (within absence threshold)
result = validator.validate_track(track_id)
assert result.is_stable is True
# Phase 3: Track reappears
bbox = BoundingBox(x1=120, y1=220, x2=320, y2=420)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=track_id,
timestamp=1640995207000
)
validator.update_detections([detection], current_time=1640995207.0)
# Should reset absence count and remain stable
result = validator.validate_track(track_id)
assert result.is_stable is True
assert validator.track_metrics[track_id].absence_count == 0
def test_multi_track_validation(self):
"""Test validation with multiple tracks."""
validator = StabilityValidator()
# Simulate multi-track scenario
frame_detections = [
# Frame 1
[
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000),
DetectionResult("truck", 0.8, BoundingBox(400, 200, 600, 400), 1002, 1640995200000)
],
# Frame 2
[
DetectionResult("car", 0.85, BoundingBox(105, 205, 305, 405), 1001, 1640995201000),
DetectionResult("truck", 0.82, BoundingBox(405, 205, 605, 405), 1002, 1640995201000),
DetectionResult("car", 0.75, BoundingBox(200, 300, 400, 500), 1003, 1640995201000)
],
# Frame 3 - track 1002 disappears
[
DetectionResult("car", 0.88, BoundingBox(110, 210, 310, 410), 1001, 1640995202000),
DetectionResult("car", 0.78, BoundingBox(205, 305, 405, 505), 1003, 1640995202000)
]
]
# Process frames
for i, detections in enumerate(frame_detections):
validator.update_detections(detections, current_time=1640995200.0 + i)
# Get validation results
validation_results = validator.validate_all_tracks()
assert len(validation_results) == 3
# All tracks should be unstable (insufficient frames)
for result in validation_results.values():
assert result.is_stable is False
assert "Insufficient detection frames" in result.reasons

View file

@ -0,0 +1,606 @@
"""
Unit tests for BoT-SORT tracking management.
"""
import pytest
import numpy as np
from unittest.mock import Mock, MagicMock, patch
from collections import defaultdict
from detector_worker.detection.tracking_manager import TrackingManager, TrackInfo
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import TrackingError
class TestTrackInfo:
"""Test TrackInfo data structure."""
def test_creation(self):
"""Test TrackInfo creation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
assert track.track_id == 1001
assert track.bbox == bbox
assert track.confidence == 0.85
assert track.class_name == "car"
assert track.first_seen == 1640995200.0
assert track.last_seen == 1640995300.0
assert track.frame_count == 1
assert track.absence_count == 0
def test_update_track(self):
"""Test updating track information."""
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox1,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995200.0
)
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
track.update(bbox2, 0.90, 1640995300.0)
assert track.bbox == bbox2
assert track.confidence == 0.90
assert track.last_seen == 1640995300.0
assert track.frame_count == 2
assert track.absence_count == 0
def test_increment_absence(self):
"""Test incrementing absence count."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995200.0
)
track.increment_absence()
assert track.absence_count == 1
track.increment_absence()
assert track.absence_count == 2
def test_age_calculation(self):
"""Test track age calculation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
age = track.age(current_time=1640995400.0)
assert age == 200.0 # 1640995400 - 1640995200
def test_time_since_last_seen(self):
"""Test time since last seen calculation."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
time_since = track.time_since_last_seen(current_time=1640995450.0)
assert time_since == 150.0 # 1640995450 - 1640995300
def test_is_stable(self):
"""Test track stability checking."""
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
# Not stable initially
assert track.is_stable(min_frames=5, max_absence=3) is False
# Make it stable
track.frame_count = 10
track.absence_count = 1
assert track.is_stable(min_frames=5, max_absence=3) is True
# Too many absences
track.absence_count = 5
assert track.is_stable(min_frames=5, max_absence=3) is False
class TestTrackingManager:
"""Test tracking management functionality."""
def test_initialization(self):
"""Test tracking manager initialization."""
manager = TrackingManager()
assert manager.max_absence_frames == 30
assert manager.min_stable_frames == 10
assert manager.track_timeout == 60.0
assert len(manager.active_tracks) == 0
assert len(manager.stable_tracks) == 0
def test_initialization_with_config(self):
"""Test initialization with custom configuration."""
config = {
"max_absence_frames": 20,
"min_stable_frames": 5,
"track_timeout": 30.0
}
manager = TrackingManager(config)
assert manager.max_absence_frames == 20
assert manager.min_stable_frames == 5
assert manager.track_timeout == 30.0
def test_update_tracks_new_detections(self):
"""Test updating with new detections."""
manager = TrackingManager()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
assert len(manager.active_tracks) == 1
assert 1001 in manager.active_tracks
track = manager.active_tracks[1001]
assert track.track_id == 1001
assert track.class_name == "car"
assert track.confidence == 0.85
assert track.frame_count == 1
def test_update_tracks_existing_detection(self):
"""Test updating existing track."""
manager = TrackingManager()
# First detection
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection1], current_time=1640995200.0)
# Second detection (same track, different position)
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
detection2 = DetectionResult(
class_name="car",
confidence=0.90,
bbox=bbox2,
track_id=1001,
timestamp=1640995300000
)
manager.update_tracks([detection2], current_time=1640995300.0)
assert len(manager.active_tracks) == 1
track = manager.active_tracks[1001]
assert track.frame_count == 2
assert track.confidence == 0.90
assert track.bbox == bbox2
assert track.absence_count == 0
def test_update_tracks_no_detections(self):
"""Test updating with no detections (increment absence)."""
manager = TrackingManager()
# Add initial track
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
# Update with no detections
manager.update_tracks([], current_time=1640995300.0)
track = manager.active_tracks[1001]
assert track.absence_count == 1
def test_cleanup_expired_tracks(self):
"""Test cleanup of expired tracks."""
manager = TrackingManager({"track_timeout": 10.0})
# Add track
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
assert len(manager.active_tracks) == 1
# Cleanup after timeout
removed_count = manager.cleanup_expired_tracks(current_time=1640995220.0) # 20 seconds later
assert removed_count == 1
assert len(manager.active_tracks) == 0
def test_cleanup_absent_tracks(self):
"""Test cleanup of tracks with too many absences."""
manager = TrackingManager({"max_absence_frames": 3})
# Add track
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
# Increment absence count beyond threshold
for i in range(5):
manager.update_tracks([], current_time=1640995200.0 + i)
track = manager.active_tracks[1001]
assert track.absence_count == 5
# Cleanup absent tracks
removed_count = manager.cleanup_absent_tracks()
assert removed_count == 1
assert len(manager.active_tracks) == 0
def test_get_stable_tracks(self):
"""Test getting stable tracks."""
manager = TrackingManager({"min_stable_frames": 3})
# Add track and make it stable
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track_info = TrackInfo(
track_id=1001,
bbox=bbox,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
track_info.frame_count = 5 # Make it stable
manager.active_tracks[1001] = track_info
stable_tracks = manager.get_stable_tracks()
assert len(stable_tracks) == 1
assert 1001 in stable_tracks
assert 1001 in manager.stable_tracks # Should be cached
def test_get_track_by_id(self):
"""Test getting track by ID."""
manager = TrackingManager()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
track = manager.get_track_by_id(1001)
assert track is not None
assert track.track_id == 1001
non_existent = manager.get_track_by_id(9999)
assert non_existent is None
def test_get_tracks_by_class(self):
"""Test getting tracks by class name."""
manager = TrackingManager()
# Add different classes
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
detection2 = DetectionResult(
class_name="truck",
confidence=0.80,
bbox=bbox2,
track_id=1002,
timestamp=1640995200000
)
bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500)
detection3 = DetectionResult(
class_name="car",
confidence=0.90,
bbox=bbox3,
track_id=1003,
timestamp=1640995200000
)
manager.update_tracks([detection1, detection2, detection3], current_time=1640995200.0)
car_tracks = manager.get_tracks_by_class("car")
assert len(car_tracks) == 2
assert 1001 in car_tracks
assert 1003 in car_tracks
truck_tracks = manager.get_tracks_by_class("truck")
assert len(truck_tracks) == 1
assert 1002 in truck_tracks
def test_get_track_count(self):
"""Test getting track counts."""
manager = TrackingManager()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
assert manager.get_active_track_count() == 1
assert manager.get_track_count_by_class("car") == 1
assert manager.get_track_count_by_class("truck") == 0
def test_clear_all_tracks(self):
"""Test clearing all tracks."""
manager = TrackingManager()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection], current_time=1640995200.0)
assert len(manager.active_tracks) == 1
manager.clear_all_tracks()
assert len(manager.active_tracks) == 0
assert len(manager.stable_tracks) == 0
def test_get_track_statistics(self):
"""Test getting track statistics."""
manager = TrackingManager({"min_stable_frames": 2})
# Add multiple tracks
detections = []
for i in range(3):
bbox = BoundingBox(x1=100+i*50, y1=200, x2=300+i*50, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=1001+i,
timestamp=1640995200000
)
detections.append(detection)
manager.update_tracks(detections, current_time=1640995200.0)
# Make some tracks stable
manager.active_tracks[1001].frame_count = 5
manager.active_tracks[1002].frame_count = 3
# 1003 remains unstable with frame_count=1
stats = manager.get_track_statistics()
assert stats["active_tracks"] == 3
assert stats["stable_tracks"] == 2
assert stats["unstable_tracks"] == 1
assert "average_track_age" in stats
assert "average_confidence" in stats
def test_validate_tracks(self):
"""Test track validation."""
manager = TrackingManager({"min_stable_frames": 3, "max_absence_frames": 2})
# Add tracks with different stability
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
track1 = TrackInfo(
track_id=1001,
bbox=bbox1,
confidence=0.85,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995300.0
)
track1.frame_count = 5 # Stable
track1.absence_count = 1 # Present
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
track2 = TrackInfo(
track_id=1002,
bbox=bbox2,
confidence=0.80,
class_name="car",
first_seen=1640995200.0,
last_seen=1640995250.0
)
track2.frame_count = 2 # Not stable
track2.absence_count = 1
bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500)
track3 = TrackInfo(
track_id=1003,
bbox=bbox3,
confidence=0.90,
class_name="car",
first_seen=1640995100.0,
last_seen=1640995150.0
)
track3.frame_count = 8 # Was stable but now absent
track3.absence_count = 5 # Too many absences
manager.active_tracks = {1001: track1, 1002: track2, 1003: track3}
manager.stable_tracks = {1001, 1003} # 1003 was previously stable
validation_result = manager.validate_tracks()
assert validation_result.stable_tracks == [1001]
assert validation_result.current_tracks == [1001, 1002, 1003]
assert validation_result.newly_stable == []
assert validation_result.lost_tracks == [1003]
def test_track_persistence_across_frames(self):
"""Test track persistence across multiple frames."""
manager = TrackingManager()
# Frame 1
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
manager.update_tracks([detection1], current_time=1640995200.0)
# Frame 2 - track moves
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
detection2 = DetectionResult(
class_name="car",
confidence=0.88,
bbox=bbox2,
track_id=1001,
timestamp=1640995300000
)
manager.update_tracks([detection2], current_time=1640995300.0)
# Frame 3 - track disappears
manager.update_tracks([], current_time=1640995400.0)
# Frame 4 - track reappears
bbox4 = BoundingBox(x1=120, y1=220, x2=320, y2=420)
detection4 = DetectionResult(
class_name="car",
confidence=0.82,
bbox=bbox4,
track_id=1001,
timestamp=1640995500000
)
manager.update_tracks([detection4], current_time=1640995500.0)
track = manager.active_tracks[1001]
assert track.frame_count == 3 # Seen in 3 frames
assert track.absence_count == 0 # Reset when reappeared
assert track.bbox == bbox4 # Latest position
class TestTrackingManagerErrorHandling:
"""Test error handling in tracking manager."""
def test_invalid_detection_input(self):
"""Test handling of invalid detection input."""
manager = TrackingManager()
# None detection should be handled gracefully
with pytest.raises(TrackingError):
manager.update_tracks([None], current_time=1640995200.0)
def test_negative_track_id(self):
"""Test handling of negative track ID."""
manager = TrackingManager()
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox,
track_id=-1, # Invalid track ID
timestamp=1640995200000
)
with pytest.raises(TrackingError):
manager.update_tracks([detection], current_time=1640995200.0)
def test_duplicate_track_ids_different_classes(self):
"""Test handling of duplicate track IDs with different classes."""
manager = TrackingManager()
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
detection1 = DetectionResult(
class_name="car",
confidence=0.85,
bbox=bbox1,
track_id=1001,
timestamp=1640995200000
)
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
detection2 = DetectionResult(
class_name="truck", # Different class, same ID
confidence=0.80,
bbox=bbox2,
track_id=1001,
timestamp=1640995200000
)
# Should log warning but handle gracefully
manager.update_tracks([detection1, detection2], current_time=1640995200.0)
# The later detection should update the track
track = manager.active_tracks[1001]
assert track.class_name == "truck" # Last update wins

View file

@ -0,0 +1,386 @@
"""
Unit tests for YOLO detector with tracking functionality.
"""
import pytest
import numpy as np
from unittest.mock import Mock, MagicMock, patch
import torch
from detector_worker.detection.yolo_detector import YOLODetector
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import DetectionError
class TestYOLODetector:
"""Test YOLO detection and tracking functionality."""
def test_initialization_with_valid_model(self, mock_yolo_model):
"""Test detector initialization with valid model."""
detector = YOLODetector(mock_yolo_model)
assert detector.model is mock_yolo_model
assert detector.class_names == {}
assert detector.is_tracking_enabled is True
def test_initialization_with_class_names(self, mock_yolo_model):
"""Test detector initialization with class names."""
class_names = {0: "car", 1: "truck", 2: "bus"}
detector = YOLODetector(mock_yolo_model, class_names=class_names)
assert detector.class_names == class_names
def test_initialization_tracking_disabled(self, mock_yolo_model):
"""Test detector initialization with tracking disabled."""
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
assert detector.is_tracking_enabled is False
def test_detect_with_tracking(self, mock_yolo_model, mock_frame):
"""Test detection with tracking enabled."""
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0], # x1, y1, x2, y2, conf, class
[150, 250, 350, 450, 0.85, 1]
])
mock_result.boxes.id = torch.tensor([1001, 1002])
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame)
assert len(detections) == 2
assert detections[0].confidence == 0.9
assert detections[0].track_id == 1001
assert detections[0].bbox.x1 == 100
mock_yolo_model.track.assert_called_once_with(mock_frame, persist=True, verbose=False)
def test_detect_without_tracking(self, mock_yolo_model, mock_frame):
"""Test detection with tracking disabled."""
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = None # No tracking IDs
mock_yolo_model.predict.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
detections = detector.detect(mock_frame)
assert len(detections) == 1
assert detections[0].track_id is None # No tracking ID
mock_yolo_model.predict.assert_called_once_with(mock_frame, verbose=False)
def test_detect_with_class_names(self, mock_yolo_model, mock_frame):
"""Test detection with class name mapping."""
class_names = {0: "car", 1: "truck"}
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0], # car
[150, 250, 350, 450, 0.85, 1] # truck
])
mock_result.boxes.id = torch.tensor([1001, 1002])
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model, class_names=class_names)
detections = detector.detect(mock_frame)
assert detections[0].class_name == "car"
assert detections[1].class_name == "truck"
def test_detect_no_boxes(self, mock_yolo_model, mock_frame):
"""Test detection when no objects are detected."""
mock_result = Mock()
mock_result.boxes = None
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame)
assert detections == []
def test_detect_empty_boxes(self, mock_yolo_model, mock_frame):
"""Test detection with empty boxes tensor."""
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([]).reshape(0, 6)
mock_result.boxes.id = None
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame)
assert detections == []
def test_detect_with_confidence_threshold(self, mock_yolo_model, mock_frame):
"""Test detection with confidence threshold filtering."""
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0], # Above threshold
[150, 250, 350, 450, 0.3, 1] # Below threshold
])
mock_result.boxes.id = torch.tensor([1001, 1002])
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame, confidence_threshold=0.5)
assert len(detections) == 1 # Only one above threshold
assert detections[0].confidence == 0.9
def test_detect_model_error_handling(self, mock_yolo_model, mock_frame):
"""Test error handling when model fails."""
mock_yolo_model.track.side_effect = Exception("Model inference failed")
detector = YOLODetector(mock_yolo_model)
with pytest.raises(DetectionError) as exc_info:
detector.detect(mock_frame)
assert "Model inference failed" in str(exc_info.value)
def test_detect_invalid_frame(self, mock_yolo_model):
"""Test detection with invalid frame input."""
detector = YOLODetector(mock_yolo_model)
with pytest.raises(DetectionError) as exc_info:
detector.detect(None)
assert "Invalid frame" in str(exc_info.value)
def test_detect_result_validation(self, mock_yolo_model, mock_frame):
"""Test detection result validation."""
# Mock result with invalid bounding box (x2 <= x1)
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[300, 200, 100, 400, 0.9, 0] # Invalid: x2 < x1
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame)
# Invalid detections should be filtered out
assert detections == []
def test_get_model_info(self, mock_yolo_model):
"""Test getting model information."""
mock_yolo_model.device = "cuda:0"
mock_yolo_model.names = {0: "car", 1: "truck"}
detector = YOLODetector(mock_yolo_model)
info = detector.get_model_info()
assert info["device"] == "cuda:0"
assert info["class_names"] == {0: "car", 1: "truck"}
assert info["tracking_enabled"] is True
def test_set_tracking_enabled(self, mock_yolo_model):
"""Test enabling/disabling tracking at runtime."""
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
assert detector.is_tracking_enabled is False
detector.set_tracking_enabled(True)
assert detector.is_tracking_enabled is True
detector.set_tracking_enabled(False)
assert detector.is_tracking_enabled is False
def test_update_class_names(self, mock_yolo_model):
"""Test updating class names at runtime."""
detector = YOLODetector(mock_yolo_model)
new_class_names = {0: "vehicle", 1: "person"}
detector.update_class_names(new_class_names)
assert detector.class_names == new_class_names
def test_reset_tracker(self, mock_yolo_model):
"""Test resetting the tracking state."""
detector = YOLODetector(mock_yolo_model)
# This should not raise an error
detector.reset_tracker()
def test_detect_with_crop_region(self, mock_yolo_model, mock_frame):
"""Test detection with crop region specified."""
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[50, 75, 150, 175, 0.9, 0] # Relative to cropped region
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
crop_region = (100, 200, 300, 400) # x1, y1, x2, y2
detections = detector.detect(mock_frame, crop_region=crop_region)
# Bounding box should be adjusted to global coordinates
assert detections[0].bbox.x1 == 150 # 100 + 50
assert detections[0].bbox.y1 == 275 # 200 + 75
assert detections[0].bbox.x2 == 250 # 100 + 150
assert detections[0].bbox.y2 == 375 # 200 + 175
def test_detect_batch_processing(self, mock_yolo_model):
"""Test batch detection processing."""
frames = [
np.zeros((480, 640, 3), dtype=np.uint8),
np.ones((480, 640, 3), dtype=np.uint8) * 255
]
mock_results = []
for i in range(2):
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100 + i*10, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001 + i])
mock_results.append(mock_result)
mock_yolo_model.track.side_effect = [[result] for result in mock_results]
detector = YOLODetector(mock_yolo_model)
batch_detections = detector.detect_batch(frames)
assert len(batch_detections) == 2
assert len(batch_detections[0]) == 1
assert len(batch_detections[1]) == 1
assert batch_detections[0][0].bbox.x1 == 100
assert batch_detections[1][0].bbox.x1 == 110
def test_detect_batch_empty_frames(self, mock_yolo_model):
"""Test batch detection with empty frame list."""
detector = YOLODetector(mock_yolo_model)
batch_detections = detector.detect_batch([])
assert batch_detections == []
def test_detect_performance_metrics(self, mock_yolo_model, mock_frame):
"""Test detection performance metrics collection."""
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_result.speed = {"preprocess": 2.1, "inference": 15.3, "postprocess": 1.2}
mock_yolo_model.track.return_value = [mock_result]
detector = YOLODetector(mock_yolo_model)
detections = detector.detect(mock_frame, return_metrics=True)
# Check if performance metrics are available
assert hasattr(detector, '_last_inference_time')
@pytest.mark.parametrize("device", ["cpu", "cuda:0", "mps"])
def test_detect_different_devices(self, device, mock_frame):
"""Test detection on different devices."""
mock_model = Mock()
mock_model.device = device
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_model.track.return_value = [mock_result]
detector = YOLODetector(mock_model)
detections = detector.detect(mock_frame)
assert len(detections) == 1
assert detections[0].confidence == 0.9
class TestYOLODetectorIntegration:
"""Integration tests for YOLO detector."""
def test_detect_with_real_tensor_operations(self, mock_yolo_model, mock_frame):
"""Test detection with realistic tensor operations."""
# Create more realistic box data
boxes_data = torch.tensor([
[100.5, 200.3, 299.7, 399.8, 0.95, 0],
[150.2, 250.1, 349.9, 449.6, 0.87, 1],
[200.0, 300.0, 400.0, 500.0, 0.45, 0] # Low confidence
])
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = boxes_data
mock_result.boxes.id = torch.tensor([2001, 2002, 2003])
mock_yolo_model.track.return_value = [mock_result]
class_names = {0: "car", 1: "truck"}
detector = YOLODetector(mock_yolo_model, class_names=class_names)
detections = detector.detect(mock_frame, confidence_threshold=0.5)
# Should filter out low confidence detection
assert len(detections) == 2
# Check first detection
det1 = detections[0]
assert det1.class_name == "car"
assert det1.confidence == pytest.approx(0.95)
assert det1.track_id == 2001
assert det1.bbox.x1 == pytest.approx(100.5)
assert det1.bbox.y1 == pytest.approx(200.3)
# Check second detection
det2 = detections[1]
assert det2.class_name == "truck"
assert det2.confidence == pytest.approx(0.87)
assert det2.track_id == 2002
def test_multi_frame_tracking_consistency(self, mock_yolo_model, mock_frame):
"""Test that tracking IDs remain consistent across frames."""
detector = YOLODetector(mock_yolo_model)
# Frame 1
mock_result1 = Mock()
mock_result1.boxes = Mock()
mock_result1.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result1.boxes.id = torch.tensor([5001])
mock_yolo_model.track.return_value = [mock_result1]
detections1 = detector.detect(mock_frame)
# Frame 2 - same object, slightly moved
mock_result2 = Mock()
mock_result2.boxes = Mock()
mock_result2.boxes.data = torch.tensor([
[105, 205, 305, 405, 0.88, 0]
])
mock_result2.boxes.id = torch.tensor([5001]) # Same ID
mock_yolo_model.track.return_value = [mock_result2]
detections2 = detector.detect(mock_frame)
# Should maintain same track ID
assert detections1[0].track_id == detections2[0].track_id == 5001

View file

@ -0,0 +1,882 @@
"""
Unit tests for model management functionality.
"""
import pytest
import os
import tempfile
import threading
import time
from unittest.mock import Mock, patch, MagicMock
import torch
import numpy as np
from detector_worker.models.model_manager import (
ModelManager,
ModelInfo,
ModelConfig,
ModelCache,
ModelLoader,
ModelError,
ModelLoadError,
ModelCacheError
)
from detector_worker.core.exceptions import ConfigurationError
class TestModelConfig:
"""Test model configuration."""
def test_creation(self):
"""Test model config creation."""
config = ModelConfig(
model_id="yolo_v8_car",
model_path="/models/yolo_v8_car.pt",
model_type="detection",
device="cuda:0"
)
assert config.model_id == "yolo_v8_car"
assert config.model_path == "/models/yolo_v8_car.pt"
assert config.model_type == "detection"
assert config.device == "cuda:0"
assert config.confidence_threshold == 0.5
assert config.max_memory_mb == 1024
def test_creation_with_optional_params(self):
"""Test config creation with optional parameters."""
config = ModelConfig(
model_id="classifier_v1",
model_path="/models/classifier.pt",
model_type="classification",
device="cpu",
confidence_threshold=0.8,
max_memory_mb=512,
class_names={0: "car", 1: "truck", 2: "bus"},
preprocessing_config={"resize": (224, 224), "normalize": True}
)
assert config.confidence_threshold == 0.8
assert config.max_memory_mb == 512
assert config.class_names[0] == "car"
assert config.preprocessing_config["resize"] == (224, 224)
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"model_id": "detection_model",
"model_path": "/path/to/model.pt",
"model_type": "detection",
"device": "cuda:0",
"confidence_threshold": 0.75,
"class_names": {0: "person", 1: "vehicle"},
"unknown_field": "ignored"
}
config = ModelConfig.from_dict(config_dict)
assert config.model_id == "detection_model"
assert config.confidence_threshold == 0.75
assert config.class_names[1] == "vehicle"
def test_validation(self):
"""Test config validation."""
# Valid config
valid_config = ModelConfig(
model_id="test_model",
model_path="/valid/path/model.pt",
model_type="detection",
device="cpu"
)
assert valid_config.is_valid() is True
# Invalid config (empty model_id)
invalid_config = ModelConfig(
model_id="",
model_path="/path/model.pt",
model_type="detection",
device="cpu"
)
assert invalid_config.is_valid() is False
def test_get_memory_limit_bytes(self):
"""Test getting memory limit in bytes."""
config = ModelConfig(
model_id="test",
model_path="/path",
model_type="detection",
device="cpu",
max_memory_mb=256
)
assert config.get_memory_limit_bytes() == 256 * 1024 * 1024
class TestModelInfo:
"""Test model information."""
def test_creation(self):
"""Test model info creation."""
config = ModelConfig(
model_id="test_model",
model_path="/path/model.pt",
model_type="detection",
device="cuda:0"
)
mock_model = Mock()
info = ModelInfo(
config=config,
model_instance=mock_model,
load_time=1.5
)
assert info.config == config
assert info.model_instance == mock_model
assert info.load_time == 1.5
assert info.reference_count == 0
assert info.last_used <= time.time()
assert info.memory_usage == 0
def test_increment_reference(self):
"""Test incrementing reference count."""
config = ModelConfig("test", "/path", "detection", "cpu")
info = ModelInfo(config, Mock(), 1.0)
assert info.reference_count == 0
info.increment_reference()
assert info.reference_count == 1
info.increment_reference()
assert info.reference_count == 2
def test_decrement_reference(self):
"""Test decrementing reference count."""
config = ModelConfig("test", "/path", "detection", "cpu")
info = ModelInfo(config, Mock(), 1.0)
info.reference_count = 3
assert info.decrement_reference() == 2
assert info.reference_count == 2
assert info.decrement_reference() == 1
assert info.decrement_reference() == 0
# Should not go below 0
assert info.decrement_reference() == 0
def test_update_usage(self):
"""Test updating usage statistics."""
config = ModelConfig("test", "/path", "detection", "cpu")
info = ModelInfo(config, Mock(), 1.0)
original_time = info.last_used
original_count = info.usage_count
time.sleep(0.01) # Small delay
info.update_usage(memory_usage=512*1024*1024) # 512MB
assert info.last_used > original_time
assert info.usage_count == original_count + 1
assert info.memory_usage == 512*1024*1024
def test_age_calculation(self):
"""Test age calculation."""
config = ModelConfig("test", "/path", "detection", "cpu")
info = ModelInfo(config, Mock(), 1.0)
time.sleep(0.01)
age = info.age()
assert age > 0
assert age < 1 # Should be less than 1 second
def test_get_stats(self):
"""Test getting model statistics."""
config = ModelConfig("test_model", "/path", "detection", "cuda:0")
info = ModelInfo(config, Mock(), 2.5)
info.reference_count = 3
info.usage_count = 100
info.memory_usage = 1024*1024*1024 # 1GB
stats = info.get_stats()
assert stats["model_id"] == "test_model"
assert stats["device"] == "cuda:0"
assert stats["load_time"] == 2.5
assert stats["reference_count"] == 3
assert stats["usage_count"] == 100
assert stats["memory_usage_mb"] == 1024
assert "age_seconds" in stats
class TestModelLoader:
"""Test model loading functionality."""
def test_creation(self):
"""Test model loader creation."""
loader = ModelLoader()
assert loader.supported_formats == [".pt", ".pth", ".onnx", ".trt"]
assert loader.default_device == "cpu"
def test_detect_device_cuda_available(self):
"""Test device detection when CUDA is available."""
loader = ModelLoader()
with patch('torch.cuda.is_available', return_value=True):
with patch('torch.cuda.device_count', return_value=2):
device = loader.detect_optimal_device()
assert device == "cuda:0"
def test_detect_device_cuda_unavailable(self):
"""Test device detection when CUDA is not available."""
loader = ModelLoader()
with patch('torch.cuda.is_available', return_value=False):
device = loader.detect_optimal_device()
assert device == "cpu"
def test_load_pytorch_model(self):
"""Test loading PyTorch model."""
loader = ModelLoader()
with patch('torch.load') as mock_torch_load:
with patch('os.path.exists', return_value=True):
mock_model = Mock()
mock_torch_load.return_value = mock_model
config = ModelConfig(
model_id="test_model",
model_path="/path/to/model.pt",
model_type="detection",
device="cpu"
)
loaded_model = loader.load_model(config)
assert loaded_model == mock_model
mock_torch_load.assert_called_once_with("/path/to/model.pt", map_location="cpu")
def test_load_model_file_not_exists(self):
"""Test loading model when file doesn't exist."""
loader = ModelLoader()
with patch('os.path.exists', return_value=False):
config = ModelConfig(
model_id="missing_model",
model_path="/nonexistent/model.pt",
model_type="detection",
device="cpu"
)
with pytest.raises(ModelLoadError) as exc_info:
loader.load_model(config)
assert "does not exist" in str(exc_info.value)
def test_load_model_invalid_format(self):
"""Test loading model with invalid format."""
loader = ModelLoader()
with patch('os.path.exists', return_value=True):
config = ModelConfig(
model_id="invalid_model",
model_path="/path/to/model.invalid",
model_type="detection",
device="cpu"
)
with pytest.raises(ModelLoadError) as exc_info:
loader.load_model(config)
assert "unsupported format" in str(exc_info.value).lower()
def test_load_model_torch_error(self):
"""Test loading model with torch loading error."""
loader = ModelLoader()
with patch('os.path.exists', return_value=True):
with patch('torch.load', side_effect=RuntimeError("CUDA out of memory")):
config = ModelConfig(
model_id="error_model",
model_path="/path/to/model.pt",
model_type="detection",
device="cuda:0"
)
with pytest.raises(ModelLoadError) as exc_info:
loader.load_model(config)
assert "CUDA out of memory" in str(exc_info.value)
def test_validate_model_pytorch(self):
"""Test validating PyTorch model."""
loader = ModelLoader()
mock_model = Mock()
mock_model.__class__.__module__ = "torch.nn"
config = ModelConfig("test", "/path", "detection", "cpu")
is_valid = loader.validate_model(mock_model, config)
assert is_valid is True
def test_validate_model_invalid(self):
"""Test validating invalid model."""
loader = ModelLoader()
invalid_model = "not_a_model"
config = ModelConfig("test", "/path", "detection", "cpu")
is_valid = loader.validate_model(invalid_model, config)
assert is_valid is False
def test_estimate_model_memory(self):
"""Test estimating model memory usage."""
loader = ModelLoader()
mock_model = Mock()
mock_param1 = Mock()
mock_param1.numel.return_value = 1000000 # 1M parameters
mock_param1.element_size.return_value = 4 # 4 bytes per parameter
mock_param2 = Mock()
mock_param2.numel.return_value = 500000 # 0.5M parameters
mock_param2.element_size.return_value = 4
mock_model.parameters.return_value = [mock_param1, mock_param2]
memory_bytes = loader.estimate_memory_usage(mock_model)
expected_bytes = (1000000 + 500000) * 4 # 6MB
assert memory_bytes == expected_bytes
class TestModelCache:
"""Test model caching functionality."""
def test_creation(self):
"""Test model cache creation."""
cache = ModelCache(max_size=5, max_memory_mb=2048)
assert cache.max_size == 5
assert cache.max_memory_mb == 2048
assert len(cache.models) == 0
assert len(cache.access_order) == 0
def test_put_and_get_model(self):
"""Test putting and getting model from cache."""
cache = ModelCache(max_size=3)
config = ModelConfig("test_model", "/path", "detection", "cpu")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 1.5)
cache.put("test_model", model_info)
retrieved_info = cache.get("test_model")
assert retrieved_info == model_info
assert retrieved_info.reference_count == 1 # Should be incremented on get
def test_get_nonexistent_model(self):
"""Test getting non-existent model."""
cache = ModelCache(max_size=3)
result = cache.get("nonexistent_model")
assert result is None
def test_contains_check(self):
"""Test checking if model exists in cache."""
cache = ModelCache(max_size=3)
config = ModelConfig("test_model", "/path", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
cache.put("test_model", model_info)
assert cache.contains("test_model") is True
assert cache.contains("nonexistent_model") is False
def test_remove_model(self):
"""Test removing model from cache."""
cache = ModelCache(max_size=3)
config = ModelConfig("test_model", "/path", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
cache.put("test_model", model_info)
assert cache.contains("test_model") is True
removed_info = cache.remove("test_model")
assert removed_info == model_info
assert cache.contains("test_model") is False
def test_lru_eviction(self):
"""Test LRU eviction policy."""
cache = ModelCache(max_size=2)
# Add models to fill cache
for i in range(2):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
cache.put(f"model_{i}", model_info)
# Access model_0 to make it recently used
cache.get("model_0")
# Add another model (should evict model_1, the least recently used)
config = ModelConfig("model_2", "/path_2", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
cache.put("model_2", model_info)
assert cache.size() == 2
assert cache.contains("model_0") is True # Recently accessed
assert cache.contains("model_1") is False # Evicted
assert cache.contains("model_2") is True # Newly added
def test_memory_based_eviction(self):
"""Test memory-based eviction."""
cache = ModelCache(max_size=10, max_memory_mb=1) # 1MB limit
# Add model that uses 0.8MB
config1 = ModelConfig("model_1", "/path_1", "detection", "cpu")
model1 = Mock()
info1 = ModelInfo(config1, model1, 1.0)
info1.memory_usage = 0.8 * 1024 * 1024 # 0.8MB
cache.put("model_1", info1)
# Add model that would exceed memory limit
config2 = ModelConfig("model_2", "/path_2", "detection", "cpu")
model2 = Mock()
info2 = ModelInfo(config2, model2, 1.0)
info2.memory_usage = 0.5 * 1024 * 1024 # 0.5MB
cache.put("model_2", info2)
# First model should be evicted due to memory constraint
assert cache.contains("model_1") is False
assert cache.contains("model_2") is True
def test_get_stats(self):
"""Test getting cache statistics."""
cache = ModelCache(max_size=5)
# Add some models
for i in range(3):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
model_info.memory_usage = 100 * 1024 * 1024 # 100MB each
cache.put(f"model_{i}", model_info)
# Access some models
cache.get("model_0")
cache.get("model_1")
cache.get("nonexistent") # Miss
stats = cache.get_stats()
assert stats["size"] == 3
assert stats["max_size"] == 5
assert stats["hits"] == 2
assert stats["misses"] == 1
assert stats["hit_rate"] == 2/3
assert stats["memory_usage_mb"] == 300
def test_clear_cache(self):
"""Test clearing entire cache."""
cache = ModelCache(max_size=5)
# Add models
for i in range(3):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
cache.put(f"model_{i}", model_info)
assert cache.size() == 3
cache.clear()
assert cache.size() == 0
assert len(cache.models) == 0
assert len(cache.access_order) == 0
class TestModelManager:
"""Test main model manager functionality."""
def test_initialization(self):
"""Test model manager initialization."""
manager = ModelManager()
assert isinstance(manager.cache, ModelCache)
assert isinstance(manager.loader, ModelLoader)
assert manager.models_directory == "models"
assert manager.default_device == "cpu"
def test_initialization_with_config(self):
"""Test initialization with custom configuration."""
config = {
"models_directory": "/custom/models",
"default_device": "cuda:0",
"cache_max_size": 20,
"cache_max_memory_mb": 4096
}
manager = ModelManager(config)
assert manager.models_directory == "/custom/models"
assert manager.default_device == "cuda:0"
assert manager.cache.max_size == 20
assert manager.cache.max_memory_mb == 4096
def test_load_model_new(self):
"""Test loading new model."""
manager = ModelManager()
config = ModelConfig(
model_id="test_model",
model_path="/path/to/model.pt",
model_type="detection",
device="cpu"
)
with patch.object(manager.loader, 'load_model') as mock_load:
with patch.object(manager.loader, 'estimate_memory_usage', return_value=512*1024*1024):
mock_model = Mock()
mock_load.return_value = mock_model
loaded_model = manager.load_model(config)
assert loaded_model == mock_model
assert manager.cache.contains("test_model") is True
mock_load.assert_called_once_with(config)
def test_load_model_from_cache(self):
"""Test loading model from cache."""
manager = ModelManager()
# Pre-populate cache
config = ModelConfig("cached_model", "/path", "detection", "cpu")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 1.0)
manager.cache.put("cached_model", model_info)
with patch.object(manager.loader, 'load_model') as mock_load:
loaded_model = manager.load_model(config)
assert loaded_model == mock_model
mock_load.assert_not_called() # Should not load from disk
def test_get_model_by_id(self):
"""Test getting model by ID."""
manager = ModelManager()
config = ModelConfig("test_model", "/path", "detection", "cpu")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 1.0)
manager.cache.put("test_model", model_info)
retrieved_model = manager.get_model("test_model")
assert retrieved_model == mock_model
def test_get_nonexistent_model(self):
"""Test getting non-existent model."""
manager = ModelManager()
model = manager.get_model("nonexistent_model")
assert model is None
def test_unload_model_with_references(self):
"""Test unloading model with active references."""
manager = ModelManager()
config = ModelConfig("ref_model", "/path", "detection", "cpu")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 1.0)
model_info.reference_count = 2 # Active references
manager.cache.put("ref_model", model_info)
result = manager.unload_model("ref_model")
assert result is False # Should not unload with active references
assert manager.cache.contains("ref_model") is True
def test_unload_model_no_references(self):
"""Test unloading model without references."""
manager = ModelManager()
config = ModelConfig("no_ref_model", "/path", "detection", "cpu")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 1.0)
model_info.reference_count = 0 # No references
manager.cache.put("no_ref_model", model_info)
result = manager.unload_model("no_ref_model")
assert result is True
assert manager.cache.contains("no_ref_model") is False
def test_list_loaded_models(self):
"""Test listing loaded models."""
manager = ModelManager()
# Add models to cache
for i in range(3):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
manager.cache.put(f"model_{i}", model_info)
loaded_models = manager.list_loaded_models()
assert len(loaded_models) == 3
assert all(info["model_id"].startswith("model_") for info in loaded_models)
def test_get_model_info(self):
"""Test getting model information."""
manager = ModelManager()
config = ModelConfig("info_model", "/path", "detection", "cuda:0")
mock_model = Mock()
model_info = ModelInfo(config, mock_model, 2.5)
model_info.usage_count = 10
manager.cache.put("info_model", model_info)
info = manager.get_model_info("info_model")
assert info is not None
assert info["model_id"] == "info_model"
assert info["device"] == "cuda:0"
assert info["load_time"] == 2.5
assert info["usage_count"] == 10
def test_cleanup_unused_models(self):
"""Test cleaning up unused models."""
manager = ModelManager()
# Add models with different reference counts
models_data = [
("used_model", 2), # Has references
("unused_model_1", 0), # No references
("unused_model_2", 0) # No references
]
for model_id, ref_count in models_data:
config = ModelConfig(model_id, f"/path/{model_id}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
model_info.reference_count = ref_count
manager.cache.put(model_id, model_info)
cleaned_count = manager.cleanup_unused_models()
assert cleaned_count == 2 # Two unused models cleaned
assert manager.cache.contains("used_model") is True
assert manager.cache.contains("unused_model_1") is False
assert manager.cache.contains("unused_model_2") is False
def test_get_memory_usage(self):
"""Test getting total memory usage."""
manager = ModelManager()
# Add models with different memory usage
memory_sizes = [256, 512, 1024] # MB
for i, memory_mb in enumerate(memory_sizes):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
model_info.memory_usage = memory_mb * 1024 * 1024 # Convert to bytes
manager.cache.put(f"model_{i}", model_info)
total_usage = manager.get_memory_usage()
expected_bytes = sum(memory_sizes) * 1024 * 1024
assert total_usage == expected_bytes
def test_health_check(self):
"""Test model manager health check."""
manager = ModelManager()
# Add models
for i in range(3):
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
model_info = ModelInfo(config, Mock(), 1.0)
model_info.memory_usage = 100 * 1024 * 1024 # 100MB each
manager.cache.put(f"model_{i}", model_info)
health_report = manager.health_check()
assert health_report["status"] == "healthy"
assert health_report["loaded_models"] == 3
assert health_report["total_memory_mb"] == 300
assert health_report["cache_hit_rate"] >= 0
class TestModelManagerIntegration:
"""Integration tests for model manager."""
def test_concurrent_model_loading(self):
"""Test concurrent model loading."""
manager = ModelManager()
# Mock loader to simulate loading time
def slow_load(config):
time.sleep(0.1) # Simulate loading time
mock_model = Mock()
mock_model.model_id = config.model_id
return mock_model
with patch.object(manager.loader, 'load_model', side_effect=slow_load):
with patch.object(manager.loader, 'estimate_memory_usage', return_value=100*1024*1024):
# Create multiple threads loading different models
results = {}
errors = []
def load_model_thread(model_id):
try:
config = ModelConfig(
model_id=model_id,
model_path=f"/path/{model_id}.pt",
model_type="detection",
device="cpu"
)
model = manager.load_model(config)
results[model_id] = model
except Exception as e:
errors.append((model_id, str(e)))
threads = []
for i in range(5):
thread = threading.Thread(target=load_model_thread, args=(f"model_{i}",))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# All models should be loaded successfully
assert len(errors) == 0
assert len(results) == 5
assert len(manager.cache.models) == 5
def test_memory_pressure_handling(self):
"""Test handling memory pressure."""
# Create manager with small memory limit
manager = ModelManager({
"cache_max_memory_mb": 200 # 200MB limit
})
with patch.object(manager.loader, 'load_model') as mock_load:
with patch.object(manager.loader, 'estimate_memory_usage', return_value=100*1024*1024): # 100MB per model
def create_mock_model(config):
mock_model = Mock()
mock_model.model_id = config.model_id
return mock_model
mock_load.side_effect = create_mock_model
# Load models that exceed memory limit
for i in range(4): # 4 * 100MB = 400MB > 200MB limit
config = ModelConfig(
model_id=f"large_model_{i}",
model_path=f"/path/large_model_{i}.pt",
model_type="detection",
device="cpu"
)
manager.load_model(config)
# Should not exceed memory limit due to eviction
total_memory = manager.get_memory_usage()
memory_limit = 200 * 1024 * 1024
assert total_memory <= memory_limit
def test_model_lifecycle_management(self):
"""Test complete model lifecycle."""
manager = ModelManager()
with patch.object(manager.loader, 'load_model') as mock_load:
with patch.object(manager.loader, 'estimate_memory_usage', return_value=50*1024*1024):
mock_model = Mock()
mock_load.return_value = mock_model
config = ModelConfig(
model_id="lifecycle_model",
model_path="/path/lifecycle_model.pt",
model_type="detection",
device="cpu"
)
# 1. Load model
loaded_model = manager.load_model(config)
assert loaded_model == mock_model
assert manager.cache.contains("lifecycle_model") is True
# 2. Get model multiple times (increase usage)
for _ in range(5):
model = manager.get_model("lifecycle_model")
assert model == mock_model
# 3. Check model info
info = manager.get_model_info("lifecycle_model")
assert info["usage_count"] >= 5
# 4. Simulate model still in use
model_info = manager.cache.get("lifecycle_model")
model_info.reference_count = 1
# Should not unload while in use
unloaded = manager.unload_model("lifecycle_model")
assert unloaded is False
assert manager.cache.contains("lifecycle_model") is True
# 5. Release reference and unload
model_info.reference_count = 0
unloaded = manager.unload_model("lifecycle_model")
assert unloaded is True
assert manager.cache.contains("lifecycle_model") is False
def test_error_recovery(self):
"""Test error recovery scenarios."""
manager = ModelManager()
# Test loading model that fails initially then succeeds
call_count = 0
def failing_then_success_load(config):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ModelLoadError("First attempt failed")
return Mock()
with patch.object(manager.loader, 'load_model', side_effect=failing_then_success_load):
with patch.object(manager.loader, 'estimate_memory_usage', return_value=50*1024*1024):
config = ModelConfig(
model_id="retry_model",
model_path="/path/retry_model.pt",
model_type="detection",
device="cpu"
)
# First attempt should fail
with pytest.raises(ModelLoadError):
manager.load_model(config)
# Model should not be in cache
assert manager.cache.contains("retry_model") is False
# Second attempt should succeed
model = manager.load_model(config)
assert model is not None
assert manager.cache.contains("retry_model") is True

View file

@ -0,0 +1,959 @@
"""
Unit tests for action execution functionality.
"""
import pytest
import asyncio
import json
import base64
import numpy as np
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
from detector_worker.pipeline.action_executor import (
ActionExecutor,
ActionResult,
ActionType,
RedisAction,
PostgreSQLAction,
FileAction
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import ActionError, RedisError, DatabaseError
class TestActionResult:
"""Test action execution result."""
def test_creation_success(self):
"""Test successful action result creation."""
result = ActionResult(
action_type=ActionType.REDIS_SAVE,
success=True,
execution_time=0.05,
metadata={"key": "saved_image_key", "expiry": 600}
)
assert result.action_type == ActionType.REDIS_SAVE
assert result.success is True
assert result.execution_time == 0.05
assert result.metadata["key"] == "saved_image_key"
assert result.error is None
def test_creation_failure(self):
"""Test failed action result creation."""
result = ActionResult(
action_type=ActionType.POSTGRESQL_INSERT,
success=False,
error="Database connection failed",
execution_time=0.02
)
assert result.action_type == ActionType.POSTGRESQL_INSERT
assert result.success is False
assert result.error == "Database connection failed"
assert result.metadata == {}
class TestRedisAction:
"""Test Redis action implementations."""
def test_creation(self):
"""Test Redis action creation."""
action_config = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{timestamp}:{session_id}",
"expire_seconds": 600
}
action = RedisAction(action_config)
assert action.action_type == ActionType.REDIS_SAVE
assert action.region == "car"
assert action.key_template == "inference:{display_id}:{timestamp}:{session_id}"
assert action.expire_seconds == 600
def test_resolve_key_template(self):
"""Test key template resolution."""
action_config = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600
}
action = RedisAction(action_config)
context = {
"display_id": "display_001",
"timestamp": "1640995200000",
"session_id": "session_123",
"filename": "detection.jpg"
}
resolved_key = action.resolve_key(context)
expected_key = "inference:display_001:1640995200000:session_123:detection.jpg"
assert resolved_key == expected_key
def test_resolve_key_missing_variable(self):
"""Test key resolution with missing variable."""
action_config = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{missing_var}",
"expire_seconds": 600
}
action = RedisAction(action_config)
context = {"display_id": "display_001"}
with pytest.raises(ActionError):
action.resolve_key(context)
class TestPostgreSQLAction:
"""Test PostgreSQL action implementations."""
def test_creation_insert(self):
"""Test PostgreSQL insert action creation."""
action_config = {
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"bbox_x1": "{bbox.x1}",
"created_at": "NOW()"
}
}
action = PostgreSQLAction(action_config)
assert action.action_type == ActionType.POSTGRESQL_INSERT
assert action.table == "detections"
assert len(action.fields) == 6
assert action.key_field is None
def test_creation_update(self):
"""Test PostgreSQL update action creation."""
action_config = {
"type": "postgresql_update_combined",
"table": "car_info",
"key_field": "session_id",
"fields": {
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"updated_at": "NOW()"
},
"waitForBranches": ["car_brand_cls", "car_bodytype_cls"]
}
action = PostgreSQLAction(action_config)
assert action.action_type == ActionType.POSTGRESQL_UPDATE
assert action.table == "car_info"
assert action.key_field == "session_id"
assert action.wait_for_branches == ["car_brand_cls", "car_bodytype_cls"]
def test_resolve_field_values(self):
"""Test field value resolution."""
action_config = {
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"brand": "{car_brand_cls.brand}"
}
}
action = PostgreSQLAction(action_config)
context = {
"camera_id": "camera_001",
"class": "car",
"confidence": 0.85
}
branch_results = {
"car_brand_cls": {"brand": "Toyota", "confidence": 0.78}
}
resolved_fields = action.resolve_field_values(context, branch_results)
assert resolved_fields["camera_id"] == "camera_001"
assert resolved_fields["detection_class"] == "car"
assert resolved_fields["confidence"] == 0.85
assert resolved_fields["brand"] == "Toyota"
class TestFileAction:
"""Test file action implementations."""
def test_creation(self):
"""Test file action creation."""
action_config = {
"type": "save_image",
"path": "/tmp/detections/{camera_id}_{timestamp}.jpg",
"region": "car",
"format": "jpeg",
"quality": 85
}
action = FileAction(action_config)
assert action.action_type == ActionType.SAVE_IMAGE
assert action.path_template == "/tmp/detections/{camera_id}_{timestamp}.jpg"
assert action.region == "car"
assert action.format == "jpeg"
assert action.quality == 85
def test_resolve_path_template(self):
"""Test path template resolution."""
action_config = {
"type": "save_image",
"path": "/tmp/detections/{camera_id}/{date}/{timestamp}.jpg"
}
action = FileAction(action_config)
context = {
"camera_id": "camera_001",
"timestamp": "1640995200000",
"date": "2022-01-01"
}
resolved_path = action.resolve_path(context)
expected_path = "/tmp/detections/camera_001/2022-01-01/1640995200000.jpg"
assert resolved_path == expected_path
class TestActionExecutor:
"""Test action execution functionality."""
def test_initialization(self):
"""Test action executor initialization."""
executor = ActionExecutor()
assert executor.redis_client is None
assert executor.db_manager is None
assert executor.max_concurrent_actions == 10
assert executor.action_timeout == 30.0
def test_initialization_with_clients(self, mock_redis_client, mock_database_connection):
"""Test initialization with client instances."""
executor = ActionExecutor(
redis_client=mock_redis_client,
db_manager=mock_database_connection
)
assert executor.redis_client is mock_redis_client
assert executor.db_manager is mock_database_connection
@pytest.mark.asyncio
async def test_execute_actions_empty_list(self):
"""Test executing empty action list."""
executor = ActionExecutor()
context = {
"camera_id": "camera_001",
"session_id": "session_123"
}
results = await executor.execute_actions([], {}, context)
assert results == []
@pytest.mark.asyncio
async def test_execute_redis_save_action(self, mock_redis_client, mock_frame):
"""Test executing Redis save image action."""
executor = ActionExecutor(redis_client=mock_redis_client)
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{camera_id}:{session_id}",
"expire_seconds": 600
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"frame_data": mock_frame
}
# Mock successful Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.REDIS_SAVE
# Verify Redis calls
mock_redis_client.set.assert_called_once()
mock_redis_client.expire.assert_called_once()
@pytest.mark.asyncio
async def test_execute_postgresql_insert_action(self, mock_database_connection):
"""Test executing PostgreSQL insert action."""
# Mock database manager
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(db_manager=mock_db_manager)
actions = [
{
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"created_at": "NOW()"
}
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"class": "car",
"confidence": 0.9
}
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.POSTGRESQL_INSERT
# Verify database call
mock_db_manager.execute_query.assert_called_once()
call_args = mock_db_manager.execute_query.call_args[0]
assert "INSERT INTO detections" in call_args[0]
@pytest.mark.asyncio
async def test_execute_postgresql_update_action(self, mock_database_connection):
"""Test executing PostgreSQL update action."""
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(db_manager=mock_db_manager)
actions = [
{
"type": "postgresql_update_combined",
"table": "car_info",
"key_field": "session_id",
"fields": {
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"updated_at": "NOW()"
},
"waitForBranches": ["car_brand_cls", "car_bodytype_cls"]
}
]
regions = {}
context = {
"session_id": "session_123"
}
branch_results = {
"car_brand_cls": {"brand": "Toyota"},
"car_bodytype_cls": {"body_type": "Sedan"}
}
results = await executor.execute_actions(actions, regions, context, branch_results)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.POSTGRESQL_UPDATE
# Verify database call
mock_db_manager.execute_query.assert_called_once()
call_args = mock_db_manager.execute_query.call_args[0]
assert "UPDATE car_info SET" in call_args[0]
assert "WHERE session_id" in call_args[0]
@pytest.mark.asyncio
async def test_execute_file_save_action(self, mock_frame):
"""Test executing file save action."""
executor = ActionExecutor()
actions = [
{
"type": "save_image",
"path": "/tmp/test_{camera_id}_{timestamp}.jpg",
"region": "car",
"format": "jpeg",
"quality": 85
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"timestamp": "1640995200000",
"frame_data": mock_frame
}
with patch('cv2.imwrite') as mock_imwrite:
mock_imwrite.return_value = True
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.SAVE_IMAGE
# Verify file save call
mock_imwrite.assert_called_once()
call_args = mock_imwrite.call_args
assert "/tmp/test_camera_001_1640995200000.jpg" in call_args[0][0]
@pytest.mark.asyncio
async def test_execute_actions_parallel(self, mock_redis_client):
"""Test parallel execution of multiple actions."""
executor = ActionExecutor(redis_client=mock_redis_client)
# Multiple Redis actions
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:car:{session_id}",
"expire_seconds": 600
},
{
"type": "redis_publish",
"channel": "detections",
"message": "{camera_id}:car_detected"
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
}
# Mock Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
mock_redis_client.publish.return_value = 1
import time
start_time = time.time()
results = await executor.execute_actions(actions, regions, context)
execution_time = time.time() - start_time
assert len(results) == 2
assert all(result.success for result in results)
# Should execute in parallel (faster than sequential)
assert execution_time < 0.1 # Allow some overhead
@pytest.mark.asyncio
async def test_execute_actions_error_handling(self, mock_redis_client):
"""Test error handling in action execution."""
executor = ActionExecutor(redis_client=mock_redis_client)
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{session_id}",
"expire_seconds": 600
},
{
"type": "redis_save_image", # This one will fail
"region": "truck", # Region not detected
"key": "inference:truck:{session_id}",
"expire_seconds": 600
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
# No truck region
}
context = {
"session_id": "session_123",
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
}
# Mock Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 2
assert results[0].success is True # Car action succeeds
assert results[1].success is False # Truck action fails
assert "Region 'truck' not found" in results[1].error
@pytest.mark.asyncio
async def test_execute_actions_timeout(self, mock_redis_client):
"""Test action execution timeout."""
config = {"action_timeout": 0.001} # Very short timeout
executor = ActionExecutor(redis_client=mock_redis_client, config=config)
def slow_redis_operation(*args, **kwargs):
import time
time.sleep(1) # Longer than timeout
return True
mock_redis_client.set.side_effect = slow_redis_operation
actions = [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{session_id}",
"expire_seconds": 600
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"session_id": "session_123",
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
}
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is False
assert "timeout" in results[0].error.lower()
@pytest.mark.asyncio
async def test_execute_redis_publish_action(self, mock_redis_client):
"""Test executing Redis publish action."""
executor = ActionExecutor(redis_client=mock_redis_client)
actions = [
{
"type": "redis_publish",
"channel": "detections:{camera_id}",
"message": {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"timestamp": "{timestamp}"
}
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"class": "car",
"confidence": 0.9,
"timestamp": "1640995200000"
}
mock_redis_client.publish.return_value = 1
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.REDIS_PUBLISH
# Verify publish call
mock_redis_client.publish.assert_called_once()
call_args = mock_redis_client.publish.call_args
assert call_args[0][0] == "detections:camera_001" # Channel
# Message should be JSON
message = call_args[0][1]
parsed_message = json.loads(message)
assert parsed_message["camera_id"] == "camera_001"
assert parsed_message["detection_class"] == "car"
@pytest.mark.asyncio
async def test_execute_conditional_action(self):
"""Test executing conditional actions."""
executor = ActionExecutor()
actions = [
{
"type": "conditional",
"condition": "{confidence} > 0.8",
"actions": [
{
"type": "log",
"message": "High confidence detection: {class} ({confidence})"
}
]
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.95, # High confidence
"detection": DetectionResult("car", 0.95, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"class": "car",
"confidence": 0.95
}
with patch('logging.info') as mock_log:
results = await executor.execute_actions(actions, regions, context)
assert len(results) == 1
assert results[0].success is True
# Should have logged the message
mock_log.assert_called_once()
log_message = mock_log.call_args[0][0]
assert "High confidence detection: car (0.95)" in log_message
def test_crop_region_from_frame(self, mock_frame):
"""Test cropping region from frame."""
executor = ActionExecutor()
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
cropped = executor._crop_region_from_frame(mock_frame, detection.bbox)
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
def test_encode_image_base64(self, mock_frame):
"""Test encoding image to base64."""
executor = ActionExecutor()
# Crop a small region
cropped_frame = mock_frame[200:400, 100:300] # 200x200 region
with patch('cv2.imencode') as mock_imencode:
# Mock successful encoding
mock_imencode.return_value = (True, np.array([1, 2, 3, 4], dtype=np.uint8))
encoded = executor._encode_image_base64(cropped_frame, format="jpeg")
# Should return base64 string
assert isinstance(encoded, str)
assert len(encoded) > 0
# Verify encoding call
mock_imencode.assert_called_once()
assert mock_imencode.call_args[0][0] == '.jpg'
def test_build_insert_query(self):
"""Test building INSERT SQL query."""
executor = ActionExecutor()
table = "detections"
fields = {
"camera_id": "camera_001",
"detection_class": "car",
"confidence": 0.9,
"created_at": "NOW()"
}
query, values = executor._build_insert_query(table, fields)
assert "INSERT INTO detections" in query
assert "camera_id, detection_class, confidence, created_at" in query
assert "VALUES (%s, %s, %s, NOW())" in query
assert values == ["camera_001", "car", 0.9]
def test_build_update_query(self):
"""Test building UPDATE SQL query."""
executor = ActionExecutor()
table = "car_info"
fields = {
"car_brand": "Toyota",
"car_body_type": "Sedan",
"updated_at": "NOW()"
}
key_field = "session_id"
key_value = "session_123"
query, values = executor._build_update_query(table, fields, key_field, key_value)
assert "UPDATE car_info SET" in query
assert "car_brand = %s" in query
assert "car_body_type = %s" in query
assert "updated_at = NOW()" in query
assert "WHERE session_id = %s" in query
assert values == ["Toyota", "Sedan", "session_123"]
def test_evaluate_condition(self):
"""Test evaluating conditional expressions."""
executor = ActionExecutor()
context = {
"confidence": 0.85,
"class": "car",
"area": 40000
}
# Simple comparisons
assert executor._evaluate_condition("{confidence} > 0.8", context) is True
assert executor._evaluate_condition("{confidence} < 0.8", context) is False
assert executor._evaluate_condition("{confidence} >= 0.85", context) is True
assert executor._evaluate_condition("{confidence} == 0.85", context) is True
# String comparisons
assert executor._evaluate_condition("{class} == 'car'", context) is True
assert executor._evaluate_condition("{class} != 'truck'", context) is True
# Complex conditions
assert executor._evaluate_condition("{confidence} > 0.8 and {area} > 30000", context) is True
assert executor._evaluate_condition("{confidence} > 0.9 or {area} > 30000", context) is True
assert executor._evaluate_condition("{confidence} > 0.9 and {area} < 30000", context) is False
def test_validate_action_config(self):
"""Test action configuration validation."""
executor = ActionExecutor()
# Valid Redis action
valid_redis = {
"type": "redis_save_image",
"region": "car",
"key": "inference:{session_id}",
"expire_seconds": 600
}
assert executor._validate_action_config(valid_redis) is True
# Invalid action (missing required fields)
invalid_action = {
"type": "redis_save_image"
# Missing region and key
}
with pytest.raises(ActionError):
executor._validate_action_config(invalid_action)
# Unknown action type
unknown_action = {
"type": "unknown_action_type",
"some_field": "value"
}
with pytest.raises(ActionError):
executor._validate_action_config(unknown_action)
class TestActionExecutorIntegration:
"""Integration tests for action execution."""
@pytest.mark.asyncio
async def test_complete_detection_workflow(self, mock_redis_client, mock_frame):
"""Test complete detection workflow with multiple actions."""
# Mock database manager
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(
redis_client=mock_redis_client,
db_manager=mock_db_manager
)
# Complete action workflow
actions = [
# Save cropped image to Redis
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{camera_id}:{timestamp}:{session_id}:car",
"expire_seconds": 600
},
# Insert initial detection record
{
"type": "postgresql_insert",
"table": "car_detections",
"fields": {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_class": "{class}",
"confidence": "{confidence}",
"bbox_x1": "{bbox.x1}",
"bbox_y1": "{bbox.y1}",
"bbox_x2": "{bbox.x2}",
"bbox_y2": "{bbox.y2}",
"created_at": "NOW()"
}
},
# Publish detection event
{
"type": "redis_publish",
"channel": "detections:{camera_id}",
"message": {
"event": "car_detected",
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"timestamp": "{timestamp}"
}
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.92,
"detection": DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = {
"camera_id": "camera_001",
"session_id": "session_123",
"timestamp": "1640995200000",
"class": "car",
"confidence": 0.92,
"bbox": {"x1": 100, "y1": 200, "x2": 300, "y2": 400},
"frame_data": mock_frame
}
# Mock all Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
mock_redis_client.publish.return_value = 1
results = await executor.execute_actions(actions, regions, context)
# All actions should succeed
assert len(results) == 3
assert all(result.success for result in results)
# Verify all operations were called
mock_redis_client.set.assert_called_once() # Image save
mock_redis_client.expire.assert_called_once() # Set expiry
mock_redis_client.publish.assert_called_once() # Publish event
mock_db_manager.execute_query.assert_called_once() # Database insert
@pytest.mark.asyncio
async def test_branch_dependent_actions(self, mock_database_connection):
"""Test actions that depend on branch results."""
mock_db_manager = Mock()
mock_db_manager.execute_query = AsyncMock(return_value=True)
executor = ActionExecutor(db_manager=mock_db_manager)
# Action that waits for classification branches
actions = [
{
"type": "postgresql_update_combined",
"table": "car_info",
"key_field": "session_id",
"fields": {
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"car_color": "{car_color_cls.color}",
"confidence_brand": "{car_brand_cls.confidence}",
"confidence_bodytype": "{car_bodytype_cls.confidence}",
"updated_at": "NOW()"
},
"waitForBranches": ["car_brand_cls", "car_bodytype_cls", "car_color_cls"]
}
]
regions = {}
context = {
"session_id": "session_123"
}
# Simulated branch results
branch_results = {
"car_brand_cls": {"brand": "Toyota", "confidence": 0.87},
"car_bodytype_cls": {"body_type": "Sedan", "confidence": 0.82},
"car_color_cls": {"color": "Red", "confidence": 0.79}
}
results = await executor.execute_actions(actions, regions, context, branch_results)
assert len(results) == 1
assert results[0].success is True
assert results[0].action_type == ActionType.POSTGRESQL_UPDATE
# Verify database call with all branch data
mock_db_manager.execute_query.assert_called_once()
call_args = mock_db_manager.execute_query.call_args
query = call_args[0][0]
values = call_args[0][1]
assert "UPDATE car_info SET" in query
assert "car_brand = %s" in query
assert "car_body_type = %s" in query
assert "car_color = %s" in query
assert "WHERE session_id = %s" in query
assert "Toyota" in values
assert "Sedan" in values
assert "Red" in values
assert "session_123" in values

View file

@ -0,0 +1,786 @@
"""
Unit tests for field mapping and template resolution.
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime
import json
from detector_worker.pipeline.field_mapper import (
FieldMapper,
MappingContext,
TemplateResolver,
FieldMappingError,
NestedFieldAccessor
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
class TestNestedFieldAccessor:
"""Test nested field access functionality."""
def test_get_nested_value_simple(self):
"""Test getting simple nested values."""
data = {
"user": {
"name": "John",
"age": 30,
"address": {
"city": "New York",
"zip": "10001"
}
}
}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "user.name") == "John"
assert accessor.get_nested_value(data, "user.age") == 30
assert accessor.get_nested_value(data, "user.address.city") == "New York"
assert accessor.get_nested_value(data, "user.address.zip") == "10001"
def test_get_nested_value_array_access(self):
"""Test accessing array elements."""
data = {
"results": [
{"score": 0.9, "label": "car"},
{"score": 0.8, "label": "truck"}
],
"bbox": [100, 200, 300, 400]
}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "results[0].score") == 0.9
assert accessor.get_nested_value(data, "results[0].label") == "car"
assert accessor.get_nested_value(data, "results[1].score") == 0.8
assert accessor.get_nested_value(data, "bbox[0]") == 100
assert accessor.get_nested_value(data, "bbox[3]") == 400
def test_get_nested_value_nonexistent_path(self):
"""Test accessing non-existent paths."""
data = {"user": {"name": "John"}}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "user.nonexistent") is None
assert accessor.get_nested_value(data, "nonexistent.field") is None
assert accessor.get_nested_value(data, "user.address.city") is None
def test_get_nested_value_with_default(self):
"""Test getting nested values with default fallback."""
data = {"user": {"name": "John"}}
accessor = NestedFieldAccessor()
assert accessor.get_nested_value(data, "user.age", default=25) == 25
assert accessor.get_nested_value(data, "user.name", default="Unknown") == "John"
def test_set_nested_value(self):
"""Test setting nested values."""
data = {}
accessor = NestedFieldAccessor()
accessor.set_nested_value(data, "user.name", "John")
assert data["user"]["name"] == "John"
accessor.set_nested_value(data, "user.address.city", "New York")
assert data["user"]["address"]["city"] == "New York"
accessor.set_nested_value(data, "scores[0]", 0.95)
assert data["scores"][0] == 0.95
def test_set_nested_value_overwrite(self):
"""Test overwriting existing nested values."""
data = {"user": {"name": "John", "age": 30}}
accessor = NestedFieldAccessor()
accessor.set_nested_value(data, "user.name", "Jane")
assert data["user"]["name"] == "Jane"
assert data["user"]["age"] == 30 # Should not affect other fields
class TestTemplateResolver:
"""Test template string resolution."""
def test_resolve_simple_template(self):
"""Test resolving simple template variables."""
resolver = TemplateResolver()
template = "Hello {name}, you are {age} years old"
context = {"name": "John", "age": 30}
result = resolver.resolve(template, context)
assert result == "Hello John, you are 30 years old"
def test_resolve_nested_template(self):
"""Test resolving nested field templates."""
resolver = TemplateResolver()
template = "User: {user.name} from {user.address.city}"
context = {
"user": {
"name": "John",
"address": {"city": "New York", "zip": "10001"}
}
}
result = resolver.resolve(template, context)
assert result == "User: John from New York"
def test_resolve_array_template(self):
"""Test resolving array element templates."""
resolver = TemplateResolver()
template = "First result: {results[0].label} ({results[0].score})"
context = {
"results": [
{"label": "car", "score": 0.95},
{"label": "truck", "score": 0.87}
]
}
result = resolver.resolve(template, context)
assert result == "First result: car (0.95)"
def test_resolve_missing_variables(self):
"""Test resolving templates with missing variables."""
resolver = TemplateResolver()
template = "Hello {name}, you are {age} years old"
context = {"name": "John"} # Missing age
with pytest.raises(FieldMappingError) as exc_info:
resolver.resolve(template, context)
assert "Variable 'age' not found" in str(exc_info.value)
def test_resolve_with_defaults(self):
"""Test resolving templates with default values."""
resolver = TemplateResolver(allow_missing=True)
template = "Hello {name}, you are {age|25} years old"
context = {"name": "John"} # Missing age, should use default
result = resolver.resolve(template, context)
assert result == "Hello John, you are 25 years old"
def test_resolve_complex_template(self):
"""Test resolving complex templates with multiple variable types."""
resolver = TemplateResolver()
template = "{camera_id}:{timestamp}:{session_id}:{results[0].class}_{bbox[0]}_{bbox[1]}"
context = {
"camera_id": "cam001",
"timestamp": 1640995200000,
"session_id": "sess123",
"results": [{"class": "car", "confidence": 0.95}],
"bbox": [100, 200, 300, 400]
}
result = resolver.resolve(template, context)
assert result == "cam001:1640995200000:sess123:car_100_200"
def test_resolve_conditional_template(self):
"""Test resolving conditional templates."""
resolver = TemplateResolver()
# Simple conditional
template = "{name} is {age > 18 ? 'adult' : 'minor'}"
context_adult = {"name": "John", "age": 25}
result_adult = resolver.resolve(template, context_adult)
assert result_adult == "John is adult"
context_minor = {"name": "Jane", "age": 16}
result_minor = resolver.resolve(template, context_minor)
assert result_minor == "Jane is minor"
def test_escape_braces(self):
"""Test escaping braces in templates."""
resolver = TemplateResolver()
template = "Literal {{braces}} and variable {name}"
context = {"name": "John"}
result = resolver.resolve(template, context)
assert result == "Literal {braces} and variable John"
class TestMappingContext:
"""Test mapping context data structure."""
def test_creation(self):
"""Test mapping context creation."""
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection,
timestamp=1640995200000
)
assert context.camera_id == "camera_001"
assert context.display_id == "display_001"
assert context.session_id == "session_123"
assert context.detection == detection
assert context.timestamp == 1640995200000
assert context.branch_results == {}
assert context.metadata == {}
def test_add_branch_result(self):
"""Test adding branch results to context."""
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
context.add_branch_result("car_brand_cls", {"brand": "Toyota", "confidence": 0.87})
context.add_branch_result("car_bodytype_cls", {"body_type": "Sedan", "confidence": 0.82})
assert len(context.branch_results) == 2
assert context.branch_results["car_brand_cls"]["brand"] == "Toyota"
assert context.branch_results["car_bodytype_cls"]["body_type"] == "Sedan"
def test_to_dict(self):
"""Test converting context to dictionary."""
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection,
timestamp=1640995200000
)
context.add_branch_result("car_brand_cls", {"brand": "Toyota"})
context.add_metadata("model_id", "yolo_v8")
context_dict = context.to_dict()
assert context_dict["camera_id"] == "camera_001"
assert context_dict["display_id"] == "display_001"
assert context_dict["session_id"] == "session_123"
assert context_dict["timestamp"] == 1640995200000
assert context_dict["class"] == "car"
assert context_dict["confidence"] == 0.9
assert context_dict["track_id"] == 1001
assert context_dict["bbox"]["x1"] == 100
assert context_dict["car_brand_cls"]["brand"] == "Toyota"
assert context_dict["model_id"] == "yolo_v8"
def test_add_metadata(self):
"""Test adding metadata to context."""
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
context.add_metadata("model_version", "v2.1")
context.add_metadata("inference_time", 0.15)
assert context.metadata["model_version"] == "v2.1"
assert context.metadata["inference_time"] == 0.15
class TestFieldMapper:
"""Test field mapping functionality."""
def test_initialization(self):
"""Test field mapper initialization."""
mapper = FieldMapper()
assert isinstance(mapper.template_resolver, TemplateResolver)
assert isinstance(mapper.field_accessor, NestedFieldAccessor)
def test_map_fields_simple(self):
"""Test simple field mapping."""
mapper = FieldMapper()
field_mappings = {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence_score": "{confidence}",
"track_identifier": "{track_id}"
}
detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection,
timestamp=1640995200000
)
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["camera_id"] == "camera_001"
assert mapped_fields["detection_class"] == "car"
assert mapped_fields["confidence_score"] == 0.92
assert mapped_fields["track_identifier"] == 1001
def test_map_fields_with_branch_results(self):
"""Test field mapping with branch results."""
mapper = FieldMapper()
field_mappings = {
"car_brand": "{car_brand_cls.brand}",
"car_model": "{car_brand_cls.model}",
"body_type": "{car_bodytype_cls.body_type}",
"brand_confidence": "{car_brand_cls.confidence}",
"combined_info": "{car_brand_cls.brand} {car_bodytype_cls.body_type}"
}
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
context.add_branch_result("car_brand_cls", {
"brand": "Toyota",
"model": "Camry",
"confidence": 0.87
})
context.add_branch_result("car_bodytype_cls", {
"body_type": "Sedan",
"confidence": 0.82
})
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["car_brand"] == "Toyota"
assert mapped_fields["car_model"] == "Camry"
assert mapped_fields["body_type"] == "Sedan"
assert mapped_fields["brand_confidence"] == 0.87
assert mapped_fields["combined_info"] == "Toyota Sedan"
def test_map_fields_bbox_access(self):
"""Test field mapping with bounding box access."""
mapper = FieldMapper()
field_mappings = {
"bbox_x1": "{bbox.x1}",
"bbox_y1": "{bbox.y1}",
"bbox_x2": "{bbox.x2}",
"bbox_y2": "{bbox.y2}",
"bbox_width": "{bbox.width}",
"bbox_height": "{bbox.height}",
"bbox_area": "{bbox.area}",
"bbox_center_x": "{bbox.center_x}",
"bbox_center_y": "{bbox.center_y}"
}
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection
)
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["bbox_x1"] == 100
assert mapped_fields["bbox_y1"] == 200
assert mapped_fields["bbox_x2"] == 300
assert mapped_fields["bbox_y2"] == 400
assert mapped_fields["bbox_width"] == 200 # 300 - 100
assert mapped_fields["bbox_height"] == 200 # 400 - 200
assert mapped_fields["bbox_area"] == 40000 # 200 * 200
assert mapped_fields["bbox_center_x"] == 200 # (100 + 300) / 2
assert mapped_fields["bbox_center_y"] == 300 # (200 + 400) / 2
def test_map_fields_with_sql_functions(self):
"""Test field mapping with SQL function templates."""
mapper = FieldMapper()
field_mappings = {
"created_at": "NOW()",
"updated_at": "CURRENT_TIMESTAMP",
"uuid_field": "UUID()",
"json_data": "JSON_OBJECT('class', '{class}', 'confidence', {confidence})"
}
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
detection=detection
)
mapped_fields = mapper.map_fields(field_mappings, context)
# SQL functions should pass through unchanged
assert mapped_fields["created_at"] == "NOW()"
assert mapped_fields["updated_at"] == "CURRENT_TIMESTAMP"
assert mapped_fields["uuid_field"] == "UUID()"
assert mapped_fields["json_data"] == "JSON_OBJECT('class', 'car', 'confidence', 0.9)"
def test_map_fields_missing_branch_data(self):
"""Test field mapping with missing branch data."""
mapper = FieldMapper()
field_mappings = {
"car_brand": "{car_brand_cls.brand}",
"car_model": "{nonexistent_branch.model}"
}
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
# Only add one branch result
context.add_branch_result("car_brand_cls", {"brand": "Toyota"})
with pytest.raises(FieldMappingError) as exc_info:
mapper.map_fields(field_mappings, context)
assert "nonexistent_branch.model" in str(exc_info.value)
def test_map_fields_with_defaults(self):
"""Test field mapping with default values."""
mapper = FieldMapper(allow_missing=True)
field_mappings = {
"car_brand": "{car_brand_cls.brand|Unknown}",
"car_model": "{car_brand_cls.model|N/A}",
"confidence": "{confidence|0.0}"
}
context = MappingContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123"
)
# Don't add any branch results
mapped_fields = mapper.map_fields(field_mappings, context)
assert mapped_fields["car_brand"] == "Unknown"
assert mapped_fields["car_model"] == "N/A"
assert mapped_fields["confidence"] == "0.0"
def test_map_database_fields(self):
"""Test mapping fields for database operations."""
mapper = FieldMapper()
# Database field mapping
db_field_mappings = {
"camera_id": "{camera_id}",
"session_id": "{session_id}",
"detection_timestamp": "{timestamp}",
"object_class": "{class}",
"detection_confidence": "{confidence}",
"track_id": "{track_id}",
"bbox_json": "JSON_OBJECT('x1', {bbox.x1}, 'y1', {bbox.y1}, 'x2', {bbox.x2}, 'y2', {bbox.y2})",
"car_brand": "{car_brand_cls.brand}",
"car_body_type": "{car_bodytype_cls.body_type}",
"license_plate": "{license_ocr.text}",
"created_at": "NOW()",
"updated_at": "NOW()"
}
detection = DetectionResult("car", 0.93, BoundingBox(150, 250, 350, 450), 2001, 1640995300000)
context = MappingContext(
camera_id="camera_002",
display_id="display_002",
session_id="session_456",
detection=detection,
timestamp=1640995300000
)
# Add branch results
context.add_branch_result("car_brand_cls", {"brand": "Honda", "confidence": 0.89})
context.add_branch_result("car_bodytype_cls", {"body_type": "SUV", "confidence": 0.85})
context.add_branch_result("license_ocr", {"text": "ABC-123", "confidence": 0.76})
mapped_fields = mapper.map_fields(db_field_mappings, context)
assert mapped_fields["camera_id"] == "camera_002"
assert mapped_fields["session_id"] == "session_456"
assert mapped_fields["detection_timestamp"] == 1640995300000
assert mapped_fields["object_class"] == "car"
assert mapped_fields["detection_confidence"] == 0.93
assert mapped_fields["track_id"] == 2001
assert mapped_fields["bbox_json"] == "JSON_OBJECT('x1', 150, 'y1', 250, 'x2', 350, 'y2', 450)"
assert mapped_fields["car_brand"] == "Honda"
assert mapped_fields["car_body_type"] == "SUV"
assert mapped_fields["license_plate"] == "ABC-123"
assert mapped_fields["created_at"] == "NOW()"
assert mapped_fields["updated_at"] == "NOW()"
def test_map_redis_keys(self):
"""Test mapping Redis key templates."""
mapper = FieldMapper()
key_templates = [
"inference:{camera_id}:{timestamp}:{session_id}:car",
"detection:{display_id}:{track_id}",
"cropped_image:{camera_id}:{session_id}:{class}",
"metadata:{session_id}:brands:{car_brand_cls.brand}",
"tracking:{camera_id}:active_tracks"
]
detection = DetectionResult("car", 0.88, BoundingBox(200, 300, 400, 500), 3001, 1640995400000)
context = MappingContext(
camera_id="camera_003",
display_id="display_003",
session_id="session_789",
detection=detection,
timestamp=1640995400000
)
context.add_branch_result("car_brand_cls", {"brand": "Ford"})
mapped_keys = [mapper.map_template(template, context) for template in key_templates]
expected_keys = [
"inference:camera_003:1640995400000:session_789:car",
"detection:display_003:3001",
"cropped_image:camera_003:session_789:car",
"metadata:session_789:brands:Ford",
"tracking:camera_003:active_tracks"
]
assert mapped_keys == expected_keys
def test_map_template(self):
"""Test single template mapping."""
mapper = FieldMapper()
template = "Camera {camera_id} detected {class} with {confidence:.2f} confidence at {timestamp}"
detection = DetectionResult("truck", 0.876, BoundingBox(100, 150, 300, 350), 4001, 1640995500000)
context = MappingContext(
camera_id="camera_004",
display_id="display_004",
session_id="session_101",
detection=detection,
timestamp=1640995500000
)
result = mapper.map_template(template, context)
expected = "Camera camera_004 detected truck with 0.88 confidence at 1640995500000"
assert result == expected
def test_validate_field_mappings(self):
"""Test field mapping validation."""
mapper = FieldMapper()
# Valid mappings
valid_mappings = {
"camera_id": "{camera_id}",
"class": "{class}",
"confidence": "{confidence}",
"created_at": "NOW()"
}
assert mapper.validate_field_mappings(valid_mappings) is True
# Invalid mappings (malformed templates)
invalid_mappings = {
"camera_id": "{camera_id", # Missing closing brace
"class": "class}", # Missing opening brace
"confidence": "{nonexistent_field}" # This might be valid depending on context
}
with pytest.raises(FieldMappingError):
mapper.validate_field_mappings(invalid_mappings)
def test_create_context_from_detection(self):
"""Test creating mapping context from detection result."""
mapper = FieldMapper()
detection = DetectionResult("car", 0.95, BoundingBox(50, 100, 250, 300), 5001, 1640995600000)
context = mapper.create_context_from_detection(
detection,
camera_id="camera_005",
display_id="display_005",
session_id="session_202"
)
assert context.camera_id == "camera_005"
assert context.display_id == "display_005"
assert context.session_id == "session_202"
assert context.detection == detection
assert context.timestamp == 1640995600000
def test_format_sql_value(self):
"""Test SQL value formatting."""
mapper = FieldMapper()
# String values should be quoted
assert mapper.format_sql_value("test_string") == "'test_string'"
assert mapper.format_sql_value("John's car") == "'John''s car'" # Escape quotes
# Numeric values should not be quoted
assert mapper.format_sql_value(42) == "42"
assert mapper.format_sql_value(3.14) == "3.14"
assert mapper.format_sql_value(0.95) == "0.95"
# Boolean values
assert mapper.format_sql_value(True) == "TRUE"
assert mapper.format_sql_value(False) == "FALSE"
# None/NULL values
assert mapper.format_sql_value(None) == "NULL"
# SQL functions should pass through
assert mapper.format_sql_value("NOW()") == "NOW()"
assert mapper.format_sql_value("CURRENT_TIMESTAMP") == "CURRENT_TIMESTAMP"
class TestFieldMapperIntegration:
"""Integration tests for field mapping."""
def test_complete_mapping_workflow(self):
"""Test complete field mapping workflow."""
mapper = FieldMapper()
# Simulate complete detection workflow
detection = DetectionResult("car", 0.91, BoundingBox(120, 180, 320, 380), 6001, 1640995700000)
context = MappingContext(
camera_id="camera_006",
display_id="display_006",
session_id="session_303",
detection=detection,
timestamp=1640995700000
)
# Add comprehensive branch results
context.add_branch_result("car_brand_cls", {
"brand": "BMW",
"model": "X5",
"confidence": 0.84,
"top3_brands": ["BMW", "Audi", "Mercedes"]
})
context.add_branch_result("car_bodytype_cls", {
"body_type": "SUV",
"confidence": 0.79,
"features": ["tall", "4_doors", "roof_rails"]
})
context.add_branch_result("car_color_cls", {
"color": "Black",
"confidence": 0.73,
"rgb_values": [20, 25, 30]
})
context.add_branch_result("license_ocr", {
"text": "XYZ-789",
"confidence": 0.68,
"region_bbox": [150, 320, 290, 360]
})
# Database field mapping
db_mappings = {
"camera_id": "{camera_id}",
"display_id": "{display_id}",
"session_id": "{session_id}",
"detection_timestamp": "{timestamp}",
"object_class": "{class}",
"detection_confidence": "{confidence}",
"track_id": "{track_id}",
"bbox_x1": "{bbox.x1}",
"bbox_y1": "{bbox.y1}",
"bbox_x2": "{bbox.x2}",
"bbox_y2": "{bbox.y2}",
"bbox_area": "{bbox.area}",
"car_brand": "{car_brand_cls.brand}",
"car_model": "{car_brand_cls.model}",
"car_body_type": "{car_bodytype_cls.body_type}",
"car_color": "{car_color_cls.color}",
"license_plate": "{license_ocr.text}",
"brand_confidence": "{car_brand_cls.confidence}",
"bodytype_confidence": "{car_bodytype_cls.confidence}",
"color_confidence": "{car_color_cls.confidence}",
"license_confidence": "{license_ocr.confidence}",
"detection_summary": "{car_brand_cls.brand} {car_bodytype_cls.body_type} ({car_color_cls.color})",
"created_at": "NOW()",
"updated_at": "NOW()"
}
mapped_db_fields = mapper.map_fields(db_mappings, context)
# Verify all mappings
assert mapped_db_fields["camera_id"] == "camera_006"
assert mapped_db_fields["session_id"] == "session_303"
assert mapped_db_fields["object_class"] == "car"
assert mapped_db_fields["detection_confidence"] == 0.91
assert mapped_db_fields["track_id"] == 6001
assert mapped_db_fields["bbox_area"] == 40000 # 200 * 200
assert mapped_db_fields["car_brand"] == "BMW"
assert mapped_db_fields["car_model"] == "X5"
assert mapped_db_fields["car_body_type"] == "SUV"
assert mapped_db_fields["car_color"] == "Black"
assert mapped_db_fields["license_plate"] == "XYZ-789"
assert mapped_db_fields["detection_summary"] == "BMW SUV (Black)"
# Redis key mapping
redis_key_templates = [
"detection:{camera_id}:{session_id}:main",
"cropped:{camera_id}:{session_id}:car_image",
"metadata:{session_id}:brand:{car_brand_cls.brand}",
"tracking:{camera_id}:track_{track_id}",
"classification:{session_id}:results"
]
mapped_redis_keys = [
mapper.map_template(template, context)
for template in redis_key_templates
]
expected_redis_keys = [
"detection:camera_006:session_303:main",
"cropped:camera_006:session_303:car_image",
"metadata:session_303:brand:BMW",
"tracking:camera_006:track_6001",
"classification:session_303:results"
]
assert mapped_redis_keys == expected_redis_keys
def test_error_handling_and_recovery(self):
"""Test error handling and recovery in field mapping."""
mapper = FieldMapper(allow_missing=True)
# Context with missing detection
context = MappingContext(
camera_id="camera_007",
display_id="display_007",
session_id="session_404"
)
# Partial branch results
context.add_branch_result("car_brand_cls", {"brand": "Unknown"})
# Missing car_bodytype_cls branch
# Field mappings with some missing data
mappings = {
"camera_id": "{camera_id}",
"detection_class": "{class|Unknown}",
"confidence": "{confidence|0.0}",
"car_brand": "{car_brand_cls.brand|N/A}",
"car_body_type": "{car_bodytype_cls.body_type|Unknown}",
"car_model": "{car_brand_cls.model|N/A}"
}
mapped_fields = mapper.map_fields(mappings, context)
assert mapped_fields["camera_id"] == "camera_007"
assert mapped_fields["detection_class"] == "Unknown"
assert mapped_fields["confidence"] == "0.0"
assert mapped_fields["car_brand"] == "Unknown"
assert mapped_fields["car_body_type"] == "Unknown"
assert mapped_fields["car_model"] == "N/A"

View file

@ -0,0 +1,921 @@
"""
Unit tests for pipeline execution functionality.
"""
import pytest
import asyncio
import numpy as np
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from concurrent.futures import ThreadPoolExecutor
import json
from detector_worker.pipeline.pipeline_executor import (
PipelineExecutor,
PipelineContext,
PipelineResult,
BranchResult,
ExecutionMode
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import PipelineError, ModelError, ActionError
class TestPipelineContext:
"""Test pipeline context data structure."""
def test_creation(self):
"""Test pipeline context creation."""
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
)
assert context.camera_id == "camera_001"
assert context.display_id == "display_001"
assert context.session_id == "session_123"
assert context.timestamp == 1640995200000
assert context.frame_data.shape == (480, 640, 3)
assert context.metadata == {}
assert context.crop_region is None
def test_creation_with_crop_region(self):
"""Test context creation with crop region."""
crop_region = (100, 200, 300, 400)
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8),
crop_region=crop_region
)
assert context.crop_region == crop_region
def test_add_metadata(self):
"""Test adding metadata to context."""
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
)
context.add_metadata("model_id", "yolo_v8")
context.add_metadata("confidence_threshold", 0.8)
assert context.metadata["model_id"] == "yolo_v8"
assert context.metadata["confidence_threshold"] == 0.8
def test_get_cropped_frame(self):
"""Test getting cropped frame."""
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=frame,
crop_region=(100, 200, 300, 400)
)
cropped = context.get_cropped_frame()
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
assert np.all(cropped == 255)
def test_get_cropped_frame_no_crop(self):
"""Test getting frame when no crop region specified."""
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=frame
)
cropped = context.get_cropped_frame()
assert np.array_equal(cropped, frame)
class TestBranchResult:
"""Test branch execution result."""
def test_creation_success(self):
"""Test successful branch result creation."""
detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
result = BranchResult(
branch_id="car_brand_cls",
success=True,
detections=detections,
metadata={"brand": "Toyota"},
execution_time=0.15
)
assert result.branch_id == "car_brand_cls"
assert result.success is True
assert len(result.detections) == 1
assert result.metadata["brand"] == "Toyota"
assert result.execution_time == 0.15
assert result.error is None
def test_creation_failure(self):
"""Test failed branch result creation."""
result = BranchResult(
branch_id="car_brand_cls",
success=False,
error="Model inference failed",
execution_time=0.05
)
assert result.branch_id == "car_brand_cls"
assert result.success is False
assert result.detections == []
assert result.metadata == {}
assert result.error == "Model inference failed"
class TestPipelineResult:
"""Test pipeline execution result."""
def test_creation_success(self):
"""Test successful pipeline result creation."""
main_detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
}
result = PipelineResult(
success=True,
detections=main_detections,
branch_results=branch_results,
total_execution_time=0.5
)
assert result.success is True
assert len(result.detections) == 1
assert len(result.branch_results) == 2
assert result.total_execution_time == 0.5
assert result.error is None
def test_creation_failure(self):
"""Test failed pipeline result creation."""
result = PipelineResult(
success=False,
error="Pipeline execution failed",
total_execution_time=0.1
)
assert result.success is False
assert result.detections == []
assert result.branch_results == {}
assert result.error == "Pipeline execution failed"
def test_get_combined_results(self):
"""Test getting combined results from all branches."""
main_detections = [
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
]
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
}
result = PipelineResult(
success=True,
detections=main_detections,
branch_results=branch_results,
total_execution_time=0.5
)
combined = result.get_combined_results()
assert "brand" in combined
assert "body_type" in combined
assert combined["brand"] == "Toyota"
assert combined["body_type"] == "Sedan"
class TestPipelineExecutor:
"""Test pipeline execution functionality."""
def test_initialization(self):
"""Test pipeline executor initialization."""
executor = PipelineExecutor()
assert isinstance(executor.thread_pool, ThreadPoolExecutor)
assert executor.max_workers == 4
assert executor.execution_mode == ExecutionMode.PARALLEL
assert executor.timeout == 30.0
def test_initialization_custom_config(self):
"""Test initialization with custom configuration."""
config = {
"max_workers": 8,
"execution_mode": "sequential",
"timeout": 60.0
}
executor = PipelineExecutor(config)
assert executor.max_workers == 8
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
assert executor.timeout == 60.0
@pytest.mark.asyncio
async def test_execute_pipeline_simple(self, mock_yolo_model, mock_frame):
"""Test simple pipeline execution."""
# Mock pipeline configuration
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert len(result.detections) == 1
assert result.detections[0].class_name == "0" # Default class name
assert result.detections[0].confidence == 0.9
@pytest.mark.asyncio
async def test_execute_pipeline_with_branches(self, mock_yolo_model, mock_frame):
"""Test pipeline execution with classification branches."""
import torch
# Mock main detection
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0] # car detection
])
mock_detection_result.boxes.id = torch.tensor([1001])
# Mock classification results
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 2 # Toyota
mock_brand_result.probs.top1conf = 0.85
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 1 # Sedan
mock_bodytype_result.probs.top1conf = 0.78
mock_yolo_model.track.return_value = [mock_detection_result]
mock_yolo_model.predict.return_value = [mock_brand_result]
mock_brand_model = Mock()
mock_brand_model.predict.return_value = [mock_brand_result]
mock_brand_model.names = {0: "Honda", 1: "Ford", 2: "Toyota"}
mock_bodytype_model = Mock()
mock_bodytype_model.predict.return_value = [mock_bodytype_result]
mock_bodytype_model.names = {0: "SUV", 1: "Sedan", 2: "Hatchback"}
# Pipeline configuration with branches
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [
{
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
},
{
"modelId": "car_bodytype_cls",
"modelFile": "car_bodytype.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
}
],
"actions": []
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
def get_model_side_effect(model_id, camera_id):
if model_id == "car_detection_v1":
return mock_yolo_model
elif model_id == "car_brand_cls":
return mock_brand_model
elif model_id == "car_bodytype_cls":
return mock_bodytype_model
return None
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert len(result.detections) == 1
assert len(result.branch_results) == 2
# Check branch results
assert "car_brand_cls" in result.branch_results
assert "car_bodytype_cls" in result.branch_results
brand_result = result.branch_results["car_brand_cls"]
assert brand_result.success is True
assert brand_result.metadata.get("brand") == "Toyota"
bodytype_result = result.branch_results["car_bodytype_cls"]
assert bodytype_result.success is True
assert bodytype_result.metadata.get("body_type") == "Sedan"
@pytest.mark.asyncio
async def test_execute_pipeline_sequential_mode(self, mock_yolo_model, mock_frame):
"""Test pipeline execution in sequential mode."""
import torch
config = {"execution_mode": "sequential"}
executor = PipelineExecutor(config)
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [
{
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": False # Sequential execution
}
],
"actions": []
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
@pytest.mark.asyncio
async def test_execute_pipeline_with_actions(self, mock_yolo_model, mock_frame):
"""Test pipeline execution with actions."""
import torch
# Mock detection result
mock_result = Mock()
mock_result.boxes = Mock()
mock_result.boxes.data = torch.tensor([
[100, 200, 300, 400, 0.9, 0]
])
mock_result.boxes.id = torch.tensor([1001])
mock_yolo_model.track.return_value = [mock_result]
# Pipeline configuration with actions
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": [
{
"type": "redis_save_image",
"region": "car",
"key": "inference:{display_id}:{timestamp}:{session_id}",
"expire_seconds": 600
},
{
"type": "postgresql_insert",
"table": "detections",
"fields": {
"camera_id": "{camera_id}",
"detection_class": "{class}",
"confidence": "{confidence}"
}
}
]
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager, \
patch('detector_worker.pipeline.action_executor.ActionExecutor') as mock_action_executor:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
mock_action_executor.return_value.execute_actions = AsyncMock(return_value=True)
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is True
# Actions should be executed
mock_action_executor.return_value.execute_actions.assert_called_once()
@pytest.mark.asyncio
async def test_execute_pipeline_model_error(self, mock_frame):
"""Test pipeline execution with model error."""
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
executor = PipelineExecutor()
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
# Model manager raises error
mock_model_manager.return_value.get_model.side_effect = ModelError("Model not found")
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is False
assert "Model not found" in result.error
@pytest.mark.asyncio
async def test_execute_pipeline_timeout(self, mock_yolo_model, mock_frame):
"""Test pipeline execution timeout."""
import torch
# Configure short timeout
config = {"timeout": 0.001} # Very short timeout
executor = PipelineExecutor(config)
# Mock slow model inference
def slow_inference(*args, **kwargs):
import time
time.sleep(1) # Longer than timeout
mock_result = Mock()
mock_result.boxes = None
return [mock_result]
mock_yolo_model.track.side_effect = slow_inference
pipeline_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8,
"branches": [],
"actions": []
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
result = await executor.execute_pipeline(pipeline_config, context)
assert result.success is False
assert "timeout" in result.error.lower()
@pytest.mark.asyncio
async def test_execute_branch_parallel(self, mock_frame):
"""Test parallel branch execution."""
import torch
# Mock classification model
mock_brand_model = Mock()
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
mock_brand_model.predict.return_value = [mock_result]
mock_brand_model.names = {0: "Honda", 1: "Toyota", 2: "Ford"}
executor = PipelineExecutor()
# Branch configuration
branch_config = {
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True,
"crop": True,
"cropClass": "car"
}
# Mock detected regions
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_brand_model
result = await executor._execute_branch(branch_config, regions, context)
assert result.success is True
assert result.branch_id == "car_brand_cls"
assert result.metadata.get("brand") == "Toyota"
assert result.execution_time > 0
@pytest.mark.asyncio
async def test_execute_branch_no_trigger_class(self, mock_frame):
"""Test branch execution when trigger class not detected."""
executor = PipelineExecutor()
branch_config = {
"modelId": "car_brand_cls",
"modelFile": "car_brand.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7
}
# No car detected
regions = {
"truck": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("truck", 0.9, BoundingBox(100, 200, 300, 400), 1002)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
result = await executor._execute_branch(branch_config, regions, context)
assert result.success is False
assert "trigger class not detected" in result.error.lower()
def test_wait_for_branches(self):
"""Test waiting for specific branches to complete."""
executor = PipelineExecutor()
# Mock completed branch results
branch_results = {
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12),
"license_ocr": BranchResult("license_ocr", True, [], {"license": "ABC123"}, 0.2)
}
# Wait for specific branches
wait_for = ["car_brand_cls", "car_bodytype_cls"]
completed = executor._wait_for_branches(branch_results, wait_for, timeout=1.0)
assert completed is True
# Wait for non-existent branch (should timeout)
wait_for_missing = ["car_brand_cls", "nonexistent_branch"]
completed = executor._wait_for_branches(branch_results, wait_for_missing, timeout=0.1)
assert completed is False
def test_validate_pipeline_config(self):
"""Test pipeline configuration validation."""
executor = PipelineExecutor()
# Valid configuration
valid_config = {
"modelId": "car_detection_v1",
"modelFile": "car_detection.pt",
"expectedClasses": ["car"],
"triggerClasses": ["car"],
"minConfidence": 0.8
}
assert executor._validate_pipeline_config(valid_config) is True
# Invalid configuration (missing required fields)
invalid_config = {
"modelFile": "car_detection.pt"
# Missing modelId
}
with pytest.raises(PipelineError):
executor._validate_pipeline_config(invalid_config)
def test_crop_frame_for_detection(self, mock_frame):
"""Test frame cropping for detection."""
executor = PipelineExecutor()
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
cropped = executor._crop_frame_for_detection(mock_frame, detection)
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
def test_crop_frame_invalid_bounds(self, mock_frame):
"""Test frame cropping with invalid bounds."""
executor = PipelineExecutor()
# Detection outside frame bounds
detection = DetectionResult("car", 0.9, BoundingBox(-100, -200, 50, 100), 1001)
cropped = executor._crop_frame_for_detection(mock_frame, detection)
# Should handle bounds gracefully
assert cropped.shape[0] > 0
assert cropped.shape[1] > 0
class TestPipelineExecutorPerformance:
"""Test pipeline executor performance and optimization."""
@pytest.mark.asyncio
async def test_parallel_branch_execution_performance(self, mock_frame):
"""Test that parallel execution is faster than sequential."""
import time
import torch
def slow_inference(*args, **kwargs):
time.sleep(0.1) # Simulate slow inference
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
return [mock_result]
mock_model = Mock()
mock_model.predict.side_effect = slow_inference
mock_model.names = {0: "Class0", 1: "Class1"}
# Test parallel execution
parallel_executor = PipelineExecutor({"execution_mode": "parallel", "max_workers": 2})
branch_configs = [
{
"modelId": f"branch_{i}",
"modelFile": f"branch_{i}.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
}
for i in range(3) # 3 branches
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.return_value = mock_model
start_time = time.time()
# Execute branches in parallel
tasks = [
parallel_executor._execute_branch(config, regions, context)
for config in branch_configs
]
results = await asyncio.gather(*tasks)
parallel_time = time.time() - start_time
# Parallel execution should be faster than 3 * 0.1 seconds
assert parallel_time < 0.25 # Allow some overhead
assert len(results) == 3
assert all(result.success for result in results)
def test_thread_pool_management(self):
"""Test thread pool creation and management."""
# Test different worker counts
for workers in [1, 2, 4, 8]:
executor = PipelineExecutor({"max_workers": workers})
assert executor.max_workers == workers
assert executor.thread_pool._max_workers == workers
def test_memory_management_large_frames(self):
"""Test memory management with large frames."""
executor = PipelineExecutor()
# Create large frame
large_frame = np.ones((1080, 1920, 3), dtype=np.uint8) * 128
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=large_frame,
crop_region=(500, 400, 1000, 800)
)
# Get cropped frame
cropped = context.get_cropped_frame()
# Should reduce memory usage
assert cropped.shape == (400, 500, 3) # Much smaller than original
assert cropped.nbytes < large_frame.nbytes
class TestPipelineExecutorErrorHandling:
"""Test comprehensive error handling."""
@pytest.mark.asyncio
async def test_branch_execution_error_isolation(self, mock_frame):
"""Test that errors in one branch don't affect others."""
executor = PipelineExecutor()
# Mock models - one fails, one succeeds
failing_model = Mock()
failing_model.predict.side_effect = Exception("Model crashed")
success_model = Mock()
mock_result = Mock()
mock_result.probs = Mock()
mock_result.probs.top1 = 1
mock_result.probs.top1conf = 0.85
success_model.predict.return_value = [mock_result]
success_model.names = {0: "Class0", 1: "Class1"}
branch_configs = [
{
"modelId": "failing_branch",
"modelFile": "failing.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
},
{
"modelId": "success_branch",
"modelFile": "success.pt",
"triggerClasses": ["car"],
"minConfidence": 0.7,
"parallel": True
}
]
regions = {
"car": {
"bbox": [100, 200, 300, 400],
"confidence": 0.9,
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
}
}
context = PipelineContext(
camera_id="camera_001",
display_id="display_001",
session_id="session_123",
timestamp=1640995200000,
frame_data=mock_frame
)
def get_model_side_effect(model_id, camera_id):
if model_id == "failing_branch":
return failing_model
elif model_id == "success_branch":
return success_model
return None
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
# Execute branches
tasks = [
executor._execute_branch(config, regions, context)
for config in branch_configs
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# One should fail, one should succeed
failing_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "failing_branch")
success_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "success_branch")
assert failing_result.success is False
assert "Model crashed" in failing_result.error
assert success_result.success is True
assert success_result.error is None

View file

@ -0,0 +1,976 @@
"""
Unit tests for database management functionality.
"""
import pytest
import asyncio
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
import psycopg2
import uuid
from detector_worker.storage.database_manager import (
DatabaseManager,
DatabaseConfig,
DatabaseConnection,
QueryBuilder,
TransactionManager,
DatabaseError,
ConnectionPoolError
)
from detector_worker.core.exceptions import ConfigurationError
class TestDatabaseConfig:
"""Test database configuration."""
def test_creation_minimal(self):
"""Test creating database config with minimal parameters."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
assert config.host == "localhost"
assert config.port == 5432 # Default port
assert config.database == "test_db"
assert config.username == "test_user"
assert config.password == "test_pass"
assert config.schema == "public" # Default schema
assert config.enabled is True
def test_creation_full(self):
"""Test creating database config with all parameters."""
config = DatabaseConfig(
host="db.example.com",
port=5433,
database="production_db",
username="prod_user",
password="secure_pass",
schema="gas_station_1",
enabled=True,
pool_min_conn=2,
pool_max_conn=20,
pool_timeout=30.0,
connection_timeout=10.0,
ssl_mode="require"
)
assert config.host == "db.example.com"
assert config.port == 5433
assert config.database == "production_db"
assert config.schema == "gas_station_1"
assert config.pool_min_conn == 2
assert config.pool_max_conn == 20
assert config.ssl_mode == "require"
def test_get_connection_string(self):
"""Test generating connection string."""
config = DatabaseConfig(
host="localhost",
port=5432,
database="test_db",
username="test_user",
password="test_pass"
)
conn_string = config.get_connection_string()
expected = "host=localhost port=5432 database=test_db user=test_user password=test_pass"
assert conn_string == expected
def test_get_connection_string_with_ssl(self):
"""Test generating connection string with SSL."""
config = DatabaseConfig(
host="db.example.com",
database="secure_db",
username="user",
password="pass",
ssl_mode="require"
)
conn_string = config.get_connection_string()
assert "sslmode=require" in conn_string
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"host": "test-host",
"port": 5433,
"database": "test-db",
"username": "test-user",
"password": "test-pass",
"schema": "test_schema",
"pool_max_conn": 15,
"unknown_field": "ignored"
}
config = DatabaseConfig.from_dict(config_dict)
assert config.host == "test-host"
assert config.port == 5433
assert config.database == "test-db"
assert config.schema == "test_schema"
assert config.pool_max_conn == 15
class TestQueryBuilder:
"""Test SQL query building functionality."""
def test_build_select_query(self):
"""Test building SELECT queries."""
builder = QueryBuilder("test_schema")
query, params = builder.build_select_query(
table="users",
columns=["id", "name", "email"],
where={"status": "active", "age": 25},
order_by="created_at DESC",
limit=10
)
expected_query = (
"SELECT id, name, email FROM test_schema.users "
"WHERE status = %s AND age = %s "
"ORDER BY created_at DESC LIMIT 10"
)
assert query == expected_query
assert params == ["active", 25]
def test_build_select_all_columns(self):
"""Test building SELECT * query."""
builder = QueryBuilder("public")
query, params = builder.build_select_query("products")
expected_query = "SELECT * FROM public.products"
assert query == expected_query
assert params == []
def test_build_insert_query(self):
"""Test building INSERT queries."""
builder = QueryBuilder("inventory")
data = {
"product_name": "Widget",
"price": 19.99,
"quantity": 100,
"created_at": "NOW()"
}
query, params = builder.build_insert_query("products", data)
expected_query = (
"INSERT INTO inventory.products (product_name, price, quantity, created_at) "
"VALUES (%s, %s, %s, NOW()) RETURNING id"
)
assert query == expected_query
assert params == ["Widget", 19.99, 100]
def test_build_update_query(self):
"""Test building UPDATE queries."""
builder = QueryBuilder("sales")
data = {
"status": "shipped",
"shipped_date": "NOW()",
"tracking_number": "ABC123"
}
where_conditions = {"order_id": 12345}
query, params = builder.build_update_query("orders", data, where_conditions)
expected_query = (
"UPDATE sales.orders SET status = %s, shipped_date = NOW(), tracking_number = %s "
"WHERE order_id = %s"
)
assert query == expected_query
assert params == ["shipped", "ABC123", 12345]
def test_build_delete_query(self):
"""Test building DELETE queries."""
builder = QueryBuilder("logs")
where_conditions = {
"level": "DEBUG",
"created_at": "< NOW() - INTERVAL '7 days'"
}
query, params = builder.build_delete_query("application_logs", where_conditions)
expected_query = (
"DELETE FROM logs.application_logs "
"WHERE level = %s AND created_at < NOW() - INTERVAL '7 days'"
)
assert query == expected_query
assert params == ["DEBUG"]
def test_build_create_table_query(self):
"""Test building CREATE TABLE queries."""
builder = QueryBuilder("gas_station_1")
columns = {
"id": "SERIAL PRIMARY KEY",
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
"camera_id": "VARCHAR(255) NOT NULL",
"detection_class": "VARCHAR(100)",
"confidence": "DECIMAL(4,3)",
"bbox_data": "JSON",
"created_at": "TIMESTAMP DEFAULT NOW()",
"updated_at": "TIMESTAMP DEFAULT NOW()"
}
query = builder.build_create_table_query("detections", columns)
expected_parts = [
"CREATE TABLE IF NOT EXISTS gas_station_1.detections",
"id SERIAL PRIMARY KEY",
"session_id VARCHAR(255) UNIQUE NOT NULL",
"camera_id VARCHAR(255) NOT NULL",
"bbox_data JSON",
"created_at TIMESTAMP DEFAULT NOW()"
]
for part in expected_parts:
assert part in query
def test_escape_identifier(self):
"""Test SQL identifier escaping."""
builder = QueryBuilder("test")
assert builder.escape_identifier("table") == '"table"'
assert builder.escape_identifier("column_name") == '"column_name"'
assert builder.escape_identifier("user-table") == '"user-table"'
def test_format_value_for_sql(self):
"""Test SQL value formatting."""
builder = QueryBuilder("test")
# Regular values should use placeholder
assert builder.format_value_for_sql("string") == ("%s", "string")
assert builder.format_value_for_sql(42) == ("%s", 42)
assert builder.format_value_for_sql(3.14) == ("%s", 3.14)
# SQL functions should be literal
assert builder.format_value_for_sql("NOW()") == ("NOW()", None)
assert builder.format_value_for_sql("CURRENT_TIMESTAMP") == ("CURRENT_TIMESTAMP", None)
assert builder.format_value_for_sql("UUID()") == ("UUID()", None)
class TestDatabaseConnection:
"""Test database connection management."""
def test_creation(self, mock_database_connection):
"""Test connection creation."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
assert conn.config == config
assert conn.connection == mock_database_connection
assert conn.is_connected is True
def test_execute_query(self, mock_database_connection):
"""Test query execution."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [
(1, "John", "john@example.com"),
(2, "Jane", "jane@example.com")
]
mock_cursor.rowcount = 2
conn = DatabaseConnection(config, mock_database_connection)
query = "SELECT id, name, email FROM users WHERE status = %s"
params = ["active"]
result = conn.execute_query(query, params)
assert result == [
(1, "John", "john@example.com"),
(2, "Jane", "jane@example.com")
]
mock_cursor.execute.assert_called_once_with(query, params)
mock_cursor.fetchall.assert_called_once()
def test_execute_query_single_result(self, mock_database_connection):
"""Test query execution with single result."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (1, "John", "john@example.com")
conn = DatabaseConnection(config, mock_database_connection)
result = conn.execute_query("SELECT * FROM users WHERE id = %s", [1], fetch_one=True)
assert result == (1, "John", "john@example.com")
mock_cursor.fetchone.assert_called_once()
def test_execute_query_no_fetch(self, mock_database_connection):
"""Test query execution without fetching results."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 1
conn = DatabaseConnection(config, mock_database_connection)
result = conn.execute_query(
"INSERT INTO users (name) VALUES (%s)",
["John"],
fetch_results=False
)
assert result == 1 # Row count
mock_cursor.execute.assert_called_once()
mock_cursor.fetchall.assert_not_called()
mock_cursor.fetchone.assert_not_called()
def test_execute_query_error(self, mock_database_connection):
"""Test query execution error handling."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.execute.side_effect = psycopg2.Error("Database error")
conn = DatabaseConnection(config, mock_database_connection)
with pytest.raises(DatabaseError) as exc_info:
conn.execute_query("SELECT * FROM invalid_table")
assert "Database error" in str(exc_info.value)
def test_commit_transaction(self, mock_database_connection):
"""Test transaction commit."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.commit()
mock_database_connection.commit.assert_called_once()
def test_rollback_transaction(self, mock_database_connection):
"""Test transaction rollback."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.rollback()
mock_database_connection.rollback.assert_called_once()
def test_close_connection(self, mock_database_connection):
"""Test connection closing."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
conn.close()
assert conn.is_connected is False
mock_database_connection.close.assert_called_once()
class TestTransactionManager:
"""Test transaction management."""
def test_transaction_context_success(self, mock_database_connection):
"""Test successful transaction context."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
tx_manager = TransactionManager(conn)
with tx_manager:
# Simulate some database operations
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should commit on successful exit
mock_database_connection.commit.assert_called_once()
mock_database_connection.rollback.assert_not_called()
def test_transaction_context_error(self, mock_database_connection):
"""Test transaction context with error."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
conn = DatabaseConnection(config, mock_database_connection)
tx_manager = TransactionManager(conn)
with pytest.raises(DatabaseError):
with tx_manager:
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
# Simulate an error
raise DatabaseError("Something went wrong")
# Should rollback on error
mock_database_connection.rollback.assert_called_once()
mock_database_connection.commit.assert_not_called()
class TestDatabaseManager:
"""Test main database manager functionality."""
def test_initialization(self):
"""Test database manager initialization."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
assert manager.config == config
assert isinstance(manager.query_builder, QueryBuilder)
assert manager.query_builder.schema == "gas_station_1"
assert manager.connection is None
@pytest.mark.asyncio
async def test_connect_success(self):
"""Test successful database connection."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
mock_connection = Mock()
mock_connect.return_value = mock_connection
await manager.connect()
assert manager.connection is not None
assert manager.is_connected is True
mock_connect.assert_called_once()
@pytest.mark.asyncio
async def test_connect_failure(self):
"""Test database connection failure."""
config = DatabaseConfig(
host="nonexistent-host",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
mock_connect.side_effect = psycopg2.Error("Connection failed")
with pytest.raises(DatabaseError) as exc_info:
await manager.connect()
assert "Connection failed" in str(exc_info.value)
assert manager.is_connected is False
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test database disconnection."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
# Mock connection
mock_connection = Mock()
manager.connection = DatabaseConnection(config, mock_connection)
await manager.disconnect()
assert manager.connection is None
mock_connection.close.assert_called_once()
@pytest.mark.asyncio
async def test_execute_query(self, mock_database_connection):
"""Test query execution through manager."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [(1, "Test"), (2, "Data")]
result = await manager.execute_query("SELECT * FROM test_table")
assert result == [(1, "Test"), (2, "Data")]
mock_cursor.execute.assert_called_once()
@pytest.mark.asyncio
async def test_execute_query_not_connected(self):
"""Test query execution when not connected."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with pytest.raises(DatabaseError) as exc_info:
await manager.execute_query("SELECT * FROM test_table")
assert "not connected" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_insert_record(self, mock_database_connection):
"""Test inserting a record."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (123,) # Returned ID
data = {
"session_id": "session_123",
"camera_id": "camera_001",
"detection_class": "car",
"confidence": 0.95,
"created_at": "NOW()"
}
record_id = await manager.insert_record("car_detections", data)
assert record_id == 123
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_record(self, mock_database_connection):
"""Test updating a record."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 1
data = {
"car_brand": "Toyota",
"car_body_type": "Sedan",
"updated_at": "NOW()"
}
where_conditions = {"session_id": "session_123"}
rows_affected = await manager.update_record("car_info", data, where_conditions)
assert rows_affected == 1
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_records(self, mock_database_connection):
"""Test deleting records."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 3
where_conditions = {
"created_at": "< NOW() - INTERVAL '30 days'",
"processed": True
}
rows_deleted = await manager.delete_records("old_detections", where_conditions)
assert rows_deleted == 3
mock_cursor.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_create_table(self, mock_database_connection):
"""Test creating a table."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
columns = {
"id": "SERIAL PRIMARY KEY",
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
"camera_id": "VARCHAR(255) NOT NULL",
"detection_data": "JSON",
"created_at": "TIMESTAMP DEFAULT NOW()"
}
await manager.create_table("test_detections", columns)
mock_database_connection.cursor.return_value.execute.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_table_exists(self, mock_database_connection):
"""Test checking if table exists."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior - table exists
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchone.return_value = (1,)
exists = await manager.table_exists("car_detections")
assert exists is True
mock_cursor.execute.assert_called_once()
# Mock cursor behavior - table doesn't exist
mock_cursor.fetchone.return_value = None
exists = await manager.table_exists("nonexistent_table")
assert exists is False
@pytest.mark.asyncio
async def test_transaction_context(self, mock_database_connection):
"""Test transaction context manager."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
async with manager.transaction():
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should commit on successful completion
mock_database_connection.commit.assert_called()
@pytest.mark.asyncio
async def test_get_table_schema(self, mock_database_connection):
"""Test getting table schema information."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behavior
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.fetchall.return_value = [
("id", "integer", "NOT NULL"),
("session_id", "character varying", "NOT NULL"),
("created_at", "timestamp without time zone", "DEFAULT now()")
]
schema = await manager.get_table_schema("car_detections")
assert len(schema) == 3
assert schema[0] == ("id", "integer", "NOT NULL")
assert schema[1] == ("session_id", "character varying", "NOT NULL")
@pytest.mark.asyncio
async def test_bulk_insert(self, mock_database_connection):
"""Test bulk insert operation."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
records = [
{"name": "John", "email": "john@example.com"},
{"name": "Jane", "email": "jane@example.com"},
{"name": "Bob", "email": "bob@example.com"}
]
mock_cursor = mock_database_connection.cursor.return_value
mock_cursor.rowcount = 3
rows_inserted = await manager.bulk_insert("users", records)
assert rows_inserted == 3
mock_cursor.executemany.assert_called_once()
mock_database_connection.commit.assert_called_once()
@pytest.mark.asyncio
async def test_get_connection_stats(self, mock_database_connection):
"""Test getting connection statistics."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
stats = manager.get_connection_stats()
assert "connected" in stats
assert "host" in stats
assert "database" in stats
assert "schema" in stats
assert stats["connected"] is True
assert stats["host"] == "localhost"
assert stats["database"] == "test_db"
class TestDatabaseManagerIntegration:
"""Integration tests for database manager."""
@pytest.mark.asyncio
async def test_complete_car_detection_workflow(self, mock_database_connection):
"""Test complete car detection database workflow."""
config = DatabaseConfig(
host="localhost",
database="gas_station_db",
username="detector_user",
password="detector_pass",
schema="gas_station_1"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Mock cursor behaviors for different operations
mock_cursor = mock_database_connection.cursor.return_value
# 1. Create initial detection record
mock_cursor.fetchone.return_value = (456,) # Returned ID
detection_data = {
"session_id": str(uuid.uuid4()),
"camera_id": "camera_001",
"display_id": "display_001",
"detection_class": "car",
"confidence": 0.92,
"bbox_x1": 100,
"bbox_y1": 200,
"bbox_x2": 300,
"bbox_y2": 400,
"track_id": 1001,
"created_at": "NOW()"
}
detection_id = await manager.insert_record("car_detections", detection_data)
assert detection_id == 456
# 2. Update with classification results
mock_cursor.rowcount = 1
classification_data = {
"car_brand": "Toyota",
"car_model": "Camry",
"car_body_type": "Sedan",
"car_color": "Blue",
"brand_confidence": 0.87,
"bodytype_confidence": 0.82,
"color_confidence": 0.79,
"updated_at": "NOW()"
}
where_conditions = {"session_id": detection_data["session_id"]}
rows_updated = await manager.update_record("car_detections", classification_data, where_conditions)
assert rows_updated == 1
# 3. Query final results
mock_cursor.fetchall.return_value = [
(456, detection_data["session_id"], "camera_001", "car", 0.92, "Toyota", "Sedan")
]
results = await manager.execute_query(
"SELECT id, session_id, camera_id, detection_class, confidence, car_brand, car_body_type "
"FROM gas_station_1.car_detections WHERE session_id = %s",
[detection_data["session_id"]]
)
assert len(results) == 1
assert results[0][0] == 456 # ID
assert results[0][3] == "car" # detection_class
assert results[0][5] == "Toyota" # car_brand
# Verify all database operations were called
assert mock_cursor.execute.call_count == 3
assert mock_database_connection.commit.call_count == 2
@pytest.mark.asyncio
async def test_error_handling_and_recovery(self, mock_database_connection):
"""Test error handling and recovery scenarios."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
manager.connection = DatabaseConnection(config, mock_database_connection)
# Test transaction rollback on error
mock_cursor = mock_database_connection.cursor.return_value
with pytest.raises(DatabaseError):
async with manager.transaction():
# First operation succeeds
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
# Second operation fails
mock_cursor.execute.side_effect = psycopg2.Error("Constraint violation")
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
# Should have rolled back
mock_database_connection.rollback.assert_called_once()
mock_database_connection.commit.assert_not_called()
@pytest.mark.asyncio
async def test_connection_recovery(self):
"""Test automatic connection recovery."""
config = DatabaseConfig(
host="localhost",
database="test_db",
username="test_user",
password="test_pass"
)
manager = DatabaseManager(config)
with patch('psycopg2.connect') as mock_connect:
# First connection attempt fails
mock_connect.side_effect = [
psycopg2.Error("Connection refused"),
Mock() # Second attempt succeeds
]
# First attempt should fail
with pytest.raises(DatabaseError):
await manager.connect()
# Second attempt should succeed
await manager.connect()
assert manager.is_connected is True

View file

@ -0,0 +1,964 @@
"""
Unit tests for Redis client functionality.
"""
import pytest
import asyncio
import json
import base64
import time
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from datetime import datetime, timedelta
import redis
import numpy as np
from detector_worker.storage.redis_client import (
RedisClient,
RedisConfig,
RedisConnectionPool,
RedisPublisher,
RedisSubscriber,
RedisImageStorage,
RedisError,
ConnectionPoolError
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
from detector_worker.core.exceptions import ConfigurationError
class TestRedisConfig:
"""Test Redis configuration."""
def test_creation_minimal(self):
"""Test creating Redis config with minimal parameters."""
config = RedisConfig(
host="localhost"
)
assert config.host == "localhost"
assert config.port == 6379 # Default port
assert config.password is None
assert config.db == 0 # Default database
assert config.enabled is True
def test_creation_full(self):
"""Test creating Redis config with all parameters."""
config = RedisConfig(
host="redis.example.com",
port=6380,
password="secure_pass",
db=2,
enabled=True,
connection_timeout=5.0,
socket_timeout=3.0,
socket_connect_timeout=2.0,
max_connections=50,
retry_on_timeout=True,
health_check_interval=30
)
assert config.host == "redis.example.com"
assert config.port == 6380
assert config.password == "secure_pass"
assert config.db == 2
assert config.connection_timeout == 5.0
assert config.max_connections == 50
assert config.retry_on_timeout is True
def test_get_connection_params(self):
"""Test getting Redis connection parameters."""
config = RedisConfig(
host="localhost",
port=6379,
password="test_pass",
db=1,
connection_timeout=10.0
)
params = config.get_connection_params()
assert params["host"] == "localhost"
assert params["port"] == 6379
assert params["password"] == "test_pass"
assert params["db"] == 1
assert params["socket_timeout"] == 10.0
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"host": "redis-server",
"port": 6380,
"password": "secret",
"db": 3,
"max_connections": 100,
"unknown_field": "ignored"
}
config = RedisConfig.from_dict(config_dict)
assert config.host == "redis-server"
assert config.port == 6380
assert config.password == "secret"
assert config.db == 3
assert config.max_connections == 100
class TestRedisConnectionPool:
"""Test Redis connection pool management."""
def test_creation(self):
"""Test connection pool creation."""
config = RedisConfig(
host="localhost",
max_connections=20
)
pool = RedisConnectionPool(config)
assert pool.config == config
assert pool.pool is None
assert pool.is_connected is False
@pytest.mark.asyncio
async def test_connect_success(self):
"""Test successful connection to Redis."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
with patch('redis.ConnectionPool') as mock_pool_class:
mock_pool = Mock()
mock_pool_class.return_value = mock_pool
with patch('redis.Redis') as mock_redis_class:
mock_redis = Mock()
mock_redis.ping.return_value = True
mock_redis_class.return_value = mock_redis
await pool.connect()
assert pool.is_connected is True
assert pool.pool is not None
mock_pool_class.assert_called_once()
@pytest.mark.asyncio
async def test_connect_failure(self):
"""Test Redis connection failure."""
config = RedisConfig(host="nonexistent-redis")
pool = RedisConnectionPool(config)
with patch('redis.ConnectionPool') as mock_pool_class:
mock_pool_class.side_effect = redis.ConnectionError("Connection failed")
with pytest.raises(RedisError) as exc_info:
await pool.connect()
assert "Connection failed" in str(exc_info.value)
assert pool.is_connected is False
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test Redis disconnection."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
# Mock connected state
mock_pool = Mock()
mock_redis = Mock()
pool.pool = mock_pool
pool._redis_client = mock_redis
pool.is_connected = True
await pool.disconnect()
assert pool.is_connected is False
assert pool.pool is None
mock_pool.disconnect.assert_called_once()
def test_get_client_connected(self):
"""Test getting Redis client when connected."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
mock_pool = Mock()
mock_redis = Mock()
pool.pool = mock_pool
pool._redis_client = mock_redis
pool.is_connected = True
client = pool.get_client()
assert client == mock_redis
def test_get_client_not_connected(self):
"""Test getting Redis client when not connected."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
with pytest.raises(RedisError) as exc_info:
pool.get_client()
assert "not connected" in str(exc_info.value).lower()
def test_health_check(self):
"""Test Redis health check."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
mock_redis = Mock()
mock_redis.ping.return_value = True
pool._redis_client = mock_redis
pool.is_connected = True
is_healthy = pool.health_check()
assert is_healthy is True
mock_redis.ping.assert_called_once()
def test_health_check_failure(self):
"""Test Redis health check failure."""
config = RedisConfig(host="localhost")
pool = RedisConnectionPool(config)
mock_redis = Mock()
mock_redis.ping.side_effect = redis.ConnectionError("Connection lost")
pool._redis_client = mock_redis
pool.is_connected = True
is_healthy = pool.health_check()
assert is_healthy is False
class TestRedisImageStorage:
"""Test Redis image storage functionality."""
def test_creation(self, mock_redis_client):
"""Test Redis image storage creation."""
storage = RedisImageStorage(mock_redis_client)
assert storage.redis_client == mock_redis_client
assert storage.default_expiry == 3600 # 1 hour
assert storage.compression_enabled is True
@pytest.mark.asyncio
async def test_store_image_success(self, mock_redis_client, mock_frame):
"""Test successful image storage."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
with patch('cv2.imencode') as mock_imencode:
# Mock successful encoding
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
result = await storage.store_image("test_key", mock_frame, expire_seconds=600)
assert result is True
mock_redis_client.set.assert_called_once()
mock_redis_client.expire.assert_called_once_with("test_key", 600)
mock_imencode.assert_called_once()
@pytest.mark.asyncio
async def test_store_image_cropped(self, mock_redis_client, mock_frame):
"""Test storing cropped image."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
result = await storage.store_image("cropped_key", mock_frame, crop_bbox=bbox)
assert result is True
mock_redis_client.set.assert_called_once()
@pytest.mark.asyncio
async def test_store_image_encoding_failure(self, mock_redis_client, mock_frame):
"""Test image storage with encoding failure."""
storage = RedisImageStorage(mock_redis_client)
with patch('cv2.imencode') as mock_imencode:
# Mock encoding failure
mock_imencode.return_value = (False, None)
with pytest.raises(RedisError) as exc_info:
await storage.store_image("test_key", mock_frame)
assert "Failed to encode image" in str(exc_info.value)
mock_redis_client.set.assert_not_called()
@pytest.mark.asyncio
async def test_store_image_redis_failure(self, mock_redis_client, mock_frame):
"""Test image storage with Redis failure."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.set.side_effect = redis.RedisError("Redis error")
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
with pytest.raises(RedisError) as exc_info:
await storage.store_image("test_key", mock_frame)
assert "Redis error" in str(exc_info.value)
@pytest.mark.asyncio
async def test_retrieve_image_success(self, mock_redis_client):
"""Test successful image retrieval."""
storage = RedisImageStorage(mock_redis_client)
# Mock encoded image data
original_image = np.ones((100, 100, 3), dtype=np.uint8) * 128
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Mock Redis returning base64 encoded data
base64_data = base64.b64encode(encoded_data.tobytes()).decode('utf-8')
mock_redis_client.get.return_value = base64_data
with patch('cv2.imdecode') as mock_imdecode:
mock_imdecode.return_value = original_image
retrieved_image = await storage.retrieve_image("test_key")
assert retrieved_image is not None
assert retrieved_image.shape == (100, 100, 3)
mock_redis_client.get.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_retrieve_image_not_found(self, mock_redis_client):
"""Test image retrieval when key not found."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.get.return_value = None
retrieved_image = await storage.retrieve_image("nonexistent_key")
assert retrieved_image is None
mock_redis_client.get.assert_called_once_with("nonexistent_key")
@pytest.mark.asyncio
async def test_delete_image(self, mock_redis_client):
"""Test image deletion."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.delete.return_value = 1
result = await storage.delete_image("test_key")
assert result is True
mock_redis_client.delete.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_delete_image_not_found(self, mock_redis_client):
"""Test deleting non-existent image."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.delete.return_value = 0
result = await storage.delete_image("nonexistent_key")
assert result is False
mock_redis_client.delete.assert_called_once_with("nonexistent_key")
@pytest.mark.asyncio
async def test_bulk_delete_images(self, mock_redis_client):
"""Test bulk image deletion."""
storage = RedisImageStorage(mock_redis_client)
keys = ["key1", "key2", "key3"]
mock_redis_client.delete.return_value = 3
deleted_count = await storage.bulk_delete_images(keys)
assert deleted_count == 3
mock_redis_client.delete.assert_called_once_with(*keys)
@pytest.mark.asyncio
async def test_cleanup_expired_images(self, mock_redis_client):
"""Test cleanup of expired images."""
storage = RedisImageStorage(mock_redis_client)
# Mock scan to return image keys
mock_redis_client.scan_iter.return_value = [
b"inference:camera1:image1",
b"inference:camera2:image2",
b"inference:camera1:image3"
]
# Mock ttl to return different expiry times
mock_redis_client.ttl.side_effect = [-1, 100, -2] # No expiry, valid, expired
mock_redis_client.delete.return_value = 1
deleted_count = await storage.cleanup_expired_images("inference:*")
assert deleted_count == 1 # Only expired images deleted
mock_redis_client.delete.assert_called_once()
def test_get_image_info(self, mock_redis_client):
"""Test getting image metadata."""
storage = RedisImageStorage(mock_redis_client)
mock_redis_client.exists.return_value = 1
mock_redis_client.ttl.return_value = 1800 # 30 minutes
mock_redis_client.memory_usage.return_value = 4096 # 4KB
info = storage.get_image_info("test_key")
assert info["exists"] is True
assert info["ttl"] == 1800
assert info["size_bytes"] == 4096
mock_redis_client.exists.assert_called_once_with("test_key")
mock_redis_client.ttl.assert_called_once_with("test_key")
class TestRedisPublisher:
"""Test Redis publisher functionality."""
def test_creation(self, mock_redis_client):
"""Test Redis publisher creation."""
publisher = RedisPublisher(mock_redis_client)
assert publisher.redis_client == mock_redis_client
@pytest.mark.asyncio
async def test_publish_message_string(self, mock_redis_client):
"""Test publishing string message."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.return_value = 2 # 2 subscribers
result = await publisher.publish("test_channel", "Hello, Redis!")
assert result == 2
mock_redis_client.publish.assert_called_once_with("test_channel", "Hello, Redis!")
@pytest.mark.asyncio
async def test_publish_message_json(self, mock_redis_client):
"""Test publishing JSON message."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.return_value = 1
message_data = {
"camera_id": "camera_001",
"detection_class": "car",
"confidence": 0.95,
"timestamp": 1640995200000
}
result = await publisher.publish("detections", message_data)
assert result == 1
# Should have been JSON serialized
expected_json = json.dumps(message_data)
mock_redis_client.publish.assert_called_once_with("detections", expected_json)
@pytest.mark.asyncio
async def test_publish_detection_event(self, mock_redis_client):
"""Test publishing detection event."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.return_value = 3
detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
result = await publisher.publish_detection_event(
"camera_detections",
detection,
camera_id="camera_001",
session_id="session_123"
)
assert result == 3
# Verify the published message structure
call_args = mock_redis_client.publish.call_args
channel = call_args[0][0]
message_str = call_args[0][1]
message_data = json.loads(message_str)
assert channel == "camera_detections"
assert message_data["event_type"] == "detection"
assert message_data["camera_id"] == "camera_001"
assert message_data["session_id"] == "session_123"
assert message_data["detection"]["class"] == "car"
assert message_data["detection"]["confidence"] == 0.92
@pytest.mark.asyncio
async def test_publish_batch_messages(self, mock_redis_client):
"""Test publishing multiple messages in batch."""
publisher = RedisPublisher(mock_redis_client)
mock_pipeline = Mock()
mock_redis_client.pipeline.return_value = mock_pipeline
mock_pipeline.execute.return_value = [1, 2, 1] # Subscriber counts
messages = [
("channel1", "message1"),
("channel2", {"data": "message2"}),
("channel1", "message3")
]
results = await publisher.publish_batch(messages)
assert results == [1, 2, 1]
mock_redis_client.pipeline.assert_called_once()
assert mock_pipeline.publish.call_count == 3
mock_pipeline.execute.assert_called_once()
@pytest.mark.asyncio
async def test_publish_error_handling(self, mock_redis_client):
"""Test error handling in publishing."""
publisher = RedisPublisher(mock_redis_client)
mock_redis_client.publish.side_effect = redis.RedisError("Publish failed")
with pytest.raises(RedisError) as exc_info:
await publisher.publish("test_channel", "test_message")
assert "Publish failed" in str(exc_info.value)
class TestRedisSubscriber:
"""Test Redis subscriber functionality."""
def test_creation(self, mock_redis_client):
"""Test Redis subscriber creation."""
subscriber = RedisSubscriber(mock_redis_client)
assert subscriber.redis_client == mock_redis_client
assert subscriber.pubsub is None
assert subscriber.subscriptions == set()
@pytest.mark.asyncio
async def test_subscribe_to_channel(self, mock_redis_client):
"""Test subscribing to a channel."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
await subscriber.subscribe("test_channel")
assert "test_channel" in subscriber.subscriptions
mock_pubsub.subscribe.assert_called_once_with("test_channel")
@pytest.mark.asyncio
async def test_subscribe_to_pattern(self, mock_redis_client):
"""Test subscribing to a pattern."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
await subscriber.subscribe_pattern("detection:*")
assert "detection:*" in subscriber.subscriptions
mock_pubsub.psubscribe.assert_called_once_with("detection:*")
@pytest.mark.asyncio
async def test_unsubscribe_from_channel(self, mock_redis_client):
"""Test unsubscribing from a channel."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
subscriber.pubsub = mock_pubsub
subscriber.subscriptions.add("test_channel")
await subscriber.unsubscribe("test_channel")
assert "test_channel" not in subscriber.subscriptions
mock_pubsub.unsubscribe.assert_called_once_with("test_channel")
@pytest.mark.asyncio
async def test_listen_for_messages(self, mock_redis_client):
"""Test listening for messages."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
mock_redis_client.pubsub.return_value = mock_pubsub
# Mock message stream
messages = [
{"type": "subscribe", "channel": "test", "data": 1},
{"type": "message", "channel": "test", "data": "Hello"},
{"type": "message", "channel": "test", "data": '{"key": "value"}'}
]
mock_pubsub.listen.return_value = iter(messages)
received_messages = []
message_count = 0
async for message in subscriber.listen():
received_messages.append(message)
message_count += 1
if message_count >= 2: # Only process actual messages
break
# Should receive 2 actual messages (excluding subscribe confirmation)
assert len(received_messages) == 2
assert received_messages[0]["data"] == "Hello"
assert received_messages[1]["data"] == {"key": "value"} # Should be parsed as JSON
@pytest.mark.asyncio
async def test_close_subscription(self, mock_redis_client):
"""Test closing subscription."""
subscriber = RedisSubscriber(mock_redis_client)
mock_pubsub = Mock()
subscriber.pubsub = mock_pubsub
subscriber.subscriptions = {"channel1", "pattern:*"}
await subscriber.close()
assert len(subscriber.subscriptions) == 0
mock_pubsub.close.assert_called_once()
assert subscriber.pubsub is None
class TestRedisClient:
"""Test main Redis client functionality."""
def test_initialization(self):
"""Test Redis client initialization."""
config = RedisConfig(host="localhost", port=6379)
client = RedisClient(config)
assert client.config == config
assert isinstance(client.connection_pool, RedisConnectionPool)
assert client.image_storage is None
assert client.publisher is None
assert client.subscriber is None
@pytest.mark.asyncio
async def test_connect_and_initialize_components(self):
"""Test connecting and initializing all components."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect:
mock_redis_client = Mock()
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
await client.connect()
assert client.image_storage is not None
assert client.publisher is not None
assert client.subscriber is not None
assert isinstance(client.image_storage, RedisImageStorage)
assert isinstance(client.publisher, RedisPublisher)
assert isinstance(client.subscriber, RedisSubscriber)
mock_connect.assert_called_once()
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test disconnection."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.is_connected = True
client.subscriber = Mock()
client.subscriber.close = AsyncMock()
with patch.object(client.connection_pool, 'disconnect', new_callable=AsyncMock) as mock_disconnect:
await client.disconnect()
client.subscriber.close.assert_called_once()
mock_disconnect.assert_called_once()
assert client.image_storage is None
assert client.publisher is None
assert client.subscriber is None
@pytest.mark.asyncio
async def test_store_and_retrieve_data(self, mock_redis_client):
"""Test storing and retrieving data."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
# Test storing data
mock_redis_client.set.return_value = True
result = await client.set("test_key", "test_value", expire_seconds=300)
assert result is True
mock_redis_client.set.assert_called_once_with("test_key", "test_value")
mock_redis_client.expire.assert_called_once_with("test_key", 300)
# Test retrieving data
mock_redis_client.get.return_value = "test_value"
value = await client.get("test_key")
assert value == "test_value"
mock_redis_client.get.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_delete_keys(self, mock_redis_client):
"""Test deleting keys."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.delete.return_value = 2
result = await client.delete("key1", "key2")
assert result == 2
mock_redis_client.delete.assert_called_once_with("key1", "key2")
@pytest.mark.asyncio
async def test_exists_check(self, mock_redis_client):
"""Test checking key existence."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.exists.return_value = 1
exists = await client.exists("test_key")
assert exists is True
mock_redis_client.exists.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_expire_key(self, mock_redis_client):
"""Test setting key expiration."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.expire.return_value = True
result = await client.expire("test_key", 600)
assert result is True
mock_redis_client.expire.assert_called_once_with("test_key", 600)
@pytest.mark.asyncio
async def test_get_ttl(self, mock_redis_client):
"""Test getting key TTL."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.ttl.return_value = 300
ttl = await client.ttl("test_key")
assert ttl == 300
mock_redis_client.ttl.assert_called_once_with("test_key")
@pytest.mark.asyncio
async def test_scan_keys(self, mock_redis_client):
"""Test scanning for keys."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.scan_iter.return_value = [b"key1", b"key2", b"key3"]
keys = await client.scan_keys("test:*")
assert keys == ["key1", "key2", "key3"]
mock_redis_client.scan_iter.assert_called_once_with(match="test:*")
@pytest.mark.asyncio
async def test_flush_database(self, mock_redis_client):
"""Test flushing database."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_redis_client.flushdb.return_value = True
result = await client.flush_db()
assert result is True
mock_redis_client.flushdb.assert_called_once()
def test_get_connection_info(self):
"""Test getting connection information."""
config = RedisConfig(
host="redis.example.com",
port=6380,
db=2
)
client = RedisClient(config)
client.connection_pool.is_connected = True
info = client.get_connection_info()
assert info["connected"] is True
assert info["host"] == "redis.example.com"
assert info["port"] == 6380
assert info["database"] == 2
@pytest.mark.asyncio
async def test_pipeline_operations(self, mock_redis_client):
"""Test Redis pipeline operations."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
mock_pipeline = Mock()
mock_redis_client.pipeline.return_value = mock_pipeline
mock_pipeline.execute.return_value = [True, True, 1]
async with client.pipeline() as pipe:
pipe.set("key1", "value1")
pipe.set("key2", "value2")
pipe.delete("key3")
results = await pipe.execute()
assert results == [True, True, 1]
mock_redis_client.pipeline.assert_called_once()
mock_pipeline.execute.assert_called_once()
class TestRedisClientIntegration:
"""Integration tests for Redis client."""
@pytest.mark.asyncio
async def test_complete_image_workflow(self, mock_redis_client, mock_frame):
"""Test complete image storage workflow."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state and components
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
client.image_storage = RedisImageStorage(mock_redis_client)
client.publisher = RedisPublisher(mock_redis_client)
# Mock Redis operations
mock_redis_client.set.return_value = True
mock_redis_client.expire.return_value = True
mock_redis_client.publish.return_value = 2
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Store image
store_result = await client.image_storage.store_image(
"detection:camera001:1640995200:session123",
mock_frame,
expire_seconds=600
)
# Publish detection event
detection_event = {
"camera_id": "camera001",
"session_id": "session123",
"detection_class": "car",
"confidence": 0.95,
"timestamp": 1640995200000
}
publish_result = await client.publisher.publish("detections:camera001", detection_event)
assert store_result is True
assert publish_result == 2
# Verify Redis operations
mock_redis_client.set.assert_called_once()
mock_redis_client.expire.assert_called_once()
mock_redis_client.publish.assert_called_once()
@pytest.mark.asyncio
async def test_error_recovery_and_reconnection(self):
"""Test error recovery and reconnection."""
config = RedisConfig(host="localhost", retry_on_timeout=True)
client = RedisClient(config)
with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect:
with patch.object(client.connection_pool, 'health_check') as mock_health_check:
# First health check fails, second succeeds
mock_health_check.side_effect = [False, True]
# First connection attempt fails, second succeeds
mock_connect.side_effect = [RedisError("Connection failed"), None]
# Simulate connection recovery
try:
await client.connect()
except RedisError:
# Retry connection
await client.connect()
assert mock_connect.call_count == 2
@pytest.mark.asyncio
async def test_bulk_operations_performance(self, mock_redis_client):
"""Test bulk operations for performance."""
config = RedisConfig(host="localhost")
client = RedisClient(config)
# Mock connected state
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
client.connection_pool.is_connected = True
client.publisher = RedisPublisher(mock_redis_client)
# Mock pipeline operations
mock_pipeline = Mock()
mock_redis_client.pipeline.return_value = mock_pipeline
mock_pipeline.execute.return_value = [1] * 100 # 100 successful operations
# Prepare bulk messages
messages = [
(f"channel_{i}", f"message_{i}")
for i in range(100)
]
start_time = time.time()
results = await client.publisher.publish_batch(messages)
execution_time = time.time() - start_time
assert len(results) == 100
assert all(result == 1 for result in results)
# Should be faster than individual operations
assert execution_time < 1.0 # Should complete in less than 1 second
# Pipeline should be used for efficiency
mock_redis_client.pipeline.assert_called_once()
assert mock_pipeline.publish.call_count == 100
mock_pipeline.execute.assert_called_once()

View file

@ -0,0 +1,883 @@
"""
Unit tests for session cache management.
"""
import pytest
import time
import uuid
from unittest.mock import Mock, patch
from datetime import datetime, timedelta
from collections import defaultdict
from detector_worker.storage.session_cache import (
SessionCache,
SessionCacheManager,
SessionData,
CacheConfig,
CacheEntry,
CacheStats,
SessionError,
CacheError
)
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
class TestCacheConfig:
"""Test cache configuration."""
def test_creation_default(self):
"""Test creating cache config with default values."""
config = CacheConfig()
assert config.max_size == 1000
assert config.ttl_seconds == 3600 # 1 hour
assert config.cleanup_interval == 300 # 5 minutes
assert config.eviction_policy == "lru"
assert config.enable_persistence is False
def test_creation_custom(self):
"""Test creating cache config with custom values."""
config = CacheConfig(
max_size=5000,
ttl_seconds=7200,
cleanup_interval=600,
eviction_policy="lfu",
enable_persistence=True,
persistence_path="/tmp/cache"
)
assert config.max_size == 5000
assert config.ttl_seconds == 7200
assert config.cleanup_interval == 600
assert config.eviction_policy == "lfu"
assert config.enable_persistence is True
assert config.persistence_path == "/tmp/cache"
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"max_size": 2000,
"ttl_seconds": 1800,
"eviction_policy": "fifo",
"enable_persistence": True,
"unknown_field": "ignored"
}
config = CacheConfig.from_dict(config_dict)
assert config.max_size == 2000
assert config.ttl_seconds == 1800
assert config.eviction_policy == "fifo"
assert config.enable_persistence is True
class TestCacheEntry:
"""Test cache entry data structure."""
def test_creation(self):
"""Test cache entry creation."""
data = {"key": "value", "number": 42}
entry = CacheEntry(data, ttl_seconds=600)
assert entry.data == data
assert entry.ttl_seconds == 600
assert entry.created_at <= time.time()
assert entry.last_accessed <= time.time()
assert entry.access_count == 1
assert entry.size > 0
def test_is_expired(self):
"""Test expiration checking."""
# Non-expired entry
entry = CacheEntry({"data": "test"}, ttl_seconds=600)
assert entry.is_expired() is False
# Expired entry (simulate by setting old creation time)
entry.created_at = time.time() - 700 # Created 700 seconds ago
assert entry.is_expired() is True
# Entry without expiration
entry_no_ttl = CacheEntry({"data": "test"})
assert entry_no_ttl.is_expired() is False
def test_touch(self):
"""Test updating access time and count."""
entry = CacheEntry({"data": "test"})
original_access_time = entry.last_accessed
original_access_count = entry.access_count
time.sleep(0.01) # Small delay
entry.touch()
assert entry.last_accessed > original_access_time
assert entry.access_count == original_access_count + 1
def test_age(self):
"""Test age calculation."""
entry = CacheEntry({"data": "test"})
time.sleep(0.01) # Small delay
age = entry.age()
assert age > 0
assert age < 1 # Should be less than 1 second
def test_size_estimation(self):
"""Test size estimation."""
small_entry = CacheEntry({"key": "value"})
large_entry = CacheEntry({"key": "x" * 1000, "data": list(range(100))})
assert large_entry.size > small_entry.size
class TestSessionData:
"""Test session data structure."""
def test_creation(self):
"""Test session data creation."""
session_data = SessionData(
session_id="session_123",
camera_id="camera_001",
display_id="display_001"
)
assert session_data.session_id == "session_123"
assert session_data.camera_id == "camera_001"
assert session_data.display_id == "display_001"
assert session_data.created_at <= time.time()
assert session_data.last_activity <= time.time()
assert session_data.detection_data == {}
assert session_data.metadata == {}
def test_update_activity(self):
"""Test updating last activity."""
session_data = SessionData("session_123", "camera_001", "display_001")
original_activity = session_data.last_activity
time.sleep(0.01)
session_data.update_activity()
assert session_data.last_activity > original_activity
def test_add_detection_data(self):
"""Test adding detection data."""
session_data = SessionData("session_123", "camera_001", "display_001")
detection_data = {
"class": "car",
"confidence": 0.95,
"bbox": [100, 200, 300, 400]
}
session_data.add_detection_data("main_detection", detection_data)
assert "main_detection" in session_data.detection_data
assert session_data.detection_data["main_detection"] == detection_data
def test_add_metadata(self):
"""Test adding metadata."""
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_metadata("model_version", "v2.1")
session_data.add_metadata("inference_time", 0.15)
assert session_data.metadata["model_version"] == "v2.1"
assert session_data.metadata["inference_time"] == 0.15
def test_is_expired(self):
"""Test session expiration."""
session_data = SessionData("session_123", "camera_001", "display_001")
# Not expired with default timeout
assert session_data.is_expired() is False
# Expired with short timeout
assert session_data.is_expired(timeout_seconds=0.001) is True
def test_to_dict(self):
"""Test converting session to dictionary."""
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("detection", {"class": "car", "confidence": 0.9})
session_data.add_metadata("model_id", "yolo_v8")
data_dict = session_data.to_dict()
assert data_dict["session_id"] == "session_123"
assert data_dict["camera_id"] == "camera_001"
assert data_dict["detection_data"]["detection"]["class"] == "car"
assert data_dict["metadata"]["model_id"] == "yolo_v8"
assert "created_at" in data_dict
assert "last_activity" in data_dict
class TestCacheStats:
"""Test cache statistics."""
def test_creation(self):
"""Test cache stats creation."""
stats = CacheStats()
assert stats.hits == 0
assert stats.misses == 0
assert stats.evictions == 0
assert stats.size == 0
assert stats.memory_usage == 0
def test_hit_rate_calculation(self):
"""Test hit rate calculation."""
stats = CacheStats()
# No requests yet
assert stats.hit_rate() == 0.0
# Some hits and misses
stats.hits = 8
stats.misses = 2
assert stats.hit_rate() == 0.8 # 8 / (8 + 2)
def test_total_requests(self):
"""Test total requests calculation."""
stats = CacheStats()
stats.hits = 15
stats.misses = 5
assert stats.total_requests() == 20
class TestSessionCache:
"""Test session cache functionality."""
def test_creation(self):
"""Test session cache creation."""
config = CacheConfig(max_size=100, ttl_seconds=300)
cache = SessionCache(config)
assert cache.config == config
assert cache.max_size == 100
assert cache.ttl_seconds == 300
assert len(cache._cache) == 0
assert len(cache._access_order) == 0
def test_put_and_get_session(self):
"""Test putting and getting session data."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("main", {"class": "car", "confidence": 0.9})
# Put session
cache.put("session_123", session_data)
# Get session
retrieved_data = cache.get("session_123")
assert retrieved_data is not None
assert retrieved_data.session_id == "session_123"
assert retrieved_data.camera_id == "camera_001"
assert "main" in retrieved_data.detection_data
def test_get_nonexistent_session(self):
"""Test getting non-existent session."""
cache = SessionCache(CacheConfig(max_size=10))
result = cache.get("nonexistent_session")
assert result is None
def test_contains_check(self):
"""Test checking if session exists."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
assert cache.contains("session_123") is True
assert cache.contains("nonexistent_session") is False
def test_remove_session(self):
"""Test removing session."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
assert cache.contains("session_123") is True
removed_data = cache.remove("session_123")
assert removed_data is not None
assert removed_data.session_id == "session_123"
assert cache.contains("session_123") is False
def test_size_tracking(self):
"""Test cache size tracking."""
cache = SessionCache(CacheConfig(max_size=10))
assert cache.size() == 0
assert cache.is_empty() is True
# Add sessions
for i in range(3):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
assert cache.size() == 3
assert cache.is_empty() is False
def test_lru_eviction(self):
"""Test LRU eviction policy."""
cache = SessionCache(CacheConfig(max_size=3, eviction_policy="lru"))
# Fill cache to capacity
for i in range(3):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
# Access session_1 to make it recently used
cache.get("session_1")
# Add another session (should evict session_0, the least recently used)
new_session = SessionData("session_3", "camera_001", "display_001")
cache.put("session_3", new_session)
assert cache.size() == 3
assert cache.contains("session_0") is False # Evicted
assert cache.contains("session_1") is True # Recently accessed
assert cache.contains("session_2") is True
assert cache.contains("session_3") is True # Newly added
def test_ttl_expiration(self):
"""Test TTL-based expiration."""
cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1)) # 100ms TTL
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
# Should exist immediately
assert cache.contains("session_123") is True
# Wait for expiration
time.sleep(0.2)
# Should be expired (but might still be in cache until cleanup)
entry = cache._cache.get("session_123")
if entry:
assert entry.is_expired() is True
# Getting expired entry should return None and clean it up
retrieved = cache.get("session_123")
assert retrieved is None
assert cache.contains("session_123") is False
def test_cleanup_expired_entries(self):
"""Test cleanup of expired entries."""
cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1))
# Add multiple sessions
for i in range(3):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
assert cache.size() == 3
# Wait for expiration
time.sleep(0.2)
# Cleanup expired entries
cleaned_count = cache.cleanup_expired()
assert cleaned_count == 3
assert cache.size() == 0
def test_clear_cache(self):
"""Test clearing entire cache."""
cache = SessionCache(CacheConfig(max_size=10))
# Add sessions
for i in range(5):
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
cache.put(f"session_{i}", session_data)
assert cache.size() == 5
cache.clear()
assert cache.size() == 0
assert cache.is_empty() is True
def test_get_all_sessions(self):
"""Test getting all sessions."""
cache = SessionCache(CacheConfig(max_size=10))
sessions = []
for i in range(3):
session_data = SessionData(f"session_{i}", f"camera_{i}", "display_001")
cache.put(f"session_{i}", session_data)
sessions.append(session_data)
all_sessions = cache.get_all()
assert len(all_sessions) == 3
for session_id, session_data in all_sessions.items():
assert session_id.startswith("session_")
assert session_data.session_id == session_id
def test_get_sessions_by_camera(self):
"""Test getting sessions by camera ID."""
cache = SessionCache(CacheConfig(max_size=10))
# Add sessions for different cameras
for i in range(2):
session_data1 = SessionData(f"session_cam1_{i}", "camera_001", "display_001")
session_data2 = SessionData(f"session_cam2_{i}", "camera_002", "display_001")
cache.put(f"session_cam1_{i}", session_data1)
cache.put(f"session_cam2_{i}", session_data2)
camera1_sessions = cache.get_by_camera("camera_001")
camera2_sessions = cache.get_by_camera("camera_002")
assert len(camera1_sessions) == 2
assert len(camera2_sessions) == 2
for session_data in camera1_sessions:
assert session_data.camera_id == "camera_001"
for session_data in camera2_sessions:
assert session_data.camera_id == "camera_002"
def test_statistics_tracking(self):
"""Test cache statistics tracking."""
cache = SessionCache(CacheConfig(max_size=10))
session_data = SessionData("session_123", "camera_001", "display_001")
cache.put("session_123", session_data)
# Cache miss
cache.get("nonexistent_session")
# Cache hit
cache.get("session_123")
cache.get("session_123") # Another hit
stats = cache.get_stats()
assert stats.hits == 2
assert stats.misses == 1
assert stats.size == 1
assert stats.hit_rate() == 2/3 # 2 hits out of 3 total requests
def test_memory_usage_estimation(self):
"""Test memory usage estimation."""
cache = SessionCache(CacheConfig(max_size=10))
initial_memory = cache.get_memory_usage()
# Add large session
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("large_data", {"data": "x" * 1000})
cache.put("session_123", session_data)
after_memory = cache.get_memory_usage()
assert after_memory > initial_memory
class TestSessionCacheManager:
"""Test session cache manager."""
def test_singleton_behavior(self):
"""Test that SessionCacheManager is a singleton."""
manager1 = SessionCacheManager()
manager2 = SessionCacheManager()
assert manager1 is manager2
def test_initialization(self):
"""Test session cache manager initialization."""
manager = SessionCacheManager()
assert manager.detection_cache is not None
assert manager.pipeline_cache is not None
assert manager.session_cache is not None
assert isinstance(manager.detection_cache, SessionCache)
assert isinstance(manager.pipeline_cache, SessionCache)
assert isinstance(manager.session_cache, SessionCache)
def test_cache_detection_result(self):
"""Test caching detection results."""
manager = SessionCacheManager()
manager.clear_all() # Start fresh
detection_data = {
"class": "car",
"confidence": 0.95,
"bbox": [100, 200, 300, 400],
"track_id": 1001
}
manager.cache_detection("camera_001", detection_data)
cached_detection = manager.get_cached_detection("camera_001")
assert cached_detection is not None
assert cached_detection["class"] == "car"
assert cached_detection["confidence"] == 0.95
assert cached_detection["track_id"] == 1001
def test_cache_pipeline_result(self):
"""Test caching pipeline results."""
manager = SessionCacheManager()
manager.clear_all()
pipeline_result = {
"status": "success",
"detections": [{"class": "car", "confidence": 0.9}],
"execution_time": 0.15,
"model_id": "yolo_v8"
}
manager.cache_pipeline_result("camera_001", pipeline_result)
cached_result = manager.get_cached_pipeline_result("camera_001")
assert cached_result is not None
assert cached_result["status"] == "success"
assert cached_result["execution_time"] == 0.15
assert len(cached_result["detections"]) == 1
def test_manage_session_data(self):
"""Test session data management."""
manager = SessionCacheManager()
manager.clear_all()
session_id = str(uuid.uuid4())
# Create session
manager.create_session(session_id, "camera_001", {"initial": "data"})
# Update session
manager.update_session_detection(session_id, {"car_brand": "Toyota"})
# Get session
session_data = manager.get_session_detection(session_id)
assert session_data is not None
assert "initial" in session_data
assert session_data["car_brand"] == "Toyota"
def test_set_latest_frame(self):
"""Test setting and getting latest frame."""
manager = SessionCacheManager()
manager.clear_all()
frame_data = b"fake_frame_data"
manager.set_latest_frame("camera_001", frame_data)
retrieved_frame = manager.get_latest_frame("camera_001")
assert retrieved_frame == frame_data
def test_frame_skip_flag_management(self):
"""Test frame skip flag management."""
manager = SessionCacheManager()
manager.clear_all()
# Initially should be False
assert manager.get_frame_skip_flag("camera_001") is False
# Set to True
manager.set_frame_skip_flag("camera_001", True)
assert manager.get_frame_skip_flag("camera_001") is True
# Set back to False
manager.set_frame_skip_flag("camera_001", False)
assert manager.get_frame_skip_flag("camera_001") is False
def test_cleanup_expired_sessions(self):
"""Test cleanup of expired sessions."""
manager = SessionCacheManager()
manager.clear_all()
# Create sessions with short TTL
manager.session_cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1))
# Add sessions
for i in range(3):
session_id = f"session_{i}"
manager.create_session(session_id, "camera_001", {"test": "data"})
assert manager.session_cache.size() == 3
# Wait for expiration
time.sleep(0.2)
# Cleanup
expired_count = manager.cleanup_expired_sessions()
assert expired_count == 3
assert manager.session_cache.size() == 0
def test_clear_camera_cache(self):
"""Test clearing cache for specific camera."""
manager = SessionCacheManager()
manager.clear_all()
# Add data for multiple cameras
manager.cache_detection("camera_001", {"class": "car"})
manager.cache_detection("camera_002", {"class": "truck"})
manager.cache_pipeline_result("camera_001", {"status": "success"})
manager.set_latest_frame("camera_001", b"frame1")
manager.set_latest_frame("camera_002", b"frame2")
# Clear camera_001 cache
manager.clear_camera_cache("camera_001")
# camera_001 data should be gone
assert manager.get_cached_detection("camera_001") is None
assert manager.get_cached_pipeline_result("camera_001") is None
assert manager.get_latest_frame("camera_001") is None
# camera_002 data should remain
assert manager.get_cached_detection("camera_002") is not None
assert manager.get_latest_frame("camera_002") is not None
def test_get_cache_statistics(self):
"""Test getting cache statistics."""
manager = SessionCacheManager()
manager.clear_all()
# Add some data to generate statistics
manager.cache_detection("camera_001", {"class": "car"})
manager.cache_pipeline_result("camera_001", {"status": "success"})
manager.create_session("session_123", "camera_001", {"initial": "data"})
# Access data to generate hits/misses
manager.get_cached_detection("camera_001") # Hit
manager.get_cached_detection("camera_002") # Miss
stats = manager.get_cache_statistics()
assert "detection_cache" in stats
assert "pipeline_cache" in stats
assert "session_cache" in stats
assert "total_memory_usage" in stats
detection_stats = stats["detection_cache"]
assert detection_stats["size"] >= 1
assert detection_stats["hits"] >= 1
assert detection_stats["misses"] >= 1
def test_memory_pressure_handling(self):
"""Test handling memory pressure."""
# Create manager with small cache sizes
config = CacheConfig(max_size=3)
manager = SessionCacheManager()
manager.detection_cache = SessionCache(config)
manager.pipeline_cache = SessionCache(config)
manager.session_cache = SessionCache(config)
# Fill caches beyond capacity
for i in range(5):
manager.cache_detection(f"camera_{i}", {"class": "car", "data": "x" * 100})
manager.cache_pipeline_result(f"camera_{i}", {"status": "success", "data": "y" * 100})
manager.create_session(f"session_{i}", f"camera_{i}", {"data": "z" * 100})
# Caches should not exceed max size due to eviction
assert manager.detection_cache.size() <= 3
assert manager.pipeline_cache.size() <= 3
assert manager.session_cache.size() <= 3
def test_concurrent_access_thread_safety(self):
"""Test thread safety of concurrent cache access."""
import threading
import concurrent.futures
manager = SessionCacheManager()
manager.clear_all()
results = []
errors = []
def cache_operation(thread_id):
try:
# Each thread performs multiple cache operations
for i in range(10):
session_id = f"thread_{thread_id}_session_{i}"
# Create session
manager.create_session(session_id, f"camera_{thread_id}", {"thread": thread_id, "index": i})
# Update session
manager.update_session_detection(session_id, {"updated": True})
# Read session
data = manager.get_session_detection(session_id)
if data and data.get("thread") == thread_id:
results.append((thread_id, i))
except Exception as e:
errors.append((thread_id, str(e)))
# Run operations concurrently
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(cache_operation, i) for i in range(5)]
concurrent.futures.wait(futures)
# Should have no errors and successful operations
assert len(errors) == 0
assert len(results) >= 25 # At least some operations should succeed
class TestSessionCacheIntegration:
"""Integration tests for session cache."""
def test_complete_detection_workflow(self):
"""Test complete detection workflow with caching."""
manager = SessionCacheManager()
manager.clear_all()
camera_id = "camera_001"
session_id = str(uuid.uuid4())
# 1. Cache initial detection
detection_data = {
"class": "car",
"confidence": 0.92,
"bbox": [100, 200, 300, 400],
"track_id": 1001,
"timestamp": int(time.time() * 1000)
}
manager.cache_detection(camera_id, detection_data)
# 2. Create session for tracking
initial_session_data = {
"detection_class": detection_data["class"],
"confidence": detection_data["confidence"],
"track_id": detection_data["track_id"]
}
manager.create_session(session_id, camera_id, initial_session_data)
# 3. Cache pipeline processing result
pipeline_result = {
"status": "processing",
"stage": "classification",
"detections": [detection_data],
"branches_completed": [],
"branches_pending": ["car_brand_cls", "car_bodytype_cls"]
}
manager.cache_pipeline_result(camera_id, pipeline_result)
# 4. Update session with classification results
classification_updates = [
{"car_brand": "Toyota", "brand_confidence": 0.87},
{"car_body_type": "Sedan", "bodytype_confidence": 0.82}
]
for update in classification_updates:
manager.update_session_detection(session_id, update)
# 5. Update pipeline result to completed
final_pipeline_result = {
"status": "completed",
"stage": "finished",
"detections": [detection_data],
"branches_completed": ["car_brand_cls", "car_bodytype_cls"],
"branches_pending": [],
"execution_time": 0.25
}
manager.cache_pipeline_result(camera_id, final_pipeline_result)
# 6. Verify all cached data
cached_detection = manager.get_cached_detection(camera_id)
cached_pipeline = manager.get_cached_pipeline_result(camera_id)
cached_session = manager.get_session_detection(session_id)
# Assertions
assert cached_detection["class"] == "car"
assert cached_detection["track_id"] == 1001
assert cached_pipeline["status"] == "completed"
assert len(cached_pipeline["branches_completed"]) == 2
assert cached_session["detection_class"] == "car"
assert cached_session["car_brand"] == "Toyota"
assert cached_session["car_body_type"] == "Sedan"
assert cached_session["brand_confidence"] == 0.87
def test_cache_performance_under_load(self):
"""Test cache performance under load."""
manager = SessionCacheManager()
manager.clear_all()
import time
# Measure performance of cache operations
start_time = time.time()
# Perform many cache operations
for i in range(1000):
camera_id = f"camera_{i % 10}" # 10 different cameras
session_id = f"session_{i}"
# Cache detection
detection_data = {
"class": "car",
"confidence": 0.9 + (i % 10) * 0.01,
"track_id": i,
"bbox": [i % 100, i % 100, (i % 100) + 200, (i % 100) + 200]
}
manager.cache_detection(camera_id, detection_data)
# Create session
manager.create_session(session_id, camera_id, {"index": i})
# Read back (every 10th operation)
if i % 10 == 0:
manager.get_cached_detection(camera_id)
manager.get_session_detection(session_id)
end_time = time.time()
total_time = end_time - start_time
# Should complete in reasonable time (less than 1 second)
assert total_time < 1.0
# Verify cache statistics
stats = manager.get_cache_statistics()
assert stats["detection_cache"]["size"] > 0
assert stats["session_cache"]["size"] > 0
assert stats["detection_cache"]["hits"] > 0
def test_cache_persistence_and_recovery(self):
"""Test cache persistence and recovery (if enabled)."""
# This test would be more meaningful with actual persistence
# For now, test the configuration and structure
persistence_config = CacheConfig(
max_size=100,
enable_persistence=True,
persistence_path="/tmp/detector_cache_test"
)
cache = SessionCache(persistence_config)
# Add some data
session_data = SessionData("session_123", "camera_001", "display_001")
session_data.add_detection_data("main", {"class": "car", "confidence": 0.95})
cache.put("session_123", session_data)
# Verify data exists
assert cache.contains("session_123") is True
# In a real implementation, this would test:
# 1. Saving cache to disk
# 2. Loading cache from disk
# 3. Verifying data integrity after reload

View file

@ -0,0 +1,818 @@
"""
Unit tests for stream management functionality.
"""
import pytest
import asyncio
import threading
import time
from unittest.mock import Mock, AsyncMock, patch, MagicMock
import numpy as np
import cv2
from detector_worker.streams.stream_manager import (
StreamManager,
StreamInfo,
StreamConfig,
StreamReader,
StreamError,
ConnectionError as StreamConnectionError
)
from detector_worker.streams.frame_reader import FrameReader
from detector_worker.core.exceptions import ConfigurationError
class TestStreamConfig:
"""Test stream configuration."""
def test_creation_rtsp(self):
"""Test creating RTSP stream config."""
config = StreamConfig(
stream_url="rtsp://example.com/stream1",
stream_type="rtsp",
target_fps=15,
reconnect_interval=5.0,
max_retries=3
)
assert config.stream_url == "rtsp://example.com/stream1"
assert config.stream_type == "rtsp"
assert config.target_fps == 15
assert config.reconnect_interval == 5.0
assert config.max_retries == 3
def test_creation_http_snapshot(self):
"""Test creating HTTP snapshot config."""
config = StreamConfig(
stream_url="http://example.com/snapshot.jpg",
stream_type="http_snapshot",
snapshot_interval=1.0,
timeout=10.0
)
assert config.stream_url == "http://example.com/snapshot.jpg"
assert config.stream_type == "http_snapshot"
assert config.snapshot_interval == 1.0
assert config.timeout == 10.0
def test_from_dict(self):
"""Test creating config from dictionary."""
config_dict = {
"stream_url": "rtsp://camera.example.com/live",
"stream_type": "rtsp",
"target_fps": 20,
"reconnect_interval": 3.0,
"max_retries": 5,
"crop_region": [100, 200, 300, 400],
"unknown_field": "ignored"
}
config = StreamConfig.from_dict(config_dict)
assert config.stream_url == "rtsp://camera.example.com/live"
assert config.target_fps == 20
assert config.crop_region == [100, 200, 300, 400]
def test_validation(self):
"""Test config validation."""
# Valid config
valid_config = StreamConfig(
stream_url="rtsp://example.com/stream",
stream_type="rtsp"
)
assert valid_config.is_valid() is True
# Invalid config (empty URL)
invalid_config = StreamConfig(
stream_url="",
stream_type="rtsp"
)
assert invalid_config.is_valid() is False
class TestStreamInfo:
"""Test stream information."""
def test_creation(self):
"""Test stream info creation."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
info = StreamInfo(
stream_id="stream_001",
config=config,
camera_id="camera_001"
)
assert info.stream_id == "stream_001"
assert info.config == config
assert info.camera_id == "camera_001"
assert info.status == "inactive"
assert info.reference_count == 0
assert info.created_at <= time.time()
def test_increment_reference(self):
"""Test incrementing reference count."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
info = StreamInfo("stream_001", config, "camera_001")
assert info.reference_count == 0
info.increment_reference()
assert info.reference_count == 1
info.increment_reference()
assert info.reference_count == 2
def test_decrement_reference(self):
"""Test decrementing reference count."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
info = StreamInfo("stream_001", config, "camera_001")
info.reference_count = 3
assert info.decrement_reference() == 2
assert info.reference_count == 2
assert info.decrement_reference() == 1
assert info.decrement_reference() == 0
# Should not go below 0
assert info.decrement_reference() == 0
def test_update_status(self):
"""Test updating stream status."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
info = StreamInfo("stream_001", config, "camera_001")
info.update_status("connecting")
assert info.status == "connecting"
assert info.last_update <= time.time()
info.update_status("active", frame_count=100)
assert info.status == "active"
assert info.frame_count == 100
def test_get_stats(self):
"""Test getting stream statistics."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
info = StreamInfo("stream_001", config, "camera_001")
info.frame_count = 1000
info.error_count = 5
info.reference_count = 2
stats = info.get_stats()
assert stats["stream_id"] == "stream_001"
assert stats["status"] == "inactive"
assert stats["frame_count"] == 1000
assert stats["error_count"] == 5
assert stats["reference_count"] == 2
assert "uptime" in stats
class TestStreamReader:
"""Test stream reader functionality."""
def test_creation(self):
"""Test stream reader creation."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
assert reader.stream_id == "stream_001"
assert reader.config == config
assert reader.is_running is False
assert reader.latest_frame is None
assert reader.frame_queue.qsize() == 0
@pytest.mark.asyncio
async def test_start_rtsp_stream(self):
"""Test starting RTSP stream."""
config = StreamConfig("rtsp://example.com/stream", "rtsp", target_fps=10)
reader = StreamReader("stream_001", config)
# Mock cv2.VideoCapture
with patch('cv2.VideoCapture') as mock_cap:
mock_cap_instance = Mock()
mock_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, np.zeros((480, 640, 3), dtype=np.uint8))
await reader.start()
assert reader.is_running is True
assert reader.capture is not None
mock_cap.assert_called_once_with("rtsp://example.com/stream")
@pytest.mark.asyncio
async def test_start_rtsp_connection_failure(self):
"""Test RTSP connection failure."""
config = StreamConfig("rtsp://invalid.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
with patch('cv2.VideoCapture') as mock_cap:
mock_cap_instance = Mock()
mock_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = False
with pytest.raises(StreamConnectionError):
await reader.start()
@pytest.mark.asyncio
async def test_start_http_snapshot(self):
"""Test starting HTTP snapshot stream."""
config = StreamConfig("http://example.com/snapshot.jpg", "http_snapshot", snapshot_interval=1.0)
reader = StreamReader("stream_001", config)
with patch('requests.get') as mock_get:
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"fake_image_data"
mock_get.return_value = mock_response
with patch('cv2.imdecode') as mock_decode:
mock_decode.return_value = np.zeros((480, 640, 3), dtype=np.uint8)
await reader.start()
assert reader.is_running is True
mock_get.assert_called_once()
@pytest.mark.asyncio
async def test_stop_stream(self):
"""Test stopping stream."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
# Simulate running state
reader.is_running = True
reader.capture = Mock()
reader.capture.release = Mock()
reader._reader_task = Mock()
reader._reader_task.cancel = Mock()
await reader.stop()
assert reader.is_running is False
reader.capture.release.assert_called_once()
reader._reader_task.cancel.assert_called_once()
def test_get_latest_frame(self):
"""Test getting latest frame."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
reader.latest_frame = test_frame
frame = reader.get_latest_frame()
assert np.array_equal(frame, test_frame)
def test_get_frame_from_queue(self):
"""Test getting frame from queue."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
reader.frame_queue.put(test_frame)
frame = reader.get_frame(timeout=0.1)
assert np.array_equal(frame, test_frame)
def test_get_frame_timeout(self):
"""Test getting frame with timeout."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
# Queue is empty, should timeout
frame = reader.get_frame(timeout=0.1)
assert frame is None
def test_get_stats(self):
"""Test getting reader statistics."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("stream_001", config)
reader.frame_count = 500
reader.error_count = 2
stats = reader.get_stats()
assert stats["stream_id"] == "stream_001"
assert stats["frame_count"] == 500
assert stats["error_count"] == 2
assert stats["is_running"] is False
class TestStreamManager:
"""Test stream manager functionality."""
def test_initialization(self):
"""Test stream manager initialization."""
manager = StreamManager()
assert len(manager.streams) == 0
assert len(manager.readers) == 0
assert manager.max_streams == 10
assert manager.default_timeout == 30.0
def test_initialization_with_config(self):
"""Test initialization with custom configuration."""
config = {
"max_streams": 20,
"default_timeout": 60.0,
"frame_buffer_size": 5
}
manager = StreamManager(config)
assert manager.max_streams == 20
assert manager.default_timeout == 60.0
assert manager.frame_buffer_size == 5
@pytest.mark.asyncio
async def test_create_stream_new(self):
"""Test creating new stream."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
stream_info = await manager.create_stream("camera_001", config, "sub_001")
assert "camera_001" in manager.streams
assert manager.streams["camera_001"].reference_count == 1
assert manager.streams["camera_001"].camera_id == "camera_001"
@pytest.mark.asyncio
async def test_create_stream_shared(self):
"""Test creating shared stream (same URL)."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
# Create first stream
stream_info1 = await manager.create_stream("camera_001", config, "sub_001")
# Create second stream with same URL
stream_info2 = await manager.create_stream("camera_001", config, "sub_002")
assert stream_info1 == stream_info2 # Should be same stream
assert manager.streams["camera_001"].reference_count == 2
@pytest.mark.asyncio
async def test_create_stream_max_limit(self):
"""Test creating stream when at max limit."""
manager = StreamManager({"max_streams": 1})
config1 = StreamConfig("rtsp://example.com/stream1", "rtsp")
config2 = StreamConfig("rtsp://example.com/stream2", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
# Create first stream (should succeed)
await manager.create_stream("camera_001", config1, "sub_001")
# Try to create second stream (should fail)
with pytest.raises(StreamError) as exc_info:
await manager.create_stream("camera_002", config2, "sub_002")
assert "maximum number of streams" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_remove_stream_single_reference(self):
"""Test removing stream with single reference."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
with patch.object(StreamReader, 'stop', new_callable=AsyncMock):
# Create stream
await manager.create_stream("camera_001", config, "sub_001")
# Remove stream
removed = await manager.remove_stream("camera_001", "sub_001")
assert removed is True
assert "camera_001" not in manager.streams
@pytest.mark.asyncio
async def test_remove_stream_multiple_references(self):
"""Test removing stream with multiple references."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
# Create shared stream
await manager.create_stream("camera_001", config, "sub_001")
await manager.create_stream("camera_001", config, "sub_002")
assert manager.streams["camera_001"].reference_count == 2
# Remove one reference
removed = await manager.remove_stream("camera_001", "sub_001")
assert removed is True
assert "camera_001" in manager.streams # Still exists
assert manager.streams["camera_001"].reference_count == 1
def test_get_stream_info(self):
"""Test getting stream information."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
stream_info = StreamInfo("camera_001", config, "camera_001")
manager.streams["camera_001"] = stream_info
retrieved_info = manager.get_stream_info("camera_001")
assert retrieved_info == stream_info
def test_get_nonexistent_stream_info(self):
"""Test getting info for non-existent stream."""
manager = StreamManager()
info = manager.get_stream_info("nonexistent_camera")
assert info is None
def test_get_latest_frame(self):
"""Test getting latest frame from stream."""
manager = StreamManager()
# Create mock reader
mock_reader = Mock()
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
mock_reader.get_latest_frame.return_value = test_frame
manager.readers["camera_001"] = mock_reader
frame = manager.get_latest_frame("camera_001")
assert np.array_equal(frame, test_frame)
mock_reader.get_latest_frame.assert_called_once()
def test_get_frame_from_nonexistent_stream(self):
"""Test getting frame from non-existent stream."""
manager = StreamManager()
frame = manager.get_latest_frame("nonexistent_camera")
assert frame is None
def test_list_active_streams(self):
"""Test listing active streams."""
manager = StreamManager()
# Add streams
config1 = StreamConfig("rtsp://example.com/stream1", "rtsp")
config2 = StreamConfig("rtsp://example.com/stream2", "rtsp")
stream1 = StreamInfo("camera_001", config1, "camera_001")
stream1.update_status("active")
stream2 = StreamInfo("camera_002", config2, "camera_002")
stream2.update_status("inactive")
manager.streams["camera_001"] = stream1
manager.streams["camera_002"] = stream2
active_streams = manager.list_active_streams()
assert len(active_streams) == 1
assert active_streams[0]["camera_id"] == "camera_001"
assert active_streams[0]["status"] == "active"
@pytest.mark.asyncio
async def test_stop_all_streams(self):
"""Test stopping all streams."""
manager = StreamManager()
# Add mock streams
mock_reader1 = Mock()
mock_reader1.stop = AsyncMock()
mock_reader2 = Mock()
mock_reader2.stop = AsyncMock()
manager.readers["camera_001"] = mock_reader1
manager.readers["camera_002"] = mock_reader2
stopped_count = await manager.stop_all_streams()
assert stopped_count == 2
mock_reader1.stop.assert_called_once()
mock_reader2.stop.assert_called_once()
assert len(manager.readers) == 0
assert len(manager.streams) == 0
def test_get_stream_statistics(self):
"""Test getting stream statistics."""
manager = StreamManager()
# Add streams
config = StreamConfig("rtsp://example.com/stream", "rtsp")
stream1 = StreamInfo("camera_001", config, "camera_001")
stream1.update_status("active")
stream1.frame_count = 1000
stream1.reference_count = 2
stream2 = StreamInfo("camera_002", config, "camera_002")
stream2.update_status("error")
stream2.error_count = 5
manager.streams["camera_001"] = stream1
manager.streams["camera_002"] = stream2
stats = manager.get_stream_statistics()
assert stats["total_streams"] == 2
assert stats["active_streams"] == 1
assert stats["error_streams"] == 1
assert stats["total_references"] == 2
assert "status_breakdown" in stats
@pytest.mark.asyncio
async def test_reconnect_stream(self):
"""Test reconnecting failed stream."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/stream", "rtsp")
stream_info = StreamInfo("camera_001", config, "camera_001")
stream_info.update_status("error")
manager.streams["camera_001"] = stream_info
# Mock reader
mock_reader = Mock()
mock_reader.start = AsyncMock()
mock_reader.stop = AsyncMock()
manager.readers["camera_001"] = mock_reader
result = await manager.reconnect_stream("camera_001")
assert result is True
mock_reader.stop.assert_called_once()
mock_reader.start.assert_called_once()
assert stream_info.status != "error"
@pytest.mark.asyncio
async def test_health_check_streams(self):
"""Test health check of all streams."""
manager = StreamManager()
# Add streams with different states
config = StreamConfig("rtsp://example.com/stream", "rtsp")
stream1 = StreamInfo("camera_001", config, "camera_001")
stream1.update_status("active")
stream2 = StreamInfo("camera_002", config, "camera_002")
stream2.update_status("error")
manager.streams["camera_001"] = stream1
manager.streams["camera_002"] = stream2
# Mock readers
mock_reader1 = Mock()
mock_reader1.is_running = True
mock_reader2 = Mock()
mock_reader2.is_running = False
manager.readers["camera_001"] = mock_reader1
manager.readers["camera_002"] = mock_reader2
health_report = await manager.health_check()
assert health_report["total_streams"] == 2
assert health_report["healthy_streams"] == 1
assert health_report["unhealthy_streams"] == 1
assert len(health_report["unhealthy_stream_ids"]) == 1
class TestStreamManagerIntegration:
"""Integration tests for stream manager."""
@pytest.mark.asyncio
async def test_multiple_subscribers_same_stream(self):
"""Test multiple subscribers to same stream."""
manager = StreamManager()
config = StreamConfig("rtsp://example.com/shared_stream", "rtsp")
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
# Multiple subscribers to same stream
stream1 = await manager.create_stream("camera_001", config, "sub_001")
stream2 = await manager.create_stream("camera_001", config, "sub_002")
stream3 = await manager.create_stream("camera_001", config, "sub_003")
# All should reference same stream
assert stream1 == stream2 == stream3
assert manager.streams["camera_001"].reference_count == 3
assert len(manager.readers) == 1 # Only one actual reader
# Remove subscribers one by one
with patch.object(StreamReader, 'stop', new_callable=AsyncMock) as mock_stop:
await manager.remove_stream("camera_001", "sub_001") # ref_count = 2
await manager.remove_stream("camera_001", "sub_002") # ref_count = 1
# Stream should still exist
assert "camera_001" in manager.streams
mock_stop.assert_not_called()
await manager.remove_stream("camera_001", "sub_003") # ref_count = 0
# Now stream should be stopped and removed
assert "camera_001" not in manager.streams
mock_stop.assert_called_once()
@pytest.mark.asyncio
async def test_stream_failure_and_recovery(self):
"""Test stream failure and recovery workflow."""
manager = StreamManager()
config = StreamConfig("rtsp://unreliable.com/stream", "rtsp", max_retries=2)
# Mock reader that fails initially then succeeds
with patch.object(StreamReader, 'start', new_callable=AsyncMock) as mock_start:
mock_start.side_effect = [
StreamConnectionError("Connection failed"), # First attempt fails
None # Second attempt succeeds
]
# First attempt should fail
with pytest.raises(StreamConnectionError):
await manager.create_stream("camera_001", config, "sub_001")
# Retry should succeed
stream_info = await manager.create_stream("camera_001", config, "sub_001")
assert stream_info is not None
assert mock_start.call_count == 2
@pytest.mark.asyncio
async def test_concurrent_stream_operations(self):
"""Test concurrent stream operations."""
manager = StreamManager()
configs = [
StreamConfig(f"rtsp://example.com/stream{i}", "rtsp")
for i in range(5)
]
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
with patch.object(StreamReader, 'stop', new_callable=AsyncMock):
# Create streams concurrently
create_tasks = [
manager.create_stream(f"camera_{i}", configs[i], f"sub_{i}")
for i in range(5)
]
results = await asyncio.gather(*create_tasks)
assert len(results) == 5
assert len(manager.streams) == 5
# Remove streams concurrently
remove_tasks = [
manager.remove_stream(f"camera_{i}", f"sub_{i}")
for i in range(5)
]
remove_results = await asyncio.gather(*remove_tasks)
assert all(remove_results)
assert len(manager.streams) == 0
@pytest.mark.asyncio
async def test_memory_management_large_scale(self):
"""Test memory management with many streams."""
manager = StreamManager({"max_streams": 50})
# Create many streams
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
for i in range(30):
config = StreamConfig(f"rtsp://example.com/stream{i}", "rtsp")
await manager.create_stream(f"camera_{i}", config, f"sub_{i}")
# Verify memory usage is reasonable
stats = manager.get_stream_statistics()
assert stats["total_streams"] == 30
assert stats["active_streams"] <= 30
# Test bulk cleanup
with patch.object(StreamReader, 'stop', new_callable=AsyncMock):
stopped_count = await manager.stop_all_streams()
assert stopped_count == 30
assert len(manager.streams) == 0
assert len(manager.readers) == 0
class TestFrameReaderIntegration:
"""Integration tests for frame reader."""
@pytest.mark.asyncio
async def test_rtsp_frame_processing(self):
"""Test RTSP frame processing pipeline."""
config = StreamConfig(
stream_url="rtsp://example.com/stream",
stream_type="rtsp",
target_fps=10,
crop_region=[100, 100, 400, 300]
)
reader = StreamReader("test_stream", config)
# Mock cv2.VideoCapture
with patch('cv2.VideoCapture') as mock_cap:
mock_cap_instance = Mock()
mock_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
# Mock frame sequence
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
mock_cap_instance.read.side_effect = [
(True, test_frame), # First frame
(True, test_frame * 0.8), # Second frame
(False, None), # Connection lost
(True, test_frame * 1.2), # Reconnected
]
await reader.start()
# Let reader process some frames
await asyncio.sleep(0.1)
# Verify frame processing
latest_frame = reader.get_latest_frame()
assert latest_frame is not None
assert latest_frame.shape == (480, 640, 3)
await reader.stop()
@pytest.mark.asyncio
async def test_http_snapshot_processing(self):
"""Test HTTP snapshot processing."""
config = StreamConfig(
stream_url="http://camera.example.com/snapshot.jpg",
stream_type="http_snapshot",
snapshot_interval=0.5,
timeout=5.0
)
reader = StreamReader("snapshot_stream", config)
with patch('requests.get') as mock_get:
# Mock HTTP responses
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"fake_jpeg_data"
mock_get.return_value = mock_response
with patch('cv2.imdecode') as mock_decode:
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 200
mock_decode.return_value = test_frame
await reader.start()
# Wait for snapshot capture
await asyncio.sleep(0.6)
# Verify snapshot processing
latest_frame = reader.get_latest_frame()
assert latest_frame is not None
assert np.array_equal(latest_frame, test_frame)
await reader.stop()
def test_frame_queue_management(self):
"""Test frame queue management and buffering."""
config = StreamConfig("rtsp://example.com/stream", "rtsp")
reader = StreamReader("queue_test", config, frame_buffer_size=3)
# Add frames to queue
frames = [
np.ones((100, 100, 3), dtype=np.uint8) * i
for i in range(50, 250, 50) # 4 different frames
]
for frame in frames[:3]: # Fill buffer
reader._add_frame_to_queue(frame)
assert reader.frame_queue.qsize() == 3
# Add one more (should drop oldest)
reader._add_frame_to_queue(frames[3])
assert reader.frame_queue.qsize() == 3
# Verify frame order (oldest should be dropped)
retrieved_frames = []
while not reader.frame_queue.empty():
retrieved_frames.append(reader.get_frame(timeout=0.1))
assert len(retrieved_frames) == 3
# First frame should have been dropped, so we should have frames 1,2,3
assert not np.array_equal(retrieved_frames[0], frames[0])