Refactor: PHASE 8: Testing & Integration

This commit is contained in:
ziesorx 2025-09-12 18:55:23 +07:00
parent af34f4fd08
commit 9e8c6804a7
32 changed files with 17128 additions and 0 deletions

View file

@ -0,0 +1,19 @@
"""
Integration tests for the detector worker application.
This package contains integration tests that verify the interaction
between multiple components and end-to-end workflows.
"""
# Integration test modules
from . import (
test_complete_detection_workflow,
test_websocket_protocol,
test_pipeline_integration
)
__all__ = [
"test_complete_detection_workflow",
"test_websocket_protocol",
"test_pipeline_integration"
]

View file

@ -0,0 +1,681 @@
"""
Integration tests for complete detection workflow.
This module tests the full end-to-end detection pipeline from stream
to database update, ensuring all components work together correctly.
"""
import pytest
import asyncio
import uuid
import time
import json
import tempfile
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from pathlib import Path
import numpy as np
import cv2
from detector_worker.app import create_app
from detector_worker.core.config import Configuration
from detector_worker.core.dependency_injection import ServiceContainer
from detector_worker.streams.stream_manager import StreamManager, StreamConfig
from detector_worker.models.model_manager import ModelManager
from detector_worker.pipeline.pipeline_executor import PipelineExecutor
from detector_worker.storage.database_manager import DatabaseManager
from detector_worker.storage.redis_client import RedisClient, RedisConfig
from detector_worker.communication.websocket_handler import WebSocketHandler
from detector_worker.communication.message_processor import MessageProcessor
@pytest.fixture
def temp_config_file():
"""Create temporary configuration file."""
config_data = {
"poll_interval_ms": 100,
"max_streams": 5,
"target_fps": 10,
"reconnect_interval_sec": 1,
"max_retries": 2,
"database": {
"enabled": True,
"host": "localhost",
"port": 5432,
"database": "test_gas_station_1",
"user": "test_user",
"password": "test_pass"
},
"redis": {
"enabled": True,
"host": "localhost",
"port": 6379,
"db": 0
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
temp_path = f.name
yield temp_path
# Cleanup
Path(temp_path).unlink(missing_ok=True)
@pytest.fixture
def mock_frame():
"""Create a mock frame for testing."""
return np.ones((480, 640, 3), dtype=np.uint8) * 128
@pytest.fixture
def sample_mpta_file():
"""Create a sample MPTA (pipeline) file."""
pipeline_config = {
"modelId": "car_frontal_detection_v1",
"modelFile": "car_frontal_detection_v1.pt",
"multiClass": True,
"expectedClasses": ["Car", "Frontal"],
"triggerClasses": ["Car", "Frontal"],
"minConfidence": 0.8,
"actions": [
{
"type": "redis_save_image",
"region": "Frontal",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600
},
{
"type": "postgresql_create_record",
"table": "car_frontal_info",
"fields": {
"display_id": "{display_id}",
"captured_timestamp": "{timestamp}",
"session_id": "{session_id}",
"license_character": None,
"license_type": "No model available"
}
}
],
"branches": [
{
"modelId": "car_brand_cls_v1",
"modelFile": "car_brand_cls_v1.pt",
"parallel": True,
"crop": True,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.85
},
{
"modelId": "car_bodytype_cls_v1",
"modelFile": "car_bodytype_cls_v1.pt",
"parallel": True,
"crop": True,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.80
}
],
"parallelActions": [
{
"type": "postgresql_update_combined",
"table": "car_frontal_info",
"key_field": "session_id",
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
"fields": {
"car_brand": "{car_brand_cls_v1.brand}",
"car_body_type": "{car_bodytype_cls_v1.body_type}"
}
}
]
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(pipeline_config, f)
temp_path = f.name
yield temp_path
# Cleanup
Path(temp_path).unlink(missing_ok=True)
class TestCompleteDetectionWorkflow:
"""Test complete detection workflow from stream to database."""
@pytest.mark.asyncio
async def test_complete_rtsp_detection_workflow(self, temp_config_file, sample_mpta_file, mock_frame):
"""Test complete workflow: RTSP stream -> detection -> classification -> database."""
# Initialize configuration
config = Configuration()
config.load_from_file(temp_config_file)
# Create service container
container = ServiceContainer()
# Mock all external dependencies
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load, \
patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis:
# Setup video capture mock
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, mock_frame)
# Setup model loading mock
mock_detection_model = Mock()
mock_brand_model = Mock()
mock_bodytype_model = Mock()
def mock_model_load(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return mock_brand_model
elif "bodytype" in path:
return mock_bodytype_model
return Mock()
mock_torch_load.side_effect = mock_model_load
# Setup detection model predictions
mock_detection_model.return_value = Mock()
mock_detection_model.return_value.boxes = Mock()
mock_detection_model.return_value.boxes.xyxy = Mock()
mock_detection_model.return_value.boxes.conf = Mock()
mock_detection_model.return_value.boxes.cls = Mock()
mock_detection_model.return_value.names = {0: "Car", 1: "Frontal"}
# Mock detection results - Car and Frontal detected
mock_detection_model.return_value.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[100, 200, 300, 400], # Car bbox
[150, 250, 250, 350] # Frontal bbox
])
mock_detection_model.return_value.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_model.return_value.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
# Setup classification model predictions
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 5 # Toyota index
mock_brand_result.probs.top1conf = Mock()
mock_brand_result.probs.top1conf.item.return_value = 0.87
mock_brand_result.names = {5: "Toyota"}
mock_brand_model.return_value = mock_brand_result
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 2 # Sedan index
mock_bodytype_result.probs.top1conf = Mock()
mock_bodytype_result.probs.top1conf.item.return_value = 0.82
mock_bodytype_result.names = {2: "Sedan"}
mock_bodytype_model.return_value = mock_bodytype_result
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
# Setup Redis mock
mock_redis_instance = Mock()
mock_redis.return_value = mock_redis_instance
mock_redis_instance.ping.return_value = True
mock_redis_instance.set.return_value = True
mock_redis_instance.expire.return_value = True
# Initialize managers
stream_manager = StreamManager()
model_manager = ModelManager()
pipeline_executor = PipelineExecutor()
# Register services in container
container.register_singleton(StreamManager, lambda: stream_manager)
container.register_singleton(ModelManager, lambda: model_manager)
container.register_singleton(PipelineExecutor, lambda: pipeline_executor)
try:
# 1. Create RTSP stream
stream_config = StreamConfig(
stream_url="rtsp://example.com/stream",
stream_type="rtsp",
target_fps=10
)
stream_info = await stream_manager.create_stream(
camera_id="camera_001",
config=stream_config,
subscription_id="sub_001"
)
# 2. Load pipeline and models
pipeline_config = json.loads(Path(sample_mpta_file).read_text())
# Mock model file paths exist
with patch('os.path.exists', return_value=True):
await model_manager.load_models_from_config(pipeline_config)
# 3. Get frame from stream
frame = stream_manager.get_latest_frame("camera_001")
assert frame is not None
# 4. Run detection pipeline
detection_context = {
"camera_id": "camera_001",
"display_id": "display_001",
"frame": mock_frame,
"timestamp": int(time.time() * 1000),
"session_id": str(uuid.uuid4())
}
# Mock cv2.imencode for Redis image storage
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
pipeline_result = await pipeline_executor.execute_pipeline(
pipeline_config,
detection_context
)
# 5. Verify pipeline execution
assert pipeline_result is not None
assert pipeline_result.get("status") == "completed"
# 6. Verify database operations
# Should have called create record
create_calls = [call for call in mock_cursor.execute.call_args_list
if "INSERT" in str(call)]
assert len(create_calls) >= 1
# Should have called update with classification results
update_calls = [call for call in mock_cursor.execute.call_args_list
if "UPDATE" in str(call)]
assert len(update_calls) >= 1
# 7. Verify Redis operations
# Should have saved cropped image
assert mock_redis_instance.set.called
assert mock_redis_instance.expire.called
# 8. Verify final results
assert "detections" in pipeline_result
detections = pipeline_result["detections"]
assert len(detections) >= 2 # Car and Frontal
# Check detection classes
detected_classes = [d.get("class") for d in detections]
assert "Car" in detected_classes
assert "Frontal" in detected_classes
# 9. Verify classification results in context
classification_results = pipeline_result.get("classification_results", {})
assert "car_brand_cls_v1" in classification_results
assert "car_bodytype_cls_v1" in classification_results
brand_result = classification_results["car_brand_cls_v1"]
assert brand_result.get("brand") == "Toyota"
assert brand_result.get("confidence") == 0.87
bodytype_result = classification_results["car_bodytype_cls_v1"]
assert bodytype_result.get("body_type") == "Sedan"
assert bodytype_result.get("confidence") == 0.82
finally:
# Cleanup
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_websocket_subscription_workflow(self, temp_config_file, sample_mpta_file):
"""Test complete WebSocket subscription workflow."""
# Mock WebSocket for testing
mock_websocket = Mock()
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock()
# Initialize configuration
config = Configuration()
config.load_from_file(temp_config_file)
# Initialize message processor and WebSocket handler
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load, \
patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis:
# Setup mocks
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
mock_torch_load.return_value = Mock()
mock_db_connect.return_value = Mock()
mock_redis.return_value = Mock()
# Simulate WebSocket message sequence
subscription_message = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://example.com/stream",
"modelUrl": f"file://{sample_mpta_file}",
"modelId": 101,
"modelName": "Vehicle Detection",
"cropX1": 100, "cropY1": 200,
"cropX2": 300, "cropY2": 400
}
}
mock_websocket.receive_json.side_effect = [
subscription_message,
{"type": "requestState"},
{"type": "unsubscribe", "payload": {"subscriptionIdentifier": "display-001;cam-001"}}
]
# Mock file operations
with patch('builtins.open', mock_open(read_data=json.dumps(json.loads(Path(sample_mpta_file).read_text())))):
with patch('os.path.exists', return_value=True):
try:
# Start WebSocket handler (will process messages)
await websocket_handler.handle_websocket(mock_websocket, "client_001")
# Verify WebSocket interactions
mock_websocket.accept.assert_called_once()
# Should have sent subscription acknowledgment
send_calls = mock_websocket.send_json.call_args_list
assert len(send_calls) >= 1
# Check subscription acknowledgment
first_response = send_calls[0][0][0]
assert first_response.get("type") == "subscribeAck"
assert first_response.get("status") == "success"
# Should have sent state report
state_responses = [call[0][0] for call in send_calls
if call[0][0].get("type") == "stateReport"]
assert len(state_responses) >= 1
except Exception as e:
# WebSocket disconnect is expected at end of message sequence
pass
@pytest.mark.asyncio
async def test_http_snapshot_workflow(self, temp_config_file, mock_frame):
"""Test HTTP snapshot workflow."""
config = Configuration()
config.load_from_file(temp_config_file)
stream_manager = StreamManager()
with patch('requests.get') as mock_requests:
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"fake_jpeg_data"
mock_requests.return_value = mock_response
with patch('cv2.imdecode') as mock_imdecode:
mock_imdecode.return_value = mock_frame
try:
# Create HTTP snapshot stream
stream_config = StreamConfig(
stream_url="http://camera.example.com/snapshot.jpg",
stream_type="http_snapshot",
snapshot_interval=1.0
)
stream_info = await stream_manager.create_stream(
camera_id="camera_002",
config=stream_config,
subscription_id="sub_002"
)
# Wait for snapshot capture
await asyncio.sleep(1.2)
# Verify frame was captured
frame = stream_manager.get_latest_frame("camera_002")
assert frame is not None
assert np.array_equal(frame, mock_frame)
# Verify HTTP request was made
mock_requests.assert_called()
mock_imdecode.assert_called()
finally:
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_error_recovery_workflow(self, temp_config_file):
"""Test error recovery and resilience."""
config = Configuration()
config.load_from_file(temp_config_file)
stream_manager = StreamManager()
with patch('cv2.VideoCapture') as mock_video_cap:
# Simulate connection failures then success
mock_cap_instances = []
# First attempt fails
mock_cap_fail = Mock()
mock_cap_fail.isOpened.return_value = False
mock_cap_instances.append(mock_cap_fail)
# Second attempt succeeds
mock_cap_success = Mock()
mock_cap_success.isOpened.return_value = True
mock_cap_success.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
mock_cap_instances.append(mock_cap_success)
mock_video_cap.side_effect = mock_cap_instances
try:
stream_config = StreamConfig(
stream_url="rtsp://unreliable.example.com/stream",
stream_type="rtsp",
max_retries=2
)
# First attempt should fail
try:
await stream_manager.create_stream(
camera_id="camera_003",
config=stream_config,
subscription_id="sub_003"
)
assert False, "Should have failed on first attempt"
except Exception:
pass
# Retry should succeed
stream_info = await stream_manager.create_stream(
camera_id="camera_003",
config=stream_config,
subscription_id="sub_003"
)
assert stream_info is not None
assert mock_video_cap.call_count == 2
finally:
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_concurrent_streams_workflow(self, temp_config_file, mock_frame):
"""Test handling multiple concurrent streams."""
config = Configuration()
config.load_from_file(temp_config_file)
stream_manager = StreamManager()
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('requests.get') as mock_requests:
# Setup RTSP mock
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, mock_frame)
# Setup HTTP mock
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"fake_jpeg_data"
mock_requests.return_value = mock_response
with patch('cv2.imdecode', return_value=mock_frame):
try:
# Create multiple streams concurrently
stream_tasks = []
# RTSP streams
for i in range(3):
config = StreamConfig(
stream_url=f"rtsp://camera{i}.example.com/stream",
stream_type="rtsp"
)
task = stream_manager.create_stream(
camera_id=f"camera_rtsp_{i}",
config=config,
subscription_id=f"sub_rtsp_{i}"
)
stream_tasks.append(task)
# HTTP snapshot streams
for i in range(2):
config = StreamConfig(
stream_url=f"http://camera{i}.example.com/snapshot.jpg",
stream_type="http_snapshot"
)
task = stream_manager.create_stream(
camera_id=f"camera_http_{i}",
config=config,
subscription_id=f"sub_http_{i}"
)
stream_tasks.append(task)
# Wait for all streams to be created
stream_results = await asyncio.gather(*stream_tasks, return_exceptions=True)
# Verify all streams were created successfully
successful_streams = [r for r in stream_results if not isinstance(r, Exception)]
assert len(successful_streams) == 5
# Verify frames can be retrieved from all streams
await asyncio.sleep(0.5) # Allow time for frame capture
for i in range(3):
frame = stream_manager.get_latest_frame(f"camera_rtsp_{i}")
assert frame is not None
for i in range(2):
frame = stream_manager.get_latest_frame(f"camera_http_{i}")
# HTTP snapshots might not have frames immediately
# Verify stream statistics
stats = stream_manager.get_stream_statistics()
assert stats["total_streams"] == 5
assert stats["active_streams"] >= 3 # At least RTSP streams should be active
finally:
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_memory_usage_workflow(self, temp_config_file):
"""Test memory usage tracking and cleanup."""
config = Configuration()
config.load_from_file(temp_config_file)
# Create managers with small limits for testing
stream_manager = StreamManager({"max_streams": 10})
model_manager = ModelManager({"cache_max_size": 5})
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load:
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, np.ones((100, 100, 3), dtype=np.uint8))
# Mock model loading
def create_mock_model():
model = Mock()
# Mock model parameters for memory estimation
param1 = Mock()
param1.numel.return_value = 1000
param1.element_size.return_value = 4
model.parameters.return_value = [param1]
return model
mock_torch_load.side_effect = lambda *args, **kwargs: create_mock_model()
try:
# Create streams up to limit
for i in range(8):
stream_config = StreamConfig(
stream_url=f"rtsp://test{i}.example.com/stream",
stream_type="rtsp"
)
await stream_manager.create_stream(
camera_id=f"test_camera_{i}",
config=stream_config,
subscription_id=f"test_sub_{i}"
)
# Load models up to cache limit
with patch('os.path.exists', return_value=True):
for i in range(7):
config = {
"model_id": f"test_model_{i}",
"model_path": f"/fake/path/model_{i}.pt",
"model_type": "detection",
"device": "cpu"
}
await model_manager.load_model_from_dict(config)
# Check memory usage tracking
stream_stats = stream_manager.get_stream_statistics()
model_stats = model_manager.get_cache_statistics()
assert stream_stats["total_streams"] == 8
assert model_stats["size"] <= 5 # Should be limited by cache size
# Test cleanup
cleaned_models = model_manager.cleanup_unused_models()
assert cleaned_models >= 0
stopped_streams = await stream_manager.stop_all_streams()
assert stopped_streams == 8
# Verify cleanup
final_stream_stats = stream_manager.get_stream_statistics()
assert final_stream_stats["total_streams"] == 0
finally:
await stream_manager.stop_all_streams()
def mock_open(read_data=""):
"""Create a mock file opener."""
from unittest.mock import mock_open as _mock_open
return _mock_open(read_data=read_data)

