python-detector-worker/tests/unit/communication/test_websocket_handler.py
2025-09-12 18:55:23 +07:00

856 lines
No EOL
32 KiB
Python

"""
Unit tests for WebSocket handling functionality.
"""
import pytest
import asyncio
import json
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from fastapi.websockets import WebSocket, WebSocketDisconnect
import uuid
from detector_worker.communication.websocket_handler import (
WebSocketHandler,
ConnectionManager,
WebSocketConnection,
MessageHandler,
WebSocketError,
ConnectionError as WSConnectionError
)
from detector_worker.communication.message_processor import MessageType
from detector_worker.core.exceptions import MessageProcessingError
class TestWebSocketConnection:
"""Test WebSocket connection wrapper."""
def test_creation(self, mock_websocket):
"""Test WebSocket connection creation."""
connection = WebSocketConnection(mock_websocket, "client_001")
assert connection.websocket == mock_websocket
assert connection.client_id == "client_001"
assert connection.is_connected is False
assert connection.connected_at is None
assert connection.last_ping is None
assert connection.subscription_id is None
@pytest.mark.asyncio
async def test_accept_connection(self, mock_websocket):
"""Test accepting WebSocket connection."""
connection = WebSocketConnection(mock_websocket, "client_001")
mock_websocket.accept = AsyncMock()
await connection.accept()
assert connection.is_connected is True
assert connection.connected_at is not None
mock_websocket.accept.assert_called_once()
@pytest.mark.asyncio
async def test_send_message_json(self, mock_websocket):
"""Test sending JSON message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_json = AsyncMock()
message = {"type": "test", "data": "hello"}
await connection.send_message(message)
mock_websocket.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_send_message_text(self, mock_websocket):
"""Test sending text message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_text = AsyncMock()
await connection.send_message("hello world")
mock_websocket.send_text.assert_called_once_with("hello world")
@pytest.mark.asyncio
async def test_send_message_not_connected(self, mock_websocket):
"""Test sending message when not connected."""
connection = WebSocketConnection(mock_websocket, "client_001")
# Don't set is_connected = True
with pytest.raises(WebSocketError) as exc_info:
await connection.send_message({"type": "test"})
assert "not connected" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_receive_message_json(self, mock_websocket):
"""Test receiving JSON message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.receive_json = AsyncMock(return_value={"type": "test", "data": "received"})
message = await connection.receive_message()
assert message == {"type": "test", "data": "received"}
mock_websocket.receive_json.assert_called_once()
@pytest.mark.asyncio
async def test_receive_message_text(self, mock_websocket):
"""Test receiving text message."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
# Mock receive_json to fail, then receive_text to succeed
mock_websocket.receive_json = AsyncMock(side_effect=json.JSONDecodeError("Invalid JSON", "", 0))
mock_websocket.receive_text = AsyncMock(return_value="plain text message")
message = await connection.receive_message()
assert message == "plain text message"
mock_websocket.receive_text.assert_called_once()
@pytest.mark.asyncio
async def test_ping_pong(self, mock_websocket):
"""Test ping/pong functionality."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.ping = AsyncMock()
await connection.ping()
assert connection.last_ping is not None
mock_websocket.ping.assert_called_once()
@pytest.mark.asyncio
async def test_close_connection(self, mock_websocket):
"""Test closing connection."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.close = AsyncMock()
await connection.close(code=1000, reason="Normal closure")
assert connection.is_connected is False
mock_websocket.close.assert_called_once_with(code=1000, reason="Normal closure")
def test_connection_info(self, mock_websocket):
"""Test getting connection information."""
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
connection.subscription_id = "sub_123"
info = connection.get_connection_info()
assert info["client_id"] == "client_001"
assert info["is_connected"] is True
assert info["subscription_id"] == "sub_123"
assert "connected_at" in info
assert "last_ping" in info
class TestConnectionManager:
"""Test WebSocket connection management."""
def test_initialization(self):
"""Test connection manager initialization."""
manager = ConnectionManager()
assert len(manager.connections) == 0
assert len(manager.subscriptions) == 0
assert manager.max_connections == 100
@pytest.mark.asyncio
async def test_add_connection(self, mock_websocket):
"""Test adding a connection."""
manager = ConnectionManager()
client_id = "client_001"
connection = await manager.add_connection(mock_websocket, client_id)
assert connection.client_id == client_id
assert client_id in manager.connections
assert manager.get_connection_count() == 1
@pytest.mark.asyncio
async def test_remove_connection(self, mock_websocket):
"""Test removing a connection."""
manager = ConnectionManager()
client_id = "client_001"
await manager.add_connection(mock_websocket, client_id)
assert client_id in manager.connections
removed_connection = await manager.remove_connection(client_id)
assert removed_connection is not None
assert removed_connection.client_id == client_id
assert client_id not in manager.connections
assert manager.get_connection_count() == 0
def test_get_connection(self, mock_websocket):
"""Test getting a connection."""
manager = ConnectionManager()
client_id = "client_001"
# Manually add connection for testing
connection = WebSocketConnection(mock_websocket, client_id)
manager.connections[client_id] = connection
retrieved_connection = manager.get_connection(client_id)
assert retrieved_connection == connection
assert retrieved_connection.client_id == client_id
def test_get_nonexistent_connection(self):
"""Test getting non-existent connection."""
manager = ConnectionManager()
connection = manager.get_connection("nonexistent_client")
assert connection is None
@pytest.mark.asyncio
async def test_broadcast_message(self, mock_websocket):
"""Test broadcasting message to all connections."""
manager = ConnectionManager()
# Add multiple connections
connections = []
for i in range(3):
client_id = f"client_{i}"
ws = Mock()
ws.send_json = AsyncMock()
connection = WebSocketConnection(ws, client_id)
connection.is_connected = True
manager.connections[client_id] = connection
connections.append(connection)
message = {"type": "broadcast", "data": "hello all"}
await manager.broadcast(message)
# All connections should have received the message
for connection in connections:
connection.websocket.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_broadcast_to_subscription(self, mock_websocket):
"""Test broadcasting to specific subscription."""
manager = ConnectionManager()
# Add connections with different subscriptions
subscription_id = "camera_001"
# Connection with target subscription
ws1 = Mock()
ws1.send_json = AsyncMock()
connection1 = WebSocketConnection(ws1, "client_001")
connection1.is_connected = True
connection1.subscription_id = subscription_id
manager.connections["client_001"] = connection1
manager.subscriptions[subscription_id] = {"client_001"}
# Connection with different subscription
ws2 = Mock()
ws2.send_json = AsyncMock()
connection2 = WebSocketConnection(ws2, "client_002")
connection2.is_connected = True
connection2.subscription_id = "camera_002"
manager.connections["client_002"] = connection2
manager.subscriptions["camera_002"] = {"client_002"}
message = {"type": "detection", "data": "camera detection"}
await manager.broadcast_to_subscription(subscription_id, message)
# Only connection1 should have received the message
ws1.send_json.assert_called_once_with(message)
ws2.send_json.assert_not_called()
def test_add_subscription(self):
"""Test adding subscription mapping."""
manager = ConnectionManager()
client_id = "client_001"
subscription_id = "camera_001"
manager.add_subscription(client_id, subscription_id)
assert subscription_id in manager.subscriptions
assert client_id in manager.subscriptions[subscription_id]
def test_remove_subscription(self):
"""Test removing subscription mapping."""
manager = ConnectionManager()
client_id = "client_001"
subscription_id = "camera_001"
# Add subscription first
manager.add_subscription(client_id, subscription_id)
assert client_id in manager.subscriptions[subscription_id]
# Remove subscription
manager.remove_subscription(client_id, subscription_id)
assert client_id not in manager.subscriptions.get(subscription_id, set())
def test_get_subscription_clients(self):
"""Test getting clients for a subscription."""
manager = ConnectionManager()
subscription_id = "camera_001"
clients = ["client_001", "client_002", "client_003"]
for client_id in clients:
manager.add_subscription(client_id, subscription_id)
subscription_clients = manager.get_subscription_clients(subscription_id)
assert subscription_clients == set(clients)
def test_get_client_subscriptions(self):
"""Test getting subscriptions for a client."""
manager = ConnectionManager()
client_id = "client_001"
subscriptions = ["camera_001", "camera_002", "camera_003"]
for subscription_id in subscriptions:
manager.add_subscription(client_id, subscription_id)
client_subscriptions = manager.get_client_subscriptions(client_id)
assert client_subscriptions == set(subscriptions)
@pytest.mark.asyncio
async def test_cleanup_disconnected_connections(self):
"""Test cleanup of disconnected connections."""
manager = ConnectionManager()
# Add connected and disconnected connections
ws1 = Mock()
connection1 = WebSocketConnection(ws1, "client_001")
connection1.is_connected = True
manager.connections["client_001"] = connection1
ws2 = Mock()
connection2 = WebSocketConnection(ws2, "client_002")
connection2.is_connected = False # Disconnected
manager.connections["client_002"] = connection2
# Add subscriptions
manager.add_subscription("client_001", "camera_001")
manager.add_subscription("client_002", "camera_002")
cleaned_count = await manager.cleanup_disconnected()
assert cleaned_count == 1
assert "client_001" in manager.connections # Still connected
assert "client_002" not in manager.connections # Cleaned up
# Subscriptions should also be cleaned up
assert manager.get_client_subscriptions("client_002") == set()
def test_get_connection_stats(self):
"""Test getting connection statistics."""
manager = ConnectionManager()
# Add various connections and subscriptions
for i in range(3):
client_id = f"client_{i}"
ws = Mock()
connection = WebSocketConnection(ws, client_id)
connection.is_connected = i < 2 # First 2 connected, last one disconnected
manager.connections[client_id] = connection
if i < 2: # Add subscriptions for connected clients
manager.add_subscription(client_id, f"camera_{i}")
stats = manager.get_connection_stats()
assert stats["total_connections"] == 3
assert stats["active_connections"] == 2
assert stats["total_subscriptions"] == 2
assert "uptime" in stats
class TestMessageHandler:
"""Test message handling functionality."""
def test_creation(self):
"""Test message handler creation."""
mock_processor = Mock()
handler = MessageHandler(mock_processor)
assert handler.message_processor == mock_processor
assert handler.connection_manager is None
def test_set_connection_manager(self):
"""Test setting connection manager."""
mock_processor = Mock()
mock_manager = Mock()
handler = MessageHandler(mock_processor)
handler.set_connection_manager(mock_manager)
assert handler.connection_manager == mock_manager
@pytest.mark.asyncio
async def test_handle_message_success(self, mock_websocket):
"""Test successful message handling."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(return_value={"type": "response", "status": "success"})
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
message = {"type": "subscribe", "payload": {"camera_id": "camera_001"}}
response = await handler.handle_message(connection, message)
assert response["status"] == "success"
mock_processor.process_message.assert_called_once_with(message, "client_001")
@pytest.mark.asyncio
async def test_handle_message_processing_error(self, mock_websocket):
"""Test message handling with processing error."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(side_effect=MessageProcessingError("Invalid message"))
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
message = {"type": "invalid", "payload": {}}
response = await handler.handle_message(connection, message)
assert response["type"] == "error"
assert "Invalid message" in response["message"]
@pytest.mark.asyncio
async def test_handle_message_unexpected_error(self, mock_websocket):
"""Test message handling with unexpected error."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(side_effect=Exception("Unexpected error"))
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
message = {"type": "test", "payload": {}}
response = await handler.handle_message(connection, message)
assert response["type"] == "error"
assert "internal error" in response["message"].lower()
@pytest.mark.asyncio
async def test_send_response(self, mock_websocket):
"""Test sending response to client."""
mock_processor = Mock()
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_json = AsyncMock()
response = {"type": "response", "data": "test response"}
await handler.send_response(connection, response)
mock_websocket.send_json.assert_called_once_with(response)
@pytest.mark.asyncio
async def test_send_error_response(self, mock_websocket):
"""Test sending error response."""
mock_processor = Mock()
handler = MessageHandler(mock_processor)
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
mock_websocket.send_json = AsyncMock()
await handler.send_error_response(connection, "Test error message", "TEST_ERROR")
mock_websocket.send_json.assert_called_once()
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == "error"
assert call_args["message"] == "Test error message"
assert call_args["error_code"] == "TEST_ERROR"
class TestWebSocketHandler:
"""Test main WebSocket handler functionality."""
def test_initialization(self):
"""Test WebSocket handler initialization."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
assert isinstance(handler.connection_manager, ConnectionManager)
assert isinstance(handler.message_handler, MessageHandler)
assert handler.message_handler.connection_manager == handler.connection_manager
assert handler.heartbeat_interval == 30.0
assert handler.max_connections == 100
def test_initialization_with_config(self):
"""Test initialization with custom configuration."""
mock_processor = Mock()
config = {
"heartbeat_interval": 60.0,
"max_connections": 200,
"connection_timeout": 300.0
}
handler = WebSocketHandler(mock_processor, config)
assert handler.heartbeat_interval == 60.0
assert handler.max_connections == 200
assert handler.connection_timeout == 300.0
@pytest.mark.asyncio
async def test_handle_websocket_connection(self, mock_websocket):
"""Test handling WebSocket connection."""
mock_processor = Mock()
mock_processor.process_message = AsyncMock(return_value={"type": "ack", "status": "success"})
handler = WebSocketHandler(mock_processor)
# Mock WebSocket behavior
mock_websocket.accept = AsyncMock()
mock_websocket.receive_json = AsyncMock(side_effect=[
{"type": "subscribe", "payload": {"camera_id": "camera_001"}},
WebSocketDisconnect() # Simulate disconnection
])
mock_websocket.send_json = AsyncMock()
client_id = "test_client_001"
# Handle connection (should not raise exception)
await handler.handle_websocket(mock_websocket, client_id)
# Verify connection was accepted
mock_websocket.accept.assert_called_once()
# Verify message was processed
mock_processor.process_message.assert_called_once()
@pytest.mark.asyncio
async def test_handle_websocket_max_connections(self, mock_websocket):
"""Test handling max connections limit."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor, {"max_connections": 1})
# Add one connection to reach limit
client1_ws = Mock()
connection1 = WebSocketConnection(client1_ws, "client_001")
handler.connection_manager.connections["client_001"] = connection1
mock_websocket.close = AsyncMock()
# Try to add second connection
await handler.handle_websocket(mock_websocket, "client_002")
# Should close connection due to limit
mock_websocket.close.assert_called_once()
@pytest.mark.asyncio
async def test_broadcast_message(self):
"""Test broadcasting message to all connections."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Mock connection manager
handler.connection_manager.broadcast = AsyncMock()
message = {"type": "system", "data": "Server maintenance in 10 minutes"}
await handler.broadcast_message(message)
handler.connection_manager.broadcast.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_send_to_client(self):
"""Test sending message to specific client."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Create mock connection
mock_websocket = Mock()
mock_websocket.send_json = AsyncMock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
handler.connection_manager.connections["client_001"] = connection
message = {"type": "notification", "data": "Personal message"}
result = await handler.send_to_client("client_001", message)
assert result is True
mock_websocket.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_send_to_nonexistent_client(self):
"""Test sending message to non-existent client."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
message = {"type": "notification", "data": "Message"}
result = await handler.send_to_client("nonexistent_client", message)
assert result is False
@pytest.mark.asyncio
async def test_send_to_subscription(self):
"""Test sending message to subscription."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Mock connection manager
handler.connection_manager.broadcast_to_subscription = AsyncMock()
subscription_id = "camera_001"
message = {"type": "detection", "data": {"class": "car", "confidence": 0.95}}
await handler.send_to_subscription(subscription_id, message)
handler.connection_manager.broadcast_to_subscription.assert_called_once_with(subscription_id, message)
@pytest.mark.asyncio
async def test_start_heartbeat_task(self):
"""Test starting heartbeat task."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor, {"heartbeat_interval": 0.1})
# Mock connection with ping capability
mock_websocket = Mock()
mock_websocket.ping = AsyncMock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
handler.connection_manager.connections["client_001"] = connection
# Start heartbeat task
heartbeat_task = asyncio.create_task(handler._heartbeat_loop())
# Let it run briefly
await asyncio.sleep(0.2)
# Cancel task
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
# Should have sent at least one ping
assert mock_websocket.ping.called
def test_get_connection_stats(self):
"""Test getting connection statistics."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Add some mock connections
for i in range(3):
client_id = f"client_{i}"
ws = Mock()
connection = WebSocketConnection(ws, client_id)
connection.is_connected = True
handler.connection_manager.connections[client_id] = connection
stats = handler.get_connection_stats()
assert stats["total_connections"] == 3
assert stats["active_connections"] == 3
def test_get_client_info(self):
"""Test getting client information."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Add mock connection
mock_websocket = Mock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
connection.subscription_id = "camera_001"
handler.connection_manager.connections["client_001"] = connection
info = handler.get_client_info("client_001")
assert info is not None
assert info["client_id"] == "client_001"
assert info["is_connected"] is True
assert info["subscription_id"] == "camera_001"
@pytest.mark.asyncio
async def test_disconnect_client(self):
"""Test disconnecting specific client."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Add mock connection
mock_websocket = Mock()
mock_websocket.close = AsyncMock()
connection = WebSocketConnection(mock_websocket, "client_001")
connection.is_connected = True
handler.connection_manager.connections["client_001"] = connection
result = await handler.disconnect_client("client_001", code=1000, reason="Admin disconnect")
assert result is True
mock_websocket.close.assert_called_once_with(code=1000, reason="Admin disconnect")
@pytest.mark.asyncio
async def test_cleanup_connections(self):
"""Test cleanup of disconnected connections."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
# Mock connection manager cleanup
handler.connection_manager.cleanup_disconnected = AsyncMock(return_value=2)
cleaned_count = await handler.cleanup_connections()
assert cleaned_count == 2
handler.connection_manager.cleanup_disconnected.assert_called_once()
class TestWebSocketHandlerIntegration:
"""Integration tests for WebSocket handler."""
@pytest.mark.asyncio
async def test_complete_subscription_workflow(self, mock_websocket):
"""Test complete subscription workflow."""
mock_processor = Mock()
# Mock processor responses
mock_processor.process_message = AsyncMock(side_effect=[
{"type": "subscribeAck", "status": "success", "subscription_id": "camera_001"},
{"type": "unsubscribeAck", "status": "success"}
])
handler = WebSocketHandler(mock_processor)
# Mock WebSocket behavior
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock(side_effect=[
{"type": "subscribe", "payload": {"camera_id": "camera_001", "rtsp_url": "rtsp://example.com"}},
{"type": "unsubscribe", "payload": {"subscription_id": "camera_001"}},
WebSocketDisconnect()
])
client_id = "test_client"
# Handle complete workflow
await handler.handle_websocket(mock_websocket, client_id)
# Verify both messages were processed
assert mock_processor.process_message.call_count == 2
# Verify responses were sent
assert mock_websocket.send_json.call_count == 2
@pytest.mark.asyncio
async def test_multiple_client_management(self):
"""Test managing multiple concurrent clients."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor)
clients = []
for i in range(5):
client_id = f"client_{i}"
mock_ws = Mock()
mock_ws.send_json = AsyncMock()
connection = WebSocketConnection(mock_ws, client_id)
connection.is_connected = True
handler.connection_manager.connections[client_id] = connection
clients.append(connection)
# Test broadcasting to all clients
message = {"type": "broadcast", "data": "Hello all clients"}
await handler.broadcast_message(message)
# All clients should receive the message
for connection in clients:
connection.websocket.send_json.assert_called_once_with(message)
# Test subscription-specific messaging
subscription_id = "camera_001"
handler.connection_manager.add_subscription("client_0", subscription_id)
handler.connection_manager.add_subscription("client_2", subscription_id)
subscription_message = {"type": "detection", "camera_id": "camera_001"}
await handler.send_to_subscription(subscription_id, subscription_message)
# Only subscribed clients should receive the message
# Note: This would require additional mocking of broadcast_to_subscription
@pytest.mark.asyncio
async def test_error_handling_and_recovery(self, mock_websocket):
"""Test error handling and recovery scenarios."""
mock_processor = Mock()
# First message causes error, second succeeds
mock_processor.process_message = AsyncMock(side_effect=[
MessageProcessingError("Invalid message format"),
{"type": "ack", "status": "success"}
])
handler = WebSocketHandler(mock_processor)
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock(side_effect=[
{"type": "invalid", "malformed": True},
{"type": "valid", "payload": {"test": True}},
WebSocketDisconnect()
])
client_id = "error_test_client"
# Should handle errors gracefully and continue processing
await handler.handle_websocket(mock_websocket, client_id)
# Both messages should have been processed
assert mock_processor.process_message.call_count == 2
# Should have sent error response and success response
assert mock_websocket.send_json.call_count == 2
# First call should be error response
first_response = mock_websocket.send_json.call_args_list[0][0][0]
assert first_response["type"] == "error"
@pytest.mark.asyncio
async def test_connection_timeout_handling(self):
"""Test connection timeout handling."""
mock_processor = Mock()
handler = WebSocketHandler(mock_processor, {"connection_timeout": 0.1})
# Add connection that hasn't been active
mock_websocket = Mock()
connection = WebSocketConnection(mock_websocket, "timeout_client")
connection.is_connected = True
# Don't update last_ping to simulate timeout
handler.connection_manager.connections["timeout_client"] = connection
# Wait longer than timeout
await asyncio.sleep(0.2)
# Manual cleanup (in real implementation this would be automatic)
cleaned = await handler.cleanup_connections()
# Connection should be identified for cleanup
# (Actual timeout logic would need to be implemented in the cleanup method)