224 lines
No EOL
6.8 KiB
Python
224 lines
No EOL
6.8 KiB
Python
"""
|
|
Global pytest configuration and fixtures.
|
|
|
|
This file provides shared fixtures and configuration for all test modules.
|
|
"""
|
|
import pytest
|
|
import asyncio
|
|
import tempfile
|
|
import os
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, AsyncMock
|
|
import numpy as np
|
|
|
|
# Configure asyncio event loop for async tests
|
|
@pytest.fixture(scope="session")
|
|
def event_loop():
|
|
"""Create an instance of the default event loop for the test session."""
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_websocket():
|
|
"""Create a mock WebSocket for testing."""
|
|
websocket = Mock()
|
|
websocket.accept = AsyncMock()
|
|
websocket.send_json = AsyncMock()
|
|
websocket.send_text = AsyncMock()
|
|
websocket.receive_json = AsyncMock()
|
|
websocket.receive_text = AsyncMock()
|
|
websocket.close = AsyncMock()
|
|
websocket.ping = AsyncMock()
|
|
return websocket
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_redis_client():
|
|
"""Create a mock Redis client."""
|
|
redis_client = Mock()
|
|
redis_client.ping.return_value = True
|
|
redis_client.set.return_value = True
|
|
redis_client.get.return_value = "test_value"
|
|
redis_client.delete.return_value = 1
|
|
redis_client.exists.return_value = 1
|
|
redis_client.expire.return_value = True
|
|
redis_client.ttl.return_value = 300
|
|
redis_client.scan_iter.return_value = []
|
|
return redis_client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_frame():
|
|
"""Create a mock frame for testing."""
|
|
return np.ones((480, 640, 3), dtype=np.uint8) * 128
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_detection_result():
|
|
"""Create a sample detection result."""
|
|
return {
|
|
"class": "car",
|
|
"confidence": 0.92,
|
|
"bbox": [100, 200, 300, 400],
|
|
"track_id": 1001,
|
|
"timestamp": 1640995200000
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_directory():
|
|
"""Create a temporary directory for testing."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
yield Path(temp_dir)
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_config_file():
|
|
"""Create a temporary configuration file."""
|
|
config_data = """{
|
|
"poll_interval_ms": 100,
|
|
"max_streams": 10,
|
|
"target_fps": 30,
|
|
"reconnect_interval_sec": 5,
|
|
"max_retries": 3,
|
|
"database": {
|
|
"enabled": false,
|
|
"host": "localhost",
|
|
"port": 5432,
|
|
"database": "test_db",
|
|
"user": "test_user",
|
|
"password": "test_pass"
|
|
},
|
|
"redis": {
|
|
"enabled": false,
|
|
"host": "localhost",
|
|
"port": 6379,
|
|
"db": 0
|
|
}
|
|
}"""
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
f.write(config_data)
|
|
temp_path = f.name
|
|
|
|
yield temp_path
|
|
|
|
# Cleanup
|
|
try:
|
|
os.unlink(temp_path)
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_pipeline_config():
|
|
"""Create a sample pipeline configuration."""
|
|
return {
|
|
"modelId": "test_detection_model",
|
|
"modelFile": "test_model.pt",
|
|
"multiClass": True,
|
|
"expectedClasses": ["Car", "Person"],
|
|
"triggerClasses": ["Car", "Person"],
|
|
"minConfidence": 0.8,
|
|
"actions": [
|
|
{
|
|
"type": "redis_save_image",
|
|
"region": "Car",
|
|
"key": "detection:{display_id}:{timestamp}:{session_id}",
|
|
"expire_seconds": 600
|
|
}
|
|
],
|
|
"branches": [
|
|
{
|
|
"modelId": "classification_model",
|
|
"modelFile": "classifier.pt",
|
|
"parallel": True,
|
|
"crop": True,
|
|
"cropClass": "Car",
|
|
"triggerClasses": ["Car"],
|
|
"minConfidence": 0.85
|
|
}
|
|
]
|
|
}
|
|
|
|
|
|
def pytest_configure(config):
|
|
"""Configure pytest with custom settings."""
|
|
# Register custom markers
|
|
config.addinivalue_line("markers", "unit: mark test as a unit test")
|
|
config.addinivalue_line("markers", "integration: mark test as an integration test")
|
|
config.addinivalue_line("markers", "performance: mark test as a performance benchmark")
|
|
config.addinivalue_line("markers", "slow: mark test as slow running")
|
|
config.addinivalue_line("markers", "network: mark test as requiring network access")
|
|
config.addinivalue_line("markers", "database: mark test as requiring database access")
|
|
config.addinivalue_line("markers", "redis: mark test as requiring Redis access")
|
|
|
|
|
|
def pytest_collection_modifyitems(config, items):
|
|
"""Modify test collection to add markers automatically."""
|
|
for item in items:
|
|
# Auto-mark tests based on file path
|
|
if "unit" in str(item.fspath):
|
|
item.add_marker(pytest.mark.unit)
|
|
elif "integration" in str(item.fspath):
|
|
item.add_marker(pytest.mark.integration)
|
|
elif "performance" in str(item.fspath):
|
|
item.add_marker(pytest.mark.performance)
|
|
|
|
# Auto-mark slow tests
|
|
if "performance" in str(item.fspath) or "large" in item.name.lower():
|
|
item.add_marker(pytest.mark.slow)
|
|
|
|
# Auto-mark tests requiring external services
|
|
if "database" in item.name.lower() or "db" in item.name.lower():
|
|
item.add_marker(pytest.mark.database)
|
|
if "redis" in item.name.lower():
|
|
item.add_marker(pytest.mark.redis)
|
|
if "websocket" in item.name.lower() or "network" in item.name.lower():
|
|
item.add_marker(pytest.mark.network)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def cleanup_singletons():
|
|
"""Clean up singleton instances between tests."""
|
|
yield
|
|
|
|
# Reset singleton managers to prevent test interference
|
|
try:
|
|
from detector_worker.core.singleton_managers import (
|
|
ModelStateManager, StreamStateManager, SessionStateManager,
|
|
CacheStateManager, CameraStateManager, PipelineStateManager
|
|
)
|
|
|
|
# Clear singleton instances
|
|
for manager_class in [
|
|
ModelStateManager, StreamStateManager, SessionStateManager,
|
|
CacheStateManager, CameraStateManager, PipelineStateManager
|
|
]:
|
|
if hasattr(manager_class, '_instances'):
|
|
manager_class._instances.clear()
|
|
except ImportError:
|
|
# Modules may not be available in all test contexts
|
|
pass
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_asyncio_loop():
|
|
"""Reset asyncio event loop between tests."""
|
|
# This helps prevent asyncio-related issues between tests
|
|
yield
|
|
|
|
# Close any remaining tasks
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
# Cancel all remaining tasks
|
|
pending_tasks = asyncio.all_tasks(loop)
|
|
for task in pending_tasks:
|
|
if not task.done():
|
|
task.cancel()
|
|
except RuntimeError:
|
|
# No event loop in current thread
|
|
pass |