""" Unit tests for WebSocket handling functionality. """ import pytest import asyncio import json from unittest.mock import Mock, AsyncMock, patch, MagicMock from fastapi.websockets import WebSocket, WebSocketDisconnect import uuid from detector_worker.communication.websocket_handler import ( WebSocketHandler, ConnectionManager, WebSocketConnection, MessageHandler, WebSocketError, ConnectionError as WSConnectionError ) from detector_worker.communication.message_processor import MessageType from detector_worker.core.exceptions import MessageProcessingError class TestWebSocketConnection: """Test WebSocket connection wrapper.""" def test_creation(self, mock_websocket): """Test WebSocket connection creation.""" connection = WebSocketConnection(mock_websocket, "client_001") assert connection.websocket == mock_websocket assert connection.client_id == "client_001" assert connection.is_connected is False assert connection.connected_at is None assert connection.last_ping is None assert connection.subscription_id is None @pytest.mark.asyncio async def test_accept_connection(self, mock_websocket): """Test accepting WebSocket connection.""" connection = WebSocketConnection(mock_websocket, "client_001") mock_websocket.accept = AsyncMock() await connection.accept() assert connection.is_connected is True assert connection.connected_at is not None mock_websocket.accept.assert_called_once() @pytest.mark.asyncio async def test_send_message_json(self, mock_websocket): """Test sending JSON message.""" connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True mock_websocket.send_json = AsyncMock() message = {"type": "test", "data": "hello"} await connection.send_message(message) mock_websocket.send_json.assert_called_once_with(message) @pytest.mark.asyncio async def test_send_message_text(self, mock_websocket): """Test sending text message.""" connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True mock_websocket.send_text = AsyncMock() await connection.send_message("hello world") mock_websocket.send_text.assert_called_once_with("hello world") @pytest.mark.asyncio async def test_send_message_not_connected(self, mock_websocket): """Test sending message when not connected.""" connection = WebSocketConnection(mock_websocket, "client_001") # Don't set is_connected = True with pytest.raises(WebSocketError) as exc_info: await connection.send_message({"type": "test"}) assert "not connected" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_receive_message_json(self, mock_websocket): """Test receiving JSON message.""" connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True mock_websocket.receive_json = AsyncMock(return_value={"type": "test", "data": "received"}) message = await connection.receive_message() assert message == {"type": "test", "data": "received"} mock_websocket.receive_json.assert_called_once() @pytest.mark.asyncio async def test_receive_message_text(self, mock_websocket): """Test receiving text message.""" connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True # Mock receive_json to fail, then receive_text to succeed mock_websocket.receive_json = AsyncMock(side_effect=json.JSONDecodeError("Invalid JSON", "", 0)) mock_websocket.receive_text = AsyncMock(return_value="plain text message") message = await connection.receive_message() assert message == "plain text message" mock_websocket.receive_text.assert_called_once() @pytest.mark.asyncio async def test_ping_pong(self, mock_websocket): """Test ping/pong functionality.""" connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True mock_websocket.ping = AsyncMock() await connection.ping() assert connection.last_ping is not None mock_websocket.ping.assert_called_once() @pytest.mark.asyncio async def test_close_connection(self, mock_websocket): """Test closing connection.""" connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True mock_websocket.close = AsyncMock() await connection.close(code=1000, reason="Normal closure") assert connection.is_connected is False mock_websocket.close.assert_called_once_with(code=1000, reason="Normal closure") def test_connection_info(self, mock_websocket): """Test getting connection information.""" connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True connection.subscription_id = "sub_123" info = connection.get_connection_info() assert info["client_id"] == "client_001" assert info["is_connected"] is True assert info["subscription_id"] == "sub_123" assert "connected_at" in info assert "last_ping" in info class TestConnectionManager: """Test WebSocket connection management.""" def test_initialization(self): """Test connection manager initialization.""" manager = ConnectionManager() assert len(manager.connections) == 0 assert len(manager.subscriptions) == 0 assert manager.max_connections == 100 @pytest.mark.asyncio async def test_add_connection(self, mock_websocket): """Test adding a connection.""" manager = ConnectionManager() client_id = "client_001" connection = await manager.add_connection(mock_websocket, client_id) assert connection.client_id == client_id assert client_id in manager.connections assert manager.get_connection_count() == 1 @pytest.mark.asyncio async def test_remove_connection(self, mock_websocket): """Test removing a connection.""" manager = ConnectionManager() client_id = "client_001" await manager.add_connection(mock_websocket, client_id) assert client_id in manager.connections removed_connection = await manager.remove_connection(client_id) assert removed_connection is not None assert removed_connection.client_id == client_id assert client_id not in manager.connections assert manager.get_connection_count() == 0 def test_get_connection(self, mock_websocket): """Test getting a connection.""" manager = ConnectionManager() client_id = "client_001" # Manually add connection for testing connection = WebSocketConnection(mock_websocket, client_id) manager.connections[client_id] = connection retrieved_connection = manager.get_connection(client_id) assert retrieved_connection == connection assert retrieved_connection.client_id == client_id def test_get_nonexistent_connection(self): """Test getting non-existent connection.""" manager = ConnectionManager() connection = manager.get_connection("nonexistent_client") assert connection is None @pytest.mark.asyncio async def test_broadcast_message(self, mock_websocket): """Test broadcasting message to all connections.""" manager = ConnectionManager() # Add multiple connections connections = [] for i in range(3): client_id = f"client_{i}" ws = Mock() ws.send_json = AsyncMock() connection = WebSocketConnection(ws, client_id) connection.is_connected = True manager.connections[client_id] = connection connections.append(connection) message = {"type": "broadcast", "data": "hello all"} await manager.broadcast(message) # All connections should have received the message for connection in connections: connection.websocket.send_json.assert_called_once_with(message) @pytest.mark.asyncio async def test_broadcast_to_subscription(self, mock_websocket): """Test broadcasting to specific subscription.""" manager = ConnectionManager() # Add connections with different subscriptions subscription_id = "camera_001" # Connection with target subscription ws1 = Mock() ws1.send_json = AsyncMock() connection1 = WebSocketConnection(ws1, "client_001") connection1.is_connected = True connection1.subscription_id = subscription_id manager.connections["client_001"] = connection1 manager.subscriptions[subscription_id] = {"client_001"} # Connection with different subscription ws2 = Mock() ws2.send_json = AsyncMock() connection2 = WebSocketConnection(ws2, "client_002") connection2.is_connected = True connection2.subscription_id = "camera_002" manager.connections["client_002"] = connection2 manager.subscriptions["camera_002"] = {"client_002"} message = {"type": "detection", "data": "camera detection"} await manager.broadcast_to_subscription(subscription_id, message) # Only connection1 should have received the message ws1.send_json.assert_called_once_with(message) ws2.send_json.assert_not_called() def test_add_subscription(self): """Test adding subscription mapping.""" manager = ConnectionManager() client_id = "client_001" subscription_id = "camera_001" manager.add_subscription(client_id, subscription_id) assert subscription_id in manager.subscriptions assert client_id in manager.subscriptions[subscription_id] def test_remove_subscription(self): """Test removing subscription mapping.""" manager = ConnectionManager() client_id = "client_001" subscription_id = "camera_001" # Add subscription first manager.add_subscription(client_id, subscription_id) assert client_id in manager.subscriptions[subscription_id] # Remove subscription manager.remove_subscription(client_id, subscription_id) assert client_id not in manager.subscriptions.get(subscription_id, set()) def test_get_subscription_clients(self): """Test getting clients for a subscription.""" manager = ConnectionManager() subscription_id = "camera_001" clients = ["client_001", "client_002", "client_003"] for client_id in clients: manager.add_subscription(client_id, subscription_id) subscription_clients = manager.get_subscription_clients(subscription_id) assert subscription_clients == set(clients) def test_get_client_subscriptions(self): """Test getting subscriptions for a client.""" manager = ConnectionManager() client_id = "client_001" subscriptions = ["camera_001", "camera_002", "camera_003"] for subscription_id in subscriptions: manager.add_subscription(client_id, subscription_id) client_subscriptions = manager.get_client_subscriptions(client_id) assert client_subscriptions == set(subscriptions) @pytest.mark.asyncio async def test_cleanup_disconnected_connections(self): """Test cleanup of disconnected connections.""" manager = ConnectionManager() # Add connected and disconnected connections ws1 = Mock() connection1 = WebSocketConnection(ws1, "client_001") connection1.is_connected = True manager.connections["client_001"] = connection1 ws2 = Mock() connection2 = WebSocketConnection(ws2, "client_002") connection2.is_connected = False # Disconnected manager.connections["client_002"] = connection2 # Add subscriptions manager.add_subscription("client_001", "camera_001") manager.add_subscription("client_002", "camera_002") cleaned_count = await manager.cleanup_disconnected() assert cleaned_count == 1 assert "client_001" in manager.connections # Still connected assert "client_002" not in manager.connections # Cleaned up # Subscriptions should also be cleaned up assert manager.get_client_subscriptions("client_002") == set() def test_get_connection_stats(self): """Test getting connection statistics.""" manager = ConnectionManager() # Add various connections and subscriptions for i in range(3): client_id = f"client_{i}" ws = Mock() connection = WebSocketConnection(ws, client_id) connection.is_connected = i < 2 # First 2 connected, last one disconnected manager.connections[client_id] = connection if i < 2: # Add subscriptions for connected clients manager.add_subscription(client_id, f"camera_{i}") stats = manager.get_connection_stats() assert stats["total_connections"] == 3 assert stats["active_connections"] == 2 assert stats["total_subscriptions"] == 2 assert "uptime" in stats class TestMessageHandler: """Test message handling functionality.""" def test_creation(self): """Test message handler creation.""" mock_processor = Mock() handler = MessageHandler(mock_processor) assert handler.message_processor == mock_processor assert handler.connection_manager is None def test_set_connection_manager(self): """Test setting connection manager.""" mock_processor = Mock() mock_manager = Mock() handler = MessageHandler(mock_processor) handler.set_connection_manager(mock_manager) assert handler.connection_manager == mock_manager @pytest.mark.asyncio async def test_handle_message_success(self, mock_websocket): """Test successful message handling.""" mock_processor = Mock() mock_processor.process_message = AsyncMock(return_value={"type": "response", "status": "success"}) handler = MessageHandler(mock_processor) connection = WebSocketConnection(mock_websocket, "client_001") message = {"type": "subscribe", "payload": {"camera_id": "camera_001"}} response = await handler.handle_message(connection, message) assert response["status"] == "success" mock_processor.process_message.assert_called_once_with(message, "client_001") @pytest.mark.asyncio async def test_handle_message_processing_error(self, mock_websocket): """Test message handling with processing error.""" mock_processor = Mock() mock_processor.process_message = AsyncMock(side_effect=MessageProcessingError("Invalid message")) handler = MessageHandler(mock_processor) connection = WebSocketConnection(mock_websocket, "client_001") message = {"type": "invalid", "payload": {}} response = await handler.handle_message(connection, message) assert response["type"] == "error" assert "Invalid message" in response["message"] @pytest.mark.asyncio async def test_handle_message_unexpected_error(self, mock_websocket): """Test message handling with unexpected error.""" mock_processor = Mock() mock_processor.process_message = AsyncMock(side_effect=Exception("Unexpected error")) handler = MessageHandler(mock_processor) connection = WebSocketConnection(mock_websocket, "client_001") message = {"type": "test", "payload": {}} response = await handler.handle_message(connection, message) assert response["type"] == "error" assert "internal error" in response["message"].lower() @pytest.mark.asyncio async def test_send_response(self, mock_websocket): """Test sending response to client.""" mock_processor = Mock() handler = MessageHandler(mock_processor) connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True mock_websocket.send_json = AsyncMock() response = {"type": "response", "data": "test response"} await handler.send_response(connection, response) mock_websocket.send_json.assert_called_once_with(response) @pytest.mark.asyncio async def test_send_error_response(self, mock_websocket): """Test sending error response.""" mock_processor = Mock() handler = MessageHandler(mock_processor) connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True mock_websocket.send_json = AsyncMock() await handler.send_error_response(connection, "Test error message", "TEST_ERROR") mock_websocket.send_json.assert_called_once() call_args = mock_websocket.send_json.call_args[0][0] assert call_args["type"] == "error" assert call_args["message"] == "Test error message" assert call_args["error_code"] == "TEST_ERROR" class TestWebSocketHandler: """Test main WebSocket handler functionality.""" def test_initialization(self): """Test WebSocket handler initialization.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor) assert isinstance(handler.connection_manager, ConnectionManager) assert isinstance(handler.message_handler, MessageHandler) assert handler.message_handler.connection_manager == handler.connection_manager assert handler.heartbeat_interval == 30.0 assert handler.max_connections == 100 def test_initialization_with_config(self): """Test initialization with custom configuration.""" mock_processor = Mock() config = { "heartbeat_interval": 60.0, "max_connections": 200, "connection_timeout": 300.0 } handler = WebSocketHandler(mock_processor, config) assert handler.heartbeat_interval == 60.0 assert handler.max_connections == 200 assert handler.connection_timeout == 300.0 @pytest.mark.asyncio async def test_handle_websocket_connection(self, mock_websocket): """Test handling WebSocket connection.""" mock_processor = Mock() mock_processor.process_message = AsyncMock(return_value={"type": "ack", "status": "success"}) handler = WebSocketHandler(mock_processor) # Mock WebSocket behavior mock_websocket.accept = AsyncMock() mock_websocket.receive_json = AsyncMock(side_effect=[ {"type": "subscribe", "payload": {"camera_id": "camera_001"}}, WebSocketDisconnect() # Simulate disconnection ]) mock_websocket.send_json = AsyncMock() client_id = "test_client_001" # Handle connection (should not raise exception) await handler.handle_websocket(mock_websocket, client_id) # Verify connection was accepted mock_websocket.accept.assert_called_once() # Verify message was processed mock_processor.process_message.assert_called_once() @pytest.mark.asyncio async def test_handle_websocket_max_connections(self, mock_websocket): """Test handling max connections limit.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor, {"max_connections": 1}) # Add one connection to reach limit client1_ws = Mock() connection1 = WebSocketConnection(client1_ws, "client_001") handler.connection_manager.connections["client_001"] = connection1 mock_websocket.close = AsyncMock() # Try to add second connection await handler.handle_websocket(mock_websocket, "client_002") # Should close connection due to limit mock_websocket.close.assert_called_once() @pytest.mark.asyncio async def test_broadcast_message(self): """Test broadcasting message to all connections.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor) # Mock connection manager handler.connection_manager.broadcast = AsyncMock() message = {"type": "system", "data": "Server maintenance in 10 minutes"} await handler.broadcast_message(message) handler.connection_manager.broadcast.assert_called_once_with(message) @pytest.mark.asyncio async def test_send_to_client(self): """Test sending message to specific client.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor) # Create mock connection mock_websocket = Mock() mock_websocket.send_json = AsyncMock() connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True handler.connection_manager.connections["client_001"] = connection message = {"type": "notification", "data": "Personal message"} result = await handler.send_to_client("client_001", message) assert result is True mock_websocket.send_json.assert_called_once_with(message) @pytest.mark.asyncio async def test_send_to_nonexistent_client(self): """Test sending message to non-existent client.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor) message = {"type": "notification", "data": "Message"} result = await handler.send_to_client("nonexistent_client", message) assert result is False @pytest.mark.asyncio async def test_send_to_subscription(self): """Test sending message to subscription.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor) # Mock connection manager handler.connection_manager.broadcast_to_subscription = AsyncMock() subscription_id = "camera_001" message = {"type": "detection", "data": {"class": "car", "confidence": 0.95}} await handler.send_to_subscription(subscription_id, message) handler.connection_manager.broadcast_to_subscription.assert_called_once_with(subscription_id, message) @pytest.mark.asyncio async def test_start_heartbeat_task(self): """Test starting heartbeat task.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor, {"heartbeat_interval": 0.1}) # Mock connection with ping capability mock_websocket = Mock() mock_websocket.ping = AsyncMock() connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True handler.connection_manager.connections["client_001"] = connection # Start heartbeat task heartbeat_task = asyncio.create_task(handler._heartbeat_loop()) # Let it run briefly await asyncio.sleep(0.2) # Cancel task heartbeat_task.cancel() try: await heartbeat_task except asyncio.CancelledError: pass # Should have sent at least one ping assert mock_websocket.ping.called def test_get_connection_stats(self): """Test getting connection statistics.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor) # Add some mock connections for i in range(3): client_id = f"client_{i}" ws = Mock() connection = WebSocketConnection(ws, client_id) connection.is_connected = True handler.connection_manager.connections[client_id] = connection stats = handler.get_connection_stats() assert stats["total_connections"] == 3 assert stats["active_connections"] == 3 def test_get_client_info(self): """Test getting client information.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor) # Add mock connection mock_websocket = Mock() connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True connection.subscription_id = "camera_001" handler.connection_manager.connections["client_001"] = connection info = handler.get_client_info("client_001") assert info is not None assert info["client_id"] == "client_001" assert info["is_connected"] is True assert info["subscription_id"] == "camera_001" @pytest.mark.asyncio async def test_disconnect_client(self): """Test disconnecting specific client.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor) # Add mock connection mock_websocket = Mock() mock_websocket.close = AsyncMock() connection = WebSocketConnection(mock_websocket, "client_001") connection.is_connected = True handler.connection_manager.connections["client_001"] = connection result = await handler.disconnect_client("client_001", code=1000, reason="Admin disconnect") assert result is True mock_websocket.close.assert_called_once_with(code=1000, reason="Admin disconnect") @pytest.mark.asyncio async def test_cleanup_connections(self): """Test cleanup of disconnected connections.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor) # Mock connection manager cleanup handler.connection_manager.cleanup_disconnected = AsyncMock(return_value=2) cleaned_count = await handler.cleanup_connections() assert cleaned_count == 2 handler.connection_manager.cleanup_disconnected.assert_called_once() class TestWebSocketHandlerIntegration: """Integration tests for WebSocket handler.""" @pytest.mark.asyncio async def test_complete_subscription_workflow(self, mock_websocket): """Test complete subscription workflow.""" mock_processor = Mock() # Mock processor responses mock_processor.process_message = AsyncMock(side_effect=[ {"type": "subscribeAck", "status": "success", "subscription_id": "camera_001"}, {"type": "unsubscribeAck", "status": "success"} ]) handler = WebSocketHandler(mock_processor) # Mock WebSocket behavior mock_websocket.accept = AsyncMock() mock_websocket.send_json = AsyncMock() mock_websocket.receive_json = AsyncMock(side_effect=[ {"type": "subscribe", "payload": {"camera_id": "camera_001", "rtsp_url": "rtsp://example.com"}}, {"type": "unsubscribe", "payload": {"subscription_id": "camera_001"}}, WebSocketDisconnect() ]) client_id = "test_client" # Handle complete workflow await handler.handle_websocket(mock_websocket, client_id) # Verify both messages were processed assert mock_processor.process_message.call_count == 2 # Verify responses were sent assert mock_websocket.send_json.call_count == 2 @pytest.mark.asyncio async def test_multiple_client_management(self): """Test managing multiple concurrent clients.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor) clients = [] for i in range(5): client_id = f"client_{i}" mock_ws = Mock() mock_ws.send_json = AsyncMock() connection = WebSocketConnection(mock_ws, client_id) connection.is_connected = True handler.connection_manager.connections[client_id] = connection clients.append(connection) # Test broadcasting to all clients message = {"type": "broadcast", "data": "Hello all clients"} await handler.broadcast_message(message) # All clients should receive the message for connection in clients: connection.websocket.send_json.assert_called_once_with(message) # Test subscription-specific messaging subscription_id = "camera_001" handler.connection_manager.add_subscription("client_0", subscription_id) handler.connection_manager.add_subscription("client_2", subscription_id) subscription_message = {"type": "detection", "camera_id": "camera_001"} await handler.send_to_subscription(subscription_id, subscription_message) # Only subscribed clients should receive the message # Note: This would require additional mocking of broadcast_to_subscription @pytest.mark.asyncio async def test_error_handling_and_recovery(self, mock_websocket): """Test error handling and recovery scenarios.""" mock_processor = Mock() # First message causes error, second succeeds mock_processor.process_message = AsyncMock(side_effect=[ MessageProcessingError("Invalid message format"), {"type": "ack", "status": "success"} ]) handler = WebSocketHandler(mock_processor) mock_websocket.accept = AsyncMock() mock_websocket.send_json = AsyncMock() mock_websocket.receive_json = AsyncMock(side_effect=[ {"type": "invalid", "malformed": True}, {"type": "valid", "payload": {"test": True}}, WebSocketDisconnect() ]) client_id = "error_test_client" # Should handle errors gracefully and continue processing await handler.handle_websocket(mock_websocket, client_id) # Both messages should have been processed assert mock_processor.process_message.call_count == 2 # Should have sent error response and success response assert mock_websocket.send_json.call_count == 2 # First call should be error response first_response = mock_websocket.send_json.call_args_list[0][0][0] assert first_response["type"] == "error" @pytest.mark.asyncio async def test_connection_timeout_handling(self): """Test connection timeout handling.""" mock_processor = Mock() handler = WebSocketHandler(mock_processor, {"connection_timeout": 0.1}) # Add connection that hasn't been active mock_websocket = Mock() connection = WebSocketConnection(mock_websocket, "timeout_client") connection.is_connected = True # Don't update last_ping to simulate timeout handler.connection_manager.connections["timeout_client"] = connection # Wait longer than timeout await asyncio.sleep(0.2) # Manual cleanup (in real implementation this would be automatic) cleaned = await handler.cleanup_connections() # Connection should be identified for cleanup # (Actual timeout logic would need to be implemented in the cleanup method)