Refactor: PHASE 8: Testing & Integration
This commit is contained in:
parent
af34f4fd08
commit
9e8c6804a7
32 changed files with 17128 additions and 0 deletions
19
tests/integration/__init__.py
Normal file
19
tests/integration/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
"""
|
||||
Integration tests for the detector worker application.
|
||||
|
||||
This package contains integration tests that verify the interaction
|
||||
between multiple components and end-to-end workflows.
|
||||
"""
|
||||
|
||||
# Integration test modules
|
||||
from . import (
|
||||
test_complete_detection_workflow,
|
||||
test_websocket_protocol,
|
||||
test_pipeline_integration
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"test_complete_detection_workflow",
|
||||
"test_websocket_protocol",
|
||||
"test_pipeline_integration"
|
||||
]
|
681
tests/integration/test_complete_detection_workflow.py
Normal file
681
tests/integration/test_complete_detection_workflow.py
Normal file
|
@ -0,0 +1,681 @@
|
|||
"""
|
||||
Integration tests for complete detection workflow.
|
||||
|
||||
This module tests the full end-to-end detection pipeline from stream
|
||||
to database update, ensuring all components work together correctly.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import uuid
|
||||
import time
|
||||
import json
|
||||
import tempfile
|
||||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from detector_worker.app import create_app
|
||||
from detector_worker.core.config import Configuration
|
||||
from detector_worker.core.dependency_injection import ServiceContainer
|
||||
from detector_worker.streams.stream_manager import StreamManager, StreamConfig
|
||||
from detector_worker.models.model_manager import ModelManager
|
||||
from detector_worker.pipeline.pipeline_executor import PipelineExecutor
|
||||
from detector_worker.storage.database_manager import DatabaseManager
|
||||
from detector_worker.storage.redis_client import RedisClient, RedisConfig
|
||||
from detector_worker.communication.websocket_handler import WebSocketHandler
|
||||
from detector_worker.communication.message_processor import MessageProcessor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_config_file():
|
||||
"""Create temporary configuration file."""
|
||||
config_data = {
|
||||
"poll_interval_ms": 100,
|
||||
"max_streams": 5,
|
||||
"target_fps": 10,
|
||||
"reconnect_interval_sec": 1,
|
||||
"max_retries": 2,
|
||||
"database": {
|
||||
"enabled": True,
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "test_gas_station_1",
|
||||
"user": "test_user",
|
||||
"password": "test_pass"
|
||||
},
|
||||
"redis": {
|
||||
"enabled": True,
|
||||
"host": "localhost",
|
||||
"port": 6379,
|
||||
"db": 0
|
||||
}
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
temp_path = f.name
|
||||
|
||||
yield temp_path
|
||||
|
||||
# Cleanup
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_frame():
|
||||
"""Create a mock frame for testing."""
|
||||
return np.ones((480, 640, 3), dtype=np.uint8) * 128
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_mpta_file():
|
||||
"""Create a sample MPTA (pipeline) file."""
|
||||
pipeline_config = {
|
||||
"modelId": "car_frontal_detection_v1",
|
||||
"modelFile": "car_frontal_detection_v1.pt",
|
||||
"multiClass": True,
|
||||
"expectedClasses": ["Car", "Frontal"],
|
||||
"triggerClasses": ["Car", "Frontal"],
|
||||
"minConfidence": 0.8,
|
||||
"actions": [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "Frontal",
|
||||
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
|
||||
"expire_seconds": 600
|
||||
},
|
||||
{
|
||||
"type": "postgresql_create_record",
|
||||
"table": "car_frontal_info",
|
||||
"fields": {
|
||||
"display_id": "{display_id}",
|
||||
"captured_timestamp": "{timestamp}",
|
||||
"session_id": "{session_id}",
|
||||
"license_character": None,
|
||||
"license_type": "No model available"
|
||||
}
|
||||
}
|
||||
],
|
||||
"branches": [
|
||||
{
|
||||
"modelId": "car_brand_cls_v1",
|
||||
"modelFile": "car_brand_cls_v1.pt",
|
||||
"parallel": True,
|
||||
"crop": True,
|
||||
"cropClass": "Frontal",
|
||||
"triggerClasses": ["Frontal"],
|
||||
"minConfidence": 0.85
|
||||
},
|
||||
{
|
||||
"modelId": "car_bodytype_cls_v1",
|
||||
"modelFile": "car_bodytype_cls_v1.pt",
|
||||
"parallel": True,
|
||||
"crop": True,
|
||||
"cropClass": "Frontal",
|
||||
"triggerClasses": ["Frontal"],
|
||||
"minConfidence": 0.80
|
||||
}
|
||||
],
|
||||
"parallelActions": [
|
||||
{
|
||||
"type": "postgresql_update_combined",
|
||||
"table": "car_frontal_info",
|
||||
"key_field": "session_id",
|
||||
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
|
||||
"fields": {
|
||||
"car_brand": "{car_brand_cls_v1.brand}",
|
||||
"car_body_type": "{car_bodytype_cls_v1.body_type}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(pipeline_config, f)
|
||||
temp_path = f.name
|
||||
|
||||
yield temp_path
|
||||
|
||||
# Cleanup
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
class TestCompleteDetectionWorkflow:
|
||||
"""Test complete detection workflow from stream to database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_rtsp_detection_workflow(self, temp_config_file, sample_mpta_file, mock_frame):
|
||||
"""Test complete workflow: RTSP stream -> detection -> classification -> database."""
|
||||
|
||||
# Initialize configuration
|
||||
config = Configuration()
|
||||
config.load_from_file(temp_config_file)
|
||||
|
||||
# Create service container
|
||||
container = ServiceContainer()
|
||||
|
||||
# Mock all external dependencies
|
||||
with patch('cv2.VideoCapture') as mock_video_cap, \
|
||||
patch('torch.load') as mock_torch_load, \
|
||||
patch('psycopg2.connect') as mock_db_connect, \
|
||||
patch('redis.Redis') as mock_redis:
|
||||
|
||||
# Setup video capture mock
|
||||
mock_cap_instance = Mock()
|
||||
mock_video_cap.return_value = mock_cap_instance
|
||||
mock_cap_instance.isOpened.return_value = True
|
||||
mock_cap_instance.read.return_value = (True, mock_frame)
|
||||
|
||||
# Setup model loading mock
|
||||
mock_detection_model = Mock()
|
||||
mock_brand_model = Mock()
|
||||
mock_bodytype_model = Mock()
|
||||
|
||||
def mock_model_load(path, **kwargs):
|
||||
if "detection" in path:
|
||||
return mock_detection_model
|
||||
elif "brand" in path:
|
||||
return mock_brand_model
|
||||
elif "bodytype" in path:
|
||||
return mock_bodytype_model
|
||||
return Mock()
|
||||
|
||||
mock_torch_load.side_effect = mock_model_load
|
||||
|
||||
# Setup detection model predictions
|
||||
mock_detection_model.return_value = Mock()
|
||||
mock_detection_model.return_value.boxes = Mock()
|
||||
mock_detection_model.return_value.boxes.xyxy = Mock()
|
||||
mock_detection_model.return_value.boxes.conf = Mock()
|
||||
mock_detection_model.return_value.boxes.cls = Mock()
|
||||
mock_detection_model.return_value.names = {0: "Car", 1: "Frontal"}
|
||||
|
||||
# Mock detection results - Car and Frontal detected
|
||||
mock_detection_model.return_value.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
|
||||
[100, 200, 300, 400], # Car bbox
|
||||
[150, 250, 250, 350] # Frontal bbox
|
||||
])
|
||||
mock_detection_model.return_value.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
|
||||
mock_detection_model.return_value.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
|
||||
|
||||
# Setup classification model predictions
|
||||
mock_brand_result = Mock()
|
||||
mock_brand_result.probs = Mock()
|
||||
mock_brand_result.probs.top1 = 5 # Toyota index
|
||||
mock_brand_result.probs.top1conf = Mock()
|
||||
mock_brand_result.probs.top1conf.item.return_value = 0.87
|
||||
mock_brand_result.names = {5: "Toyota"}
|
||||
mock_brand_model.return_value = mock_brand_result
|
||||
|
||||
mock_bodytype_result = Mock()
|
||||
mock_bodytype_result.probs = Mock()
|
||||
mock_bodytype_result.probs.top1 = 2 # Sedan index
|
||||
mock_bodytype_result.probs.top1conf = Mock()
|
||||
mock_bodytype_result.probs.top1conf.item.return_value = 0.82
|
||||
mock_bodytype_result.names = {2: "Sedan"}
|
||||
mock_bodytype_model.return_value = mock_bodytype_result
|
||||
|
||||
# Setup database mock
|
||||
mock_db_conn = Mock()
|
||||
mock_db_connect.return_value = mock_db_conn
|
||||
mock_cursor = Mock()
|
||||
mock_db_conn.cursor.return_value = mock_cursor
|
||||
|
||||
# Setup Redis mock
|
||||
mock_redis_instance = Mock()
|
||||
mock_redis.return_value = mock_redis_instance
|
||||
mock_redis_instance.ping.return_value = True
|
||||
mock_redis_instance.set.return_value = True
|
||||
mock_redis_instance.expire.return_value = True
|
||||
|
||||
# Initialize managers
|
||||
stream_manager = StreamManager()
|
||||
model_manager = ModelManager()
|
||||
pipeline_executor = PipelineExecutor()
|
||||
|
||||
# Register services in container
|
||||
container.register_singleton(StreamManager, lambda: stream_manager)
|
||||
container.register_singleton(ModelManager, lambda: model_manager)
|
||||
container.register_singleton(PipelineExecutor, lambda: pipeline_executor)
|
||||
|
||||
try:
|
||||
# 1. Create RTSP stream
|
||||
stream_config = StreamConfig(
|
||||
stream_url="rtsp://example.com/stream",
|
||||
stream_type="rtsp",
|
||||
target_fps=10
|
||||
)
|
||||
|
||||
stream_info = await stream_manager.create_stream(
|
||||
camera_id="camera_001",
|
||||
config=stream_config,
|
||||
subscription_id="sub_001"
|
||||
)
|
||||
|
||||
# 2. Load pipeline and models
|
||||
pipeline_config = json.loads(Path(sample_mpta_file).read_text())
|
||||
|
||||
# Mock model file paths exist
|
||||
with patch('os.path.exists', return_value=True):
|
||||
await model_manager.load_models_from_config(pipeline_config)
|
||||
|
||||
# 3. Get frame from stream
|
||||
frame = stream_manager.get_latest_frame("camera_001")
|
||||
assert frame is not None
|
||||
|
||||
# 4. Run detection pipeline
|
||||
detection_context = {
|
||||
"camera_id": "camera_001",
|
||||
"display_id": "display_001",
|
||||
"frame": mock_frame,
|
||||
"timestamp": int(time.time() * 1000),
|
||||
"session_id": str(uuid.uuid4())
|
||||
}
|
||||
|
||||
# Mock cv2.imencode for Redis image storage
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
pipeline_result = await pipeline_executor.execute_pipeline(
|
||||
pipeline_config,
|
||||
detection_context
|
||||
)
|
||||
|
||||
# 5. Verify pipeline execution
|
||||
assert pipeline_result is not None
|
||||
assert pipeline_result.get("status") == "completed"
|
||||
|
||||
# 6. Verify database operations
|
||||
# Should have called create record
|
||||
create_calls = [call for call in mock_cursor.execute.call_args_list
|
||||
if "INSERT" in str(call)]
|
||||
assert len(create_calls) >= 1
|
||||
|
||||
# Should have called update with classification results
|
||||
update_calls = [call for call in mock_cursor.execute.call_args_list
|
||||
if "UPDATE" in str(call)]
|
||||
assert len(update_calls) >= 1
|
||||
|
||||
# 7. Verify Redis operations
|
||||
# Should have saved cropped image
|
||||
assert mock_redis_instance.set.called
|
||||
assert mock_redis_instance.expire.called
|
||||
|
||||
# 8. Verify final results
|
||||
assert "detections" in pipeline_result
|
||||
detections = pipeline_result["detections"]
|
||||
assert len(detections) >= 2 # Car and Frontal
|
||||
|
||||
# Check detection classes
|
||||
detected_classes = [d.get("class") for d in detections]
|
||||
assert "Car" in detected_classes
|
||||
assert "Frontal" in detected_classes
|
||||
|
||||
# 9. Verify classification results in context
|
||||
classification_results = pipeline_result.get("classification_results", {})
|
||||
assert "car_brand_cls_v1" in classification_results
|
||||
assert "car_bodytype_cls_v1" in classification_results
|
||||
|
||||
brand_result = classification_results["car_brand_cls_v1"]
|
||||
assert brand_result.get("brand") == "Toyota"
|
||||
assert brand_result.get("confidence") == 0.87
|
||||
|
||||
bodytype_result = classification_results["car_bodytype_cls_v1"]
|
||||
assert bodytype_result.get("body_type") == "Sedan"
|
||||
assert bodytype_result.get("confidence") == 0.82
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
await stream_manager.stop_all_streams()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_subscription_workflow(self, temp_config_file, sample_mpta_file):
|
||||
"""Test complete WebSocket subscription workflow."""
|
||||
|
||||
# Mock WebSocket for testing
|
||||
mock_websocket = Mock()
|
||||
mock_websocket.accept = AsyncMock()
|
||||
mock_websocket.send_json = AsyncMock()
|
||||
mock_websocket.receive_json = AsyncMock()
|
||||
|
||||
# Initialize configuration
|
||||
config = Configuration()
|
||||
config.load_from_file(temp_config_file)
|
||||
|
||||
# Initialize message processor and WebSocket handler
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
with patch('cv2.VideoCapture') as mock_video_cap, \
|
||||
patch('torch.load') as mock_torch_load, \
|
||||
patch('psycopg2.connect') as mock_db_connect, \
|
||||
patch('redis.Redis') as mock_redis:
|
||||
|
||||
# Setup mocks
|
||||
mock_cap_instance = Mock()
|
||||
mock_video_cap.return_value = mock_cap_instance
|
||||
mock_cap_instance.isOpened.return_value = True
|
||||
mock_cap_instance.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
|
||||
|
||||
mock_torch_load.return_value = Mock()
|
||||
mock_db_connect.return_value = Mock()
|
||||
mock_redis.return_value = Mock()
|
||||
|
||||
# Simulate WebSocket message sequence
|
||||
subscription_message = {
|
||||
"type": "subscribe",
|
||||
"payload": {
|
||||
"subscriptionIdentifier": "display-001;cam-001",
|
||||
"rtspUrl": "rtsp://example.com/stream",
|
||||
"modelUrl": f"file://{sample_mpta_file}",
|
||||
"modelId": 101,
|
||||
"modelName": "Vehicle Detection",
|
||||
"cropX1": 100, "cropY1": 200,
|
||||
"cropX2": 300, "cropY2": 400
|
||||
}
|
||||
}
|
||||
|
||||
mock_websocket.receive_json.side_effect = [
|
||||
subscription_message,
|
||||
{"type": "requestState"},
|
||||
{"type": "unsubscribe", "payload": {"subscriptionIdentifier": "display-001;cam-001"}}
|
||||
]
|
||||
|
||||
# Mock file operations
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(json.loads(Path(sample_mpta_file).read_text())))):
|
||||
with patch('os.path.exists', return_value=True):
|
||||
|
||||
try:
|
||||
# Start WebSocket handler (will process messages)
|
||||
await websocket_handler.handle_websocket(mock_websocket, "client_001")
|
||||
|
||||
# Verify WebSocket interactions
|
||||
mock_websocket.accept.assert_called_once()
|
||||
|
||||
# Should have sent subscription acknowledgment
|
||||
send_calls = mock_websocket.send_json.call_args_list
|
||||
assert len(send_calls) >= 1
|
||||
|
||||
# Check subscription acknowledgment
|
||||
first_response = send_calls[0][0][0]
|
||||
assert first_response.get("type") == "subscribeAck"
|
||||
assert first_response.get("status") == "success"
|
||||
|
||||
# Should have sent state report
|
||||
state_responses = [call[0][0] for call in send_calls
|
||||
if call[0][0].get("type") == "stateReport"]
|
||||
assert len(state_responses) >= 1
|
||||
|
||||
except Exception as e:
|
||||
# WebSocket disconnect is expected at end of message sequence
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_snapshot_workflow(self, temp_config_file, mock_frame):
|
||||
"""Test HTTP snapshot workflow."""
|
||||
|
||||
config = Configuration()
|
||||
config.load_from_file(temp_config_file)
|
||||
|
||||
stream_manager = StreamManager()
|
||||
|
||||
with patch('requests.get') as mock_requests:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"fake_jpeg_data"
|
||||
mock_requests.return_value = mock_response
|
||||
|
||||
with patch('cv2.imdecode') as mock_imdecode:
|
||||
mock_imdecode.return_value = mock_frame
|
||||
|
||||
try:
|
||||
# Create HTTP snapshot stream
|
||||
stream_config = StreamConfig(
|
||||
stream_url="http://camera.example.com/snapshot.jpg",
|
||||
stream_type="http_snapshot",
|
||||
snapshot_interval=1.0
|
||||
)
|
||||
|
||||
stream_info = await stream_manager.create_stream(
|
||||
camera_id="camera_002",
|
||||
config=stream_config,
|
||||
subscription_id="sub_002"
|
||||
)
|
||||
|
||||
# Wait for snapshot capture
|
||||
await asyncio.sleep(1.2)
|
||||
|
||||
# Verify frame was captured
|
||||
frame = stream_manager.get_latest_frame("camera_002")
|
||||
assert frame is not None
|
||||
assert np.array_equal(frame, mock_frame)
|
||||
|
||||
# Verify HTTP request was made
|
||||
mock_requests.assert_called()
|
||||
mock_imdecode.assert_called()
|
||||
|
||||
finally:
|
||||
await stream_manager.stop_all_streams()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_recovery_workflow(self, temp_config_file):
|
||||
"""Test error recovery and resilience."""
|
||||
|
||||
config = Configuration()
|
||||
config.load_from_file(temp_config_file)
|
||||
|
||||
stream_manager = StreamManager()
|
||||
|
||||
with patch('cv2.VideoCapture') as mock_video_cap:
|
||||
# Simulate connection failures then success
|
||||
mock_cap_instances = []
|
||||
|
||||
# First attempt fails
|
||||
mock_cap_fail = Mock()
|
||||
mock_cap_fail.isOpened.return_value = False
|
||||
mock_cap_instances.append(mock_cap_fail)
|
||||
|
||||
# Second attempt succeeds
|
||||
mock_cap_success = Mock()
|
||||
mock_cap_success.isOpened.return_value = True
|
||||
mock_cap_success.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
|
||||
mock_cap_instances.append(mock_cap_success)
|
||||
|
||||
mock_video_cap.side_effect = mock_cap_instances
|
||||
|
||||
try:
|
||||
stream_config = StreamConfig(
|
||||
stream_url="rtsp://unreliable.example.com/stream",
|
||||
stream_type="rtsp",
|
||||
max_retries=2
|
||||
)
|
||||
|
||||
# First attempt should fail
|
||||
try:
|
||||
await stream_manager.create_stream(
|
||||
camera_id="camera_003",
|
||||
config=stream_config,
|
||||
subscription_id="sub_003"
|
||||
)
|
||||
assert False, "Should have failed on first attempt"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Retry should succeed
|
||||
stream_info = await stream_manager.create_stream(
|
||||
camera_id="camera_003",
|
||||
config=stream_config,
|
||||
subscription_id="sub_003"
|
||||
)
|
||||
|
||||
assert stream_info is not None
|
||||
assert mock_video_cap.call_count == 2
|
||||
|
||||
finally:
|
||||
await stream_manager.stop_all_streams()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_streams_workflow(self, temp_config_file, mock_frame):
|
||||
"""Test handling multiple concurrent streams."""
|
||||
|
||||
config = Configuration()
|
||||
config.load_from_file(temp_config_file)
|
||||
|
||||
stream_manager = StreamManager()
|
||||
|
||||
with patch('cv2.VideoCapture') as mock_video_cap, \
|
||||
patch('requests.get') as mock_requests:
|
||||
|
||||
# Setup RTSP mock
|
||||
mock_cap_instance = Mock()
|
||||
mock_video_cap.return_value = mock_cap_instance
|
||||
mock_cap_instance.isOpened.return_value = True
|
||||
mock_cap_instance.read.return_value = (True, mock_frame)
|
||||
|
||||
# Setup HTTP mock
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b"fake_jpeg_data"
|
||||
mock_requests.return_value = mock_response
|
||||
|
||||
with patch('cv2.imdecode', return_value=mock_frame):
|
||||
|
||||
try:
|
||||
# Create multiple streams concurrently
|
||||
stream_tasks = []
|
||||
|
||||
# RTSP streams
|
||||
for i in range(3):
|
||||
config = StreamConfig(
|
||||
stream_url=f"rtsp://camera{i}.example.com/stream",
|
||||
stream_type="rtsp"
|
||||
)
|
||||
task = stream_manager.create_stream(
|
||||
camera_id=f"camera_rtsp_{i}",
|
||||
config=config,
|
||||
subscription_id=f"sub_rtsp_{i}"
|
||||
)
|
||||
stream_tasks.append(task)
|
||||
|
||||
# HTTP snapshot streams
|
||||
for i in range(2):
|
||||
config = StreamConfig(
|
||||
stream_url=f"http://camera{i}.example.com/snapshot.jpg",
|
||||
stream_type="http_snapshot"
|
||||
)
|
||||
task = stream_manager.create_stream(
|
||||
camera_id=f"camera_http_{i}",
|
||||
config=config,
|
||||
subscription_id=f"sub_http_{i}"
|
||||
)
|
||||
stream_tasks.append(task)
|
||||
|
||||
# Wait for all streams to be created
|
||||
stream_results = await asyncio.gather(*stream_tasks, return_exceptions=True)
|
||||
|
||||
# Verify all streams were created successfully
|
||||
successful_streams = [r for r in stream_results if not isinstance(r, Exception)]
|
||||
assert len(successful_streams) == 5
|
||||
|
||||
# Verify frames can be retrieved from all streams
|
||||
await asyncio.sleep(0.5) # Allow time for frame capture
|
||||
|
||||
for i in range(3):
|
||||
frame = stream_manager.get_latest_frame(f"camera_rtsp_{i}")
|
||||
assert frame is not None
|
||||
|
||||
for i in range(2):
|
||||
frame = stream_manager.get_latest_frame(f"camera_http_{i}")
|
||||
# HTTP snapshots might not have frames immediately
|
||||
|
||||
# Verify stream statistics
|
||||
stats = stream_manager.get_stream_statistics()
|
||||
assert stats["total_streams"] == 5
|
||||
assert stats["active_streams"] >= 3 # At least RTSP streams should be active
|
||||
|
||||
finally:
|
||||
await stream_manager.stop_all_streams()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_workflow(self, temp_config_file):
|
||||
"""Test memory usage tracking and cleanup."""
|
||||
|
||||
config = Configuration()
|
||||
config.load_from_file(temp_config_file)
|
||||
|
||||
# Create managers with small limits for testing
|
||||
stream_manager = StreamManager({"max_streams": 10})
|
||||
model_manager = ModelManager({"cache_max_size": 5})
|
||||
|
||||
with patch('cv2.VideoCapture') as mock_video_cap, \
|
||||
patch('torch.load') as mock_torch_load:
|
||||
|
||||
mock_cap_instance = Mock()
|
||||
mock_video_cap.return_value = mock_cap_instance
|
||||
mock_cap_instance.isOpened.return_value = True
|
||||
mock_cap_instance.read.return_value = (True, np.ones((100, 100, 3), dtype=np.uint8))
|
||||
|
||||
# Mock model loading
|
||||
def create_mock_model():
|
||||
model = Mock()
|
||||
# Mock model parameters for memory estimation
|
||||
param1 = Mock()
|
||||
param1.numel.return_value = 1000
|
||||
param1.element_size.return_value = 4
|
||||
model.parameters.return_value = [param1]
|
||||
return model
|
||||
|
||||
mock_torch_load.side_effect = lambda *args, **kwargs: create_mock_model()
|
||||
|
||||
try:
|
||||
# Create streams up to limit
|
||||
for i in range(8):
|
||||
stream_config = StreamConfig(
|
||||
stream_url=f"rtsp://test{i}.example.com/stream",
|
||||
stream_type="rtsp"
|
||||
)
|
||||
await stream_manager.create_stream(
|
||||
camera_id=f"test_camera_{i}",
|
||||
config=stream_config,
|
||||
subscription_id=f"test_sub_{i}"
|
||||
)
|
||||
|
||||
# Load models up to cache limit
|
||||
with patch('os.path.exists', return_value=True):
|
||||
for i in range(7):
|
||||
config = {
|
||||
"model_id": f"test_model_{i}",
|
||||
"model_path": f"/fake/path/model_{i}.pt",
|
||||
"model_type": "detection",
|
||||
"device": "cpu"
|
||||
}
|
||||
await model_manager.load_model_from_dict(config)
|
||||
|
||||
# Check memory usage tracking
|
||||
stream_stats = stream_manager.get_stream_statistics()
|
||||
model_stats = model_manager.get_cache_statistics()
|
||||
|
||||
assert stream_stats["total_streams"] == 8
|
||||
assert model_stats["size"] <= 5 # Should be limited by cache size
|
||||
|
||||
# Test cleanup
|
||||
cleaned_models = model_manager.cleanup_unused_models()
|
||||
assert cleaned_models >= 0
|
||||
|
||||
stopped_streams = await stream_manager.stop_all_streams()
|
||||
assert stopped_streams == 8
|
||||
|
||||
# Verify cleanup
|
||||
final_stream_stats = stream_manager.get_stream_statistics()
|
||||
assert final_stream_stats["total_streams"] == 0
|
||||
|
||||
finally:
|
||||
await stream_manager.stop_all_streams()
|
||||
|
||||
|
||||
def mock_open(read_data=""):
|
||||
"""Create a mock file opener."""
|
||||
from unittest.mock import mock_open as _mock_open
|
||||
return _mock_open(read_data=read_data)
|
738
tests/integration/test_pipeline_integration.py
Normal file
738
tests/integration/test_pipeline_integration.py
Normal file
|
@ -0,0 +1,738 @@
|
|||
"""
|
||||
Integration tests for pipeline execution workflows.
|
||||
|
||||
Tests the complete machine learning pipeline execution including
|
||||
detection, classification, database updates, and Redis actions.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
import uuid
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
import numpy as np
|
||||
|
||||
from detector_worker.pipeline.pipeline_executor import PipelineExecutor
|
||||
from detector_worker.pipeline.action_executor import ActionExecutor
|
||||
from detector_worker.pipeline.field_mapper import FieldMapper
|
||||
from detector_worker.models.model_manager import ModelManager
|
||||
from detector_worker.storage.database_manager import DatabaseManager
|
||||
from detector_worker.storage.redis_client import RedisClient, RedisConfig
|
||||
from detector_worker.detection.detection_result import DetectionResult, BoundingBox
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_detection_pipeline():
|
||||
"""Create sample detection pipeline configuration."""
|
||||
return {
|
||||
"modelId": "car_frontal_detection_v1",
|
||||
"modelFile": "car_frontal_detection_v1.pt",
|
||||
"multiClass": True,
|
||||
"expectedClasses": ["Car", "Frontal"],
|
||||
"triggerClasses": ["Car", "Frontal"],
|
||||
"minConfidence": 0.8,
|
||||
"actions": [
|
||||
{
|
||||
"type": "redis_save_image",
|
||||
"region": "Frontal",
|
||||
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
|
||||
"expire_seconds": 600
|
||||
},
|
||||
{
|
||||
"type": "postgresql_create_record",
|
||||
"table": "car_frontal_info",
|
||||
"fields": {
|
||||
"display_id": "{display_id}",
|
||||
"captured_timestamp": "{timestamp}",
|
||||
"session_id": "{session_id}",
|
||||
"license_character": None,
|
||||
"license_type": "No model available"
|
||||
}
|
||||
}
|
||||
],
|
||||
"branches": [
|
||||
{
|
||||
"modelId": "car_brand_cls_v1",
|
||||
"modelFile": "car_brand_cls_v1.pt",
|
||||
"parallel": True,
|
||||
"crop": True,
|
||||
"cropClass": "Frontal",
|
||||
"triggerClasses": ["Frontal"],
|
||||
"minConfidence": 0.85
|
||||
},
|
||||
{
|
||||
"modelId": "car_bodytype_cls_v1",
|
||||
"modelFile": "car_bodytype_cls_v1.pt",
|
||||
"parallel": True,
|
||||
"crop": True,
|
||||
"cropClass": "Frontal",
|
||||
"triggerClasses": ["Frontal"],
|
||||
"minConfidence": 0.80
|
||||
}
|
||||
],
|
||||
"parallelActions": [
|
||||
{
|
||||
"type": "postgresql_update_combined",
|
||||
"table": "car_frontal_info",
|
||||
"key_field": "session_id",
|
||||
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
|
||||
"fields": {
|
||||
"car_brand": "{car_brand_cls_v1.brand}",
|
||||
"car_body_type": "{car_bodytype_cls_v1.body_type}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_frame():
|
||||
"""Create sample frame for testing."""
|
||||
return np.ones((480, 640, 3), dtype=np.uint8) * 128
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def detection_context():
|
||||
"""Create sample detection context."""
|
||||
return {
|
||||
"camera_id": "camera_001",
|
||||
"display_id": "display_001",
|
||||
"timestamp": int(time.time() * 1000),
|
||||
"session_id": str(uuid.uuid4()),
|
||||
"frame": np.ones((480, 640, 3), dtype=np.uint8) * 128,
|
||||
"filename": "detection_image.jpg"
|
||||
}
|
||||
|
||||
|
||||
class TestPipelineIntegration:
|
||||
"""Test complete pipeline integration workflows."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_detection_classification_pipeline(self, sample_detection_pipeline, detection_context):
|
||||
"""Test complete detection to classification pipeline."""
|
||||
|
||||
pipeline_executor = PipelineExecutor()
|
||||
model_manager = ModelManager()
|
||||
|
||||
with patch('torch.load') as mock_torch_load, \
|
||||
patch('os.path.exists', return_value=True), \
|
||||
patch('psycopg2.connect') as mock_db_connect, \
|
||||
patch('redis.Redis') as mock_redis:
|
||||
|
||||
# Setup detection model mock
|
||||
mock_detection_model = Mock()
|
||||
mock_detection_result = Mock()
|
||||
|
||||
# Mock successful multi-class detection
|
||||
mock_detection_result.boxes = Mock()
|
||||
mock_detection_result.boxes.xyxy = Mock()
|
||||
mock_detection_result.boxes.conf = Mock()
|
||||
mock_detection_result.boxes.cls = Mock()
|
||||
mock_detection_result.names = {0: "Car", 1: "Frontal"}
|
||||
|
||||
# Detection results: Car and Frontal detected with high confidence
|
||||
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
|
||||
[50, 100, 350, 450], # Car bbox
|
||||
[150, 200, 300, 400] # Frontal bbox (within Car)
|
||||
])
|
||||
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
|
||||
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
|
||||
|
||||
mock_detection_model.return_value = mock_detection_result
|
||||
|
||||
# Setup classification models
|
||||
mock_brand_model = Mock()
|
||||
mock_brand_result = Mock()
|
||||
mock_brand_result.probs = Mock()
|
||||
mock_brand_result.probs.top1 = 3 # Toyota index
|
||||
mock_brand_result.probs.top1conf = Mock()
|
||||
mock_brand_result.probs.top1conf.item.return_value = 0.87
|
||||
mock_brand_result.names = {3: "Toyota"}
|
||||
mock_brand_model.return_value = mock_brand_result
|
||||
|
||||
mock_bodytype_model = Mock()
|
||||
mock_bodytype_result = Mock()
|
||||
mock_bodytype_result.probs = Mock()
|
||||
mock_bodytype_result.probs.top1 = 1 # Sedan index
|
||||
mock_bodytype_result.probs.top1conf = Mock()
|
||||
mock_bodytype_result.probs.top1conf.item.return_value = 0.82
|
||||
mock_bodytype_result.names = {1: "Sedan"}
|
||||
mock_bodytype_model.return_value = mock_bodytype_result
|
||||
|
||||
# Route model loading to appropriate mocks
|
||||
def model_loader(path, **kwargs):
|
||||
if "detection" in path:
|
||||
return mock_detection_model
|
||||
elif "brand" in path:
|
||||
return mock_brand_model
|
||||
elif "bodytype" in path:
|
||||
return mock_bodytype_model
|
||||
return Mock()
|
||||
|
||||
mock_torch_load.side_effect = model_loader
|
||||
|
||||
# Setup database mock
|
||||
mock_db_conn = Mock()
|
||||
mock_db_connect.return_value = mock_db_conn
|
||||
mock_cursor = Mock()
|
||||
mock_db_conn.cursor.return_value = mock_cursor
|
||||
mock_cursor.fetchone.return_value = None
|
||||
|
||||
# Setup Redis mock
|
||||
mock_redis_instance = Mock()
|
||||
mock_redis.return_value = mock_redis_instance
|
||||
mock_redis_instance.ping.return_value = True
|
||||
mock_redis_instance.set.return_value = True
|
||||
mock_redis_instance.expire.return_value = True
|
||||
|
||||
# Mock image encoding for Redis storage
|
||||
with patch('cv2.imencode') as mock_imencode:
|
||||
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
# Execute complete pipeline
|
||||
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
|
||||
|
||||
# Verify pipeline execution
|
||||
assert result is not None
|
||||
assert result.get("status") == "completed"
|
||||
assert "detections" in result
|
||||
|
||||
# Verify detection results
|
||||
detections = result["detections"]
|
||||
assert len(detections) == 2 # Car and Frontal
|
||||
|
||||
detection_classes = [d.get("class") for d in detections]
|
||||
assert "Car" in detection_classes
|
||||
assert "Frontal" in detection_classes
|
||||
|
||||
# Verify classification results
|
||||
assert "classification_results" in result
|
||||
classification_results = result["classification_results"]
|
||||
|
||||
assert "car_brand_cls_v1" in classification_results
|
||||
brand_result = classification_results["car_brand_cls_v1"]
|
||||
assert brand_result.get("brand") == "Toyota"
|
||||
assert brand_result.get("confidence") == 0.87
|
||||
|
||||
assert "car_bodytype_cls_v1" in classification_results
|
||||
bodytype_result = classification_results["car_bodytype_cls_v1"]
|
||||
assert bodytype_result.get("body_type") == "Sedan"
|
||||
assert bodytype_result.get("confidence") == 0.82
|
||||
|
||||
# Verify database operations
|
||||
db_calls = mock_cursor.execute.call_args_list
|
||||
|
||||
# Should have INSERT for initial record creation
|
||||
insert_calls = [call for call in db_calls if "INSERT" in str(call[0])]
|
||||
assert len(insert_calls) >= 1
|
||||
|
||||
# Should have UPDATE for classification results
|
||||
update_calls = [call for call in db_calls if "UPDATE" in str(call[0])]
|
||||
assert len(update_calls) >= 1
|
||||
|
||||
# Verify Redis operations
|
||||
assert mock_redis_instance.set.called
|
||||
assert mock_redis_instance.expire.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_with_missing_detections(self, sample_detection_pipeline, detection_context):
|
||||
"""Test pipeline behavior when expected detections are missing."""
|
||||
|
||||
pipeline_executor = PipelineExecutor()
|
||||
|
||||
with patch('torch.load') as mock_torch_load, \
|
||||
patch('os.path.exists', return_value=True):
|
||||
|
||||
# Setup detection model that doesn't find expected classes
|
||||
mock_detection_model = Mock()
|
||||
mock_detection_result = Mock()
|
||||
mock_detection_result.boxes = Mock()
|
||||
mock_detection_result.boxes.xyxy = Mock()
|
||||
mock_detection_result.boxes.conf = Mock()
|
||||
mock_detection_result.boxes.cls = Mock()
|
||||
mock_detection_result.names = {0: "Car", 1: "Frontal"}
|
||||
|
||||
# Only detect Car, no Frontal
|
||||
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
|
||||
[50, 100, 350, 450] # Only Car bbox
|
||||
])
|
||||
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92])
|
||||
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0])
|
||||
|
||||
mock_detection_model.return_value = mock_detection_result
|
||||
mock_torch_load.return_value = mock_detection_model
|
||||
|
||||
# Execute pipeline
|
||||
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
|
||||
|
||||
# Pipeline should complete but skip classification branches
|
||||
assert result is not None
|
||||
assert "detections" in result
|
||||
|
||||
detections = result["detections"]
|
||||
assert len(detections) == 1 # Only Car detected
|
||||
assert detections[0].get("class") == "Car"
|
||||
|
||||
# Classification should not have run (no Frontal detected)
|
||||
classification_results = result.get("classification_results", {})
|
||||
assert len(classification_results) == 0 or all(
|
||||
not res for res in classification_results.values()
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_with_low_confidence_detections(self, sample_detection_pipeline, detection_context):
|
||||
"""Test pipeline with detections below confidence threshold."""
|
||||
|
||||
pipeline_executor = PipelineExecutor()
|
||||
|
||||
with patch('torch.load') as mock_torch_load, \
|
||||
patch('os.path.exists', return_value=True):
|
||||
|
||||
mock_detection_model = Mock()
|
||||
mock_detection_result = Mock()
|
||||
mock_detection_result.boxes = Mock()
|
||||
mock_detection_result.boxes.xyxy = Mock()
|
||||
mock_detection_result.boxes.conf = Mock()
|
||||
mock_detection_result.boxes.cls = Mock()
|
||||
mock_detection_result.names = {0: "Car", 1: "Frontal"}
|
||||
|
||||
# Detections with low confidence (below 0.8 threshold)
|
||||
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
|
||||
[50, 100, 350, 450], # Car bbox
|
||||
[150, 200, 300, 400] # Frontal bbox
|
||||
])
|
||||
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.75, 0.70]) # Below threshold
|
||||
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
|
||||
|
||||
mock_detection_model.return_value = mock_detection_result
|
||||
mock_torch_load.return_value = mock_detection_model
|
||||
|
||||
# Execute pipeline
|
||||
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
|
||||
|
||||
# Should complete but with filtered detections
|
||||
assert result is not None
|
||||
|
||||
# Low confidence detections should be filtered out
|
||||
detections = result.get("detections", [])
|
||||
high_conf_detections = [d for d in detections if d.get("confidence", 0) >= 0.8]
|
||||
assert len(high_conf_detections) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_branch_execution_order(self, sample_detection_pipeline, detection_context):
|
||||
"""Test that pipeline branches execute in correct order and parallel mode works."""
|
||||
|
||||
pipeline_executor = PipelineExecutor()
|
||||
|
||||
with patch('torch.load') as mock_torch_load, \
|
||||
patch('os.path.exists', return_value=True), \
|
||||
patch('psycopg2.connect') as mock_db_connect:
|
||||
|
||||
# Track execution order
|
||||
execution_order = []
|
||||
|
||||
# Setup detection model
|
||||
mock_detection_model = Mock()
|
||||
mock_detection_result = Mock()
|
||||
mock_detection_result.boxes = Mock()
|
||||
mock_detection_result.boxes.xyxy = Mock()
|
||||
mock_detection_result.boxes.conf = Mock()
|
||||
mock_detection_result.boxes.cls = Mock()
|
||||
mock_detection_result.names = {0: "Car", 1: "Frontal"}
|
||||
|
||||
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
|
||||
[50, 100, 350, 450], [150, 200, 300, 400]
|
||||
])
|
||||
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
|
||||
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
|
||||
|
||||
def track_detection_execution(*args, **kwargs):
|
||||
execution_order.append("detection")
|
||||
return mock_detection_result
|
||||
|
||||
mock_detection_model.side_effect = track_detection_execution
|
||||
|
||||
# Setup classification models with execution tracking
|
||||
def create_tracked_model(model_id):
|
||||
def track_execution(*args, **kwargs):
|
||||
execution_order.append(model_id)
|
||||
result = Mock()
|
||||
result.probs = Mock()
|
||||
result.probs.top1 = 0
|
||||
result.probs.top1conf = Mock()
|
||||
result.probs.top1conf.item.return_value = 0.90
|
||||
result.names = {0: "TestResult"}
|
||||
return result
|
||||
|
||||
model = Mock()
|
||||
model.side_effect = track_execution
|
||||
return model
|
||||
|
||||
# Route models with execution tracking
|
||||
def model_loader(path, **kwargs):
|
||||
if "detection" in path:
|
||||
return mock_detection_model
|
||||
elif "brand" in path:
|
||||
return create_tracked_model("car_brand_cls_v1")
|
||||
elif "bodytype" in path:
|
||||
return create_tracked_model("car_bodytype_cls_v1")
|
||||
return Mock()
|
||||
|
||||
mock_torch_load.side_effect = model_loader
|
||||
|
||||
# Setup database mock
|
||||
mock_db_conn = Mock()
|
||||
mock_db_connect.return_value = mock_db_conn
|
||||
mock_cursor = Mock()
|
||||
mock_db_conn.cursor.return_value = mock_cursor
|
||||
|
||||
# Execute pipeline
|
||||
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
|
||||
|
||||
# Verify execution order
|
||||
assert "detection" in execution_order
|
||||
assert execution_order[0] == "detection" # Detection should run first
|
||||
|
||||
# Classification models should run after detection
|
||||
brand_index = execution_order.index("car_brand_cls_v1") if "car_brand_cls_v1" in execution_order else -1
|
||||
bodytype_index = execution_order.index("car_bodytype_cls_v1") if "car_bodytype_cls_v1" in execution_order else -1
|
||||
detection_index = execution_order.index("detection")
|
||||
|
||||
if brand_index >= 0:
|
||||
assert brand_index > detection_index
|
||||
if bodytype_index >= 0:
|
||||
assert bodytype_index > detection_index
|
||||
|
||||
# Since branches are parallel, they could run in any order relative to each other
|
||||
# but both should run after detection
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_error_recovery(self, sample_detection_pipeline, detection_context):
|
||||
"""Test pipeline error handling and recovery."""
|
||||
|
||||
pipeline_executor = PipelineExecutor()
|
||||
|
||||
with patch('torch.load') as mock_torch_load, \
|
||||
patch('os.path.exists', return_value=True), \
|
||||
patch('psycopg2.connect') as mock_db_connect:
|
||||
|
||||
# Setup detection model that works
|
||||
mock_detection_model = Mock()
|
||||
mock_detection_result = Mock()
|
||||
mock_detection_result.boxes = Mock()
|
||||
mock_detection_result.boxes.xyxy = Mock()
|
||||
mock_detection_result.boxes.conf = Mock()
|
||||
mock_detection_result.boxes.cls = Mock()
|
||||
mock_detection_result.names = {0: "Car", 1: "Frontal"}
|
||||
|
||||
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
|
||||
[50, 100, 350, 450], [150, 200, 300, 400]
|
||||
])
|
||||
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
|
||||
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
|
||||
|
||||
mock_detection_model.return_value = mock_detection_result
|
||||
|
||||
# Setup classification models - one fails, one succeeds
|
||||
mock_brand_model = Mock()
|
||||
mock_brand_model.side_effect = RuntimeError("Model inference failed")
|
||||
|
||||
mock_bodytype_model = Mock()
|
||||
mock_bodytype_result = Mock()
|
||||
mock_bodytype_result.probs = Mock()
|
||||
mock_bodytype_result.probs.top1 = 1
|
||||
mock_bodytype_result.probs.top1conf = Mock()
|
||||
mock_bodytype_result.probs.top1conf.item.return_value = 0.85
|
||||
mock_bodytype_result.names = {1: "SUV"}
|
||||
mock_bodytype_model.return_value = mock_bodytype_result
|
||||
|
||||
def model_loader(path, **kwargs):
|
||||
if "detection" in path:
|
||||
return mock_detection_model
|
||||
elif "brand" in path:
|
||||
return mock_brand_model
|
||||
elif "bodytype" in path:
|
||||
return mock_bodytype_model
|
||||
return Mock()
|
||||
|
||||
mock_torch_load.side_effect = model_loader
|
||||
|
||||
# Setup database mock
|
||||
mock_db_conn = Mock()
|
||||
mock_db_connect.return_value = mock_db_conn
|
||||
mock_cursor = Mock()
|
||||
mock_db_conn.cursor.return_value = mock_cursor
|
||||
|
||||
# Execute pipeline
|
||||
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
|
||||
|
||||
# Pipeline should complete despite one branch failing
|
||||
assert result is not None
|
||||
|
||||
# Detection should succeed
|
||||
assert "detections" in result
|
||||
detections = result["detections"]
|
||||
assert len(detections) == 2
|
||||
|
||||
# Classification results should be partial
|
||||
classification_results = result.get("classification_results", {})
|
||||
|
||||
# Brand classification should have failed
|
||||
brand_result = classification_results.get("car_brand_cls_v1")
|
||||
assert brand_result is None or brand_result.get("error") is not None
|
||||
|
||||
# Body type classification should have succeeded
|
||||
bodytype_result = classification_results.get("car_bodytype_cls_v1")
|
||||
assert bodytype_result is not None
|
||||
assert bodytype_result.get("body_type") == "SUV"
|
||||
assert bodytype_result.get("confidence") == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field_mapping_and_database_update(self, sample_detection_pipeline, detection_context):
|
||||
"""Test field mapping and database update integration."""
|
||||
|
||||
pipeline_executor = PipelineExecutor()
|
||||
field_mapper = FieldMapper()
|
||||
|
||||
with patch('torch.load') as mock_torch_load, \
|
||||
patch('os.path.exists', return_value=True), \
|
||||
patch('psycopg2.connect') as mock_db_connect:
|
||||
|
||||
# Setup successful detection and classification
|
||||
mock_detection_model = Mock()
|
||||
mock_detection_result = Mock()
|
||||
mock_detection_result.boxes = Mock()
|
||||
mock_detection_result.boxes.xyxy = Mock()
|
||||
mock_detection_result.boxes.conf = Mock()
|
||||
mock_detection_result.boxes.cls = Mock()
|
||||
mock_detection_result.names = {0: "Car", 1: "Frontal"}
|
||||
|
||||
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
|
||||
[50, 100, 350, 450], [150, 200, 300, 400]
|
||||
])
|
||||
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
|
||||
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
|
||||
mock_detection_model.return_value = mock_detection_result
|
||||
|
||||
# Setup classification models
|
||||
mock_brand_model = Mock()
|
||||
mock_brand_result = Mock()
|
||||
mock_brand_result.probs = Mock()
|
||||
mock_brand_result.probs.top1 = 2
|
||||
mock_brand_result.probs.top1conf = Mock()
|
||||
mock_brand_result.probs.top1conf.item.return_value = 0.88
|
||||
mock_brand_result.names = {2: "Honda"}
|
||||
mock_brand_model.return_value = mock_brand_result
|
||||
|
||||
mock_bodytype_model = Mock()
|
||||
mock_bodytype_result = Mock()
|
||||
mock_bodytype_result.probs = Mock()
|
||||
mock_bodytype_result.probs.top1 = 0
|
||||
mock_bodytype_result.probs.top1conf = Mock()
|
||||
mock_bodytype_result.probs.top1conf.item.return_value = 0.91
|
||||
mock_bodytype_result.names = {0: "Hatchback"}
|
||||
mock_bodytype_model.return_value = mock_bodytype_result
|
||||
|
||||
def model_loader(path, **kwargs):
|
||||
if "detection" in path:
|
||||
return mock_detection_model
|
||||
elif "brand" in path:
|
||||
return mock_brand_model
|
||||
elif "bodytype" in path:
|
||||
return mock_bodytype_model
|
||||
return Mock()
|
||||
|
||||
mock_torch_load.side_effect = model_loader
|
||||
|
||||
# Setup database mock
|
||||
mock_db_conn = Mock()
|
||||
mock_db_connect.return_value = mock_db_conn
|
||||
mock_cursor = Mock()
|
||||
mock_db_conn.cursor.return_value = mock_cursor
|
||||
|
||||
# Execute pipeline
|
||||
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
|
||||
|
||||
# Verify pipeline completed successfully
|
||||
assert result is not None
|
||||
assert result.get("status") == "completed"
|
||||
|
||||
# Check database operations
|
||||
db_calls = mock_cursor.execute.call_args_list
|
||||
|
||||
# Should have INSERT and UPDATE operations
|
||||
insert_calls = [call for call in db_calls if "INSERT" in str(call[0])]
|
||||
update_calls = [call for call in db_calls if "UPDATE" in str(call[0])]
|
||||
|
||||
assert len(insert_calls) >= 1
|
||||
assert len(update_calls) >= 1
|
||||
|
||||
# Check that UPDATE includes field mapping results
|
||||
update_sql = str(update_calls[0][0])
|
||||
assert "car_brand" in update_sql.lower()
|
||||
assert "car_body_type" in update_sql.lower()
|
||||
|
||||
# Check that classification results were properly mapped
|
||||
classification_results = result.get("classification_results", {})
|
||||
assert "car_brand_cls_v1" in classification_results
|
||||
assert "car_bodytype_cls_v1" in classification_results
|
||||
|
||||
brand_result = classification_results["car_brand_cls_v1"]
|
||||
bodytype_result = classification_results["car_bodytype_cls_v1"]
|
||||
|
||||
assert brand_result.get("brand") == "Honda"
|
||||
assert brand_result.get("confidence") == 0.88
|
||||
assert bodytype_result.get("body_type") == "Hatchback"
|
||||
assert bodytype_result.get("confidence") == 0.91
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_image_storage_integration(self, sample_detection_pipeline, detection_context):
|
||||
"""Test Redis image storage integration in pipeline."""
|
||||
|
||||
pipeline_executor = PipelineExecutor()
|
||||
|
||||
with patch('torch.load') as mock_torch_load, \
|
||||
patch('os.path.exists', return_value=True), \
|
||||
patch('redis.Redis') as mock_redis, \
|
||||
patch('cv2.imencode') as mock_imencode:
|
||||
|
||||
# Setup successful detection
|
||||
mock_detection_model = Mock()
|
||||
mock_detection_result = Mock()
|
||||
mock_detection_result.boxes = Mock()
|
||||
mock_detection_result.boxes.xyxy = Mock()
|
||||
mock_detection_result.boxes.conf = Mock()
|
||||
mock_detection_result.boxes.cls = Mock()
|
||||
mock_detection_result.names = {0: "Car", 1: "Frontal"}
|
||||
|
||||
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
|
||||
[50, 100, 350, 450], [150, 200, 300, 400]
|
||||
])
|
||||
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
|
||||
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
|
||||
mock_detection_model.return_value = mock_detection_result
|
||||
|
||||
mock_torch_load.return_value = mock_detection_model
|
||||
|
||||
# Setup Redis mock
|
||||
mock_redis_instance = Mock()
|
||||
mock_redis.return_value = mock_redis_instance
|
||||
mock_redis_instance.ping.return_value = True
|
||||
mock_redis_instance.set.return_value = True
|
||||
mock_redis_instance.expire.return_value = True
|
||||
|
||||
# Setup image encoding mock
|
||||
encoded_data = np.array([1, 2, 3, 4, 5], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
# Execute pipeline
|
||||
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
|
||||
|
||||
# Verify Redis operations
|
||||
assert mock_redis_instance.set.called
|
||||
assert mock_redis_instance.expire.called
|
||||
|
||||
# Check that image was encoded
|
||||
assert mock_imencode.called
|
||||
|
||||
# Verify correct key format was used
|
||||
set_call = mock_redis_instance.set.call_args
|
||||
redis_key = set_call[0][0]
|
||||
|
||||
# Key should contain display_id, timestamp, session_id
|
||||
assert detection_context["display_id"] in redis_key
|
||||
assert detection_context["session_id"] in redis_key
|
||||
assert str(detection_context["timestamp"]) in redis_key
|
||||
|
||||
# Should set expiration
|
||||
expire_call = mock_redis_instance.expire.call_args
|
||||
expire_key = expire_call[0][0]
|
||||
expire_seconds = expire_call[0][1]
|
||||
|
||||
assert expire_key == redis_key
|
||||
assert expire_seconds == 600 # As configured in pipeline
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_performance_timing(self, sample_detection_pipeline, detection_context):
|
||||
"""Test pipeline execution timing and performance."""
|
||||
|
||||
pipeline_executor = PipelineExecutor()
|
||||
|
||||
with patch('torch.load') as mock_torch_load, \
|
||||
patch('os.path.exists', return_value=True), \
|
||||
patch('psycopg2.connect') as mock_db_connect, \
|
||||
patch('redis.Redis') as mock_redis, \
|
||||
patch('cv2.imencode') as mock_imencode:
|
||||
|
||||
# Setup fast mocks
|
||||
mock_detection_model = Mock()
|
||||
mock_detection_result = Mock()
|
||||
mock_detection_result.boxes = Mock()
|
||||
mock_detection_result.boxes.xyxy = Mock()
|
||||
mock_detection_result.boxes.conf = Mock()
|
||||
mock_detection_result.boxes.cls = Mock()
|
||||
mock_detection_result.names = {0: "Car", 1: "Frontal"}
|
||||
|
||||
mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
|
||||
[50, 100, 350, 450], [150, 200, 300, 400]
|
||||
])
|
||||
mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
|
||||
mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
|
||||
mock_detection_model.return_value = mock_detection_result
|
||||
|
||||
# Setup fast classification models
|
||||
def create_fast_model():
|
||||
model = Mock()
|
||||
result = Mock()
|
||||
result.probs = Mock()
|
||||
result.probs.top1 = 0
|
||||
result.probs.top1conf = Mock()
|
||||
result.probs.top1conf.item.return_value = 0.90
|
||||
result.names = {0: "TestClass"}
|
||||
model.return_value = result
|
||||
return model
|
||||
|
||||
def model_loader(path, **kwargs):
|
||||
if "detection" in path:
|
||||
return mock_detection_model
|
||||
else:
|
||||
return create_fast_model()
|
||||
|
||||
mock_torch_load.side_effect = model_loader
|
||||
|
||||
# Setup fast database and Redis
|
||||
mock_db_conn = Mock()
|
||||
mock_db_connect.return_value = mock_db_conn
|
||||
mock_cursor = Mock()
|
||||
mock_db_conn.cursor.return_value = mock_cursor
|
||||
|
||||
mock_redis_instance = Mock()
|
||||
mock_redis.return_value = mock_redis_instance
|
||||
mock_redis_instance.ping.return_value = True
|
||||
mock_redis_instance.set.return_value = True
|
||||
mock_redis_instance.expire.return_value = True
|
||||
|
||||
encoded_data = np.array([1, 2, 3], dtype=np.uint8)
|
||||
mock_imencode.return_value = (True, encoded_data)
|
||||
|
||||
# Measure execution time
|
||||
start_time = time.time()
|
||||
|
||||
result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Pipeline should complete quickly (less than 1 second with mocks)
|
||||
assert execution_time < 1.0
|
||||
|
||||
# Should have timing information in result
|
||||
assert result is not None
|
||||
if "execution_time" in result:
|
||||
assert result["execution_time"] > 0
|
||||
|
||||
# Verify pipeline completed successfully
|
||||
assert result.get("status") == "completed"
|
579
tests/integration/test_websocket_protocol.py
Normal file
579
tests/integration/test_websocket_protocol.py
Normal file
|
@ -0,0 +1,579 @@
|
|||
"""
|
||||
Integration tests for WebSocket protocol compliance.
|
||||
|
||||
Tests the complete WebSocket communication protocol to ensure
|
||||
compatibility with existing clients and proper message handling.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from fastapi.websockets import WebSocket
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from detector_worker.app import create_app
|
||||
from detector_worker.communication.websocket_handler import WebSocketHandler
|
||||
from detector_worker.communication.message_processor import MessageProcessor, MessageType
|
||||
from detector_worker.core.exceptions import MessageProcessingError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app():
|
||||
"""Create test FastAPI application."""
|
||||
return create_app()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
"""Create mock WebSocket for testing."""
|
||||
websocket = Mock(spec=WebSocket)
|
||||
websocket.accept = AsyncMock()
|
||||
websocket.send_json = AsyncMock()
|
||||
websocket.send_text = AsyncMock()
|
||||
websocket.receive_json = AsyncMock()
|
||||
websocket.receive_text = AsyncMock()
|
||||
websocket.close = AsyncMock()
|
||||
websocket.ping = AsyncMock()
|
||||
return websocket
|
||||
|
||||
|
||||
class TestWebSocketProtocol:
|
||||
"""Test WebSocket protocol compliance."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_message_protocol(self, mock_websocket):
|
||||
"""Test subscription message handling protocol."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
# Mock external dependencies
|
||||
with patch('cv2.VideoCapture') as mock_video_cap, \
|
||||
patch('torch.load') as mock_torch_load, \
|
||||
patch('builtins.open') as mock_file_open:
|
||||
|
||||
# Setup video capture mock
|
||||
mock_cap_instance = Mock()
|
||||
mock_video_cap.return_value = mock_cap_instance
|
||||
mock_cap_instance.isOpened.return_value = True
|
||||
|
||||
# Setup model loading mock
|
||||
mock_torch_load.return_value = Mock()
|
||||
|
||||
# Setup pipeline file mock
|
||||
pipeline_config = {
|
||||
"modelId": "test_detection_model",
|
||||
"modelFile": "test_model.pt",
|
||||
"expectedClasses": ["Car"],
|
||||
"minConfidence": 0.8
|
||||
}
|
||||
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
|
||||
|
||||
# Test message sequence
|
||||
subscription_message = {
|
||||
"type": "subscribe",
|
||||
"payload": {
|
||||
"subscriptionIdentifier": "display-001;cam-001",
|
||||
"rtspUrl": "rtsp://example.com/stream",
|
||||
"modelUrl": "http://example.com/model.mpta",
|
||||
"modelId": 101,
|
||||
"modelName": "Test Detection",
|
||||
"cropX1": 0,
|
||||
"cropY1": 0,
|
||||
"cropX2": 640,
|
||||
"cropY2": 480
|
||||
}
|
||||
}
|
||||
|
||||
request_state_message = {
|
||||
"type": "requestState"
|
||||
}
|
||||
|
||||
unsubscribe_message = {
|
||||
"type": "unsubscribe",
|
||||
"payload": {
|
||||
"subscriptionIdentifier": "display-001;cam-001"
|
||||
}
|
||||
}
|
||||
|
||||
# Mock WebSocket message sequence
|
||||
mock_websocket.receive_json.side_effect = [
|
||||
subscription_message,
|
||||
request_state_message,
|
||||
unsubscribe_message,
|
||||
asyncio.CancelledError() # Simulate client disconnect
|
||||
]
|
||||
|
||||
client_id = "test_client_001"
|
||||
|
||||
try:
|
||||
# Handle WebSocket connection
|
||||
await websocket_handler.handle_websocket(mock_websocket, client_id)
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected when client disconnects
|
||||
|
||||
# Verify protocol compliance
|
||||
mock_websocket.accept.assert_called_once()
|
||||
|
||||
# Check sent messages
|
||||
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
|
||||
|
||||
# Should receive subscription acknowledgment
|
||||
subscribe_acks = [msg for msg in sent_messages if msg.get("type") == "subscribeAck"]
|
||||
assert len(subscribe_acks) >= 1
|
||||
|
||||
subscribe_ack = subscribe_acks[0]
|
||||
assert subscribe_ack["status"] in ["success", "error"]
|
||||
if subscribe_ack["status"] == "success":
|
||||
assert "subscriptionId" in subscribe_ack
|
||||
|
||||
# Should receive state report
|
||||
state_reports = [msg for msg in sent_messages if msg.get("type") == "stateReport"]
|
||||
assert len(state_reports) >= 1
|
||||
|
||||
state_report = state_reports[0]
|
||||
assert "payload" in state_report
|
||||
assert "subscriptions" in state_report["payload"]
|
||||
assert "system" in state_report["payload"]
|
||||
|
||||
# Should receive unsubscribe acknowledgment
|
||||
unsubscribe_acks = [msg for msg in sent_messages if msg.get("type") == "unsubscribeAck"]
|
||||
assert len(unsubscribe_acks) >= 1
|
||||
|
||||
unsubscribe_ack = unsubscribe_acks[0]
|
||||
assert unsubscribe_ack["status"] in ["success", "error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_message_handling(self, mock_websocket):
|
||||
"""Test handling of invalid messages."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
# Test invalid message types
|
||||
invalid_messages = [
|
||||
{"invalid": "message"}, # Missing type
|
||||
{"type": "unknown_type", "payload": {}}, # Unknown type
|
||||
{"type": "subscribe"}, # Missing payload
|
||||
{"type": "subscribe", "payload": {}}, # Missing required fields
|
||||
{"type": "subscribe", "payload": {"subscriptionIdentifier": "test"}}, # Missing URL
|
||||
]
|
||||
|
||||
mock_websocket.receive_json.side_effect = invalid_messages + [asyncio.CancelledError()]
|
||||
|
||||
client_id = "test_client_error"
|
||||
|
||||
try:
|
||||
await websocket_handler.handle_websocket(mock_websocket, client_id)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Verify error responses
|
||||
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
|
||||
error_messages = [msg for msg in sent_messages if msg.get("type") == "error"]
|
||||
|
||||
# Should receive error responses for invalid messages
|
||||
assert len(error_messages) >= len(invalid_messages)
|
||||
|
||||
for error_msg in error_messages:
|
||||
assert "message" in error_msg
|
||||
assert error_msg["message"] # Non-empty error message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_management_protocol(self, mock_websocket):
|
||||
"""Test session management protocol."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Test session messages
|
||||
set_session_message = {
|
||||
"type": "setSessionId",
|
||||
"payload": {
|
||||
"sessionId": session_id,
|
||||
"displayId": "display-001"
|
||||
}
|
||||
}
|
||||
|
||||
patch_session_message = {
|
||||
"type": "patchSession",
|
||||
"payload": {
|
||||
"sessionId": session_id,
|
||||
"data": {
|
||||
"car_brand": "Toyota",
|
||||
"confidence": 0.92
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mock_websocket.receive_json.side_effect = [
|
||||
set_session_message,
|
||||
patch_session_message,
|
||||
asyncio.CancelledError()
|
||||
]
|
||||
|
||||
client_id = "test_client_session"
|
||||
|
||||
try:
|
||||
await websocket_handler.handle_websocket(mock_websocket, client_id)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Verify session responses
|
||||
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
|
||||
|
||||
# Should receive acknowledgments for session operations
|
||||
set_session_acks = [msg for msg in sent_messages
|
||||
if msg.get("type") == "setSessionIdAck"]
|
||||
assert len(set_session_acks) >= 1
|
||||
|
||||
patch_session_acks = [msg for msg in sent_messages
|
||||
if msg.get("type") == "patchSessionAck"]
|
||||
assert len(patch_session_acks) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_protocol(self, mock_websocket):
|
||||
"""Test heartbeat/ping-pong protocol."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor, {"heartbeat_interval": 0.1})
|
||||
|
||||
# Simulate long-running connection
|
||||
mock_websocket.receive_json.side_effect = [
|
||||
asyncio.TimeoutError(), # No messages for a while
|
||||
asyncio.TimeoutError(),
|
||||
asyncio.CancelledError() # Then disconnect
|
||||
]
|
||||
|
||||
client_id = "test_client_heartbeat"
|
||||
|
||||
# Start heartbeat task
|
||||
heartbeat_task = asyncio.create_task(websocket_handler._heartbeat_loop())
|
||||
|
||||
try:
|
||||
# Let heartbeat run briefly
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Cancel heartbeat
|
||||
heartbeat_task.cancel()
|
||||
await heartbeat_task
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Verify ping was sent
|
||||
assert mock_websocket.ping.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detection_result_protocol(self, mock_websocket):
|
||||
"""Test detection result message protocol."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
# Mock a detection result being sent
|
||||
detection_result = {
|
||||
"type": "imageDetection",
|
||||
"payload": {
|
||||
"subscriptionId": "display-001;cam-001",
|
||||
"detections": [
|
||||
{
|
||||
"class": "Car",
|
||||
"confidence": 0.95,
|
||||
"bbox": [100, 200, 300, 400],
|
||||
"trackId": 1001
|
||||
}
|
||||
],
|
||||
"timestamp": 1640995200000,
|
||||
"modelInfo": {
|
||||
"modelId": 101,
|
||||
"modelName": "Vehicle Detection"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Send detection result to client
|
||||
await websocket_handler.send_to_client("test_client", detection_result)
|
||||
|
||||
# Verify message was sent
|
||||
mock_websocket.send_json.assert_called_with(detection_result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_recovery_protocol(self, mock_websocket):
|
||||
"""Test error recovery and graceful degradation."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
# Simulate WebSocket errors
|
||||
mock_websocket.send_json.side_effect = [
|
||||
None, # First message succeeds
|
||||
ConnectionError("Connection lost"), # Second fails
|
||||
None # Third succeeds after recovery
|
||||
]
|
||||
|
||||
# Try to send multiple messages
|
||||
messages = [
|
||||
{"type": "test", "message": "1"},
|
||||
{"type": "test", "message": "2"},
|
||||
{"type": "test", "message": "3"}
|
||||
]
|
||||
|
||||
results = []
|
||||
for msg in messages:
|
||||
try:
|
||||
await websocket_handler.send_to_client("test_client", msg)
|
||||
results.append("success")
|
||||
except Exception:
|
||||
results.append("error")
|
||||
|
||||
# Should handle errors gracefully
|
||||
assert "error" in results
|
||||
# But should still be able to send other messages
|
||||
assert "success" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_client_protocol(self, mock_websocket):
|
||||
"""Test handling multiple concurrent clients."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
# Create multiple mock WebSocket connections
|
||||
mock_websockets = []
|
||||
for i in range(3):
|
||||
ws = Mock(spec=WebSocket)
|
||||
ws.accept = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
ws.receive_json = AsyncMock(side_effect=[asyncio.CancelledError()])
|
||||
mock_websockets.append(ws)
|
||||
|
||||
# Handle multiple clients concurrently
|
||||
client_tasks = []
|
||||
for i, ws in enumerate(mock_websockets):
|
||||
task = asyncio.create_task(
|
||||
websocket_handler.handle_websocket(ws, f"client_{i}")
|
||||
)
|
||||
client_tasks.append(task)
|
||||
|
||||
# Wait briefly then cancel all
|
||||
await asyncio.sleep(0.1)
|
||||
for task in client_tasks:
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await asyncio.gather(*client_tasks)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Verify all connections were accepted
|
||||
for ws in mock_websockets:
|
||||
ws.accept.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_sharing_protocol(self, mock_websocket):
|
||||
"""Test shared subscription protocol."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
# Mock multiple clients subscribing to same camera
|
||||
with patch('cv2.VideoCapture') as mock_video_cap:
|
||||
mock_cap_instance = Mock()
|
||||
mock_video_cap.return_value = mock_cap_instance
|
||||
mock_cap_instance.isOpened.return_value = True
|
||||
|
||||
# First client subscribes
|
||||
subscription_msg1 = {
|
||||
"type": "subscribe",
|
||||
"payload": {
|
||||
"subscriptionIdentifier": "display-001;cam-001",
|
||||
"rtspUrl": "rtsp://shared.example.com/stream",
|
||||
"modelUrl": "http://example.com/model.mpta"
|
||||
}
|
||||
}
|
||||
|
||||
# Second client subscribes to same camera
|
||||
subscription_msg2 = {
|
||||
"type": "subscribe",
|
||||
"payload": {
|
||||
"subscriptionIdentifier": "display-002;cam-001",
|
||||
"rtspUrl": "rtsp://shared.example.com/stream", # Same URL
|
||||
"modelUrl": "http://example.com/model.mpta"
|
||||
}
|
||||
}
|
||||
|
||||
# Mock file operations
|
||||
with patch('builtins.open') as mock_file_open:
|
||||
pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]}
|
||||
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
|
||||
|
||||
# Process both subscriptions
|
||||
response1 = await message_processor.process_message(subscription_msg1, "client_1")
|
||||
response2 = await message_processor.process_message(subscription_msg2, "client_2")
|
||||
|
||||
# Both should succeed and reference same underlying stream
|
||||
assert response1.get("status") == "success"
|
||||
assert response2.get("status") == "success"
|
||||
|
||||
# Should only create one video capture instance (shared stream)
|
||||
assert mock_video_cap.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_ordering_protocol(self, mock_websocket):
|
||||
"""Test message ordering and sequencing."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
# Test sequence of related messages
|
||||
messages = [
|
||||
{"type": "subscribe", "payload": {"subscriptionIdentifier": "test", "rtspUrl": "rtsp://test.com"}},
|
||||
{"type": "setSessionId", "payload": {"sessionId": "session_123", "displayId": "display_001"}},
|
||||
{"type": "requestState"},
|
||||
{"type": "patchSession", "payload": {"sessionId": "session_123", "data": {"test": "data"}}},
|
||||
{"type": "unsubscribe", "payload": {"subscriptionIdentifier": "test"}}
|
||||
]
|
||||
|
||||
mock_websocket.receive_json.side_effect = messages + [asyncio.CancelledError()]
|
||||
|
||||
with patch('cv2.VideoCapture') as mock_video_cap, \
|
||||
patch('builtins.open') as mock_file_open:
|
||||
|
||||
mock_cap_instance = Mock()
|
||||
mock_video_cap.return_value = mock_cap_instance
|
||||
mock_cap_instance.isOpened.return_value = True
|
||||
|
||||
pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]}
|
||||
mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config)
|
||||
|
||||
client_id = "test_client_ordering"
|
||||
|
||||
try:
|
||||
await websocket_handler.handle_websocket(mock_websocket, client_id)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Verify responses were sent in correct order
|
||||
sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list]
|
||||
|
||||
# Should receive responses for each message type
|
||||
response_types = [msg.get("type") for msg in sent_messages]
|
||||
|
||||
expected_types = ["subscribeAck", "setSessionIdAck", "stateReport", "patchSessionAck", "unsubscribeAck"]
|
||||
|
||||
# Check that we got appropriate responses (order may vary slightly)
|
||||
for expected_type in expected_types:
|
||||
assert any(expected_type in response_types), f"Missing response type: {expected_type}"
|
||||
|
||||
|
||||
class TestWebSocketPerformance:
|
||||
"""Test WebSocket performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_throughput(self, mock_websocket):
|
||||
"""Test message processing throughput."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
|
||||
# Prepare batch of simple messages
|
||||
state_request = {"type": "requestState"}
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Process many messages quickly
|
||||
for _ in range(100):
|
||||
await message_processor.process_message(state_request, "test_client")
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
# Should process messages quickly (less than 1 second for 100 messages)
|
||||
assert processing_time < 1.0
|
||||
|
||||
# Calculate throughput
|
||||
throughput = 100 / processing_time
|
||||
assert throughput > 100 # Should handle > 100 messages/second
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_message_handling(self, mock_websocket):
|
||||
"""Test concurrent message handling."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
# Create multiple mock clients
|
||||
num_clients = 10
|
||||
mock_websockets = []
|
||||
|
||||
for i in range(num_clients):
|
||||
ws = Mock(spec=WebSocket)
|
||||
ws.accept = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
ws.receive_json = AsyncMock(side_effect=[
|
||||
{"type": "requestState"},
|
||||
asyncio.CancelledError()
|
||||
])
|
||||
mock_websockets.append(ws)
|
||||
|
||||
# Handle all clients concurrently
|
||||
client_tasks = []
|
||||
for i, ws in enumerate(mock_websockets):
|
||||
task = asyncio.create_task(
|
||||
websocket_handler.handle_websocket(ws, f"perf_client_{i}")
|
||||
)
|
||||
client_tasks.append(task)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Wait for all to complete
|
||||
try:
|
||||
await asyncio.gather(*client_tasks)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
|
||||
# Should handle all clients efficiently
|
||||
assert total_time < 2.0 # Should complete in less than 2 seconds
|
||||
|
||||
# All clients should have been accepted
|
||||
for ws in mock_websockets:
|
||||
ws.accept.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_stability(self, mock_websocket):
|
||||
"""Test memory usage remains stable."""
|
||||
|
||||
message_processor = MessageProcessor()
|
||||
websocket_handler = WebSocketHandler(message_processor)
|
||||
|
||||
# Simulate many connection/disconnection cycles
|
||||
for cycle in range(10):
|
||||
# Create client
|
||||
client_id = f"memory_test_client_{cycle}"
|
||||
|
||||
mock_websocket.receive_json.side_effect = [
|
||||
{"type": "requestState"},
|
||||
asyncio.CancelledError()
|
||||
]
|
||||
|
||||
try:
|
||||
await websocket_handler.handle_websocket(mock_websocket, client_id)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Reset mock for next cycle
|
||||
mock_websocket.reset_mock()
|
||||
mock_websocket.accept = AsyncMock()
|
||||
mock_websocket.send_json = AsyncMock()
|
||||
mock_websocket.receive_json = AsyncMock()
|
||||
|
||||
# Connection manager should not accumulate stale connections
|
||||
stats = websocket_handler.get_connection_stats()
|
||||
assert stats["total_connections"] == 0 # All should be cleaned up
|
Loading…
Add table
Add a link
Reference in a new issue