""" Integration tests for WebSocket protocol compliance. Tests the complete WebSocket communication protocol to ensure compatibility with existing clients and proper message handling. """ import pytest import asyncio import json import uuid from unittest.mock import Mock, AsyncMock, patch from fastapi.websockets import WebSocket from fastapi.testclient import TestClient from detector_worker.app import create_app from detector_worker.communication.websocket_handler import WebSocketHandler from detector_worker.communication.message_processor import MessageProcessor, MessageType from detector_worker.core.exceptions import MessageProcessingError @pytest.fixture def test_app(): """Create test FastAPI application.""" return create_app() @pytest.fixture def mock_websocket(): """Create mock WebSocket for testing.""" websocket = Mock(spec=WebSocket) websocket.accept = AsyncMock() websocket.send_json = AsyncMock() websocket.send_text = AsyncMock() websocket.receive_json = AsyncMock() websocket.receive_text = AsyncMock() websocket.close = AsyncMock() websocket.ping = AsyncMock() return websocket class TestWebSocketProtocol: """Test WebSocket protocol compliance.""" @pytest.mark.asyncio async def test_subscription_message_protocol(self, mock_websocket): """Test subscription message handling protocol.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor) # Mock external dependencies with patch('cv2.VideoCapture') as mock_video_cap, \ patch('torch.load') as mock_torch_load, \ patch('builtins.open') as mock_file_open: # Setup video capture mock mock_cap_instance = Mock() mock_video_cap.return_value = mock_cap_instance mock_cap_instance.isOpened.return_value = True # Setup model loading mock mock_torch_load.return_value = Mock() # Setup pipeline file mock pipeline_config = { "modelId": "test_detection_model", "modelFile": "test_model.pt", "expectedClasses": ["Car"], "minConfidence": 0.8 } mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config) # Test message sequence subscription_message = { "type": "subscribe", "payload": { "subscriptionIdentifier": "display-001;cam-001", "rtspUrl": "rtsp://example.com/stream", "modelUrl": "http://example.com/model.mpta", "modelId": 101, "modelName": "Test Detection", "cropX1": 0, "cropY1": 0, "cropX2": 640, "cropY2": 480 } } request_state_message = { "type": "requestState" } unsubscribe_message = { "type": "unsubscribe", "payload": { "subscriptionIdentifier": "display-001;cam-001" } } # Mock WebSocket message sequence mock_websocket.receive_json.side_effect = [ subscription_message, request_state_message, unsubscribe_message, asyncio.CancelledError() # Simulate client disconnect ] client_id = "test_client_001" try: # Handle WebSocket connection await websocket_handler.handle_websocket(mock_websocket, client_id) except asyncio.CancelledError: pass # Expected when client disconnects # Verify protocol compliance mock_websocket.accept.assert_called_once() # Check sent messages sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list] # Should receive subscription acknowledgment subscribe_acks = [msg for msg in sent_messages if msg.get("type") == "subscribeAck"] assert len(subscribe_acks) >= 1 subscribe_ack = subscribe_acks[0] assert subscribe_ack["status"] in ["success", "error"] if subscribe_ack["status"] == "success": assert "subscriptionId" in subscribe_ack # Should receive state report state_reports = [msg for msg in sent_messages if msg.get("type") == "stateReport"] assert len(state_reports) >= 1 state_report = state_reports[0] assert "payload" in state_report assert "subscriptions" in state_report["payload"] assert "system" in state_report["payload"] # Should receive unsubscribe acknowledgment unsubscribe_acks = [msg for msg in sent_messages if msg.get("type") == "unsubscribeAck"] assert len(unsubscribe_acks) >= 1 unsubscribe_ack = unsubscribe_acks[0] assert unsubscribe_ack["status"] in ["success", "error"] @pytest.mark.asyncio async def test_invalid_message_handling(self, mock_websocket): """Test handling of invalid messages.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor) # Test invalid message types invalid_messages = [ {"invalid": "message"}, # Missing type {"type": "unknown_type", "payload": {}}, # Unknown type {"type": "subscribe"}, # Missing payload {"type": "subscribe", "payload": {}}, # Missing required fields {"type": "subscribe", "payload": {"subscriptionIdentifier": "test"}}, # Missing URL ] mock_websocket.receive_json.side_effect = invalid_messages + [asyncio.CancelledError()] client_id = "test_client_error" try: await websocket_handler.handle_websocket(mock_websocket, client_id) except asyncio.CancelledError: pass # Verify error responses sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list] error_messages = [msg for msg in sent_messages if msg.get("type") == "error"] # Should receive error responses for invalid messages assert len(error_messages) >= len(invalid_messages) for error_msg in error_messages: assert "message" in error_msg assert error_msg["message"] # Non-empty error message @pytest.mark.asyncio async def test_session_management_protocol(self, mock_websocket): """Test session management protocol.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor) session_id = str(uuid.uuid4()) # Test session messages set_session_message = { "type": "setSessionId", "payload": { "sessionId": session_id, "displayId": "display-001" } } patch_session_message = { "type": "patchSession", "payload": { "sessionId": session_id, "data": { "car_brand": "Toyota", "confidence": 0.92 } } } mock_websocket.receive_json.side_effect = [ set_session_message, patch_session_message, asyncio.CancelledError() ] client_id = "test_client_session" try: await websocket_handler.handle_websocket(mock_websocket, client_id) except asyncio.CancelledError: pass # Verify session responses sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list] # Should receive acknowledgments for session operations set_session_acks = [msg for msg in sent_messages if msg.get("type") == "setSessionIdAck"] assert len(set_session_acks) >= 1 patch_session_acks = [msg for msg in sent_messages if msg.get("type") == "patchSessionAck"] assert len(patch_session_acks) >= 1 @pytest.mark.asyncio async def test_heartbeat_protocol(self, mock_websocket): """Test heartbeat/ping-pong protocol.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor, {"heartbeat_interval": 0.1}) # Simulate long-running connection mock_websocket.receive_json.side_effect = [ asyncio.TimeoutError(), # No messages for a while asyncio.TimeoutError(), asyncio.CancelledError() # Then disconnect ] client_id = "test_client_heartbeat" # Start heartbeat task heartbeat_task = asyncio.create_task(websocket_handler._heartbeat_loop()) try: # Let heartbeat run briefly await asyncio.sleep(0.2) # Cancel heartbeat heartbeat_task.cancel() await heartbeat_task except asyncio.CancelledError: pass # Verify ping was sent assert mock_websocket.ping.called @pytest.mark.asyncio async def test_detection_result_protocol(self, mock_websocket): """Test detection result message protocol.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor) # Mock a detection result being sent detection_result = { "type": "imageDetection", "payload": { "subscriptionId": "display-001;cam-001", "detections": [ { "class": "Car", "confidence": 0.95, "bbox": [100, 200, 300, 400], "trackId": 1001 } ], "timestamp": 1640995200000, "modelInfo": { "modelId": 101, "modelName": "Vehicle Detection" } } } # Send detection result to client await websocket_handler.send_to_client("test_client", detection_result) # Verify message was sent mock_websocket.send_json.assert_called_with(detection_result) @pytest.mark.asyncio async def test_error_recovery_protocol(self, mock_websocket): """Test error recovery and graceful degradation.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor) # Simulate WebSocket errors mock_websocket.send_json.side_effect = [ None, # First message succeeds ConnectionError("Connection lost"), # Second fails None # Third succeeds after recovery ] # Try to send multiple messages messages = [ {"type": "test", "message": "1"}, {"type": "test", "message": "2"}, {"type": "test", "message": "3"} ] results = [] for msg in messages: try: await websocket_handler.send_to_client("test_client", msg) results.append("success") except Exception: results.append("error") # Should handle errors gracefully assert "error" in results # But should still be able to send other messages assert "success" in results @pytest.mark.asyncio async def test_concurrent_client_protocol(self, mock_websocket): """Test handling multiple concurrent clients.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor) # Create multiple mock WebSocket connections mock_websockets = [] for i in range(3): ws = Mock(spec=WebSocket) ws.accept = AsyncMock() ws.send_json = AsyncMock() ws.receive_json = AsyncMock(side_effect=[asyncio.CancelledError()]) mock_websockets.append(ws) # Handle multiple clients concurrently client_tasks = [] for i, ws in enumerate(mock_websockets): task = asyncio.create_task( websocket_handler.handle_websocket(ws, f"client_{i}") ) client_tasks.append(task) # Wait briefly then cancel all await asyncio.sleep(0.1) for task in client_tasks: task.cancel() try: await asyncio.gather(*client_tasks) except asyncio.CancelledError: pass # Verify all connections were accepted for ws in mock_websockets: ws.accept.assert_called_once() @pytest.mark.asyncio async def test_subscription_sharing_protocol(self, mock_websocket): """Test shared subscription protocol.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor) # Mock multiple clients subscribing to same camera with patch('cv2.VideoCapture') as mock_video_cap: mock_cap_instance = Mock() mock_video_cap.return_value = mock_cap_instance mock_cap_instance.isOpened.return_value = True # First client subscribes subscription_msg1 = { "type": "subscribe", "payload": { "subscriptionIdentifier": "display-001;cam-001", "rtspUrl": "rtsp://shared.example.com/stream", "modelUrl": "http://example.com/model.mpta" } } # Second client subscribes to same camera subscription_msg2 = { "type": "subscribe", "payload": { "subscriptionIdentifier": "display-002;cam-001", "rtspUrl": "rtsp://shared.example.com/stream", # Same URL "modelUrl": "http://example.com/model.mpta" } } # Mock file operations with patch('builtins.open') as mock_file_open: pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]} mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config) # Process both subscriptions response1 = await message_processor.process_message(subscription_msg1, "client_1") response2 = await message_processor.process_message(subscription_msg2, "client_2") # Both should succeed and reference same underlying stream assert response1.get("status") == "success" assert response2.get("status") == "success" # Should only create one video capture instance (shared stream) assert mock_video_cap.call_count == 1 @pytest.mark.asyncio async def test_message_ordering_protocol(self, mock_websocket): """Test message ordering and sequencing.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor) # Test sequence of related messages messages = [ {"type": "subscribe", "payload": {"subscriptionIdentifier": "test", "rtspUrl": "rtsp://test.com"}}, {"type": "setSessionId", "payload": {"sessionId": "session_123", "displayId": "display_001"}}, {"type": "requestState"}, {"type": "patchSession", "payload": {"sessionId": "session_123", "data": {"test": "data"}}}, {"type": "unsubscribe", "payload": {"subscriptionIdentifier": "test"}} ] mock_websocket.receive_json.side_effect = messages + [asyncio.CancelledError()] with patch('cv2.VideoCapture') as mock_video_cap, \ patch('builtins.open') as mock_file_open: mock_cap_instance = Mock() mock_video_cap.return_value = mock_cap_instance mock_cap_instance.isOpened.return_value = True pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]} mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config) client_id = "test_client_ordering" try: await websocket_handler.handle_websocket(mock_websocket, client_id) except asyncio.CancelledError: pass # Verify responses were sent in correct order sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list] # Should receive responses for each message type response_types = [msg.get("type") for msg in sent_messages] expected_types = ["subscribeAck", "setSessionIdAck", "stateReport", "patchSessionAck", "unsubscribeAck"] # Check that we got appropriate responses (order may vary slightly) for expected_type in expected_types: assert any(expected_type in response_types), f"Missing response type: {expected_type}" class TestWebSocketPerformance: """Test WebSocket performance characteristics.""" @pytest.mark.asyncio async def test_message_throughput(self, mock_websocket): """Test message processing throughput.""" message_processor = MessageProcessor() # Prepare batch of simple messages state_request = {"type": "requestState"} import time start_time = time.time() # Process many messages quickly for _ in range(100): await message_processor.process_message(state_request, "test_client") end_time = time.time() processing_time = end_time - start_time # Should process messages quickly (less than 1 second for 100 messages) assert processing_time < 1.0 # Calculate throughput throughput = 100 / processing_time assert throughput > 100 # Should handle > 100 messages/second @pytest.mark.asyncio async def test_concurrent_message_handling(self, mock_websocket): """Test concurrent message handling.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor) # Create multiple mock clients num_clients = 10 mock_websockets = [] for i in range(num_clients): ws = Mock(spec=WebSocket) ws.accept = AsyncMock() ws.send_json = AsyncMock() ws.receive_json = AsyncMock(side_effect=[ {"type": "requestState"}, asyncio.CancelledError() ]) mock_websockets.append(ws) # Handle all clients concurrently client_tasks = [] for i, ws in enumerate(mock_websockets): task = asyncio.create_task( websocket_handler.handle_websocket(ws, f"perf_client_{i}") ) client_tasks.append(task) start_time = time.time() # Wait for all to complete try: await asyncio.gather(*client_tasks) except asyncio.CancelledError: pass end_time = time.time() total_time = end_time - start_time # Should handle all clients efficiently assert total_time < 2.0 # Should complete in less than 2 seconds # All clients should have been accepted for ws in mock_websockets: ws.accept.assert_called_once() @pytest.mark.asyncio async def test_memory_usage_stability(self, mock_websocket): """Test memory usage remains stable.""" message_processor = MessageProcessor() websocket_handler = WebSocketHandler(message_processor) # Simulate many connection/disconnection cycles for cycle in range(10): # Create client client_id = f"memory_test_client_{cycle}" mock_websocket.receive_json.side_effect = [ {"type": "requestState"}, asyncio.CancelledError() ] try: await websocket_handler.handle_websocket(mock_websocket, client_id) except asyncio.CancelledError: pass # Reset mock for next cycle mock_websocket.reset_mock() mock_websocket.accept = AsyncMock() mock_websocket.send_json = AsyncMock() mock_websocket.receive_json = AsyncMock() # Connection manager should not accumulate stale connections stats = websocket_handler.get_connection_stats() assert stats["total_connections"] == 0 # All should be cleaned up