579 lines
No EOL
21 KiB
Python
579 lines
No EOL
21 KiB
Python
"""
|
|
Integration tests for WebSocket protocol compliance.
|
|
|
|
Tests the complete WebSocket communication protocol to ensure
|
|
compatibility with existing clients and proper message handling.
|
|
"""
|
|
import pytest
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from unittest.mock import Mock, AsyncMock, patch
|
|
from fastapi.websockets import WebSocket
|
|
from fastapi.testclient import TestClient
|
|
|
|
from detector_worker.app import create_app
|
|
from detector_worker.communication.websocket_handler import WebSocketHandler
|
|
from detector_worker.communication.message_processor import MessageProcessor, MessageType
|
|
from detector_worker.core.exceptions import MessageProcessingError
|
|
|
|
|
|
@pytest.fixture
|
|
def test_app():
|
|
"""Create test FastAPI application."""
|
|
return create_app()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_websocket():
|
|
"""Create mock WebSocket for testing."""
|
|
websocket = Mock(spec=WebSocket)
|
|
websocket.accept = AsyncMock()
|
|
websocket.send_json = AsyncMock()
|
|
websocket.send_text = AsyncMock()
|
|
websocket.receive_json = AsyncMock()
|
|
websocket.receive_text = AsyncMock()
|
|
websocket.close = AsyncMock()
|
|
websocket.ping = AsyncMock()
|
|
return websocket
|
|
|
|
|
|
class TestWebSocketProtocol:
|
|
"""Test WebSocket protocol compliance."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscription_message_protocol(self, mock_websocket):
|
|
"""Test subscription message handling protocol."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
# Mock external dependencies
|
|
with patch('cv2.VideoCapture') as mock_video_cap, \
|
|
patch('torch.load') as mock_torch_load, \
|
|
patch('builtins.open') as mock_file_open:
|
|
|
|
# Setup video capture mock
|
|
mock_cap_instance = Mock()
|
|
mock_video_cap.return_value = mock_cap_instance
|
|
mock_cap_instance.isOpened.return_value = True
|
|
|
|
# Setup model loading mock
|
|
mock_torch_load.return_value = Mock()
|
|
|
|
# Setup pipeline file mock
|
|
pipeline_config = {
|
|
"modelId": "test_detection_model",
|
|
"modelFile": "test_model.pt",
|
|
"expectedClasses": ["Car"],
|
|
"minConfidence": 0.8
|
|
}
|
|
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
|
|
|
|
# Test message sequence
|
|
subscription_message = {
|
|
"type": "subscribe",
|
|
"payload": {
|
|
"subscriptionIdentifier": "display-001;cam-001",
|
|
"rtspUrl": "rtsp://example.com/stream",
|
|
"modelUrl": "http://example.com/model.mpta",
|
|
"modelId": 101,
|
|
"modelName": "Test Detection",
|
|
"cropX1": 0,
|
|
"cropY1": 0,
|
|
"cropX2": 640,
|
|
"cropY2": 480
|
|
}
|
|
}
|
|
|
|
request_state_message = {
|
|
"type": "requestState"
|
|
}
|
|
|
|
unsubscribe_message = {
|
|
"type": "unsubscribe",
|
|
"payload": {
|
|
"subscriptionIdentifier": "display-001;cam-001"
|
|
}
|
|
}
|
|
|
|
# Mock WebSocket message sequence
|
|
mock_websocket.receive_json.side_effect = [
|
|
subscription_message,
|
|
request_state_message,
|
|
unsubscribe_message,
|
|
asyncio.CancelledError() # Simulate client disconnect
|
|
]
|
|
|
|
client_id = "test_client_001"
|
|
|
|
try:
|
|
# Handle WebSocket connection
|
|
await websocket_handler.handle_websocket(mock_websocket, client_id)
|
|
except asyncio.CancelledError:
|
|
pass # Expected when client disconnects
|
|
|
|
# Verify protocol compliance
|
|
mock_websocket.accept.assert_called_once()
|
|
|
|
# Check sent messages
|
|
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
|
|
|
|
# Should receive subscription acknowledgment
|
|
subscribe_acks = [msg for msg in sent_messages if msg.get("type") == "subscribeAck"]
|
|
assert len(subscribe_acks) >= 1
|
|
|
|
subscribe_ack = subscribe_acks[0]
|
|
assert subscribe_ack["status"] in ["success", "error"]
|
|
if subscribe_ack["status"] == "success":
|
|
assert "subscriptionId" in subscribe_ack
|
|
|
|
# Should receive state report
|
|
state_reports = [msg for msg in sent_messages if msg.get("type") == "stateReport"]
|
|
assert len(state_reports) >= 1
|
|
|
|
state_report = state_reports[0]
|
|
assert "payload" in state_report
|
|
assert "subscriptions" in state_report["payload"]
|
|
assert "system" in state_report["payload"]
|
|
|
|
# Should receive unsubscribe acknowledgment
|
|
unsubscribe_acks = [msg for msg in sent_messages if msg.get("type") == "unsubscribeAck"]
|
|
assert len(unsubscribe_acks) >= 1
|
|
|
|
unsubscribe_ack = unsubscribe_acks[0]
|
|
assert unsubscribe_ack["status"] in ["success", "error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_message_handling(self, mock_websocket):
|
|
"""Test handling of invalid messages."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
# Test invalid message types
|
|
invalid_messages = [
|
|
{"invalid": "message"}, # Missing type
|
|
{"type": "unknown_type", "payload": {}}, # Unknown type
|
|
{"type": "subscribe"}, # Missing payload
|
|
{"type": "subscribe", "payload": {}}, # Missing required fields
|
|
{"type": "subscribe", "payload": {"subscriptionIdentifier": "test"}}, # Missing URL
|
|
]
|
|
|
|
mock_websocket.receive_json.side_effect = invalid_messages + [asyncio.CancelledError()]
|
|
|
|
client_id = "test_client_error"
|
|
|
|
try:
|
|
await websocket_handler.handle_websocket(mock_websocket, client_id)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Verify error responses
|
|
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
|
|
error_messages = [msg for msg in sent_messages if msg.get("type") == "error"]
|
|
|
|
# Should receive error responses for invalid messages
|
|
assert len(error_messages) >= len(invalid_messages)
|
|
|
|
for error_msg in error_messages:
|
|
assert "message" in error_msg
|
|
assert error_msg["message"] # Non-empty error message
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_management_protocol(self, mock_websocket):
|
|
"""Test session management protocol."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
session_id = str(uuid.uuid4())
|
|
|
|
# Test session messages
|
|
set_session_message = {
|
|
"type": "setSessionId",
|
|
"payload": {
|
|
"sessionId": session_id,
|
|
"displayId": "display-001"
|
|
}
|
|
}
|
|
|
|
patch_session_message = {
|
|
"type": "patchSession",
|
|
"payload": {
|
|
"sessionId": session_id,
|
|
"data": {
|
|
"car_brand": "Toyota",
|
|
"confidence": 0.92
|
|
}
|
|
}
|
|
}
|
|
|
|
mock_websocket.receive_json.side_effect = [
|
|
set_session_message,
|
|
patch_session_message,
|
|
asyncio.CancelledError()
|
|
]
|
|
|
|
client_id = "test_client_session"
|
|
|
|
try:
|
|
await websocket_handler.handle_websocket(mock_websocket, client_id)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Verify session responses
|
|
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
|
|
|
|
# Should receive acknowledgments for session operations
|
|
set_session_acks = [msg for msg in sent_messages
|
|
if msg.get("type") == "setSessionIdAck"]
|
|
assert len(set_session_acks) >= 1
|
|
|
|
patch_session_acks = [msg for msg in sent_messages
|
|
if msg.get("type") == "patchSessionAck"]
|
|
assert len(patch_session_acks) >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_heartbeat_protocol(self, mock_websocket):
|
|
"""Test heartbeat/ping-pong protocol."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor, {"heartbeat_interval": 0.1})
|
|
|
|
# Simulate long-running connection
|
|
mock_websocket.receive_json.side_effect = [
|
|
asyncio.TimeoutError(), # No messages for a while
|
|
asyncio.TimeoutError(),
|
|
asyncio.CancelledError() # Then disconnect
|
|
]
|
|
|
|
client_id = "test_client_heartbeat"
|
|
|
|
# Start heartbeat task
|
|
heartbeat_task = asyncio.create_task(websocket_handler._heartbeat_loop())
|
|
|
|
try:
|
|
# Let heartbeat run briefly
|
|
await asyncio.sleep(0.2)
|
|
|
|
# Cancel heartbeat
|
|
heartbeat_task.cancel()
|
|
await heartbeat_task
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Verify ping was sent
|
|
assert mock_websocket.ping.called
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_detection_result_protocol(self, mock_websocket):
|
|
"""Test detection result message protocol."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
# Mock a detection result being sent
|
|
detection_result = {
|
|
"type": "imageDetection",
|
|
"payload": {
|
|
"subscriptionId": "display-001;cam-001",
|
|
"detections": [
|
|
{
|
|
"class": "Car",
|
|
"confidence": 0.95,
|
|
"bbox": [100, 200, 300, 400],
|
|
"trackId": 1001
|
|
}
|
|
],
|
|
"timestamp": 1640995200000,
|
|
"modelInfo": {
|
|
"modelId": 101,
|
|
"modelName": "Vehicle Detection"
|
|
}
|
|
}
|
|
}
|
|
|
|
# Send detection result to client
|
|
await websocket_handler.send_to_client("test_client", detection_result)
|
|
|
|
# Verify message was sent
|
|
mock_websocket.send_json.assert_called_with(detection_result)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_recovery_protocol(self, mock_websocket):
|
|
"""Test error recovery and graceful degradation."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
# Simulate WebSocket errors
|
|
mock_websocket.send_json.side_effect = [
|
|
None, # First message succeeds
|
|
ConnectionError("Connection lost"), # Second fails
|
|
None # Third succeeds after recovery
|
|
]
|
|
|
|
# Try to send multiple messages
|
|
messages = [
|
|
{"type": "test", "message": "1"},
|
|
{"type": "test", "message": "2"},
|
|
{"type": "test", "message": "3"}
|
|
]
|
|
|
|
results = []
|
|
for msg in messages:
|
|
try:
|
|
await websocket_handler.send_to_client("test_client", msg)
|
|
results.append("success")
|
|
except Exception:
|
|
results.append("error")
|
|
|
|
# Should handle errors gracefully
|
|
assert "error" in results
|
|
# But should still be able to send other messages
|
|
assert "success" in results
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_client_protocol(self, mock_websocket):
|
|
"""Test handling multiple concurrent clients."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
# Create multiple mock WebSocket connections
|
|
mock_websockets = []
|
|
for i in range(3):
|
|
ws = Mock(spec=WebSocket)
|
|
ws.accept = AsyncMock()
|
|
ws.send_json = AsyncMock()
|
|
ws.receive_json = AsyncMock(side_effect=[asyncio.CancelledError()])
|
|
mock_websockets.append(ws)
|
|
|
|
# Handle multiple clients concurrently
|
|
client_tasks = []
|
|
for i, ws in enumerate(mock_websockets):
|
|
task = asyncio.create_task(
|
|
websocket_handler.handle_websocket(ws, f"client_{i}")
|
|
)
|
|
client_tasks.append(task)
|
|
|
|
# Wait briefly then cancel all
|
|
await asyncio.sleep(0.1)
|
|
for task in client_tasks:
|
|
task.cancel()
|
|
|
|
try:
|
|
await asyncio.gather(*client_tasks)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Verify all connections were accepted
|
|
for ws in mock_websockets:
|
|
ws.accept.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscription_sharing_protocol(self, mock_websocket):
|
|
"""Test shared subscription protocol."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
# Mock multiple clients subscribing to same camera
|
|
with patch('cv2.VideoCapture') as mock_video_cap:
|
|
mock_cap_instance = Mock()
|
|
mock_video_cap.return_value = mock_cap_instance
|
|
mock_cap_instance.isOpened.return_value = True
|
|
|
|
# First client subscribes
|
|
subscription_msg1 = {
|
|
"type": "subscribe",
|
|
"payload": {
|
|
"subscriptionIdentifier": "display-001;cam-001",
|
|
"rtspUrl": "rtsp://shared.example.com/stream",
|
|
"modelUrl": "http://example.com/model.mpta"
|
|
}
|
|
}
|
|
|
|
# Second client subscribes to same camera
|
|
subscription_msg2 = {
|
|
"type": "subscribe",
|
|
"payload": {
|
|
"subscriptionIdentifier": "display-002;cam-001",
|
|
"rtspUrl": "rtsp://shared.example.com/stream", # Same URL
|
|
"modelUrl": "http://example.com/model.mpta"
|
|
}
|
|
}
|
|
|
|
# Mock file operations
|
|
with patch('builtins.open') as mock_file_open:
|
|
pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]}
|
|
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
|
|
|
|
# Process both subscriptions
|
|
response1 = await message_processor.process_message(subscription_msg1, "client_1")
|
|
response2 = await message_processor.process_message(subscription_msg2, "client_2")
|
|
|
|
# Both should succeed and reference same underlying stream
|
|
assert response1.get("status") == "success"
|
|
assert response2.get("status") == "success"
|
|
|
|
# Should only create one video capture instance (shared stream)
|
|
assert mock_video_cap.call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_ordering_protocol(self, mock_websocket):
|
|
"""Test message ordering and sequencing."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
# Test sequence of related messages
|
|
messages = [
|
|
{"type": "subscribe", "payload": {"subscriptionIdentifier": "test", "rtspUrl": "rtsp://test.com"}},
|
|
{"type": "setSessionId", "payload": {"sessionId": "session_123", "displayId": "display_001"}},
|
|
{"type": "requestState"},
|
|
{"type": "patchSession", "payload": {"sessionId": "session_123", "data": {"test": "data"}}},
|
|
{"type": "unsubscribe", "payload": {"subscriptionIdentifier": "test"}}
|
|
]
|
|
|
|
mock_websocket.receive_json.side_effect = messages + [asyncio.CancelledError()]
|
|
|
|
with patch('cv2.VideoCapture') as mock_video_cap, \
|
|
patch('builtins.open') as mock_file_open:
|
|
|
|
mock_cap_instance = Mock()
|
|
mock_video_cap.return_value = mock_cap_instance
|
|
mock_cap_instance.isOpened.return_value = True
|
|
|
|
pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]}
|
|
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
|
|
|
|
client_id = "test_client_ordering"
|
|
|
|
try:
|
|
await websocket_handler.handle_websocket(mock_websocket, client_id)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Verify responses were sent in correct order
|
|
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
|
|
|
|
# Should receive responses for each message type
|
|
response_types = [msg.get("type") for msg in sent_messages]
|
|
|
|
expected_types = ["subscribeAck", "setSessionIdAck", "stateReport", "patchSessionAck", "unsubscribeAck"]
|
|
|
|
# Check that we got appropriate responses (order may vary slightly)
|
|
for expected_type in expected_types:
|
|
assert any(expected_type in response_types), f"Missing response type: {expected_type}"
|
|
|
|
|
|
class TestWebSocketPerformance:
|
|
"""Test WebSocket performance characteristics."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_throughput(self, mock_websocket):
|
|
"""Test message processing throughput."""
|
|
|
|
message_processor = MessageProcessor()
|
|
|
|
# Prepare batch of simple messages
|
|
state_request = {"type": "requestState"}
|
|
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Process many messages quickly
|
|
for _ in range(100):
|
|
await message_processor.process_message(state_request, "test_client")
|
|
|
|
end_time = time.time()
|
|
processing_time = end_time - start_time
|
|
|
|
# Should process messages quickly (less than 1 second for 100 messages)
|
|
assert processing_time < 1.0
|
|
|
|
# Calculate throughput
|
|
throughput = 100 / processing_time
|
|
assert throughput > 100 # Should handle > 100 messages/second
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_message_handling(self, mock_websocket):
|
|
"""Test concurrent message handling."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
# Create multiple mock clients
|
|
num_clients = 10
|
|
mock_websockets = []
|
|
|
|
for i in range(num_clients):
|
|
ws = Mock(spec=WebSocket)
|
|
ws.accept = AsyncMock()
|
|
ws.send_json = AsyncMock()
|
|
ws.receive_json = AsyncMock(side_effect=[
|
|
{"type": "requestState"},
|
|
asyncio.CancelledError()
|
|
])
|
|
mock_websockets.append(ws)
|
|
|
|
# Handle all clients concurrently
|
|
client_tasks = []
|
|
for i, ws in enumerate(mock_websockets):
|
|
task = asyncio.create_task(
|
|
websocket_handler.handle_websocket(ws, f"perf_client_{i}")
|
|
)
|
|
client_tasks.append(task)
|
|
|
|
start_time = time.time()
|
|
|
|
# Wait for all to complete
|
|
try:
|
|
await asyncio.gather(*client_tasks)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
end_time = time.time()
|
|
total_time = end_time - start_time
|
|
|
|
# Should handle all clients efficiently
|
|
assert total_time < 2.0 # Should complete in less than 2 seconds
|
|
|
|
# All clients should have been accepted
|
|
for ws in mock_websockets:
|
|
ws.accept.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_memory_usage_stability(self, mock_websocket):
|
|
"""Test memory usage remains stable."""
|
|
|
|
message_processor = MessageProcessor()
|
|
websocket_handler = WebSocketHandler(message_processor)
|
|
|
|
# Simulate many connection/disconnection cycles
|
|
for cycle in range(10):
|
|
# Create client
|
|
client_id = f"memory_test_client_{cycle}"
|
|
|
|
mock_websocket.receive_json.side_effect = [
|
|
{"type": "requestState"},
|
|
asyncio.CancelledError()
|
|
]
|
|
|
|
try:
|
|
await websocket_handler.handle_websocket(mock_websocket, client_id)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Reset mock for next cycle
|
|
mock_websocket.reset_mock()
|
|
mock_websocket.accept = AsyncMock()
|
|
mock_websocket.send_json = AsyncMock()
|
|
mock_websocket.receive_json = AsyncMock()
|
|
|
|
# Connection manager should not accumulate stale connections
|
|
stats = websocket_handler.get_connection_stats()
|
|
assert stats["total_connections"] == 0 # All should be cleaned up |