python-detector-worker/tests/integration/test_complete_detection_workflow.py
2025-09-12 18:55:23 +07:00

681 lines
No EOL
28 KiB
Python

"""
Integration tests for complete detection workflow.
This module tests the full end-to-end detection pipeline from stream
to database update, ensuring all components work together correctly.
"""
import pytest
import asyncio
import uuid
import time
import json
import tempfile
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from pathlib import Path
import numpy as np
import cv2
from detector_worker.app import create_app
from detector_worker.core.config import Configuration
from detector_worker.core.dependency_injection import ServiceContainer
from detector_worker.streams.stream_manager import StreamManager, StreamConfig
from detector_worker.models.model_manager import ModelManager
from detector_worker.pipeline.pipeline_executor import PipelineExecutor
from detector_worker.storage.database_manager import DatabaseManager
from detector_worker.storage.redis_client import RedisClient, RedisConfig
from detector_worker.communication.websocket_handler import WebSocketHandler
from detector_worker.communication.message_processor import MessageProcessor
@pytest.fixture
def temp_config_file():
"""Create temporary configuration file."""
config_data = {
"poll_interval_ms": 100,
"max_streams": 5,
"target_fps": 10,
"reconnect_interval_sec": 1,
"max_retries": 2,
"database": {
"enabled": True,
"host": "localhost",
"port": 5432,
"database": "test_gas_station_1",
"user": "test_user",
"password": "test_pass"
},
"redis": {
"enabled": True,
"host": "localhost",
"port": 6379,
"db": 0
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
temp_path = f.name
yield temp_path
# Cleanup
Path(temp_path).unlink(missing_ok=True)
@pytest.fixture
def mock_frame():
"""Create a mock frame for testing."""
return np.ones((480, 640, 3), dtype=np.uint8) * 128
@pytest.fixture
def sample_mpta_file():
"""Create a sample MPTA (pipeline) file."""
pipeline_config = {
"modelId": "car_frontal_detection_v1",
"modelFile": "car_frontal_detection_v1.pt",
"multiClass": True,
"expectedClasses": ["Car", "Frontal"],
"triggerClasses": ["Car", "Frontal"],
"minConfidence": 0.8,
"actions": [
{
"type": "redis_save_image",
"region": "Frontal",
"key": "inference:{display_id}:{timestamp}:{session_id}:{filename}",
"expire_seconds": 600
},
{
"type": "postgresql_create_record",
"table": "car_frontal_info",
"fields": {
"display_id": "{display_id}",
"captured_timestamp": "{timestamp}",
"session_id": "{session_id}",
"license_character": None,
"license_type": "No model available"
}
}
],
"branches": [
{
"modelId": "car_brand_cls_v1",
"modelFile": "car_brand_cls_v1.pt",
"parallel": True,
"crop": True,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.85
},
{
"modelId": "car_bodytype_cls_v1",
"modelFile": "car_bodytype_cls_v1.pt",
"parallel": True,
"crop": True,
"cropClass": "Frontal",
"triggerClasses": ["Frontal"],
"minConfidence": 0.80
}
],
"parallelActions": [
{
"type": "postgresql_update_combined",
"table": "car_frontal_info",
"key_field": "session_id",
"waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"],
"fields": {
"car_brand": "{car_brand_cls_v1.brand}",
"car_body_type": "{car_bodytype_cls_v1.body_type}"
}
}
]
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(pipeline_config, f)
temp_path = f.name
yield temp_path
# Cleanup
Path(temp_path).unlink(missing_ok=True)
class TestCompleteDetectionWorkflow:
"""Test complete detection workflow from stream to database."""
@pytest.mark.asyncio
async def test_complete_rtsp_detection_workflow(self, temp_config_file, sample_mpta_file, mock_frame):
"""Test complete workflow: RTSP stream -> detection -> classification -> database."""
# Initialize configuration
config = Configuration()
config.load_from_file(temp_config_file)
# Create service container
container = ServiceContainer()
# Mock all external dependencies
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load, \
patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis:
# Setup video capture mock
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, mock_frame)
# Setup model loading mock
mock_detection_model = Mock()
mock_brand_model = Mock()
mock_bodytype_model = Mock()
def mock_model_load(path, **kwargs):
if "detection" in path:
return mock_detection_model
elif "brand" in path:
return mock_brand_model
elif "bodytype" in path:
return mock_bodytype_model
return Mock()
mock_torch_load.side_effect = mock_model_load
# Setup detection model predictions
mock_detection_model.return_value = Mock()
mock_detection_model.return_value.boxes = Mock()
mock_detection_model.return_value.boxes.xyxy = Mock()
mock_detection_model.return_value.boxes.conf = Mock()
mock_detection_model.return_value.boxes.cls = Mock()
mock_detection_model.return_value.names = {0: "Car", 1: "Frontal"}
# Mock detection results - Car and Frontal detected
mock_detection_model.return_value.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([
[100, 200, 300, 400], # Car bbox
[150, 250, 250, 350] # Frontal bbox
])
mock_detection_model.return_value.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89])
mock_detection_model.return_value.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1])
# Setup classification model predictions
mock_brand_result = Mock()
mock_brand_result.probs = Mock()
mock_brand_result.probs.top1 = 5 # Toyota index
mock_brand_result.probs.top1conf = Mock()
mock_brand_result.probs.top1conf.item.return_value = 0.87
mock_brand_result.names = {5: "Toyota"}
mock_brand_model.return_value = mock_brand_result
mock_bodytype_result = Mock()
mock_bodytype_result.probs = Mock()
mock_bodytype_result.probs.top1 = 2 # Sedan index
mock_bodytype_result.probs.top1conf = Mock()
mock_bodytype_result.probs.top1conf.item.return_value = 0.82
mock_bodytype_result.names = {2: "Sedan"}
mock_bodytype_model.return_value = mock_bodytype_result
# Setup database mock
mock_db_conn = Mock()
mock_db_connect.return_value = mock_db_conn
mock_cursor = Mock()
mock_db_conn.cursor.return_value = mock_cursor
# Setup Redis mock
mock_redis_instance = Mock()
mock_redis.return_value = mock_redis_instance
mock_redis_instance.ping.return_value = True
mock_redis_instance.set.return_value = True
mock_redis_instance.expire.return_value = True
# Initialize managers
stream_manager = StreamManager()
model_manager = ModelManager()
pipeline_executor = PipelineExecutor()
# Register services in container
container.register_singleton(StreamManager, lambda: stream_manager)
container.register_singleton(ModelManager, lambda: model_manager)
container.register_singleton(PipelineExecutor, lambda: pipeline_executor)
try:
# 1. Create RTSP stream
stream_config = StreamConfig(
stream_url="rtsp://example.com/stream",
stream_type="rtsp",
target_fps=10
)
stream_info = await stream_manager.create_stream(
camera_id="camera_001",
config=stream_config,
subscription_id="sub_001"
)
# 2. Load pipeline and models
pipeline_config = json.loads(Path(sample_mpta_file).read_text())
# Mock model file paths exist
with patch('os.path.exists', return_value=True):
await model_manager.load_models_from_config(pipeline_config)
# 3. Get frame from stream
frame = stream_manager.get_latest_frame("camera_001")
assert frame is not None
# 4. Run detection pipeline
detection_context = {
"camera_id": "camera_001",
"display_id": "display_001",
"frame": mock_frame,
"timestamp": int(time.time() * 1000),
"session_id": str(uuid.uuid4())
}
# Mock cv2.imencode for Redis image storage
with patch('cv2.imencode') as mock_imencode:
encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8)
mock_imencode.return_value = (True, encoded_data)
pipeline_result = await pipeline_executor.execute_pipeline(
pipeline_config,
detection_context
)
# 5. Verify pipeline execution
assert pipeline_result is not None
assert pipeline_result.get("status") == "completed"
# 6. Verify database operations
# Should have called create record
create_calls = [call for call in mock_cursor.execute.call_args_list
if "INSERT" in str(call)]
assert len(create_calls) >= 1
# Should have called update with classification results
update_calls = [call for call in mock_cursor.execute.call_args_list
if "UPDATE" in str(call)]
assert len(update_calls) >= 1
# 7. Verify Redis operations
# Should have saved cropped image
assert mock_redis_instance.set.called
assert mock_redis_instance.expire.called
# 8. Verify final results
assert "detections" in pipeline_result
detections = pipeline_result["detections"]
assert len(detections) >= 2 # Car and Frontal
# Check detection classes
detected_classes = [d.get("class") for d in detections]
assert "Car" in detected_classes
assert "Frontal" in detected_classes
# 9. Verify classification results in context
classification_results = pipeline_result.get("classification_results", {})
assert "car_brand_cls_v1" in classification_results
assert "car_bodytype_cls_v1" in classification_results
brand_result = classification_results["car_brand_cls_v1"]
assert brand_result.get("brand") == "Toyota"
assert brand_result.get("confidence") == 0.87
bodytype_result = classification_results["car_bodytype_cls_v1"]
assert bodytype_result.get("body_type") == "Sedan"
assert bodytype_result.get("confidence") == 0.82
finally:
# Cleanup
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_websocket_subscription_workflow(self, temp_config_file, sample_mpta_file):
"""Test complete WebSocket subscription workflow."""
# Mock WebSocket for testing
mock_websocket = Mock()
mock_websocket.accept = AsyncMock()
mock_websocket.send_json = AsyncMock()
mock_websocket.receive_json = AsyncMock()
# Initialize configuration
config = Configuration()
config.load_from_file(temp_config_file)
# Initialize message processor and WebSocket handler
message_processor = MessageProcessor()
websocket_handler = WebSocketHandler(message_processor)
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load, \
patch('psycopg2.connect') as mock_db_connect, \
patch('redis.Redis') as mock_redis:
# Setup mocks
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
mock_torch_load.return_value = Mock()
mock_db_connect.return_value = Mock()
mock_redis.return_value = Mock()
# Simulate WebSocket message sequence
subscription_message = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://example.com/stream",
"modelUrl": f"file://{sample_mpta_file}",
"modelId": 101,
"modelName": "Vehicle Detection",
"cropX1": 100, "cropY1": 200,
"cropX2": 300, "cropY2": 400
}
}
mock_websocket.receive_json.side_effect = [
subscription_message,
{"type": "requestState"},
{"type": "unsubscribe", "payload": {"subscriptionIdentifier": "display-001;cam-001"}}
]
# Mock file operations
with patch('builtins.open', mock_open(read_data=json.dumps(json.loads(Path(sample_mpta_file).read_text())))):
with patch('os.path.exists', return_value=True):
try:
# Start WebSocket handler (will process messages)
await websocket_handler.handle_websocket(mock_websocket, "client_001")
# Verify WebSocket interactions
mock_websocket.accept.assert_called_once()
# Should have sent subscription acknowledgment
send_calls = mock_websocket.send_json.call_args_list
assert len(send_calls) >= 1
# Check subscription acknowledgment
first_response = send_calls[0][0][0]
assert first_response.get("type") == "subscribeAck"
assert first_response.get("status") == "success"
# Should have sent state report
state_responses = [call[0][0] for call in send_calls
if call[0][0].get("type") == "stateReport"]
assert len(state_responses) >= 1
except Exception as e:
# WebSocket disconnect is expected at end of message sequence
pass
@pytest.mark.asyncio
async def test_http_snapshot_workflow(self, temp_config_file, mock_frame):
"""Test HTTP snapshot workflow."""
config = Configuration()
config.load_from_file(temp_config_file)
stream_manager = StreamManager()
with patch('requests.get') as mock_requests:
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"fake_jpeg_data"
mock_requests.return_value = mock_response
with patch('cv2.imdecode') as mock_imdecode:
mock_imdecode.return_value = mock_frame
try:
# Create HTTP snapshot stream
stream_config = StreamConfig(
stream_url="http://camera.example.com/snapshot.jpg",
stream_type="http_snapshot",
snapshot_interval=1.0
)
stream_info = await stream_manager.create_stream(
camera_id="camera_002",
config=stream_config,
subscription_id="sub_002"
)
# Wait for snapshot capture
await asyncio.sleep(1.2)
# Verify frame was captured
frame = stream_manager.get_latest_frame("camera_002")
assert frame is not None
assert np.array_equal(frame, mock_frame)
# Verify HTTP request was made
mock_requests.assert_called()
mock_imdecode.assert_called()
finally:
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_error_recovery_workflow(self, temp_config_file):
"""Test error recovery and resilience."""
config = Configuration()
config.load_from_file(temp_config_file)
stream_manager = StreamManager()
with patch('cv2.VideoCapture') as mock_video_cap:
# Simulate connection failures then success
mock_cap_instances = []
# First attempt fails
mock_cap_fail = Mock()
mock_cap_fail.isOpened.return_value = False
mock_cap_instances.append(mock_cap_fail)
# Second attempt succeeds
mock_cap_success = Mock()
mock_cap_success.isOpened.return_value = True
mock_cap_success.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8))
mock_cap_instances.append(mock_cap_success)
mock_video_cap.side_effect = mock_cap_instances
try:
stream_config = StreamConfig(
stream_url="rtsp://unreliable.example.com/stream",
stream_type="rtsp",
max_retries=2
)
# First attempt should fail
try:
await stream_manager.create_stream(
camera_id="camera_003",
config=stream_config,
subscription_id="sub_003"
)
assert False, "Should have failed on first attempt"
except Exception:
pass
# Retry should succeed
stream_info = await stream_manager.create_stream(
camera_id="camera_003",
config=stream_config,
subscription_id="sub_003"
)
assert stream_info is not None
assert mock_video_cap.call_count == 2
finally:
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_concurrent_streams_workflow(self, temp_config_file, mock_frame):
"""Test handling multiple concurrent streams."""
config = Configuration()
config.load_from_file(temp_config_file)
stream_manager = StreamManager()
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('requests.get') as mock_requests:
# Setup RTSP mock
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, mock_frame)
# Setup HTTP mock
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"fake_jpeg_data"
mock_requests.return_value = mock_response
with patch('cv2.imdecode', return_value=mock_frame):
try:
# Create multiple streams concurrently
stream_tasks = []
# RTSP streams
for i in range(3):
config = StreamConfig(
stream_url=f"rtsp://camera{i}.example.com/stream",
stream_type="rtsp"
)
task = stream_manager.create_stream(
camera_id=f"camera_rtsp_{i}",
config=config,
subscription_id=f"sub_rtsp_{i}"
)
stream_tasks.append(task)
# HTTP snapshot streams
for i in range(2):
config = StreamConfig(
stream_url=f"http://camera{i}.example.com/snapshot.jpg",
stream_type="http_snapshot"
)
task = stream_manager.create_stream(
camera_id=f"camera_http_{i}",
config=config,
subscription_id=f"sub_http_{i}"
)
stream_tasks.append(task)
# Wait for all streams to be created
stream_results = await asyncio.gather(*stream_tasks, return_exceptions=True)
# Verify all streams were created successfully
successful_streams = [r for r in stream_results if not isinstance(r, Exception)]
assert len(successful_streams) == 5
# Verify frames can be retrieved from all streams
await asyncio.sleep(0.5) # Allow time for frame capture
for i in range(3):
frame = stream_manager.get_latest_frame(f"camera_rtsp_{i}")
assert frame is not None
for i in range(2):
frame = stream_manager.get_latest_frame(f"camera_http_{i}")
# HTTP snapshots might not have frames immediately
# Verify stream statistics
stats = stream_manager.get_stream_statistics()
assert stats["total_streams"] == 5
assert stats["active_streams"] >= 3 # At least RTSP streams should be active
finally:
await stream_manager.stop_all_streams()
@pytest.mark.asyncio
async def test_memory_usage_workflow(self, temp_config_file):
"""Test memory usage tracking and cleanup."""
config = Configuration()
config.load_from_file(temp_config_file)
# Create managers with small limits for testing
stream_manager = StreamManager({"max_streams": 10})
model_manager = ModelManager({"cache_max_size": 5})
with patch('cv2.VideoCapture') as mock_video_cap, \
patch('torch.load') as mock_torch_load:
mock_cap_instance = Mock()
mock_video_cap.return_value = mock_cap_instance
mock_cap_instance.isOpened.return_value = True
mock_cap_instance.read.return_value = (True, np.ones((100, 100, 3), dtype=np.uint8))
# Mock model loading
def create_mock_model():
model = Mock()
# Mock model parameters for memory estimation
param1 = Mock()
param1.numel.return_value = 1000
param1.element_size.return_value = 4
model.parameters.return_value = [param1]
return model
mock_torch_load.side_effect = lambda *args, **kwargs: create_mock_model()
try:
# Create streams up to limit
for i in range(8):
stream_config = StreamConfig(
stream_url=f"rtsp://test{i}.example.com/stream",
stream_type="rtsp"
)
await stream_manager.create_stream(
camera_id=f"test_camera_{i}",
config=stream_config,
subscription_id=f"test_sub_{i}"
)
# Load models up to cache limit
with patch('os.path.exists', return_value=True):
for i in range(7):
config = {
"model_id": f"test_model_{i}",
"model_path": f"/fake/path/model_{i}.pt",
"model_type": "detection",
"device": "cpu"
}
await model_manager.load_model_from_dict(config)
# Check memory usage tracking
stream_stats = stream_manager.get_stream_statistics()
model_stats = model_manager.get_cache_statistics()
assert stream_stats["total_streams"] == 8
assert model_stats["size"] <= 5 # Should be limited by cache size
# Test cleanup
cleaned_models = model_manager.cleanup_unused_models()
assert cleaned_models >= 0
stopped_streams = await stream_manager.stop_all_streams()
assert stopped_streams == 8
# Verify cleanup
final_stream_stats = stream_manager.get_stream_statistics()
assert final_stream_stats["total_streams"] == 0
finally:
await stream_manager.stop_all_streams()
def mock_open(read_data=""):
"""Create a mock file opener."""
from unittest.mock import mock_open as _mock_open
return _mock_open(read_data=read_data)