View file

@ -0,0 +1,738 @@
"""
Integration tests for pipeline execution workflows.
Tests the complete machine learning pipeline execution including
detection, classification, database updates, and Redis actions.
"""
import pytest
import asyncio
import json
import tempfile
import uuid
import time
from pathlib import Path
from unittest.mock import Mock, patch, AsyncMock
import numpy as np
from detector_worker.pipeline.pipeline_executor import PipelineExecutor
from detector_worker.pipeline.action_executor import ActionExecutor
from detector_worker.pipeline.field_mapper import FieldMapper
from detector_worker.models.model_manager import ModelManager
from detector_worker.storage.database_manager import DatabaseManager
from detector_worker.storage.redis_client import RedisClient, RedisConfig
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
@pytest.fixture
def sample_detection_pipeline():
"""Create sample detection pipeline configuration."""
return {
"modelId": "car_frontal_detection_v1",
"modelFile": "car_frontal_detection_v1.pt",
"multiClass": True,
"expectedClasses": ["Car", "Frontal"],
"triggerClasses": ["Car", "Frontal"],
"minConfidence": 0.8,
"actions": [
{
"type": "redis_save_image",
"region": "Frontal",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600
},
{
"type": "postgresql_create_record",
"table": "car_frontal_info",
"fields": {
"display_id": "{display_id}",
"captured_timestamp": "{timestamp}",
"session_id": "{session_id}",
"license_character": None,
"license_type": "No model available"
}
}
],
"branches": [
{
"modelId": "car_brand_cls_v1",
"modelFile": "car_brand_cls_v1.pt",
"parallel": True,
"crop": True,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.85
},
{
"modelId": "car_bodytype_cls_v1",
"modelFile": "car_bodytype_cls_v1.pt",
"parallel": True,
"crop": True,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.80
}
],
"parallelActions": [
{
"type": "postgresql_update_combined",
"table": "car_frontal_info",
"key_field": "session_id",
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
"fields": {
"car_brand": "{car_brand_cls_v1.brand}",
"car_body_type": "{car_bodytype_cls_v1.body_type}"
}
}
]
}
@pytest.fixture
def sample_frame():
"""Create sample frame for testing."""
return np.ones((480, 640, 3), dtype=np.uint8) * 128
@pytest.fixture
def detection_context():
"""Create sample detection context."""
return {
"camera_id": "camera_001",
"display_id": "display_001",
"timestamp": int(time.time() * 1000),
"session_id": str(uuid.uuid4()),
"frame": np.ones((480, 640, 3), dtype=np.uint8) * 128,
"filename": "detection_image.jpg"
}
class TestPipelineIntegration:
"""Test complete pipeline integration workflows."""
@pytest.mark.asyncio
async def test_complete_detection_classification_pipeline(self, sample_detection_pipeline, detection_context):
"""Test complete detection to classification pipeline."""
pipeline_executor = PipelineExecutor()
model_manager = ModelManager()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis:
# Setup detection model mock
mock_detection_model = Mock()
mock_detection_result = Mock()
# Mock successful multi-class detection
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
# Detection results: Car and Frontal detected with high confidence
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], # Car bbox
[150, 200, 300, 400] # Frontal bbox (within Car)
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
# Setup classification models
mock_brand_model = Mock()
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 3 # Toyota index
mock_brand_result.probs.top1conf = Mock()
mock_brand_result.probs.top1conf.item.return_value = 0.87
mock_brand_result.names = {3: "Toyota"}
mock_brand_model.return_value = mock_brand_result
mock_bodytype_model = Mock()
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 1 # Sedan index
mock_bodytype_result.probs.top1conf = Mock()
mock_bodytype_result.probs.top1conf.item.return_value = 0.82
mock_bodytype_result.names = {1: "Sedan"}
mock_bodytype_model.return_value = mock_bodytype_result
# Route model loading to appropriate mocks
def model_loader(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return mock_brand_model
elif "bodytype" in path:
return mock_bodytype_model
return Mock()
mock_torch_load.side_effect = model_loader
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
mock_cursor.fetchone.return_value = None
# Setup Redis mock
mock_redis_instance = Mock()
mock_redis.return_value = mock_redis_instance
mock_redis_instance.ping.return_value = True
mock_redis_instance.set.return_value = True
mock_redis_instance.expire.return_value = True
# Mock image encoding for Redis storage
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Execute complete pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Verify pipeline execution
assert result is not None
assert result.get("status") == "completed"
assert "detections" in result
# Verify detection results
detections = result["detections"]
assert len(detections) == 2 # Car and Frontal
detection_classes = [d.get("class") for d in detections]
assert "Car" in detection_classes
assert "Frontal" in detection_classes
# Verify classification results
assert "classification_results" in result
classification_results = result["classification_results"]
assert "car_brand_cls_v1" in classification_results
brand_result = classification_results["car_brand_cls_v1"]
assert brand_result.get("brand") == "Toyota"
assert brand_result.get("confidence") == 0.87
assert "car_bodytype_cls_v1" in classification_results
bodytype_result = classification_results["car_bodytype_cls_v1"]
assert bodytype_result.get("body_type") == "Sedan"
assert bodytype_result.get("confidence") == 0.82
# Verify database operations
db_calls = mock_cursor.execute.call_args_list
# Should have INSERT for initial record creation
insert_calls = [call for call in db_calls if "INSERT" in str(call[0])]
assert len(insert_calls) >= 1
# Should have UPDATE for classification results
update_calls = [call for call in db_calls if "UPDATE" in str(call[0])]
assert len(update_calls) >= 1
# Verify Redis operations
assert mock_redis_instance.set.called
assert mock_redis_instance.expire.called
@pytest.mark.asyncio
async def test_pipeline_with_missing_detections(self, sample_detection_pipeline, detection_context):
"""Test pipeline behavior when expected detections are missing."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True):
# Setup detection model that doesn't find expected classes
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
# Only detect Car, no Frontal
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450] # Only Car bbox
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0])
mock_detection_model.return_value = mock_detection_result
mock_torch_load.return_value = mock_detection_model
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Pipeline should complete but skip classification branches
assert result is not None
assert "detections" in result
detections = result["detections"]
assert len(detections) == 1 # Only Car detected
assert detections[0].get("class") == "Car"
# Classification should not have run (no Frontal detected)
classification_results = result.get("classification_results", {})
assert len(classification_results) == 0 or all(
not res for res in classification_results.values()
)
@pytest.mark.asyncio
async def test_pipeline_with_low_confidence_detections(self, sample_detection_pipeline, detection_context):
"""Test pipeline with detections below confidence threshold."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True):
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
# Detections with low confidence (below 0.8 threshold)
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], # Car bbox
[150, 200, 300, 400] # Frontal bbox
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.75, 0.70]) # Below threshold
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
mock_torch_load.return_value = mock_detection_model
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Should complete but with filtered detections
assert result is not None
# Low confidence detections should be filtered out
detections = result.get("detections", [])
high_conf_detections = [d for d in detections if d.get("confidence", 0) >= 0.8]
assert len(high_conf_detections) == 0
@pytest.mark.asyncio
async def test_pipeline_branch_execution_order(self, sample_detection_pipeline, detection_context):
"""Test that pipeline branches execute in correct order and parallel mode works."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('psycopg2.connect') as mock_db_connect:
# Track execution order
execution_order = []
# Setup detection model
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], [150, 200, 300, 400]
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
def track_detection_execution(*args, **kwargs):
execution_order.append("detection")
return mock_detection_result
mock_detection_model.side_effect = track_detection_execution
# Setup classification models with execution tracking
def create_tracked_model(model_id):
def track_execution(*args, **kwargs):
execution_order.append(model_id)
result = Mock()
result.probs = Mock()
result.probs.top1 = 0
result.probs.top1conf = Mock()
result.probs.top1conf.item.return_value = 0.90
result.names = {0: "TestResult"}
return result
model = Mock()
model.side_effect = track_execution
return model
# Route models with execution tracking
def model_loader(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return create_tracked_model("car_brand_cls_v1")
elif "bodytype" in path:
return create_tracked_model("car_bodytype_cls_v1")
return Mock()
mock_torch_load.side_effect = model_loader
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Verify execution order
assert "detection" in execution_order
assert execution_order[0] == "detection" # Detection should run first
# Classification models should run after detection
brand_index = execution_order.index("car_brand_cls_v1") if "car_brand_cls_v1" in execution_order else -1
bodytype_index = execution_order.index("car_bodytype_cls_v1") if "car_bodytype_cls_v1" in execution_order else -1
detection_index = execution_order.index("detection")
if brand_index >= 0:
assert brand_index > detection_index
if bodytype_index >= 0:
assert bodytype_index > detection_index
# Since branches are parallel, they could run in any order relative to each other
# but both should run after detection
@pytest.mark.asyncio
async def test_pipeline_error_recovery(self, sample_detection_pipeline, detection_context):
"""Test pipeline error handling and recovery."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('psycopg2.connect') as mock_db_connect:
# Setup detection model that works
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], [150, 200, 300, 400]
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
# Setup classification models - one fails, one succeeds
mock_brand_model = Mock()
mock_brand_model.side_effect = RuntimeError("Model inference failed")
mock_bodytype_model = Mock()
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 1
mock_bodytype_result.probs.top1conf = Mock()
mock_bodytype_result.probs.top1conf.item.return_value = 0.85
mock_bodytype_result.names = {1: "SUV"}
mock_bodytype_model.return_value = mock_bodytype_result
def model_loader(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return mock_brand_model
elif "bodytype" in path:
return mock_bodytype_model
return Mock()
mock_torch_load.side_effect = model_loader
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Pipeline should complete despite one branch failing
assert result is not None
# Detection should succeed
assert "detections" in result
detections = result["detections"]
assert len(detections) == 2
# Classification results should be partial
classification_results = result.get("classification_results", {})
# Brand classification should have failed
brand_result = classification_results.get("car_brand_cls_v1")
assert brand_result is None or brand_result.get("error") is not None
# Body type classification should have succeeded
bodytype_result = classification_results.get("car_bodytype_cls_v1")
assert bodytype_result is not None
assert bodytype_result.get("body_type") == "SUV"
assert bodytype_result.get("confidence") == 0.85
@pytest.mark.asyncio
async def test_field_mapping_and_database_update(self, sample_detection_pipeline, detection_context):
"""Test field mapping and database update integration."""
pipeline_executor = PipelineExecutor()
field_mapper = FieldMapper()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('psycopg2.connect') as mock_db_connect:
# Setup successful detection and classification
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], [150, 200, 300, 400]
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
# Setup classification models
mock_brand_model = Mock()
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 2
mock_brand_result.probs.top1conf = Mock()
mock_brand_result.probs.top1conf.item.return_value = 0.88
mock_brand_result.names = {2: "Honda"}
mock_brand_model.return_value = mock_brand_result
mock_bodytype_model = Mock()
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 0
mock_bodytype_result.probs.top1conf = Mock()
mock_bodytype_result.probs.top1conf.item.return_value = 0.91
mock_bodytype_result.names = {0: "Hatchback"}
mock_bodytype_model.return_value = mock_bodytype_result
def model_loader(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return mock_brand_model
elif "bodytype" in path:
return mock_bodytype_model
return Mock()
mock_torch_load.side_effect = model_loader
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Verify pipeline completed successfully
assert result is not None
assert result.get("status") == "completed"
# Check database operations
db_calls = mock_cursor.execute.call_args_list
# Should have INSERT and UPDATE operations
insert_calls = [call for call in db_calls if "INSERT" in str(call[0])]
update_calls = [call for call in db_calls if "UPDATE" in str(call[0])]
assert len(insert_calls) >= 1
assert len(update_calls) >= 1
# Check that UPDATE includes field mapping results
update_sql = str(update_calls[0][0])
assert "car_brand" in update_sql.lower()
assert "car_body_type" in update_sql.lower()
# Check that classification results were properly mapped
classification_results = result.get("classification_results", {})
assert "car_brand_cls_v1" in classification_results
assert "car_bodytype_cls_v1" in classification_results
brand_result = classification_results["car_brand_cls_v1"]
bodytype_result = classification_results["car_bodytype_cls_v1"]
assert brand_result.get("brand") == "Honda"
assert brand_result.get("confidence") == 0.88
assert bodytype_result.get("body_type") == "Hatchback"
assert bodytype_result.get("confidence") == 0.91
@pytest.mark.asyncio
async def test_redis_image_storage_integration(self, sample_detection_pipeline, detection_context):
"""Test Redis image storage integration in pipeline."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('redis.Redis') as mock_redis, \
patch('cv2.imencode') as mock_imencode:
# Setup successful detection
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], [150, 200, 300, 400]
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
mock_torch_load.return_value = mock_detection_model
# Setup Redis mock
mock_redis_instance = Mock()
mock_redis.return_value = mock_redis_instance
mock_redis_instance.ping.return_value = True
mock_redis_instance.set.return_value = True
mock_redis_instance.expire.return_value = True
# Setup image encoding mock
encoded_data = np.array([1, 2, 3, 4, 5], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Execute pipeline
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
# Verify Redis operations
assert mock_redis_instance.set.called
assert mock_redis_instance.expire.called
# Check that image was encoded
assert mock_imencode.called
# Verify correct key format was used
set_call = mock_redis_instance.set.call_args
redis_key = set_call[0][0]
# Key should contain display_id, timestamp, session_id
assert detection_context["display_id"] in redis_key
assert detection_context["session_id"] in redis_key
assert str(detection_context["timestamp"]) in redis_key
# Should set expiration
expire_call = mock_redis_instance.expire.call_args
expire_key = expire_call[0][0]
expire_seconds = expire_call[0][1]
assert expire_key == redis_key
assert expire_seconds == 600 # As configured in pipeline
@pytest.mark.asyncio
async def test_pipeline_performance_timing(self, sample_detection_pipeline, detection_context):
"""Test pipeline execution timing and performance."""
pipeline_executor = PipelineExecutor()
with patch('torch.load') as mock_torch_load, \
patch('os.path.exists', return_value=True), \
patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis, \
patch('cv2.imencode') as mock_imencode:
# Setup fast mocks
mock_detection_model = Mock()
mock_detection_result = Mock()
mock_detection_result.boxes = Mock()
mock_detection_result.boxes.xyxy = Mock()
mock_detection_result.boxes.conf = Mock()
mock_detection_result.boxes.cls = Mock()
mock_detection_result.names = {0: "Car", 1: "Frontal"}
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[50, 100, 350, 450], [150, 200, 300, 400]
])
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
mock_detection_model.return_value = mock_detection_result
# Setup fast classification models
def create_fast_model():
model = Mock()
result = Mock()
result.probs = Mock()
result.probs.top1 = 0
result.probs.top1conf = Mock()
result.probs.top1conf.item.return_value = 0.90
result.names = {0: "TestClass"}
model.return_value = result
return model
def model_loader(path, **kwargs):
if "detection" in path:
return mock_detection_model
else:
return create_fast_model()
mock_torch_load.side_effect = model_loader
# Setup fast database and Redis
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
mock_redis_instance = Mock()
mock_redis.return_value = mock_redis_instance
mock_redis_instance.ping.return_value = True
mock_redis_instance.set.return_value = True
mock_redis_instance.expire.return_value = True
encoded_data = np.array([1, 2, 3], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
# Measure execution time
start_time = time.time()
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
end_time = time.time()
execution_time = end_time - start_time
# Pipeline should complete quickly (less than 1 second with mocks)
assert execution_time < 1.0
# Should have timing information in result
assert result is not None
if "execution_time" in result:
assert result["execution_time"] > 0
# Verify pipeline completed successfully
assert result.get("status") == "completed"

View file

@ -0,0 +1,579 @@
"""
Integration tests for WebSocket protocol compliance.
Tests the complete WebSocket communication protocol to ensure
compatibility with existing clients and proper message handling.
"""
import pytest
import asyncio
import json
import uuid
from unittest.mock import Mock, AsyncMock, patch
from fastapi.websockets import WebSocket
from fastapi.testclient import TestClient
from detector_worker.app import create_app
from detector_worker.communication.websocket_handler import WebSocketHandler
from detector_worker.communication.message_processor import MessageProcessor, MessageType
from detector_worker.core.exceptions import MessageProcessingError
@pytest.fixture
def test_app():
"""Create test FastAPI application."""
return create_app()
@pytest.fixture
def mock_websocket():
"""Create mock WebSocket for testing."""
websocket = Mock(spec=WebSocket)
websocket.accept = AsyncMock()
websocket.send_json = AsyncMock()
websocket.send_text = AsyncMock()
websocket.receive_json = AsyncMock()
websocket.receive_text = AsyncMock()
websocket.close = AsyncMock()
websocket.ping = AsyncMock()
return websocket
class TestWebSocketProtocol:
"""Test WebSocket protocol compliance."""
@pytest.mark.asyncio
async def test_subscription_message_protocol(self, mock_websocket):
"""Test subscription message handling protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Mock external dependencies
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load, \
patch('builtins.open') as mock_file_open:
# Setup video capture mock
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
# Setup model loading mock
mock_torch_load.return_value = Mock()
# Setup pipeline file mock
pipeline_config = {
"modelId": "test_detection_model",
"modelFile": "test_model.pt",
"expectedClasses": ["Car"],
"minConfidence": 0.8
}
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
# Test message sequence
subscription_message = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://example.com/stream",
"modelUrl": "http://example.com/model.mpta",
"modelId": 101,
"modelName": "Test Detection",
"cropX1": 0,
"cropY1": 0,
"cropX2": 640,
"cropY2": 480
}
}
request_state_message = {
"type": "requestState"
}
unsubscribe_message = {
"type": "unsubscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001"
}
}
# Mock WebSocket message sequence
mock_websocket.receive_json.side_effect = [
subscription_message,
request_state_message,
unsubscribe_message,
asyncio.CancelledError() # Simulate client disconnect
]
client_id = "test_client_001"
try:
# Handle WebSocket connection
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass # Expected when client disconnects
# Verify protocol compliance
mock_websocket.accept.assert_called_once()
# Check sent messages
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
# Should receive subscription acknowledgment
subscribe_acks = [msg for msg in sent_messages if msg.get("type") == "subscribeAck"]
assert len(subscribe_acks) >= 1
subscribe_ack = subscribe_acks[0]
assert subscribe_ack["status"] in ["success", "error"]
if subscribe_ack["status"] == "success":
assert "subscriptionId" in subscribe_ack
# Should receive state report
state_reports = [msg for msg in sent_messages if msg.get("type") == "stateReport"]
assert len(state_reports) >= 1
state_report = state_reports[0]
assert "payload" in state_report
assert "subscriptions" in state_report["payload"]
assert "system" in state_report["payload"]
# Should receive unsubscribe acknowledgment
unsubscribe_acks = [msg for msg in sent_messages if msg.get("type") == "unsubscribeAck"]
assert len(unsubscribe_acks) >= 1
unsubscribe_ack = unsubscribe_acks[0]
assert unsubscribe_ack["status"] in ["success", "error"]
@pytest.mark.asyncio
async def test_invalid_message_handling(self, mock_websocket):
"""Test handling of invalid messages."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Test invalid message types
invalid_messages = [
{"invalid": "message"}, # Missing type
{"type": "unknown_type", "payload": {}}, # Unknown type
{"type": "subscribe"}, # Missing payload
{"type": "subscribe", "payload": {}}, # Missing required fields
{"type": "subscribe", "payload": {"subscriptionIdentifier": "test"}}, # Missing URL
]
mock_websocket.receive_json.side_effect = invalid_messages + [asyncio.CancelledError()]
client_id = "test_client_error"
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Verify error responses
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
error_messages = [msg for msg in sent_messages if msg.get("type") == "error"]
# Should receive error responses for invalid messages
assert len(error_messages) >= len(invalid_messages)
for error_msg in error_messages:
assert "message" in error_msg
assert error_msg["message"] # Non-empty error message
@pytest.mark.asyncio
async def test_session_management_protocol(self, mock_websocket):
"""Test session management protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
session_id = str(uuid.uuid4())
# Test session messages
set_session_message = {
"type": "setSessionId",
"payload": {
"sessionId": session_id,
"displayId": "display-001"
}
}
patch_session_message = {
"type": "patchSession",
"payload": {
"sessionId": session_id,
"data": {
"car_brand": "Toyota",
"confidence": 0.92
}
}
}
mock_websocket.receive_json.side_effect = [
set_session_message,
patch_session_message,
asyncio.CancelledError()
]
client_id = "test_client_session"
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Verify session responses
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
# Should receive acknowledgments for session operations
set_session_acks = [msg for msg in sent_messages
if msg.get("type") == "setSessionIdAck"]
assert len(set_session_acks) >= 1
patch_session_acks = [msg for msg in sent_messages
if msg.get("type") == "patchSessionAck"]
assert len(patch_session_acks) >= 1
@pytest.mark.asyncio
async def test_heartbeat_protocol(self, mock_websocket):
"""Test heartbeat/ping-pong protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor, {"heartbeat_interval": 0.1})
# Simulate long-running connection
mock_websocket.receive_json.side_effect = [
asyncio.TimeoutError(), # No messages for a while
asyncio.TimeoutError(),
asyncio.CancelledError() # Then disconnect
]
client_id = "test_client_heartbeat"
# Start heartbeat task
heartbeat_task = asyncio.create_task(websocket_handler._heartbeat_loop())
try:
# Let heartbeat run briefly
await asyncio.sleep(0.2)
# Cancel heartbeat
heartbeat_task.cancel()
await heartbeat_task
except asyncio.CancelledError:
pass
# Verify ping was sent
assert mock_websocket.ping.called
@pytest.mark.asyncio
async def test_detection_result_protocol(self, mock_websocket):
"""Test detection result message protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Mock a detection result being sent
detection_result = {
"type": "imageDetection",
"payload": {
"subscriptionId": "display-001;cam-001",
"detections": [
{
"class": "Car",
"confidence": 0.95,
"bbox": [100, 200, 300, 400],
"trackId": 1001
}
],
"timestamp": 1640995200000,
"modelInfo": {
"modelId": 101,
"modelName": "Vehicle Detection"
}
}
}
# Send detection result to client
await websocket_handler.send_to_client("test_client", detection_result)
# Verify message was sent
mock_websocket.send_json.assert_called_with(detection_result)
@pytest.mark.asyncio
async def test_error_recovery_protocol(self, mock_websocket):
"""Test error recovery and graceful degradation."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Simulate WebSocket errors
mock_websocket.send_json.side_effect = [
None, # First message succeeds
ConnectionError("Connection lost"), # Second fails
None # Third succeeds after recovery
]
# Try to send multiple messages
messages = [
{"type": "test", "message": "1"},
{"type": "test", "message": "2"},
{"type": "test", "message": "3"}
]
results = []
for msg in messages:
try:
await websocket_handler.send_to_client("test_client", msg)
results.append("success")
except Exception:
results.append("error")
# Should handle errors gracefully
assert "error" in results
# But should still be able to send other messages
assert "success" in results
@pytest.mark.asyncio
async def test_concurrent_client_protocol(self, mock_websocket):
"""Test handling multiple concurrent clients."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Create multiple mock WebSocket connections
mock_websockets = []
for i in range(3):
ws = Mock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
ws.receive_json = AsyncMock(side_effect=[asyncio.CancelledError()])
mock_websockets.append(ws)
# Handle multiple clients concurrently
client_tasks = []
for i, ws in enumerate(mock_websockets):
task = asyncio.create_task(
websocket_handler.handle_websocket(ws, f"client_{i}")
)
client_tasks.append(task)
# Wait briefly then cancel all
await asyncio.sleep(0.1)
for task in client_tasks:
task.cancel()
try:
await asyncio.gather(*client_tasks)
except asyncio.CancelledError:
pass
# Verify all connections were accepted
for ws in mock_websockets:
ws.accept.assert_called_once()
@pytest.mark.asyncio
async def test_subscription_sharing_protocol(self, mock_websocket):
"""Test shared subscription protocol."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Mock multiple clients subscribing to same camera
with patch('cv2.VideoCapture') as mock_video_cap:
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
# First client subscribes
subscription_msg1 = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://shared.example.com/stream",
"modelUrl": "http://example.com/model.mpta"
}
}
# Second client subscribes to same camera
subscription_msg2 = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-002;cam-001",
"rtspUrl": "rtsp://shared.example.com/stream", # Same URL
"modelUrl": "http://example.com/model.mpta"
}
}
# Mock file operations
with patch('builtins.open') as mock_file_open:
pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]}
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
# Process both subscriptions
response1 = await message_processor.process_message(subscription_msg1, "client_1")
response2 = await message_processor.process_message(subscription_msg2, "client_2")
# Both should succeed and reference same underlying stream
assert response1.get("status") == "success"
assert response2.get("status") == "success"
# Should only create one video capture instance (shared stream)
assert mock_video_cap.call_count == 1
@pytest.mark.asyncio
async def test_message_ordering_protocol(self, mock_websocket):
"""Test message ordering and sequencing."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Test sequence of related messages
messages = [
{"type": "subscribe", "payload": {"subscriptionIdentifier": "test", "rtspUrl": "rtsp://test.com"}},
{"type": "setSessionId", "payload": {"sessionId": "session_123", "displayId": "display_001"}},
{"type": "requestState"},
{"type": "patchSession", "payload": {"sessionId": "session_123", "data": {"test": "data"}}},
{"type": "unsubscribe", "payload": {"subscriptionIdentifier": "test"}}
]
mock_websocket.receive_json.side_effect = messages + [asyncio.CancelledError()]
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('builtins.open') as mock_file_open:
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]}
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
client_id = "test_client_ordering"
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Verify responses were sent in correct order
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
# Should receive responses for each message type
response_types = [msg.get("type") for msg in sent_messages]
expected_types = ["subscribeAck", "setSessionIdAck", "stateReport", "patchSessionAck", "unsubscribeAck"]
# Check that we got appropriate responses (order may vary slightly)
for expected_type in expected_types:
assert any(expected_type in response_types), f"Missing response type: {expected_type}"
class TestWebSocketPerformance:
"""Test WebSocket performance characteristics."""
@pytest.mark.asyncio
async def test_message_throughput(self, mock_websocket):
"""Test message processing throughput."""
message_processor = MessageProcessor()
# Prepare batch of simple messages
state_request = {"type": "requestState"}
import time
start_time = time.time()
# Process many messages quickly
for _ in range(100):
await message_processor.process_message(state_request, "test_client")
end_time = time.time()
processing_time = end_time - start_time
# Should process messages quickly (less than 1 second for 100 messages)
assert processing_time < 1.0
# Calculate throughput
throughput = 100 / processing_time
assert throughput > 100 # Should handle > 100 messages/second
@pytest.mark.asyncio
async def test_concurrent_message_handling(self, mock_websocket):
"""Test concurrent message handling."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Create multiple mock clients
num_clients = 10
mock_websockets = []
for i in range(num_clients):
ws = Mock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
ws.receive_json = AsyncMock(side_effect=[
{"type": "requestState"},
asyncio.CancelledError()
])
mock_websockets.append(ws)
# Handle all clients concurrently
client_tasks = []
for i, ws in enumerate(mock_websockets):
task = asyncio.create_task(
websocket_handler.handle_websocket(ws, f"perf_client_{i}")
)
client_tasks.append(task)
start_time = time.time()
# Wait for all to complete
try:
await asyncio.gather(*client_tasks)
except asyncio.CancelledError:
pass
end_time = time.time()
total_time = end_time - start_time
# Should handle all clients efficiently
assert total_time < 2.0 # Should complete in less than 2 seconds
# All clients should have been accepted
for ws in mock_websockets:
ws.accept.assert_called_once()
@pytest.mark.asyncio
async def test_memory_usage_stability(self, mock_websocket):
"""Test memory usage remains stable."""
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
# Simulate many connection/disconnection cycles
for cycle in range(10):
# Create client
client_id = f"memory_test_client_{cycle}"
mock_websocket.receive_json.side_effect = [
{"type": "requestState"},
asyncio.CancelledError()
]
try:
await websocket_handler.handle_websocket(mock_websocket, client_id)
except asyncio.CancelledError:
pass
# Reset mock for next cycle
mock_websocket.reset_mock()
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock()
# Connection manager should not accumulate stale connections
stats = websocket_handler.get_connection_stats()
assert stats["total_connections"] == 0 # All should be cleaned up