Refactor: PHASE 8: Testing & Integration
This commit is contained in:
parent
af34f4fd08
commit
9e8c6804a7
32 changed files with 17128 additions and 0 deletions
856
tests/unit/communication/test_websocket_handler.py
Normal file
856
tests/unit/communication/test_websocket_handler.py
Normal file
|
@ -0,0 +1,856 @@
|
|||
"""
|
||||
Unit tests for WebSocket handling functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
from fastapi.websockets import WebSocket, WebSocketDisconnect
|
||||
import uuid
|
||||
|
||||
from detector_worker.communication.websocket_handler import (
|
||||
WebSocketHandler,
|
||||
ConnectionManager,
|
||||
WebSocketConnection,
|
||||
MessageHandler,
|
||||
WebSocketError,
|
||||
ConnectionError as WSConnectionError
|
||||
)
|
||||
from detector_worker.communication.message_processor import MessageType
|
||||
from detector_worker.core.exceptions import MessageProcessingError
|
||||
|
||||
|
||||
class TestWebSocketConnection:
|
||||
"""Test WebSocket connection wrapper."""
|
||||
|
||||
def test_creation(self, mock_websocket):
|
||||
"""Test WebSocket connection creation."""
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
|
||||
assert connection.websocket == mock_websocket
|
||||
assert connection.client_id == "client_001"
|
||||
assert connection.is_connected is False
|
||||
assert connection.connected_at is None
|
||||
assert connection.last_ping is None
|
||||
assert connection.subscription_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_connection(self, mock_websocket):
|
||||
"""Test accepting WebSocket connection."""
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
|
||||
mock_websocket.accept = AsyncMock()
|
||||
|
||||
await connection.accept()
|
||||
|
||||
assert connection.is_connected is True
|
||||
assert connection.connected_at is not None
|
||||
mock_websocket.accept.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_json(self, mock_websocket):
|
||||
"""Test sending JSON message."""
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
|
||||
mock_websocket.send_json = AsyncMock()
|
||||
|
||||
message = {"type": "test", "data": "hello"}
|
||||
await connection.send_message(message)
|
||||
|
||||
mock_websocket.send_json.assert_called_once_with(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_text(self, mock_websocket):
|
||||
"""Test sending text message."""
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
|
||||
mock_websocket.send_text = AsyncMock()
|
||||
|
||||
await connection.send_message("hello world")
|
||||
|
||||
mock_websocket.send_text.assert_called_once_with("hello world")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_not_connected(self, mock_websocket):
|
||||
"""Test sending message when not connected."""
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
# Don't set is_connected = True
|
||||
|
||||
with pytest.raises(WebSocketError) as exc_info:
|
||||
await connection.send_message({"type": "test"})
|
||||
|
||||
assert "not connected" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_message_json(self, mock_websocket):
|
||||
"""Test receiving JSON message."""
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
|
||||
mock_websocket.receive_json = AsyncMock(return_value={"type": "test", "data": "received"})
|
||||
|
||||
message = await connection.receive_message()
|
||||
|
||||
assert message == {"type": "test", "data": "received"}
|
||||
mock_websocket.receive_json.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_message_text(self, mock_websocket):
|
||||
"""Test receiving text message."""
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
|
||||
# Mock receive_json to fail, then receive_text to succeed
|
||||
mock_websocket.receive_json = AsyncMock(side_effect=json.JSONDecodeError("Invalid JSON", "", 0))
|
||||
mock_websocket.receive_text = AsyncMock(return_value="plain text message")
|
||||
|
||||
message = await connection.receive_message()
|
||||
|
||||
assert message == "plain text message"
|
||||
mock_websocket.receive_text.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_pong(self, mock_websocket):
|
||||
"""Test ping/pong functionality."""
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
|
||||
mock_websocket.ping = AsyncMock()
|
||||
|
||||
await connection.ping()
|
||||
|
||||
assert connection.last_ping is not None
|
||||
mock_websocket.ping.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_connection(self, mock_websocket):
|
||||
"""Test closing connection."""
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
|
||||
mock_websocket.close = AsyncMock()
|
||||
|
||||
await connection.close(code=1000, reason="Normal closure")
|
||||
|
||||
assert connection.is_connected is False
|
||||
mock_websocket.close.assert_called_once_with(code=1000, reason="Normal closure")
|
||||
|
||||
def test_connection_info(self, mock_websocket):
|
||||
"""Test getting connection information."""
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
connection.subscription_id = "sub_123"
|
||||
|
||||
info = connection.get_connection_info()
|
||||
|
||||
assert info["client_id"] == "client_001"
|
||||
assert info["is_connected"] is True
|
||||
assert info["subscription_id"] == "sub_123"
|
||||
assert "connected_at" in info
|
||||
assert "last_ping" in info
|
||||
|
||||
|
||||
class TestConnectionManager:
|
||||
"""Test WebSocket connection management."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test connection manager initialization."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
assert len(manager.connections) == 0
|
||||
assert len(manager.subscriptions) == 0
|
||||
assert manager.max_connections == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_connection(self, mock_websocket):
|
||||
"""Test adding a connection."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
client_id = "client_001"
|
||||
connection = await manager.add_connection(mock_websocket, client_id)
|
||||
|
||||
assert connection.client_id == client_id
|
||||
assert client_id in manager.connections
|
||||
assert manager.get_connection_count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_connection(self, mock_websocket):
|
||||
"""Test removing a connection."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
client_id = "client_001"
|
||||
await manager.add_connection(mock_websocket, client_id)
|
||||
|
||||
assert client_id in manager.connections
|
||||
|
||||
removed_connection = await manager.remove_connection(client_id)
|
||||
|
||||
assert removed_connection is not None
|
||||
assert removed_connection.client_id == client_id
|
||||
assert client_id not in manager.connections
|
||||
assert manager.get_connection_count() == 0
|
||||
|
||||
def test_get_connection(self, mock_websocket):
|
||||
"""Test getting a connection."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
client_id = "client_001"
|
||||
# Manually add connection for testing
|
||||
connection = WebSocketConnection(mock_websocket, client_id)
|
||||
manager.connections[client_id] = connection
|
||||
|
||||
retrieved_connection = manager.get_connection(client_id)
|
||||
|
||||
assert retrieved_connection == connection
|
||||
assert retrieved_connection.client_id == client_id
|
||||
|
||||
def test_get_nonexistent_connection(self):
|
||||
"""Test getting non-existent connection."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
connection = manager.get_connection("nonexistent_client")
|
||||
|
||||
assert connection is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_message(self, mock_websocket):
|
||||
"""Test broadcasting message to all connections."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Add multiple connections
|
||||
connections = []
|
||||
for i in range(3):
|
||||
client_id = f"client_{i}"
|
||||
ws = Mock()
|
||||
ws.send_json = AsyncMock()
|
||||
connection = WebSocketConnection(ws, client_id)
|
||||
connection.is_connected = True
|
||||
manager.connections[client_id] = connection
|
||||
connections.append(connection)
|
||||
|
||||
message = {"type": "broadcast", "data": "hello all"}
|
||||
|
||||
await manager.broadcast(message)
|
||||
|
||||
# All connections should have received the message
|
||||
for connection in connections:
|
||||
connection.websocket.send_json.assert_called_once_with(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_to_subscription(self, mock_websocket):
|
||||
"""Test broadcasting to specific subscription."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Add connections with different subscriptions
|
||||
subscription_id = "camera_001"
|
||||
|
||||
# Connection with target subscription
|
||||
ws1 = Mock()
|
||||
ws1.send_json = AsyncMock()
|
||||
connection1 = WebSocketConnection(ws1, "client_001")
|
||||
connection1.is_connected = True
|
||||
connection1.subscription_id = subscription_id
|
||||
manager.connections["client_001"] = connection1
|
||||
manager.subscriptions[subscription_id] = {"client_001"}
|
||||
|
||||
# Connection with different subscription
|
||||
ws2 = Mock()
|
||||
ws2.send_json = AsyncMock()
|
||||
connection2 = WebSocketConnection(ws2, "client_002")
|
||||
connection2.is_connected = True
|
||||
connection2.subscription_id = "camera_002"
|
||||
manager.connections["client_002"] = connection2
|
||||
manager.subscriptions["camera_002"] = {"client_002"}
|
||||
|
||||
message = {"type": "detection", "data": "camera detection"}
|
||||
|
||||
await manager.broadcast_to_subscription(subscription_id, message)
|
||||
|
||||
# Only connection1 should have received the message
|
||||
ws1.send_json.assert_called_once_with(message)
|
||||
ws2.send_json.assert_not_called()
|
||||
|
||||
def test_add_subscription(self):
|
||||
"""Test adding subscription mapping."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
client_id = "client_001"
|
||||
subscription_id = "camera_001"
|
||||
|
||||
manager.add_subscription(client_id, subscription_id)
|
||||
|
||||
assert subscription_id in manager.subscriptions
|
||||
assert client_id in manager.subscriptions[subscription_id]
|
||||
|
||||
def test_remove_subscription(self):
|
||||
"""Test removing subscription mapping."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
client_id = "client_001"
|
||||
subscription_id = "camera_001"
|
||||
|
||||
# Add subscription first
|
||||
manager.add_subscription(client_id, subscription_id)
|
||||
assert client_id in manager.subscriptions[subscription_id]
|
||||
|
||||
# Remove subscription
|
||||
manager.remove_subscription(client_id, subscription_id)
|
||||
|
||||
assert client_id not in manager.subscriptions.get(subscription_id, set())
|
||||
|
||||
def test_get_subscription_clients(self):
|
||||
"""Test getting clients for a subscription."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
subscription_id = "camera_001"
|
||||
clients = ["client_001", "client_002", "client_003"]
|
||||
|
||||
for client_id in clients:
|
||||
manager.add_subscription(client_id, subscription_id)
|
||||
|
||||
subscription_clients = manager.get_subscription_clients(subscription_id)
|
||||
|
||||
assert subscription_clients == set(clients)
|
||||
|
||||
def test_get_client_subscriptions(self):
|
||||
"""Test getting subscriptions for a client."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
client_id = "client_001"
|
||||
subscriptions = ["camera_001", "camera_002", "camera_003"]
|
||||
|
||||
for subscription_id in subscriptions:
|
||||
manager.add_subscription(client_id, subscription_id)
|
||||
|
||||
client_subscriptions = manager.get_client_subscriptions(client_id)
|
||||
|
||||
assert client_subscriptions == set(subscriptions)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_disconnected_connections(self):
|
||||
"""Test cleanup of disconnected connections."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Add connected and disconnected connections
|
||||
ws1 = Mock()
|
||||
connection1 = WebSocketConnection(ws1, "client_001")
|
||||
connection1.is_connected = True
|
||||
manager.connections["client_001"] = connection1
|
||||
|
||||
ws2 = Mock()
|
||||
connection2 = WebSocketConnection(ws2, "client_002")
|
||||
connection2.is_connected = False # Disconnected
|
||||
manager.connections["client_002"] = connection2
|
||||
|
||||
# Add subscriptions
|
||||
manager.add_subscription("client_001", "camera_001")
|
||||
manager.add_subscription("client_002", "camera_002")
|
||||
|
||||
cleaned_count = await manager.cleanup_disconnected()
|
||||
|
||||
assert cleaned_count == 1
|
||||
assert "client_001" in manager.connections # Still connected
|
||||
assert "client_002" not in manager.connections # Cleaned up
|
||||
|
||||
# Subscriptions should also be cleaned up
|
||||
assert manager.get_client_subscriptions("client_002") == set()
|
||||
|
||||
def test_get_connection_stats(self):
|
||||
"""Test getting connection statistics."""
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Add various connections and subscriptions
|
||||
for i in range(3):
|
||||
client_id = f"client_{i}"
|
||||
ws = Mock()
|
||||
connection = WebSocketConnection(ws, client_id)
|
||||
connection.is_connected = i < 2 # First 2 connected, last one disconnected
|
||||
manager.connections[client_id] = connection
|
||||
|
||||
if i < 2: # Add subscriptions for connected clients
|
||||
manager.add_subscription(client_id, f"camera_{i}")
|
||||
|
||||
stats = manager.get_connection_stats()
|
||||
|
||||
assert stats["total_connections"] == 3
|
||||
assert stats["active_connections"] == 2
|
||||
assert stats["total_subscriptions"] == 2
|
||||
assert "uptime" in stats
|
||||
|
||||
|
||||
class TestMessageHandler:
|
||||
"""Test message handling functionality."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test message handler creation."""
|
||||
mock_processor = Mock()
|
||||
handler = MessageHandler(mock_processor)
|
||||
|
||||
assert handler.message_processor == mock_processor
|
||||
assert handler.connection_manager is None
|
||||
|
||||
def test_set_connection_manager(self):
|
||||
"""Test setting connection manager."""
|
||||
mock_processor = Mock()
|
||||
mock_manager = Mock()
|
||||
handler = MessageHandler(mock_processor)
|
||||
|
||||
handler.set_connection_manager(mock_manager)
|
||||
|
||||
assert handler.connection_manager == mock_manager
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_success(self, mock_websocket):
|
||||
"""Test successful message handling."""
|
||||
mock_processor = Mock()
|
||||
mock_processor.process_message = AsyncMock(return_value={"type": "response", "status": "success"})
|
||||
|
||||
handler = MessageHandler(mock_processor)
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
|
||||
message = {"type": "subscribe", "payload": {"camera_id": "camera_001"}}
|
||||
|
||||
response = await handler.handle_message(connection, message)
|
||||
|
||||
assert response["status"] == "success"
|
||||
mock_processor.process_message.assert_called_once_with(message, "client_001")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_processing_error(self, mock_websocket):
|
||||
"""Test message handling with processing error."""
|
||||
mock_processor = Mock()
|
||||
mock_processor.process_message = AsyncMock(side_effect=MessageProcessingError("Invalid message"))
|
||||
|
||||
handler = MessageHandler(mock_processor)
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
|
||||
message = {"type": "invalid", "payload": {}}
|
||||
|
||||
response = await handler.handle_message(connection, message)
|
||||
|
||||
assert response["type"] == "error"
|
||||
assert "Invalid message" in response["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_unexpected_error(self, mock_websocket):
|
||||
"""Test message handling with unexpected error."""
|
||||
mock_processor = Mock()
|
||||
mock_processor.process_message = AsyncMock(side_effect=Exception("Unexpected error"))
|
||||
|
||||
handler = MessageHandler(mock_processor)
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
|
||||
message = {"type": "test", "payload": {}}
|
||||
|
||||
response = await handler.handle_message(connection, message)
|
||||
|
||||
assert response["type"] == "error"
|
||||
assert "internal error" in response["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_response(self, mock_websocket):
|
||||
"""Test sending response to client."""
|
||||
mock_processor = Mock()
|
||||
handler = MessageHandler(mock_processor)
|
||||
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
mock_websocket.send_json = AsyncMock()
|
||||
|
||||
response = {"type": "response", "data": "test response"}
|
||||
|
||||
await handler.send_response(connection, response)
|
||||
|
||||
mock_websocket.send_json.assert_called_once_with(response)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_error_response(self, mock_websocket):
|
||||
"""Test sending error response."""
|
||||
mock_processor = Mock()
|
||||
handler = MessageHandler(mock_processor)
|
||||
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
mock_websocket.send_json = AsyncMock()
|
||||
|
||||
await handler.send_error_response(connection, "Test error message", "TEST_ERROR")
|
||||
|
||||
mock_websocket.send_json.assert_called_once()
|
||||
call_args = mock_websocket.send_json.call_args[0][0]
|
||||
|
||||
assert call_args["type"] == "error"
|
||||
assert call_args["message"] == "Test error message"
|
||||
assert call_args["error_code"] == "TEST_ERROR"
|
||||
|
||||
|
||||
class TestWebSocketHandler:
|
||||
"""Test main WebSocket handler functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test WebSocket handler initialization."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
assert isinstance(handler.connection_manager, ConnectionManager)
|
||||
assert isinstance(handler.message_handler, MessageHandler)
|
||||
assert handler.message_handler.connection_manager == handler.connection_manager
|
||||
assert handler.heartbeat_interval == 30.0
|
||||
assert handler.max_connections == 100
|
||||
|
||||
def test_initialization_with_config(self):
|
||||
"""Test initialization with custom configuration."""
|
||||
mock_processor = Mock()
|
||||
config = {
|
||||
"heartbeat_interval": 60.0,
|
||||
"max_connections": 200,
|
||||
"connection_timeout": 300.0
|
||||
}
|
||||
|
||||
handler = WebSocketHandler(mock_processor, config)
|
||||
|
||||
assert handler.heartbeat_interval == 60.0
|
||||
assert handler.max_connections == 200
|
||||
assert handler.connection_timeout == 300.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_websocket_connection(self, mock_websocket):
|
||||
"""Test handling WebSocket connection."""
|
||||
mock_processor = Mock()
|
||||
mock_processor.process_message = AsyncMock(return_value={"type": "ack", "status": "success"})
|
||||
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
# Mock WebSocket behavior
|
||||
mock_websocket.accept = AsyncMock()
|
||||
mock_websocket.receive_json = AsyncMock(side_effect=[
|
||||
{"type": "subscribe", "payload": {"camera_id": "camera_001"}},
|
||||
WebSocketDisconnect() # Simulate disconnection
|
||||
])
|
||||
mock_websocket.send_json = AsyncMock()
|
||||
|
||||
client_id = "test_client_001"
|
||||
|
||||
# Handle connection (should not raise exception)
|
||||
await handler.handle_websocket(mock_websocket, client_id)
|
||||
|
||||
# Verify connection was accepted
|
||||
mock_websocket.accept.assert_called_once()
|
||||
|
||||
# Verify message was processed
|
||||
mock_processor.process_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_websocket_max_connections(self, mock_websocket):
|
||||
"""Test handling max connections limit."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor, {"max_connections": 1})
|
||||
|
||||
# Add one connection to reach limit
|
||||
client1_ws = Mock()
|
||||
connection1 = WebSocketConnection(client1_ws, "client_001")
|
||||
handler.connection_manager.connections["client_001"] = connection1
|
||||
|
||||
mock_websocket.close = AsyncMock()
|
||||
|
||||
# Try to add second connection
|
||||
await handler.handle_websocket(mock_websocket, "client_002")
|
||||
|
||||
# Should close connection due to limit
|
||||
mock_websocket.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_message(self):
|
||||
"""Test broadcasting message to all connections."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
# Mock connection manager
|
||||
handler.connection_manager.broadcast = AsyncMock()
|
||||
|
||||
message = {"type": "system", "data": "Server maintenance in 10 minutes"}
|
||||
|
||||
await handler.broadcast_message(message)
|
||||
|
||||
handler.connection_manager.broadcast.assert_called_once_with(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_client(self):
|
||||
"""Test sending message to specific client."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
# Create mock connection
|
||||
mock_websocket = Mock()
|
||||
mock_websocket.send_json = AsyncMock()
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
|
||||
handler.connection_manager.connections["client_001"] = connection
|
||||
|
||||
message = {"type": "notification", "data": "Personal message"}
|
||||
|
||||
result = await handler.send_to_client("client_001", message)
|
||||
|
||||
assert result is True
|
||||
mock_websocket.send_json.assert_called_once_with(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_nonexistent_client(self):
|
||||
"""Test sending message to non-existent client."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
message = {"type": "notification", "data": "Message"}
|
||||
|
||||
result = await handler.send_to_client("nonexistent_client", message)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_subscription(self):
|
||||
"""Test sending message to subscription."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
# Mock connection manager
|
||||
handler.connection_manager.broadcast_to_subscription = AsyncMock()
|
||||
|
||||
subscription_id = "camera_001"
|
||||
message = {"type": "detection", "data": {"class": "car", "confidence": 0.95}}
|
||||
|
||||
await handler.send_to_subscription(subscription_id, message)
|
||||
|
||||
handler.connection_manager.broadcast_to_subscription.assert_called_once_with(subscription_id, message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_heartbeat_task(self):
|
||||
"""Test starting heartbeat task."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor, {"heartbeat_interval": 0.1})
|
||||
|
||||
# Mock connection with ping capability
|
||||
mock_websocket = Mock()
|
||||
mock_websocket.ping = AsyncMock()
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
handler.connection_manager.connections["client_001"] = connection
|
||||
|
||||
# Start heartbeat task
|
||||
heartbeat_task = asyncio.create_task(handler._heartbeat_loop())
|
||||
|
||||
# Let it run briefly
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Cancel task
|
||||
heartbeat_task.cancel()
|
||||
|
||||
try:
|
||||
await heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Should have sent at least one ping
|
||||
assert mock_websocket.ping.called
|
||||
|
||||
def test_get_connection_stats(self):
|
||||
"""Test getting connection statistics."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
# Add some mock connections
|
||||
for i in range(3):
|
||||
client_id = f"client_{i}"
|
||||
ws = Mock()
|
||||
connection = WebSocketConnection(ws, client_id)
|
||||
connection.is_connected = True
|
||||
handler.connection_manager.connections[client_id] = connection
|
||||
|
||||
stats = handler.get_connection_stats()
|
||||
|
||||
assert stats["total_connections"] == 3
|
||||
assert stats["active_connections"] == 3
|
||||
|
||||
def test_get_client_info(self):
|
||||
"""Test getting client information."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
# Add mock connection
|
||||
mock_websocket = Mock()
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
connection.subscription_id = "camera_001"
|
||||
handler.connection_manager.connections["client_001"] = connection
|
||||
|
||||
info = handler.get_client_info("client_001")
|
||||
|
||||
assert info is not None
|
||||
assert info["client_id"] == "client_001"
|
||||
assert info["is_connected"] is True
|
||||
assert info["subscription_id"] == "camera_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_client(self):
|
||||
"""Test disconnecting specific client."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
# Add mock connection
|
||||
mock_websocket = Mock()
|
||||
mock_websocket.close = AsyncMock()
|
||||
connection = WebSocketConnection(mock_websocket, "client_001")
|
||||
connection.is_connected = True
|
||||
handler.connection_manager.connections["client_001"] = connection
|
||||
|
||||
result = await handler.disconnect_client("client_001", code=1000, reason="Admin disconnect")
|
||||
|
||||
assert result is True
|
||||
mock_websocket.close.assert_called_once_with(code=1000, reason="Admin disconnect")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_connections(self):
|
||||
"""Test cleanup of disconnected connections."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
# Mock connection manager cleanup
|
||||
handler.connection_manager.cleanup_disconnected = AsyncMock(return_value=2)
|
||||
|
||||
cleaned_count = await handler.cleanup_connections()
|
||||
|
||||
assert cleaned_count == 2
|
||||
handler.connection_manager.cleanup_disconnected.assert_called_once()
|
||||
|
||||
|
||||
class TestWebSocketHandlerIntegration:
|
||||
"""Integration tests for WebSocket handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_subscription_workflow(self, mock_websocket):
|
||||
"""Test complete subscription workflow."""
|
||||
mock_processor = Mock()
|
||||
|
||||
# Mock processor responses
|
||||
mock_processor.process_message = AsyncMock(side_effect=[
|
||||
{"type": "subscribeAck", "status": "success", "subscription_id": "camera_001"},
|
||||
{"type": "unsubscribeAck", "status": "success"}
|
||||
])
|
||||
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
# Mock WebSocket behavior
|
||||
mock_websocket.accept = AsyncMock()
|
||||
mock_websocket.send_json = AsyncMock()
|
||||
mock_websocket.receive_json = AsyncMock(side_effect=[
|
||||
{"type": "subscribe", "payload": {"camera_id": "camera_001", "rtsp_url": "rtsp://example.com"}},
|
||||
{"type": "unsubscribe", "payload": {"subscription_id": "camera_001"}},
|
||||
WebSocketDisconnect()
|
||||
])
|
||||
|
||||
client_id = "test_client"
|
||||
|
||||
# Handle complete workflow
|
||||
await handler.handle_websocket(mock_websocket, client_id)
|
||||
|
||||
# Verify both messages were processed
|
||||
assert mock_processor.process_message.call_count == 2
|
||||
|
||||
# Verify responses were sent
|
||||
assert mock_websocket.send_json.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_client_management(self):
|
||||
"""Test managing multiple concurrent clients."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
clients = []
|
||||
for i in range(5):
|
||||
client_id = f"client_{i}"
|
||||
mock_ws = Mock()
|
||||
mock_ws.send_json = AsyncMock()
|
||||
|
||||
connection = WebSocketConnection(mock_ws, client_id)
|
||||
connection.is_connected = True
|
||||
handler.connection_manager.connections[client_id] = connection
|
||||
clients.append(connection)
|
||||
|
||||
# Test broadcasting to all clients
|
||||
message = {"type": "broadcast", "data": "Hello all clients"}
|
||||
await handler.broadcast_message(message)
|
||||
|
||||
# All clients should receive the message
|
||||
for connection in clients:
|
||||
connection.websocket.send_json.assert_called_once_with(message)
|
||||
|
||||
# Test subscription-specific messaging
|
||||
subscription_id = "camera_001"
|
||||
handler.connection_manager.add_subscription("client_0", subscription_id)
|
||||
handler.connection_manager.add_subscription("client_2", subscription_id)
|
||||
|
||||
subscription_message = {"type": "detection", "camera_id": "camera_001"}
|
||||
await handler.send_to_subscription(subscription_id, subscription_message)
|
||||
|
||||
# Only subscribed clients should receive the message
|
||||
# Note: This would require additional mocking of broadcast_to_subscription
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_and_recovery(self, mock_websocket):
|
||||
"""Test error handling and recovery scenarios."""
|
||||
mock_processor = Mock()
|
||||
|
||||
# First message causes error, second succeeds
|
||||
mock_processor.process_message = AsyncMock(side_effect=[
|
||||
MessageProcessingError("Invalid message format"),
|
||||
{"type": "ack", "status": "success"}
|
||||
])
|
||||
|
||||
handler = WebSocketHandler(mock_processor)
|
||||
|
||||
mock_websocket.accept = AsyncMock()
|
||||
mock_websocket.send_json = AsyncMock()
|
||||
mock_websocket.receive_json = AsyncMock(side_effect=[
|
||||
{"type": "invalid", "malformed": True},
|
||||
{"type": "valid", "payload": {"test": True}},
|
||||
WebSocketDisconnect()
|
||||
])
|
||||
|
||||
client_id = "error_test_client"
|
||||
|
||||
# Should handle errors gracefully and continue processing
|
||||
await handler.handle_websocket(mock_websocket, client_id)
|
||||
|
||||
# Both messages should have been processed
|
||||
assert mock_processor.process_message.call_count == 2
|
||||
|
||||
# Should have sent error response and success response
|
||||
assert mock_websocket.send_json.call_count == 2
|
||||
|
||||
# First call should be error response
|
||||
first_response = mock_websocket.send_json.call_args_list[0][0][0]
|
||||
assert first_response["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_timeout_handling(self):
|
||||
"""Test connection timeout handling."""
|
||||
mock_processor = Mock()
|
||||
handler = WebSocketHandler(mock_processor, {"connection_timeout": 0.1})
|
||||
|
||||
# Add connection that hasn't been active
|
||||
mock_websocket = Mock()
|
||||
connection = WebSocketConnection(mock_websocket, "timeout_client")
|
||||
connection.is_connected = True
|
||||
# Don't update last_ping to simulate timeout
|
||||
|
||||
handler.connection_manager.connections["timeout_client"] = connection
|
||||
|
||||
# Wait longer than timeout
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Manual cleanup (in real implementation this would be automatic)
|
||||
cleaned = await handler.cleanup_connections()
|
||||
|
||||
# Connection should be identified for cleanup
|
||||
# (Actual timeout logic would need to be implemented in the cleanup method)
|
429
tests/unit/core/test_config.py
Normal file
429
tests/unit/core/test_config.py
Normal file
|
@ -0,0 +1,429 @@
|
|||
"""
|
||||
Unit tests for configuration management system.
|
||||
"""
|
||||
import pytest
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from detector_worker.core.config import (
|
||||
ConfigurationManager,
|
||||
JsonFileProvider,
|
||||
EnvironmentProvider,
|
||||
DatabaseConfig,
|
||||
RedisConfig,
|
||||
StreamConfig,
|
||||
ModelConfig,
|
||||
LoggingConfig,
|
||||
get_config_manager,
|
||||
validate_config
|
||||
)
|
||||
from detector_worker.core.exceptions import ConfigurationError
|
||||
|
||||
|
||||
class TestJsonFileProvider:
|
||||
"""Test JSON file configuration provider."""
|
||||
|
||||
def test_get_config_from_valid_file(self, temp_dir):
|
||||
"""Test loading configuration from a valid JSON file."""
|
||||
config_data = {"test_key": "test_value", "number": 42}
|
||||
config_file = os.path.join(temp_dir, "config.json")
|
||||
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
provider = JsonFileProvider(config_file)
|
||||
result = provider.get_config()
|
||||
|
||||
assert result == config_data
|
||||
|
||||
def test_get_config_file_not_exists(self, temp_dir):
|
||||
"""Test handling of non-existent config file."""
|
||||
config_file = os.path.join(temp_dir, "nonexistent.json")
|
||||
provider = JsonFileProvider(config_file)
|
||||
|
||||
result = provider.get_config()
|
||||
assert result == {}
|
||||
|
||||
def test_get_config_invalid_json(self, temp_dir):
|
||||
"""Test handling of invalid JSON file."""
|
||||
config_file = os.path.join(temp_dir, "invalid.json")
|
||||
|
||||
with open(config_file, 'w') as f:
|
||||
f.write("invalid json content")
|
||||
|
||||
provider = JsonFileProvider(config_file)
|
||||
result = provider.get_config()
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_reload_updates_config(self, temp_dir):
|
||||
"""Test that reload updates configuration."""
|
||||
config_file = os.path.join(temp_dir, "config.json")
|
||||
|
||||
# Initial config
|
||||
initial_config = {"version": 1}
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(initial_config, f)
|
||||
|
||||
provider = JsonFileProvider(config_file)
|
||||
assert provider.get_config() == initial_config
|
||||
|
||||
# Update config file
|
||||
updated_config = {"version": 2}
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(updated_config, f)
|
||||
|
||||
# Force reload
|
||||
provider.reload()
|
||||
assert provider.get_config() == updated_config
|
||||
|
||||
|
||||
class TestEnvironmentProvider:
|
||||
"""Test environment variable configuration provider."""
|
||||
|
||||
def test_get_config_with_env_vars(self):
|
||||
"""Test loading configuration from environment variables."""
|
||||
env_vars = {
|
||||
"DETECTOR_MAX_STREAMS": "10",
|
||||
"DETECTOR_TARGET_FPS": "15",
|
||||
"DETECTOR_CONFIG": '{"nested": "value"}',
|
||||
"OTHER_VAR": "ignored"
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
provider = EnvironmentProvider("DETECTOR_")
|
||||
config = provider.get_config()
|
||||
|
||||
assert config["max_streams"] == "10"
|
||||
assert config["target_fps"] == "15"
|
||||
assert config["config"] == {"nested": "value"}
|
||||
assert "other_var" not in config
|
||||
|
||||
def test_get_config_no_env_vars(self):
|
||||
"""Test with no matching environment variables."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
provider = EnvironmentProvider("DETECTOR_")
|
||||
config = provider.get_config()
|
||||
|
||||
assert config == {}
|
||||
|
||||
def test_custom_prefix(self):
|
||||
"""Test with custom prefix."""
|
||||
env_vars = {"CUSTOM_TEST": "value"}
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
provider = EnvironmentProvider("CUSTOM_")
|
||||
config = provider.get_config()
|
||||
|
||||
assert config["test"] == "value"
|
||||
|
||||
|
||||
class TestConfigDataclasses:
|
||||
"""Test configuration dataclasses."""
|
||||
|
||||
def test_database_config_from_dict(self):
|
||||
"""Test DatabaseConfig creation from dictionary."""
|
||||
data = {
|
||||
"enabled": True,
|
||||
"host": "db.example.com",
|
||||
"port": 5432,
|
||||
"database": "testdb",
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
"schema": "test_schema",
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = DatabaseConfig.from_dict(data)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.host == "db.example.com"
|
||||
assert config.port == 5432
|
||||
assert config.database == "testdb"
|
||||
assert config.username == "user"
|
||||
assert config.password == "pass"
|
||||
assert config.schema == "test_schema"
|
||||
# Unknown fields should be ignored
|
||||
assert not hasattr(config, 'unknown_field')
|
||||
|
||||
def test_redis_config_from_dict(self):
|
||||
"""Test RedisConfig creation from dictionary."""
|
||||
data = {
|
||||
"enabled": True,
|
||||
"host": "redis.example.com",
|
||||
"port": 6379,
|
||||
"password": "secret",
|
||||
"db": 1
|
||||
}
|
||||
|
||||
config = RedisConfig.from_dict(data)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.host == "redis.example.com"
|
||||
assert config.port == 6379
|
||||
assert config.password == "secret"
|
||||
assert config.db == 1
|
||||
|
||||
def test_stream_config_from_dict(self):
|
||||
"""Test StreamConfig creation from dictionary."""
|
||||
data = {
|
||||
"poll_interval_ms": 50,
|
||||
"max_streams": 10,
|
||||
"target_fps": 20,
|
||||
"reconnect_interval_sec": 10,
|
||||
"max_retries": 5
|
||||
}
|
||||
|
||||
config = StreamConfig.from_dict(data)
|
||||
|
||||
assert config.poll_interval_ms == 50
|
||||
assert config.max_streams == 10
|
||||
assert config.target_fps == 20
|
||||
assert config.reconnect_interval_sec == 10
|
||||
assert config.max_retries == 5
|
||||
|
||||
|
||||
class TestConfigurationManager:
|
||||
"""Test main configuration manager."""
|
||||
|
||||
def test_initialization_with_defaults(self):
|
||||
"""Test that manager initializes with default values."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Should have default providers
|
||||
assert len(manager._providers) >= 1
|
||||
|
||||
# Should have default configuration values
|
||||
config = manager.get_all()
|
||||
assert "poll_interval_ms" in config
|
||||
assert "max_streams" in config
|
||||
assert "target_fps" in config
|
||||
|
||||
def test_add_provider(self):
|
||||
"""Test adding configuration providers."""
|
||||
manager = ConfigurationManager()
|
||||
initial_count = len(manager._providers)
|
||||
|
||||
mock_provider = Mock()
|
||||
mock_provider.get_config.return_value = {"test": "value"}
|
||||
|
||||
manager.add_provider(mock_provider)
|
||||
|
||||
assert len(manager._providers) == initial_count + 1
|
||||
|
||||
def test_get_configuration_value(self):
|
||||
"""Test getting specific configuration values."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Test existing key
|
||||
value = manager.get("poll_interval_ms")
|
||||
assert value is not None
|
||||
|
||||
# Test non-existing key with default
|
||||
value = manager.get("nonexistent", "default")
|
||||
assert value == "default"
|
||||
|
||||
# Test non-existing key without default
|
||||
value = manager.get("nonexistent")
|
||||
assert value is None
|
||||
|
||||
def test_get_section(self):
|
||||
"""Test getting configuration sections."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Test existing section
|
||||
db_section = manager.get_section("database")
|
||||
assert isinstance(db_section, dict)
|
||||
|
||||
# Test non-existing section
|
||||
empty_section = manager.get_section("nonexistent")
|
||||
assert empty_section == {}
|
||||
|
||||
def test_typed_config_access(self):
|
||||
"""Test typed configuration object access."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Test database config
|
||||
db_config = manager.get_database_config()
|
||||
assert isinstance(db_config, DatabaseConfig)
|
||||
|
||||
# Test Redis config
|
||||
redis_config = manager.get_redis_config()
|
||||
assert isinstance(redis_config, RedisConfig)
|
||||
|
||||
# Test stream config
|
||||
stream_config = manager.get_stream_config()
|
||||
assert isinstance(stream_config, StreamConfig)
|
||||
|
||||
# Test model config
|
||||
model_config = manager.get_model_config()
|
||||
assert isinstance(model_config, ModelConfig)
|
||||
|
||||
# Test logging config
|
||||
logging_config = manager.get_logging_config()
|
||||
assert isinstance(logging_config, LoggingConfig)
|
||||
|
||||
def test_set_configuration_value(self):
|
||||
"""Test setting configuration values at runtime."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
manager.set("test_key", "test_value")
|
||||
|
||||
assert manager.get("test_key") == "test_value"
|
||||
|
||||
# Should also update typed configs
|
||||
manager.set("poll_interval_ms", 200)
|
||||
stream_config = manager.get_stream_config()
|
||||
assert stream_config.poll_interval_ms == 200
|
||||
|
||||
def test_validation_success(self):
|
||||
"""Test configuration validation with valid config."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Set valid configuration
|
||||
manager.set("poll_interval_ms", 100)
|
||||
manager.set("max_streams", 5)
|
||||
manager.set("target_fps", 10)
|
||||
|
||||
errors = manager.validate()
|
||||
assert errors == []
|
||||
assert manager.is_valid() is True
|
||||
|
||||
def test_validation_errors(self):
|
||||
"""Test configuration validation with invalid values."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Set invalid configuration
|
||||
manager.set("poll_interval_ms", 0)
|
||||
manager.set("max_streams", -1)
|
||||
manager.set("target_fps", 0)
|
||||
|
||||
errors = manager.validate()
|
||||
assert len(errors) > 0
|
||||
assert manager.is_valid() is False
|
||||
|
||||
# Check specific errors
|
||||
error_messages = " ".join(errors)
|
||||
assert "poll_interval_ms must be positive" in error_messages
|
||||
assert "max_streams must be positive" in error_messages
|
||||
assert "target_fps must be positive" in error_messages
|
||||
|
||||
def test_database_validation(self):
|
||||
"""Test database-specific validation."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Enable database but don't provide required fields
|
||||
db_config = {
|
||||
"enabled": True,
|
||||
"host": "",
|
||||
"database": ""
|
||||
}
|
||||
manager.set("database", db_config)
|
||||
|
||||
errors = manager.validate()
|
||||
error_messages = " ".join(errors)
|
||||
|
||||
assert "database host is required" in error_messages
|
||||
assert "database name is required" in error_messages
|
||||
|
||||
def test_redis_validation(self):
|
||||
"""Test Redis-specific validation."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Enable Redis but don't provide required fields
|
||||
redis_config = {
|
||||
"enabled": True,
|
||||
"host": ""
|
||||
}
|
||||
manager.set("redis", redis_config)
|
||||
|
||||
errors = manager.validate()
|
||||
error_messages = " ".join(errors)
|
||||
|
||||
assert "redis host is required" in error_messages
|
||||
|
||||
|
||||
class TestGlobalConfigurationFunctions:
|
||||
"""Test global configuration functions."""
|
||||
|
||||
def test_get_config_manager_singleton(self):
|
||||
"""Test that get_config_manager returns a singleton."""
|
||||
manager1 = get_config_manager()
|
||||
manager2 = get_config_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
@patch('detector_worker.core.config.get_config_manager')
|
||||
def test_validate_config_function(self, mock_get_manager):
|
||||
"""Test global validate_config function."""
|
||||
mock_manager = Mock()
|
||||
mock_manager.validate.return_value = ["error1", "error2"]
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
errors = validate_config()
|
||||
|
||||
assert errors == ["error1", "error2"]
|
||||
mock_manager.validate.assert_called_once()
|
||||
|
||||
|
||||
class TestConfigurationIntegration:
|
||||
"""Integration tests for configuration system."""
|
||||
|
||||
def test_provider_priority(self, temp_dir):
|
||||
"""Test that later providers override earlier ones."""
|
||||
# Create JSON file with initial config
|
||||
config_file = os.path.join(temp_dir, "config.json")
|
||||
json_config = {"test_value": "from_json", "json_only": "json"}
|
||||
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(json_config, f)
|
||||
|
||||
# Set environment variable that should override
|
||||
env_vars = {"DETECTOR_TEST_VALUE": "from_env", "DETECTOR_ENV_ONLY": "env"}
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
manager = ConfigurationManager()
|
||||
manager._providers.clear() # Start fresh
|
||||
|
||||
# Add providers in order
|
||||
manager.add_provider(JsonFileProvider(config_file))
|
||||
manager.add_provider(EnvironmentProvider("DETECTOR_"))
|
||||
|
||||
config = manager.get_all()
|
||||
|
||||
# Environment should override JSON
|
||||
assert config["test_value"] == "from_env"
|
||||
|
||||
# Both sources should be present
|
||||
assert config["json_only"] == "json"
|
||||
assert config["env_only"] == "env"
|
||||
|
||||
def test_hot_reload(self, temp_dir):
|
||||
"""Test configuration hot reload functionality."""
|
||||
config_file = os.path.join(temp_dir, "config.json")
|
||||
|
||||
# Initial config
|
||||
initial_config = {"version": 1, "feature_enabled": False}
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(initial_config, f)
|
||||
|
||||
manager = ConfigurationManager()
|
||||
manager._providers.clear()
|
||||
manager.add_provider(JsonFileProvider(config_file))
|
||||
|
||||
assert manager.get("version") == 1
|
||||
assert manager.get("feature_enabled") is False
|
||||
|
||||
# Update config file
|
||||
updated_config = {"version": 2, "feature_enabled": True}
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(updated_config, f)
|
||||
|
||||
# Reload configuration
|
||||
success = manager.reload()
|
||||
assert success is True
|
||||
|
||||
assert manager.get("version") == 2
|
||||
assert manager.get("feature_enabled") is True
|
566
tests/unit/core/test_dependency_injection.py
Normal file
566
tests/unit/core/test_dependency_injection.py
Normal file
|
@ -0,0 +1,566 @@
|
|||
"""
|
||||
Unit tests for dependency injection system.
|
||||
"""
|
||||
import pytest
|
||||
import threading
|
||||
from unittest.mock import Mock, MagicMock
|
||||
|
||||
from detector_worker.core.dependency_injection import (
|
||||
ServiceContainer,
|
||||
ServiceLifetime,
|
||||
ServiceDescriptor,
|
||||
ServiceScope,
|
||||
DetectorWorkerContainer,
|
||||
get_container,
|
||||
resolve_service,
|
||||
create_service_scope
|
||||
)
|
||||
from detector_worker.core.exceptions import DependencyInjectionError
|
||||
|
||||
|
||||
class TestServiceContainer:
|
||||
"""Test core service container functionality."""
|
||||
|
||||
def test_register_singleton(self):
|
||||
"""Test singleton service registration."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
# Register singleton
|
||||
container.register_singleton(TestService)
|
||||
|
||||
# Resolve twice - should get same instance
|
||||
instance1 = container.resolve(TestService)
|
||||
instance2 = container.resolve(TestService)
|
||||
|
||||
assert instance1 is instance2
|
||||
assert instance1.value == 42
|
||||
|
||||
def test_register_singleton_with_instance(self):
|
||||
"""Test singleton registration with pre-created instance."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
# Create instance and register
|
||||
instance = TestService(99)
|
||||
container.register_singleton(TestService, instance=instance)
|
||||
|
||||
# Resolve should return the pre-created instance
|
||||
resolved = container.resolve(TestService)
|
||||
assert resolved is instance
|
||||
assert resolved.value == 99
|
||||
|
||||
def test_register_transient(self):
|
||||
"""Test transient service registration."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
# Register transient
|
||||
container.register_transient(TestService)
|
||||
|
||||
# Resolve twice - should get different instances
|
||||
instance1 = container.resolve(TestService)
|
||||
instance2 = container.resolve(TestService)
|
||||
|
||||
assert instance1 is not instance2
|
||||
assert instance1.value == instance2.value == 42
|
||||
|
||||
def test_register_scoped(self):
|
||||
"""Test scoped service registration."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
# Register scoped
|
||||
container.register_scoped(TestService)
|
||||
|
||||
# Resolve in same scope - should get same instance
|
||||
instance1 = container.resolve(TestService, scope_id="scope1")
|
||||
instance2 = container.resolve(TestService, scope_id="scope1")
|
||||
|
||||
assert instance1 is instance2
|
||||
|
||||
# Resolve in different scope - should get different instance
|
||||
instance3 = container.resolve(TestService, scope_id="scope2")
|
||||
assert instance3 is not instance1
|
||||
|
||||
def test_register_with_factory(self):
|
||||
"""Test service registration with factory function."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
# Register with factory
|
||||
def factory():
|
||||
return TestService(100)
|
||||
|
||||
container.register_singleton(TestService, factory=factory)
|
||||
|
||||
instance = container.resolve(TestService)
|
||||
assert instance.value == 100
|
||||
|
||||
def test_register_with_implementation_type(self):
|
||||
"""Test service registration with implementation type."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class ITestService:
|
||||
pass
|
||||
|
||||
class TestService(ITestService):
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
# Register interface with implementation
|
||||
container.register_singleton(ITestService, implementation_type=TestService)
|
||||
|
||||
instance = container.resolve(ITestService)
|
||||
assert isinstance(instance, TestService)
|
||||
assert instance.value == 42
|
||||
|
||||
def test_dependency_injection(self):
|
||||
"""Test automatic dependency injection."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class DatabaseService:
|
||||
def __init__(self):
|
||||
self.connected = True
|
||||
|
||||
class UserService:
|
||||
def __init__(self, database: DatabaseService):
|
||||
self.database = database
|
||||
|
||||
# Register services
|
||||
container.register_singleton(DatabaseService)
|
||||
container.register_transient(UserService)
|
||||
|
||||
# Resolve should inject dependencies
|
||||
user_service = container.resolve(UserService)
|
||||
assert isinstance(user_service.database, DatabaseService)
|
||||
assert user_service.database.connected is True
|
||||
|
||||
def test_circular_dependency_detection(self):
|
||||
"""Test circular dependency detection."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class ServiceA:
|
||||
def __init__(self, service_b: 'ServiceB'):
|
||||
self.service_b = service_b
|
||||
|
||||
class ServiceB:
|
||||
def __init__(self, service_a: ServiceA):
|
||||
self.service_a = service_a
|
||||
|
||||
# Register circular dependencies
|
||||
container.register_singleton(ServiceA)
|
||||
container.register_singleton(ServiceB)
|
||||
|
||||
# Should raise circular dependency error
|
||||
with pytest.raises(DependencyInjectionError) as exc_info:
|
||||
container.resolve(ServiceA)
|
||||
|
||||
assert "Circular dependency detected" in str(exc_info.value)
|
||||
|
||||
def test_unregistered_service_error(self):
|
||||
"""Test error when resolving unregistered service."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class UnregisteredService:
|
||||
pass
|
||||
|
||||
with pytest.raises(DependencyInjectionError) as exc_info:
|
||||
container.resolve(UnregisteredService)
|
||||
|
||||
assert "is not registered" in str(exc_info.value)
|
||||
|
||||
def test_scoped_service_without_scope_id(self):
|
||||
"""Test error when resolving scoped service without scope ID."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
pass
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
with pytest.raises(DependencyInjectionError) as exc_info:
|
||||
container.resolve(TestService)
|
||||
|
||||
assert "Scope ID required" in str(exc_info.value)
|
||||
|
||||
def test_factory_error_handling(self):
|
||||
"""Test factory error handling."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
pass
|
||||
|
||||
def failing_factory():
|
||||
raise ValueError("Factory failed")
|
||||
|
||||
container.register_singleton(TestService, factory=failing_factory)
|
||||
|
||||
with pytest.raises(DependencyInjectionError) as exc_info:
|
||||
container.resolve(TestService)
|
||||
|
||||
assert "Failed to create service using factory" in str(exc_info.value)
|
||||
|
||||
def test_constructor_dependency_with_default(self):
|
||||
"""Test dependency with default value."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self, value: int = 42):
|
||||
self.value = value
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
instance = container.resolve(TestService)
|
||||
assert instance.value == 42
|
||||
|
||||
def test_unresolvable_dependency_with_default(self):
|
||||
"""Test unresolvable dependency that has a default value."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class UnregisteredService:
|
||||
pass
|
||||
|
||||
class TestService:
|
||||
def __init__(self, dep: UnregisteredService = None):
|
||||
self.dep = dep
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
instance = container.resolve(TestService)
|
||||
assert instance.dep is None
|
||||
|
||||
def test_unresolvable_dependency_without_default(self):
|
||||
"""Test unresolvable dependency without default value."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class UnregisteredService:
|
||||
pass
|
||||
|
||||
class TestService:
|
||||
def __init__(self, dep: UnregisteredService):
|
||||
self.dep = dep
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
with pytest.raises(DependencyInjectionError) as exc_info:
|
||||
container.resolve(TestService)
|
||||
|
||||
assert "Cannot resolve dependency" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestServiceScope:
|
||||
"""Test service scope functionality."""
|
||||
|
||||
def test_create_scope(self):
|
||||
"""Test scope creation."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
scope = container.create_scope("test_scope")
|
||||
assert isinstance(scope, ServiceScope)
|
||||
assert scope.scope_id == "test_scope"
|
||||
|
||||
def test_scope_context_manager(self):
|
||||
"""Test scope as context manager."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.disposed = False
|
||||
|
||||
def dispose(self):
|
||||
self.disposed = True
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
instance = None
|
||||
with container.create_scope("test_scope") as scope:
|
||||
instance = scope.resolve(TestService)
|
||||
assert not instance.disposed
|
||||
|
||||
# Instance should be disposed after scope exit
|
||||
assert instance.disposed
|
||||
|
||||
def test_dispose_scope(self):
|
||||
"""Test manual scope disposal."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.disposed = False
|
||||
|
||||
def dispose(self):
|
||||
self.disposed = True
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
instance = container.resolve(TestService, scope_id="test_scope")
|
||||
assert not instance.disposed
|
||||
|
||||
container.dispose_scope("test_scope")
|
||||
assert instance.disposed
|
||||
|
||||
def test_dispose_error_handling(self):
|
||||
"""Test error handling during scope disposal."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def dispose(self):
|
||||
raise ValueError("Dispose failed")
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
container.resolve(TestService, scope_id="test_scope")
|
||||
|
||||
# Should not raise error, just log it
|
||||
container.dispose_scope("test_scope")
|
||||
|
||||
|
||||
class TestContainerIntrospection:
|
||||
"""Test container introspection capabilities."""
|
||||
|
||||
def test_is_registered(self):
|
||||
"""Test checking if service is registered."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class RegisteredService:
|
||||
pass
|
||||
|
||||
class UnregisteredService:
|
||||
pass
|
||||
|
||||
container.register_singleton(RegisteredService)
|
||||
|
||||
assert container.is_registered(RegisteredService) is True
|
||||
assert container.is_registered(UnregisteredService) is False
|
||||
|
||||
def test_get_registration_info(self):
|
||||
"""Test getting service registration information."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
pass
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
info = container.get_registration_info(TestService)
|
||||
assert isinstance(info, ServiceDescriptor)
|
||||
assert info.service_type == TestService
|
||||
assert info.lifetime == ServiceLifetime.SINGLETON
|
||||
|
||||
def test_get_registered_services(self):
|
||||
"""Test getting all registered services."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class Service1:
|
||||
pass
|
||||
|
||||
class Service2:
|
||||
pass
|
||||
|
||||
container.register_singleton(Service1)
|
||||
container.register_transient(Service2)
|
||||
|
||||
services = container.get_registered_services()
|
||||
assert len(services) == 2
|
||||
assert Service1 in services
|
||||
assert Service2 in services
|
||||
|
||||
def test_clear_singletons(self):
|
||||
"""Test clearing singleton instances."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
pass
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
# Create singleton instance
|
||||
instance1 = container.resolve(TestService)
|
||||
|
||||
# Clear singletons
|
||||
container.clear_singletons()
|
||||
|
||||
# Next resolve should create new instance
|
||||
instance2 = container.resolve(TestService)
|
||||
assert instance2 is not instance1
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting container statistics."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class Service1:
|
||||
pass
|
||||
|
||||
class Service2:
|
||||
pass
|
||||
|
||||
class Service3:
|
||||
pass
|
||||
|
||||
container.register_singleton(Service1)
|
||||
container.register_transient(Service2)
|
||||
container.register_scoped(Service3)
|
||||
|
||||
# Create some instances
|
||||
container.resolve(Service1)
|
||||
container.resolve(Service3, scope_id="scope1")
|
||||
|
||||
stats = container.get_stats()
|
||||
|
||||
assert stats["registered_services"] == 3
|
||||
assert stats["active_singletons"] == 1
|
||||
assert stats["active_scopes"] == 1
|
||||
assert stats["lifetime_breakdown"]["singleton"] == 1
|
||||
assert stats["lifetime_breakdown"]["transient"] == 1
|
||||
assert stats["lifetime_breakdown"]["scoped"] == 1
|
||||
|
||||
|
||||
class TestDetectorWorkerContainer:
|
||||
"""Test pre-configured detector worker container."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test detector worker container initialization."""
|
||||
container = DetectorWorkerContainer()
|
||||
|
||||
assert isinstance(container.container, ServiceContainer)
|
||||
|
||||
# Should have core services registered
|
||||
stats = container.container.get_stats()
|
||||
assert stats["registered_services"] > 0
|
||||
|
||||
def test_resolve_convenience_method(self):
|
||||
"""Test resolve convenience method."""
|
||||
container = DetectorWorkerContainer()
|
||||
|
||||
# Should be able to resolve through convenience method
|
||||
from detector_worker.core.singleton_managers import ModelStateManager
|
||||
|
||||
manager = container.resolve(ModelStateManager)
|
||||
assert isinstance(manager, ModelStateManager)
|
||||
|
||||
def test_create_scope_convenience_method(self):
|
||||
"""Test create scope convenience method."""
|
||||
container = DetectorWorkerContainer()
|
||||
|
||||
scope = container.create_scope("test_scope")
|
||||
assert isinstance(scope, ServiceScope)
|
||||
assert scope.scope_id == "test_scope"
|
||||
|
||||
|
||||
class TestGlobalContainerFunctions:
|
||||
"""Test global container functions."""
|
||||
|
||||
def test_get_container_singleton(self):
|
||||
"""Test that get_container returns a singleton."""
|
||||
container1 = get_container()
|
||||
container2 = get_container()
|
||||
|
||||
assert container1 is container2
|
||||
assert isinstance(container1, DetectorWorkerContainer)
|
||||
|
||||
def test_resolve_service_convenience(self):
|
||||
"""Test resolve_service convenience function."""
|
||||
from detector_worker.core.singleton_managers import ModelStateManager
|
||||
|
||||
manager = resolve_service(ModelStateManager)
|
||||
assert isinstance(manager, ModelStateManager)
|
||||
|
||||
def test_create_service_scope_convenience(self):
|
||||
"""Test create_service_scope convenience function."""
|
||||
scope = create_service_scope("test_scope")
|
||||
assert isinstance(scope, ServiceScope)
|
||||
assert scope.scope_id == "test_scope"
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Test thread safety of dependency injection system."""
|
||||
|
||||
def test_container_thread_safety(self):
|
||||
"""Test that container is thread-safe."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
import threading
|
||||
self.thread_id = threading.current_thread().ident
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
instances = {}
|
||||
|
||||
def resolve_service(thread_id):
|
||||
instances[thread_id] = container.resolve(TestService)
|
||||
|
||||
# Create multiple threads
|
||||
threads = []
|
||||
for i in range(10):
|
||||
thread = threading.Thread(target=resolve_service, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# All should get the same singleton instance
|
||||
first_instance = list(instances.values())[0]
|
||||
for instance in instances.values():
|
||||
assert instance is first_instance
|
||||
|
||||
def test_scope_thread_safety(self):
|
||||
"""Test that scoped services are thread-safe."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
import threading
|
||||
self.thread_id = threading.current_thread().ident
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
results = {}
|
||||
|
||||
def resolve_in_scope(thread_id):
|
||||
# Each thread uses its own scope
|
||||
instance1 = container.resolve(TestService, scope_id=f"scope_{thread_id}")
|
||||
instance2 = container.resolve(TestService, scope_id=f"scope_{thread_id}")
|
||||
|
||||
results[thread_id] = {
|
||||
"same_instance": instance1 is instance2,
|
||||
"thread_id": instance1.thread_id
|
||||
}
|
||||
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=resolve_in_scope, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Each thread should get same instance within its scope
|
||||
for thread_id, result in results.items():
|
||||
assert result["same_instance"] is True
|
560
tests/unit/core/test_singleton_managers.py
Normal file
560
tests/unit/core/test_singleton_managers.py
Normal file
|
@ -0,0 +1,560 @@
|
|||
"""
|
||||
Unit tests for singleton state managers.
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
import threading
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from detector_worker.core.singleton_managers import (
|
||||
SingletonMeta,
|
||||
ModelStateManager,
|
||||
StreamStateManager,
|
||||
SessionStateManager,
|
||||
CacheStateManager,
|
||||
CameraStateManager,
|
||||
PipelineStateManager,
|
||||
ModelInfo,
|
||||
StreamInfo,
|
||||
SessionInfo
|
||||
)
|
||||
|
||||
|
||||
class TestSingletonMeta:
|
||||
"""Test singleton metaclass."""
|
||||
|
||||
def test_singleton_behavior(self):
|
||||
"""Test that singleton metaclass creates only one instance."""
|
||||
class TestSingleton(metaclass=SingletonMeta):
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
instance1 = TestSingleton()
|
||||
instance2 = TestSingleton()
|
||||
|
||||
assert instance1 is instance2
|
||||
assert instance1.value == instance2.value
|
||||
|
||||
def test_singleton_thread_safety(self):
|
||||
"""Test that singleton is thread-safe."""
|
||||
class TestSingleton(metaclass=SingletonMeta):
|
||||
def __init__(self):
|
||||
self.created_by = threading.current_thread().name
|
||||
|
||||
instances = {}
|
||||
|
||||
def create_instance(thread_id):
|
||||
instances[thread_id] = TestSingleton()
|
||||
|
||||
threads = []
|
||||
for i in range(10):
|
||||
thread = threading.Thread(target=create_instance, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# All instances should be the same object
|
||||
first_instance = instances[0]
|
||||
for instance in instances.values():
|
||||
assert instance is first_instance
|
||||
|
||||
|
||||
class TestModelStateManager:
|
||||
"""Test model state management."""
|
||||
|
||||
def test_singleton_behavior(self):
|
||||
"""Test that ModelStateManager is a singleton."""
|
||||
manager1 = ModelStateManager()
|
||||
manager2 = ModelStateManager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_load_model(self):
|
||||
"""Test loading a model."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all() # Start fresh
|
||||
|
||||
mock_model = Mock()
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
|
||||
retrieved_model = manager.get_model("camera1", "model1")
|
||||
assert retrieved_model is mock_model
|
||||
|
||||
def test_load_same_model_increments_reference_count(self):
|
||||
"""Test that loading the same model increments reference count."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
mock_model = Mock()
|
||||
|
||||
# Load same model twice
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
|
||||
# Should still be accessible
|
||||
assert manager.get_model("camera1", "model1") is mock_model
|
||||
|
||||
def test_get_camera_models(self):
|
||||
"""Test getting all models for a camera."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
mock_model1 = Mock()
|
||||
mock_model2 = Mock()
|
||||
|
||||
manager.load_model("camera1", "model1", mock_model1)
|
||||
manager.load_model("camera1", "model2", mock_model2)
|
||||
|
||||
models = manager.get_camera_models("camera1")
|
||||
|
||||
assert len(models) == 2
|
||||
assert models["model1"] is mock_model1
|
||||
assert models["model2"] is mock_model2
|
||||
|
||||
def test_unload_model_with_multiple_references(self):
|
||||
"""Test unloading model with multiple references."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
mock_model = Mock()
|
||||
|
||||
# Load model twice (reference count = 2)
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
|
||||
# First unload should not remove model
|
||||
result = manager.unload_model("camera1", "model1")
|
||||
assert result is False # Still referenced
|
||||
assert manager.get_model("camera1", "model1") is mock_model
|
||||
|
||||
# Second unload should remove model
|
||||
result = manager.unload_model("camera1", "model1")
|
||||
assert result is True # Completely removed
|
||||
assert manager.get_model("camera1", "model1") is None
|
||||
|
||||
def test_unload_camera_models(self):
|
||||
"""Test unloading all models for a camera."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
mock_model1 = Mock()
|
||||
mock_model2 = Mock()
|
||||
|
||||
manager.load_model("camera1", "model1", mock_model1)
|
||||
manager.load_model("camera1", "model2", mock_model2)
|
||||
|
||||
manager.unload_camera_models("camera1")
|
||||
|
||||
assert manager.get_model("camera1", "model1") is None
|
||||
assert manager.get_model("camera1", "model2") is None
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting model statistics."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
mock_model = Mock()
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
manager.load_model("camera2", "model2", mock_model)
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert stats["total_models"] == 2
|
||||
assert stats["total_cameras"] == 2
|
||||
assert "camera1" in stats["cameras"]
|
||||
assert "camera2" in stats["cameras"]
|
||||
|
||||
|
||||
class TestStreamStateManager:
|
||||
"""Test stream state management."""
|
||||
|
||||
def test_add_stream(self):
|
||||
"""Test adding a stream."""
|
||||
manager = StreamStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
config = {"rtsp_url": "rtsp://example.com", "model_id": "test"}
|
||||
manager.add_stream("camera1", "sub1", config)
|
||||
|
||||
stream = manager.get_stream("camera1")
|
||||
assert stream is not None
|
||||
assert stream.camera_id == "camera1"
|
||||
assert stream.subscription_id == "sub1"
|
||||
assert stream.config == config
|
||||
|
||||
def test_subscription_mapping(self):
|
||||
"""Test subscription to camera mapping."""
|
||||
manager = StreamStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
config = {"rtsp_url": "rtsp://example.com"}
|
||||
manager.add_stream("camera1", "sub1", config)
|
||||
|
||||
camera_id = manager.get_camera_by_subscription("sub1")
|
||||
assert camera_id == "camera1"
|
||||
|
||||
def test_remove_stream(self):
|
||||
"""Test removing a stream."""
|
||||
manager = StreamStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
config = {"rtsp_url": "rtsp://example.com"}
|
||||
manager.add_stream("camera1", "sub1", config)
|
||||
|
||||
removed_stream = manager.remove_stream("camera1")
|
||||
|
||||
assert removed_stream is not None
|
||||
assert removed_stream.camera_id == "camera1"
|
||||
assert manager.get_stream("camera1") is None
|
||||
assert manager.get_camera_by_subscription("sub1") is None
|
||||
|
||||
def test_shared_stream_management(self):
|
||||
"""Test shared stream management."""
|
||||
manager = StreamStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
stream_data = {"reader": Mock(), "reference_count": 1}
|
||||
manager.add_shared_stream("rtsp://example.com", stream_data)
|
||||
|
||||
retrieved_data = manager.get_shared_stream("rtsp://example.com")
|
||||
assert retrieved_data == stream_data
|
||||
|
||||
removed_data = manager.remove_shared_stream("rtsp://example.com")
|
||||
assert removed_data == stream_data
|
||||
assert manager.get_shared_stream("rtsp://example.com") is None
|
||||
|
||||
|
||||
class TestSessionStateManager:
|
||||
"""Test session state management."""
|
||||
|
||||
def test_session_id_management(self):
|
||||
"""Test session ID assignment."""
|
||||
manager = SessionStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.set_session_id("display1", "session123")
|
||||
|
||||
session_id = manager.get_session_id("display1")
|
||||
assert session_id == "session123"
|
||||
|
||||
def test_create_session(self):
|
||||
"""Test session creation with detection data."""
|
||||
manager = SessionStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
detection_data = {"class": "car", "confidence": 0.85}
|
||||
manager.create_session("session123", "camera1", detection_data)
|
||||
|
||||
retrieved_data = manager.get_session_detection("session123")
|
||||
assert retrieved_data == detection_data
|
||||
|
||||
camera_id = manager.get_camera_by_session("session123")
|
||||
assert camera_id == "camera1"
|
||||
|
||||
def test_update_session_detection(self):
|
||||
"""Test updating session detection data."""
|
||||
manager = SessionStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
initial_data = {"class": "car", "confidence": 0.85}
|
||||
manager.create_session("session123", "camera1", initial_data)
|
||||
|
||||
update_data = {"brand": "Toyota"}
|
||||
manager.update_session_detection("session123", update_data)
|
||||
|
||||
final_data = manager.get_session_detection("session123")
|
||||
assert final_data["class"] == "car"
|
||||
assert final_data["brand"] == "Toyota"
|
||||
|
||||
def test_session_expiration(self):
|
||||
"""Test session expiration based on TTL."""
|
||||
# Use a very short TTL for testing
|
||||
manager = SessionStateManager(session_ttl=0.1)
|
||||
manager.clear_all()
|
||||
|
||||
detection_data = {"class": "car", "confidence": 0.85}
|
||||
manager.create_session("session123", "camera1", detection_data)
|
||||
|
||||
# Session should exist initially
|
||||
assert manager.get_session_detection("session123") is not None
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.2)
|
||||
|
||||
# Clean up expired sessions
|
||||
expired_count = manager.cleanup_expired_sessions()
|
||||
|
||||
assert expired_count == 1
|
||||
assert manager.get_session_detection("session123") is None
|
||||
|
||||
def test_remove_session(self):
|
||||
"""Test manual session removal."""
|
||||
manager = SessionStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
detection_data = {"class": "car", "confidence": 0.85}
|
||||
manager.create_session("session123", "camera1", detection_data)
|
||||
|
||||
result = manager.remove_session("session123")
|
||||
assert result is True
|
||||
|
||||
assert manager.get_session_detection("session123") is None
|
||||
assert manager.get_camera_by_session("session123") is None
|
||||
|
||||
|
||||
class TestCacheStateManager:
|
||||
"""Test cache state management."""
|
||||
|
||||
def test_cache_detection(self):
|
||||
"""Test caching detection results."""
|
||||
manager = CacheStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
detection_data = {"class": "car", "confidence": 0.85, "bbox": [100, 200, 300, 400]}
|
||||
manager.cache_detection("camera1", detection_data)
|
||||
|
||||
cached_data = manager.get_cached_detection("camera1")
|
||||
assert cached_data == detection_data
|
||||
|
||||
def test_cache_pipeline_result(self):
|
||||
"""Test caching pipeline results."""
|
||||
manager = CacheStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
pipeline_result = {"status": "success", "detections": []}
|
||||
manager.cache_pipeline_result("camera1", pipeline_result)
|
||||
|
||||
cached_result = manager.get_cached_pipeline_result("camera1")
|
||||
assert cached_result == pipeline_result
|
||||
|
||||
def test_latest_frame_management(self):
|
||||
"""Test latest frame storage."""
|
||||
manager = CacheStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
frame_data = b"fake_frame_data"
|
||||
manager.set_latest_frame("camera1", frame_data)
|
||||
|
||||
retrieved_frame = manager.get_latest_frame("camera1")
|
||||
assert retrieved_frame == frame_data
|
||||
|
||||
def test_frame_skip_flag(self):
|
||||
"""Test frame skip flag management."""
|
||||
manager = CacheStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Initially should be False
|
||||
assert manager.get_frame_skip_flag("camera1") is False
|
||||
|
||||
manager.set_frame_skip_flag("camera1", True)
|
||||
assert manager.get_frame_skip_flag("camera1") is True
|
||||
|
||||
manager.set_frame_skip_flag("camera1", False)
|
||||
assert manager.get_frame_skip_flag("camera1") is False
|
||||
|
||||
def test_clear_camera_cache(self):
|
||||
"""Test clearing all cache data for a camera."""
|
||||
manager = CacheStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Set up cache data
|
||||
detection_data = {"class": "car"}
|
||||
pipeline_result = {"status": "success"}
|
||||
frame_data = b"frame"
|
||||
|
||||
manager.cache_detection("camera1", detection_data)
|
||||
manager.cache_pipeline_result("camera1", pipeline_result)
|
||||
manager.set_latest_frame("camera1", frame_data)
|
||||
manager.set_frame_skip_flag("camera1", True)
|
||||
|
||||
# Clear cache
|
||||
manager.clear_camera_cache("camera1")
|
||||
|
||||
# All data should be gone
|
||||
assert manager.get_cached_detection("camera1") is None
|
||||
assert manager.get_cached_pipeline_result("camera1") is None
|
||||
assert manager.get_latest_frame("camera1") is None
|
||||
assert manager.get_frame_skip_flag("camera1") is False
|
||||
|
||||
|
||||
class TestCameraStateManager:
|
||||
"""Test camera state management."""
|
||||
|
||||
def test_camera_connection_state(self):
|
||||
"""Test camera connection state management."""
|
||||
manager = CameraStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Initially connected (default)
|
||||
assert manager.is_camera_connected("camera1") is True
|
||||
|
||||
# Set disconnected
|
||||
manager.set_camera_connected("camera1", False)
|
||||
assert manager.is_camera_connected("camera1") is False
|
||||
|
||||
# Set connected again
|
||||
manager.set_camera_connected("camera1", True)
|
||||
assert manager.is_camera_connected("camera1") is True
|
||||
|
||||
def test_notification_flags(self):
|
||||
"""Test disconnection/reconnection notification flags."""
|
||||
manager = CameraStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Set disconnected
|
||||
manager.set_camera_connected("camera1", False)
|
||||
|
||||
# Should notify disconnection once
|
||||
assert manager.should_notify_disconnection("camera1") is True
|
||||
manager.mark_disconnection_notified("camera1")
|
||||
assert manager.should_notify_disconnection("camera1") is False
|
||||
|
||||
# Reconnect
|
||||
manager.set_camera_connected("camera1", True)
|
||||
|
||||
# Should notify reconnection
|
||||
assert manager.should_notify_reconnection("camera1") is True
|
||||
manager.mark_reconnection_notified("camera1")
|
||||
assert manager.should_notify_reconnection("camera1") is False
|
||||
|
||||
def test_get_camera_state(self):
|
||||
"""Test getting full camera state."""
|
||||
manager = CameraStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.set_camera_connected("camera1", False)
|
||||
|
||||
state = manager.get_camera_state("camera1")
|
||||
|
||||
assert state["connected"] is False
|
||||
assert "last_update" in state
|
||||
assert "disconnection_notified" in state
|
||||
assert "reconnection_notified" in state
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting camera state statistics."""
|
||||
manager = CameraStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.set_camera_connected("camera1", True)
|
||||
manager.set_camera_connected("camera2", False)
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert stats["total_cameras"] == 2
|
||||
assert stats["connected_cameras"] == 1
|
||||
assert stats["disconnected_cameras"] == 1
|
||||
|
||||
|
||||
class TestPipelineStateManager:
|
||||
"""Test pipeline state management."""
|
||||
|
||||
def test_get_or_init_state(self):
|
||||
"""Test getting or initializing pipeline state."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
state = manager.get_or_init_state("camera1")
|
||||
|
||||
assert state["mode"] == "validation_detecting"
|
||||
assert state["backend_session_id"] is None
|
||||
assert state["yolo_inference_enabled"] is True
|
||||
assert "created_at" in state
|
||||
|
||||
def test_update_mode(self):
|
||||
"""Test updating pipeline mode."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.update_mode("camera1", "classification", "session123")
|
||||
|
||||
state = manager.get_state("camera1")
|
||||
assert state["mode"] == "classification"
|
||||
assert state["backend_session_id"] == "session123"
|
||||
|
||||
def test_set_yolo_inference_enabled(self):
|
||||
"""Test setting YOLO inference state."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.set_yolo_inference_enabled("camera1", False)
|
||||
|
||||
state = manager.get_state("camera1")
|
||||
assert state["yolo_inference_enabled"] is False
|
||||
|
||||
def test_set_progression_stage(self):
|
||||
"""Test setting progression stage."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.set_progression_stage("camera1", "brand_classification")
|
||||
|
||||
state = manager.get_state("camera1")
|
||||
assert state["progression_stage"] == "brand_classification"
|
||||
|
||||
def test_set_validated_detection(self):
|
||||
"""Test setting validated detection."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
detection = {"class": "car", "confidence": 0.85}
|
||||
manager.set_validated_detection("camera1", detection)
|
||||
|
||||
state = manager.get_state("camera1")
|
||||
assert state["validated_detection"] == detection
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting pipeline state statistics."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.update_mode("camera1", "validation_detecting")
|
||||
manager.update_mode("camera2", "classification")
|
||||
manager.update_mode("camera3", "classification")
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert stats["total_pipeline_states"] == 3
|
||||
assert stats["mode_breakdown"]["validation_detecting"] == 1
|
||||
assert stats["mode_breakdown"]["classification"] == 2
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Test thread safety of singleton managers."""
|
||||
|
||||
def test_model_manager_thread_safety(self):
|
||||
"""Test ModelStateManager thread safety."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
results = {}
|
||||
|
||||
def load_models(thread_id):
|
||||
for i in range(10):
|
||||
model = Mock()
|
||||
model.thread_id = thread_id
|
||||
model.model_id = i
|
||||
manager.load_model(f"camera{thread_id}", f"model{i}", model)
|
||||
|
||||
# Verify models
|
||||
models = manager.get_camera_models(f"camera{thread_id}")
|
||||
results[thread_id] = len(models)
|
||||
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=load_models, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Each thread should have loaded 10 models
|
||||
for thread_id, model_count in results.items():
|
||||
assert model_count == 10
|
||||
|
||||
# Total should be 50 models
|
||||
stats = manager.get_stats()
|
||||
assert stats["total_models"] == 50
|
479
tests/unit/detection/test_detection_result.py
Normal file
479
tests/unit/detection/test_detection_result.py
Normal file
|
@ -0,0 +1,479 @@
|
|||
"""
|
||||
Unit tests for detection result data structures.
|
||||
"""
|
||||
import pytest
|
||||
from dataclasses import asdict
|
||||
import numpy as np
|
||||
|
||||
from detector_worker.detection.detection_result import (
|
||||
BoundingBox,
|
||||
DetectionResult,
|
||||
LightweightDetectionResult,
|
||||
DetectionSession,
|
||||
TrackValidationResult
|
||||
)
|
||||
|
||||
|
||||
class TestBoundingBox:
|
||||
"""Test BoundingBox data structure."""
|
||||
|
||||
def test_creation_from_coordinates(self):
|
||||
"""Test creating bounding box from coordinates."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
assert bbox.x1 == 100
|
||||
assert bbox.y1 == 200
|
||||
assert bbox.x2 == 300
|
||||
assert bbox.y2 == 400
|
||||
|
||||
def test_creation_from_list(self):
|
||||
"""Test creating bounding box from list."""
|
||||
coords = [100, 200, 300, 400]
|
||||
bbox = BoundingBox.from_list(coords)
|
||||
|
||||
assert bbox.x1 == 100
|
||||
assert bbox.y1 == 200
|
||||
assert bbox.x2 == 300
|
||||
assert bbox.y2 == 400
|
||||
|
||||
def test_creation_from_invalid_list(self):
|
||||
"""Test error handling for invalid list."""
|
||||
with pytest.raises(ValueError):
|
||||
BoundingBox.from_list([100, 200, 300]) # Too few elements
|
||||
|
||||
def test_to_list(self):
|
||||
"""Test converting bounding box to list."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
coords = bbox.to_list()
|
||||
|
||||
assert coords == [100, 200, 300, 400]
|
||||
|
||||
def test_area_calculation(self):
|
||||
"""Test area calculation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
area = bbox.area()
|
||||
|
||||
expected_area = (300 - 100) * (400 - 200) # 200 * 200 = 40000
|
||||
assert area == expected_area
|
||||
|
||||
def test_area_zero_for_invalid_bbox(self):
|
||||
"""Test area is zero for invalid bounding box."""
|
||||
# x2 <= x1
|
||||
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
||||
assert bbox.area() == 0
|
||||
|
||||
# y2 <= y1
|
||||
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
|
||||
assert bbox.area() == 0
|
||||
|
||||
def test_width_height(self):
|
||||
"""Test width and height properties."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
assert bbox.width() == 200
|
||||
assert bbox.height() == 200
|
||||
|
||||
def test_center_point(self):
|
||||
"""Test center point calculation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
center = bbox.center()
|
||||
|
||||
assert center == (200, 300) # (x1+x2)/2, (y1+y2)/2
|
||||
|
||||
def test_is_valid(self):
|
||||
"""Test bounding box validation."""
|
||||
# Valid bbox
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
assert bbox.is_valid() is True
|
||||
|
||||
# Invalid bbox (x2 <= x1)
|
||||
bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
||||
assert bbox.is_valid() is False
|
||||
|
||||
# Invalid bbox (y2 <= y1)
|
||||
bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200)
|
||||
assert bbox.is_valid() is False
|
||||
|
||||
def test_intersection(self):
|
||||
"""Test bounding box intersection."""
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
||||
|
||||
intersection = bbox1.intersection(bbox2)
|
||||
|
||||
assert intersection.x1 == 200
|
||||
assert intersection.y1 == 200
|
||||
assert intersection.x2 == 300
|
||||
assert intersection.y2 == 300
|
||||
|
||||
def test_no_intersection(self):
|
||||
"""Test no intersection between bounding boxes."""
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
|
||||
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
|
||||
|
||||
intersection = bbox1.intersection(bbox2)
|
||||
|
||||
assert intersection.is_valid() is False
|
||||
|
||||
def test_union(self):
|
||||
"""Test bounding box union."""
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
||||
|
||||
union = bbox1.union(bbox2)
|
||||
|
||||
assert union.x1 == 100
|
||||
assert union.y1 == 100
|
||||
assert union.x2 == 400
|
||||
assert union.y2 == 400
|
||||
|
||||
def test_iou_calculation(self):
|
||||
"""Test IoU (Intersection over Union) calculation."""
|
||||
# Perfect overlap
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
assert bbox1.iou(bbox2) == 1.0
|
||||
|
||||
# No overlap
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200)
|
||||
bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400)
|
||||
assert bbox1.iou(bbox2) == 0.0
|
||||
|
||||
# Partial overlap
|
||||
bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300)
|
||||
bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400)
|
||||
|
||||
# Intersection area: 100x100 = 10000
|
||||
# Union area: 200x200 + 200x200 - 10000 = 30000
|
||||
# IoU = 10000/30000 = 1/3
|
||||
expected_iou = 1.0 / 3.0
|
||||
assert abs(bbox1.iou(bbox2) - expected_iou) < 1e-6
|
||||
|
||||
|
||||
class TestDetectionResult:
|
||||
"""Test DetectionResult data structure."""
|
||||
|
||||
def test_creation_with_required_fields(self):
|
||||
"""Test creating detection result with required fields."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
assert detection.class_name == "car"
|
||||
assert detection.confidence == 0.85
|
||||
assert detection.bbox == bbox
|
||||
assert detection.track_id == 12345
|
||||
|
||||
def test_creation_with_all_fields(self):
|
||||
"""Test creating detection result with all fields."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345,
|
||||
model_id="yolo_v8",
|
||||
timestamp=1640995200000,
|
||||
branch_results={"brand": "Toyota"}
|
||||
)
|
||||
|
||||
assert detection.model_id == "yolo_v8"
|
||||
assert detection.timestamp == 1640995200000
|
||||
assert detection.branch_results == {"brand": "Toyota"}
|
||||
|
||||
def test_creation_from_dict(self):
|
||||
"""Test creating detection result from dictionary."""
|
||||
data = {
|
||||
"class": "car",
|
||||
"confidence": 0.85,
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"id": 12345,
|
||||
"model_id": "yolo_v8",
|
||||
"timestamp": 1640995200000
|
||||
}
|
||||
|
||||
detection = DetectionResult.from_dict(data)
|
||||
|
||||
assert detection.class_name == "car"
|
||||
assert detection.confidence == 0.85
|
||||
assert detection.bbox.to_list() == [100, 200, 300, 400]
|
||||
assert detection.track_id == 12345
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting detection result to dictionary."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
data = detection.to_dict()
|
||||
|
||||
assert data["class"] == "car"
|
||||
assert data["confidence"] == 0.85
|
||||
assert data["bbox"] == [100, 200, 300, 400]
|
||||
assert data["id"] == 12345
|
||||
|
||||
def test_is_valid_detection(self):
|
||||
"""Test detection validation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
# Valid detection
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is True
|
||||
|
||||
# Invalid confidence (too low)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=-0.1,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is False
|
||||
|
||||
# Invalid confidence (too high)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=1.5,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is False
|
||||
|
||||
# Invalid bounding box
|
||||
invalid_bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=invalid_bbox,
|
||||
track_id=12345
|
||||
)
|
||||
assert detection.is_valid() is False
|
||||
|
||||
|
||||
class TestLightweightDetectionResult:
|
||||
"""Test LightweightDetectionResult data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating lightweight detection result."""
|
||||
detection = LightweightDetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox_area=40000,
|
||||
frame_width=1920,
|
||||
frame_height=1080
|
||||
)
|
||||
|
||||
assert detection.class_name == "car"
|
||||
assert detection.confidence == 0.85
|
||||
assert detection.bbox_area == 40000
|
||||
assert detection.frame_width == 1920
|
||||
assert detection.frame_height == 1080
|
||||
|
||||
def test_area_ratio_calculation(self):
|
||||
"""Test bounding box area ratio calculation."""
|
||||
detection = LightweightDetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox_area=40000,
|
||||
frame_width=1920,
|
||||
frame_height=1080
|
||||
)
|
||||
|
||||
expected_ratio = 40000 / (1920 * 1080)
|
||||
assert abs(detection.area_ratio() - expected_ratio) < 1e-6
|
||||
|
||||
def test_meets_threshold(self):
|
||||
"""Test threshold checking."""
|
||||
detection = LightweightDetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox_area=40000,
|
||||
frame_width=1920,
|
||||
frame_height=1080
|
||||
)
|
||||
|
||||
assert detection.meets_threshold(confidence=0.8, area_ratio=0.01) is True
|
||||
assert detection.meets_threshold(confidence=0.9, area_ratio=0.01) is False
|
||||
assert detection.meets_threshold(confidence=0.8, area_ratio=0.1) is False
|
||||
|
||||
|
||||
class TestDetectionSession:
|
||||
"""Test DetectionSession data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating detection session."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
assert session.session_id == "session_123"
|
||||
assert session.camera_id == "camera_001"
|
||||
assert session.display_id == "display_001"
|
||||
assert session.detections == []
|
||||
assert session.metadata == {}
|
||||
|
||||
def test_add_detection(self):
|
||||
"""Test adding detection to session."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
session.add_detection(detection)
|
||||
|
||||
assert len(session.detections) == 1
|
||||
assert session.detections[0] == detection
|
||||
|
||||
def test_get_latest_detection(self):
|
||||
"""Test getting latest detection."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
# Add multiple detections
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=12345,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.90,
|
||||
bbox=bbox2,
|
||||
track_id=12345,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
session.add_detection(detection1)
|
||||
session.add_detection(detection2)
|
||||
|
||||
latest = session.get_latest_detection()
|
||||
assert latest == detection2 # Should be the one with later timestamp
|
||||
|
||||
def test_get_detections_by_class(self):
|
||||
"""Test filtering detections by class."""
|
||||
session = DetectionSession(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
car_detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=12345
|
||||
)
|
||||
|
||||
truck_detection = DetectionResult(
|
||||
class_name="truck",
|
||||
confidence=0.80,
|
||||
bbox=bbox,
|
||||
track_id=54321
|
||||
)
|
||||
|
||||
session.add_detection(car_detection)
|
||||
session.add_detection(truck_detection)
|
||||
|
||||
car_detections = session.get_detections_by_class("car")
|
||||
assert len(car_detections) == 1
|
||||
assert car_detections[0] == car_detection
|
||||
|
||||
truck_detections = session.get_detections_by_class("truck")
|
||||
assert len(truck_detections) == 1
|
||||
assert truck_detections[0] == truck_detection
|
||||
|
||||
|
||||
class TestTrackValidationResult:
|
||||
"""Test TrackValidationResult data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating track validation result."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102, 103],
|
||||
current_tracks=[101, 102, 104, 105],
|
||||
newly_stable=[103],
|
||||
lost_tracks=[106]
|
||||
)
|
||||
|
||||
assert result.stable_tracks == [101, 102, 103]
|
||||
assert result.current_tracks == [101, 102, 104, 105]
|
||||
assert result.newly_stable == [103]
|
||||
assert result.lost_tracks == [106]
|
||||
|
||||
def test_has_stable_tracks(self):
|
||||
"""Test checking for stable tracks."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102],
|
||||
current_tracks=[101, 102, 103]
|
||||
)
|
||||
|
||||
assert result.has_stable_tracks() is True
|
||||
|
||||
result_empty = TrackValidationResult(
|
||||
stable_tracks=[],
|
||||
current_tracks=[101, 102, 103]
|
||||
)
|
||||
|
||||
assert result_empty.has_stable_tracks() is False
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting validation statistics."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102, 103],
|
||||
current_tracks=[101, 102, 104, 105],
|
||||
newly_stable=[103],
|
||||
lost_tracks=[106]
|
||||
)
|
||||
|
||||
stats = result.get_stats()
|
||||
|
||||
assert stats["stable_count"] == 3
|
||||
assert stats["current_count"] == 4
|
||||
assert stats["newly_stable_count"] == 1
|
||||
assert stats["lost_count"] == 1
|
||||
assert stats["stability_ratio"] == 3/4 # stable/current
|
||||
|
||||
def test_is_track_stable(self):
|
||||
"""Test checking if specific track is stable."""
|
||||
result = TrackValidationResult(
|
||||
stable_tracks=[101, 102, 103],
|
||||
current_tracks=[101, 102, 104, 105]
|
||||
)
|
||||
|
||||
assert result.is_track_stable(101) is True
|
||||
assert result.is_track_stable(102) is True
|
||||
assert result.is_track_stable(104) is False
|
||||
assert result.is_track_stable(999) is False
|
701
tests/unit/detection/test_stability_validator.py
Normal file
701
tests/unit/detection/test_stability_validator.py
Normal file
|
@ -0,0 +1,701 @@
|
|||
"""
|
||||
Unit tests for track stability validation.
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import Mock, patch
|
||||
from collections import defaultdict
|
||||
|
||||
from detector_worker.detection.stability_validator import (
|
||||
StabilityValidator,
|
||||
StabilityConfig,
|
||||
ValidationResult,
|
||||
TrackStabilityMetrics
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox, TrackValidationResult
|
||||
from detector_worker.core.exceptions import ValidationError
|
||||
|
||||
|
||||
class TestStabilityConfig:
|
||||
"""Test stability configuration data structure."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default stability configuration."""
|
||||
config = StabilityConfig()
|
||||
|
||||
assert config.min_detection_frames == 10
|
||||
assert config.max_absence_frames == 30
|
||||
assert config.confidence_threshold == 0.5
|
||||
assert config.stability_window == 60.0
|
||||
assert config.iou_threshold == 0.3
|
||||
assert config.movement_threshold == 50.0
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom stability configuration."""
|
||||
config = StabilityConfig(
|
||||
min_detection_frames=5,
|
||||
max_absence_frames=15,
|
||||
confidence_threshold=0.8,
|
||||
stability_window=30.0,
|
||||
iou_threshold=0.5,
|
||||
movement_threshold=25.0
|
||||
)
|
||||
|
||||
assert config.min_detection_frames == 5
|
||||
assert config.max_absence_frames == 15
|
||||
assert config.confidence_threshold == 0.8
|
||||
assert config.stability_window == 30.0
|
||||
assert config.iou_threshold == 0.5
|
||||
assert config.movement_threshold == 25.0
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"min_detection_frames": 8,
|
||||
"max_absence_frames": 25,
|
||||
"confidence_threshold": 0.75,
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = StabilityConfig.from_dict(config_dict)
|
||||
|
||||
assert config.min_detection_frames == 8
|
||||
assert config.max_absence_frames == 25
|
||||
assert config.confidence_threshold == 0.75
|
||||
# Unknown fields should use defaults
|
||||
assert config.stability_window == 60.0
|
||||
|
||||
|
||||
class TestTrackStabilityMetrics:
|
||||
"""Test track stability metrics."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test metrics initialization."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
assert metrics.track_id == 1001
|
||||
assert metrics.detection_count == 0
|
||||
assert metrics.absence_count == 0
|
||||
assert metrics.total_confidence == 0.0
|
||||
assert metrics.first_detection_time is None
|
||||
assert metrics.last_detection_time is None
|
||||
assert metrics.bounding_boxes == []
|
||||
assert metrics.confidence_scores == []
|
||||
|
||||
def test_add_detection(self):
|
||||
"""Test adding detection to metrics."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
metrics.add_detection(detection, current_time=1640995200.0)
|
||||
|
||||
assert metrics.detection_count == 1
|
||||
assert metrics.absence_count == 0
|
||||
assert metrics.total_confidence == 0.85
|
||||
assert metrics.first_detection_time == 1640995200.0
|
||||
assert metrics.last_detection_time == 1640995200.0
|
||||
assert len(metrics.bounding_boxes) == 1
|
||||
assert len(metrics.confidence_scores) == 1
|
||||
|
||||
def test_increment_absence(self):
|
||||
"""Test incrementing absence count."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
metrics.increment_absence()
|
||||
assert metrics.absence_count == 1
|
||||
|
||||
metrics.increment_absence()
|
||||
assert metrics.absence_count == 2
|
||||
|
||||
def test_reset_absence(self):
|
||||
"""Test resetting absence count."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
metrics.increment_absence()
|
||||
metrics.increment_absence()
|
||||
assert metrics.absence_count == 2
|
||||
|
||||
metrics.reset_absence()
|
||||
assert metrics.absence_count == 0
|
||||
|
||||
def test_average_confidence(self):
|
||||
"""Test average confidence calculation."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
# No detections
|
||||
assert metrics.average_confidence() == 0.0
|
||||
|
||||
# Add detections
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.9,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
metrics.add_detection(detection1, current_time=1640995200.0)
|
||||
metrics.add_detection(detection2, current_time=1640995300.0)
|
||||
|
||||
assert metrics.average_confidence() == 0.85 # (0.8 + 0.9) / 2
|
||||
|
||||
def test_tracking_duration(self):
|
||||
"""Test tracking duration calculation."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
# No detections
|
||||
assert metrics.tracking_duration() == 0.0
|
||||
|
||||
# Add detections
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.9,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
metrics.add_detection(detection1, current_time=1640995200.0)
|
||||
metrics.add_detection(detection2, current_time=1640995300.0)
|
||||
|
||||
assert metrics.tracking_duration() == 100.0 # 1640995300 - 1640995200
|
||||
|
||||
def test_movement_distance(self):
|
||||
"""Test movement distance calculation."""
|
||||
metrics = TrackStabilityMetrics(track_id=1001)
|
||||
|
||||
# No movement with single detection
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
metrics.add_detection(detection1, current_time=1640995200.0)
|
||||
assert metrics.total_movement_distance() == 0.0
|
||||
|
||||
# Add second detection with movement
|
||||
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.9,
|
||||
bbox=bbox2,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
metrics.add_detection(detection2, current_time=1640995300.0)
|
||||
|
||||
# Distance between centers: (200,300) to (210,310) = sqrt(100+100) ≈ 14.14
|
||||
movement = metrics.total_movement_distance()
|
||||
assert movement == pytest.approx(14.14, rel=1e-2)
|
||||
|
||||
|
||||
class TestValidationResult:
|
||||
"""Test validation result data structure."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test validation result initialization."""
|
||||
result = ValidationResult(
|
||||
track_id=1001,
|
||||
is_stable=True,
|
||||
detection_count=15,
|
||||
absence_count=2,
|
||||
average_confidence=0.85,
|
||||
tracking_duration=120.0
|
||||
)
|
||||
|
||||
assert result.track_id == 1001
|
||||
assert result.is_stable is True
|
||||
assert result.detection_count == 15
|
||||
assert result.absence_count == 2
|
||||
assert result.average_confidence == 0.85
|
||||
assert result.tracking_duration == 120.0
|
||||
assert result.reasons == []
|
||||
|
||||
def test_with_reasons(self):
|
||||
"""Test validation result with failure reasons."""
|
||||
result = ValidationResult(
|
||||
track_id=1001,
|
||||
is_stable=False,
|
||||
detection_count=5,
|
||||
absence_count=35,
|
||||
average_confidence=0.4,
|
||||
tracking_duration=30.0,
|
||||
reasons=["Insufficient detection frames", "Too many absences", "Low confidence"]
|
||||
)
|
||||
|
||||
assert result.is_stable is False
|
||||
assert len(result.reasons) == 3
|
||||
assert "Insufficient detection frames" in result.reasons
|
||||
|
||||
|
||||
class TestStabilityValidator:
|
||||
"""Test stability validation functionality."""
|
||||
|
||||
def test_initialization_default(self):
|
||||
"""Test validator initialization with default config."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
assert isinstance(validator.config, StabilityConfig)
|
||||
assert validator.config.min_detection_frames == 10
|
||||
assert len(validator.track_metrics) == 0
|
||||
|
||||
def test_initialization_custom_config(self):
|
||||
"""Test validator initialization with custom config."""
|
||||
config = StabilityConfig(min_detection_frames=5, confidence_threshold=0.8)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
assert validator.config.min_detection_frames == 5
|
||||
assert validator.config.confidence_threshold == 0.8
|
||||
|
||||
def test_update_detections_new_track(self):
|
||||
"""Test updating with new track."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
validator.update_detections([detection], current_time=1640995200.0)
|
||||
|
||||
assert 1001 in validator.track_metrics
|
||||
metrics = validator.track_metrics[1001]
|
||||
assert metrics.detection_count == 1
|
||||
assert metrics.absence_count == 0
|
||||
|
||||
def test_update_detections_existing_track(self):
|
||||
"""Test updating existing track."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
# First detection
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
validator.update_detections([detection1], current_time=1640995200.0)
|
||||
|
||||
# Second detection
|
||||
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.9,
|
||||
bbox=bbox2,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
validator.update_detections([detection2], current_time=1640995300.0)
|
||||
|
||||
metrics = validator.track_metrics[1001]
|
||||
assert metrics.detection_count == 2
|
||||
assert metrics.absence_count == 0
|
||||
assert metrics.average_confidence() == 0.85
|
||||
|
||||
def test_update_detections_missing_track(self):
|
||||
"""Test updating when track is missing (increment absence)."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
# Add track
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
validator.update_detections([detection], current_time=1640995200.0)
|
||||
|
||||
# Update with empty detections
|
||||
validator.update_detections([], current_time=1640995300.0)
|
||||
|
||||
metrics = validator.track_metrics[1001]
|
||||
assert metrics.detection_count == 1
|
||||
assert metrics.absence_count == 1
|
||||
|
||||
def test_validate_track_stable(self):
|
||||
"""Test validating a stable track."""
|
||||
config = StabilityConfig(min_detection_frames=3, max_absence_frames=5)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Create track with sufficient detections
|
||||
track_id = 1001
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
# Add sufficient detections
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(5):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
result = validator.validate_track(track_id)
|
||||
|
||||
assert result.is_stable is True
|
||||
assert result.detection_count == 5
|
||||
assert result.absence_count == 0
|
||||
assert len(result.reasons) == 0
|
||||
|
||||
def test_validate_track_insufficient_detections(self):
|
||||
"""Test validating track with insufficient detections."""
|
||||
config = StabilityConfig(min_detection_frames=10, max_absence_frames=5)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Create track with insufficient detections
|
||||
track_id = 1001
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
# Add only few detections
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(3):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
result = validator.validate_track(track_id)
|
||||
|
||||
assert result.is_stable is False
|
||||
assert "Insufficient detection frames" in result.reasons
|
||||
|
||||
def test_validate_track_too_many_absences(self):
|
||||
"""Test validating track with too many absences."""
|
||||
config = StabilityConfig(min_detection_frames=3, max_absence_frames=2)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Create track with too many absences
|
||||
track_id = 1001
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
# Add detections and absences
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(5):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
# Add too many absences
|
||||
for _ in range(5):
|
||||
metrics.increment_absence()
|
||||
|
||||
result = validator.validate_track(track_id)
|
||||
|
||||
assert result.is_stable is False
|
||||
assert "Too many absence frames" in result.reasons
|
||||
|
||||
def test_validate_track_low_confidence(self):
|
||||
"""Test validating track with low confidence."""
|
||||
config = StabilityConfig(
|
||||
min_detection_frames=3,
|
||||
max_absence_frames=5,
|
||||
confidence_threshold=0.8
|
||||
)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Create track with low confidence
|
||||
track_id = 1001
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
# Add detections with low confidence
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(5):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.5, # Below threshold
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
result = validator.validate_track(track_id)
|
||||
|
||||
assert result.is_stable is False
|
||||
assert "Low average confidence" in result.reasons
|
||||
|
||||
def test_validate_all_tracks(self):
|
||||
"""Test validating all tracks."""
|
||||
config = StabilityConfig(min_detection_frames=3)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Add multiple tracks
|
||||
for track_id in [1001, 1002, 1003]:
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
# Make some tracks stable, others not
|
||||
detection_count = 5 if track_id == 1001 else 2
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
for i in range(detection_count):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
results = validator.validate_all_tracks()
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[1001].is_stable is True # 5 detections
|
||||
assert results[1002].is_stable is False # 2 detections
|
||||
assert results[1003].is_stable is False # 2 detections
|
||||
|
||||
def test_get_stable_tracks(self):
|
||||
"""Test getting stable track IDs."""
|
||||
config = StabilityConfig(min_detection_frames=3)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Add tracks with different stability
|
||||
for track_id, detection_count in [(1001, 5), (1002, 2), (1003, 4)]:
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(detection_count):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
stable_tracks = validator.get_stable_tracks()
|
||||
|
||||
assert stable_tracks == [1001, 1003] # 5 and 4 detections respectively
|
||||
|
||||
def test_cleanup_expired_tracks(self):
|
||||
"""Test cleanup of expired tracks."""
|
||||
config = StabilityConfig(stability_window=10.0)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Add tracks with different last detection times
|
||||
current_time = 1640995300.0
|
||||
|
||||
for track_id, last_detection_time in [(1001, current_time - 5), (1002, current_time - 15)]:
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=int(last_detection_time * 1000)
|
||||
)
|
||||
metrics.add_detection(detection, current_time=last_detection_time)
|
||||
|
||||
removed_count = validator.cleanup_expired_tracks(current_time)
|
||||
|
||||
assert removed_count == 1 # 1002 should be removed (15 > 10 seconds)
|
||||
assert 1001 in validator.track_metrics
|
||||
assert 1002 not in validator.track_metrics
|
||||
|
||||
def test_clear_all_tracks(self):
|
||||
"""Test clearing all track metrics."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
# Add some tracks
|
||||
for track_id in [1001, 1002]:
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
|
||||
assert len(validator.track_metrics) == 2
|
||||
|
||||
validator.clear_all_tracks()
|
||||
|
||||
assert len(validator.track_metrics) == 0
|
||||
|
||||
def test_get_validation_summary(self):
|
||||
"""Test getting validation summary statistics."""
|
||||
config = StabilityConfig(min_detection_frames=3)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
# Add tracks with different characteristics
|
||||
track_data = [
|
||||
(1001, 5, True), # Stable
|
||||
(1002, 2, False), # Unstable
|
||||
(1003, 4, True), # Stable
|
||||
(1004, 1, False) # Unstable
|
||||
]
|
||||
|
||||
for track_id, detection_count, _ in track_data:
|
||||
validator.track_metrics[track_id] = TrackStabilityMetrics(track_id)
|
||||
metrics = validator.track_metrics[track_id]
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
for i in range(detection_count):
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
metrics.add_detection(detection, current_time=1640995200.0 + i)
|
||||
|
||||
summary = validator.get_validation_summary()
|
||||
|
||||
assert summary["total_tracks"] == 4
|
||||
assert summary["stable_tracks"] == 2
|
||||
assert summary["unstable_tracks"] == 2
|
||||
assert summary["stability_rate"] == 0.5
|
||||
|
||||
|
||||
class TestStabilityValidatorIntegration:
|
||||
"""Integration tests for stability validator."""
|
||||
|
||||
def test_full_tracking_lifecycle(self):
|
||||
"""Test complete tracking lifecycle with stability validation."""
|
||||
config = StabilityConfig(
|
||||
min_detection_frames=3,
|
||||
max_absence_frames=2,
|
||||
confidence_threshold=0.7
|
||||
)
|
||||
validator = StabilityValidator(config)
|
||||
|
||||
track_id = 1001
|
||||
|
||||
# Phase 1: Initial detections (building up)
|
||||
for i in range(5):
|
||||
bbox = BoundingBox(x1=100+i*2, y1=200+i*2, x2=300+i*2, y2=400+i*2)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.8,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995200000 + i * 1000
|
||||
)
|
||||
validator.update_detections([detection], current_time=1640995200.0 + i)
|
||||
|
||||
# Should be stable now
|
||||
result = validator.validate_track(track_id)
|
||||
assert result.is_stable is True
|
||||
|
||||
# Phase 2: Some absences
|
||||
for i in range(2):
|
||||
validator.update_detections([], current_time=1640995205.0 + i)
|
||||
|
||||
# Still stable (within absence threshold)
|
||||
result = validator.validate_track(track_id)
|
||||
assert result.is_stable is True
|
||||
|
||||
# Phase 3: Track reappears
|
||||
bbox = BoundingBox(x1=120, y1=220, x2=320, y2=420)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=track_id,
|
||||
timestamp=1640995207000
|
||||
)
|
||||
validator.update_detections([detection], current_time=1640995207.0)
|
||||
|
||||
# Should reset absence count and remain stable
|
||||
result = validator.validate_track(track_id)
|
||||
assert result.is_stable is True
|
||||
assert validator.track_metrics[track_id].absence_count == 0
|
||||
|
||||
def test_multi_track_validation(self):
|
||||
"""Test validation with multiple tracks."""
|
||||
validator = StabilityValidator()
|
||||
|
||||
# Simulate multi-track scenario
|
||||
frame_detections = [
|
||||
# Frame 1
|
||||
[
|
||||
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000),
|
||||
DetectionResult("truck", 0.8, BoundingBox(400, 200, 600, 400), 1002, 1640995200000)
|
||||
],
|
||||
# Frame 2
|
||||
[
|
||||
DetectionResult("car", 0.85, BoundingBox(105, 205, 305, 405), 1001, 1640995201000),
|
||||
DetectionResult("truck", 0.82, BoundingBox(405, 205, 605, 405), 1002, 1640995201000),
|
||||
DetectionResult("car", 0.75, BoundingBox(200, 300, 400, 500), 1003, 1640995201000)
|
||||
],
|
||||
# Frame 3 - track 1002 disappears
|
||||
[
|
||||
DetectionResult("car", 0.88, BoundingBox(110, 210, 310, 410), 1001, 1640995202000),
|
||||
DetectionResult("car", 0.78, BoundingBox(205, 305, 405, 505), 1003, 1640995202000)
|
||||
]
|
||||
]
|
||||
|
||||
# Process frames
|
||||
for i, detections in enumerate(frame_detections):
|
||||
validator.update_detections(detections, current_time=1640995200.0 + i)
|
||||
|
||||
# Get validation results
|
||||
validation_results = validator.validate_all_tracks()
|
||||
|
||||
assert len(validation_results) == 3
|
||||
|
||||
# All tracks should be unstable (insufficient frames)
|
||||
for result in validation_results.values():
|
||||
assert result.is_stable is False
|
||||
assert "Insufficient detection frames" in result.reasons
|
606
tests/unit/detection/test_tracking_manager.py
Normal file
606
tests/unit/detection/test_tracking_manager.py
Normal file
|
@ -0,0 +1,606 @@
|
|||
"""
|
||||
Unit tests for BoT-SORT tracking management.
|
||||
"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from collections import defaultdict
|
||||
|
||||
from detector_worker.detection.tracking_manager import TrackingManager, TrackInfo
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
from detector_worker.core.exceptions import TrackingError
|
||||
|
||||
|
||||
class TestTrackInfo:
|
||||
"""Test TrackInfo data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test TrackInfo creation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
|
||||
assert track.track_id == 1001
|
||||
assert track.bbox == bbox
|
||||
assert track.confidence == 0.85
|
||||
assert track.class_name == "car"
|
||||
assert track.first_seen == 1640995200.0
|
||||
assert track.last_seen == 1640995300.0
|
||||
assert track.frame_count == 1
|
||||
assert track.absence_count == 0
|
||||
|
||||
def test_update_track(self):
|
||||
"""Test updating track information."""
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox1,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995200.0
|
||||
)
|
||||
|
||||
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
|
||||
track.update(bbox2, 0.90, 1640995300.0)
|
||||
|
||||
assert track.bbox == bbox2
|
||||
assert track.confidence == 0.90
|
||||
assert track.last_seen == 1640995300.0
|
||||
assert track.frame_count == 2
|
||||
assert track.absence_count == 0
|
||||
|
||||
def test_increment_absence(self):
|
||||
"""Test incrementing absence count."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995200.0
|
||||
)
|
||||
|
||||
track.increment_absence()
|
||||
assert track.absence_count == 1
|
||||
|
||||
track.increment_absence()
|
||||
assert track.absence_count == 2
|
||||
|
||||
def test_age_calculation(self):
|
||||
"""Test track age calculation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
|
||||
age = track.age(current_time=1640995400.0)
|
||||
assert age == 200.0 # 1640995400 - 1640995200
|
||||
|
||||
def test_time_since_last_seen(self):
|
||||
"""Test time since last seen calculation."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
|
||||
time_since = track.time_since_last_seen(current_time=1640995450.0)
|
||||
assert time_since == 150.0 # 1640995450 - 1640995300
|
||||
|
||||
def test_is_stable(self):
|
||||
"""Test track stability checking."""
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
|
||||
# Not stable initially
|
||||
assert track.is_stable(min_frames=5, max_absence=3) is False
|
||||
|
||||
# Make it stable
|
||||
track.frame_count = 10
|
||||
track.absence_count = 1
|
||||
assert track.is_stable(min_frames=5, max_absence=3) is True
|
||||
|
||||
# Too many absences
|
||||
track.absence_count = 5
|
||||
assert track.is_stable(min_frames=5, max_absence=3) is False
|
||||
|
||||
|
||||
class TestTrackingManager:
|
||||
"""Test tracking management functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test tracking manager initialization."""
|
||||
manager = TrackingManager()
|
||||
|
||||
assert manager.max_absence_frames == 30
|
||||
assert manager.min_stable_frames == 10
|
||||
assert manager.track_timeout == 60.0
|
||||
assert len(manager.active_tracks) == 0
|
||||
assert len(manager.stable_tracks) == 0
|
||||
|
||||
def test_initialization_with_config(self):
|
||||
"""Test initialization with custom configuration."""
|
||||
config = {
|
||||
"max_absence_frames": 20,
|
||||
"min_stable_frames": 5,
|
||||
"track_timeout": 30.0
|
||||
}
|
||||
manager = TrackingManager(config)
|
||||
|
||||
assert manager.max_absence_frames == 20
|
||||
assert manager.min_stable_frames == 5
|
||||
assert manager.track_timeout == 30.0
|
||||
|
||||
def test_update_tracks_new_detections(self):
|
||||
"""Test updating with new detections."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
assert len(manager.active_tracks) == 1
|
||||
assert 1001 in manager.active_tracks
|
||||
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.track_id == 1001
|
||||
assert track.class_name == "car"
|
||||
assert track.confidence == 0.85
|
||||
assert track.frame_count == 1
|
||||
|
||||
def test_update_tracks_existing_detection(self):
|
||||
"""Test updating existing track."""
|
||||
manager = TrackingManager()
|
||||
|
||||
# First detection
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection1], current_time=1640995200.0)
|
||||
|
||||
# Second detection (same track, different position)
|
||||
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.90,
|
||||
bbox=bbox2,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection2], current_time=1640995300.0)
|
||||
|
||||
assert len(manager.active_tracks) == 1
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.frame_count == 2
|
||||
assert track.confidence == 0.90
|
||||
assert track.bbox == bbox2
|
||||
assert track.absence_count == 0
|
||||
|
||||
def test_update_tracks_no_detections(self):
|
||||
"""Test updating with no detections (increment absence)."""
|
||||
manager = TrackingManager()
|
||||
|
||||
# Add initial track
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
# Update with no detections
|
||||
manager.update_tracks([], current_time=1640995300.0)
|
||||
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.absence_count == 1
|
||||
|
||||
def test_cleanup_expired_tracks(self):
|
||||
"""Test cleanup of expired tracks."""
|
||||
manager = TrackingManager({"track_timeout": 10.0})
|
||||
|
||||
# Add track
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
assert len(manager.active_tracks) == 1
|
||||
|
||||
# Cleanup after timeout
|
||||
removed_count = manager.cleanup_expired_tracks(current_time=1640995220.0) # 20 seconds later
|
||||
|
||||
assert removed_count == 1
|
||||
assert len(manager.active_tracks) == 0
|
||||
|
||||
def test_cleanup_absent_tracks(self):
|
||||
"""Test cleanup of tracks with too many absences."""
|
||||
manager = TrackingManager({"max_absence_frames": 3})
|
||||
|
||||
# Add track
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
# Increment absence count beyond threshold
|
||||
for i in range(5):
|
||||
manager.update_tracks([], current_time=1640995200.0 + i)
|
||||
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.absence_count == 5
|
||||
|
||||
# Cleanup absent tracks
|
||||
removed_count = manager.cleanup_absent_tracks()
|
||||
|
||||
assert removed_count == 1
|
||||
assert len(manager.active_tracks) == 0
|
||||
|
||||
def test_get_stable_tracks(self):
|
||||
"""Test getting stable tracks."""
|
||||
manager = TrackingManager({"min_stable_frames": 3})
|
||||
|
||||
# Add track and make it stable
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track_info = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
track_info.frame_count = 5 # Make it stable
|
||||
|
||||
manager.active_tracks[1001] = track_info
|
||||
|
||||
stable_tracks = manager.get_stable_tracks()
|
||||
|
||||
assert len(stable_tracks) == 1
|
||||
assert 1001 in stable_tracks
|
||||
assert 1001 in manager.stable_tracks # Should be cached
|
||||
|
||||
def test_get_track_by_id(self):
|
||||
"""Test getting track by ID."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
track = manager.get_track_by_id(1001)
|
||||
assert track is not None
|
||||
assert track.track_id == 1001
|
||||
|
||||
non_existent = manager.get_track_by_id(9999)
|
||||
assert non_existent is None
|
||||
|
||||
def test_get_tracks_by_class(self):
|
||||
"""Test getting tracks by class name."""
|
||||
manager = TrackingManager()
|
||||
|
||||
# Add different classes
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
|
||||
detection2 = DetectionResult(
|
||||
class_name="truck",
|
||||
confidence=0.80,
|
||||
bbox=bbox2,
|
||||
track_id=1002,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500)
|
||||
detection3 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.90,
|
||||
bbox=bbox3,
|
||||
track_id=1003,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection1, detection2, detection3], current_time=1640995200.0)
|
||||
|
||||
car_tracks = manager.get_tracks_by_class("car")
|
||||
assert len(car_tracks) == 2
|
||||
assert 1001 in car_tracks
|
||||
assert 1003 in car_tracks
|
||||
|
||||
truck_tracks = manager.get_tracks_by_class("truck")
|
||||
assert len(truck_tracks) == 1
|
||||
assert 1002 in truck_tracks
|
||||
|
||||
def test_get_track_count(self):
|
||||
"""Test getting track counts."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
assert manager.get_active_track_count() == 1
|
||||
assert manager.get_track_count_by_class("car") == 1
|
||||
assert manager.get_track_count_by_class("truck") == 0
|
||||
|
||||
def test_clear_all_tracks(self):
|
||||
"""Test clearing all tracks."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
assert len(manager.active_tracks) == 1
|
||||
|
||||
manager.clear_all_tracks()
|
||||
|
||||
assert len(manager.active_tracks) == 0
|
||||
assert len(manager.stable_tracks) == 0
|
||||
|
||||
def test_get_track_statistics(self):
|
||||
"""Test getting track statistics."""
|
||||
manager = TrackingManager({"min_stable_frames": 2})
|
||||
|
||||
# Add multiple tracks
|
||||
detections = []
|
||||
for i in range(3):
|
||||
bbox = BoundingBox(x1=100+i*50, y1=200, x2=300+i*50, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=1001+i,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
manager.update_tracks(detections, current_time=1640995200.0)
|
||||
|
||||
# Make some tracks stable
|
||||
manager.active_tracks[1001].frame_count = 5
|
||||
manager.active_tracks[1002].frame_count = 3
|
||||
# 1003 remains unstable with frame_count=1
|
||||
|
||||
stats = manager.get_track_statistics()
|
||||
|
||||
assert stats["active_tracks"] == 3
|
||||
assert stats["stable_tracks"] == 2
|
||||
assert stats["unstable_tracks"] == 1
|
||||
assert "average_track_age" in stats
|
||||
assert "average_confidence" in stats
|
||||
|
||||
def test_validate_tracks(self):
|
||||
"""Test track validation."""
|
||||
manager = TrackingManager({"min_stable_frames": 3, "max_absence_frames": 2})
|
||||
|
||||
# Add tracks with different stability
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
track1 = TrackInfo(
|
||||
track_id=1001,
|
||||
bbox=bbox1,
|
||||
confidence=0.85,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995300.0
|
||||
)
|
||||
track1.frame_count = 5 # Stable
|
||||
track1.absence_count = 1 # Present
|
||||
|
||||
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
|
||||
track2 = TrackInfo(
|
||||
track_id=1002,
|
||||
bbox=bbox2,
|
||||
confidence=0.80,
|
||||
class_name="car",
|
||||
first_seen=1640995200.0,
|
||||
last_seen=1640995250.0
|
||||
)
|
||||
track2.frame_count = 2 # Not stable
|
||||
track2.absence_count = 1
|
||||
|
||||
bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500)
|
||||
track3 = TrackInfo(
|
||||
track_id=1003,
|
||||
bbox=bbox3,
|
||||
confidence=0.90,
|
||||
class_name="car",
|
||||
first_seen=1640995100.0,
|
||||
last_seen=1640995150.0
|
||||
)
|
||||
track3.frame_count = 8 # Was stable but now absent
|
||||
track3.absence_count = 5 # Too many absences
|
||||
|
||||
manager.active_tracks = {1001: track1, 1002: track2, 1003: track3}
|
||||
manager.stable_tracks = {1001, 1003} # 1003 was previously stable
|
||||
|
||||
validation_result = manager.validate_tracks()
|
||||
|
||||
assert validation_result.stable_tracks == [1001]
|
||||
assert validation_result.current_tracks == [1001, 1002, 1003]
|
||||
assert validation_result.newly_stable == []
|
||||
assert validation_result.lost_tracks == [1003]
|
||||
|
||||
def test_track_persistence_across_frames(self):
|
||||
"""Test track persistence across multiple frames."""
|
||||
manager = TrackingManager()
|
||||
|
||||
# Frame 1
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection1], current_time=1640995200.0)
|
||||
|
||||
# Frame 2 - track moves
|
||||
bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410)
|
||||
detection2 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.88,
|
||||
bbox=bbox2,
|
||||
track_id=1001,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection2], current_time=1640995300.0)
|
||||
|
||||
# Frame 3 - track disappears
|
||||
manager.update_tracks([], current_time=1640995400.0)
|
||||
|
||||
# Frame 4 - track reappears
|
||||
bbox4 = BoundingBox(x1=120, y1=220, x2=320, y2=420)
|
||||
detection4 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.82,
|
||||
bbox=bbox4,
|
||||
track_id=1001,
|
||||
timestamp=1640995500000
|
||||
)
|
||||
|
||||
manager.update_tracks([detection4], current_time=1640995500.0)
|
||||
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.frame_count == 3 # Seen in 3 frames
|
||||
assert track.absence_count == 0 # Reset when reappeared
|
||||
assert track.bbox == bbox4 # Latest position
|
||||
|
||||
|
||||
class TestTrackingManagerErrorHandling:
|
||||
"""Test error handling in tracking manager."""
|
||||
|
||||
def test_invalid_detection_input(self):
|
||||
"""Test handling of invalid detection input."""
|
||||
manager = TrackingManager()
|
||||
|
||||
# None detection should be handled gracefully
|
||||
with pytest.raises(TrackingError):
|
||||
manager.update_tracks([None], current_time=1640995200.0)
|
||||
|
||||
def test_negative_track_id(self):
|
||||
"""Test handling of negative track ID."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox,
|
||||
track_id=-1, # Invalid track ID
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
with pytest.raises(TrackingError):
|
||||
manager.update_tracks([detection], current_time=1640995200.0)
|
||||
|
||||
def test_duplicate_track_ids_different_classes(self):
|
||||
"""Test handling of duplicate track IDs with different classes."""
|
||||
manager = TrackingManager()
|
||||
|
||||
bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
detection1 = DetectionResult(
|
||||
class_name="car",
|
||||
confidence=0.85,
|
||||
bbox=bbox1,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450)
|
||||
detection2 = DetectionResult(
|
||||
class_name="truck", # Different class, same ID
|
||||
confidence=0.80,
|
||||
bbox=bbox2,
|
||||
track_id=1001,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
# Should log warning but handle gracefully
|
||||
manager.update_tracks([detection1, detection2], current_time=1640995200.0)
|
||||
|
||||
# The later detection should update the track
|
||||
track = manager.active_tracks[1001]
|
||||
assert track.class_name == "truck" # Last update wins
|
386
tests/unit/detection/test_yolo_detector.py
Normal file
386
tests/unit/detection/test_yolo_detector.py
Normal file
|
@ -0,0 +1,386 @@
|
|||
"""
|
||||
Unit tests for YOLO detector with tracking functionality.
|
||||
"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
import torch
|
||||
|
||||
from detector_worker.detection.yolo_detector import YOLODetector
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
from detector_worker.core.exceptions import DetectionError
|
||||
|
||||
|
||||
class TestYOLODetector:
|
||||
"""Test YOLO detection and tracking functionality."""
|
||||
|
||||
def test_initialization_with_valid_model(self, mock_yolo_model):
|
||||
"""Test detector initialization with valid model."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
assert detector.model is mock_yolo_model
|
||||
assert detector.class_names == {}
|
||||
assert detector.is_tracking_enabled is True
|
||||
|
||||
def test_initialization_with_class_names(self, mock_yolo_model):
|
||||
"""Test detector initialization with class names."""
|
||||
class_names = {0: "car", 1: "truck", 2: "bus"}
|
||||
detector = YOLODetector(mock_yolo_model, class_names=class_names)
|
||||
|
||||
assert detector.class_names == class_names
|
||||
|
||||
def test_initialization_tracking_disabled(self, mock_yolo_model):
|
||||
"""Test detector initialization with tracking disabled."""
|
||||
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
|
||||
|
||||
assert detector.is_tracking_enabled is False
|
||||
|
||||
def test_detect_with_tracking(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with tracking enabled."""
|
||||
# Mock detection result
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0], # x1, y1, x2, y2, conf, class
|
||||
[150, 250, 350, 450, 0.85, 1]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001, 1002])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert len(detections) == 2
|
||||
assert detections[0].confidence == 0.9
|
||||
assert detections[0].track_id == 1001
|
||||
assert detections[0].bbox.x1 == 100
|
||||
|
||||
mock_yolo_model.track.assert_called_once_with(mock_frame, persist=True, verbose=False)
|
||||
|
||||
def test_detect_without_tracking(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with tracking disabled."""
|
||||
# Mock detection result
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = None # No tracking IDs
|
||||
|
||||
mock_yolo_model.predict.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert len(detections) == 1
|
||||
assert detections[0].track_id is None # No tracking ID
|
||||
|
||||
mock_yolo_model.predict.assert_called_once_with(mock_frame, verbose=False)
|
||||
|
||||
def test_detect_with_class_names(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with class name mapping."""
|
||||
class_names = {0: "car", 1: "truck"}
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0], # car
|
||||
[150, 250, 350, 450, 0.85, 1] # truck
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001, 1002])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model, class_names=class_names)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert detections[0].class_name == "car"
|
||||
assert detections[1].class_name == "truck"
|
||||
|
||||
def test_detect_no_boxes(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection when no objects are detected."""
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = None
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert detections == []
|
||||
|
||||
def test_detect_empty_boxes(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with empty boxes tensor."""
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([]).reshape(0, 6)
|
||||
mock_result.boxes.id = None
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert detections == []
|
||||
|
||||
def test_detect_with_confidence_threshold(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with confidence threshold filtering."""
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0], # Above threshold
|
||||
[150, 250, 350, 450, 0.3, 1] # Below threshold
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001, 1002])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame, confidence_threshold=0.5)
|
||||
|
||||
assert len(detections) == 1 # Only one above threshold
|
||||
assert detections[0].confidence == 0.9
|
||||
|
||||
def test_detect_model_error_handling(self, mock_yolo_model, mock_frame):
|
||||
"""Test error handling when model fails."""
|
||||
mock_yolo_model.track.side_effect = Exception("Model inference failed")
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
with pytest.raises(DetectionError) as exc_info:
|
||||
detector.detect(mock_frame)
|
||||
|
||||
assert "Model inference failed" in str(exc_info.value)
|
||||
|
||||
def test_detect_invalid_frame(self, mock_yolo_model):
|
||||
"""Test detection with invalid frame input."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
with pytest.raises(DetectionError) as exc_info:
|
||||
detector.detect(None)
|
||||
|
||||
assert "Invalid frame" in str(exc_info.value)
|
||||
|
||||
def test_detect_result_validation(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection result validation."""
|
||||
# Mock result with invalid bounding box (x2 <= x1)
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[300, 200, 100, 400, 0.9, 0] # Invalid: x2 < x1
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
# Invalid detections should be filtered out
|
||||
assert detections == []
|
||||
|
||||
def test_get_model_info(self, mock_yolo_model):
|
||||
"""Test getting model information."""
|
||||
mock_yolo_model.device = "cuda:0"
|
||||
mock_yolo_model.names = {0: "car", 1: "truck"}
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
info = detector.get_model_info()
|
||||
|
||||
assert info["device"] == "cuda:0"
|
||||
assert info["class_names"] == {0: "car", 1: "truck"}
|
||||
assert info["tracking_enabled"] is True
|
||||
|
||||
def test_set_tracking_enabled(self, mock_yolo_model):
|
||||
"""Test enabling/disabling tracking at runtime."""
|
||||
detector = YOLODetector(mock_yolo_model, enable_tracking=False)
|
||||
assert detector.is_tracking_enabled is False
|
||||
|
||||
detector.set_tracking_enabled(True)
|
||||
assert detector.is_tracking_enabled is True
|
||||
|
||||
detector.set_tracking_enabled(False)
|
||||
assert detector.is_tracking_enabled is False
|
||||
|
||||
def test_update_class_names(self, mock_yolo_model):
|
||||
"""Test updating class names at runtime."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
new_class_names = {0: "vehicle", 1: "person"}
|
||||
detector.update_class_names(new_class_names)
|
||||
|
||||
assert detector.class_names == new_class_names
|
||||
|
||||
def test_reset_tracker(self, mock_yolo_model):
|
||||
"""Test resetting the tracking state."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
# This should not raise an error
|
||||
detector.reset_tracker()
|
||||
|
||||
def test_detect_with_crop_region(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with crop region specified."""
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[50, 75, 150, 175, 0.9, 0] # Relative to cropped region
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
crop_region = (100, 200, 300, 400) # x1, y1, x2, y2
|
||||
detections = detector.detect(mock_frame, crop_region=crop_region)
|
||||
|
||||
# Bounding box should be adjusted to global coordinates
|
||||
assert detections[0].bbox.x1 == 150 # 100 + 50
|
||||
assert detections[0].bbox.y1 == 275 # 200 + 75
|
||||
assert detections[0].bbox.x2 == 250 # 100 + 150
|
||||
assert detections[0].bbox.y2 == 375 # 200 + 175
|
||||
|
||||
def test_detect_batch_processing(self, mock_yolo_model):
|
||||
"""Test batch detection processing."""
|
||||
frames = [
|
||||
np.zeros((480, 640, 3), dtype=np.uint8),
|
||||
np.ones((480, 640, 3), dtype=np.uint8) * 255
|
||||
]
|
||||
|
||||
mock_results = []
|
||||
for i in range(2):
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100 + i*10, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001 + i])
|
||||
mock_results.append(mock_result)
|
||||
|
||||
mock_yolo_model.track.side_effect = [[result] for result in mock_results]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
batch_detections = detector.detect_batch(frames)
|
||||
|
||||
assert len(batch_detections) == 2
|
||||
assert len(batch_detections[0]) == 1
|
||||
assert len(batch_detections[1]) == 1
|
||||
assert batch_detections[0][0].bbox.x1 == 100
|
||||
assert batch_detections[1][0].bbox.x1 == 110
|
||||
|
||||
def test_detect_batch_empty_frames(self, mock_yolo_model):
|
||||
"""Test batch detection with empty frame list."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
batch_detections = detector.detect_batch([])
|
||||
|
||||
assert batch_detections == []
|
||||
|
||||
def test_detect_performance_metrics(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection performance metrics collection."""
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
mock_result.speed = {"preprocess": 2.1, "inference": 15.3, "postprocess": 1.2}
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
detections = detector.detect(mock_frame, return_metrics=True)
|
||||
|
||||
# Check if performance metrics are available
|
||||
assert hasattr(detector, '_last_inference_time')
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda:0", "mps"])
|
||||
def test_detect_different_devices(self, device, mock_frame):
|
||||
"""Test detection on different devices."""
|
||||
mock_model = Mock()
|
||||
mock_model.device = device
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_model.track.return_value = [mock_result]
|
||||
|
||||
detector = YOLODetector(mock_model)
|
||||
detections = detector.detect(mock_frame)
|
||||
|
||||
assert len(detections) == 1
|
||||
assert detections[0].confidence == 0.9
|
||||
|
||||
|
||||
class TestYOLODetectorIntegration:
|
||||
"""Integration tests for YOLO detector."""
|
||||
|
||||
def test_detect_with_real_tensor_operations(self, mock_yolo_model, mock_frame):
|
||||
"""Test detection with realistic tensor operations."""
|
||||
# Create more realistic box data
|
||||
boxes_data = torch.tensor([
|
||||
[100.5, 200.3, 299.7, 399.8, 0.95, 0],
|
||||
[150.2, 250.1, 349.9, 449.6, 0.87, 1],
|
||||
[200.0, 300.0, 400.0, 500.0, 0.45, 0] # Low confidence
|
||||
])
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = boxes_data
|
||||
mock_result.boxes.id = torch.tensor([2001, 2002, 2003])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
class_names = {0: "car", 1: "truck"}
|
||||
detector = YOLODetector(mock_yolo_model, class_names=class_names)
|
||||
|
||||
detections = detector.detect(mock_frame, confidence_threshold=0.5)
|
||||
|
||||
# Should filter out low confidence detection
|
||||
assert len(detections) == 2
|
||||
|
||||
# Check first detection
|
||||
det1 = detections[0]
|
||||
assert det1.class_name == "car"
|
||||
assert det1.confidence == pytest.approx(0.95)
|
||||
assert det1.track_id == 2001
|
||||
assert det1.bbox.x1 == pytest.approx(100.5)
|
||||
assert det1.bbox.y1 == pytest.approx(200.3)
|
||||
|
||||
# Check second detection
|
||||
det2 = detections[1]
|
||||
assert det2.class_name == "truck"
|
||||
assert det2.confidence == pytest.approx(0.87)
|
||||
assert det2.track_id == 2002
|
||||
|
||||
def test_multi_frame_tracking_consistency(self, mock_yolo_model, mock_frame):
|
||||
"""Test that tracking IDs remain consistent across frames."""
|
||||
detector = YOLODetector(mock_yolo_model)
|
||||
|
||||
# Frame 1
|
||||
mock_result1 = Mock()
|
||||
mock_result1.boxes = Mock()
|
||||
mock_result1.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result1.boxes.id = torch.tensor([5001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result1]
|
||||
detections1 = detector.detect(mock_frame)
|
||||
|
||||
# Frame 2 - same object, slightly moved
|
||||
mock_result2 = Mock()
|
||||
mock_result2.boxes = Mock()
|
||||
mock_result2.boxes.data = torch.tensor([
|
||||
[105, 205, 305, 405, 0.88, 0]
|
||||
])
|
||||
mock_result2.boxes.id = torch.tensor([5001]) # Same ID
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result2]
|
||||
detections2 = detector.detect(mock_frame)
|
||||
|
||||
# Should maintain same track ID
|
||||
assert detections1[0].track_id == detections2[0].track_id == 5001
|
882
tests/unit/models/test_model_manager.py
Normal file
882
tests/unit/models/test_model_manager.py
Normal file
|
@ -0,0 +1,882 @@
|
|||
"""
|
||||
Unit tests for model management functionality.
|
||||
"""
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from detector_worker.models.model_manager import (
|
||||
ModelManager,
|
||||
ModelInfo,
|
||||
ModelConfig,
|
||||
ModelCache,
|
||||
ModelLoader,
|
||||
ModelError,
|
||||
ModelLoadError,
|
||||
ModelCacheError
|
||||
)
|
||||
from detector_worker.core.exceptions import ConfigurationError
|
||||
|
||||
|
||||
class TestModelConfig:
|
||||
"""Test model configuration."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test model config creation."""
|
||||
config = ModelConfig(
|
||||
model_id="yolo_v8_car",
|
||||
model_path="/models/yolo_v8_car.pt",
|
||||
model_type="detection",
|
||||
device="cuda:0"
|
||||
)
|
||||
|
||||
assert config.model_id == "yolo_v8_car"
|
||||
assert config.model_path == "/models/yolo_v8_car.pt"
|
||||
assert config.model_type == "detection"
|
||||
assert config.device == "cuda:0"
|
||||
assert config.confidence_threshold == 0.5
|
||||
assert config.max_memory_mb == 1024
|
||||
|
||||
def test_creation_with_optional_params(self):
|
||||
"""Test config creation with optional parameters."""
|
||||
config = ModelConfig(
|
||||
model_id="classifier_v1",
|
||||
model_path="/models/classifier.pt",
|
||||
model_type="classification",
|
||||
device="cpu",
|
||||
confidence_threshold=0.8,
|
||||
max_memory_mb=512,
|
||||
class_names={0: "car", 1: "truck", 2: "bus"},
|
||||
preprocessing_config={"resize": (224, 224), "normalize": True}
|
||||
)
|
||||
|
||||
assert config.confidence_threshold == 0.8
|
||||
assert config.max_memory_mb == 512
|
||||
assert config.class_names[0] == "car"
|
||||
assert config.preprocessing_config["resize"] == (224, 224)
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"model_id": "detection_model",
|
||||
"model_path": "/path/to/model.pt",
|
||||
"model_type": "detection",
|
||||
"device": "cuda:0",
|
||||
"confidence_threshold": 0.75,
|
||||
"class_names": {0: "person", 1: "vehicle"},
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = ModelConfig.from_dict(config_dict)
|
||||
|
||||
assert config.model_id == "detection_model"
|
||||
assert config.confidence_threshold == 0.75
|
||||
assert config.class_names[1] == "vehicle"
|
||||
|
||||
def test_validation(self):
|
||||
"""Test config validation."""
|
||||
# Valid config
|
||||
valid_config = ModelConfig(
|
||||
model_id="test_model",
|
||||
model_path="/valid/path/model.pt",
|
||||
model_type="detection",
|
||||
device="cpu"
|
||||
)
|
||||
assert valid_config.is_valid() is True
|
||||
|
||||
# Invalid config (empty model_id)
|
||||
invalid_config = ModelConfig(
|
||||
model_id="",
|
||||
model_path="/path/model.pt",
|
||||
model_type="detection",
|
||||
device="cpu"
|
||||
)
|
||||
assert invalid_config.is_valid() is False
|
||||
|
||||
def test_get_memory_limit_bytes(self):
|
||||
"""Test getting memory limit in bytes."""
|
||||
config = ModelConfig(
|
||||
model_id="test",
|
||||
model_path="/path",
|
||||
model_type="detection",
|
||||
device="cpu",
|
||||
max_memory_mb=256
|
||||
)
|
||||
|
||||
assert config.get_memory_limit_bytes() == 256 * 1024 * 1024
|
||||
|
||||
|
||||
class TestModelInfo:
|
||||
"""Test model information."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test model info creation."""
|
||||
config = ModelConfig(
|
||||
model_id="test_model",
|
||||
model_path="/path/model.pt",
|
||||
model_type="detection",
|
||||
device="cuda:0"
|
||||
)
|
||||
|
||||
mock_model = Mock()
|
||||
|
||||
info = ModelInfo(
|
||||
config=config,
|
||||
model_instance=mock_model,
|
||||
load_time=1.5
|
||||
)
|
||||
|
||||
assert info.config == config
|
||||
assert info.model_instance == mock_model
|
||||
assert info.load_time == 1.5
|
||||
assert info.reference_count == 0
|
||||
assert info.last_used <= time.time()
|
||||
assert info.memory_usage == 0
|
||||
|
||||
def test_increment_reference(self):
|
||||
"""Test incrementing reference count."""
|
||||
config = ModelConfig("test", "/path", "detection", "cpu")
|
||||
info = ModelInfo(config, Mock(), 1.0)
|
||||
|
||||
assert info.reference_count == 0
|
||||
|
||||
info.increment_reference()
|
||||
assert info.reference_count == 1
|
||||
|
||||
info.increment_reference()
|
||||
assert info.reference_count == 2
|
||||
|
||||
def test_decrement_reference(self):
|
||||
"""Test decrementing reference count."""
|
||||
config = ModelConfig("test", "/path", "detection", "cpu")
|
||||
info = ModelInfo(config, Mock(), 1.0)
|
||||
info.reference_count = 3
|
||||
|
||||
assert info.decrement_reference() == 2
|
||||
assert info.reference_count == 2
|
||||
|
||||
assert info.decrement_reference() == 1
|
||||
assert info.decrement_reference() == 0
|
||||
|
||||
# Should not go below 0
|
||||
assert info.decrement_reference() == 0
|
||||
|
||||
def test_update_usage(self):
|
||||
"""Test updating usage statistics."""
|
||||
config = ModelConfig("test", "/path", "detection", "cpu")
|
||||
info = ModelInfo(config, Mock(), 1.0)
|
||||
|
||||
original_time = info.last_used
|
||||
original_count = info.usage_count
|
||||
|
||||
time.sleep(0.01) # Small delay
|
||||
info.update_usage(memory_usage=512*1024*1024) # 512MB
|
||||
|
||||
assert info.last_used > original_time
|
||||
assert info.usage_count == original_count + 1
|
||||
assert info.memory_usage == 512*1024*1024
|
||||
|
||||
def test_age_calculation(self):
|
||||
"""Test age calculation."""
|
||||
config = ModelConfig("test", "/path", "detection", "cpu")
|
||||
info = ModelInfo(config, Mock(), 1.0)
|
||||
|
||||
time.sleep(0.01)
|
||||
age = info.age()
|
||||
|
||||
assert age > 0
|
||||
assert age < 1 # Should be less than 1 second
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting model statistics."""
|
||||
config = ModelConfig("test_model", "/path", "detection", "cuda:0")
|
||||
info = ModelInfo(config, Mock(), 2.5)
|
||||
|
||||
info.reference_count = 3
|
||||
info.usage_count = 100
|
||||
info.memory_usage = 1024*1024*1024 # 1GB
|
||||
|
||||
stats = info.get_stats()
|
||||
|
||||
assert stats["model_id"] == "test_model"
|
||||
assert stats["device"] == "cuda:0"
|
||||
assert stats["load_time"] == 2.5
|
||||
assert stats["reference_count"] == 3
|
||||
assert stats["usage_count"] == 100
|
||||
assert stats["memory_usage_mb"] == 1024
|
||||
assert "age_seconds" in stats
|
||||
|
||||
|
||||
class TestModelLoader:
|
||||
"""Test model loading functionality."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test model loader creation."""
|
||||
loader = ModelLoader()
|
||||
|
||||
assert loader.supported_formats == [".pt", ".pth", ".onnx", ".trt"]
|
||||
assert loader.default_device == "cpu"
|
||||
|
||||
def test_detect_device_cuda_available(self):
|
||||
"""Test device detection when CUDA is available."""
|
||||
loader = ModelLoader()
|
||||
|
||||
with patch('torch.cuda.is_available', return_value=True):
|
||||
with patch('torch.cuda.device_count', return_value=2):
|
||||
device = loader.detect_optimal_device()
|
||||
|
||||
assert device == "cuda:0"
|
||||
|
||||
def test_detect_device_cuda_unavailable(self):
|
||||
"""Test device detection when CUDA is not available."""
|
||||
loader = ModelLoader()
|
||||
|
||||
with patch('torch.cuda.is_available', return_value=False):
|
||||
device = loader.detect_optimal_device()
|
||||
|
||||
assert device == "cpu"
|
||||
|
||||
def test_load_pytorch_model(self):
|
||||
"""Test loading PyTorch model."""
|
||||
loader = ModelLoader()
|
||||
|
||||
with patch('torch.load') as mock_torch_load:
|
||||
with patch('os.path.exists', return_value=True):
|
||||
mock_model = Mock()
|
||||
mock_torch_load.return_value = mock_model
|
||||
|
||||
config = ModelConfig(
|
||||
model_id="test_model",
|
||||
model_path="/path/to/model.pt",
|
||||
model_type="detection",
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
loaded_model = loader.load_model(config)
|
||||
|
||||
assert loaded_model == mock_model
|
||||
mock_torch_load.assert_called_once_with("/path/to/model.pt", map_location="cpu")
|
||||
|
||||
def test_load_model_file_not_exists(self):
|
||||
"""Test loading model when file doesn't exist."""
|
||||
loader = ModelLoader()
|
||||
|
||||
with patch('os.path.exists', return_value=False):
|
||||
config = ModelConfig(
|
||||
model_id="missing_model",
|
||||
model_path="/nonexistent/model.pt",
|
||||
model_type="detection",
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
with pytest.raises(ModelLoadError) as exc_info:
|
||||
loader.load_model(config)
|
||||
|
||||
assert "does not exist" in str(exc_info.value)
|
||||
|
||||
def test_load_model_invalid_format(self):
|
||||
"""Test loading model with invalid format."""
|
||||
loader = ModelLoader()
|
||||
|
||||
with patch('os.path.exists', return_value=True):
|
||||
config = ModelConfig(
|
||||
model_id="invalid_model",
|
||||
model_path="/path/to/model.invalid",
|
||||
model_type="detection",
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
with pytest.raises(ModelLoadError) as exc_info:
|
||||
loader.load_model(config)
|
||||
|
||||
assert "unsupported format" in str(exc_info.value).lower()
|
||||
|
||||
def test_load_model_torch_error(self):
|
||||
"""Test loading model with torch loading error."""
|
||||
loader = ModelLoader()
|
||||
|
||||
with patch('os.path.exists', return_value=True):
|
||||
with patch('torch.load', side_effect=RuntimeError("CUDA out of memory")):
|
||||
config = ModelConfig(
|
||||
model_id="error_model",
|
||||
model_path="/path/to/model.pt",
|
||||
model_type="detection",
|
||||
device="cuda:0"
|
||||
)
|
||||
|
||||
with pytest.raises(ModelLoadError) as exc_info:
|
||||
loader.load_model(config)
|
||||
|
||||
assert "CUDA out of memory" in str(exc_info.value)
|
||||
|
||||
def test_validate_model_pytorch(self):
|
||||
"""Test validating PyTorch model."""
|
||||
loader = ModelLoader()
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.__class__.__module__ = "torch.nn"
|
||||
|
||||
config = ModelConfig("test", "/path", "detection", "cpu")
|
||||
|
||||
is_valid = loader.validate_model(mock_model, config)
|
||||
|
||||
assert is_valid is True
|
||||
|
||||
def test_validate_model_invalid(self):
|
||||
"""Test validating invalid model."""
|
||||
loader = ModelLoader()
|
||||
|
||||
invalid_model = "not_a_model"
|
||||
config = ModelConfig("test", "/path", "detection", "cpu")
|
||||
|
||||
is_valid = loader.validate_model(invalid_model, config)
|
||||
|
||||
assert is_valid is False
|
||||
|
||||
def test_estimate_model_memory(self):
|
||||
"""Test estimating model memory usage."""
|
||||
loader = ModelLoader()
|
||||
|
||||
mock_model = Mock()
|
||||
mock_param1 = Mock()
|
||||
mock_param1.numel.return_value = 1000000 # 1M parameters
|
||||
mock_param1.element_size.return_value = 4 # 4 bytes per parameter
|
||||
|
||||
mock_param2 = Mock()
|
||||
mock_param2.numel.return_value = 500000 # 0.5M parameters
|
||||
mock_param2.element_size.return_value = 4
|
||||
|
||||
mock_model.parameters.return_value = [mock_param1, mock_param2]
|
||||
|
||||
memory_bytes = loader.estimate_memory_usage(mock_model)
|
||||
|
||||
expected_bytes = (1000000 + 500000) * 4 # 6MB
|
||||
assert memory_bytes == expected_bytes
|
||||
|
||||
|
||||
class TestModelCache:
|
||||
"""Test model caching functionality."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test model cache creation."""
|
||||
cache = ModelCache(max_size=5, max_memory_mb=2048)
|
||||
|
||||
assert cache.max_size == 5
|
||||
assert cache.max_memory_mb == 2048
|
||||
assert len(cache.models) == 0
|
||||
assert len(cache.access_order) == 0
|
||||
|
||||
def test_put_and_get_model(self):
|
||||
"""Test putting and getting model from cache."""
|
||||
cache = ModelCache(max_size=3)
|
||||
|
||||
config = ModelConfig("test_model", "/path", "detection", "cpu")
|
||||
mock_model = Mock()
|
||||
model_info = ModelInfo(config, mock_model, 1.5)
|
||||
|
||||
cache.put("test_model", model_info)
|
||||
|
||||
retrieved_info = cache.get("test_model")
|
||||
|
||||
assert retrieved_info == model_info
|
||||
assert retrieved_info.reference_count == 1 # Should be incremented on get
|
||||
|
||||
def test_get_nonexistent_model(self):
|
||||
"""Test getting non-existent model."""
|
||||
cache = ModelCache(max_size=3)
|
||||
|
||||
result = cache.get("nonexistent_model")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_contains_check(self):
|
||||
"""Test checking if model exists in cache."""
|
||||
cache = ModelCache(max_size=3)
|
||||
|
||||
config = ModelConfig("test_model", "/path", "detection", "cpu")
|
||||
model_info = ModelInfo(config, Mock(), 1.0)
|
||||
cache.put("test_model", model_info)
|
||||
|
||||
assert cache.contains("test_model") is True
|
||||
assert cache.contains("nonexistent_model") is False
|
||||
|
||||
def test_remove_model(self):
|
||||
"""Test removing model from cache."""
|
||||
cache = ModelCache(max_size=3)
|
||||
|
||||
config = ModelConfig("test_model", "/path", "detection", "cpu")
|
||||
model_info = ModelInfo(config, Mock(), 1.0)
|
||||
cache.put("test_model", model_info)
|
||||
|
||||
assert cache.contains("test_model") is True
|
||||
|
||||
removed_info = cache.remove("test_model")
|
||||
|
||||
assert removed_info == model_info
|
||||
assert cache.contains("test_model") is False
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""Test LRU eviction policy."""
|
||||
cache = ModelCache(max_size=2)
|
||||
|
||||
# Add models to fill cache
|
||||
for i in range(2):
|
||||
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
|
||||
model_info = ModelInfo(config, Mock(), 1.0)
|
||||
cache.put(f"model_{i}", model_info)
|
||||
|
||||
# Access model_0 to make it recently used
|
||||
cache.get("model_0")
|
||||
|
||||
# Add another model (should evict model_1, the least recently used)
|
||||
config = ModelConfig("model_2", "/path_2", "detection", "cpu")
|
||||
model_info = ModelInfo(config, Mock(), 1.0)
|
||||
cache.put("model_2", model_info)
|
||||
|
||||
assert cache.size() == 2
|
||||
assert cache.contains("model_0") is True # Recently accessed
|
||||
assert cache.contains("model_1") is False # Evicted
|
||||
assert cache.contains("model_2") is True # Newly added
|
||||
|
||||
def test_memory_based_eviction(self):
|
||||
"""Test memory-based eviction."""
|
||||
cache = ModelCache(max_size=10, max_memory_mb=1) # 1MB limit
|
||||
|
||||
# Add model that uses 0.8MB
|
||||
config1 = ModelConfig("model_1", "/path_1", "detection", "cpu")
|
||||
model1 = Mock()
|
||||
info1 = ModelInfo(config1, model1, 1.0)
|
||||
info1.memory_usage = 0.8 * 1024 * 1024 # 0.8MB
|
||||
cache.put("model_1", info1)
|
||||
|
||||
# Add model that would exceed memory limit
|
||||
config2 = ModelConfig("model_2", "/path_2", "detection", "cpu")
|
||||
model2 = Mock()
|
||||
info2 = ModelInfo(config2, model2, 1.0)
|
||||
info2.memory_usage = 0.5 * 1024 * 1024 # 0.5MB
|
||||
cache.put("model_2", info2)
|
||||
|
||||
# First model should be evicted due to memory constraint
|
||||
assert cache.contains("model_1") is False
|
||||
assert cache.contains("model_2") is True
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting cache statistics."""
|
||||
cache = ModelCache(max_size=5)
|
||||
|
||||
# Add some models
|
||||
for i in range(3):
|
||||
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
|
||||
model_info = ModelInfo(config, Mock(), 1.0)
|
||||
model_info.memory_usage = 100 * 1024 * 1024 # 100MB each
|
||||
cache.put(f"model_{i}", model_info)
|
||||
|
||||
# Access some models
|
||||
cache.get("model_0")
|
||||
cache.get("model_1")
|
||||
cache.get("nonexistent") # Miss
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats["size"] == 3
|
||||
assert stats["max_size"] == 5
|
||||
assert stats["hits"] == 2
|
||||
assert stats["misses"] == 1
|
||||
assert stats["hit_rate"] == 2/3
|
||||
assert stats["memory_usage_mb"] == 300
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""Test clearing entire cache."""
|
||||
cache = ModelCache(max_size=5)
|
||||
|
||||
# Add models
|
||||
for i in range(3):
|
||||
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
|
||||
model_info = ModelInfo(config, Mock(), 1.0)
|
||||
cache.put(f"model_{i}", model_info)
|
||||
|
||||
assert cache.size() == 3
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert cache.size() == 0
|
||||
assert len(cache.models) == 0
|
||||
assert len(cache.access_order) == 0
|
||||
|
||||
|
||||
class TestModelManager:
|
||||
"""Test main model manager functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test model manager initialization."""
|
||||
manager = ModelManager()
|
||||
|
||||
assert isinstance(manager.cache, ModelCache)
|
||||
assert isinstance(manager.loader, ModelLoader)
|
||||
assert manager.models_directory == "models"
|
||||
assert manager.default_device == "cpu"
|
||||
|
||||
def test_initialization_with_config(self):
|
||||
"""Test initialization with custom configuration."""
|
||||
config = {
|
||||
"models_directory": "/custom/models",
|
||||
"default_device": "cuda:0",
|
||||
"cache_max_size": 20,
|
||||
"cache_max_memory_mb": 4096
|
||||
}
|
||||
|
||||
manager = ModelManager(config)
|
||||
|
||||
assert manager.models_directory == "/custom/models"
|
||||
assert manager.default_device == "cuda:0"
|
||||
assert manager.cache.max_size == 20
|
||||
assert manager.cache.max_memory_mb == 4096
|
||||
|
||||
def test_load_model_new(self):
|
||||
"""Test loading new model."""
|
||||
manager = ModelManager()
|
||||
|
||||
config = ModelConfig(
|
||||
model_id="test_model",
|
||||
model_path="/path/to/model.pt",
|
||||
model_type="detection",
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
with patch.object(manager.loader, 'load_model') as mock_load:
|
||||
with patch.object(manager.loader, 'estimate_memory_usage', return_value=512*1024*1024):
|
||||
mock_model = Mock()
|
||||
mock_load.return_value = mock_model
|
||||
|
||||
loaded_model = manager.load_model(config)
|
||||
|
||||
assert loaded_model == mock_model
|
||||
assert manager.cache.contains("test_model") is True
|
||||
mock_load.assert_called_once_with(config)
|
||||
|
||||
def test_load_model_from_cache(self):
|
||||
"""Test loading model from cache."""
|
||||
manager = ModelManager()
|
||||
|
||||
# Pre-populate cache
|
||||
config = ModelConfig("cached_model", "/path", "detection", "cpu")
|
||||
mock_model = Mock()
|
||||
model_info = ModelInfo(config, mock_model, 1.0)
|
||||
manager.cache.put("cached_model", model_info)
|
||||
|
||||
with patch.object(manager.loader, 'load_model') as mock_load:
|
||||
loaded_model = manager.load_model(config)
|
||||
|
||||
assert loaded_model == mock_model
|
||||
mock_load.assert_not_called() # Should not load from disk
|
||||
|
||||
def test_get_model_by_id(self):
|
||||
"""Test getting model by ID."""
|
||||
manager = ModelManager()
|
||||
|
||||
config = ModelConfig("test_model", "/path", "detection", "cpu")
|
||||
mock_model = Mock()
|
||||
model_info = ModelInfo(config, mock_model, 1.0)
|
||||
manager.cache.put("test_model", model_info)
|
||||
|
||||
retrieved_model = manager.get_model("test_model")
|
||||
|
||||
assert retrieved_model == mock_model
|
||||
|
||||
def test_get_nonexistent_model(self):
|
||||
"""Test getting non-existent model."""
|
||||
manager = ModelManager()
|
||||
|
||||
model = manager.get_model("nonexistent_model")
|
||||
|
||||
assert model is None
|
||||
|
||||
def test_unload_model_with_references(self):
|
||||
"""Test unloading model with active references."""
|
||||
manager = ModelManager()
|
||||
|
||||
config = ModelConfig("ref_model", "/path", "detection", "cpu")
|
||||
mock_model = Mock()
|
||||
model_info = ModelInfo(config, mock_model, 1.0)
|
||||
model_info.reference_count = 2 # Active references
|
||||
manager.cache.put("ref_model", model_info)
|
||||
|
||||
result = manager.unload_model("ref_model")
|
||||
|
||||
assert result is False # Should not unload with active references
|
||||
assert manager.cache.contains("ref_model") is True
|
||||
|
||||
def test_unload_model_no_references(self):
|
||||
"""Test unloading model without references."""
|
||||
manager = ModelManager()
|
||||
|
||||
config = ModelConfig("no_ref_model", "/path", "detection", "cpu")
|
||||
mock_model = Mock()
|
||||
model_info = ModelInfo(config, mock_model, 1.0)
|
||||
model_info.reference_count = 0 # No references
|
||||
manager.cache.put("no_ref_model", model_info)
|
||||
|
||||
result = manager.unload_model("no_ref_model")
|
||||
|
||||
assert result is True
|
||||
assert manager.cache.contains("no_ref_model") is False
|
||||
|
||||
def test_list_loaded_models(self):
|
||||
"""Test listing loaded models."""
|
||||
manager = ModelManager()
|
||||
|
||||
# Add models to cache
|
||||
for i in range(3):
|
||||
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
|
||||
model_info = ModelInfo(config, Mock(), 1.0)
|
||||
manager.cache.put(f"model_{i}", model_info)
|
||||
|
||||
loaded_models = manager.list_loaded_models()
|
||||
|
||||
assert len(loaded_models) == 3
|
||||
assert all(info["model_id"].startswith("model_") for info in loaded_models)
|
||||
|
||||
def test_get_model_info(self):
|
||||
"""Test getting model information."""
|
||||
manager = ModelManager()
|
||||
|
||||
config = ModelConfig("info_model", "/path", "detection", "cuda:0")
|
||||
mock_model = Mock()
|
||||
model_info = ModelInfo(config, mock_model, 2.5)
|
||||
model_info.usage_count = 10
|
||||
manager.cache.put("info_model", model_info)
|
||||
|
||||
info = manager.get_model_info("info_model")
|
||||
|
||||
assert info is not None
|
||||
assert info["model_id"] == "info_model"
|
||||
assert info["device"] == "cuda:0"
|
||||
assert info["load_time"] == 2.5
|
||||
assert info["usage_count"] == 10
|
||||
|
||||
def test_cleanup_unused_models(self):
|
||||
"""Test cleaning up unused models."""
|
||||
manager = ModelManager()
|
||||
|
||||
# Add models with different reference counts
|
||||
models_data = [
|
||||
("used_model", 2), # Has references
|
||||
("unused_model_1", 0), # No references
|
||||
("unused_model_2", 0) # No references
|
||||
]
|
||||
|
||||
for model_id, ref_count in models_data:
|
||||
config = ModelConfig(model_id, f"/path/{model_id}", "detection", "cpu")
|
||||
model_info = ModelInfo(config, Mock(), 1.0)
|
||||
model_info.reference_count = ref_count
|
||||
manager.cache.put(model_id, model_info)
|
||||
|
||||
cleaned_count = manager.cleanup_unused_models()
|
||||
|
||||
assert cleaned_count == 2 # Two unused models cleaned
|
||||
assert manager.cache.contains("used_model") is True
|
||||
assert manager.cache.contains("unused_model_1") is False
|
||||
assert manager.cache.contains("unused_model_2") is False
|
||||
|
||||
def test_get_memory_usage(self):
|
||||
"""Test getting total memory usage."""
|
||||
manager = ModelManager()
|
||||
|
||||
# Add models with different memory usage
|
||||
memory_sizes = [256, 512, 1024] # MB
|
||||
|
||||
for i, memory_mb in enumerate(memory_sizes):
|
||||
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
|
||||
model_info = ModelInfo(config, Mock(), 1.0)
|
||||
model_info.memory_usage = memory_mb * 1024 * 1024 # Convert to bytes
|
||||
manager.cache.put(f"model_{i}", model_info)
|
||||
|
||||
total_usage = manager.get_memory_usage()
|
||||
|
||||
expected_bytes = sum(memory_sizes) * 1024 * 1024
|
||||
assert total_usage == expected_bytes
|
||||
|
||||
def test_health_check(self):
|
||||
"""Test model manager health check."""
|
||||
manager = ModelManager()
|
||||
|
||||
# Add models
|
||||
for i in range(3):
|
||||
config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu")
|
||||
model_info = ModelInfo(config, Mock(), 1.0)
|
||||
model_info.memory_usage = 100 * 1024 * 1024 # 100MB each
|
||||
manager.cache.put(f"model_{i}", model_info)
|
||||
|
||||
health_report = manager.health_check()
|
||||
|
||||
assert health_report["status"] == "healthy"
|
||||
assert health_report["loaded_models"] == 3
|
||||
assert health_report["total_memory_mb"] == 300
|
||||
assert health_report["cache_hit_rate"] >= 0
|
||||
|
||||
|
||||
class TestModelManagerIntegration:
|
||||
"""Integration tests for model manager."""
|
||||
|
||||
def test_concurrent_model_loading(self):
|
||||
"""Test concurrent model loading."""
|
||||
manager = ModelManager()
|
||||
|
||||
# Mock loader to simulate loading time
|
||||
def slow_load(config):
|
||||
time.sleep(0.1) # Simulate loading time
|
||||
mock_model = Mock()
|
||||
mock_model.model_id = config.model_id
|
||||
return mock_model
|
||||
|
||||
with patch.object(manager.loader, 'load_model', side_effect=slow_load):
|
||||
with patch.object(manager.loader, 'estimate_memory_usage', return_value=100*1024*1024):
|
||||
|
||||
# Create multiple threads loading different models
|
||||
results = {}
|
||||
errors = []
|
||||
|
||||
def load_model_thread(model_id):
|
||||
try:
|
||||
config = ModelConfig(
|
||||
model_id=model_id,
|
||||
model_path=f"/path/{model_id}.pt",
|
||||
model_type="detection",
|
||||
device="cpu"
|
||||
)
|
||||
model = manager.load_model(config)
|
||||
results[model_id] = model
|
||||
except Exception as e:
|
||||
errors.append((model_id, str(e)))
|
||||
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=load_model_thread, args=(f"model_{i}",))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# All models should be loaded successfully
|
||||
assert len(errors) == 0
|
||||
assert len(results) == 5
|
||||
assert len(manager.cache.models) == 5
|
||||
|
||||
def test_memory_pressure_handling(self):
|
||||
"""Test handling memory pressure."""
|
||||
# Create manager with small memory limit
|
||||
manager = ModelManager({
|
||||
"cache_max_memory_mb": 200 # 200MB limit
|
||||
})
|
||||
|
||||
with patch.object(manager.loader, 'load_model') as mock_load:
|
||||
with patch.object(manager.loader, 'estimate_memory_usage', return_value=100*1024*1024): # 100MB per model
|
||||
|
||||
def create_mock_model(config):
|
||||
mock_model = Mock()
|
||||
mock_model.model_id = config.model_id
|
||||
return mock_model
|
||||
|
||||
mock_load.side_effect = create_mock_model
|
||||
|
||||
# Load models that exceed memory limit
|
||||
for i in range(4): # 4 * 100MB = 400MB > 200MB limit
|
||||
config = ModelConfig(
|
||||
model_id=f"large_model_{i}",
|
||||
model_path=f"/path/large_model_{i}.pt",
|
||||
model_type="detection",
|
||||
device="cpu"
|
||||
)
|
||||
manager.load_model(config)
|
||||
|
||||
# Should not exceed memory limit due to eviction
|
||||
total_memory = manager.get_memory_usage()
|
||||
memory_limit = 200 * 1024 * 1024
|
||||
assert total_memory <= memory_limit
|
||||
|
||||
def test_model_lifecycle_management(self):
|
||||
"""Test complete model lifecycle."""
|
||||
manager = ModelManager()
|
||||
|
||||
with patch.object(manager.loader, 'load_model') as mock_load:
|
||||
with patch.object(manager.loader, 'estimate_memory_usage', return_value=50*1024*1024):
|
||||
|
||||
mock_model = Mock()
|
||||
mock_load.return_value = mock_model
|
||||
|
||||
config = ModelConfig(
|
||||
model_id="lifecycle_model",
|
||||
model_path="/path/lifecycle_model.pt",
|
||||
model_type="detection",
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# 1. Load model
|
||||
loaded_model = manager.load_model(config)
|
||||
assert loaded_model == mock_model
|
||||
assert manager.cache.contains("lifecycle_model") is True
|
||||
|
||||
# 2. Get model multiple times (increase usage)
|
||||
for _ in range(5):
|
||||
model = manager.get_model("lifecycle_model")
|
||||
assert model == mock_model
|
||||
|
||||
# 3. Check model info
|
||||
info = manager.get_model_info("lifecycle_model")
|
||||
assert info["usage_count"] >= 5
|
||||
|
||||
# 4. Simulate model still in use
|
||||
model_info = manager.cache.get("lifecycle_model")
|
||||
model_info.reference_count = 1
|
||||
|
||||
# Should not unload while in use
|
||||
unloaded = manager.unload_model("lifecycle_model")
|
||||
assert unloaded is False
|
||||
assert manager.cache.contains("lifecycle_model") is True
|
||||
|
||||
# 5. Release reference and unload
|
||||
model_info.reference_count = 0
|
||||
unloaded = manager.unload_model("lifecycle_model")
|
||||
assert unloaded is True
|
||||
assert manager.cache.contains("lifecycle_model") is False
|
||||
|
||||
def test_error_recovery(self):
|
||||
"""Test error recovery scenarios."""
|
||||
manager = ModelManager()
|
||||
|
||||
# Test loading model that fails initially then succeeds
|
||||
call_count = 0
|
||||
def failing_then_success_load(config):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise ModelLoadError("First attempt failed")
|
||||
return Mock()
|
||||
|
||||
with patch.object(manager.loader, 'load_model', side_effect=failing_then_success_load):
|
||||
with patch.object(manager.loader, 'estimate_memory_usage', return_value=50*1024*1024):
|
||||
|
||||
config = ModelConfig(
|
||||
model_id="retry_model",
|
||||
model_path="/path/retry_model.pt",
|
||||
model_type="detection",
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# First attempt should fail
|
||||
with pytest.raises(ModelLoadError):
|
||||
manager.load_model(config)
|
||||
|
||||
# Model should not be in cache
|
||||
assert manager.cache.contains("retry_model") is False
|
||||
|
||||
# Second attempt should succeed
|
||||
model = manager.load_model(config)
|
||||
assert model is not None
|
||||
assert manager.cache.contains("retry_model") is True
|
959
tests/unit/pipeline/test_action_executor.py
Normal file
959
tests/unit/pipeline/test_action_executor.py
Normal file
|
@ -0,0 +1,959 @@
|
|||
"""
|
||||
Unit tests for action execution functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from detector_worker.pipeline.action_executor import (
|
||||
ActionExecutor,
|
||||
ActionResult,
|
||||
ActionType,
|
||||
RedisAction,
|
||||
PostgreSQLAction,
|
||||
FileAction
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
from detector_worker.core.exceptions import ActionError, RedisError, DatabaseError
|
||||
|
||||
|
||||
class TestActionResult:
|
||||
"""Test action execution result."""
|
||||
|
||||
def test_creation_success(self):
|
||||
"""Test successful action result creation."""
|
||||
result = ActionResult(
|
||||
action_type=ActionType.REDIS_SAVE,
|
||||
success=True,
|
||||
execution_time=0.05,
|
||||
metadata={"key": "saved_image_key", "expiry": 600}
|
||||
)
|
||||
|
||||
assert result.action_type == ActionType.REDIS_SAVE
|
||||
assert result.success is True
|
||||
assert result.execution_time == 0.05
|
||||
assert result.metadata["key"] == "saved_image_key"
|
||||
assert result.error is None
|
||||
|
||||
def test_creation_failure(self):
|
||||
"""Test failed action result creation."""
|
||||
result = ActionResult(
|
||||
action_type=ActionType.POSTGRESQL_INSERT,
|
||||
success=False,
|
||||
error="Database connection failed",
|
||||
execution_time=0.02
|
||||
)
|
||||
|
||||
assert result.action_type == ActionType.POSTGRESQL_INSERT
|
||||
assert result.success is False
|
||||
assert result.error == "Database connection failed"
|
||||
assert result.metadata == {}
|
||||
|
||||
|
||||
class TestRedisAction:
|
||||
"""Test Redis action implementations."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test Redis action creation."""
|
||||
action_config = {
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{display_id}:{timestamp}:{session_id}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
|
||||
action = RedisAction(action_config)
|
||||
|
||||
assert action.action_type == ActionType.REDIS_SAVE
|
||||
assert action.region == "car"
|
||||
assert action.key_template == "inference:{display_id}:{timestamp}:{session_id}"
|
||||
assert action.expire_seconds == 600
|
||||
|
||||
def test_resolve_key_template(self):
|
||||
"""Test key template resolution."""
|
||||
action_config = {
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
|
||||
action = RedisAction(action_config)
|
||||
|
||||
context = {
|
||||
"display_id": "display_001",
|
||||
"timestamp": "1640995200000",
|
||||
"session_id": "session_123",
|
||||
"filename": "detection.jpg"
|
||||
}
|
||||
|
||||
resolved_key = action.resolve_key(context)
|
||||
expected_key = "inference:display_001:1640995200000:session_123:detection.jpg"
|
||||
|
||||
assert resolved_key == expected_key
|
||||
|
||||
def test_resolve_key_missing_variable(self):
|
||||
"""Test key resolution with missing variable."""
|
||||
action_config = {
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{display_id}:{missing_var}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
|
||||
action = RedisAction(action_config)
|
||||
|
||||
context = {"display_id": "display_001"}
|
||||
|
||||
with pytest.raises(ActionError):
|
||||
action.resolve_key(context)
|
||||
|
||||
|
||||
class TestPostgreSQLAction:
|
||||
"""Test PostgreSQL action implementations."""
|
||||
|
||||
def test_creation_insert(self):
|
||||
"""Test PostgreSQL insert action creation."""
|
||||
action_config = {
|
||||
"type": "postgresql_insert",
|
||||
"table": "detections",
|
||||
"fields": {
|
||||
"camera_id": "{camera_id}",
|
||||
"session_id": "{session_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"bbox_x1": "{bbox.x1}",
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
}
|
||||
|
||||
action = PostgreSQLAction(action_config)
|
||||
|
||||
assert action.action_type == ActionType.POSTGRESQL_INSERT
|
||||
assert action.table == "detections"
|
||||
assert len(action.fields) == 6
|
||||
assert action.key_field is None
|
||||
|
||||
def test_creation_update(self):
|
||||
"""Test PostgreSQL update action creation."""
|
||||
action_config = {
|
||||
"type": "postgresql_update_combined",
|
||||
"table": "car_info",
|
||||
"key_field": "session_id",
|
||||
"fields": {
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type}",
|
||||
"updated_at": "NOW()"
|
||||
},
|
||||
"waitForBranches": ["car_brand_cls", "car_bodytype_cls"]
|
||||
}
|
||||
|
||||
action = PostgreSQLAction(action_config)
|
||||
|
||||
assert action.action_type == ActionType.POSTGRESQL_UPDATE
|
||||
assert action.table == "car_info"
|
||||
assert action.key_field == "session_id"
|
||||
assert action.wait_for_branches == ["car_brand_cls", "car_bodytype_cls"]
|
||||
|
||||
def test_resolve_field_values(self):
|
||||
"""Test field value resolution."""
|
||||
action_config = {
|
||||
"type": "postgresql_insert",
|
||||
"table": "detections",
|
||||
"fields": {
|
||||
"camera_id": "{camera_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"brand": "{car_brand_cls.brand}"
|
||||
}
|
||||
}
|
||||
|
||||
action = PostgreSQLAction(action_config)
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"class": "car",
|
||||
"confidence": 0.85
|
||||
}
|
||||
|
||||
branch_results = {
|
||||
"car_brand_cls": {"brand": "Toyota", "confidence": 0.78}
|
||||
}
|
||||
|
||||
resolved_fields = action.resolve_field_values(context, branch_results)
|
||||
|
||||
assert resolved_fields["camera_id"] == "camera_001"
|
||||
assert resolved_fields["detection_class"] == "car"
|
||||
assert resolved_fields["confidence"] == 0.85
|
||||
assert resolved_fields["brand"] == "Toyota"
|
||||
|
||||
|
||||
class TestFileAction:
|
||||
"""Test file action implementations."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test file action creation."""
|
||||
action_config = {
|
||||
"type": "save_image",
|
||||
"path": "/tmp/detections/{camera_id}_{timestamp}.jpg",
|
||||
"region": "car",
|
||||
"format": "jpeg",
|
||||
"quality": 85
|
||||
}
|
||||
|
||||
action = FileAction(action_config)
|
||||
|
||||
assert action.action_type == ActionType.SAVE_IMAGE
|
||||
assert action.path_template == "/tmp/detections/{camera_id}_{timestamp}.jpg"
|
||||
assert action.region == "car"
|
||||
assert action.format == "jpeg"
|
||||
assert action.quality == 85
|
||||
|
||||
def test_resolve_path_template(self):
|
||||
"""Test path template resolution."""
|
||||
action_config = {
|
||||
"type": "save_image",
|
||||
"path": "/tmp/detections/{camera_id}/{date}/{timestamp}.jpg"
|
||||
}
|
||||
|
||||
action = FileAction(action_config)
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"timestamp": "1640995200000",
|
||||
"date": "2022-01-01"
|
||||
}
|
||||
|
||||
resolved_path = action.resolve_path(context)
|
||||
expected_path = "/tmp/detections/camera_001/2022-01-01/1640995200000.jpg"
|
||||
|
||||
assert resolved_path == expected_path
|
||||
|
||||
|
||||
class TestActionExecutor:
|
||||
"""Test action execution functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test action executor initialization."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
assert executor.redis_client is None
|
||||
assert executor.db_manager is None
|
||||
assert executor.max_concurrent_actions == 10
|
||||
assert executor.action_timeout == 30.0
|
||||
|
||||
def test_initialization_with_clients(self, mock_redis_client, mock_database_connection):
|
||||
"""Test initialization with client instances."""
|
||||
executor = ActionExecutor(
|
||||
redis_client=mock_redis_client,
|
||||
db_manager=mock_database_connection
|
||||
)
|
||||
|
||||
assert executor.redis_client is mock_redis_client
|
||||
assert executor.db_manager is mock_database_connection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_actions_empty_list(self):
|
||||
"""Test executing empty action list."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"session_id": "session_123"
|
||||
}
|
||||
|
||||
results = await executor.execute_actions([], {}, context)
|
||||
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_redis_save_action(self, mock_redis_client, mock_frame):
|
||||
"""Test executing Redis save image action."""
|
||||
executor = ActionExecutor(redis_client=mock_redis_client)
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{camera_id}:{session_id}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"session_id": "session_123",
|
||||
"frame_data": mock_frame
|
||||
}
|
||||
|
||||
# Mock successful Redis operations
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.REDIS_SAVE
|
||||
|
||||
# Verify Redis calls
|
||||
mock_redis_client.set.assert_called_once()
|
||||
mock_redis_client.expire.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_postgresql_insert_action(self, mock_database_connection):
|
||||
"""Test executing PostgreSQL insert action."""
|
||||
# Mock database manager
|
||||
mock_db_manager = Mock()
|
||||
mock_db_manager.execute_query = AsyncMock(return_value=True)
|
||||
|
||||
executor = ActionExecutor(db_manager=mock_db_manager)
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "postgresql_insert",
|
||||
"table": "detections",
|
||||
"fields": {
|
||||
"camera_id": "{camera_id}",
|
||||
"session_id": "{session_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"session_id": "session_123",
|
||||
"class": "car",
|
||||
"confidence": 0.9
|
||||
}
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.POSTGRESQL_INSERT
|
||||
|
||||
# Verify database call
|
||||
mock_db_manager.execute_query.assert_called_once()
|
||||
call_args = mock_db_manager.execute_query.call_args[0]
|
||||
assert "INSERT INTO detections" in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_postgresql_update_action(self, mock_database_connection):
|
||||
"""Test executing PostgreSQL update action."""
|
||||
mock_db_manager = Mock()
|
||||
mock_db_manager.execute_query = AsyncMock(return_value=True)
|
||||
|
||||
executor = ActionExecutor(db_manager=mock_db_manager)
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "postgresql_update_combined",
|
||||
"table": "car_info",
|
||||
"key_field": "session_id",
|
||||
"fields": {
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type}",
|
||||
"updated_at": "NOW()"
|
||||
},
|
||||
"waitForBranches": ["car_brand_cls", "car_bodytype_cls"]
|
||||
}
|
||||
]
|
||||
|
||||
regions = {}
|
||||
|
||||
context = {
|
||||
"session_id": "session_123"
|
||||
}
|
||||
|
||||
branch_results = {
|
||||
"car_brand_cls": {"brand": "Toyota"},
|
||||
"car_bodytype_cls": {"body_type": "Sedan"}
|
||||
}
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context, branch_results)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.POSTGRESQL_UPDATE
|
||||
|
||||
# Verify database call
|
||||
mock_db_manager.execute_query.assert_called_once()
|
||||
call_args = mock_db_manager.execute_query.call_args[0]
|
||||
assert "UPDATE car_info SET" in call_args[0]
|
||||
assert "WHERE session_id" in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_file_save_action(self, mock_frame):
|
||||
"""Test executing file save action."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "save_image",
|
||||
"path": "/tmp/test_{camera_id}_{timestamp}.jpg",
|
||||
"region": "car",
|
||||
"format": "jpeg",
|
||||
"quality": 85
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"timestamp": "1640995200000",
|
||||
"frame_data": mock_frame
|
||||
}
|
||||
|
||||
with patch('cv2.imwrite') as mock_imwrite:
|
||||
mock_imwrite.return_value = True
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.SAVE_IMAGE
|
||||
|
||||
# Verify file save call
|
||||
mock_imwrite.assert_called_once()
|
||||
call_args = mock_imwrite.call_args
|
||||
assert "/tmp/test_camera_001_1640995200000.jpg" in call_args[0][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_actions_parallel(self, mock_redis_client):
|
||||
"""Test parallel execution of multiple actions."""
|
||||
executor = ActionExecutor(redis_client=mock_redis_client)
|
||||
|
||||
# Multiple Redis actions
|
||||
actions = [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:car:{session_id}",
|
||||
"expire_seconds": 600
|
||||
},
|
||||
{
|
||||
"type": "redis_publish",
|
||||
"channel": "detections",
|
||||
"message": "{camera_id}:car_detected"
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"session_id": "session_123",
|
||||
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
}
|
||||
|
||||
# Mock Redis operations
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
mock_redis_client.publish.return_value = 1
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(result.success for result in results)
|
||||
|
||||
# Should execute in parallel (faster than sequential)
|
||||
assert execution_time < 0.1 # Allow some overhead
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_actions_error_handling(self, mock_redis_client):
|
||||
"""Test error handling in action execution."""
|
||||
executor = ActionExecutor(redis_client=mock_redis_client)
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{session_id}",
|
||||
"expire_seconds": 600
|
||||
},
|
||||
{
|
||||
"type": "redis_save_image", # This one will fail
|
||||
"region": "truck", # Region not detected
|
||||
"key": "inference:truck:{session_id}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
# No truck region
|
||||
}
|
||||
|
||||
context = {
|
||||
"session_id": "session_123",
|
||||
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
}
|
||||
|
||||
# Mock Redis operations
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].success is True # Car action succeeds
|
||||
assert results[1].success is False # Truck action fails
|
||||
assert "Region 'truck' not found" in results[1].error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_actions_timeout(self, mock_redis_client):
|
||||
"""Test action execution timeout."""
|
||||
config = {"action_timeout": 0.001} # Very short timeout
|
||||
executor = ActionExecutor(redis_client=mock_redis_client, config=config)
|
||||
|
||||
def slow_redis_operation(*args, **kwargs):
|
||||
import time
|
||||
time.sleep(1) # Longer than timeout
|
||||
return True
|
||||
|
||||
mock_redis_client.set.side_effect = slow_redis_operation
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{session_id}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"session_id": "session_123",
|
||||
"frame_data": np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
}
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is False
|
||||
assert "timeout" in results[0].error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_redis_publish_action(self, mock_redis_client):
|
||||
"""Test executing Redis publish action."""
|
||||
executor = ActionExecutor(redis_client=mock_redis_client)
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "redis_publish",
|
||||
"channel": "detections:{camera_id}",
|
||||
"message": {
|
||||
"camera_id": "{camera_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"timestamp": "{timestamp}"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"class": "car",
|
||||
"confidence": 0.9,
|
||||
"timestamp": "1640995200000"
|
||||
}
|
||||
|
||||
mock_redis_client.publish.return_value = 1
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.REDIS_PUBLISH
|
||||
|
||||
# Verify publish call
|
||||
mock_redis_client.publish.assert_called_once()
|
||||
call_args = mock_redis_client.publish.call_args
|
||||
assert call_args[0][0] == "detections:camera_001" # Channel
|
||||
|
||||
# Message should be JSON
|
||||
message = call_args[0][1]
|
||||
parsed_message = json.loads(message)
|
||||
assert parsed_message["camera_id"] == "camera_001"
|
||||
assert parsed_message["detection_class"] == "car"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_conditional_action(self):
|
||||
"""Test executing conditional actions."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "conditional",
|
||||
"condition": "{confidence} > 0.8",
|
||||
"actions": [
|
||||
{
|
||||
"type": "log",
|
||||
"message": "High confidence detection: {class} ({confidence})"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.95, # High confidence
|
||||
"detection": DetectionResult("car", 0.95, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"class": "car",
|
||||
"confidence": 0.95
|
||||
}
|
||||
|
||||
with patch('logging.info') as mock_log:
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
|
||||
# Should have logged the message
|
||||
mock_log.assert_called_once()
|
||||
log_message = mock_log.call_args[0][0]
|
||||
assert "High confidence detection: car (0.95)" in log_message
|
||||
|
||||
def test_crop_region_from_frame(self, mock_frame):
|
||||
"""Test cropping region from frame."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
|
||||
cropped = executor._crop_region_from_frame(mock_frame, detection.bbox)
|
||||
|
||||
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
|
||||
|
||||
def test_encode_image_base64(self, mock_frame):
|
||||
"""Test encoding image to base64."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
# Crop a small region
|
||||
cropped_frame = mock_frame[200:400, 100:300] # 200x200 region
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
# Mock successful encoding
|
||||
mock_imencode.return_value = (True, np.array([1, 2, 3, 4], dtype=np.uint8))
|
||||
|
||||
encoded = executor._encode_image_base64(cropped_frame, format="jpeg")
|
||||
|
||||
# Should return base64 string
|
||||
assert isinstance(encoded, str)
|
||||
assert len(encoded) > 0
|
||||
|
||||
# Verify encoding call
|
||||
mock_imencode.assert_called_once()
|
||||
assert mock_imencode.call_args[0][0] == '.jpg'
|
||||
|
||||
def test_build_insert_query(self):
|
||||
"""Test building INSERT SQL query."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
table = "detections"
|
||||
fields = {
|
||||
"camera_id": "camera_001",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.9,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
query, values = executor._build_insert_query(table, fields)
|
||||
|
||||
assert "INSERT INTO detections" in query
|
||||
assert "camera_id, detection_class, confidence, created_at" in query
|
||||
assert "VALUES (%s, %s, %s, NOW())" in query
|
||||
assert values == ["camera_001", "car", 0.9]
|
||||
|
||||
def test_build_update_query(self):
|
||||
"""Test building UPDATE SQL query."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
table = "car_info"
|
||||
fields = {
|
||||
"car_brand": "Toyota",
|
||||
"car_body_type": "Sedan",
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
key_field = "session_id"
|
||||
key_value = "session_123"
|
||||
|
||||
query, values = executor._build_update_query(table, fields, key_field, key_value)
|
||||
|
||||
assert "UPDATE car_info SET" in query
|
||||
assert "car_brand = %s" in query
|
||||
assert "car_body_type = %s" in query
|
||||
assert "updated_at = NOW()" in query
|
||||
assert "WHERE session_id = %s" in query
|
||||
assert values == ["Toyota", "Sedan", "session_123"]
|
||||
|
||||
def test_evaluate_condition(self):
|
||||
"""Test evaluating conditional expressions."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
context = {
|
||||
"confidence": 0.85,
|
||||
"class": "car",
|
||||
"area": 40000
|
||||
}
|
||||
|
||||
# Simple comparisons
|
||||
assert executor._evaluate_condition("{confidence} > 0.8", context) is True
|
||||
assert executor._evaluate_condition("{confidence} < 0.8", context) is False
|
||||
assert executor._evaluate_condition("{confidence} >= 0.85", context) is True
|
||||
assert executor._evaluate_condition("{confidence} == 0.85", context) is True
|
||||
|
||||
# String comparisons
|
||||
assert executor._evaluate_condition("{class} == 'car'", context) is True
|
||||
assert executor._evaluate_condition("{class} != 'truck'", context) is True
|
||||
|
||||
# Complex conditions
|
||||
assert executor._evaluate_condition("{confidence} > 0.8 and {area} > 30000", context) is True
|
||||
assert executor._evaluate_condition("{confidence} > 0.9 or {area} > 30000", context) is True
|
||||
assert executor._evaluate_condition("{confidence} > 0.9 and {area} < 30000", context) is False
|
||||
|
||||
def test_validate_action_config(self):
|
||||
"""Test action configuration validation."""
|
||||
executor = ActionExecutor()
|
||||
|
||||
# Valid Redis action
|
||||
valid_redis = {
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{session_id}",
|
||||
"expire_seconds": 600
|
||||
}
|
||||
assert executor._validate_action_config(valid_redis) is True
|
||||
|
||||
# Invalid action (missing required fields)
|
||||
invalid_action = {
|
||||
"type": "redis_save_image"
|
||||
# Missing region and key
|
||||
}
|
||||
with pytest.raises(ActionError):
|
||||
executor._validate_action_config(invalid_action)
|
||||
|
||||
# Unknown action type
|
||||
unknown_action = {
|
||||
"type": "unknown_action_type",
|
||||
"some_field": "value"
|
||||
}
|
||||
with pytest.raises(ActionError):
|
||||
executor._validate_action_config(unknown_action)
|
||||
|
||||
|
||||
class TestActionExecutorIntegration:
|
||||
"""Integration tests for action execution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_detection_workflow(self, mock_redis_client, mock_frame):
|
||||
"""Test complete detection workflow with multiple actions."""
|
||||
# Mock database manager
|
||||
mock_db_manager = Mock()
|
||||
mock_db_manager.execute_query = AsyncMock(return_value=True)
|
||||
|
||||
executor = ActionExecutor(
|
||||
redis_client=mock_redis_client,
|
||||
db_manager=mock_db_manager
|
||||
)
|
||||
|
||||
# Complete action workflow
|
||||
actions = [
|
||||
# Save cropped image to Redis
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{camera_id}:{timestamp}:{session_id}:car",
|
||||
"expire_seconds": 600
|
||||
},
|
||||
# Insert initial detection record
|
||||
{
|
||||
"type": "postgresql_insert",
|
||||
"table": "car_detections",
|
||||
"fields": {
|
||||
"camera_id": "{camera_id}",
|
||||
"session_id": "{session_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"bbox_x1": "{bbox.x1}",
|
||||
"bbox_y1": "{bbox.y1}",
|
||||
"bbox_x2": "{bbox.x2}",
|
||||
"bbox_y2": "{bbox.y2}",
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
},
|
||||
# Publish detection event
|
||||
{
|
||||
"type": "redis_publish",
|
||||
"channel": "detections:{camera_id}",
|
||||
"message": {
|
||||
"event": "car_detected",
|
||||
"camera_id": "{camera_id}",
|
||||
"session_id": "{session_id}",
|
||||
"timestamp": "{timestamp}"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.92,
|
||||
"detection": DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = {
|
||||
"camera_id": "camera_001",
|
||||
"session_id": "session_123",
|
||||
"timestamp": "1640995200000",
|
||||
"class": "car",
|
||||
"confidence": 0.92,
|
||||
"bbox": {"x1": 100, "y1": 200, "x2": 300, "y2": 400},
|
||||
"frame_data": mock_frame
|
||||
}
|
||||
|
||||
# Mock all Redis operations
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
mock_redis_client.publish.return_value = 1
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context)
|
||||
|
||||
# All actions should succeed
|
||||
assert len(results) == 3
|
||||
assert all(result.success for result in results)
|
||||
|
||||
# Verify all operations were called
|
||||
mock_redis_client.set.assert_called_once() # Image save
|
||||
mock_redis_client.expire.assert_called_once() # Set expiry
|
||||
mock_redis_client.publish.assert_called_once() # Publish event
|
||||
mock_db_manager.execute_query.assert_called_once() # Database insert
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_branch_dependent_actions(self, mock_database_connection):
|
||||
"""Test actions that depend on branch results."""
|
||||
mock_db_manager = Mock()
|
||||
mock_db_manager.execute_query = AsyncMock(return_value=True)
|
||||
|
||||
executor = ActionExecutor(db_manager=mock_db_manager)
|
||||
|
||||
# Action that waits for classification branches
|
||||
actions = [
|
||||
{
|
||||
"type": "postgresql_update_combined",
|
||||
"table": "car_info",
|
||||
"key_field": "session_id",
|
||||
"fields": {
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type}",
|
||||
"car_color": "{car_color_cls.color}",
|
||||
"confidence_brand": "{car_brand_cls.confidence}",
|
||||
"confidence_bodytype": "{car_bodytype_cls.confidence}",
|
||||
"updated_at": "NOW()"
|
||||
},
|
||||
"waitForBranches": ["car_brand_cls", "car_bodytype_cls", "car_color_cls"]
|
||||
}
|
||||
]
|
||||
|
||||
regions = {}
|
||||
|
||||
context = {
|
||||
"session_id": "session_123"
|
||||
}
|
||||
|
||||
# Simulated branch results
|
||||
branch_results = {
|
||||
"car_brand_cls": {"brand": "Toyota", "confidence": 0.87},
|
||||
"car_bodytype_cls": {"body_type": "Sedan", "confidence": 0.82},
|
||||
"car_color_cls": {"color": "Red", "confidence": 0.79}
|
||||
}
|
||||
|
||||
results = await executor.execute_actions(actions, regions, context, branch_results)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].success is True
|
||||
assert results[0].action_type == ActionType.POSTGRESQL_UPDATE
|
||||
|
||||
# Verify database call with all branch data
|
||||
mock_db_manager.execute_query.assert_called_once()
|
||||
call_args = mock_db_manager.execute_query.call_args
|
||||
query = call_args[0][0]
|
||||
values = call_args[0][1]
|
||||
|
||||
assert "UPDATE car_info SET" in query
|
||||
assert "car_brand = %s" in query
|
||||
assert "car_body_type = %s" in query
|
||||
assert "car_color = %s" in query
|
||||
assert "WHERE session_id = %s" in query
|
||||
|
||||
assert "Toyota" in values
|
||||
assert "Sedan" in values
|
||||
assert "Red" in values
|
||||
assert "session_123" in values
|
786
tests/unit/pipeline/test_field_mapper.py
Normal file
786
tests/unit/pipeline/test_field_mapper.py
Normal file
|
@ -0,0 +1,786 @@
|
|||
"""
|
||||
Unit tests for field mapping and template resolution.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from detector_worker.pipeline.field_mapper import (
|
||||
FieldMapper,
|
||||
MappingContext,
|
||||
TemplateResolver,
|
||||
FieldMappingError,
|
||||
NestedFieldAccessor
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
|
||||
|
||||
class TestNestedFieldAccessor:
|
||||
"""Test nested field access functionality."""
|
||||
|
||||
def test_get_nested_value_simple(self):
|
||||
"""Test getting simple nested values."""
|
||||
data = {
|
||||
"user": {
|
||||
"name": "John",
|
||||
"age": 30,
|
||||
"address": {
|
||||
"city": "New York",
|
||||
"zip": "10001"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
assert accessor.get_nested_value(data, "user.name") == "John"
|
||||
assert accessor.get_nested_value(data, "user.age") == 30
|
||||
assert accessor.get_nested_value(data, "user.address.city") == "New York"
|
||||
assert accessor.get_nested_value(data, "user.address.zip") == "10001"
|
||||
|
||||
def test_get_nested_value_array_access(self):
|
||||
"""Test accessing array elements."""
|
||||
data = {
|
||||
"results": [
|
||||
{"score": 0.9, "label": "car"},
|
||||
{"score": 0.8, "label": "truck"}
|
||||
],
|
||||
"bbox": [100, 200, 300, 400]
|
||||
}
|
||||
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
assert accessor.get_nested_value(data, "results[0].score") == 0.9
|
||||
assert accessor.get_nested_value(data, "results[0].label") == "car"
|
||||
assert accessor.get_nested_value(data, "results[1].score") == 0.8
|
||||
assert accessor.get_nested_value(data, "bbox[0]") == 100
|
||||
assert accessor.get_nested_value(data, "bbox[3]") == 400
|
||||
|
||||
def test_get_nested_value_nonexistent_path(self):
|
||||
"""Test accessing non-existent paths."""
|
||||
data = {"user": {"name": "John"}}
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
assert accessor.get_nested_value(data, "user.nonexistent") is None
|
||||
assert accessor.get_nested_value(data, "nonexistent.field") is None
|
||||
assert accessor.get_nested_value(data, "user.address.city") is None
|
||||
|
||||
def test_get_nested_value_with_default(self):
|
||||
"""Test getting nested values with default fallback."""
|
||||
data = {"user": {"name": "John"}}
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
assert accessor.get_nested_value(data, "user.age", default=25) == 25
|
||||
assert accessor.get_nested_value(data, "user.name", default="Unknown") == "John"
|
||||
|
||||
def test_set_nested_value(self):
|
||||
"""Test setting nested values."""
|
||||
data = {}
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
accessor.set_nested_value(data, "user.name", "John")
|
||||
assert data["user"]["name"] == "John"
|
||||
|
||||
accessor.set_nested_value(data, "user.address.city", "New York")
|
||||
assert data["user"]["address"]["city"] == "New York"
|
||||
|
||||
accessor.set_nested_value(data, "scores[0]", 0.95)
|
||||
assert data["scores"][0] == 0.95
|
||||
|
||||
def test_set_nested_value_overwrite(self):
|
||||
"""Test overwriting existing nested values."""
|
||||
data = {"user": {"name": "John", "age": 30}}
|
||||
accessor = NestedFieldAccessor()
|
||||
|
||||
accessor.set_nested_value(data, "user.name", "Jane")
|
||||
assert data["user"]["name"] == "Jane"
|
||||
assert data["user"]["age"] == 30 # Should not affect other fields
|
||||
|
||||
|
||||
class TestTemplateResolver:
|
||||
"""Test template string resolution."""
|
||||
|
||||
def test_resolve_simple_template(self):
|
||||
"""Test resolving simple template variables."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "Hello {name}, you are {age} years old"
|
||||
context = {"name": "John", "age": 30}
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "Hello John, you are 30 years old"
|
||||
|
||||
def test_resolve_nested_template(self):
|
||||
"""Test resolving nested field templates."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "User: {user.name} from {user.address.city}"
|
||||
context = {
|
||||
"user": {
|
||||
"name": "John",
|
||||
"address": {"city": "New York", "zip": "10001"}
|
||||
}
|
||||
}
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "User: John from New York"
|
||||
|
||||
def test_resolve_array_template(self):
|
||||
"""Test resolving array element templates."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "First result: {results[0].label} ({results[0].score})"
|
||||
context = {
|
||||
"results": [
|
||||
{"label": "car", "score": 0.95},
|
||||
{"label": "truck", "score": 0.87}
|
||||
]
|
||||
}
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "First result: car (0.95)"
|
||||
|
||||
def test_resolve_missing_variables(self):
|
||||
"""Test resolving templates with missing variables."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "Hello {name}, you are {age} years old"
|
||||
context = {"name": "John"} # Missing age
|
||||
|
||||
with pytest.raises(FieldMappingError) as exc_info:
|
||||
resolver.resolve(template, context)
|
||||
|
||||
assert "Variable 'age' not found" in str(exc_info.value)
|
||||
|
||||
def test_resolve_with_defaults(self):
|
||||
"""Test resolving templates with default values."""
|
||||
resolver = TemplateResolver(allow_missing=True)
|
||||
|
||||
template = "Hello {name}, you are {age|25} years old"
|
||||
context = {"name": "John"} # Missing age, should use default
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "Hello John, you are 25 years old"
|
||||
|
||||
def test_resolve_complex_template(self):
|
||||
"""Test resolving complex templates with multiple variable types."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "{camera_id}:{timestamp}:{session_id}:{results[0].class}_{bbox[0]}_{bbox[1]}"
|
||||
context = {
|
||||
"camera_id": "cam001",
|
||||
"timestamp": 1640995200000,
|
||||
"session_id": "sess123",
|
||||
"results": [{"class": "car", "confidence": 0.95}],
|
||||
"bbox": [100, 200, 300, 400]
|
||||
}
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "cam001:1640995200000:sess123:car_100_200"
|
||||
|
||||
def test_resolve_conditional_template(self):
|
||||
"""Test resolving conditional templates."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
# Simple conditional
|
||||
template = "{name} is {age > 18 ? 'adult' : 'minor'}"
|
||||
|
||||
context_adult = {"name": "John", "age": 25}
|
||||
result_adult = resolver.resolve(template, context_adult)
|
||||
assert result_adult == "John is adult"
|
||||
|
||||
context_minor = {"name": "Jane", "age": 16}
|
||||
result_minor = resolver.resolve(template, context_minor)
|
||||
assert result_minor == "Jane is minor"
|
||||
|
||||
def test_escape_braces(self):
|
||||
"""Test escaping braces in templates."""
|
||||
resolver = TemplateResolver()
|
||||
|
||||
template = "Literal {{braces}} and variable {name}"
|
||||
context = {"name": "John"}
|
||||
|
||||
result = resolver.resolve(template, context)
|
||||
assert result == "Literal {braces} and variable John"
|
||||
|
||||
|
||||
class TestMappingContext:
|
||||
"""Test mapping context data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test mapping context creation."""
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
detection=detection,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
assert context.camera_id == "camera_001"
|
||||
assert context.display_id == "display_001"
|
||||
assert context.session_id == "session_123"
|
||||
assert context.detection == detection
|
||||
assert context.timestamp == 1640995200000
|
||||
assert context.branch_results == {}
|
||||
assert context.metadata == {}
|
||||
|
||||
def test_add_branch_result(self):
|
||||
"""Test adding branch results to context."""
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Toyota", "confidence": 0.87})
|
||||
context.add_branch_result("car_bodytype_cls", {"body_type": "Sedan", "confidence": 0.82})
|
||||
|
||||
assert len(context.branch_results) == 2
|
||||
assert context.branch_results["car_brand_cls"]["brand"] == "Toyota"
|
||||
assert context.branch_results["car_bodytype_cls"]["body_type"] == "Sedan"
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting context to dictionary."""
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
detection=detection,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Toyota"})
|
||||
context.add_metadata("model_id", "yolo_v8")
|
||||
|
||||
context_dict = context.to_dict()
|
||||
|
||||
assert context_dict["camera_id"] == "camera_001"
|
||||
assert context_dict["display_id"] == "display_001"
|
||||
assert context_dict["session_id"] == "session_123"
|
||||
assert context_dict["timestamp"] == 1640995200000
|
||||
assert context_dict["class"] == "car"
|
||||
assert context_dict["confidence"] == 0.9
|
||||
assert context_dict["track_id"] == 1001
|
||||
assert context_dict["bbox"]["x1"] == 100
|
||||
assert context_dict["car_brand_cls"]["brand"] == "Toyota"
|
||||
assert context_dict["model_id"] == "yolo_v8"
|
||||
|
||||
def test_add_metadata(self):
|
||||
"""Test adding metadata to context."""
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
context.add_metadata("model_version", "v2.1")
|
||||
context.add_metadata("inference_time", 0.15)
|
||||
|
||||
assert context.metadata["model_version"] == "v2.1"
|
||||
assert context.metadata["inference_time"] == 0.15
|
||||
|
||||
|
||||
class TestFieldMapper:
|
||||
"""Test field mapping functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test field mapper initialization."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
assert isinstance(mapper.template_resolver, TemplateResolver)
|
||||
assert isinstance(mapper.field_accessor, NestedFieldAccessor)
|
||||
|
||||
def test_map_fields_simple(self):
|
||||
"""Test simple field mapping."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
field_mappings = {
|
||||
"camera_id": "{camera_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence_score": "{confidence}",
|
||||
"track_identifier": "{track_id}"
|
||||
}
|
||||
|
||||
detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
detection=detection,
|
||||
timestamp=1640995200000
|
||||
)
|
||||
|
||||
mapped_fields = mapper.map_fields(field_mappings, context)
|
||||
|
||||
assert mapped_fields["camera_id"] == "camera_001"
|
||||
assert mapped_fields["detection_class"] == "car"
|
||||
assert mapped_fields["confidence_score"] == 0.92
|
||||
assert mapped_fields["track_identifier"] == 1001
|
||||
|
||||
def test_map_fields_with_branch_results(self):
|
||||
"""Test field mapping with branch results."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
field_mappings = {
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_model": "{car_brand_cls.model}",
|
||||
"body_type": "{car_bodytype_cls.body_type}",
|
||||
"brand_confidence": "{car_brand_cls.confidence}",
|
||||
"combined_info": "{car_brand_cls.brand} {car_bodytype_cls.body_type}"
|
||||
}
|
||||
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
context.add_branch_result("car_brand_cls", {
|
||||
"brand": "Toyota",
|
||||
"model": "Camry",
|
||||
"confidence": 0.87
|
||||
})
|
||||
context.add_branch_result("car_bodytype_cls", {
|
||||
"body_type": "Sedan",
|
||||
"confidence": 0.82
|
||||
})
|
||||
|
||||
mapped_fields = mapper.map_fields(field_mappings, context)
|
||||
|
||||
assert mapped_fields["car_brand"] == "Toyota"
|
||||
assert mapped_fields["car_model"] == "Camry"
|
||||
assert mapped_fields["body_type"] == "Sedan"
|
||||
assert mapped_fields["brand_confidence"] == 0.87
|
||||
assert mapped_fields["combined_info"] == "Toyota Sedan"
|
||||
|
||||
def test_map_fields_bbox_access(self):
|
||||
"""Test field mapping with bounding box access."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
field_mappings = {
|
||||
"bbox_x1": "{bbox.x1}",
|
||||
"bbox_y1": "{bbox.y1}",
|
||||
"bbox_x2": "{bbox.x2}",
|
||||
"bbox_y2": "{bbox.y2}",
|
||||
"bbox_width": "{bbox.width}",
|
||||
"bbox_height": "{bbox.height}",
|
||||
"bbox_area": "{bbox.area}",
|
||||
"bbox_center_x": "{bbox.center_x}",
|
||||
"bbox_center_y": "{bbox.center_y}"
|
||||
}
|
||||
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
detection=detection
|
||||
)
|
||||
|
||||
mapped_fields = mapper.map_fields(field_mappings, context)
|
||||
|
||||
assert mapped_fields["bbox_x1"] == 100
|
||||
assert mapped_fields["bbox_y1"] == 200
|
||||
assert mapped_fields["bbox_x2"] == 300
|
||||
assert mapped_fields["bbox_y2"] == 400
|
||||
assert mapped_fields["bbox_width"] == 200 # 300 - 100
|
||||
assert mapped_fields["bbox_height"] == 200 # 400 - 200
|
||||
assert mapped_fields["bbox_area"] == 40000 # 200 * 200
|
||||
assert mapped_fields["bbox_center_x"] == 200 # (100 + 300) / 2
|
||||
assert mapped_fields["bbox_center_y"] == 300 # (200 + 400) / 2
|
||||
|
||||
def test_map_fields_with_sql_functions(self):
|
||||
"""Test field mapping with SQL function templates."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
field_mappings = {
|
||||
"created_at": "NOW()",
|
||||
"updated_at": "CURRENT_TIMESTAMP",
|
||||
"uuid_field": "UUID()",
|
||||
"json_data": "JSON_OBJECT('class', '{class}', 'confidence', {confidence})"
|
||||
}
|
||||
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
detection=detection
|
||||
)
|
||||
|
||||
mapped_fields = mapper.map_fields(field_mappings, context)
|
||||
|
||||
# SQL functions should pass through unchanged
|
||||
assert mapped_fields["created_at"] == "NOW()"
|
||||
assert mapped_fields["updated_at"] == "CURRENT_TIMESTAMP"
|
||||
assert mapped_fields["uuid_field"] == "UUID()"
|
||||
assert mapped_fields["json_data"] == "JSON_OBJECT('class', 'car', 'confidence', 0.9)"
|
||||
|
||||
def test_map_fields_missing_branch_data(self):
|
||||
"""Test field mapping with missing branch data."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
field_mappings = {
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_model": "{nonexistent_branch.model}"
|
||||
}
|
||||
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
# Only add one branch result
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Toyota"})
|
||||
|
||||
with pytest.raises(FieldMappingError) as exc_info:
|
||||
mapper.map_fields(field_mappings, context)
|
||||
|
||||
assert "nonexistent_branch.model" in str(exc_info.value)
|
||||
|
||||
def test_map_fields_with_defaults(self):
|
||||
"""Test field mapping with default values."""
|
||||
mapper = FieldMapper(allow_missing=True)
|
||||
|
||||
field_mappings = {
|
||||
"car_brand": "{car_brand_cls.brand|Unknown}",
|
||||
"car_model": "{car_brand_cls.model|N/A}",
|
||||
"confidence": "{confidence|0.0}"
|
||||
}
|
||||
|
||||
context = MappingContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
# Don't add any branch results
|
||||
mapped_fields = mapper.map_fields(field_mappings, context)
|
||||
|
||||
assert mapped_fields["car_brand"] == "Unknown"
|
||||
assert mapped_fields["car_model"] == "N/A"
|
||||
assert mapped_fields["confidence"] == "0.0"
|
||||
|
||||
def test_map_database_fields(self):
|
||||
"""Test mapping fields for database operations."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
# Database field mapping
|
||||
db_field_mappings = {
|
||||
"camera_id": "{camera_id}",
|
||||
"session_id": "{session_id}",
|
||||
"detection_timestamp": "{timestamp}",
|
||||
"object_class": "{class}",
|
||||
"detection_confidence": "{confidence}",
|
||||
"track_id": "{track_id}",
|
||||
"bbox_json": "JSON_OBJECT('x1', {bbox.x1}, 'y1', {bbox.y1}, 'x2', {bbox.x2}, 'y2', {bbox.y2})",
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type}",
|
||||
"license_plate": "{license_ocr.text}",
|
||||
"created_at": "NOW()",
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
|
||||
detection = DetectionResult("car", 0.93, BoundingBox(150, 250, 350, 450), 2001, 1640995300000)
|
||||
context = MappingContext(
|
||||
camera_id="camera_002",
|
||||
display_id="display_002",
|
||||
session_id="session_456",
|
||||
detection=detection,
|
||||
timestamp=1640995300000
|
||||
)
|
||||
|
||||
# Add branch results
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Honda", "confidence": 0.89})
|
||||
context.add_branch_result("car_bodytype_cls", {"body_type": "SUV", "confidence": 0.85})
|
||||
context.add_branch_result("license_ocr", {"text": "ABC-123", "confidence": 0.76})
|
||||
|
||||
mapped_fields = mapper.map_fields(db_field_mappings, context)
|
||||
|
||||
assert mapped_fields["camera_id"] == "camera_002"
|
||||
assert mapped_fields["session_id"] == "session_456"
|
||||
assert mapped_fields["detection_timestamp"] == 1640995300000
|
||||
assert mapped_fields["object_class"] == "car"
|
||||
assert mapped_fields["detection_confidence"] == 0.93
|
||||
assert mapped_fields["track_id"] == 2001
|
||||
assert mapped_fields["bbox_json"] == "JSON_OBJECT('x1', 150, 'y1', 250, 'x2', 350, 'y2', 450)"
|
||||
assert mapped_fields["car_brand"] == "Honda"
|
||||
assert mapped_fields["car_body_type"] == "SUV"
|
||||
assert mapped_fields["license_plate"] == "ABC-123"
|
||||
assert mapped_fields["created_at"] == "NOW()"
|
||||
assert mapped_fields["updated_at"] == "NOW()"
|
||||
|
||||
def test_map_redis_keys(self):
|
||||
"""Test mapping Redis key templates."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
key_templates = [
|
||||
"inference:{camera_id}:{timestamp}:{session_id}:car",
|
||||
"detection:{display_id}:{track_id}",
|
||||
"cropped_image:{camera_id}:{session_id}:{class}",
|
||||
"metadata:{session_id}:brands:{car_brand_cls.brand}",
|
||||
"tracking:{camera_id}:active_tracks"
|
||||
]
|
||||
|
||||
detection = DetectionResult("car", 0.88, BoundingBox(200, 300, 400, 500), 3001, 1640995400000)
|
||||
context = MappingContext(
|
||||
camera_id="camera_003",
|
||||
display_id="display_003",
|
||||
session_id="session_789",
|
||||
detection=detection,
|
||||
timestamp=1640995400000
|
||||
)
|
||||
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Ford"})
|
||||
|
||||
mapped_keys = [mapper.map_template(template, context) for template in key_templates]
|
||||
|
||||
expected_keys = [
|
||||
"inference:camera_003:1640995400000:session_789:car",
|
||||
"detection:display_003:3001",
|
||||
"cropped_image:camera_003:session_789:car",
|
||||
"metadata:session_789:brands:Ford",
|
||||
"tracking:camera_003:active_tracks"
|
||||
]
|
||||
|
||||
assert mapped_keys == expected_keys
|
||||
|
||||
def test_map_template(self):
|
||||
"""Test single template mapping."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
template = "Camera {camera_id} detected {class} with {confidence:.2f} confidence at {timestamp}"
|
||||
|
||||
detection = DetectionResult("truck", 0.876, BoundingBox(100, 150, 300, 350), 4001, 1640995500000)
|
||||
context = MappingContext(
|
||||
camera_id="camera_004",
|
||||
display_id="display_004",
|
||||
session_id="session_101",
|
||||
detection=detection,
|
||||
timestamp=1640995500000
|
||||
)
|
||||
|
||||
result = mapper.map_template(template, context)
|
||||
expected = "Camera camera_004 detected truck with 0.88 confidence at 1640995500000"
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_validate_field_mappings(self):
|
||||
"""Test field mapping validation."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
# Valid mappings
|
||||
valid_mappings = {
|
||||
"camera_id": "{camera_id}",
|
||||
"class": "{class}",
|
||||
"confidence": "{confidence}",
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
assert mapper.validate_field_mappings(valid_mappings) is True
|
||||
|
||||
# Invalid mappings (malformed templates)
|
||||
invalid_mappings = {
|
||||
"camera_id": "{camera_id", # Missing closing brace
|
||||
"class": "class}", # Missing opening brace
|
||||
"confidence": "{nonexistent_field}" # This might be valid depending on context
|
||||
}
|
||||
|
||||
with pytest.raises(FieldMappingError):
|
||||
mapper.validate_field_mappings(invalid_mappings)
|
||||
|
||||
def test_create_context_from_detection(self):
|
||||
"""Test creating mapping context from detection result."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
detection = DetectionResult("car", 0.95, BoundingBox(50, 100, 250, 300), 5001, 1640995600000)
|
||||
|
||||
context = mapper.create_context_from_detection(
|
||||
detection,
|
||||
camera_id="camera_005",
|
||||
display_id="display_005",
|
||||
session_id="session_202"
|
||||
)
|
||||
|
||||
assert context.camera_id == "camera_005"
|
||||
assert context.display_id == "display_005"
|
||||
assert context.session_id == "session_202"
|
||||
assert context.detection == detection
|
||||
assert context.timestamp == 1640995600000
|
||||
|
||||
def test_format_sql_value(self):
|
||||
"""Test SQL value formatting."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
# String values should be quoted
|
||||
assert mapper.format_sql_value("test_string") == "'test_string'"
|
||||
assert mapper.format_sql_value("John's car") == "'John''s car'" # Escape quotes
|
||||
|
||||
# Numeric values should not be quoted
|
||||
assert mapper.format_sql_value(42) == "42"
|
||||
assert mapper.format_sql_value(3.14) == "3.14"
|
||||
assert mapper.format_sql_value(0.95) == "0.95"
|
||||
|
||||
# Boolean values
|
||||
assert mapper.format_sql_value(True) == "TRUE"
|
||||
assert mapper.format_sql_value(False) == "FALSE"
|
||||
|
||||
# None/NULL values
|
||||
assert mapper.format_sql_value(None) == "NULL"
|
||||
|
||||
# SQL functions should pass through
|
||||
assert mapper.format_sql_value("NOW()") == "NOW()"
|
||||
assert mapper.format_sql_value("CURRENT_TIMESTAMP") == "CURRENT_TIMESTAMP"
|
||||
|
||||
|
||||
class TestFieldMapperIntegration:
|
||||
"""Integration tests for field mapping."""
|
||||
|
||||
def test_complete_mapping_workflow(self):
|
||||
"""Test complete field mapping workflow."""
|
||||
mapper = FieldMapper()
|
||||
|
||||
# Simulate complete detection workflow
|
||||
detection = DetectionResult("car", 0.91, BoundingBox(120, 180, 320, 380), 6001, 1640995700000)
|
||||
context = MappingContext(
|
||||
camera_id="camera_006",
|
||||
display_id="display_006",
|
||||
session_id="session_303",
|
||||
detection=detection,
|
||||
timestamp=1640995700000
|
||||
)
|
||||
|
||||
# Add comprehensive branch results
|
||||
context.add_branch_result("car_brand_cls", {
|
||||
"brand": "BMW",
|
||||
"model": "X5",
|
||||
"confidence": 0.84,
|
||||
"top3_brands": ["BMW", "Audi", "Mercedes"]
|
||||
})
|
||||
|
||||
context.add_branch_result("car_bodytype_cls", {
|
||||
"body_type": "SUV",
|
||||
"confidence": 0.79,
|
||||
"features": ["tall", "4_doors", "roof_rails"]
|
||||
})
|
||||
|
||||
context.add_branch_result("car_color_cls", {
|
||||
"color": "Black",
|
||||
"confidence": 0.73,
|
||||
"rgb_values": [20, 25, 30]
|
||||
})
|
||||
|
||||
context.add_branch_result("license_ocr", {
|
||||
"text": "XYZ-789",
|
||||
"confidence": 0.68,
|
||||
"region_bbox": [150, 320, 290, 360]
|
||||
})
|
||||
|
||||
# Database field mapping
|
||||
db_mappings = {
|
||||
"camera_id": "{camera_id}",
|
||||
"display_id": "{display_id}",
|
||||
"session_id": "{session_id}",
|
||||
"detection_timestamp": "{timestamp}",
|
||||
"object_class": "{class}",
|
||||
"detection_confidence": "{confidence}",
|
||||
"track_id": "{track_id}",
|
||||
"bbox_x1": "{bbox.x1}",
|
||||
"bbox_y1": "{bbox.y1}",
|
||||
"bbox_x2": "{bbox.x2}",
|
||||
"bbox_y2": "{bbox.y2}",
|
||||
"bbox_area": "{bbox.area}",
|
||||
"car_brand": "{car_brand_cls.brand}",
|
||||
"car_model": "{car_brand_cls.model}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type}",
|
||||
"car_color": "{car_color_cls.color}",
|
||||
"license_plate": "{license_ocr.text}",
|
||||
"brand_confidence": "{car_brand_cls.confidence}",
|
||||
"bodytype_confidence": "{car_bodytype_cls.confidence}",
|
||||
"color_confidence": "{car_color_cls.confidence}",
|
||||
"license_confidence": "{license_ocr.confidence}",
|
||||
"detection_summary": "{car_brand_cls.brand} {car_bodytype_cls.body_type} ({car_color_cls.color})",
|
||||
"created_at": "NOW()",
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
|
||||
mapped_db_fields = mapper.map_fields(db_mappings, context)
|
||||
|
||||
# Verify all mappings
|
||||
assert mapped_db_fields["camera_id"] == "camera_006"
|
||||
assert mapped_db_fields["session_id"] == "session_303"
|
||||
assert mapped_db_fields["object_class"] == "car"
|
||||
assert mapped_db_fields["detection_confidence"] == 0.91
|
||||
assert mapped_db_fields["track_id"] == 6001
|
||||
assert mapped_db_fields["bbox_area"] == 40000 # 200 * 200
|
||||
assert mapped_db_fields["car_brand"] == "BMW"
|
||||
assert mapped_db_fields["car_model"] == "X5"
|
||||
assert mapped_db_fields["car_body_type"] == "SUV"
|
||||
assert mapped_db_fields["car_color"] == "Black"
|
||||
assert mapped_db_fields["license_plate"] == "XYZ-789"
|
||||
assert mapped_db_fields["detection_summary"] == "BMW SUV (Black)"
|
||||
|
||||
# Redis key mapping
|
||||
redis_key_templates = [
|
||||
"detection:{camera_id}:{session_id}:main",
|
||||
"cropped:{camera_id}:{session_id}:car_image",
|
||||
"metadata:{session_id}:brand:{car_brand_cls.brand}",
|
||||
"tracking:{camera_id}:track_{track_id}",
|
||||
"classification:{session_id}:results"
|
||||
]
|
||||
|
||||
mapped_redis_keys = [
|
||||
mapper.map_template(template, context)
|
||||
for template in redis_key_templates
|
||||
]
|
||||
|
||||
expected_redis_keys = [
|
||||
"detection:camera_006:session_303:main",
|
||||
"cropped:camera_006:session_303:car_image",
|
||||
"metadata:session_303:brand:BMW",
|
||||
"tracking:camera_006:track_6001",
|
||||
"classification:session_303:results"
|
||||
]
|
||||
|
||||
assert mapped_redis_keys == expected_redis_keys
|
||||
|
||||
def test_error_handling_and_recovery(self):
|
||||
"""Test error handling and recovery in field mapping."""
|
||||
mapper = FieldMapper(allow_missing=True)
|
||||
|
||||
# Context with missing detection
|
||||
context = MappingContext(
|
||||
camera_id="camera_007",
|
||||
display_id="display_007",
|
||||
session_id="session_404"
|
||||
)
|
||||
|
||||
# Partial branch results
|
||||
context.add_branch_result("car_brand_cls", {"brand": "Unknown"})
|
||||
# Missing car_bodytype_cls branch
|
||||
|
||||
# Field mappings with some missing data
|
||||
mappings = {
|
||||
"camera_id": "{camera_id}",
|
||||
"detection_class": "{class|Unknown}",
|
||||
"confidence": "{confidence|0.0}",
|
||||
"car_brand": "{car_brand_cls.brand|N/A}",
|
||||
"car_body_type": "{car_bodytype_cls.body_type|Unknown}",
|
||||
"car_model": "{car_brand_cls.model|N/A}"
|
||||
}
|
||||
|
||||
mapped_fields = mapper.map_fields(mappings, context)
|
||||
|
||||
assert mapped_fields["camera_id"] == "camera_007"
|
||||
assert mapped_fields["detection_class"] == "Unknown"
|
||||
assert mapped_fields["confidence"] == "0.0"
|
||||
assert mapped_fields["car_brand"] == "Unknown"
|
||||
assert mapped_fields["car_body_type"] == "Unknown"
|
||||
assert mapped_fields["car_model"] == "N/A"
|
921
tests/unit/pipeline/test_pipeline_executor.py
Normal file
921
tests/unit/pipeline/test_pipeline_executor.py
Normal file
|
@ -0,0 +1,921 @@
|
|||
"""
|
||||
Unit tests for pipeline execution functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
|
||||
from detector_worker.pipeline.pipeline_executor import (
|
||||
PipelineExecutor,
|
||||
PipelineContext,
|
||||
PipelineResult,
|
||||
BranchResult,
|
||||
ExecutionMode
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
from detector_worker.core.exceptions import PipelineError, ModelError, ActionError
|
||||
|
||||
|
||||
class TestPipelineContext:
|
||||
"""Test pipeline context data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test pipeline context creation."""
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
)
|
||||
|
||||
assert context.camera_id == "camera_001"
|
||||
assert context.display_id == "display_001"
|
||||
assert context.session_id == "session_123"
|
||||
assert context.timestamp == 1640995200000
|
||||
assert context.frame_data.shape == (480, 640, 3)
|
||||
assert context.metadata == {}
|
||||
assert context.crop_region is None
|
||||
|
||||
def test_creation_with_crop_region(self):
|
||||
"""Test context creation with crop region."""
|
||||
crop_region = (100, 200, 300, 400)
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=np.zeros((480, 640, 3), dtype=np.uint8),
|
||||
crop_region=crop_region
|
||||
)
|
||||
|
||||
assert context.crop_region == crop_region
|
||||
|
||||
def test_add_metadata(self):
|
||||
"""Test adding metadata to context."""
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
)
|
||||
|
||||
context.add_metadata("model_id", "yolo_v8")
|
||||
context.add_metadata("confidence_threshold", 0.8)
|
||||
|
||||
assert context.metadata["model_id"] == "yolo_v8"
|
||||
assert context.metadata["confidence_threshold"] == 0.8
|
||||
|
||||
def test_get_cropped_frame(self):
|
||||
"""Test getting cropped frame."""
|
||||
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=frame,
|
||||
crop_region=(100, 200, 300, 400)
|
||||
)
|
||||
|
||||
cropped = context.get_cropped_frame()
|
||||
|
||||
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
|
||||
assert np.all(cropped == 255)
|
||||
|
||||
def test_get_cropped_frame_no_crop(self):
|
||||
"""Test getting frame when no crop region specified."""
|
||||
frame = np.ones((480, 640, 3), dtype=np.uint8) * 255
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=frame
|
||||
)
|
||||
|
||||
cropped = context.get_cropped_frame()
|
||||
|
||||
assert np.array_equal(cropped, frame)
|
||||
|
||||
|
||||
class TestBranchResult:
|
||||
"""Test branch execution result."""
|
||||
|
||||
def test_creation_success(self):
|
||||
"""Test successful branch result creation."""
|
||||
detections = [
|
||||
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
]
|
||||
|
||||
result = BranchResult(
|
||||
branch_id="car_brand_cls",
|
||||
success=True,
|
||||
detections=detections,
|
||||
metadata={"brand": "Toyota"},
|
||||
execution_time=0.15
|
||||
)
|
||||
|
||||
assert result.branch_id == "car_brand_cls"
|
||||
assert result.success is True
|
||||
assert len(result.detections) == 1
|
||||
assert result.metadata["brand"] == "Toyota"
|
||||
assert result.execution_time == 0.15
|
||||
assert result.error is None
|
||||
|
||||
def test_creation_failure(self):
|
||||
"""Test failed branch result creation."""
|
||||
result = BranchResult(
|
||||
branch_id="car_brand_cls",
|
||||
success=False,
|
||||
error="Model inference failed",
|
||||
execution_time=0.05
|
||||
)
|
||||
|
||||
assert result.branch_id == "car_brand_cls"
|
||||
assert result.success is False
|
||||
assert result.detections == []
|
||||
assert result.metadata == {}
|
||||
assert result.error == "Model inference failed"
|
||||
|
||||
|
||||
class TestPipelineResult:
|
||||
"""Test pipeline execution result."""
|
||||
|
||||
def test_creation_success(self):
|
||||
"""Test successful pipeline result creation."""
|
||||
main_detections = [
|
||||
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
]
|
||||
|
||||
branch_results = {
|
||||
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
|
||||
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
|
||||
}
|
||||
|
||||
result = PipelineResult(
|
||||
success=True,
|
||||
detections=main_detections,
|
||||
branch_results=branch_results,
|
||||
total_execution_time=0.5
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.detections) == 1
|
||||
assert len(result.branch_results) == 2
|
||||
assert result.total_execution_time == 0.5
|
||||
assert result.error is None
|
||||
|
||||
def test_creation_failure(self):
|
||||
"""Test failed pipeline result creation."""
|
||||
result = PipelineResult(
|
||||
success=False,
|
||||
error="Pipeline execution failed",
|
||||
total_execution_time=0.1
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.detections == []
|
||||
assert result.branch_results == {}
|
||||
assert result.error == "Pipeline execution failed"
|
||||
|
||||
def test_get_combined_results(self):
|
||||
"""Test getting combined results from all branches."""
|
||||
main_detections = [
|
||||
DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
]
|
||||
|
||||
branch_results = {
|
||||
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
|
||||
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12)
|
||||
}
|
||||
|
||||
result = PipelineResult(
|
||||
success=True,
|
||||
detections=main_detections,
|
||||
branch_results=branch_results,
|
||||
total_execution_time=0.5
|
||||
)
|
||||
|
||||
combined = result.get_combined_results()
|
||||
|
||||
assert "brand" in combined
|
||||
assert "body_type" in combined
|
||||
assert combined["brand"] == "Toyota"
|
||||
assert combined["body_type"] == "Sedan"
|
||||
|
||||
|
||||
class TestPipelineExecutor:
|
||||
"""Test pipeline execution functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test pipeline executor initialization."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
assert isinstance(executor.thread_pool, ThreadPoolExecutor)
|
||||
assert executor.max_workers == 4
|
||||
assert executor.execution_mode == ExecutionMode.PARALLEL
|
||||
assert executor.timeout == 30.0
|
||||
|
||||
def test_initialization_custom_config(self):
|
||||
"""Test initialization with custom configuration."""
|
||||
config = {
|
||||
"max_workers": 8,
|
||||
"execution_mode": "sequential",
|
||||
"timeout": 60.0
|
||||
}
|
||||
|
||||
executor = PipelineExecutor(config)
|
||||
|
||||
assert executor.max_workers == 8
|
||||
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
|
||||
assert executor.timeout == 60.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_simple(self, mock_yolo_model, mock_frame):
|
||||
"""Test simple pipeline execution."""
|
||||
# Mock pipeline configuration
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [],
|
||||
"actions": []
|
||||
}
|
||||
|
||||
# Mock detection result
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
executor = PipelineExecutor()
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.detections) == 1
|
||||
assert result.detections[0].class_name == "0" # Default class name
|
||||
assert result.detections[0].confidence == 0.9
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_with_branches(self, mock_yolo_model, mock_frame):
|
||||
"""Test pipeline execution with classification branches."""
|
||||
import torch
|
||||
|
||||
# Mock main detection
|
||||
mock_detection_result = Mock()
|
||||
mock_detection_result.boxes = Mock()
|
||||
mock_detection_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0] # car detection
|
||||
])
|
||||
mock_detection_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
# Mock classification results
|
||||
mock_brand_result = Mock()
|
||||
mock_brand_result.probs = Mock()
|
||||
mock_brand_result.probs.top1 = 2 # Toyota
|
||||
mock_brand_result.probs.top1conf = 0.85
|
||||
|
||||
mock_bodytype_result = Mock()
|
||||
mock_bodytype_result.probs = Mock()
|
||||
mock_bodytype_result.probs.top1 = 1 # Sedan
|
||||
mock_bodytype_result.probs.top1conf = 0.78
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_detection_result]
|
||||
mock_yolo_model.predict.return_value = [mock_brand_result]
|
||||
|
||||
mock_brand_model = Mock()
|
||||
mock_brand_model.predict.return_value = [mock_brand_result]
|
||||
mock_brand_model.names = {0: "Honda", 1: "Ford", 2: "Toyota"}
|
||||
|
||||
mock_bodytype_model = Mock()
|
||||
mock_bodytype_model.predict.return_value = [mock_bodytype_result]
|
||||
mock_bodytype_model.names = {0: "SUV", 1: "Sedan", 2: "Hatchback"}
|
||||
|
||||
# Pipeline configuration with branches
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [
|
||||
{
|
||||
"modelId": "car_brand_cls",
|
||||
"modelFile": "car_brand.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True,
|
||||
"crop": True,
|
||||
"cropClass": "car"
|
||||
},
|
||||
{
|
||||
"modelId": "car_bodytype_cls",
|
||||
"modelFile": "car_bodytype.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True,
|
||||
"crop": True,
|
||||
"cropClass": "car"
|
||||
}
|
||||
],
|
||||
"actions": []
|
||||
}
|
||||
|
||||
executor = PipelineExecutor()
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
def get_model_side_effect(model_id, camera_id):
|
||||
if model_id == "car_detection_v1":
|
||||
return mock_yolo_model
|
||||
elif model_id == "car_brand_cls":
|
||||
return mock_brand_model
|
||||
elif model_id == "car_bodytype_cls":
|
||||
return mock_bodytype_model
|
||||
return None
|
||||
|
||||
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.detections) == 1
|
||||
assert len(result.branch_results) == 2
|
||||
|
||||
# Check branch results
|
||||
assert "car_brand_cls" in result.branch_results
|
||||
assert "car_bodytype_cls" in result.branch_results
|
||||
|
||||
brand_result = result.branch_results["car_brand_cls"]
|
||||
assert brand_result.success is True
|
||||
assert brand_result.metadata.get("brand") == "Toyota"
|
||||
|
||||
bodytype_result = result.branch_results["car_bodytype_cls"]
|
||||
assert bodytype_result.success is True
|
||||
assert bodytype_result.metadata.get("body_type") == "Sedan"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_sequential_mode(self, mock_yolo_model, mock_frame):
|
||||
"""Test pipeline execution in sequential mode."""
|
||||
import torch
|
||||
|
||||
config = {"execution_mode": "sequential"}
|
||||
executor = PipelineExecutor(config)
|
||||
|
||||
# Mock detection result
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [
|
||||
{
|
||||
"modelId": "car_brand_cls",
|
||||
"modelFile": "car_brand.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": False # Sequential execution
|
||||
}
|
||||
],
|
||||
"actions": []
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is True
|
||||
assert executor.execution_mode == ExecutionMode.SEQUENTIAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_with_actions(self, mock_yolo_model, mock_frame):
|
||||
"""Test pipeline execution with actions."""
|
||||
import torch
|
||||
|
||||
# Mock detection result
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = Mock()
|
||||
mock_result.boxes.data = torch.tensor([
|
||||
[100, 200, 300, 400, 0.9, 0]
|
||||
])
|
||||
mock_result.boxes.id = torch.tensor([1001])
|
||||
|
||||
mock_yolo_model.track.return_value = [mock_result]
|
||||
|
||||
# Pipeline configuration with actions
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [],
|
||||
"actions": [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "car",
|
||||
"key": "inference:{display_id}:{timestamp}:{session_id}",
|
||||
"expire_seconds": 600
|
||||
},
|
||||
{
|
||||
"type": "postgresql_insert",
|
||||
"table": "detections",
|
||||
"fields": {
|
||||
"camera_id": "{camera_id}",
|
||||
"detection_class": "{class}",
|
||||
"confidence": "{confidence}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
executor = PipelineExecutor()
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager, \
|
||||
patch('detector_worker.pipeline.action_executor.ActionExecutor') as mock_action_executor:
|
||||
|
||||
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
|
||||
mock_action_executor.return_value.execute_actions = AsyncMock(return_value=True)
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is True
|
||||
# Actions should be executed
|
||||
mock_action_executor.return_value.execute_actions.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_model_error(self, mock_frame):
|
||||
"""Test pipeline execution with model error."""
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [],
|
||||
"actions": []
|
||||
}
|
||||
|
||||
executor = PipelineExecutor()
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
# Model manager raises error
|
||||
mock_model_manager.return_value.get_model.side_effect = ModelError("Model not found")
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is False
|
||||
assert "Model not found" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_timeout(self, mock_yolo_model, mock_frame):
|
||||
"""Test pipeline execution timeout."""
|
||||
import torch
|
||||
|
||||
# Configure short timeout
|
||||
config = {"timeout": 0.001} # Very short timeout
|
||||
executor = PipelineExecutor(config)
|
||||
|
||||
# Mock slow model inference
|
||||
def slow_inference(*args, **kwargs):
|
||||
import time
|
||||
time.sleep(1) # Longer than timeout
|
||||
mock_result = Mock()
|
||||
mock_result.boxes = None
|
||||
return [mock_result]
|
||||
|
||||
mock_yolo_model.track.side_effect = slow_inference
|
||||
|
||||
pipeline_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8,
|
||||
"branches": [],
|
||||
"actions": []
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.return_value = mock_yolo_model
|
||||
|
||||
result = await executor.execute_pipeline(pipeline_config, context)
|
||||
|
||||
assert result.success is False
|
||||
assert "timeout" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_branch_parallel(self, mock_frame):
|
||||
"""Test parallel branch execution."""
|
||||
import torch
|
||||
|
||||
# Mock classification model
|
||||
mock_brand_model = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.probs = Mock()
|
||||
mock_result.probs.top1 = 1
|
||||
mock_result.probs.top1conf = 0.85
|
||||
mock_brand_model.predict.return_value = [mock_result]
|
||||
mock_brand_model.names = {0: "Honda", 1: "Toyota", 2: "Ford"}
|
||||
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Branch configuration
|
||||
branch_config = {
|
||||
"modelId": "car_brand_cls",
|
||||
"modelFile": "car_brand.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True,
|
||||
"crop": True,
|
||||
"cropClass": "car"
|
||||
}
|
||||
|
||||
# Mock detected regions
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.return_value = mock_brand_model
|
||||
|
||||
result = await executor._execute_branch(branch_config, regions, context)
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch_id == "car_brand_cls"
|
||||
assert result.metadata.get("brand") == "Toyota"
|
||||
assert result.execution_time > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_branch_no_trigger_class(self, mock_frame):
|
||||
"""Test branch execution when trigger class not detected."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
branch_config = {
|
||||
"modelId": "car_brand_cls",
|
||||
"modelFile": "car_brand.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7
|
||||
}
|
||||
|
||||
# No car detected
|
||||
regions = {
|
||||
"truck": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("truck", 0.9, BoundingBox(100, 200, 300, 400), 1002)
|
||||
}
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
result = await executor._execute_branch(branch_config, regions, context)
|
||||
|
||||
assert result.success is False
|
||||
assert "trigger class not detected" in result.error.lower()
|
||||
|
||||
def test_wait_for_branches(self):
|
||||
"""Test waiting for specific branches to complete."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Mock completed branch results
|
||||
branch_results = {
|
||||
"car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1),
|
||||
"car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12),
|
||||
"license_ocr": BranchResult("license_ocr", True, [], {"license": "ABC123"}, 0.2)
|
||||
}
|
||||
|
||||
# Wait for specific branches
|
||||
wait_for = ["car_brand_cls", "car_bodytype_cls"]
|
||||
completed = executor._wait_for_branches(branch_results, wait_for, timeout=1.0)
|
||||
|
||||
assert completed is True
|
||||
|
||||
# Wait for non-existent branch (should timeout)
|
||||
wait_for_missing = ["car_brand_cls", "nonexistent_branch"]
|
||||
completed = executor._wait_for_branches(branch_results, wait_for_missing, timeout=0.1)
|
||||
|
||||
assert completed is False
|
||||
|
||||
def test_validate_pipeline_config(self):
|
||||
"""Test pipeline configuration validation."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Valid configuration
|
||||
valid_config = {
|
||||
"modelId": "car_detection_v1",
|
||||
"modelFile": "car_detection.pt",
|
||||
"expectedClasses": ["car"],
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.8
|
||||
}
|
||||
|
||||
assert executor._validate_pipeline_config(valid_config) is True
|
||||
|
||||
# Invalid configuration (missing required fields)
|
||||
invalid_config = {
|
||||
"modelFile": "car_detection.pt"
|
||||
# Missing modelId
|
||||
}
|
||||
|
||||
with pytest.raises(PipelineError):
|
||||
executor._validate_pipeline_config(invalid_config)
|
||||
|
||||
def test_crop_frame_for_detection(self, mock_frame):
|
||||
"""Test frame cropping for detection."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
|
||||
cropped = executor._crop_frame_for_detection(mock_frame, detection)
|
||||
|
||||
assert cropped.shape == (200, 200, 3) # 400-200, 300-100
|
||||
|
||||
def test_crop_frame_invalid_bounds(self, mock_frame):
|
||||
"""Test frame cropping with invalid bounds."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Detection outside frame bounds
|
||||
detection = DetectionResult("car", 0.9, BoundingBox(-100, -200, 50, 100), 1001)
|
||||
|
||||
cropped = executor._crop_frame_for_detection(mock_frame, detection)
|
||||
|
||||
# Should handle bounds gracefully
|
||||
assert cropped.shape[0] > 0
|
||||
assert cropped.shape[1] > 0
|
||||
|
||||
|
||||
class TestPipelineExecutorPerformance:
|
||||
"""Test pipeline executor performance and optimization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_branch_execution_performance(self, mock_frame):
|
||||
"""Test that parallel execution is faster than sequential."""
|
||||
import time
|
||||
import torch
|
||||
|
||||
def slow_inference(*args, **kwargs):
|
||||
time.sleep(0.1) # Simulate slow inference
|
||||
mock_result = Mock()
|
||||
mock_result.probs = Mock()
|
||||
mock_result.probs.top1 = 1
|
||||
mock_result.probs.top1conf = 0.85
|
||||
return [mock_result]
|
||||
|
||||
mock_model = Mock()
|
||||
mock_model.predict.side_effect = slow_inference
|
||||
mock_model.names = {0: "Class0", 1: "Class1"}
|
||||
|
||||
# Test parallel execution
|
||||
parallel_executor = PipelineExecutor({"execution_mode": "parallel", "max_workers": 2})
|
||||
|
||||
branch_configs = [
|
||||
{
|
||||
"modelId": f"branch_{i}",
|
||||
"modelFile": f"branch_{i}.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True
|
||||
}
|
||||
for i in range(3) # 3 branches
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.return_value = mock_model
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Execute branches in parallel
|
||||
tasks = [
|
||||
parallel_executor._execute_branch(config, regions, context)
|
||||
for config in branch_configs
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
parallel_time = time.time() - start_time
|
||||
|
||||
# Parallel execution should be faster than 3 * 0.1 seconds
|
||||
assert parallel_time < 0.25 # Allow some overhead
|
||||
assert len(results) == 3
|
||||
assert all(result.success for result in results)
|
||||
|
||||
def test_thread_pool_management(self):
|
||||
"""Test thread pool creation and management."""
|
||||
# Test different worker counts
|
||||
for workers in [1, 2, 4, 8]:
|
||||
executor = PipelineExecutor({"max_workers": workers})
|
||||
assert executor.max_workers == workers
|
||||
assert executor.thread_pool._max_workers == workers
|
||||
|
||||
def test_memory_management_large_frames(self):
|
||||
"""Test memory management with large frames."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Create large frame
|
||||
large_frame = np.ones((1080, 1920, 3), dtype=np.uint8) * 128
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=large_frame,
|
||||
crop_region=(500, 400, 1000, 800)
|
||||
)
|
||||
|
||||
# Get cropped frame
|
||||
cropped = context.get_cropped_frame()
|
||||
|
||||
# Should reduce memory usage
|
||||
assert cropped.shape == (400, 500, 3) # Much smaller than original
|
||||
assert cropped.nbytes < large_frame.nbytes
|
||||
|
||||
|
||||
class TestPipelineExecutorErrorHandling:
|
||||
"""Test comprehensive error handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_branch_execution_error_isolation(self, mock_frame):
|
||||
"""Test that errors in one branch don't affect others."""
|
||||
executor = PipelineExecutor()
|
||||
|
||||
# Mock models - one fails, one succeeds
|
||||
failing_model = Mock()
|
||||
failing_model.predict.side_effect = Exception("Model crashed")
|
||||
|
||||
success_model = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.probs = Mock()
|
||||
mock_result.probs.top1 = 1
|
||||
mock_result.probs.top1conf = 0.85
|
||||
success_model.predict.return_value = [mock_result]
|
||||
success_model.names = {0: "Class0", 1: "Class1"}
|
||||
|
||||
branch_configs = [
|
||||
{
|
||||
"modelId": "failing_branch",
|
||||
"modelFile": "failing.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True
|
||||
},
|
||||
{
|
||||
"modelId": "success_branch",
|
||||
"modelFile": "success.pt",
|
||||
"triggerClasses": ["car"],
|
||||
"minConfidence": 0.7,
|
||||
"parallel": True
|
||||
}
|
||||
]
|
||||
|
||||
regions = {
|
||||
"car": {
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"confidence": 0.9,
|
||||
"detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001)
|
||||
}
|
||||
}
|
||||
|
||||
context = PipelineContext(
|
||||
camera_id="camera_001",
|
||||
display_id="display_001",
|
||||
session_id="session_123",
|
||||
timestamp=1640995200000,
|
||||
frame_data=mock_frame
|
||||
)
|
||||
|
||||
def get_model_side_effect(model_id, camera_id):
|
||||
if model_id == "failing_branch":
|
||||
return failing_model
|
||||
elif model_id == "success_branch":
|
||||
return success_model
|
||||
return None
|
||||
|
||||
with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager:
|
||||
mock_model_manager.return_value.get_model.side_effect = get_model_side_effect
|
||||
|
||||
# Execute branches
|
||||
tasks = [
|
||||
executor._execute_branch(config, regions, context)
|
||||
for config in branch_configs
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# One should fail, one should succeed
|
||||
failing_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "failing_branch")
|
||||
success_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "success_branch")
|
||||
|
||||
assert failing_result.success is False
|
||||
assert "Model crashed" in failing_result.error
|
||||
|
||||
assert success_result.success is True
|
||||
assert success_result.error is None
|
976
tests/unit/storage/test_database_manager.py
Normal file
976
tests/unit/storage/test_database_manager.py
Normal file
|
@ -0,0 +1,976 @@
|
|||
"""
|
||||
Unit tests for database management functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from datetime import datetime, timedelta
|
||||
import psycopg2
|
||||
import uuid
|
||||
|
||||
from detector_worker.storage.database_manager import (
|
||||
DatabaseManager,
|
||||
DatabaseConfig,
|
||||
DatabaseConnection,
|
||||
QueryBuilder,
|
||||
TransactionManager,
|
||||
DatabaseError,
|
||||
ConnectionPoolError
|
||||
)
|
||||
from detector_worker.core.exceptions import ConfigurationError
|
||||
|
||||
|
||||
class TestDatabaseConfig:
|
||||
"""Test database configuration."""
|
||||
|
||||
def test_creation_minimal(self):
|
||||
"""Test creating database config with minimal parameters."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
assert config.host == "localhost"
|
||||
assert config.port == 5432 # Default port
|
||||
assert config.database == "test_db"
|
||||
assert config.username == "test_user"
|
||||
assert config.password == "test_pass"
|
||||
assert config.schema == "public" # Default schema
|
||||
assert config.enabled is True
|
||||
|
||||
def test_creation_full(self):
|
||||
"""Test creating database config with all parameters."""
|
||||
config = DatabaseConfig(
|
||||
host="db.example.com",
|
||||
port=5433,
|
||||
database="production_db",
|
||||
username="prod_user",
|
||||
password="secure_pass",
|
||||
schema="gas_station_1",
|
||||
enabled=True,
|
||||
pool_min_conn=2,
|
||||
pool_max_conn=20,
|
||||
pool_timeout=30.0,
|
||||
connection_timeout=10.0,
|
||||
ssl_mode="require"
|
||||
)
|
||||
|
||||
assert config.host == "db.example.com"
|
||||
assert config.port == 5433
|
||||
assert config.database == "production_db"
|
||||
assert config.schema == "gas_station_1"
|
||||
assert config.pool_min_conn == 2
|
||||
assert config.pool_max_conn == 20
|
||||
assert config.ssl_mode == "require"
|
||||
|
||||
def test_get_connection_string(self):
|
||||
"""Test generating connection string."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn_string = config.get_connection_string()
|
||||
|
||||
expected = "host=localhost port=5432 database=test_db user=test_user password=test_pass"
|
||||
assert conn_string == expected
|
||||
|
||||
def test_get_connection_string_with_ssl(self):
|
||||
"""Test generating connection string with SSL."""
|
||||
config = DatabaseConfig(
|
||||
host="db.example.com",
|
||||
database="secure_db",
|
||||
username="user",
|
||||
password="pass",
|
||||
ssl_mode="require"
|
||||
)
|
||||
|
||||
conn_string = config.get_connection_string()
|
||||
|
||||
assert "sslmode=require" in conn_string
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"host": "test-host",
|
||||
"port": 5433,
|
||||
"database": "test-db",
|
||||
"username": "test-user",
|
||||
"password": "test-pass",
|
||||
"schema": "test_schema",
|
||||
"pool_max_conn": 15,
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = DatabaseConfig.from_dict(config_dict)
|
||||
|
||||
assert config.host == "test-host"
|
||||
assert config.port == 5433
|
||||
assert config.database == "test-db"
|
||||
assert config.schema == "test_schema"
|
||||
assert config.pool_max_conn == 15
|
||||
|
||||
|
||||
class TestQueryBuilder:
|
||||
"""Test SQL query building functionality."""
|
||||
|
||||
def test_build_select_query(self):
|
||||
"""Test building SELECT queries."""
|
||||
builder = QueryBuilder("test_schema")
|
||||
|
||||
query, params = builder.build_select_query(
|
||||
table="users",
|
||||
columns=["id", "name", "email"],
|
||||
where={"status": "active", "age": 25},
|
||||
order_by="created_at DESC",
|
||||
limit=10
|
||||
)
|
||||
|
||||
expected_query = (
|
||||
"SELECT id, name, email FROM test_schema.users "
|
||||
"WHERE status = %s AND age = %s "
|
||||
"ORDER BY created_at DESC LIMIT 10"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["active", 25]
|
||||
|
||||
def test_build_select_all_columns(self):
|
||||
"""Test building SELECT * query."""
|
||||
builder = QueryBuilder("public")
|
||||
|
||||
query, params = builder.build_select_query("products")
|
||||
|
||||
expected_query = "SELECT * FROM public.products"
|
||||
assert query == expected_query
|
||||
assert params == []
|
||||
|
||||
def test_build_insert_query(self):
|
||||
"""Test building INSERT queries."""
|
||||
builder = QueryBuilder("inventory")
|
||||
|
||||
data = {
|
||||
"product_name": "Widget",
|
||||
"price": 19.99,
|
||||
"quantity": 100,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
query, params = builder.build_insert_query("products", data)
|
||||
|
||||
expected_query = (
|
||||
"INSERT INTO inventory.products (product_name, price, quantity, created_at) "
|
||||
"VALUES (%s, %s, %s, NOW()) RETURNING id"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["Widget", 19.99, 100]
|
||||
|
||||
def test_build_update_query(self):
|
||||
"""Test building UPDATE queries."""
|
||||
builder = QueryBuilder("sales")
|
||||
|
||||
data = {
|
||||
"status": "shipped",
|
||||
"shipped_date": "NOW()",
|
||||
"tracking_number": "ABC123"
|
||||
}
|
||||
|
||||
where_conditions = {"order_id": 12345}
|
||||
|
||||
query, params = builder.build_update_query("orders", data, where_conditions)
|
||||
|
||||
expected_query = (
|
||||
"UPDATE sales.orders SET status = %s, shipped_date = NOW(), tracking_number = %s "
|
||||
"WHERE order_id = %s"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["shipped", "ABC123", 12345]
|
||||
|
||||
def test_build_delete_query(self):
|
||||
"""Test building DELETE queries."""
|
||||
builder = QueryBuilder("logs")
|
||||
|
||||
where_conditions = {
|
||||
"level": "DEBUG",
|
||||
"created_at": "< NOW() - INTERVAL '7 days'"
|
||||
}
|
||||
|
||||
query, params = builder.build_delete_query("application_logs", where_conditions)
|
||||
|
||||
expected_query = (
|
||||
"DELETE FROM logs.application_logs "
|
||||
"WHERE level = %s AND created_at < NOW() - INTERVAL '7 days'"
|
||||
)
|
||||
|
||||
assert query == expected_query
|
||||
assert params == ["DEBUG"]
|
||||
|
||||
def test_build_create_table_query(self):
|
||||
"""Test building CREATE TABLE queries."""
|
||||
builder = QueryBuilder("gas_station_1")
|
||||
|
||||
columns = {
|
||||
"id": "SERIAL PRIMARY KEY",
|
||||
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
|
||||
"camera_id": "VARCHAR(255) NOT NULL",
|
||||
"detection_class": "VARCHAR(100)",
|
||||
"confidence": "DECIMAL(4,3)",
|
||||
"bbox_data": "JSON",
|
||||
"created_at": "TIMESTAMP DEFAULT NOW()",
|
||||
"updated_at": "TIMESTAMP DEFAULT NOW()"
|
||||
}
|
||||
|
||||
query = builder.build_create_table_query("detections", columns)
|
||||
|
||||
expected_parts = [
|
||||
"CREATE TABLE IF NOT EXISTS gas_station_1.detections",
|
||||
"id SERIAL PRIMARY KEY",
|
||||
"session_id VARCHAR(255) UNIQUE NOT NULL",
|
||||
"camera_id VARCHAR(255) NOT NULL",
|
||||
"bbox_data JSON",
|
||||
"created_at TIMESTAMP DEFAULT NOW()"
|
||||
]
|
||||
|
||||
for part in expected_parts:
|
||||
assert part in query
|
||||
|
||||
def test_escape_identifier(self):
|
||||
"""Test SQL identifier escaping."""
|
||||
builder = QueryBuilder("test")
|
||||
|
||||
assert builder.escape_identifier("table") == '"table"'
|
||||
assert builder.escape_identifier("column_name") == '"column_name"'
|
||||
assert builder.escape_identifier("user-table") == '"user-table"'
|
||||
|
||||
def test_format_value_for_sql(self):
|
||||
"""Test SQL value formatting."""
|
||||
builder = QueryBuilder("test")
|
||||
|
||||
# Regular values should use placeholder
|
||||
assert builder.format_value_for_sql("string") == ("%s", "string")
|
||||
assert builder.format_value_for_sql(42) == ("%s", 42)
|
||||
assert builder.format_value_for_sql(3.14) == ("%s", 3.14)
|
||||
|
||||
# SQL functions should be literal
|
||||
assert builder.format_value_for_sql("NOW()") == ("NOW()", None)
|
||||
assert builder.format_value_for_sql("CURRENT_TIMESTAMP") == ("CURRENT_TIMESTAMP", None)
|
||||
assert builder.format_value_for_sql("UUID()") == ("UUID()", None)
|
||||
|
||||
|
||||
class TestDatabaseConnection:
|
||||
"""Test database connection management."""
|
||||
|
||||
def test_creation(self, mock_database_connection):
|
||||
"""Test connection creation."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
assert conn.config == config
|
||||
assert conn.connection == mock_database_connection
|
||||
assert conn.is_connected is True
|
||||
|
||||
def test_execute_query(self, mock_database_connection):
|
||||
"""Test query execution."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchall.return_value = [
|
||||
(1, "John", "john@example.com"),
|
||||
(2, "Jane", "jane@example.com")
|
||||
]
|
||||
mock_cursor.rowcount = 2
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
query = "SELECT id, name, email FROM users WHERE status = %s"
|
||||
params = ["active"]
|
||||
|
||||
result = conn.execute_query(query, params)
|
||||
|
||||
assert result == [
|
||||
(1, "John", "john@example.com"),
|
||||
(2, "Jane", "jane@example.com")
|
||||
]
|
||||
|
||||
mock_cursor.execute.assert_called_once_with(query, params)
|
||||
mock_cursor.fetchall.assert_called_once()
|
||||
|
||||
def test_execute_query_single_result(self, mock_database_connection):
|
||||
"""Test query execution with single result."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchone.return_value = (1, "John", "john@example.com")
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
result = conn.execute_query("SELECT * FROM users WHERE id = %s", [1], fetch_one=True)
|
||||
|
||||
assert result == (1, "John", "john@example.com")
|
||||
mock_cursor.fetchone.assert_called_once()
|
||||
|
||||
def test_execute_query_no_fetch(self, mock_database_connection):
|
||||
"""Test query execution without fetching results."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
result = conn.execute_query(
|
||||
"INSERT INTO users (name) VALUES (%s)",
|
||||
["John"],
|
||||
fetch_results=False
|
||||
)
|
||||
|
||||
assert result == 1 # Row count
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_cursor.fetchall.assert_not_called()
|
||||
mock_cursor.fetchone.assert_not_called()
|
||||
|
||||
def test_execute_query_error(self, mock_database_connection):
|
||||
"""Test query execution error handling."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.execute.side_effect = psycopg2.Error("Database error")
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
with pytest.raises(DatabaseError) as exc_info:
|
||||
conn.execute_query("SELECT * FROM invalid_table")
|
||||
|
||||
assert "Database error" in str(exc_info.value)
|
||||
|
||||
def test_commit_transaction(self, mock_database_connection):
|
||||
"""Test transaction commit."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
conn.commit()
|
||||
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
def test_rollback_transaction(self, mock_database_connection):
|
||||
"""Test transaction rollback."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
conn.rollback()
|
||||
|
||||
mock_database_connection.rollback.assert_called_once()
|
||||
|
||||
def test_close_connection(self, mock_database_connection):
|
||||
"""Test connection closing."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
conn.close()
|
||||
|
||||
assert conn.is_connected is False
|
||||
mock_database_connection.close.assert_called_once()
|
||||
|
||||
|
||||
class TestTransactionManager:
|
||||
"""Test transaction management."""
|
||||
|
||||
def test_transaction_context_success(self, mock_database_connection):
|
||||
"""Test successful transaction context."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
tx_manager = TransactionManager(conn)
|
||||
|
||||
with tx_manager:
|
||||
# Simulate some database operations
|
||||
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
||||
|
||||
# Should commit on successful exit
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
mock_database_connection.rollback.assert_not_called()
|
||||
|
||||
def test_transaction_context_error(self, mock_database_connection):
|
||||
"""Test transaction context with error."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
conn = DatabaseConnection(config, mock_database_connection)
|
||||
tx_manager = TransactionManager(conn)
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
with tx_manager:
|
||||
conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
# Simulate an error
|
||||
raise DatabaseError("Something went wrong")
|
||||
|
||||
# Should rollback on error
|
||||
mock_database_connection.rollback.assert_called_once()
|
||||
mock_database_connection.commit.assert_not_called()
|
||||
|
||||
|
||||
class TestDatabaseManager:
|
||||
"""Test main database manager functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test database manager initialization."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
assert manager.config == config
|
||||
assert isinstance(manager.query_builder, QueryBuilder)
|
||||
assert manager.query_builder.schema == "gas_station_1"
|
||||
assert manager.connection is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_success(self):
|
||||
"""Test successful database connection."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with patch('psycopg2.connect') as mock_connect:
|
||||
mock_connection = Mock()
|
||||
mock_connect.return_value = mock_connection
|
||||
|
||||
await manager.connect()
|
||||
|
||||
assert manager.connection is not None
|
||||
assert manager.is_connected is True
|
||||
mock_connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure(self):
|
||||
"""Test database connection failure."""
|
||||
config = DatabaseConfig(
|
||||
host="nonexistent-host",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with patch('psycopg2.connect') as mock_connect:
|
||||
mock_connect.side_effect = psycopg2.Error("Connection failed")
|
||||
|
||||
with pytest.raises(DatabaseError) as exc_info:
|
||||
await manager.connect()
|
||||
|
||||
assert "Connection failed" in str(exc_info.value)
|
||||
assert manager.is_connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self):
|
||||
"""Test database disconnection."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
# Mock connection
|
||||
mock_connection = Mock()
|
||||
manager.connection = DatabaseConnection(config, mock_connection)
|
||||
|
||||
await manager.disconnect()
|
||||
|
||||
assert manager.connection is None
|
||||
mock_connection.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query(self, mock_database_connection):
|
||||
"""Test query execution through manager."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchall.return_value = [(1, "Test"), (2, "Data")]
|
||||
|
||||
result = await manager.execute_query("SELECT * FROM test_table")
|
||||
|
||||
assert result == [(1, "Test"), (2, "Data")]
|
||||
mock_cursor.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_not_connected(self):
|
||||
"""Test query execution when not connected."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with pytest.raises(DatabaseError) as exc_info:
|
||||
await manager.execute_query("SELECT * FROM test_table")
|
||||
|
||||
assert "not connected" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_record(self, mock_database_connection):
|
||||
"""Test inserting a record."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchone.return_value = (123,) # Returned ID
|
||||
|
||||
data = {
|
||||
"session_id": "session_123",
|
||||
"camera_id": "camera_001",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.95,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
record_id = await manager.insert_record("car_detections", data)
|
||||
|
||||
assert record_id == 123
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_record(self, mock_database_connection):
|
||||
"""Test updating a record."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
data = {
|
||||
"car_brand": "Toyota",
|
||||
"car_body_type": "Sedan",
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
|
||||
where_conditions = {"session_id": "session_123"}
|
||||
|
||||
rows_affected = await manager.update_record("car_info", data, where_conditions)
|
||||
|
||||
assert rows_affected == 1
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_records(self, mock_database_connection):
|
||||
"""Test deleting records."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 3
|
||||
|
||||
where_conditions = {
|
||||
"created_at": "< NOW() - INTERVAL '30 days'",
|
||||
"processed": True
|
||||
}
|
||||
|
||||
rows_deleted = await manager.delete_records("old_detections", where_conditions)
|
||||
|
||||
assert rows_deleted == 3
|
||||
mock_cursor.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_table(self, mock_database_connection):
|
||||
"""Test creating a table."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
columns = {
|
||||
"id": "SERIAL PRIMARY KEY",
|
||||
"session_id": "VARCHAR(255) UNIQUE NOT NULL",
|
||||
"camera_id": "VARCHAR(255) NOT NULL",
|
||||
"detection_data": "JSON",
|
||||
"created_at": "TIMESTAMP DEFAULT NOW()"
|
||||
}
|
||||
|
||||
await manager.create_table("test_detections", columns)
|
||||
|
||||
mock_database_connection.cursor.return_value.execute.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_exists(self, mock_database_connection):
|
||||
"""Test checking if table exists."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior - table exists
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchone.return_value = (1,)
|
||||
|
||||
exists = await manager.table_exists("car_detections")
|
||||
|
||||
assert exists is True
|
||||
mock_cursor.execute.assert_called_once()
|
||||
|
||||
# Mock cursor behavior - table doesn't exist
|
||||
mock_cursor.fetchone.return_value = None
|
||||
|
||||
exists = await manager.table_exists("nonexistent_table")
|
||||
|
||||
assert exists is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_context(self, mock_database_connection):
|
||||
"""Test transaction context manager."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
async with manager.transaction():
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
||||
|
||||
# Should commit on successful completion
|
||||
mock_database_connection.commit.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema(self, mock_database_connection):
|
||||
"""Test getting table schema information."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behavior
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.fetchall.return_value = [
|
||||
("id", "integer", "NOT NULL"),
|
||||
("session_id", "character varying", "NOT NULL"),
|
||||
("created_at", "timestamp without time zone", "DEFAULT now()")
|
||||
]
|
||||
|
||||
schema = await manager.get_table_schema("car_detections")
|
||||
|
||||
assert len(schema) == 3
|
||||
assert schema[0] == ("id", "integer", "NOT NULL")
|
||||
assert schema[1] == ("session_id", "character varying", "NOT NULL")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_insert(self, mock_database_connection):
|
||||
"""Test bulk insert operation."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
records = [
|
||||
{"name": "John", "email": "john@example.com"},
|
||||
{"name": "Jane", "email": "jane@example.com"},
|
||||
{"name": "Bob", "email": "bob@example.com"}
|
||||
]
|
||||
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
mock_cursor.rowcount = 3
|
||||
|
||||
rows_inserted = await manager.bulk_insert("users", records)
|
||||
|
||||
assert rows_inserted == 3
|
||||
mock_cursor.executemany.assert_called_once()
|
||||
mock_database_connection.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_stats(self, mock_database_connection):
|
||||
"""Test getting connection statistics."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
stats = manager.get_connection_stats()
|
||||
|
||||
assert "connected" in stats
|
||||
assert "host" in stats
|
||||
assert "database" in stats
|
||||
assert "schema" in stats
|
||||
assert stats["connected"] is True
|
||||
assert stats["host"] == "localhost"
|
||||
assert stats["database"] == "test_db"
|
||||
|
||||
|
||||
class TestDatabaseManagerIntegration:
|
||||
"""Integration tests for database manager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_car_detection_workflow(self, mock_database_connection):
|
||||
"""Test complete car detection database workflow."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="gas_station_db",
|
||||
username="detector_user",
|
||||
password="detector_pass",
|
||||
schema="gas_station_1"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Mock cursor behaviors for different operations
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
|
||||
# 1. Create initial detection record
|
||||
mock_cursor.fetchone.return_value = (456,) # Returned ID
|
||||
|
||||
detection_data = {
|
||||
"session_id": str(uuid.uuid4()),
|
||||
"camera_id": "camera_001",
|
||||
"display_id": "display_001",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.92,
|
||||
"bbox_x1": 100,
|
||||
"bbox_y1": 200,
|
||||
"bbox_x2": 300,
|
||||
"bbox_y2": 400,
|
||||
"track_id": 1001,
|
||||
"created_at": "NOW()"
|
||||
}
|
||||
|
||||
detection_id = await manager.insert_record("car_detections", detection_data)
|
||||
assert detection_id == 456
|
||||
|
||||
# 2. Update with classification results
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
classification_data = {
|
||||
"car_brand": "Toyota",
|
||||
"car_model": "Camry",
|
||||
"car_body_type": "Sedan",
|
||||
"car_color": "Blue",
|
||||
"brand_confidence": 0.87,
|
||||
"bodytype_confidence": 0.82,
|
||||
"color_confidence": 0.79,
|
||||
"updated_at": "NOW()"
|
||||
}
|
||||
|
||||
where_conditions = {"session_id": detection_data["session_id"]}
|
||||
|
||||
rows_updated = await manager.update_record("car_detections", classification_data, where_conditions)
|
||||
assert rows_updated == 1
|
||||
|
||||
# 3. Query final results
|
||||
mock_cursor.fetchall.return_value = [
|
||||
(456, detection_data["session_id"], "camera_001", "car", 0.92, "Toyota", "Sedan")
|
||||
]
|
||||
|
||||
results = await manager.execute_query(
|
||||
"SELECT id, session_id, camera_id, detection_class, confidence, car_brand, car_body_type "
|
||||
"FROM gas_station_1.car_detections WHERE session_id = %s",
|
||||
[detection_data["session_id"]]
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0][0] == 456 # ID
|
||||
assert results[0][3] == "car" # detection_class
|
||||
assert results[0][5] == "Toyota" # car_brand
|
||||
|
||||
# Verify all database operations were called
|
||||
assert mock_cursor.execute.call_count == 3
|
||||
assert mock_database_connection.commit.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_and_recovery(self, mock_database_connection):
|
||||
"""Test error handling and recovery scenarios."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
manager.connection = DatabaseConnection(config, mock_database_connection)
|
||||
|
||||
# Test transaction rollback on error
|
||||
mock_cursor = mock_database_connection.cursor.return_value
|
||||
|
||||
with pytest.raises(DatabaseError):
|
||||
async with manager.transaction():
|
||||
# First operation succeeds
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"])
|
||||
|
||||
# Second operation fails
|
||||
mock_cursor.execute.side_effect = psycopg2.Error("Constraint violation")
|
||||
await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"])
|
||||
|
||||
# Should have rolled back
|
||||
mock_database_connection.rollback.assert_called_once()
|
||||
mock_database_connection.commit.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_recovery(self):
|
||||
"""Test automatic connection recovery."""
|
||||
config = DatabaseConfig(
|
||||
host="localhost",
|
||||
database="test_db",
|
||||
username="test_user",
|
||||
password="test_pass"
|
||||
)
|
||||
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with patch('psycopg2.connect') as mock_connect:
|
||||
# First connection attempt fails
|
||||
mock_connect.side_effect = [
|
||||
psycopg2.Error("Connection refused"),
|
||||
Mock() # Second attempt succeeds
|
||||
]
|
||||
|
||||
# First attempt should fail
|
||||
with pytest.raises(DatabaseError):
|
||||
await manager.connect()
|
||||
|
||||
# Second attempt should succeed
|
||||
await manager.connect()
|
||||
assert manager.is_connected is True
|
964
tests/unit/storage/test_redis_client.py
Normal file
964
tests/unit/storage/test_redis_client.py
Normal file
|
@ -0,0 +1,964 @@
|
|||
"""
|
||||
Unit tests for Redis client functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import time
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from datetime import datetime, timedelta
|
||||
import redis
|
||||
import numpy as np
|
||||
|
||||
from detector_worker.storage.redis_client import (
|
||||
RedisClient,
|
||||
RedisConfig,
|
||||
RedisConnectionPool,
|
||||
RedisPublisher,
|
||||
RedisSubscriber,
|
||||
RedisImageStorage,
|
||||
RedisError,
|
||||
ConnectionPoolError
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
from detector_worker.core.exceptions import ConfigurationError
|
||||
|
||||
|
||||
class TestRedisConfig:
|
||||
"""Test Redis configuration."""
|
||||
|
||||
def test_creation_minimal(self):
|
||||
"""Test creating Redis config with minimal parameters."""
|
||||
config = RedisConfig(
|
||||
host="localhost"
|
||||
)
|
||||
|
||||
assert config.host == "localhost"
|
||||
assert config.port == 6379 # Default port
|
||||
assert config.password is None
|
||||
assert config.db == 0 # Default database
|
||||
assert config.enabled is True
|
||||
|
||||
def test_creation_full(self):
|
||||
"""Test creating Redis config with all parameters."""
|
||||
config = RedisConfig(
|
||||
host="redis.example.com",
|
||||
port=6380,
|
||||
password="secure_pass",
|
||||
db=2,
|
||||
enabled=True,
|
||||
connection_timeout=5.0,
|
||||
socket_timeout=3.0,
|
||||
socket_connect_timeout=2.0,
|
||||
max_connections=50,
|
||||
retry_on_timeout=True,
|
||||
health_check_interval=30
|
||||
)
|
||||
|
||||
assert config.host == "redis.example.com"
|
||||
assert config.port == 6380
|
||||
assert config.password == "secure_pass"
|
||||
assert config.db == 2
|
||||
assert config.connection_timeout == 5.0
|
||||
assert config.max_connections == 50
|
||||
assert config.retry_on_timeout is True
|
||||
|
||||
def test_get_connection_params(self):
|
||||
"""Test getting Redis connection parameters."""
|
||||
config = RedisConfig(
|
||||
host="localhost",
|
||||
port=6379,
|
||||
password="test_pass",
|
||||
db=1,
|
||||
connection_timeout=10.0
|
||||
)
|
||||
|
||||
params = config.get_connection_params()
|
||||
|
||||
assert params["host"] == "localhost"
|
||||
assert params["port"] == 6379
|
||||
assert params["password"] == "test_pass"
|
||||
assert params["db"] == 1
|
||||
assert params["socket_timeout"] == 10.0
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"host": "redis-server",
|
||||
"port": 6380,
|
||||
"password": "secret",
|
||||
"db": 3,
|
||||
"max_connections": 100,
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = RedisConfig.from_dict(config_dict)
|
||||
|
||||
assert config.host == "redis-server"
|
||||
assert config.port == 6380
|
||||
assert config.password == "secret"
|
||||
assert config.db == 3
|
||||
assert config.max_connections == 100
|
||||
|
||||
|
||||
class TestRedisConnectionPool:
|
||||
"""Test Redis connection pool management."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test connection pool creation."""
|
||||
config = RedisConfig(
|
||||
host="localhost",
|
||||
max_connections=20
|
||||
)
|
||||
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
assert pool.config == config
|
||||
assert pool.pool is None
|
||||
assert pool.is_connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_success(self):
|
||||
"""Test successful connection to Redis."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
with patch('redis.ConnectionPool') as mock_pool_class:
|
||||
mock_pool = Mock()
|
||||
mock_pool_class.return_value = mock_pool
|
||||
|
||||
with patch('redis.Redis') as mock_redis_class:
|
||||
mock_redis = Mock()
|
||||
mock_redis.ping.return_value = True
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
await pool.connect()
|
||||
|
||||
assert pool.is_connected is True
|
||||
assert pool.pool is not None
|
||||
mock_pool_class.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure(self):
|
||||
"""Test Redis connection failure."""
|
||||
config = RedisConfig(host="nonexistent-redis")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
with patch('redis.ConnectionPool') as mock_pool_class:
|
||||
mock_pool_class.side_effect = redis.ConnectionError("Connection failed")
|
||||
|
||||
with pytest.raises(RedisError) as exc_info:
|
||||
await pool.connect()
|
||||
|
||||
assert "Connection failed" in str(exc_info.value)
|
||||
assert pool.is_connected is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self):
|
||||
"""Test Redis disconnection."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
# Mock connected state
|
||||
mock_pool = Mock()
|
||||
mock_redis = Mock()
|
||||
pool.pool = mock_pool
|
||||
pool._redis_client = mock_redis
|
||||
pool.is_connected = True
|
||||
|
||||
await pool.disconnect()
|
||||
|
||||
assert pool.is_connected is False
|
||||
assert pool.pool is None
|
||||
mock_pool.disconnect.assert_called_once()
|
||||
|
||||
def test_get_client_connected(self):
|
||||
"""Test getting Redis client when connected."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
mock_pool = Mock()
|
||||
mock_redis = Mock()
|
||||
pool.pool = mock_pool
|
||||
pool._redis_client = mock_redis
|
||||
pool.is_connected = True
|
||||
|
||||
client = pool.get_client()
|
||||
assert client == mock_redis
|
||||
|
||||
def test_get_client_not_connected(self):
|
||||
"""Test getting Redis client when not connected."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
with pytest.raises(RedisError) as exc_info:
|
||||
pool.get_client()
|
||||
|
||||
assert "not connected" in str(exc_info.value).lower()
|
||||
|
||||
def test_health_check(self):
|
||||
"""Test Redis health check."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
mock_redis = Mock()
|
||||
mock_redis.ping.return_value = True
|
||||
pool._redis_client = mock_redis
|
||||
pool.is_connected = True
|
||||
|
||||
is_healthy = pool.health_check()
|
||||
|
||||
assert is_healthy is True
|
||||
mock_redis.ping.assert_called_once()
|
||||
|
||||
def test_health_check_failure(self):
|
||||
"""Test Redis health check failure."""
|
||||
config = RedisConfig(host="localhost")
|
||||
pool = RedisConnectionPool(config)
|
||||
|
||||
mock_redis = Mock()
|
||||
mock_redis.ping.side_effect = redis.ConnectionError("Connection lost")
|
||||
pool._redis_client = mock_redis
|
||||
pool.is_connected = True
|
||||
|
||||
is_healthy = pool.health_check()
|
||||
|
||||
assert is_healthy is False
|
||||
|
||||
|
||||
class TestRedisImageStorage:
|
||||
"""Test Redis image storage functionality."""
|
||||
|
||||
def test_creation(self, mock_redis_client):
|
||||
"""Test Redis image storage creation."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
assert storage.redis_client == mock_redis_client
|
||||
assert storage.default_expiry == 3600 # 1 hour
|
||||
assert storage.compression_enabled is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_image_success(self, mock_redis_client, mock_frame):
|
||||
"""Test successful image storage."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
# Mock successful encoding
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
result = await storage.store_image("test_key", mock_frame, expire_seconds=600)
|
||||
|
||||
assert result is True
|
||||
mock_redis_client.set.assert_called_once()
|
||||
mock_redis_client.expire.assert_called_once_with("test_key", 600)
|
||||
mock_imencode.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_image_cropped(self, mock_redis_client, mock_frame):
|
||||
"""Test storing cropped image."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
|
||||
bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400)
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
result = await storage.store_image("cropped_key", mock_frame, crop_bbox=bbox)
|
||||
|
||||
assert result is True
|
||||
mock_redis_client.set.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_image_encoding_failure(self, mock_redis_client, mock_frame):
|
||||
"""Test image storage with encoding failure."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
# Mock encoding failure
|
||||
mock_imencode.return_value = (False, None)
|
||||
|
||||
with pytest.raises(RedisError) as exc_info:
|
||||
await storage.store_image("test_key", mock_frame)
|
||||
|
||||
assert "Failed to encode image" in str(exc_info.value)
|
||||
mock_redis_client.set.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_image_redis_failure(self, mock_redis_client, mock_frame):
|
||||
"""Test image storage with Redis failure."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.set.side_effect = redis.RedisError("Redis error")
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
with pytest.raises(RedisError) as exc_info:
|
||||
await storage.store_image("test_key", mock_frame)
|
||||
|
||||
assert "Redis error" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_image_success(self, mock_redis_client):
|
||||
"""Test successful image retrieval."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
# Mock encoded image data
|
||||
original_image = np.ones((100, 100, 3), dtype=np.uint8) * 128
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
# Mock Redis returning base64 encoded data
|
||||
base64_data = base64.b64encode(encoded_data.tobytes()).decode('utf-8')
|
||||
mock_redis_client.get.return_value = base64_data
|
||||
|
||||
with patch('cv2.imdecode') as mock_imdecode:
|
||||
mock_imdecode.return_value = original_image
|
||||
|
||||
retrieved_image = await storage.retrieve_image("test_key")
|
||||
|
||||
assert retrieved_image is not None
|
||||
assert retrieved_image.shape == (100, 100, 3)
|
||||
mock_redis_client.get.assert_called_once_with("test_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_image_not_found(self, mock_redis_client):
|
||||
"""Test image retrieval when key not found."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
retrieved_image = await storage.retrieve_image("nonexistent_key")
|
||||
|
||||
assert retrieved_image is None
|
||||
mock_redis_client.get.assert_called_once_with("nonexistent_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_image(self, mock_redis_client):
|
||||
"""Test image deletion."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.delete.return_value = 1
|
||||
|
||||
result = await storage.delete_image("test_key")
|
||||
|
||||
assert result is True
|
||||
mock_redis_client.delete.assert_called_once_with("test_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_image_not_found(self, mock_redis_client):
|
||||
"""Test deleting non-existent image."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.delete.return_value = 0
|
||||
|
||||
result = await storage.delete_image("nonexistent_key")
|
||||
|
||||
assert result is False
|
||||
mock_redis_client.delete.assert_called_once_with("nonexistent_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_delete_images(self, mock_redis_client):
|
||||
"""Test bulk image deletion."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
keys = ["key1", "key2", "key3"]
|
||||
mock_redis_client.delete.return_value = 3
|
||||
|
||||
deleted_count = await storage.bulk_delete_images(keys)
|
||||
|
||||
assert deleted_count == 3
|
||||
mock_redis_client.delete.assert_called_once_with(*keys)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_images(self, mock_redis_client):
|
||||
"""Test cleanup of expired images."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
# Mock scan to return image keys
|
||||
mock_redis_client.scan_iter.return_value = [
|
||||
b"inference:camera1:image1",
|
||||
b"inference:camera2:image2",
|
||||
b"inference:camera1:image3"
|
||||
]
|
||||
|
||||
# Mock ttl to return different expiry times
|
||||
mock_redis_client.ttl.side_effect = [-1, 100, -2] # No expiry, valid, expired
|
||||
mock_redis_client.delete.return_value = 1
|
||||
|
||||
deleted_count = await storage.cleanup_expired_images("inference:*")
|
||||
|
||||
assert deleted_count == 1 # Only expired images deleted
|
||||
mock_redis_client.delete.assert_called_once()
|
||||
|
||||
def test_get_image_info(self, mock_redis_client):
|
||||
"""Test getting image metadata."""
|
||||
storage = RedisImageStorage(mock_redis_client)
|
||||
|
||||
mock_redis_client.exists.return_value = 1
|
||||
mock_redis_client.ttl.return_value = 1800 # 30 minutes
|
||||
mock_redis_client.memory_usage.return_value = 4096 # 4KB
|
||||
|
||||
info = storage.get_image_info("test_key")
|
||||
|
||||
assert info["exists"] is True
|
||||
assert info["ttl"] == 1800
|
||||
assert info["size_bytes"] == 4096
|
||||
|
||||
mock_redis_client.exists.assert_called_once_with("test_key")
|
||||
mock_redis_client.ttl.assert_called_once_with("test_key")
|
||||
|
||||
|
||||
class TestRedisPublisher:
|
||||
"""Test Redis publisher functionality."""
|
||||
|
||||
def test_creation(self, mock_redis_client):
|
||||
"""Test Redis publisher creation."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
assert publisher.redis_client == mock_redis_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_message_string(self, mock_redis_client):
|
||||
"""Test publishing string message."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
mock_redis_client.publish.return_value = 2 # 2 subscribers
|
||||
|
||||
result = await publisher.publish("test_channel", "Hello, Redis!")
|
||||
|
||||
assert result == 2
|
||||
mock_redis_client.publish.assert_called_once_with("test_channel", "Hello, Redis!")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_message_json(self, mock_redis_client):
|
||||
"""Test publishing JSON message."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
mock_redis_client.publish.return_value = 1
|
||||
|
||||
message_data = {
|
||||
"camera_id": "camera_001",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.95,
|
||||
"timestamp": 1640995200000
|
||||
}
|
||||
|
||||
result = await publisher.publish("detections", message_data)
|
||||
|
||||
assert result == 1
|
||||
|
||||
# Should have been JSON serialized
|
||||
expected_json = json.dumps(message_data)
|
||||
mock_redis_client.publish.assert_called_once_with("detections", expected_json)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_detection_event(self, mock_redis_client):
|
||||
"""Test publishing detection event."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
mock_redis_client.publish.return_value = 3
|
||||
|
||||
detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000)
|
||||
|
||||
result = await publisher.publish_detection_event(
|
||||
"camera_detections",
|
||||
detection,
|
||||
camera_id="camera_001",
|
||||
session_id="session_123"
|
||||
)
|
||||
|
||||
assert result == 3
|
||||
|
||||
# Verify the published message structure
|
||||
call_args = mock_redis_client.publish.call_args
|
||||
channel = call_args[0][0]
|
||||
message_str = call_args[0][1]
|
||||
message_data = json.loads(message_str)
|
||||
|
||||
assert channel == "camera_detections"
|
||||
assert message_data["event_type"] == "detection"
|
||||
assert message_data["camera_id"] == "camera_001"
|
||||
assert message_data["session_id"] == "session_123"
|
||||
assert message_data["detection"]["class"] == "car"
|
||||
assert message_data["detection"]["confidence"] == 0.92
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_batch_messages(self, mock_redis_client):
|
||||
"""Test publishing multiple messages in batch."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
mock_pipeline = Mock()
|
||||
mock_redis_client.pipeline.return_value = mock_pipeline
|
||||
mock_pipeline.execute.return_value = [1, 2, 1] # Subscriber counts
|
||||
|
||||
messages = [
|
||||
("channel1", "message1"),
|
||||
("channel2", {"data": "message2"}),
|
||||
("channel1", "message3")
|
||||
]
|
||||
|
||||
results = await publisher.publish_batch(messages)
|
||||
|
||||
assert results == [1, 2, 1]
|
||||
mock_redis_client.pipeline.assert_called_once()
|
||||
assert mock_pipeline.publish.call_count == 3
|
||||
mock_pipeline.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_error_handling(self, mock_redis_client):
|
||||
"""Test error handling in publishing."""
|
||||
publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
mock_redis_client.publish.side_effect = redis.RedisError("Publish failed")
|
||||
|
||||
with pytest.raises(RedisError) as exc_info:
|
||||
await publisher.publish("test_channel", "test_message")
|
||||
|
||||
assert "Publish failed" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestRedisSubscriber:
|
||||
"""Test Redis subscriber functionality."""
|
||||
|
||||
def test_creation(self, mock_redis_client):
|
||||
"""Test Redis subscriber creation."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
assert subscriber.redis_client == mock_redis_client
|
||||
assert subscriber.pubsub is None
|
||||
assert subscriber.subscriptions == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_to_channel(self, mock_redis_client):
|
||||
"""Test subscribing to a channel."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
mock_pubsub = Mock()
|
||||
mock_redis_client.pubsub.return_value = mock_pubsub
|
||||
|
||||
await subscriber.subscribe("test_channel")
|
||||
|
||||
assert "test_channel" in subscriber.subscriptions
|
||||
mock_pubsub.subscribe.assert_called_once_with("test_channel")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_to_pattern(self, mock_redis_client):
|
||||
"""Test subscribing to a pattern."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
mock_pubsub = Mock()
|
||||
mock_redis_client.pubsub.return_value = mock_pubsub
|
||||
|
||||
await subscriber.subscribe_pattern("detection:*")
|
||||
|
||||
assert "detection:*" in subscriber.subscriptions
|
||||
mock_pubsub.psubscribe.assert_called_once_with("detection:*")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_from_channel(self, mock_redis_client):
|
||||
"""Test unsubscribing from a channel."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
mock_pubsub = Mock()
|
||||
mock_redis_client.pubsub.return_value = mock_pubsub
|
||||
subscriber.pubsub = mock_pubsub
|
||||
subscriber.subscriptions.add("test_channel")
|
||||
|
||||
await subscriber.unsubscribe("test_channel")
|
||||
|
||||
assert "test_channel" not in subscriber.subscriptions
|
||||
mock_pubsub.unsubscribe.assert_called_once_with("test_channel")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_for_messages(self, mock_redis_client):
|
||||
"""Test listening for messages."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
mock_pubsub = Mock()
|
||||
mock_redis_client.pubsub.return_value = mock_pubsub
|
||||
|
||||
# Mock message stream
|
||||
messages = [
|
||||
{"type": "subscribe", "channel": "test", "data": 1},
|
||||
{"type": "message", "channel": "test", "data": "Hello"},
|
||||
{"type": "message", "channel": "test", "data": '{"key": "value"}'}
|
||||
]
|
||||
|
||||
mock_pubsub.listen.return_value = iter(messages)
|
||||
|
||||
received_messages = []
|
||||
message_count = 0
|
||||
|
||||
async for message in subscriber.listen():
|
||||
received_messages.append(message)
|
||||
message_count += 1
|
||||
if message_count >= 2: # Only process actual messages
|
||||
break
|
||||
|
||||
# Should receive 2 actual messages (excluding subscribe confirmation)
|
||||
assert len(received_messages) == 2
|
||||
assert received_messages[0]["data"] == "Hello"
|
||||
assert received_messages[1]["data"] == {"key": "value"} # Should be parsed as JSON
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_subscription(self, mock_redis_client):
|
||||
"""Test closing subscription."""
|
||||
subscriber = RedisSubscriber(mock_redis_client)
|
||||
|
||||
mock_pubsub = Mock()
|
||||
subscriber.pubsub = mock_pubsub
|
||||
subscriber.subscriptions = {"channel1", "pattern:*"}
|
||||
|
||||
await subscriber.close()
|
||||
|
||||
assert len(subscriber.subscriptions) == 0
|
||||
mock_pubsub.close.assert_called_once()
|
||||
assert subscriber.pubsub is None
|
||||
|
||||
|
||||
class TestRedisClient:
|
||||
"""Test main Redis client functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test Redis client initialization."""
|
||||
config = RedisConfig(host="localhost", port=6379)
|
||||
client = RedisClient(config)
|
||||
|
||||
assert client.config == config
|
||||
assert isinstance(client.connection_pool, RedisConnectionPool)
|
||||
assert client.image_storage is None
|
||||
assert client.publisher is None
|
||||
assert client.subscriber is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_and_initialize_components(self):
|
||||
"""Test connecting and initializing all components."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect:
|
||||
mock_redis_client = Mock()
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
await client.connect()
|
||||
|
||||
assert client.image_storage is not None
|
||||
assert client.publisher is not None
|
||||
assert client.subscriber is not None
|
||||
assert isinstance(client.image_storage, RedisImageStorage)
|
||||
assert isinstance(client.publisher, RedisPublisher)
|
||||
assert isinstance(client.subscriber, RedisSubscriber)
|
||||
|
||||
mock_connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self):
|
||||
"""Test disconnection."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.is_connected = True
|
||||
client.subscriber = Mock()
|
||||
client.subscriber.close = AsyncMock()
|
||||
|
||||
with patch.object(client.connection_pool, 'disconnect', new_callable=AsyncMock) as mock_disconnect:
|
||||
await client.disconnect()
|
||||
|
||||
client.subscriber.close.assert_called_once()
|
||||
mock_disconnect.assert_called_once()
|
||||
|
||||
assert client.image_storage is None
|
||||
assert client.publisher is None
|
||||
assert client.subscriber is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_and_retrieve_data(self, mock_redis_client):
|
||||
"""Test storing and retrieving data."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
# Test storing data
|
||||
mock_redis_client.set.return_value = True
|
||||
result = await client.set("test_key", "test_value", expire_seconds=300)
|
||||
assert result is True
|
||||
mock_redis_client.set.assert_called_once_with("test_key", "test_value")
|
||||
mock_redis_client.expire.assert_called_once_with("test_key", 300)
|
||||
|
||||
# Test retrieving data
|
||||
mock_redis_client.get.return_value = "test_value"
|
||||
value = await client.get("test_key")
|
||||
assert value == "test_value"
|
||||
mock_redis_client.get.assert_called_once_with("test_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keys(self, mock_redis_client):
|
||||
"""Test deleting keys."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.delete.return_value = 2
|
||||
|
||||
result = await client.delete("key1", "key2")
|
||||
|
||||
assert result == 2
|
||||
mock_redis_client.delete.assert_called_once_with("key1", "key2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists_check(self, mock_redis_client):
|
||||
"""Test checking key existence."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.exists.return_value = 1
|
||||
|
||||
exists = await client.exists("test_key")
|
||||
|
||||
assert exists is True
|
||||
mock_redis_client.exists.assert_called_once_with("test_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expire_key(self, mock_redis_client):
|
||||
"""Test setting key expiration."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.expire.return_value = True
|
||||
|
||||
result = await client.expire("test_key", 600)
|
||||
|
||||
assert result is True
|
||||
mock_redis_client.expire.assert_called_once_with("test_key", 600)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ttl(self, mock_redis_client):
|
||||
"""Test getting key TTL."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.ttl.return_value = 300
|
||||
|
||||
ttl = await client.ttl("test_key")
|
||||
|
||||
assert ttl == 300
|
||||
mock_redis_client.ttl.assert_called_once_with("test_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_keys(self, mock_redis_client):
|
||||
"""Test scanning for keys."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.scan_iter.return_value = [b"key1", b"key2", b"key3"]
|
||||
|
||||
keys = await client.scan_keys("test:*")
|
||||
|
||||
assert keys == ["key1", "key2", "key3"]
|
||||
mock_redis_client.scan_iter.assert_called_once_with(match="test:*")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_database(self, mock_redis_client):
|
||||
"""Test flushing database."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_redis_client.flushdb.return_value = True
|
||||
|
||||
result = await client.flush_db()
|
||||
|
||||
assert result is True
|
||||
mock_redis_client.flushdb.assert_called_once()
|
||||
|
||||
def test_get_connection_info(self):
|
||||
"""Test getting connection information."""
|
||||
config = RedisConfig(
|
||||
host="redis.example.com",
|
||||
port=6380,
|
||||
db=2
|
||||
)
|
||||
client = RedisClient(config)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
info = client.get_connection_info()
|
||||
|
||||
assert info["connected"] is True
|
||||
assert info["host"] == "redis.example.com"
|
||||
assert info["port"] == 6380
|
||||
assert info["database"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_operations(self, mock_redis_client):
|
||||
"""Test Redis pipeline operations."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
|
||||
mock_pipeline = Mock()
|
||||
mock_redis_client.pipeline.return_value = mock_pipeline
|
||||
mock_pipeline.execute.return_value = [True, True, 1]
|
||||
|
||||
async with client.pipeline() as pipe:
|
||||
pipe.set("key1", "value1")
|
||||
pipe.set("key2", "value2")
|
||||
pipe.delete("key3")
|
||||
results = await pipe.execute()
|
||||
|
||||
assert results == [True, True, 1]
|
||||
mock_redis_client.pipeline.assert_called_once()
|
||||
mock_pipeline.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestRedisClientIntegration:
|
||||
"""Integration tests for Redis client."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_image_workflow(self, mock_redis_client, mock_frame):
|
||||
"""Test complete image storage workflow."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state and components
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
client.image_storage = RedisImageStorage(mock_redis_client)
|
||||
client.publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
# Mock Redis operations
|
||||
mock_redis_client.set.return_value = True
|
||||
mock_redis_client.expire.return_value = True
|
||||
mock_redis_client.publish.return_value = 2
|
||||
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
# Store image
|
||||
store_result = await client.image_storage.store_image(
|
||||
"detection:camera001:1640995200:session123",
|
||||
mock_frame,
|
||||
expire_seconds=600
|
||||
)
|
||||
|
||||
# Publish detection event
|
||||
detection_event = {
|
||||
"camera_id": "camera001",
|
||||
"session_id": "session123",
|
||||
"detection_class": "car",
|
||||
"confidence": 0.95,
|
||||
"timestamp": 1640995200000
|
||||
}
|
||||
|
||||
publish_result = await client.publisher.publish("detections:camera001", detection_event)
|
||||
|
||||
assert store_result is True
|
||||
assert publish_result == 2
|
||||
|
||||
# Verify Redis operations
|
||||
mock_redis_client.set.assert_called_once()
|
||||
mock_redis_client.expire.assert_called_once()
|
||||
mock_redis_client.publish.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_recovery_and_reconnection(self):
|
||||
"""Test error recovery and reconnection."""
|
||||
config = RedisConfig(host="localhost", retry_on_timeout=True)
|
||||
client = RedisClient(config)
|
||||
|
||||
with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect:
|
||||
with patch.object(client.connection_pool, 'health_check') as mock_health_check:
|
||||
# First health check fails, second succeeds
|
||||
mock_health_check.side_effect = [False, True]
|
||||
|
||||
# First connection attempt fails, second succeeds
|
||||
mock_connect.side_effect = [RedisError("Connection failed"), None]
|
||||
|
||||
# Simulate connection recovery
|
||||
try:
|
||||
await client.connect()
|
||||
except RedisError:
|
||||
# Retry connection
|
||||
await client.connect()
|
||||
|
||||
assert mock_connect.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_operations_performance(self, mock_redis_client):
|
||||
"""Test bulk operations for performance."""
|
||||
config = RedisConfig(host="localhost")
|
||||
client = RedisClient(config)
|
||||
|
||||
# Mock connected state
|
||||
client.connection_pool.get_client = Mock(return_value=mock_redis_client)
|
||||
client.connection_pool.is_connected = True
|
||||
client.publisher = RedisPublisher(mock_redis_client)
|
||||
|
||||
# Mock pipeline operations
|
||||
mock_pipeline = Mock()
|
||||
mock_redis_client.pipeline.return_value = mock_pipeline
|
||||
mock_pipeline.execute.return_value = [1] * 100 # 100 successful operations
|
||||
|
||||
# Prepare bulk messages
|
||||
messages = [
|
||||
(f"channel_{i}", f"message_{i}")
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
results = await client.publisher.publish_batch(messages)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
assert len(results) == 100
|
||||
assert all(result == 1 for result in results)
|
||||
|
||||
# Should be faster than individual operations
|
||||
assert execution_time < 1.0 # Should complete in less than 1 second
|
||||
|
||||
# Pipeline should be used for efficiency
|
||||
mock_redis_client.pipeline.assert_called_once()
|
||||
assert mock_pipeline.publish.call_count == 100
|
||||
mock_pipeline.execute.assert_called_once()
|
883
tests/unit/storage/test_session_cache.py
Normal file
883
tests/unit/storage/test_session_cache.py
Normal file
|
@ -0,0 +1,883 @@
|
|||
"""
|
||||
Unit tests for session cache management.
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from detector_worker.storage.session_cache import (
|
||||
SessionCache,
|
||||
SessionCacheManager,
|
||||
SessionData,
|
||||
CacheConfig,
|
||||
CacheEntry,
|
||||
CacheStats,
|
||||
SessionError,
|
||||
CacheError
|
||||
)
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
|
||||
|
||||
class TestCacheConfig:
|
||||
"""Test cache configuration."""
|
||||
|
||||
def test_creation_default(self):
|
||||
"""Test creating cache config with default values."""
|
||||
config = CacheConfig()
|
||||
|
||||
assert config.max_size == 1000
|
||||
assert config.ttl_seconds == 3600 # 1 hour
|
||||
assert config.cleanup_interval == 300 # 5 minutes
|
||||
assert config.eviction_policy == "lru"
|
||||
assert config.enable_persistence is False
|
||||
|
||||
def test_creation_custom(self):
|
||||
"""Test creating cache config with custom values."""
|
||||
config = CacheConfig(
|
||||
max_size=5000,
|
||||
ttl_seconds=7200,
|
||||
cleanup_interval=600,
|
||||
eviction_policy="lfu",
|
||||
enable_persistence=True,
|
||||
persistence_path="/tmp/cache"
|
||||
)
|
||||
|
||||
assert config.max_size == 5000
|
||||
assert config.ttl_seconds == 7200
|
||||
assert config.cleanup_interval == 600
|
||||
assert config.eviction_policy == "lfu"
|
||||
assert config.enable_persistence is True
|
||||
assert config.persistence_path == "/tmp/cache"
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"max_size": 2000,
|
||||
"ttl_seconds": 1800,
|
||||
"eviction_policy": "fifo",
|
||||
"enable_persistence": True,
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = CacheConfig.from_dict(config_dict)
|
||||
|
||||
assert config.max_size == 2000
|
||||
assert config.ttl_seconds == 1800
|
||||
assert config.eviction_policy == "fifo"
|
||||
assert config.enable_persistence is True
|
||||
|
||||
|
||||
class TestCacheEntry:
|
||||
"""Test cache entry data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test cache entry creation."""
|
||||
data = {"key": "value", "number": 42}
|
||||
entry = CacheEntry(data, ttl_seconds=600)
|
||||
|
||||
assert entry.data == data
|
||||
assert entry.ttl_seconds == 600
|
||||
assert entry.created_at <= time.time()
|
||||
assert entry.last_accessed <= time.time()
|
||||
assert entry.access_count == 1
|
||||
assert entry.size > 0
|
||||
|
||||
def test_is_expired(self):
|
||||
"""Test expiration checking."""
|
||||
# Non-expired entry
|
||||
entry = CacheEntry({"data": "test"}, ttl_seconds=600)
|
||||
assert entry.is_expired() is False
|
||||
|
||||
# Expired entry (simulate by setting old creation time)
|
||||
entry.created_at = time.time() - 700 # Created 700 seconds ago
|
||||
assert entry.is_expired() is True
|
||||
|
||||
# Entry without expiration
|
||||
entry_no_ttl = CacheEntry({"data": "test"})
|
||||
assert entry_no_ttl.is_expired() is False
|
||||
|
||||
def test_touch(self):
|
||||
"""Test updating access time and count."""
|
||||
entry = CacheEntry({"data": "test"})
|
||||
|
||||
original_access_time = entry.last_accessed
|
||||
original_access_count = entry.access_count
|
||||
|
||||
time.sleep(0.01) # Small delay
|
||||
entry.touch()
|
||||
|
||||
assert entry.last_accessed > original_access_time
|
||||
assert entry.access_count == original_access_count + 1
|
||||
|
||||
def test_age(self):
|
||||
"""Test age calculation."""
|
||||
entry = CacheEntry({"data": "test"})
|
||||
|
||||
time.sleep(0.01) # Small delay
|
||||
age = entry.age()
|
||||
|
||||
assert age > 0
|
||||
assert age < 1 # Should be less than 1 second
|
||||
|
||||
def test_size_estimation(self):
|
||||
"""Test size estimation."""
|
||||
small_entry = CacheEntry({"key": "value"})
|
||||
large_entry = CacheEntry({"key": "x" * 1000, "data": list(range(100))})
|
||||
|
||||
assert large_entry.size > small_entry.size
|
||||
|
||||
|
||||
class TestSessionData:
|
||||
"""Test session data structure."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test session data creation."""
|
||||
session_data = SessionData(
|
||||
session_id="session_123",
|
||||
camera_id="camera_001",
|
||||
display_id="display_001"
|
||||
)
|
||||
|
||||
assert session_data.session_id == "session_123"
|
||||
assert session_data.camera_id == "camera_001"
|
||||
assert session_data.display_id == "display_001"
|
||||
assert session_data.created_at <= time.time()
|
||||
assert session_data.last_activity <= time.time()
|
||||
assert session_data.detection_data == {}
|
||||
assert session_data.metadata == {}
|
||||
|
||||
def test_update_activity(self):
|
||||
"""Test updating last activity."""
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
|
||||
original_activity = session_data.last_activity
|
||||
time.sleep(0.01)
|
||||
session_data.update_activity()
|
||||
|
||||
assert session_data.last_activity > original_activity
|
||||
|
||||
def test_add_detection_data(self):
|
||||
"""Test adding detection data."""
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
|
||||
detection_data = {
|
||||
"class": "car",
|
||||
"confidence": 0.95,
|
||||
"bbox": [100, 200, 300, 400]
|
||||
}
|
||||
|
||||
session_data.add_detection_data("main_detection", detection_data)
|
||||
|
||||
assert "main_detection" in session_data.detection_data
|
||||
assert session_data.detection_data["main_detection"] == detection_data
|
||||
|
||||
def test_add_metadata(self):
|
||||
"""Test adding metadata."""
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
|
||||
session_data.add_metadata("model_version", "v2.1")
|
||||
session_data.add_metadata("inference_time", 0.15)
|
||||
|
||||
assert session_data.metadata["model_version"] == "v2.1"
|
||||
assert session_data.metadata["inference_time"] == 0.15
|
||||
|
||||
def test_is_expired(self):
|
||||
"""Test session expiration."""
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
|
||||
# Not expired with default timeout
|
||||
assert session_data.is_expired() is False
|
||||
|
||||
# Expired with short timeout
|
||||
assert session_data.is_expired(timeout_seconds=0.001) is True
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting session to dictionary."""
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
session_data.add_detection_data("detection", {"class": "car", "confidence": 0.9})
|
||||
session_data.add_metadata("model_id", "yolo_v8")
|
||||
|
||||
data_dict = session_data.to_dict()
|
||||
|
||||
assert data_dict["session_id"] == "session_123"
|
||||
assert data_dict["camera_id"] == "camera_001"
|
||||
assert data_dict["detection_data"]["detection"]["class"] == "car"
|
||||
assert data_dict["metadata"]["model_id"] == "yolo_v8"
|
||||
assert "created_at" in data_dict
|
||||
assert "last_activity" in data_dict
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""Test cache statistics."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test cache stats creation."""
|
||||
stats = CacheStats()
|
||||
|
||||
assert stats.hits == 0
|
||||
assert stats.misses == 0
|
||||
assert stats.evictions == 0
|
||||
assert stats.size == 0
|
||||
assert stats.memory_usage == 0
|
||||
|
||||
def test_hit_rate_calculation(self):
|
||||
"""Test hit rate calculation."""
|
||||
stats = CacheStats()
|
||||
|
||||
# No requests yet
|
||||
assert stats.hit_rate() == 0.0
|
||||
|
||||
# Some hits and misses
|
||||
stats.hits = 8
|
||||
stats.misses = 2
|
||||
|
||||
assert stats.hit_rate() == 0.8 # 8 / (8 + 2)
|
||||
|
||||
def test_total_requests(self):
|
||||
"""Test total requests calculation."""
|
||||
stats = CacheStats()
|
||||
|
||||
stats.hits = 15
|
||||
stats.misses = 5
|
||||
|
||||
assert stats.total_requests() == 20
|
||||
|
||||
|
||||
class TestSessionCache:
|
||||
"""Test session cache functionality."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test session cache creation."""
|
||||
config = CacheConfig(max_size=100, ttl_seconds=300)
|
||||
cache = SessionCache(config)
|
||||
|
||||
assert cache.config == config
|
||||
assert cache.max_size == 100
|
||||
assert cache.ttl_seconds == 300
|
||||
assert len(cache._cache) == 0
|
||||
assert len(cache._access_order) == 0
|
||||
|
||||
def test_put_and_get_session(self):
|
||||
"""Test putting and getting session data."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
session_data.add_detection_data("main", {"class": "car", "confidence": 0.9})
|
||||
|
||||
# Put session
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
# Get session
|
||||
retrieved_data = cache.get("session_123")
|
||||
|
||||
assert retrieved_data is not None
|
||||
assert retrieved_data.session_id == "session_123"
|
||||
assert retrieved_data.camera_id == "camera_001"
|
||||
assert "main" in retrieved_data.detection_data
|
||||
|
||||
def test_get_nonexistent_session(self):
|
||||
"""Test getting non-existent session."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
result = cache.get("nonexistent_session")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_contains_check(self):
|
||||
"""Test checking if session exists."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
assert cache.contains("session_123") is True
|
||||
assert cache.contains("nonexistent_session") is False
|
||||
|
||||
def test_remove_session(self):
|
||||
"""Test removing session."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
assert cache.contains("session_123") is True
|
||||
|
||||
removed_data = cache.remove("session_123")
|
||||
|
||||
assert removed_data is not None
|
||||
assert removed_data.session_id == "session_123"
|
||||
assert cache.contains("session_123") is False
|
||||
|
||||
def test_size_tracking(self):
|
||||
"""Test cache size tracking."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
assert cache.size() == 0
|
||||
assert cache.is_empty() is True
|
||||
|
||||
# Add sessions
|
||||
for i in range(3):
|
||||
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
|
||||
cache.put(f"session_{i}", session_data)
|
||||
|
||||
assert cache.size() == 3
|
||||
assert cache.is_empty() is False
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""Test LRU eviction policy."""
|
||||
cache = SessionCache(CacheConfig(max_size=3, eviction_policy="lru"))
|
||||
|
||||
# Fill cache to capacity
|
||||
for i in range(3):
|
||||
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
|
||||
cache.put(f"session_{i}", session_data)
|
||||
|
||||
# Access session_1 to make it recently used
|
||||
cache.get("session_1")
|
||||
|
||||
# Add another session (should evict session_0, the least recently used)
|
||||
new_session = SessionData("session_3", "camera_001", "display_001")
|
||||
cache.put("session_3", new_session)
|
||||
|
||||
assert cache.size() == 3
|
||||
assert cache.contains("session_0") is False # Evicted
|
||||
assert cache.contains("session_1") is True # Recently accessed
|
||||
assert cache.contains("session_2") is True
|
||||
assert cache.contains("session_3") is True # Newly added
|
||||
|
||||
def test_ttl_expiration(self):
|
||||
"""Test TTL-based expiration."""
|
||||
cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1)) # 100ms TTL
|
||||
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
# Should exist immediately
|
||||
assert cache.contains("session_123") is True
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.2)
|
||||
|
||||
# Should be expired (but might still be in cache until cleanup)
|
||||
entry = cache._cache.get("session_123")
|
||||
if entry:
|
||||
assert entry.is_expired() is True
|
||||
|
||||
# Getting expired entry should return None and clean it up
|
||||
retrieved = cache.get("session_123")
|
||||
assert retrieved is None
|
||||
assert cache.contains("session_123") is False
|
||||
|
||||
def test_cleanup_expired_entries(self):
|
||||
"""Test cleanup of expired entries."""
|
||||
cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1))
|
||||
|
||||
# Add multiple sessions
|
||||
for i in range(3):
|
||||
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
|
||||
cache.put(f"session_{i}", session_data)
|
||||
|
||||
assert cache.size() == 3
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.2)
|
||||
|
||||
# Cleanup expired entries
|
||||
cleaned_count = cache.cleanup_expired()
|
||||
|
||||
assert cleaned_count == 3
|
||||
assert cache.size() == 0
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""Test clearing entire cache."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
# Add sessions
|
||||
for i in range(5):
|
||||
session_data = SessionData(f"session_{i}", "camera_001", "display_001")
|
||||
cache.put(f"session_{i}", session_data)
|
||||
|
||||
assert cache.size() == 5
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert cache.size() == 0
|
||||
assert cache.is_empty() is True
|
||||
|
||||
def test_get_all_sessions(self):
|
||||
"""Test getting all sessions."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
sessions = []
|
||||
for i in range(3):
|
||||
session_data = SessionData(f"session_{i}", f"camera_{i}", "display_001")
|
||||
cache.put(f"session_{i}", session_data)
|
||||
sessions.append(session_data)
|
||||
|
||||
all_sessions = cache.get_all()
|
||||
|
||||
assert len(all_sessions) == 3
|
||||
for session_id, session_data in all_sessions.items():
|
||||
assert session_id.startswith("session_")
|
||||
assert session_data.session_id == session_id
|
||||
|
||||
def test_get_sessions_by_camera(self):
|
||||
"""Test getting sessions by camera ID."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
# Add sessions for different cameras
|
||||
for i in range(2):
|
||||
session_data1 = SessionData(f"session_cam1_{i}", "camera_001", "display_001")
|
||||
session_data2 = SessionData(f"session_cam2_{i}", "camera_002", "display_001")
|
||||
cache.put(f"session_cam1_{i}", session_data1)
|
||||
cache.put(f"session_cam2_{i}", session_data2)
|
||||
|
||||
camera1_sessions = cache.get_by_camera("camera_001")
|
||||
camera2_sessions = cache.get_by_camera("camera_002")
|
||||
|
||||
assert len(camera1_sessions) == 2
|
||||
assert len(camera2_sessions) == 2
|
||||
|
||||
for session_data in camera1_sessions:
|
||||
assert session_data.camera_id == "camera_001"
|
||||
|
||||
for session_data in camera2_sessions:
|
||||
assert session_data.camera_id == "camera_002"
|
||||
|
||||
def test_statistics_tracking(self):
|
||||
"""Test cache statistics tracking."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
# Cache miss
|
||||
cache.get("nonexistent_session")
|
||||
|
||||
# Cache hit
|
||||
cache.get("session_123")
|
||||
cache.get("session_123") # Another hit
|
||||
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert stats.hits == 2
|
||||
assert stats.misses == 1
|
||||
assert stats.size == 1
|
||||
assert stats.hit_rate() == 2/3 # 2 hits out of 3 total requests
|
||||
|
||||
def test_memory_usage_estimation(self):
|
||||
"""Test memory usage estimation."""
|
||||
cache = SessionCache(CacheConfig(max_size=10))
|
||||
|
||||
initial_memory = cache.get_memory_usage()
|
||||
|
||||
# Add large session
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
session_data.add_detection_data("large_data", {"data": "x" * 1000})
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
after_memory = cache.get_memory_usage()
|
||||
|
||||
assert after_memory > initial_memory
|
||||
|
||||
|
||||
class TestSessionCacheManager:
|
||||
"""Test session cache manager."""
|
||||
|
||||
def test_singleton_behavior(self):
|
||||
"""Test that SessionCacheManager is a singleton."""
|
||||
manager1 = SessionCacheManager()
|
||||
manager2 = SessionCacheManager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test session cache manager initialization."""
|
||||
manager = SessionCacheManager()
|
||||
|
||||
assert manager.detection_cache is not None
|
||||
assert manager.pipeline_cache is not None
|
||||
assert manager.session_cache is not None
|
||||
assert isinstance(manager.detection_cache, SessionCache)
|
||||
assert isinstance(manager.pipeline_cache, SessionCache)
|
||||
assert isinstance(manager.session_cache, SessionCache)
|
||||
|
||||
def test_cache_detection_result(self):
|
||||
"""Test caching detection results."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all() # Start fresh
|
||||
|
||||
detection_data = {
|
||||
"class": "car",
|
||||
"confidence": 0.95,
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"track_id": 1001
|
||||
}
|
||||
|
||||
manager.cache_detection("camera_001", detection_data)
|
||||
|
||||
cached_detection = manager.get_cached_detection("camera_001")
|
||||
|
||||
assert cached_detection is not None
|
||||
assert cached_detection["class"] == "car"
|
||||
assert cached_detection["confidence"] == 0.95
|
||||
assert cached_detection["track_id"] == 1001
|
||||
|
||||
def test_cache_pipeline_result(self):
|
||||
"""Test caching pipeline results."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
pipeline_result = {
|
||||
"status": "success",
|
||||
"detections": [{"class": "car", "confidence": 0.9}],
|
||||
"execution_time": 0.15,
|
||||
"model_id": "yolo_v8"
|
||||
}
|
||||
|
||||
manager.cache_pipeline_result("camera_001", pipeline_result)
|
||||
|
||||
cached_result = manager.get_cached_pipeline_result("camera_001")
|
||||
|
||||
assert cached_result is not None
|
||||
assert cached_result["status"] == "success"
|
||||
assert cached_result["execution_time"] == 0.15
|
||||
assert len(cached_result["detections"]) == 1
|
||||
|
||||
def test_manage_session_data(self):
|
||||
"""Test session data management."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create session
|
||||
manager.create_session(session_id, "camera_001", {"initial": "data"})
|
||||
|
||||
# Update session
|
||||
manager.update_session_detection(session_id, {"car_brand": "Toyota"})
|
||||
|
||||
# Get session
|
||||
session_data = manager.get_session_detection(session_id)
|
||||
|
||||
assert session_data is not None
|
||||
assert "initial" in session_data
|
||||
assert session_data["car_brand"] == "Toyota"
|
||||
|
||||
def test_set_latest_frame(self):
|
||||
"""Test setting and getting latest frame."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
frame_data = b"fake_frame_data"
|
||||
|
||||
manager.set_latest_frame("camera_001", frame_data)
|
||||
|
||||
retrieved_frame = manager.get_latest_frame("camera_001")
|
||||
|
||||
assert retrieved_frame == frame_data
|
||||
|
||||
def test_frame_skip_flag_management(self):
|
||||
"""Test frame skip flag management."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Initially should be False
|
||||
assert manager.get_frame_skip_flag("camera_001") is False
|
||||
|
||||
# Set to True
|
||||
manager.set_frame_skip_flag("camera_001", True)
|
||||
assert manager.get_frame_skip_flag("camera_001") is True
|
||||
|
||||
# Set back to False
|
||||
manager.set_frame_skip_flag("camera_001", False)
|
||||
assert manager.get_frame_skip_flag("camera_001") is False
|
||||
|
||||
def test_cleanup_expired_sessions(self):
|
||||
"""Test cleanup of expired sessions."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Create sessions with short TTL
|
||||
manager.session_cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1))
|
||||
|
||||
# Add sessions
|
||||
for i in range(3):
|
||||
session_id = f"session_{i}"
|
||||
manager.create_session(session_id, "camera_001", {"test": "data"})
|
||||
|
||||
assert manager.session_cache.size() == 3
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.2)
|
||||
|
||||
# Cleanup
|
||||
expired_count = manager.cleanup_expired_sessions()
|
||||
|
||||
assert expired_count == 3
|
||||
assert manager.session_cache.size() == 0
|
||||
|
||||
def test_clear_camera_cache(self):
|
||||
"""Test clearing cache for specific camera."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Add data for multiple cameras
|
||||
manager.cache_detection("camera_001", {"class": "car"})
|
||||
manager.cache_detection("camera_002", {"class": "truck"})
|
||||
manager.cache_pipeline_result("camera_001", {"status": "success"})
|
||||
manager.set_latest_frame("camera_001", b"frame1")
|
||||
manager.set_latest_frame("camera_002", b"frame2")
|
||||
|
||||
# Clear camera_001 cache
|
||||
manager.clear_camera_cache("camera_001")
|
||||
|
||||
# camera_001 data should be gone
|
||||
assert manager.get_cached_detection("camera_001") is None
|
||||
assert manager.get_cached_pipeline_result("camera_001") is None
|
||||
assert manager.get_latest_frame("camera_001") is None
|
||||
|
||||
# camera_002 data should remain
|
||||
assert manager.get_cached_detection("camera_002") is not None
|
||||
assert manager.get_latest_frame("camera_002") is not None
|
||||
|
||||
def test_get_cache_statistics(self):
|
||||
"""Test getting cache statistics."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Add some data to generate statistics
|
||||
manager.cache_detection("camera_001", {"class": "car"})
|
||||
manager.cache_pipeline_result("camera_001", {"status": "success"})
|
||||
manager.create_session("session_123", "camera_001", {"initial": "data"})
|
||||
|
||||
# Access data to generate hits/misses
|
||||
manager.get_cached_detection("camera_001") # Hit
|
||||
manager.get_cached_detection("camera_002") # Miss
|
||||
|
||||
stats = manager.get_cache_statistics()
|
||||
|
||||
assert "detection_cache" in stats
|
||||
assert "pipeline_cache" in stats
|
||||
assert "session_cache" in stats
|
||||
assert "total_memory_usage" in stats
|
||||
|
||||
detection_stats = stats["detection_cache"]
|
||||
assert detection_stats["size"] >= 1
|
||||
assert detection_stats["hits"] >= 1
|
||||
assert detection_stats["misses"] >= 1
|
||||
|
||||
def test_memory_pressure_handling(self):
|
||||
"""Test handling memory pressure."""
|
||||
# Create manager with small cache sizes
|
||||
config = CacheConfig(max_size=3)
|
||||
manager = SessionCacheManager()
|
||||
manager.detection_cache = SessionCache(config)
|
||||
manager.pipeline_cache = SessionCache(config)
|
||||
manager.session_cache = SessionCache(config)
|
||||
|
||||
# Fill caches beyond capacity
|
||||
for i in range(5):
|
||||
manager.cache_detection(f"camera_{i}", {"class": "car", "data": "x" * 100})
|
||||
manager.cache_pipeline_result(f"camera_{i}", {"status": "success", "data": "y" * 100})
|
||||
manager.create_session(f"session_{i}", f"camera_{i}", {"data": "z" * 100})
|
||||
|
||||
# Caches should not exceed max size due to eviction
|
||||
assert manager.detection_cache.size() <= 3
|
||||
assert manager.pipeline_cache.size() <= 3
|
||||
assert manager.session_cache.size() <= 3
|
||||
|
||||
def test_concurrent_access_thread_safety(self):
|
||||
"""Test thread safety of concurrent cache access."""
|
||||
import threading
|
||||
import concurrent.futures
|
||||
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def cache_operation(thread_id):
|
||||
try:
|
||||
# Each thread performs multiple cache operations
|
||||
for i in range(10):
|
||||
session_id = f"thread_{thread_id}_session_{i}"
|
||||
|
||||
# Create session
|
||||
manager.create_session(session_id, f"camera_{thread_id}", {"thread": thread_id, "index": i})
|
||||
|
||||
# Update session
|
||||
manager.update_session_detection(session_id, {"updated": True})
|
||||
|
||||
# Read session
|
||||
data = manager.get_session_detection(session_id)
|
||||
if data and data.get("thread") == thread_id:
|
||||
results.append((thread_id, i))
|
||||
|
||||
except Exception as e:
|
||||
errors.append((thread_id, str(e)))
|
||||
|
||||
# Run operations concurrently
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(cache_operation, i) for i in range(5)]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
# Should have no errors and successful operations
|
||||
assert len(errors) == 0
|
||||
assert len(results) >= 25 # At least some operations should succeed
|
||||
|
||||
|
||||
class TestSessionCacheIntegration:
|
||||
"""Integration tests for session cache."""
|
||||
|
||||
def test_complete_detection_workflow(self):
|
||||
"""Test complete detection workflow with caching."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
camera_id = "camera_001"
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# 1. Cache initial detection
|
||||
detection_data = {
|
||||
"class": "car",
|
||||
"confidence": 0.92,
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"track_id": 1001,
|
||||
"timestamp": int(time.time() * 1000)
|
||||
}
|
||||
|
||||
manager.cache_detection(camera_id, detection_data)
|
||||
|
||||
# 2. Create session for tracking
|
||||
initial_session_data = {
|
||||
"detection_class": detection_data["class"],
|
||||
"confidence": detection_data["confidence"],
|
||||
"track_id": detection_data["track_id"]
|
||||
}
|
||||
|
||||
manager.create_session(session_id, camera_id, initial_session_data)
|
||||
|
||||
# 3. Cache pipeline processing result
|
||||
pipeline_result = {
|
||||
"status": "processing",
|
||||
"stage": "classification",
|
||||
"detections": [detection_data],
|
||||
"branches_completed": [],
|
||||
"branches_pending": ["car_brand_cls", "car_bodytype_cls"]
|
||||
}
|
||||
|
||||
manager.cache_pipeline_result(camera_id, pipeline_result)
|
||||
|
||||
# 4. Update session with classification results
|
||||
classification_updates = [
|
||||
{"car_brand": "Toyota", "brand_confidence": 0.87},
|
||||
{"car_body_type": "Sedan", "bodytype_confidence": 0.82}
|
||||
]
|
||||
|
||||
for update in classification_updates:
|
||||
manager.update_session_detection(session_id, update)
|
||||
|
||||
# 5. Update pipeline result to completed
|
||||
final_pipeline_result = {
|
||||
"status": "completed",
|
||||
"stage": "finished",
|
||||
"detections": [detection_data],
|
||||
"branches_completed": ["car_brand_cls", "car_bodytype_cls"],
|
||||
"branches_pending": [],
|
||||
"execution_time": 0.25
|
||||
}
|
||||
|
||||
manager.cache_pipeline_result(camera_id, final_pipeline_result)
|
||||
|
||||
# 6. Verify all cached data
|
||||
cached_detection = manager.get_cached_detection(camera_id)
|
||||
cached_pipeline = manager.get_cached_pipeline_result(camera_id)
|
||||
cached_session = manager.get_session_detection(session_id)
|
||||
|
||||
# Assertions
|
||||
assert cached_detection["class"] == "car"
|
||||
assert cached_detection["track_id"] == 1001
|
||||
|
||||
assert cached_pipeline["status"] == "completed"
|
||||
assert len(cached_pipeline["branches_completed"]) == 2
|
||||
|
||||
assert cached_session["detection_class"] == "car"
|
||||
assert cached_session["car_brand"] == "Toyota"
|
||||
assert cached_session["car_body_type"] == "Sedan"
|
||||
assert cached_session["brand_confidence"] == 0.87
|
||||
|
||||
def test_cache_performance_under_load(self):
|
||||
"""Test cache performance under load."""
|
||||
manager = SessionCacheManager()
|
||||
manager.clear_all()
|
||||
|
||||
import time
|
||||
|
||||
# Measure performance of cache operations
|
||||
start_time = time.time()
|
||||
|
||||
# Perform many cache operations
|
||||
for i in range(1000):
|
||||
camera_id = f"camera_{i % 10}" # 10 different cameras
|
||||
session_id = f"session_{i}"
|
||||
|
||||
# Cache detection
|
||||
detection_data = {
|
||||
"class": "car",
|
||||
"confidence": 0.9 + (i % 10) * 0.01,
|
||||
"track_id": i,
|
||||
"bbox": [i % 100, i % 100, (i % 100) + 200, (i % 100) + 200]
|
||||
}
|
||||
manager.cache_detection(camera_id, detection_data)
|
||||
|
||||
# Create session
|
||||
manager.create_session(session_id, camera_id, {"index": i})
|
||||
|
||||
# Read back (every 10th operation)
|
||||
if i % 10 == 0:
|
||||
manager.get_cached_detection(camera_id)
|
||||
manager.get_session_detection(session_id)
|
||||
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
|
||||
# Should complete in reasonable time (less than 1 second)
|
||||
assert total_time < 1.0
|
||||
|
||||
# Verify cache statistics
|
||||
stats = manager.get_cache_statistics()
|
||||
assert stats["detection_cache"]["size"] > 0
|
||||
assert stats["session_cache"]["size"] > 0
|
||||
assert stats["detection_cache"]["hits"] > 0
|
||||
|
||||
def test_cache_persistence_and_recovery(self):
|
||||
"""Test cache persistence and recovery (if enabled)."""
|
||||
# This test would be more meaningful with actual persistence
|
||||
# For now, test the configuration and structure
|
||||
|
||||
persistence_config = CacheConfig(
|
||||
max_size=100,
|
||||
enable_persistence=True,
|
||||
persistence_path="/tmp/detector_cache_test"
|
||||
)
|
||||
|
||||
cache = SessionCache(persistence_config)
|
||||
|
||||
# Add some data
|
||||
session_data = SessionData("session_123", "camera_001", "display_001")
|
||||
session_data.add_detection_data("main", {"class": "car", "confidence": 0.95})
|
||||
|
||||
cache.put("session_123", session_data)
|
||||
|
||||
# Verify data exists
|
||||
assert cache.contains("session_123") is True
|
||||
|
||||
# In a real implementation, this would test:
|
||||
# 1. Saving cache to disk
|
||||
# 2. Loading cache from disk
|
||||
# 3. Verifying data integrity after reload
|
818
tests/unit/streams/test_stream_manager.py
Normal file
818
tests/unit/streams/test_stream_manager.py
Normal file
|
@ -0,0 +1,818 @@
|
|||
"""
|
||||
Unit tests for stream management functionality.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from detector_worker.streams.stream_manager import (
|
||||
StreamManager,
|
||||
StreamInfo,
|
||||
StreamConfig,
|
||||
StreamReader,
|
||||
StreamError,
|
||||
ConnectionError as StreamConnectionError
|
||||
)
|
||||
from detector_worker.streams.frame_reader import FrameReader
|
||||
from detector_worker.core.exceptions import ConfigurationError
|
||||
|
||||
|
||||
class TestStreamConfig:
|
||||
"""Test stream configuration."""
|
||||
|
||||
def test_creation_rtsp(self):
|
||||
"""Test creating RTSP stream config."""
|
||||
config = StreamConfig(
|
||||
stream_url="rtsp://example.com/stream1",
|
||||
stream_type="rtsp",
|
||||
target_fps=15,
|
||||
reconnect_interval=5.0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
assert config.stream_url == "rtsp://example.com/stream1"
|
||||
assert config.stream_type == "rtsp"
|
||||
assert config.target_fps == 15
|
||||
assert config.reconnect_interval == 5.0
|
||||
assert config.max_retries == 3
|
||||
|
||||
def test_creation_http_snapshot(self):
|
||||
"""Test creating HTTP snapshot config."""
|
||||
config = StreamConfig(
|
||||
stream_url="http://example.com/snapshot.jpg",
|
||||
stream_type="http_snapshot",
|
||||
snapshot_interval=1.0,
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
assert config.stream_url == "http://example.com/snapshot.jpg"
|
||||
assert config.stream_type == "http_snapshot"
|
||||
assert config.snapshot_interval == 1.0
|
||||
assert config.timeout == 10.0
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creating config from dictionary."""
|
||||
config_dict = {
|
||||
"stream_url": "rtsp://camera.example.com/live",
|
||||
"stream_type": "rtsp",
|
||||
"target_fps": 20,
|
||||
"reconnect_interval": 3.0,
|
||||
"max_retries": 5,
|
||||
"crop_region": [100, 200, 300, 400],
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = StreamConfig.from_dict(config_dict)
|
||||
|
||||
assert config.stream_url == "rtsp://camera.example.com/live"
|
||||
assert config.target_fps == 20
|
||||
assert config.crop_region == [100, 200, 300, 400]
|
||||
|
||||
def test_validation(self):
|
||||
"""Test config validation."""
|
||||
# Valid config
|
||||
valid_config = StreamConfig(
|
||||
stream_url="rtsp://example.com/stream",
|
||||
stream_type="rtsp"
|
||||
)
|
||||
assert valid_config.is_valid() is True
|
||||
|
||||
# Invalid config (empty URL)
|
||||
invalid_config = StreamConfig(
|
||||
stream_url="",
|
||||
stream_type="rtsp"
|
||||
)
|
||||
assert invalid_config.is_valid() is False
|
||||
|
||||
|
||||
class TestStreamInfo:
|
||||
"""Test stream information."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test stream info creation."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
info = StreamInfo(
|
||||
stream_id="stream_001",
|
||||
config=config,
|
||||
camera_id="camera_001"
|
||||
)
|
||||
|
||||
assert info.stream_id == "stream_001"
|
||||
assert info.config == config
|
||||
assert info.camera_id == "camera_001"
|
||||
assert info.status == "inactive"
|
||||
assert info.reference_count == 0
|
||||
assert info.created_at <= time.time()
|
||||
|
||||
def test_increment_reference(self):
|
||||
"""Test incrementing reference count."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
info = StreamInfo("stream_001", config, "camera_001")
|
||||
|
||||
assert info.reference_count == 0
|
||||
|
||||
info.increment_reference()
|
||||
assert info.reference_count == 1
|
||||
|
||||
info.increment_reference()
|
||||
assert info.reference_count == 2
|
||||
|
||||
def test_decrement_reference(self):
|
||||
"""Test decrementing reference count."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
info = StreamInfo("stream_001", config, "camera_001")
|
||||
|
||||
info.reference_count = 3
|
||||
|
||||
assert info.decrement_reference() == 2
|
||||
assert info.reference_count == 2
|
||||
|
||||
assert info.decrement_reference() == 1
|
||||
assert info.decrement_reference() == 0
|
||||
|
||||
# Should not go below 0
|
||||
assert info.decrement_reference() == 0
|
||||
|
||||
def test_update_status(self):
|
||||
"""Test updating stream status."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
info = StreamInfo("stream_001", config, "camera_001")
|
||||
|
||||
info.update_status("connecting")
|
||||
assert info.status == "connecting"
|
||||
assert info.last_update <= time.time()
|
||||
|
||||
info.update_status("active", frame_count=100)
|
||||
assert info.status == "active"
|
||||
assert info.frame_count == 100
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting stream statistics."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
info = StreamInfo("stream_001", config, "camera_001")
|
||||
|
||||
info.frame_count = 1000
|
||||
info.error_count = 5
|
||||
info.reference_count = 2
|
||||
|
||||
stats = info.get_stats()
|
||||
|
||||
assert stats["stream_id"] == "stream_001"
|
||||
assert stats["status"] == "inactive"
|
||||
assert stats["frame_count"] == 1000
|
||||
assert stats["error_count"] == 5
|
||||
assert stats["reference_count"] == 2
|
||||
assert "uptime" in stats
|
||||
|
||||
|
||||
class TestStreamReader:
|
||||
"""Test stream reader functionality."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test stream reader creation."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
reader = StreamReader("stream_001", config)
|
||||
|
||||
assert reader.stream_id == "stream_001"
|
||||
assert reader.config == config
|
||||
assert reader.is_running is False
|
||||
assert reader.latest_frame is None
|
||||
assert reader.frame_queue.qsize() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_rtsp_stream(self):
|
||||
"""Test starting RTSP stream."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp", target_fps=10)
|
||||
reader = StreamReader("stream_001", config)
|
||||
|
||||
# Mock cv2.VideoCapture
|
||||
with patch('cv2.VideoCapture') as mock_cap:
|
||||
mock_cap_instance = Mock()
|
||||
mock_cap.return_value = mock_cap_instance
|
||||
mock_cap_instance.isOpened.return_value = True
|
||||
mock_cap_instance.read.return_value = (True, np.zeros((480, 640, 3), dtype=np.uint8))
|
||||
|
||||
await reader.start()
|
||||
|
||||
assert reader.is_running is True
|
||||
assert reader.capture is not None
|
||||
mock_cap.assert_called_once_with("rtsp://example.com/stream")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_rtsp_connection_failure(self):
|
||||
"""Test RTSP connection failure."""
|
||||
config = StreamConfig("rtsp://invalid.com/stream", "rtsp")
|
||||
reader = StreamReader("stream_001", config)
|
||||
|
||||
with patch('cv2.VideoCapture') as mock_cap:
|
||||
mock_cap_instance = Mock()
|
||||
mock_cap.return_value = mock_cap_instance
|
||||
mock_cap_instance.isOpened.return_value = False
|
||||
|
||||
with pytest.raises(StreamConnectionError):
|
||||
await reader.start()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_http_snapshot(self):
|
||||
"""Test starting HTTP snapshot stream."""
|
||||
config = StreamConfig("http://example.com/snapshot.jpg", "http_snapshot", snapshot_interval=1.0)
|
||||
reader = StreamReader("stream_001", config)
|
||||
|
||||
with patch('requests.get') as mock_get:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"fake_image_data"
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch('cv2.imdecode') as mock_decode:
|
||||
mock_decode.return_value = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
|
||||
await reader.start()
|
||||
|
||||
assert reader.is_running is True
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_stream(self):
|
||||
"""Test stopping stream."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
reader = StreamReader("stream_001", config)
|
||||
|
||||
# Simulate running state
|
||||
reader.is_running = True
|
||||
reader.capture = Mock()
|
||||
reader.capture.release = Mock()
|
||||
reader._reader_task = Mock()
|
||||
reader._reader_task.cancel = Mock()
|
||||
|
||||
await reader.stop()
|
||||
|
||||
assert reader.is_running is False
|
||||
reader.capture.release.assert_called_once()
|
||||
reader._reader_task.cancel.assert_called_once()
|
||||
|
||||
def test_get_latest_frame(self):
|
||||
"""Test getting latest frame."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
reader = StreamReader("stream_001", config)
|
||||
|
||||
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
|
||||
reader.latest_frame = test_frame
|
||||
|
||||
frame = reader.get_latest_frame()
|
||||
|
||||
assert np.array_equal(frame, test_frame)
|
||||
|
||||
def test_get_frame_from_queue(self):
|
||||
"""Test getting frame from queue."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
reader = StreamReader("stream_001", config)
|
||||
|
||||
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
|
||||
reader.frame_queue.put(test_frame)
|
||||
|
||||
frame = reader.get_frame(timeout=0.1)
|
||||
|
||||
assert np.array_equal(frame, test_frame)
|
||||
|
||||
def test_get_frame_timeout(self):
|
||||
"""Test getting frame with timeout."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
reader = StreamReader("stream_001", config)
|
||||
|
||||
# Queue is empty, should timeout
|
||||
frame = reader.get_frame(timeout=0.1)
|
||||
|
||||
assert frame is None
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting reader statistics."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
reader = StreamReader("stream_001", config)
|
||||
|
||||
reader.frame_count = 500
|
||||
reader.error_count = 2
|
||||
|
||||
stats = reader.get_stats()
|
||||
|
||||
assert stats["stream_id"] == "stream_001"
|
||||
assert stats["frame_count"] == 500
|
||||
assert stats["error_count"] == 2
|
||||
assert stats["is_running"] is False
|
||||
|
||||
|
||||
class TestStreamManager:
|
||||
"""Test stream manager functionality."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test stream manager initialization."""
|
||||
manager = StreamManager()
|
||||
|
||||
assert len(manager.streams) == 0
|
||||
assert len(manager.readers) == 0
|
||||
assert manager.max_streams == 10
|
||||
assert manager.default_timeout == 30.0
|
||||
|
||||
def test_initialization_with_config(self):
|
||||
"""Test initialization with custom configuration."""
|
||||
config = {
|
||||
"max_streams": 20,
|
||||
"default_timeout": 60.0,
|
||||
"frame_buffer_size": 5
|
||||
}
|
||||
|
||||
manager = StreamManager(config)
|
||||
|
||||
assert manager.max_streams == 20
|
||||
assert manager.default_timeout == 60.0
|
||||
assert manager.frame_buffer_size == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_stream_new(self):
|
||||
"""Test creating new stream."""
|
||||
manager = StreamManager()
|
||||
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
|
||||
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
|
||||
stream_info = await manager.create_stream("camera_001", config, "sub_001")
|
||||
|
||||
assert "camera_001" in manager.streams
|
||||
assert manager.streams["camera_001"].reference_count == 1
|
||||
assert manager.streams["camera_001"].camera_id == "camera_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_stream_shared(self):
|
||||
"""Test creating shared stream (same URL)."""
|
||||
manager = StreamManager()
|
||||
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
|
||||
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
|
||||
# Create first stream
|
||||
stream_info1 = await manager.create_stream("camera_001", config, "sub_001")
|
||||
|
||||
# Create second stream with same URL
|
||||
stream_info2 = await manager.create_stream("camera_001", config, "sub_002")
|
||||
|
||||
assert stream_info1 == stream_info2 # Should be same stream
|
||||
assert manager.streams["camera_001"].reference_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_stream_max_limit(self):
|
||||
"""Test creating stream when at max limit."""
|
||||
manager = StreamManager({"max_streams": 1})
|
||||
|
||||
config1 = StreamConfig("rtsp://example.com/stream1", "rtsp")
|
||||
config2 = StreamConfig("rtsp://example.com/stream2", "rtsp")
|
||||
|
||||
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
|
||||
# Create first stream (should succeed)
|
||||
await manager.create_stream("camera_001", config1, "sub_001")
|
||||
|
||||
# Try to create second stream (should fail)
|
||||
with pytest.raises(StreamError) as exc_info:
|
||||
await manager.create_stream("camera_002", config2, "sub_002")
|
||||
|
||||
assert "maximum number of streams" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_stream_single_reference(self):
|
||||
"""Test removing stream with single reference."""
|
||||
manager = StreamManager()
|
||||
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
|
||||
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
|
||||
with patch.object(StreamReader, 'stop', new_callable=AsyncMock):
|
||||
# Create stream
|
||||
await manager.create_stream("camera_001", config, "sub_001")
|
||||
|
||||
# Remove stream
|
||||
removed = await manager.remove_stream("camera_001", "sub_001")
|
||||
|
||||
assert removed is True
|
||||
assert "camera_001" not in manager.streams
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_stream_multiple_references(self):
|
||||
"""Test removing stream with multiple references."""
|
||||
manager = StreamManager()
|
||||
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
|
||||
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
|
||||
# Create shared stream
|
||||
await manager.create_stream("camera_001", config, "sub_001")
|
||||
await manager.create_stream("camera_001", config, "sub_002")
|
||||
|
||||
assert manager.streams["camera_001"].reference_count == 2
|
||||
|
||||
# Remove one reference
|
||||
removed = await manager.remove_stream("camera_001", "sub_001")
|
||||
|
||||
assert removed is True
|
||||
assert "camera_001" in manager.streams # Still exists
|
||||
assert manager.streams["camera_001"].reference_count == 1
|
||||
|
||||
def test_get_stream_info(self):
|
||||
"""Test getting stream information."""
|
||||
manager = StreamManager()
|
||||
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
stream_info = StreamInfo("camera_001", config, "camera_001")
|
||||
manager.streams["camera_001"] = stream_info
|
||||
|
||||
retrieved_info = manager.get_stream_info("camera_001")
|
||||
|
||||
assert retrieved_info == stream_info
|
||||
|
||||
def test_get_nonexistent_stream_info(self):
|
||||
"""Test getting info for non-existent stream."""
|
||||
manager = StreamManager()
|
||||
|
||||
info = manager.get_stream_info("nonexistent_camera")
|
||||
|
||||
assert info is None
|
||||
|
||||
def test_get_latest_frame(self):
|
||||
"""Test getting latest frame from stream."""
|
||||
manager = StreamManager()
|
||||
|
||||
# Create mock reader
|
||||
mock_reader = Mock()
|
||||
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
|
||||
mock_reader.get_latest_frame.return_value = test_frame
|
||||
|
||||
manager.readers["camera_001"] = mock_reader
|
||||
|
||||
frame = manager.get_latest_frame("camera_001")
|
||||
|
||||
assert np.array_equal(frame, test_frame)
|
||||
mock_reader.get_latest_frame.assert_called_once()
|
||||
|
||||
def test_get_frame_from_nonexistent_stream(self):
|
||||
"""Test getting frame from non-existent stream."""
|
||||
manager = StreamManager()
|
||||
|
||||
frame = manager.get_latest_frame("nonexistent_camera")
|
||||
|
||||
assert frame is None
|
||||
|
||||
def test_list_active_streams(self):
|
||||
"""Test listing active streams."""
|
||||
manager = StreamManager()
|
||||
|
||||
# Add streams
|
||||
config1 = StreamConfig("rtsp://example.com/stream1", "rtsp")
|
||||
config2 = StreamConfig("rtsp://example.com/stream2", "rtsp")
|
||||
|
||||
stream1 = StreamInfo("camera_001", config1, "camera_001")
|
||||
stream1.update_status("active")
|
||||
|
||||
stream2 = StreamInfo("camera_002", config2, "camera_002")
|
||||
stream2.update_status("inactive")
|
||||
|
||||
manager.streams["camera_001"] = stream1
|
||||
manager.streams["camera_002"] = stream2
|
||||
|
||||
active_streams = manager.list_active_streams()
|
||||
|
||||
assert len(active_streams) == 1
|
||||
assert active_streams[0]["camera_id"] == "camera_001"
|
||||
assert active_streams[0]["status"] == "active"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_streams(self):
|
||||
"""Test stopping all streams."""
|
||||
manager = StreamManager()
|
||||
|
||||
# Add mock streams
|
||||
mock_reader1 = Mock()
|
||||
mock_reader1.stop = AsyncMock()
|
||||
mock_reader2 = Mock()
|
||||
mock_reader2.stop = AsyncMock()
|
||||
|
||||
manager.readers["camera_001"] = mock_reader1
|
||||
manager.readers["camera_002"] = mock_reader2
|
||||
|
||||
stopped_count = await manager.stop_all_streams()
|
||||
|
||||
assert stopped_count == 2
|
||||
mock_reader1.stop.assert_called_once()
|
||||
mock_reader2.stop.assert_called_once()
|
||||
assert len(manager.readers) == 0
|
||||
assert len(manager.streams) == 0
|
||||
|
||||
def test_get_stream_statistics(self):
|
||||
"""Test getting stream statistics."""
|
||||
manager = StreamManager()
|
||||
|
||||
# Add streams
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
|
||||
stream1 = StreamInfo("camera_001", config, "camera_001")
|
||||
stream1.update_status("active")
|
||||
stream1.frame_count = 1000
|
||||
stream1.reference_count = 2
|
||||
|
||||
stream2 = StreamInfo("camera_002", config, "camera_002")
|
||||
stream2.update_status("error")
|
||||
stream2.error_count = 5
|
||||
|
||||
manager.streams["camera_001"] = stream1
|
||||
manager.streams["camera_002"] = stream2
|
||||
|
||||
stats = manager.get_stream_statistics()
|
||||
|
||||
assert stats["total_streams"] == 2
|
||||
assert stats["active_streams"] == 1
|
||||
assert stats["error_streams"] == 1
|
||||
assert stats["total_references"] == 2
|
||||
assert "status_breakdown" in stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_stream(self):
|
||||
"""Test reconnecting failed stream."""
|
||||
manager = StreamManager()
|
||||
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
stream_info = StreamInfo("camera_001", config, "camera_001")
|
||||
stream_info.update_status("error")
|
||||
manager.streams["camera_001"] = stream_info
|
||||
|
||||
# Mock reader
|
||||
mock_reader = Mock()
|
||||
mock_reader.start = AsyncMock()
|
||||
mock_reader.stop = AsyncMock()
|
||||
manager.readers["camera_001"] = mock_reader
|
||||
|
||||
result = await manager.reconnect_stream("camera_001")
|
||||
|
||||
assert result is True
|
||||
mock_reader.stop.assert_called_once()
|
||||
mock_reader.start.assert_called_once()
|
||||
assert stream_info.status != "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_streams(self):
|
||||
"""Test health check of all streams."""
|
||||
manager = StreamManager()
|
||||
|
||||
# Add streams with different states
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
|
||||
stream1 = StreamInfo("camera_001", config, "camera_001")
|
||||
stream1.update_status("active")
|
||||
|
||||
stream2 = StreamInfo("camera_002", config, "camera_002")
|
||||
stream2.update_status("error")
|
||||
|
||||
manager.streams["camera_001"] = stream1
|
||||
manager.streams["camera_002"] = stream2
|
||||
|
||||
# Mock readers
|
||||
mock_reader1 = Mock()
|
||||
mock_reader1.is_running = True
|
||||
mock_reader2 = Mock()
|
||||
mock_reader2.is_running = False
|
||||
|
||||
manager.readers["camera_001"] = mock_reader1
|
||||
manager.readers["camera_002"] = mock_reader2
|
||||
|
||||
health_report = await manager.health_check()
|
||||
|
||||
assert health_report["total_streams"] == 2
|
||||
assert health_report["healthy_streams"] == 1
|
||||
assert health_report["unhealthy_streams"] == 1
|
||||
assert len(health_report["unhealthy_stream_ids"]) == 1
|
||||
|
||||
|
||||
class TestStreamManagerIntegration:
|
||||
"""Integration tests for stream manager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_subscribers_same_stream(self):
|
||||
"""Test multiple subscribers to same stream."""
|
||||
manager = StreamManager()
|
||||
|
||||
config = StreamConfig("rtsp://example.com/shared_stream", "rtsp")
|
||||
|
||||
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
|
||||
# Multiple subscribers to same stream
|
||||
stream1 = await manager.create_stream("camera_001", config, "sub_001")
|
||||
stream2 = await manager.create_stream("camera_001", config, "sub_002")
|
||||
stream3 = await manager.create_stream("camera_001", config, "sub_003")
|
||||
|
||||
# All should reference same stream
|
||||
assert stream1 == stream2 == stream3
|
||||
assert manager.streams["camera_001"].reference_count == 3
|
||||
assert len(manager.readers) == 1 # Only one actual reader
|
||||
|
||||
# Remove subscribers one by one
|
||||
with patch.object(StreamReader, 'stop', new_callable=AsyncMock) as mock_stop:
|
||||
await manager.remove_stream("camera_001", "sub_001") # ref_count = 2
|
||||
await manager.remove_stream("camera_001", "sub_002") # ref_count = 1
|
||||
|
||||
# Stream should still exist
|
||||
assert "camera_001" in manager.streams
|
||||
mock_stop.assert_not_called()
|
||||
|
||||
await manager.remove_stream("camera_001", "sub_003") # ref_count = 0
|
||||
|
||||
# Now stream should be stopped and removed
|
||||
assert "camera_001" not in manager.streams
|
||||
mock_stop.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_failure_and_recovery(self):
|
||||
"""Test stream failure and recovery workflow."""
|
||||
manager = StreamManager()
|
||||
|
||||
config = StreamConfig("rtsp://unreliable.com/stream", "rtsp", max_retries=2)
|
||||
|
||||
# Mock reader that fails initially then succeeds
|
||||
with patch.object(StreamReader, 'start', new_callable=AsyncMock) as mock_start:
|
||||
mock_start.side_effect = [
|
||||
StreamConnectionError("Connection failed"), # First attempt fails
|
||||
None # Second attempt succeeds
|
||||
]
|
||||
|
||||
# First attempt should fail
|
||||
with pytest.raises(StreamConnectionError):
|
||||
await manager.create_stream("camera_001", config, "sub_001")
|
||||
|
||||
# Retry should succeed
|
||||
stream_info = await manager.create_stream("camera_001", config, "sub_001")
|
||||
|
||||
assert stream_info is not None
|
||||
assert mock_start.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_stream_operations(self):
|
||||
"""Test concurrent stream operations."""
|
||||
manager = StreamManager()
|
||||
|
||||
configs = [
|
||||
StreamConfig(f"rtsp://example.com/stream{i}", "rtsp")
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
|
||||
with patch.object(StreamReader, 'stop', new_callable=AsyncMock):
|
||||
# Create streams concurrently
|
||||
create_tasks = [
|
||||
manager.create_stream(f"camera_{i}", configs[i], f"sub_{i}")
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*create_tasks)
|
||||
|
||||
assert len(results) == 5
|
||||
assert len(manager.streams) == 5
|
||||
|
||||
# Remove streams concurrently
|
||||
remove_tasks = [
|
||||
manager.remove_stream(f"camera_{i}", f"sub_{i}")
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
remove_results = await asyncio.gather(*remove_tasks)
|
||||
|
||||
assert all(remove_results)
|
||||
assert len(manager.streams) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_management_large_scale(self):
|
||||
"""Test memory management with many streams."""
|
||||
manager = StreamManager({"max_streams": 50})
|
||||
|
||||
# Create many streams
|
||||
with patch.object(StreamReader, 'start', new_callable=AsyncMock):
|
||||
for i in range(30):
|
||||
config = StreamConfig(f"rtsp://example.com/stream{i}", "rtsp")
|
||||
await manager.create_stream(f"camera_{i}", config, f"sub_{i}")
|
||||
|
||||
# Verify memory usage is reasonable
|
||||
stats = manager.get_stream_statistics()
|
||||
assert stats["total_streams"] == 30
|
||||
assert stats["active_streams"] <= 30
|
||||
|
||||
# Test bulk cleanup
|
||||
with patch.object(StreamReader, 'stop', new_callable=AsyncMock):
|
||||
stopped_count = await manager.stop_all_streams()
|
||||
|
||||
assert stopped_count == 30
|
||||
assert len(manager.streams) == 0
|
||||
assert len(manager.readers) == 0
|
||||
|
||||
|
||||
class TestFrameReaderIntegration:
|
||||
"""Integration tests for frame reader."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rtsp_frame_processing(self):
|
||||
"""Test RTSP frame processing pipeline."""
|
||||
config = StreamConfig(
|
||||
stream_url="rtsp://example.com/stream",
|
||||
stream_type="rtsp",
|
||||
target_fps=10,
|
||||
crop_region=[100, 100, 400, 300]
|
||||
)
|
||||
|
||||
reader = StreamReader("test_stream", config)
|
||||
|
||||
# Mock cv2.VideoCapture
|
||||
with patch('cv2.VideoCapture') as mock_cap:
|
||||
mock_cap_instance = Mock()
|
||||
mock_cap.return_value = mock_cap_instance
|
||||
mock_cap_instance.isOpened.return_value = True
|
||||
|
||||
# Mock frame sequence
|
||||
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128
|
||||
mock_cap_instance.read.side_effect = [
|
||||
(True, test_frame), # First frame
|
||||
(True, test_frame * 0.8), # Second frame
|
||||
(False, None), # Connection lost
|
||||
(True, test_frame * 1.2), # Reconnected
|
||||
]
|
||||
|
||||
await reader.start()
|
||||
|
||||
# Let reader process some frames
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Verify frame processing
|
||||
latest_frame = reader.get_latest_frame()
|
||||
assert latest_frame is not None
|
||||
assert latest_frame.shape == (480, 640, 3)
|
||||
|
||||
await reader.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_snapshot_processing(self):
|
||||
"""Test HTTP snapshot processing."""
|
||||
config = StreamConfig(
|
||||
stream_url="http://camera.example.com/snapshot.jpg",
|
||||
stream_type="http_snapshot",
|
||||
snapshot_interval=0.5,
|
||||
timeout=5.0
|
||||
)
|
||||
|
||||
reader = StreamReader("snapshot_stream", config)
|
||||
|
||||
with patch('requests.get') as mock_get:
|
||||
# Mock HTTP responses
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"fake_jpeg_data"
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with patch('cv2.imdecode') as mock_decode:
|
||||
test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 200
|
||||
mock_decode.return_value = test_frame
|
||||
|
||||
await reader.start()
|
||||
|
||||
# Wait for snapshot capture
|
||||
await asyncio.sleep(0.6)
|
||||
|
||||
# Verify snapshot processing
|
||||
latest_frame = reader.get_latest_frame()
|
||||
assert latest_frame is not None
|
||||
assert np.array_equal(latest_frame, test_frame)
|
||||
|
||||
await reader.stop()
|
||||
|
||||
def test_frame_queue_management(self):
|
||||
"""Test frame queue management and buffering."""
|
||||
config = StreamConfig("rtsp://example.com/stream", "rtsp")
|
||||
reader = StreamReader("queue_test", config, frame_buffer_size=3)
|
||||
|
||||
# Add frames to queue
|
||||
frames = [
|
||||
np.ones((100, 100, 3), dtype=np.uint8) * i
|
||||
for i in range(50, 250, 50) # 4 different frames
|
||||
]
|
||||
|
||||
for frame in frames[:3]: # Fill buffer
|
||||
reader._add_frame_to_queue(frame)
|
||||
|
||||
assert reader.frame_queue.qsize() == 3
|
||||
|
||||
# Add one more (should drop oldest)
|
||||
reader._add_frame_to_queue(frames[3])
|
||||
assert reader.frame_queue.qsize() == 3
|
||||
|
||||
# Verify frame order (oldest should be dropped)
|
||||
retrieved_frames = []
|
||||
while not reader.frame_queue.empty():
|
||||
retrieved_frames.append(reader.get_frame(timeout=0.1))
|
||||
|
||||
assert len(retrieved_frames) == 3
|
||||
# First frame should have been dropped, so we should have frames 1,2,3
|
||||
assert not np.array_equal(retrieved_frames[0], frames[0])
|
Loading…
Add table
Add a link
Reference in a new issue