python-detector-worker/tests/integration/test_websocket_protocol.py
2025-09-12 18:55:23 +07:00

579 lines
No EOL
21 KiB
Python

"""
Integration tests for WebSocket protocol compliance.
Tests the complete WebSocket communication protocol to ensure
compatibility with existing clients and proper message handling.
"""
import pytest
import asyncio
import json
import uuid
from unittest.mock import Mock, AsyncMock, patch
from fastapi.websockets import WebSocket
from fastapi.testclient import TestClient
from detector_worker.app import create_app
from detector_worker.communication.websocket_handler import WebSocketHandler
from detector_worker.communication.message_processor import MessageProcessor, MessageType
from detector_worker.core.exceptions import MessageProcessingError
@pytest.fixture
def test_app():
"""Create test FastAPI application."""
return create_app()
@pytest.fixture
def mock_websocket():
"""Create mock WebSocket for testing."""
websocket = Mock(spec=WebSocket)
websocket.accept = AsyncMock()
websocket.send_json = AsyncMock()
websocket.send_text = AsyncMock()
websocket.receive_json = AsyncMock()
websocket.receive_text = AsyncMock()
websocket.close = AsyncMock()
websocket.ping = AsyncMock()
return websocket
class TestWebSocketProtocol:
"""Test WebSocket protocol compliance."""
@pytest.mark.asyncio
async def test_subscription_message_protocol(self, mock_websocket):
"""Test subscription message handling protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Mock external dependencies
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load, \
patch('builtins.open') as mock_file_open:
# Setup video capture mock
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
# Setup model loading mock
mock_torch_load.return_value = Mock()
# Setup pipeline file mock
pipeline_config = {
"modelId": "test_detection_model",
"modelFile": "test_model.pt",
"expectedClasses": ["Car"],
"minConfidence": 0.8
}
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
# Test message sequence
subscription_message = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://example.com/stream",
"modelUrl": "http://example.com/model.mpta",
"modelId": 101,
"modelName": "Test Detection",
"cropX1": 0,
"cropY1": 0,
"cropX2": 640,
"cropY2": 480
}
}
request_state_message = {
"type": "requestState"
}
unsubscribe_message = {
"type": "unsubscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001"
}
}
# Mock WebSocket message sequence
mock_websocket.receive_json.side_effect = [
subscription_message,
request_state_message,
unsubscribe_message,
asyncio.CancelledError() # Simulate client disconnect
]
client_id = "test_client_001"
try:
# Handle WebSocket connection
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass # Expected when client disconnects
# Verify protocol compliance
mock_websocket.accept.assert_called_once()
# Check sent messages
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
# Should receive subscription acknowledgment
subscribe_acks = [msg for msg in sent_messages if msg.get("type") == "subscribeAck"]
assert len(subscribe_acks) >= 1
subscribe_ack = subscribe_acks[0]
assert subscribe_ack["status"] in ["success", "error"]
if subscribe_ack["status"] == "success":
assert "subscriptionId" in subscribe_ack
# Should receive state report
state_reports = [msg for msg in sent_messages if msg.get("type") == "stateReport"]
assert len(state_reports) >= 1
state_report = state_reports[0]
assert "payload" in state_report
assert "subscriptions" in state_report["payload"]
assert "system" in state_report["payload"]
# Should receive unsubscribe acknowledgment
unsubscribe_acks = [msg for msg in sent_messages if msg.get("type") == "unsubscribeAck"]
assert len(unsubscribe_acks) >= 1
unsubscribe_ack = unsubscribe_acks[0]
assert unsubscribe_ack["status"] in ["success", "error"]
@pytest.mark.asyncio
async def test_invalid_message_handling(self, mock_websocket):
"""Test handling of invalid messages."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Test invalid message types
invalid_messages = [
{"invalid": "message"}, # Missing type
{"type": "unknown_type", "payload": {}}, # Unknown type
{"type": "subscribe"}, # Missing payload
{"type": "subscribe", "payload": {}}, # Missing required fields
{"type": "subscribe", "payload": {"subscriptionIdentifier": "test"}}, # Missing URL
]
mock_websocket.receive_json.side_effect = invalid_messages + [asyncio.CancelledError()]
client_id = "test_client_error"
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Verify error responses
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
error_messages = [msg for msg in sent_messages if msg.get("type") == "error"]
# Should receive error responses for invalid messages
assert len(error_messages) >= len(invalid_messages)
for error_msg in error_messages:
assert "message" in error_msg
assert error_msg["message"] # Non-empty error message
@pytest.mark.asyncio
async def test_session_management_protocol(self, mock_websocket):
"""Test session management protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
session_id = str(uuid.uuid4())
# Test session messages
set_session_message = {
"type": "setSessionId",
"payload": {
"sessionId": session_id,
"displayId": "display-001"
}
}
patch_session_message = {
"type": "patchSession",
"payload": {
"sessionId": session_id,
"data": {
"car_brand": "Toyota",
"confidence": 0.92
}
}
}
mock_websocket.receive_json.side_effect = [
set_session_message,
patch_session_message,
asyncio.CancelledError()
]
client_id = "test_client_session"
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Verify session responses
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
# Should receive acknowledgments for session operations
set_session_acks = [msg for msg in sent_messages
if msg.get("type") == "setSessionIdAck"]
assert len(set_session_acks) >= 1
patch_session_acks = [msg for msg in sent_messages
if msg.get("type") == "patchSessionAck"]
assert len(patch_session_acks) >= 1
@pytest.mark.asyncio
async def test_heartbeat_protocol(self, mock_websocket):
"""Test heartbeat/ping-pong protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor, {"heartbeat_interval": 0.1})
# Simulate long-running connection
mock_websocket.receive_json.side_effect = [
asyncio.TimeoutError(), # No messages for a while
asyncio.TimeoutError(),
asyncio.CancelledError() # Then disconnect
]
client_id = "test_client_heartbeat"
# Start heartbeat task
heartbeat_task = asyncio.create_task(websocket_handler._heartbeat_loop())
try:
# Let heartbeat run briefly
await asyncio.sleep(0.2)
# Cancel heartbeat
heartbeat_task.cancel()
await heartbeat_task
except asyncio.CancelledError:
pass
# Verify ping was sent
assert mock_websocket.ping.called
@pytest.mark.asyncio
async def test_detection_result_protocol(self, mock_websocket):
"""Test detection result message protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Mock a detection result being sent
detection_result = {
"type": "imageDetection",
"payload": {
"subscriptionId": "display-001;cam-001",
"detections": [
{
"class": "Car",
"confidence": 0.95,
"bbox": [100, 200, 300, 400],
"trackId": 1001
}
],
"timestamp": 1640995200000,
"modelInfo": {
"modelId": 101,
"modelName": "Vehicle Detection"
}
}
}
# Send detection result to client
await websocket_handler.send_to_client("test_client", detection_result)
# Verify message was sent
mock_websocket.send_json.assert_called_with(detection_result)
@pytest.mark.asyncio
async def test_error_recovery_protocol(self, mock_websocket):
"""Test error recovery and graceful degradation."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Simulate WebSocket errors
mock_websocket.send_json.side_effect = [
None, # First message succeeds
ConnectionError("Connection lost"), # Second fails
None # Third succeeds after recovery
]
# Try to send multiple messages
messages = [
{"type": "test", "message": "1"},
{"type": "test", "message": "2"},
{"type": "test", "message": "3"}
]
results = []
for msg in messages:
try:
await websocket_handler.send_to_client("test_client", msg)
results.append("success")
except Exception:
results.append("error")
# Should handle errors gracefully
assert "error" in results
# But should still be able to send other messages
assert "success" in results
@pytest.mark.asyncio
async def test_concurrent_client_protocol(self, mock_websocket):
"""Test handling multiple concurrent clients."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Create multiple mock WebSocket connections
mock_websockets = []
for i in range(3):
ws = Mock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
ws.receive_json = AsyncMock(side_effect=[asyncio.CancelledError()])
mock_websockets.append(ws)
# Handle multiple clients concurrently
client_tasks = []
for i, ws in enumerate(mock_websockets):
task = asyncio.create_task(
websocket_handler.handle_websocket(ws, f"client_{i}")
)
client_tasks.append(task)
# Wait briefly then cancel all
await asyncio.sleep(0.1)
for task in client_tasks:
task.cancel()
try:
await asyncio.gather(*client_tasks)
except asyncio.CancelledError:
pass
# Verify all connections were accepted
for ws in mock_websockets:
ws.accept.assert_called_once()
@pytest.mark.asyncio
async def test_subscription_sharing_protocol(self, mock_websocket):
"""Test shared subscription protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Mock multiple clients subscribing to same camera
with patch('cv2.VideoCapture') as mock_video_cap:
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
# First client subscribes
subscription_msg1 = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://shared.example.com/stream",
"modelUrl": "http://example.com/model.mpta"
}
}
# Second client subscribes to same camera
subscription_msg2 = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-002;cam-001",
"rtspUrl": "rtsp://shared.example.com/stream", # Same URL
"modelUrl": "http://example.com/model.mpta"
}
}
# Mock file operations
with patch('builtins.open') as mock_file_open:
pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]}
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
# Process both subscriptions
response1 = await message_processor.process_message(subscription_msg1, "client_1")
response2 = await message_processor.process_message(subscription_msg2, "client_2")
# Both should succeed and reference same underlying stream
assert response1.get("status") == "success"
assert response2.get("status") == "success"
# Should only create one video capture instance (shared stream)
assert mock_video_cap.call_count == 1
@pytest.mark.asyncio
async def test_message_ordering_protocol(self, mock_websocket):
"""Test message ordering and sequencing."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Test sequence of related messages
messages = [
{"type": "subscribe", "payload": {"subscriptionIdentifier": "test", "rtspUrl": "rtsp://test.com"}},
{"type": "setSessionId", "payload": {"sessionId": "session_123", "displayId": "display_001"}},
{"type": "requestState"},
{"type": "patchSession", "payload": {"sessionId": "session_123", "data": {"test": "data"}}},
{"type": "unsubscribe", "payload": {"subscriptionIdentifier": "test"}}
]
mock_websocket.receive_json.side_effect = messages + [asyncio.CancelledError()]
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('builtins.open') as mock_file_open:
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]}
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
client_id = "test_client_ordering"
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Verify responses were sent in correct order
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
# Should receive responses for each message type
response_types = [msg.get("type") for msg in sent_messages]
expected_types = ["subscribeAck", "setSessionIdAck", "stateReport", "patchSessionAck", "unsubscribeAck"]
# Check that we got appropriate responses (order may vary slightly)
for expected_type in expected_types:
assert any(expected_type in response_types), f"Missing response type: {expected_type}"
class TestWebSocketPerformance:
"""Test WebSocket performance characteristics."""
@pytest.mark.asyncio
async def test_message_throughput(self, mock_websocket):
"""Test message processing throughput."""
message_processor = MessageProcessor()
# Prepare batch of simple messages
state_request = {"type": "requestState"}
import time
start_time = time.time()
# Process many messages quickly
for _ in range(100):
await message_processor.process_message(state_request, "test_client")
end_time = time.time()
processing_time = end_time - start_time
# Should process messages quickly (less than 1 second for 100 messages)
assert processing_time < 1.0
# Calculate throughput
throughput = 100 / processing_time
assert throughput > 100 # Should handle > 100 messages/second
@pytest.mark.asyncio
async def test_concurrent_message_handling(self, mock_websocket):
"""Test concurrent message handling."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Create multiple mock clients
num_clients = 10
mock_websockets = []
for i in range(num_clients):
ws = Mock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
ws.receive_json = AsyncMock(side_effect=[
{"type": "requestState"},
asyncio.CancelledError()
])
mock_websockets.append(ws)
# Handle all clients concurrently
client_tasks = []
for i, ws in enumerate(mock_websockets):
task = asyncio.create_task(
websocket_handler.handle_websocket(ws, f"perf_client_{i}")
)
client_tasks.append(task)
start_time = time.time()
# Wait for all to complete
try:
await asyncio.gather(*client_tasks)
except asyncio.CancelledError:
pass
end_time = time.time()
total_time = end_time - start_time
# Should handle all clients efficiently
assert total_time < 2.0 # Should complete in less than 2 seconds
# All clients should have been accepted
for ws in mock_websockets:
ws.accept.assert_called_once()
@pytest.mark.asyncio
async def test_memory_usage_stability(self, mock_websocket):
"""Test memory usage remains stable."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Simulate many connection/disconnection cycles
for cycle in range(10):
# Create client
client_id = f"memory_test_client_{cycle}"
mock_websocket.receive_json.side_effect = [
{"type": "requestState"},
asyncio.CancelledError()
]
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Reset mock for next cycle
mock_websocket.reset_mock()
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock()
# Connection manager should not accumulate stale connections
stats = websocket_handler.get_connection_stats()
assert stats["total_connections"] == 0 # All should be cleaned up