diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c48fd70 --- /dev/null +++ b/Makefile @@ -0,0 +1,160 @@ +# Detector Worker Makefile +# Provides convenient commands for development, testing, and deployment + +.PHONY: help install install-dev test test-unit test-integration test-performance test-all test-fast test-coverage lint format clean run docker-build docker-run + +# Default target +help: + @echo "Detector Worker - Available Commands:" + @echo "" + @echo "Development Setup:" + @echo " make install Install production dependencies" + @echo " make install-dev Install development dependencies" + @echo "" + @echo "Testing:" + @echo " make test Run all tests" + @echo " make test-unit Run unit tests only" + @echo " make test-integration Run integration tests" + @echo " make test-performance Run performance benchmarks" + @echo " make test-fast Run fast tests only" + @echo " make test-coverage Generate coverage report" + @echo "" + @echo "Code Quality:" + @echo " make lint Run code linting" + @echo " make format Format code with black and isort" + @echo " make quality Run all quality checks" + @echo "" + @echo "Development:" + @echo " make run Run the detector worker" + @echo " make clean Clean build artifacts" + @echo "" + @echo "Docker:" + @echo " make docker-build Build Docker image" + @echo " make docker-run Run Docker container" + +# Installation targets +install: + pip install -r requirements.txt + +install-dev: install + pip install -r requirements-dev.txt + +# Testing targets +test: + python scripts/run_tests.py --all + +test-unit: + python scripts/run_tests.py --unit --verbose + +test-integration: + python scripts/run_tests.py --integration --verbose + +test-performance: + python scripts/run_tests.py --performance --verbose + +test-fast: + python scripts/run_tests.py --fast --verbose + +test-coverage: + python scripts/run_tests.py --coverage --open-browser + +test-failed: + python scripts/run_tests.py --failed --verbose + +# Code quality targets +lint: + @echo "Running flake8..." + -flake8 detector_worker --max-line-length=120 --extend-ignore=E203,W503 + @echo "Running mypy..." + -mypy detector_worker --ignore-missing-imports --no-strict-optional + +format: + @echo "Formatting with black..." + black detector_worker tests scripts + @echo "Sorting imports with isort..." + isort detector_worker tests scripts + +quality: lint + python scripts/run_tests.py --quality + +# Development targets +run: + python app.py + +run-debug: + python app.py --debug + +clean: + @echo "Cleaning build artifacts..." + rm -rf build/ + rm -rf dist/ + rm -rf *.egg-info/ + rm -rf htmlcov/ + rm -rf .coverage + rm -rf coverage.xml + rm -rf test-results.xml + rm -rf .pytest_cache/ + find . -type d -name __pycache__ -delete + find . -type f -name "*.pyc" -delete + find . -type f -name "*.pyo" -delete + +# Docker targets +docker-build: + docker build -t detector-worker . + +docker-run: + docker run -p 8000:8000 detector-worker + +docker-dev: + docker run -it -v $(PWD):/app -p 8000:8000 detector-worker bash + +# CI/CD targets +ci-test: + python scripts/run_tests.py --all --skip-slow + +ci-quality: + python scripts/run_tests.py --quality + +# Documentation targets +docs: + @echo "Documentation generation not yet implemented" + +# Development utilities +check-deps: + pip check + +update-deps: + pip list --outdated + +freeze: + pip freeze > requirements-frozen.txt + +# Performance profiling +profile: + python -m cProfile -o profile_output.prof app.py + @echo "Profile saved to profile_output.prof" + @echo "View with: python -m pstats profile_output.prof" + +# Database utilities (if needed) +db-migrate: + @echo "Database migration not yet implemented" + +db-reset: + @echo "Database reset not yet implemented" + +# Monitor and debug +monitor: + @echo "Starting system monitor..." + python -c "import psutil; import time; [print(f'CPU: {psutil.cpu_percent()}%, Memory: {psutil.virtual_memory().percent}%') or time.sleep(1) for _ in range(60)]" + +# Utility targets +version: + python -c "import detector_worker; print(f'Detector Worker Version: {getattr(detector_worker, \"__version__\", \"unknown\")}')" + +env-info: + @echo "Environment Information:" + @echo "Python: $(shell python --version)" + @echo "Pip: $(shell pip --version)" + @echo "Working Directory: $(PWD)" + @echo "Git Branch: $(shell git branch --show-current 2>/dev/null || echo 'Not a git repository')" + @echo "Git Commit: $(shell git rev-parse --short HEAD 2>/dev/null || echo 'Not a git repository')" \ No newline at end of file diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..7ddb1f9 --- /dev/null +++ b/conftest.py @@ -0,0 +1,224 @@ +""" +Global pytest configuration and fixtures. + +This file provides shared fixtures and configuration for all test modules. +""" +import pytest +import asyncio +import tempfile +import os +from pathlib import Path +from unittest.mock import Mock, AsyncMock +import numpy as np + +# Configure asyncio event loop for async tests +@pytest.fixture(scope="session") +def event_loop(): + """Create an instance of the default event loop for the test session.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def mock_websocket(): + """Create a mock WebSocket for testing.""" + websocket = Mock() + websocket.accept = AsyncMock() + websocket.send_json = AsyncMock() + websocket.send_text = AsyncMock() + websocket.receive_json = AsyncMock() + websocket.receive_text = AsyncMock() + websocket.close = AsyncMock() + websocket.ping = AsyncMock() + return websocket + + +@pytest.fixture +def mock_redis_client(): + """Create a mock Redis client.""" + redis_client = Mock() + redis_client.ping.return_value = True + redis_client.set.return_value = True + redis_client.get.return_value = "test_value" + redis_client.delete.return_value = 1 + redis_client.exists.return_value = 1 + redis_client.expire.return_value = True + redis_client.ttl.return_value = 300 + redis_client.scan_iter.return_value = [] + return redis_client + + +@pytest.fixture +def mock_frame(): + """Create a mock frame for testing.""" + return np.ones((480, 640, 3), dtype=np.uint8) * 128 + + +@pytest.fixture +def sample_detection_result(): + """Create a sample detection result.""" + return { + "class": "car", + "confidence": 0.92, + "bbox": [100, 200, 300, 400], + "track_id": 1001, + "timestamp": 1640995200000 + } + + +@pytest.fixture +def temp_directory(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def temp_config_file(): + """Create a temporary configuration file.""" + config_data = """{ + "poll_interval_ms": 100, + "max_streams": 10, + "target_fps": 30, + "reconnect_interval_sec": 5, + "max_retries": 3, + "database": { + "enabled": false, + "host": "localhost", + "port": 5432, + "database": "test_db", + "user": "test_user", + "password": "test_pass" + }, + "redis": { + "enabled": false, + "host": "localhost", + "port": 6379, + "db": 0 + } + }""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + f.write(config_data) + temp_path = f.name + + yield temp_path + + # Cleanup + try: + os.unlink(temp_path) + except FileNotFoundError: + pass + + +@pytest.fixture +def sample_pipeline_config(): + """Create a sample pipeline configuration.""" + return { + "modelId": "test_detection_model", + "modelFile": "test_model.pt", + "multiClass": True, + "expectedClasses": ["Car", "Person"], + "triggerClasses": ["Car", "Person"], + "minConfidence": 0.8, + "actions": [ + { + "type": "redis_save_image", + "region": "Car", + "key": "detection:{display_id}:{timestamp}:{session_id}", + "expire_seconds": 600 + } + ], + "branches": [ + { + "modelId": "classification_model", + "modelFile": "classifier.pt", + "parallel": True, + "crop": True, + "cropClass": "Car", + "triggerClasses": ["Car"], + "minConfidence": 0.85 + } + ] + } + + +def pytest_configure(config): + """Configure pytest with custom settings.""" + # Register custom markers + config.addinivalue_line("markers", "unit: mark test as a unit test") + config.addinivalue_line("markers", "integration: mark test as an integration test") + config.addinivalue_line("markers", "performance: mark test as a performance benchmark") + config.addinivalue_line("markers", "slow: mark test as slow running") + config.addinivalue_line("markers", "network: mark test as requiring network access") + config.addinivalue_line("markers", "database: mark test as requiring database access") + config.addinivalue_line("markers", "redis: mark test as requiring Redis access") + + +def pytest_collection_modifyitems(config, items): + """Modify test collection to add markers automatically.""" + for item in items: + # Auto-mark tests based on file path + if "unit" in str(item.fspath): + item.add_marker(pytest.mark.unit) + elif "integration" in str(item.fspath): + item.add_marker(pytest.mark.integration) + elif "performance" in str(item.fspath): + item.add_marker(pytest.mark.performance) + + # Auto-mark slow tests + if "performance" in str(item.fspath) or "large" in item.name.lower(): + item.add_marker(pytest.mark.slow) + + # Auto-mark tests requiring external services + if "database" in item.name.lower() or "db" in item.name.lower(): + item.add_marker(pytest.mark.database) + if "redis" in item.name.lower(): + item.add_marker(pytest.mark.redis) + if "websocket" in item.name.lower() or "network" in item.name.lower(): + item.add_marker(pytest.mark.network) + + +@pytest.fixture(autouse=True) +def cleanup_singletons(): + """Clean up singleton instances between tests.""" + yield + + # Reset singleton managers to prevent test interference + try: + from detector_worker.core.singleton_managers import ( + ModelStateManager, StreamStateManager, SessionStateManager, + CacheStateManager, CameraStateManager, PipelineStateManager + ) + + # Clear singleton instances + for manager_class in [ + ModelStateManager, StreamStateManager, SessionStateManager, + CacheStateManager, CameraStateManager, PipelineStateManager + ]: + if hasattr(manager_class, '_instances'): + manager_class._instances.clear() + except ImportError: + # Modules may not be available in all test contexts + pass + + +@pytest.fixture(autouse=True) +def reset_asyncio_loop(): + """Reset asyncio event loop between tests.""" + # This helps prevent asyncio-related issues between tests + yield + + # Close any remaining tasks + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Cancel all remaining tasks + pending_tasks = asyncio.all_tasks(loop) + for task in pending_tasks: + if not task.done(): + task.cancel() + except RuntimeError: + # No event loop in current thread + pass \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..d15bd15 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,80 @@ +[tool:pytest] +# Pytest configuration file + +# Test discovery +testpaths = tests +python_files = test_*.py *_test.py +python_classes = Test* +python_functions = test_* + +# Markers for different test types +markers = + unit: Unit tests (fast, isolated) + integration: Integration tests (slower, multiple components) + performance: Performance benchmarks (may take longer) + slow: Slow tests that may be skipped in CI + network: Tests requiring network access + database: Tests requiring database access + redis: Tests requiring Redis access + +# Output options +addopts = + --strict-markers + --strict-config + --verbose + --tb=short + --durations=10 + --color=yes + --cov=detector_worker + --cov-report=term-missing + --cov-report=html:htmlcov + --cov-report=xml:coverage.xml + --cov-fail-under=80 + --junitxml=test-results.xml + +# Coverage configuration +[coverage:run] +source = detector_worker +omit = + */tests/* + */test_* + */__pycache__/* + */venv/* + */env/* + */build/* + */dist/* + setup.py + conftest.py + +[coverage:report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + + # Don't complain about abstract methods + @abstractmethod + +# Report precision +precision = 2 +show_missing = True +skip_covered = False + +[coverage:html] +directory = htmlcov +title = Detector Worker Test Coverage Report + +[coverage:xml] +output = coverage.xml \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..11395be --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,48 @@ +# Development dependencies for testing, linting, and debugging + +# Testing framework and extensions +pytest>=7.0.0 +pytest-asyncio>=0.20.0 +pytest-cov>=4.0.0 +pytest-mock>=3.10.0 +pytest-xdist>=3.0.0 # Parallel test execution +pytest-html>=3.1.0 # HTML test reports +pytest-benchmark>=4.0.0 # Performance benchmarking + +# Code coverage +coverage>=7.0.0 + +# Code quality and linting +flake8>=5.0.0 +black>=22.0.0 +isort>=5.10.0 +mypy>=1.0.0 +pylint>=2.15.0 + +# Type checking and stubs +types-redis>=4.0.0 +types-requests>=2.28.0 +types-PyYAML>=6.0.0 + +# Development utilities +ipython>=8.0.0 # Enhanced Python REPL +ipdb>=0.13.0 # Enhanced debugger +memory-profiler>=0.60.0 # Memory profiling +line-profiler>=3.5.0 # Line-by-line profiling + +# Documentation +sphinx>=5.0.0 +sphinx-rtd-theme>=1.0.0 + +# Performance monitoring +psutil>=5.9.0 # System monitoring (also in main requirements) +py-spy>=0.3.0 # Python profiler + +# Testing utilities +responses>=0.22.0 # HTTP request mocking +freezegun>=1.2.0 # Time mocking +factory-boy>=3.2.0 # Test data generation + +# Additional development tools +pre-commit>=2.20.0 # Git hooks +tox>=4.0.0 # Testing across multiple environments \ No newline at end of file diff --git a/scripts/run_tests.py b/scripts/run_tests.py new file mode 100755 index 0000000..c047443 --- /dev/null +++ b/scripts/run_tests.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 +""" +Test runner script with comprehensive test execution options. + +This script provides different test execution modes for development, +CI/CD, and performance testing scenarios. +""" +import sys +import os +import subprocess +import argparse +from pathlib import Path +import time + +# Add project root to path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + + +class TestRunner: + """Test runner with various execution modes.""" + + def __init__(self): + self.project_root = project_root + self.test_dir = self.project_root / "tests" + + def run_unit_tests(self, verbose=False, coverage=True): + """Run unit tests only.""" + print("๐Ÿงช Running unit tests...") + + cmd = [ + "python", "-m", "pytest", + "tests/unit", + "-m", "unit", + "--tb=short" + ] + + if verbose: + cmd.append("-v") + + if coverage: + cmd.extend([ + "--cov=detector_worker", + "--cov-report=term-missing" + ]) + + return self._run_command(cmd) + + def run_integration_tests(self, verbose=False): + """Run integration tests only.""" + print("๐Ÿ”— Running integration tests...") + + cmd = [ + "python", "-m", "pytest", + "tests/integration", + "-m", "integration", + "--tb=short" + ] + + if verbose: + cmd.append("-v") + + return self._run_command(cmd) + + def run_performance_tests(self, verbose=False): + """Run performance benchmarks.""" + print("๐Ÿš€ Running performance tests...") + + cmd = [ + "python", "-m", "pytest", + "tests/performance", + "-m", "performance", + "--tb=short", + "--durations=0" # Show all durations + ] + + if verbose: + cmd.append("-v") + + return self._run_command(cmd) + + def run_all_tests(self, verbose=False, coverage=True, skip_slow=False): + """Run all tests with full coverage.""" + print("๐Ÿงช Running complete test suite...") + + cmd = [ + "python", "-m", "pytest", + "tests/", + "--tb=short", + "--durations=10" + ] + + if skip_slow: + cmd.extend(["-m", "not slow"]) + + if verbose: + cmd.append("-v") + + if coverage: + cmd.extend([ + "--cov=detector_worker", + "--cov-report=term-missing", + "--cov-report=html:htmlcov", + "--cov-report=xml:coverage.xml", + "--cov-fail-under=80" + ]) + + return self._run_command(cmd) + + def run_fast_tests(self, verbose=False): + """Run only fast tests (unit tests, no slow markers).""" + print("โšก Running fast tests only...") + + cmd = [ + "python", "-m", "pytest", + "tests/unit", + "-m", "unit and not slow", + "--tb=short" + ] + + if verbose: + cmd.append("-v") + + return self._run_command(cmd) + + def run_specific_test(self, test_pattern, verbose=False): + """Run specific test(s) matching pattern.""" + print(f"๐ŸŽฏ Running tests matching: {test_pattern}") + + cmd = [ + "python", "-m", "pytest", + "-k", test_pattern, + "--tb=short" + ] + + if verbose: + cmd.append("-v") + + return self._run_command(cmd) + + def run_failed_tests(self, verbose=False): + """Rerun only failed tests from last run.""" + print("๐Ÿ”„ Rerunning failed tests...") + + cmd = [ + "python", "-m", "pytest", + "--lf", # Last failed + "--tb=short" + ] + + if verbose: + cmd.append("-v") + + return self._run_command(cmd) + + def run_with_coverage_report(self, open_browser=False): + """Run tests and generate detailed coverage report.""" + print("๐Ÿ“Š Running tests with detailed coverage analysis...") + + cmd = [ + "python", "-m", "pytest", + "tests/", + "-m", "not performance", # Skip performance tests for coverage + "--cov=detector_worker", + "--cov-report=html:htmlcov", + "--cov-report=xml:coverage.xml", + "--cov-report=term-missing", + "--cov-fail-under=80", + "--tb=short" + ] + + result = self._run_command(cmd) + + if result == 0 and open_browser: + coverage_html = self.project_root / "htmlcov" / "index.html" + if coverage_html.exists(): + import webbrowser + webbrowser.open(f"file://{coverage_html}") + print(f"๐Ÿ“– Coverage report opened in browser: {coverage_html}") + + return result + + def check_code_quality(self): + """Run code quality checks (linting, formatting).""" + print("๐Ÿ” Running code quality checks...") + + # Check if tools are available + tools = [] + + # Check for flake8 + if self._command_exists("flake8"): + tools.append(("flake8", ["flake8", "detector_worker", "--max-line-length=120"])) + + # Check for black + if self._command_exists("black"): + tools.append(("black", ["black", "--check", "--diff", "detector_worker"])) + + # Check for isort + if self._command_exists("isort"): + tools.append(("isort", ["isort", "--check-only", "--diff", "detector_worker"])) + + # Check for mypy + if self._command_exists("mypy"): + tools.append(("mypy", ["mypy", "detector_worker", "--ignore-missing-imports"])) + + if not tools: + print("โš ๏ธ No code quality tools found. Install flake8, black, isort, mypy for quality checks.") + return 0 + + all_passed = True + for tool_name, cmd in tools: + print(f" Running {tool_name}...") + result = self._run_command(cmd, capture_output=True) + if result != 0: + all_passed = False + print(f" โŒ {tool_name} failed") + else: + print(f" โœ… {tool_name} passed") + + return 0 if all_passed else 1 + + def _run_command(self, cmd, capture_output=False): + """Run a command and return exit code.""" + try: + if capture_output: + result = subprocess.run(cmd, cwd=self.project_root, capture_output=True) + return result.returncode + else: + result = subprocess.run(cmd, cwd=self.project_root) + return result.returncode + except KeyboardInterrupt: + print("\nโš ๏ธ Tests interrupted by user") + return 1 + except Exception as e: + print(f"โŒ Error running command: {e}") + return 1 + + def _command_exists(self, command): + """Check if a command exists in PATH.""" + try: + subprocess.run([command, "--version"], capture_output=True, check=True) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + def print_test_summary(self): + """Print test directory structure and available tests.""" + print("\n๐Ÿ“ Test Directory Structure:") + print("โ”œโ”€โ”€ tests/") + print("โ”‚ โ”œโ”€โ”€ unit/ # Fast, isolated unit tests") + print("โ”‚ โ”œโ”€โ”€ integration/ # Multi-component integration tests") + print("โ”‚ โ”œโ”€โ”€ performance/ # Performance benchmarks") + print("โ”‚ โ””โ”€โ”€ conftest.py # Shared fixtures and configuration") + print() + + if self.test_dir.exists(): + unit_tests = len(list((self.test_dir / "unit").rglob("test_*.py"))) + integration_tests = len(list((self.test_dir / "integration").rglob("test_*.py"))) + performance_tests = len(list((self.test_dir / "performance").rglob("test_*.py"))) + + print(f"๐Ÿ“Š Test Counts:") + print(f" Unit tests: {unit_tests} files") + print(f" Integration tests: {integration_tests} files") + print(f" Performance tests: {performance_tests} files") + print(f" Total: {unit_tests + integration_tests + performance_tests} test files") + print() + + +def main(): + """Main entry point for test runner.""" + parser = argparse.ArgumentParser( + description="Detector Worker Test Runner", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s --all # Run all tests with coverage + %(prog)s --unit # Run unit tests only + %(prog)s --integration # Run integration tests only + %(prog)s --performance # Run performance benchmarks + %(prog)s --fast # Run only fast tests + %(prog)s --failed # Rerun failed tests + %(prog)s --specific "test_config" # Run tests matching pattern + %(prog)s --coverage # Generate coverage report + %(prog)s --quality # Run code quality checks + """ + ) + + parser.add_argument("--all", action="store_true", help="Run all tests") + parser.add_argument("--unit", action="store_true", help="Run unit tests only") + parser.add_argument("--integration", action="store_true", help="Run integration tests only") + parser.add_argument("--performance", action="store_true", help="Run performance benchmarks") + parser.add_argument("--fast", action="store_true", help="Run fast tests only") + parser.add_argument("--failed", action="store_true", help="Rerun failed tests") + parser.add_argument("--specific", metavar="PATTERN", help="Run specific tests matching pattern") + parser.add_argument("--coverage", action="store_true", help="Generate coverage report") + parser.add_argument("--quality", action="store_true", help="Run code quality checks") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + parser.add_argument("--no-coverage", action="store_true", help="Skip coverage reporting") + parser.add_argument("--skip-slow", action="store_true", help="Skip slow tests") + parser.add_argument("--open-browser", action="store_true", help="Open coverage report in browser") + parser.add_argument("--summary", action="store_true", help="Print test summary and exit") + + args = parser.parse_args() + + runner = TestRunner() + + if args.summary: + runner.print_test_summary() + return 0 + + # If no specific test type specified, show help + if not any([args.all, args.unit, args.integration, args.performance, + args.fast, args.failed, args.specific, args.coverage, args.quality]): + parser.print_help() + print("\n๐Ÿ’ก Use --summary to see available tests") + return 0 + + start_time = time.time() + exit_code = 0 + + try: + if args.quality: + exit_code = runner.check_code_quality() + elif args.unit: + exit_code = runner.run_unit_tests(args.verbose, not args.no_coverage) + elif args.integration: + exit_code = runner.run_integration_tests(args.verbose) + elif args.performance: + exit_code = runner.run_performance_tests(args.verbose) + elif args.fast: + exit_code = runner.run_fast_tests(args.verbose) + elif args.failed: + exit_code = runner.run_failed_tests(args.verbose) + elif args.specific: + exit_code = runner.run_specific_test(args.specific, args.verbose) + elif args.coverage: + exit_code = runner.run_with_coverage_report(args.open_browser) + elif args.all: + exit_code = runner.run_all_tests(args.verbose, not args.no_coverage, args.skip_slow) + + end_time = time.time() + duration = end_time - start_time + + if exit_code == 0: + print(f"\nโœ… Tests completed successfully in {duration:.1f} seconds") + else: + print(f"\nโŒ Tests failed in {duration:.1f} seconds") + + except KeyboardInterrupt: + print("\nโš ๏ธ Test execution interrupted") + exit_code = 1 + + return exit_code + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..7030c82 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package for detector worker.""" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8cc4898 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,276 @@ +""" +Pytest configuration and shared fixtures for detector worker tests. +""" +import pytest +import tempfile +import os +from unittest.mock import Mock, MagicMock, patch +from typing import Dict, Any, Generator + +# Add the project root to the path so we can import detector_worker +import sys +from pathlib import Path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from detector_worker.core.config import get_config_manager, ConfigurationManager +from detector_worker.core.dependency_injection import get_container, DetectorWorkerContainer +from detector_worker.core.singleton_managers import ( + ModelStateManager, StreamStateManager, SessionStateManager, + CacheStateManager, CameraStateManager, PipelineStateManager +) + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture +def mock_config(): + """Mock configuration data for tests.""" + return { + "poll_interval_ms": 100, + "max_streams": 5, + "target_fps": 10, + "reconnect_interval_sec": 5, + "max_retries": 3, + "heartbeat_interval": 2, + "session_timeout": 600, + "models_dir": "models", + "log_level": "INFO", + "database": { + "enabled": False, + "host": "localhost", + "port": 5432, + "database": "test_db", + "username": "test_user", + "password": "test_pass", + "schema": "public" + }, + "redis": { + "enabled": False, + "host": "localhost", + "port": 6379, + "password": None, + "db": 0 + } + } + + +@pytest.fixture +def mock_detection_result(): + """Mock detection result for tests.""" + return { + "class": "car", + "confidence": 0.85, + "bbox": [100, 200, 300, 400], + "id": 12345, + "branch_results": {} + } + + +@pytest.fixture +def mock_frame(): + """Mock frame data for tests.""" + import numpy as np + return np.zeros((480, 640, 3), dtype=np.uint8) + + +@pytest.fixture +def mock_model_tree(): + """Mock model tree structure for tests.""" + return { + "modelId": "test_model_v1", + "modelFile": "test_model.pt", + "multiClass": True, + "expectedClasses": ["car", "truck"], + "triggerClasses": ["car"], + "minConfidence": 0.8, + "branches": [], + "actions": [] + } + + +@pytest.fixture +def mock_pipeline_context(): + """Mock pipeline context for tests.""" + return { + "camera_id": "test_camera_001", + "display_id": "display_001", + "session_id": "session_12345", + "timestamp": 1640995200000, + "subscription_id": "sub_001" + } + + +@pytest.fixture(autouse=True) +def reset_singletons(): + """Reset singleton managers before each test.""" + # Clear all singleton state before each test + yield + + # Cleanup after test + try: + ModelStateManager().clear_all() + StreamStateManager().clear_all() + SessionStateManager().clear_all() + CacheStateManager().clear_all() + CameraStateManager().clear_all() + PipelineStateManager().clear_all() + except Exception: + pass # Ignore cleanup errors + + +@pytest.fixture +def isolated_config_manager(temp_dir, mock_config): + """Create an isolated configuration manager for testing.""" + config_file = os.path.join(temp_dir, "test_config.json") + + import json + with open(config_file, 'w') as f: + json.dump(mock_config, f) + + # Create a fresh ConfigurationManager for testing + from detector_worker.core.config import JsonFileProvider, EnvironmentProvider + + manager = ConfigurationManager() + manager._providers.clear() # Remove default providers + manager.add_provider(JsonFileProvider(config_file)) + + return manager + + +@pytest.fixture +def mock_websocket(): + """Mock WebSocket for testing.""" + mock_ws = Mock() + mock_ws.accept = Mock() + mock_ws.send_text = Mock() + mock_ws.send_json = Mock() + mock_ws.receive_text = Mock() + mock_ws.receive_json = Mock() + mock_ws.close = Mock() + mock_ws.client_state = Mock() + mock_ws.client_state.DISCONNECTED = False + return mock_ws + + +@pytest.fixture +def mock_redis_client(): + """Mock Redis client for testing.""" + mock_redis = Mock() + mock_redis.get = Mock(return_value=None) + mock_redis.set = Mock(return_value=True) + mock_redis.delete = Mock(return_value=1) + mock_redis.exists = Mock(return_value=0) + mock_redis.expire = Mock(return_value=True) + mock_redis.publish = Mock(return_value=1) + return mock_redis + + +@pytest.fixture +def mock_database_connection(): + """Mock database connection for testing.""" + mock_conn = Mock() + mock_cursor = Mock() + + mock_cursor.execute = Mock() + mock_cursor.fetchone = Mock(return_value=None) + mock_cursor.fetchall = Mock(return_value=[]) + mock_cursor.fetchmany = Mock(return_value=[]) + mock_cursor.rowcount = 1 + + mock_conn.cursor = Mock(return_value=mock_cursor) + mock_conn.commit = Mock() + mock_conn.rollback = Mock() + mock_conn.close = Mock() + + return mock_conn + + +@pytest.fixture +def mock_yolo_model(): + """Mock YOLO model for testing.""" + mock_model = Mock() + + # Mock results with boxes + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = Mock() + mock_result.boxes.conf = Mock() + mock_result.boxes.cls = Mock() + mock_result.boxes.id = Mock() + + # Mock track method + mock_model.track = Mock(return_value=[mock_result]) + mock_model.predict = Mock(return_value=[mock_result]) + + return mock_model + + +@pytest.fixture +def sample_detection_data(): + """Sample detection data for testing.""" + return [ + { + "class": "car", + "confidence": 0.92, + "bbox": [100, 150, 250, 300], + "id": 1001, + "branch_results": {} + }, + { + "class": "truck", + "confidence": 0.87, + "bbox": [300, 200, 450, 350], + "id": 1002, + "branch_results": {} + } + ] + + +@pytest.fixture +def sample_session_data(): + """Sample session data for testing.""" + return { + "session_id": "session_test_001", + "display_id": "display_test_001", + "camera_id": "camera_test_001", + "created_at": 1640995200.0, + "last_activity": 1640995200.0, + "detection_data": { + "car_brand": "Toyota", + "car_model": "Camry", + "license_plate": "ABC-123" + } + } + + +# Helper functions for tests +def create_mock_detection_result(class_name: str = "car", confidence: float = 0.85, track_id: int = 1001): + """Helper function to create mock detection results.""" + return { + "class": class_name, + "confidence": confidence, + "bbox": [100, 200, 300, 400], + "id": track_id, + "branch_results": {} + } + + +def create_mock_regions_dict(detections: list = None): + """Helper function to create mock regions dictionary.""" + if detections is None: + detections = [create_mock_detection_result()] + + regions = {} + for det in detections: + regions[det["class"]] = { + "bbox": det["bbox"], + "confidence": det["confidence"], + "detection": det + } + return regions \ No newline at end of file diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..16e3ffd --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1,19 @@ +""" +Integration tests for the detector worker application. + +This package contains integration tests that verify the interaction +between multiple components and end-to-end workflows. +""" + +# Integration test modules +from . import ( + test_complete_detection_workflow, + test_websocket_protocol, + test_pipeline_integration +) + +__all__ = [ + "test_complete_detection_workflow", + "test_websocket_protocol", + "test_pipeline_integration" +] \ No newline at end of file diff --git a/tests/integration/test_complete_detection_workflow.py b/tests/integration/test_complete_detection_workflow.py new file mode 100644 index 0000000..0ea8273 --- /dev/null +++ b/tests/integration/test_complete_detection_workflow.py @@ -0,0 +1,681 @@ +""" +Integration tests for complete detection workflow. + +This module tests the full end-to-end detection pipeline from stream +to database update, ensuring all components work together correctly. +""" +import pytest +import asyncio +import uuid +import time +import json +import tempfile +from unittest.mock import Mock, AsyncMock, patch, MagicMock +from pathlib import Path +import numpy as np +import cv2 + +from detector_worker.app import create_app +from detector_worker.core.config import Configuration +from detector_worker.core.dependency_injection import ServiceContainer +from detector_worker.streams.stream_manager import StreamManager, StreamConfig +from detector_worker.models.model_manager import ModelManager +from detector_worker.pipeline.pipeline_executor import PipelineExecutor +from detector_worker.storage.database_manager import DatabaseManager +from detector_worker.storage.redis_client import RedisClient, RedisConfig +from detector_worker.communication.websocket_handler import WebSocketHandler +from detector_worker.communication.message_processor import MessageProcessor + + +@pytest.fixture +def temp_config_file(): + """Create temporary configuration file.""" + config_data = { + "poll_interval_ms": 100, + "max_streams": 5, + "target_fps": 10, + "reconnect_interval_sec": 1, + "max_retries": 2, + "database": { + "enabled": True, + "host": "localhost", + "port": 5432, + "database": "test_gas_station_1", + "user": "test_user", + "password": "test_pass" + }, + "redis": { + "enabled": True, + "host": "localhost", + "port": 6379, + "db": 0 + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config_data, f) + temp_path = f.name + + yield temp_path + + # Cleanup + Path(temp_path).unlink(missing_ok=True) + + +@pytest.fixture +def mock_frame(): + """Create a mock frame for testing.""" + return np.ones((480, 640, 3), dtype=np.uint8) * 128 + + +@pytest.fixture +def sample_mpta_file(): + """Create a sample MPTA (pipeline) file.""" + pipeline_config = { + "modelId": "car_frontal_detection_v1", + "modelFile": "car_frontal_detection_v1.pt", + "multiClass": True, + "expectedClasses": ["Car", "Frontal"], + "triggerClasses": ["Car", "Frontal"], + "minConfidence": 0.8, + "actions": [ + { + "type": "redis_save_image", + "region": "Frontal", + "key": "inference:{display_id}:{timestamp}:{session_id}:{filename}", + "expire_seconds": 600 + }, + { + "type": "postgresql_create_record", + "table": "car_frontal_info", + "fields": { + "display_id": "{display_id}", + "captured_timestamp": "{timestamp}", + "session_id": "{session_id}", + "license_character": None, + "license_type": "No model available" + } + } + ], + "branches": [ + { + "modelId": "car_brand_cls_v1", + "modelFile": "car_brand_cls_v1.pt", + "parallel": True, + "crop": True, + "cropClass": "Frontal", + "triggerClasses": ["Frontal"], + "minConfidence": 0.85 + }, + { + "modelId": "car_bodytype_cls_v1", + "modelFile": "car_bodytype_cls_v1.pt", + "parallel": True, + "crop": True, + "cropClass": "Frontal", + "triggerClasses": ["Frontal"], + "minConfidence": 0.80 + } + ], + "parallelActions": [ + { + "type": "postgresql_update_combined", + "table": "car_frontal_info", + "key_field": "session_id", + "waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"], + "fields": { + "car_brand": "{car_brand_cls_v1.brand}", + "car_body_type": "{car_bodytype_cls_v1.body_type}" + } + } + ] + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(pipeline_config, f) + temp_path = f.name + + yield temp_path + + # Cleanup + Path(temp_path).unlink(missing_ok=True) + + +class TestCompleteDetectionWorkflow: + """Test complete detection workflow from stream to database.""" + + @pytest.mark.asyncio + async def test_complete_rtsp_detection_workflow(self, temp_config_file, sample_mpta_file, mock_frame): + """Test complete workflow: RTSP stream -> detection -> classification -> database.""" + + # Initialize configuration + config = Configuration() + config.load_from_file(temp_config_file) + + # Create service container + container = ServiceContainer() + + # Mock all external dependencies + with patch('cv2.VideoCapture') as mock_video_cap, \ + patch('torch.load') as mock_torch_load, \ + patch('psycopg2.connect') as mock_db_connect, \ + patch('redis.Redis') as mock_redis: + + # Setup video capture mock + mock_cap_instance = Mock() + mock_video_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + mock_cap_instance.read.return_value = (True, mock_frame) + + # Setup model loading mock + mock_detection_model = Mock() + mock_brand_model = Mock() + mock_bodytype_model = Mock() + + def mock_model_load(path, **kwargs): + if "detection" in path: + return mock_detection_model + elif "brand" in path: + return mock_brand_model + elif "bodytype" in path: + return mock_bodytype_model + return Mock() + + mock_torch_load.side_effect = mock_model_load + + # Setup detection model predictions + mock_detection_model.return_value = Mock() + mock_detection_model.return_value.boxes = Mock() + mock_detection_model.return_value.boxes.xyxy = Mock() + mock_detection_model.return_value.boxes.conf = Mock() + mock_detection_model.return_value.boxes.cls = Mock() + mock_detection_model.return_value.names = {0: "Car", 1: "Frontal"} + + # Mock detection results - Car and Frontal detected + mock_detection_model.return_value.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [100, 200, 300, 400], # Car bbox + [150, 250, 250, 350] # Frontal bbox + ]) + mock_detection_model.return_value.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89]) + mock_detection_model.return_value.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1]) + + # Setup classification model predictions + mock_brand_result = Mock() + mock_brand_result.probs = Mock() + mock_brand_result.probs.top1 = 5 # Toyota index + mock_brand_result.probs.top1conf = Mock() + mock_brand_result.probs.top1conf.item.return_value = 0.87 + mock_brand_result.names = {5: "Toyota"} + mock_brand_model.return_value = mock_brand_result + + mock_bodytype_result = Mock() + mock_bodytype_result.probs = Mock() + mock_bodytype_result.probs.top1 = 2 # Sedan index + mock_bodytype_result.probs.top1conf = Mock() + mock_bodytype_result.probs.top1conf.item.return_value = 0.82 + mock_bodytype_result.names = {2: "Sedan"} + mock_bodytype_model.return_value = mock_bodytype_result + + # Setup database mock + mock_db_conn = Mock() + mock_db_connect.return_value = mock_db_conn + mock_cursor = Mock() + mock_db_conn.cursor.return_value = mock_cursor + + # Setup Redis mock + mock_redis_instance = Mock() + mock_redis.return_value = mock_redis_instance + mock_redis_instance.ping.return_value = True + mock_redis_instance.set.return_value = True + mock_redis_instance.expire.return_value = True + + # Initialize managers + stream_manager = StreamManager() + model_manager = ModelManager() + pipeline_executor = PipelineExecutor() + + # Register services in container + container.register_singleton(StreamManager, lambda: stream_manager) + container.register_singleton(ModelManager, lambda: model_manager) + container.register_singleton(PipelineExecutor, lambda: pipeline_executor) + + try: + # 1. Create RTSP stream + stream_config = StreamConfig( + stream_url="rtsp://example.com/stream", + stream_type="rtsp", + target_fps=10 + ) + + stream_info = await stream_manager.create_stream( + camera_id="camera_001", + config=stream_config, + subscription_id="sub_001" + ) + + # 2. Load pipeline and models + pipeline_config = json.loads(Path(sample_mpta_file).read_text()) + + # Mock model file paths exist + with patch('os.path.exists', return_value=True): + await model_manager.load_models_from_config(pipeline_config) + + # 3. Get frame from stream + frame = stream_manager.get_latest_frame("camera_001") + assert frame is not None + + # 4. Run detection pipeline + detection_context = { + "camera_id": "camera_001", + "display_id": "display_001", + "frame": mock_frame, + "timestamp": int(time.time() * 1000), + "session_id": str(uuid.uuid4()) + } + + # Mock cv2.imencode for Redis image storage + with patch('cv2.imencode') as mock_imencode: + encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8) + mock_imencode.return_value = (True, encoded_data) + + pipeline_result = await pipeline_executor.execute_pipeline( + pipeline_config, + detection_context + ) + + # 5. Verify pipeline execution + assert pipeline_result is not None + assert pipeline_result.get("status") == "completed" + + # 6. Verify database operations + # Should have called create record + create_calls = [call for call in mock_cursor.execute.call_args_list + if "INSERT" in str(call)] + assert len(create_calls) >= 1 + + # Should have called update with classification results + update_calls = [call for call in mock_cursor.execute.call_args_list + if "UPDATE" in str(call)] + assert len(update_calls) >= 1 + + # 7. Verify Redis operations + # Should have saved cropped image + assert mock_redis_instance.set.called + assert mock_redis_instance.expire.called + + # 8. Verify final results + assert "detections" in pipeline_result + detections = pipeline_result["detections"] + assert len(detections) >= 2 # Car and Frontal + + # Check detection classes + detected_classes = [d.get("class") for d in detections] + assert "Car" in detected_classes + assert "Frontal" in detected_classes + + # 9. Verify classification results in context + classification_results = pipeline_result.get("classification_results", {}) + assert "car_brand_cls_v1" in classification_results + assert "car_bodytype_cls_v1" in classification_results + + brand_result = classification_results["car_brand_cls_v1"] + assert brand_result.get("brand") == "Toyota" + assert brand_result.get("confidence") == 0.87 + + bodytype_result = classification_results["car_bodytype_cls_v1"] + assert bodytype_result.get("body_type") == "Sedan" + assert bodytype_result.get("confidence") == 0.82 + + finally: + # Cleanup + await stream_manager.stop_all_streams() + + @pytest.mark.asyncio + async def test_websocket_subscription_workflow(self, temp_config_file, sample_mpta_file): + """Test complete WebSocket subscription workflow.""" + + # Mock WebSocket for testing + mock_websocket = Mock() + mock_websocket.accept = AsyncMock() + mock_websocket.send_json = AsyncMock() + mock_websocket.receive_json = AsyncMock() + + # Initialize configuration + config = Configuration() + config.load_from_file(temp_config_file) + + # Initialize message processor and WebSocket handler + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + with patch('cv2.VideoCapture') as mock_video_cap, \ + patch('torch.load') as mock_torch_load, \ + patch('psycopg2.connect') as mock_db_connect, \ + patch('redis.Redis') as mock_redis: + + # Setup mocks + mock_cap_instance = Mock() + mock_video_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + mock_cap_instance.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8)) + + mock_torch_load.return_value = Mock() + mock_db_connect.return_value = Mock() + mock_redis.return_value = Mock() + + # Simulate WebSocket message sequence + subscription_message = { + "type": "subscribe", + "payload": { + "subscriptionIdentifier": "display-001;cam-001", + "rtspUrl": "rtsp://example.com/stream", + "modelUrl": f"file://{sample_mpta_file}", + "modelId": 101, + "modelName": "Vehicle Detection", + "cropX1": 100, "cropY1": 200, + "cropX2": 300, "cropY2": 400 + } + } + + mock_websocket.receive_json.side_effect = [ + subscription_message, + {"type": "requestState"}, + {"type": "unsubscribe", "payload": {"subscriptionIdentifier": "display-001;cam-001"}} + ] + + # Mock file operations + with patch('builtins.open', mock_open(read_data=json.dumps(json.loads(Path(sample_mpta_file).read_text())))): + with patch('os.path.exists', return_value=True): + + try: + # Start WebSocket handler (will process messages) + await websocket_handler.handle_websocket(mock_websocket, "client_001") + + # Verify WebSocket interactions + mock_websocket.accept.assert_called_once() + + # Should have sent subscription acknowledgment + send_calls = mock_websocket.send_json.call_args_list + assert len(send_calls) >= 1 + + # Check subscription acknowledgment + first_response = send_calls[0][0][0] + assert first_response.get("type") == "subscribeAck" + assert first_response.get("status") == "success" + + # Should have sent state report + state_responses = [call[0][0] for call in send_calls + if call[0][0].get("type") == "stateReport"] + assert len(state_responses) >= 1 + + except Exception as e: + # WebSocket disconnect is expected at end of message sequence + pass + + @pytest.mark.asyncio + async def test_http_snapshot_workflow(self, temp_config_file, mock_frame): + """Test HTTP snapshot workflow.""" + + config = Configuration() + config.load_from_file(temp_config_file) + + stream_manager = StreamManager() + + with patch('requests.get') as mock_requests: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b"fake_jpeg_data" + mock_requests.return_value = mock_response + + with patch('cv2.imdecode') as mock_imdecode: + mock_imdecode.return_value = mock_frame + + try: + # Create HTTP snapshot stream + stream_config = StreamConfig( + stream_url="http://camera.example.com/snapshot.jpg", + stream_type="http_snapshot", + snapshot_interval=1.0 + ) + + stream_info = await stream_manager.create_stream( + camera_id="camera_002", + config=stream_config, + subscription_id="sub_002" + ) + + # Wait for snapshot capture + await asyncio.sleep(1.2) + + # Verify frame was captured + frame = stream_manager.get_latest_frame("camera_002") + assert frame is not None + assert np.array_equal(frame, mock_frame) + + # Verify HTTP request was made + mock_requests.assert_called() + mock_imdecode.assert_called() + + finally: + await stream_manager.stop_all_streams() + + @pytest.mark.asyncio + async def test_error_recovery_workflow(self, temp_config_file): + """Test error recovery and resilience.""" + + config = Configuration() + config.load_from_file(temp_config_file) + + stream_manager = StreamManager() + + with patch('cv2.VideoCapture') as mock_video_cap: + # Simulate connection failures then success + mock_cap_instances = [] + + # First attempt fails + mock_cap_fail = Mock() + mock_cap_fail.isOpened.return_value = False + mock_cap_instances.append(mock_cap_fail) + + # Second attempt succeeds + mock_cap_success = Mock() + mock_cap_success.isOpened.return_value = True + mock_cap_success.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8)) + mock_cap_instances.append(mock_cap_success) + + mock_video_cap.side_effect = mock_cap_instances + + try: + stream_config = StreamConfig( + stream_url="rtsp://unreliable.example.com/stream", + stream_type="rtsp", + max_retries=2 + ) + + # First attempt should fail + try: + await stream_manager.create_stream( + camera_id="camera_003", + config=stream_config, + subscription_id="sub_003" + ) + assert False, "Should have failed on first attempt" + except Exception: + pass + + # Retry should succeed + stream_info = await stream_manager.create_stream( + camera_id="camera_003", + config=stream_config, + subscription_id="sub_003" + ) + + assert stream_info is not None + assert mock_video_cap.call_count == 2 + + finally: + await stream_manager.stop_all_streams() + + @pytest.mark.asyncio + async def test_concurrent_streams_workflow(self, temp_config_file, mock_frame): + """Test handling multiple concurrent streams.""" + + config = Configuration() + config.load_from_file(temp_config_file) + + stream_manager = StreamManager() + + with patch('cv2.VideoCapture') as mock_video_cap, \ + patch('requests.get') as mock_requests: + + # Setup RTSP mock + mock_cap_instance = Mock() + mock_video_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + mock_cap_instance.read.return_value = (True, mock_frame) + + # Setup HTTP mock + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b"fake_jpeg_data" + mock_requests.return_value = mock_response + + with patch('cv2.imdecode', return_value=mock_frame): + + try: + # Create multiple streams concurrently + stream_tasks = [] + + # RTSP streams + for i in range(3): + config = StreamConfig( + stream_url=f"rtsp://camera{i}.example.com/stream", + stream_type="rtsp" + ) + task = stream_manager.create_stream( + camera_id=f"camera_rtsp_{i}", + config=config, + subscription_id=f"sub_rtsp_{i}" + ) + stream_tasks.append(task) + + # HTTP snapshot streams + for i in range(2): + config = StreamConfig( + stream_url=f"http://camera{i}.example.com/snapshot.jpg", + stream_type="http_snapshot" + ) + task = stream_manager.create_stream( + camera_id=f"camera_http_{i}", + config=config, + subscription_id=f"sub_http_{i}" + ) + stream_tasks.append(task) + + # Wait for all streams to be created + stream_results = await asyncio.gather(*stream_tasks, return_exceptions=True) + + # Verify all streams were created successfully + successful_streams = [r for r in stream_results if not isinstance(r, Exception)] + assert len(successful_streams) == 5 + + # Verify frames can be retrieved from all streams + await asyncio.sleep(0.5) # Allow time for frame capture + + for i in range(3): + frame = stream_manager.get_latest_frame(f"camera_rtsp_{i}") + assert frame is not None + + for i in range(2): + frame = stream_manager.get_latest_frame(f"camera_http_{i}") + # HTTP snapshots might not have frames immediately + + # Verify stream statistics + stats = stream_manager.get_stream_statistics() + assert stats["total_streams"] == 5 + assert stats["active_streams"] >= 3 # At least RTSP streams should be active + + finally: + await stream_manager.stop_all_streams() + + @pytest.mark.asyncio + async def test_memory_usage_workflow(self, temp_config_file): + """Test memory usage tracking and cleanup.""" + + config = Configuration() + config.load_from_file(temp_config_file) + + # Create managers with small limits for testing + stream_manager = StreamManager({"max_streams": 10}) + model_manager = ModelManager({"cache_max_size": 5}) + + with patch('cv2.VideoCapture') as mock_video_cap, \ + patch('torch.load') as mock_torch_load: + + mock_cap_instance = Mock() + mock_video_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + mock_cap_instance.read.return_value = (True, np.ones((100, 100, 3), dtype=np.uint8)) + + # Mock model loading + def create_mock_model(): + model = Mock() + # Mock model parameters for memory estimation + param1 = Mock() + param1.numel.return_value = 1000 + param1.element_size.return_value = 4 + model.parameters.return_value = [param1] + return model + + mock_torch_load.side_effect = lambda *args, **kwargs: create_mock_model() + + try: + # Create streams up to limit + for i in range(8): + stream_config = StreamConfig( + stream_url=f"rtsp://test{i}.example.com/stream", + stream_type="rtsp" + ) + await stream_manager.create_stream( + camera_id=f"test_camera_{i}", + config=stream_config, + subscription_id=f"test_sub_{i}" + ) + + # Load models up to cache limit + with patch('os.path.exists', return_value=True): + for i in range(7): + config = { + "model_id": f"test_model_{i}", + "model_path": f"/fake/path/model_{i}.pt", + "model_type": "detection", + "device": "cpu" + } + await model_manager.load_model_from_dict(config) + + # Check memory usage tracking + stream_stats = stream_manager.get_stream_statistics() + model_stats = model_manager.get_cache_statistics() + + assert stream_stats["total_streams"] == 8 + assert model_stats["size"] <= 5 # Should be limited by cache size + + # Test cleanup + cleaned_models = model_manager.cleanup_unused_models() + assert cleaned_models >= 0 + + stopped_streams = await stream_manager.stop_all_streams() + assert stopped_streams == 8 + + # Verify cleanup + final_stream_stats = stream_manager.get_stream_statistics() + assert final_stream_stats["total_streams"] == 0 + + finally: + await stream_manager.stop_all_streams() + + +def mock_open(read_data=""): + """Create a mock file opener.""" + from unittest.mock import mock_open as _mock_open + return _mock_open(read_data=read_data) \ No newline at end of file diff --git a/tests/integration/test_pipeline_integration.py b/tests/integration/test_pipeline_integration.py new file mode 100644 index 0000000..79839f4 --- /dev/null +++ b/tests/integration/test_pipeline_integration.py @@ -0,0 +1,738 @@ +""" +Integration tests for pipeline execution workflows. + +Tests the complete machine learning pipeline execution including +detection, classification, database updates, and Redis actions. +""" +import pytest +import asyncio +import json +import tempfile +import uuid +import time +from pathlib import Path +from unittest.mock import Mock, patch, AsyncMock +import numpy as np + +from detector_worker.pipeline.pipeline_executor import PipelineExecutor +from detector_worker.pipeline.action_executor import ActionExecutor +from detector_worker.pipeline.field_mapper import FieldMapper +from detector_worker.models.model_manager import ModelManager +from detector_worker.storage.database_manager import DatabaseManager +from detector_worker.storage.redis_client import RedisClient, RedisConfig +from detector_worker.detection.detection_result import DetectionResult, BoundingBox + + +@pytest.fixture +def sample_detection_pipeline(): + """Create sample detection pipeline configuration.""" + return { + "modelId": "car_frontal_detection_v1", + "modelFile": "car_frontal_detection_v1.pt", + "multiClass": True, + "expectedClasses": ["Car", "Frontal"], + "triggerClasses": ["Car", "Frontal"], + "minConfidence": 0.8, + "actions": [ + { + "type": "redis_save_image", + "region": "Frontal", + "key": "inference:{display_id}:{timestamp}:{session_id}:{filename}", + "expire_seconds": 600 + }, + { + "type": "postgresql_create_record", + "table": "car_frontal_info", + "fields": { + "display_id": "{display_id}", + "captured_timestamp": "{timestamp}", + "session_id": "{session_id}", + "license_character": None, + "license_type": "No model available" + } + } + ], + "branches": [ + { + "modelId": "car_brand_cls_v1", + "modelFile": "car_brand_cls_v1.pt", + "parallel": True, + "crop": True, + "cropClass": "Frontal", + "triggerClasses": ["Frontal"], + "minConfidence": 0.85 + }, + { + "modelId": "car_bodytype_cls_v1", + "modelFile": "car_bodytype_cls_v1.pt", + "parallel": True, + "crop": True, + "cropClass": "Frontal", + "triggerClasses": ["Frontal"], + "minConfidence": 0.80 + } + ], + "parallelActions": [ + { + "type": "postgresql_update_combined", + "table": "car_frontal_info", + "key_field": "session_id", + "waitForBranches": ["car_brand_cls_v1", "car_bodytype_cls_v1"], + "fields": { + "car_brand": "{car_brand_cls_v1.brand}", + "car_body_type": "{car_bodytype_cls_v1.body_type}" + } + } + ] + } + + +@pytest.fixture +def sample_frame(): + """Create sample frame for testing.""" + return np.ones((480, 640, 3), dtype=np.uint8) * 128 + + +@pytest.fixture +def detection_context(): + """Create sample detection context.""" + return { + "camera_id": "camera_001", + "display_id": "display_001", + "timestamp": int(time.time() * 1000), + "session_id": str(uuid.uuid4()), + "frame": np.ones((480, 640, 3), dtype=np.uint8) * 128, + "filename": "detection_image.jpg" + } + + +class TestPipelineIntegration: + """Test complete pipeline integration workflows.""" + + @pytest.mark.asyncio + async def test_complete_detection_classification_pipeline(self, sample_detection_pipeline, detection_context): + """Test complete detection to classification pipeline.""" + + pipeline_executor = PipelineExecutor() + model_manager = ModelManager() + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True), \ + patch('psycopg2.connect') as mock_db_connect, \ + patch('redis.Redis') as mock_redis: + + # Setup detection model mock + mock_detection_model = Mock() + mock_detection_result = Mock() + + # Mock successful multi-class detection + mock_detection_result.boxes = Mock() + mock_detection_result.boxes.xyxy = Mock() + mock_detection_result.boxes.conf = Mock() + mock_detection_result.boxes.cls = Mock() + mock_detection_result.names = {0: "Car", 1: "Frontal"} + + # Detection results: Car and Frontal detected with high confidence + mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [50, 100, 350, 450], # Car bbox + [150, 200, 300, 400] # Frontal bbox (within Car) + ]) + mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89]) + mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1]) + + mock_detection_model.return_value = mock_detection_result + + # Setup classification models + mock_brand_model = Mock() + mock_brand_result = Mock() + mock_brand_result.probs = Mock() + mock_brand_result.probs.top1 = 3 # Toyota index + mock_brand_result.probs.top1conf = Mock() + mock_brand_result.probs.top1conf.item.return_value = 0.87 + mock_brand_result.names = {3: "Toyota"} + mock_brand_model.return_value = mock_brand_result + + mock_bodytype_model = Mock() + mock_bodytype_result = Mock() + mock_bodytype_result.probs = Mock() + mock_bodytype_result.probs.top1 = 1 # Sedan index + mock_bodytype_result.probs.top1conf = Mock() + mock_bodytype_result.probs.top1conf.item.return_value = 0.82 + mock_bodytype_result.names = {1: "Sedan"} + mock_bodytype_model.return_value = mock_bodytype_result + + # Route model loading to appropriate mocks + def model_loader(path, **kwargs): + if "detection" in path: + return mock_detection_model + elif "brand" in path: + return mock_brand_model + elif "bodytype" in path: + return mock_bodytype_model + return Mock() + + mock_torch_load.side_effect = model_loader + + # Setup database mock + mock_db_conn = Mock() + mock_db_connect.return_value = mock_db_conn + mock_cursor = Mock() + mock_db_conn.cursor.return_value = mock_cursor + mock_cursor.fetchone.return_value = None + + # Setup Redis mock + mock_redis_instance = Mock() + mock_redis.return_value = mock_redis_instance + mock_redis_instance.ping.return_value = True + mock_redis_instance.set.return_value = True + mock_redis_instance.expire.return_value = True + + # Mock image encoding for Redis storage + with patch('cv2.imencode') as mock_imencode: + encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8) + mock_imencode.return_value = (True, encoded_data) + + # Execute complete pipeline + result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context) + + # Verify pipeline execution + assert result is not None + assert result.get("status") == "completed" + assert "detections" in result + + # Verify detection results + detections = result["detections"] + assert len(detections) == 2 # Car and Frontal + + detection_classes = [d.get("class") for d in detections] + assert "Car" in detection_classes + assert "Frontal" in detection_classes + + # Verify classification results + assert "classification_results" in result + classification_results = result["classification_results"] + + assert "car_brand_cls_v1" in classification_results + brand_result = classification_results["car_brand_cls_v1"] + assert brand_result.get("brand") == "Toyota" + assert brand_result.get("confidence") == 0.87 + + assert "car_bodytype_cls_v1" in classification_results + bodytype_result = classification_results["car_bodytype_cls_v1"] + assert bodytype_result.get("body_type") == "Sedan" + assert bodytype_result.get("confidence") == 0.82 + + # Verify database operations + db_calls = mock_cursor.execute.call_args_list + + # Should have INSERT for initial record creation + insert_calls = [call for call in db_calls if "INSERT" in str(call[0])] + assert len(insert_calls) >= 1 + + # Should have UPDATE for classification results + update_calls = [call for call in db_calls if "UPDATE" in str(call[0])] + assert len(update_calls) >= 1 + + # Verify Redis operations + assert mock_redis_instance.set.called + assert mock_redis_instance.expire.called + + @pytest.mark.asyncio + async def test_pipeline_with_missing_detections(self, sample_detection_pipeline, detection_context): + """Test pipeline behavior when expected detections are missing.""" + + pipeline_executor = PipelineExecutor() + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True): + + # Setup detection model that doesn't find expected classes + mock_detection_model = Mock() + mock_detection_result = Mock() + mock_detection_result.boxes = Mock() + mock_detection_result.boxes.xyxy = Mock() + mock_detection_result.boxes.conf = Mock() + mock_detection_result.boxes.cls = Mock() + mock_detection_result.names = {0: "Car", 1: "Frontal"} + + # Only detect Car, no Frontal + mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [50, 100, 350, 450] # Only Car bbox + ]) + mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92]) + mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0]) + + mock_detection_model.return_value = mock_detection_result + mock_torch_load.return_value = mock_detection_model + + # Execute pipeline + result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context) + + # Pipeline should complete but skip classification branches + assert result is not None + assert "detections" in result + + detections = result["detections"] + assert len(detections) == 1 # Only Car detected + assert detections[0].get("class") == "Car" + + # Classification should not have run (no Frontal detected) + classification_results = result.get("classification_results", {}) + assert len(classification_results) == 0 or all( + not res for res in classification_results.values() + ) + + @pytest.mark.asyncio + async def test_pipeline_with_low_confidence_detections(self, sample_detection_pipeline, detection_context): + """Test pipeline with detections below confidence threshold.""" + + pipeline_executor = PipelineExecutor() + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True): + + mock_detection_model = Mock() + mock_detection_result = Mock() + mock_detection_result.boxes = Mock() + mock_detection_result.boxes.xyxy = Mock() + mock_detection_result.boxes.conf = Mock() + mock_detection_result.boxes.cls = Mock() + mock_detection_result.names = {0: "Car", 1: "Frontal"} + + # Detections with low confidence (below 0.8 threshold) + mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [50, 100, 350, 450], # Car bbox + [150, 200, 300, 400] # Frontal bbox + ]) + mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.75, 0.70]) # Below threshold + mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1]) + + mock_detection_model.return_value = mock_detection_result + mock_torch_load.return_value = mock_detection_model + + # Execute pipeline + result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context) + + # Should complete but with filtered detections + assert result is not None + + # Low confidence detections should be filtered out + detections = result.get("detections", []) + high_conf_detections = [d for d in detections if d.get("confidence", 0) >= 0.8] + assert len(high_conf_detections) == 0 + + @pytest.mark.asyncio + async def test_pipeline_branch_execution_order(self, sample_detection_pipeline, detection_context): + """Test that pipeline branches execute in correct order and parallel mode works.""" + + pipeline_executor = PipelineExecutor() + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True), \ + patch('psycopg2.connect') as mock_db_connect: + + # Track execution order + execution_order = [] + + # Setup detection model + mock_detection_model = Mock() + mock_detection_result = Mock() + mock_detection_result.boxes = Mock() + mock_detection_result.boxes.xyxy = Mock() + mock_detection_result.boxes.conf = Mock() + mock_detection_result.boxes.cls = Mock() + mock_detection_result.names = {0: "Car", 1: "Frontal"} + + mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [50, 100, 350, 450], [150, 200, 300, 400] + ]) + mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89]) + mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1]) + + def track_detection_execution(*args, **kwargs): + execution_order.append("detection") + return mock_detection_result + + mock_detection_model.side_effect = track_detection_execution + + # Setup classification models with execution tracking + def create_tracked_model(model_id): + def track_execution(*args, **kwargs): + execution_order.append(model_id) + result = Mock() + result.probs = Mock() + result.probs.top1 = 0 + result.probs.top1conf = Mock() + result.probs.top1conf.item.return_value = 0.90 + result.names = {0: "TestResult"} + return result + + model = Mock() + model.side_effect = track_execution + return model + + # Route models with execution tracking + def model_loader(path, **kwargs): + if "detection" in path: + return mock_detection_model + elif "brand" in path: + return create_tracked_model("car_brand_cls_v1") + elif "bodytype" in path: + return create_tracked_model("car_bodytype_cls_v1") + return Mock() + + mock_torch_load.side_effect = model_loader + + # Setup database mock + mock_db_conn = Mock() + mock_db_connect.return_value = mock_db_conn + mock_cursor = Mock() + mock_db_conn.cursor.return_value = mock_cursor + + # Execute pipeline + result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context) + + # Verify execution order + assert "detection" in execution_order + assert execution_order[0] == "detection" # Detection should run first + + # Classification models should run after detection + brand_index = execution_order.index("car_brand_cls_v1") if "car_brand_cls_v1" in execution_order else -1 + bodytype_index = execution_order.index("car_bodytype_cls_v1") if "car_bodytype_cls_v1" in execution_order else -1 + detection_index = execution_order.index("detection") + + if brand_index >= 0: + assert brand_index > detection_index + if bodytype_index >= 0: + assert bodytype_index > detection_index + + # Since branches are parallel, they could run in any order relative to each other + # but both should run after detection + + @pytest.mark.asyncio + async def test_pipeline_error_recovery(self, sample_detection_pipeline, detection_context): + """Test pipeline error handling and recovery.""" + + pipeline_executor = PipelineExecutor() + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True), \ + patch('psycopg2.connect') as mock_db_connect: + + # Setup detection model that works + mock_detection_model = Mock() + mock_detection_result = Mock() + mock_detection_result.boxes = Mock() + mock_detection_result.boxes.xyxy = Mock() + mock_detection_result.boxes.conf = Mock() + mock_detection_result.boxes.cls = Mock() + mock_detection_result.names = {0: "Car", 1: "Frontal"} + + mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [50, 100, 350, 450], [150, 200, 300, 400] + ]) + mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89]) + mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1]) + + mock_detection_model.return_value = mock_detection_result + + # Setup classification models - one fails, one succeeds + mock_brand_model = Mock() + mock_brand_model.side_effect = RuntimeError("Model inference failed") + + mock_bodytype_model = Mock() + mock_bodytype_result = Mock() + mock_bodytype_result.probs = Mock() + mock_bodytype_result.probs.top1 = 1 + mock_bodytype_result.probs.top1conf = Mock() + mock_bodytype_result.probs.top1conf.item.return_value = 0.85 + mock_bodytype_result.names = {1: "SUV"} + mock_bodytype_model.return_value = mock_bodytype_result + + def model_loader(path, **kwargs): + if "detection" in path: + return mock_detection_model + elif "brand" in path: + return mock_brand_model + elif "bodytype" in path: + return mock_bodytype_model + return Mock() + + mock_torch_load.side_effect = model_loader + + # Setup database mock + mock_db_conn = Mock() + mock_db_connect.return_value = mock_db_conn + mock_cursor = Mock() + mock_db_conn.cursor.return_value = mock_cursor + + # Execute pipeline + result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context) + + # Pipeline should complete despite one branch failing + assert result is not None + + # Detection should succeed + assert "detections" in result + detections = result["detections"] + assert len(detections) == 2 + + # Classification results should be partial + classification_results = result.get("classification_results", {}) + + # Brand classification should have failed + brand_result = classification_results.get("car_brand_cls_v1") + assert brand_result is None or brand_result.get("error") is not None + + # Body type classification should have succeeded + bodytype_result = classification_results.get("car_bodytype_cls_v1") + assert bodytype_result is not None + assert bodytype_result.get("body_type") == "SUV" + assert bodytype_result.get("confidence") == 0.85 + + @pytest.mark.asyncio + async def test_field_mapping_and_database_update(self, sample_detection_pipeline, detection_context): + """Test field mapping and database update integration.""" + + pipeline_executor = PipelineExecutor() + field_mapper = FieldMapper() + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True), \ + patch('psycopg2.connect') as mock_db_connect: + + # Setup successful detection and classification + mock_detection_model = Mock() + mock_detection_result = Mock() + mock_detection_result.boxes = Mock() + mock_detection_result.boxes.xyxy = Mock() + mock_detection_result.boxes.conf = Mock() + mock_detection_result.boxes.cls = Mock() + mock_detection_result.names = {0: "Car", 1: "Frontal"} + + mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [50, 100, 350, 450], [150, 200, 300, 400] + ]) + mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89]) + mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1]) + mock_detection_model.return_value = mock_detection_result + + # Setup classification models + mock_brand_model = Mock() + mock_brand_result = Mock() + mock_brand_result.probs = Mock() + mock_brand_result.probs.top1 = 2 + mock_brand_result.probs.top1conf = Mock() + mock_brand_result.probs.top1conf.item.return_value = 0.88 + mock_brand_result.names = {2: "Honda"} + mock_brand_model.return_value = mock_brand_result + + mock_bodytype_model = Mock() + mock_bodytype_result = Mock() + mock_bodytype_result.probs = Mock() + mock_bodytype_result.probs.top1 = 0 + mock_bodytype_result.probs.top1conf = Mock() + mock_bodytype_result.probs.top1conf.item.return_value = 0.91 + mock_bodytype_result.names = {0: "Hatchback"} + mock_bodytype_model.return_value = mock_bodytype_result + + def model_loader(path, **kwargs): + if "detection" in path: + return mock_detection_model + elif "brand" in path: + return mock_brand_model + elif "bodytype" in path: + return mock_bodytype_model + return Mock() + + mock_torch_load.side_effect = model_loader + + # Setup database mock + mock_db_conn = Mock() + mock_db_connect.return_value = mock_db_conn + mock_cursor = Mock() + mock_db_conn.cursor.return_value = mock_cursor + + # Execute pipeline + result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context) + + # Verify pipeline completed successfully + assert result is not None + assert result.get("status") == "completed" + + # Check database operations + db_calls = mock_cursor.execute.call_args_list + + # Should have INSERT and UPDATE operations + insert_calls = [call for call in db_calls if "INSERT" in str(call[0])] + update_calls = [call for call in db_calls if "UPDATE" in str(call[0])] + + assert len(insert_calls) >= 1 + assert len(update_calls) >= 1 + + # Check that UPDATE includes field mapping results + update_sql = str(update_calls[0][0]) + assert "car_brand" in update_sql.lower() + assert "car_body_type" in update_sql.lower() + + # Check that classification results were properly mapped + classification_results = result.get("classification_results", {}) + assert "car_brand_cls_v1" in classification_results + assert "car_bodytype_cls_v1" in classification_results + + brand_result = classification_results["car_brand_cls_v1"] + bodytype_result = classification_results["car_bodytype_cls_v1"] + + assert brand_result.get("brand") == "Honda" + assert brand_result.get("confidence") == 0.88 + assert bodytype_result.get("body_type") == "Hatchback" + assert bodytype_result.get("confidence") == 0.91 + + @pytest.mark.asyncio + async def test_redis_image_storage_integration(self, sample_detection_pipeline, detection_context): + """Test Redis image storage integration in pipeline.""" + + pipeline_executor = PipelineExecutor() + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True), \ + patch('redis.Redis') as mock_redis, \ + patch('cv2.imencode') as mock_imencode: + + # Setup successful detection + mock_detection_model = Mock() + mock_detection_result = Mock() + mock_detection_result.boxes = Mock() + mock_detection_result.boxes.xyxy = Mock() + mock_detection_result.boxes.conf = Mock() + mock_detection_result.boxes.cls = Mock() + mock_detection_result.names = {0: "Car", 1: "Frontal"} + + mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [50, 100, 350, 450], [150, 200, 300, 400] + ]) + mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89]) + mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1]) + mock_detection_model.return_value = mock_detection_result + + mock_torch_load.return_value = mock_detection_model + + # Setup Redis mock + mock_redis_instance = Mock() + mock_redis.return_value = mock_redis_instance + mock_redis_instance.ping.return_value = True + mock_redis_instance.set.return_value = True + mock_redis_instance.expire.return_value = True + + # Setup image encoding mock + encoded_data = np.array([1, 2, 3, 4, 5], dtype=np.uint8) + mock_imencode.return_value = (True, encoded_data) + + # Execute pipeline + result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context) + + # Verify Redis operations + assert mock_redis_instance.set.called + assert mock_redis_instance.expire.called + + # Check that image was encoded + assert mock_imencode.called + + # Verify correct key format was used + set_call = mock_redis_instance.set.call_args + redis_key = set_call[0][0] + + # Key should contain display_id, timestamp, session_id + assert detection_context["display_id"] in redis_key + assert detection_context["session_id"] in redis_key + assert str(detection_context["timestamp"]) in redis_key + + # Should set expiration + expire_call = mock_redis_instance.expire.call_args + expire_key = expire_call[0][0] + expire_seconds = expire_call[0][1] + + assert expire_key == redis_key + assert expire_seconds == 600 # As configured in pipeline + + @pytest.mark.asyncio + async def test_pipeline_performance_timing(self, sample_detection_pipeline, detection_context): + """Test pipeline execution timing and performance.""" + + pipeline_executor = PipelineExecutor() + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True), \ + patch('psycopg2.connect') as mock_db_connect, \ + patch('redis.Redis') as mock_redis, \ + patch('cv2.imencode') as mock_imencode: + + # Setup fast mocks + mock_detection_model = Mock() + mock_detection_result = Mock() + mock_detection_result.boxes = Mock() + mock_detection_result.boxes.xyxy = Mock() + mock_detection_result.boxes.conf = Mock() + mock_detection_result.boxes.cls = Mock() + mock_detection_result.names = {0: "Car", 1: "Frontal"} + + mock_detection_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [50, 100, 350, 450], [150, 200, 300, 400] + ]) + mock_detection_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.92, 0.89]) + mock_detection_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1]) + mock_detection_model.return_value = mock_detection_result + + # Setup fast classification models + def create_fast_model(): + model = Mock() + result = Mock() + result.probs = Mock() + result.probs.top1 = 0 + result.probs.top1conf = Mock() + result.probs.top1conf.item.return_value = 0.90 + result.names = {0: "TestClass"} + model.return_value = result + return model + + def model_loader(path, **kwargs): + if "detection" in path: + return mock_detection_model + else: + return create_fast_model() + + mock_torch_load.side_effect = model_loader + + # Setup fast database and Redis + mock_db_conn = Mock() + mock_db_connect.return_value = mock_db_conn + mock_cursor = Mock() + mock_db_conn.cursor.return_value = mock_cursor + + mock_redis_instance = Mock() + mock_redis.return_value = mock_redis_instance + mock_redis_instance.ping.return_value = True + mock_redis_instance.set.return_value = True + mock_redis_instance.expire.return_value = True + + encoded_data = np.array([1, 2, 3], dtype=np.uint8) + mock_imencode.return_value = (True, encoded_data) + + # Measure execution time + start_time = time.time() + + result = await pipeline_executor.execute_pipeline(sample_detection_pipeline, detection_context) + + end_time = time.time() + execution_time = end_time - start_time + + # Pipeline should complete quickly (less than 1 second with mocks) + assert execution_time < 1.0 + + # Should have timing information in result + assert result is not None + if "execution_time" in result: + assert result["execution_time"] > 0 + + # Verify pipeline completed successfully + assert result.get("status") == "completed" \ No newline at end of file diff --git a/tests/integration/test_websocket_protocol.py b/tests/integration/test_websocket_protocol.py new file mode 100644 index 0000000..d874421 --- /dev/null +++ b/tests/integration/test_websocket_protocol.py @@ -0,0 +1,579 @@ +""" +Integration tests for WebSocket protocol compliance. + +Tests the complete WebSocket communication protocol to ensure +compatibility with existing clients and proper message handling. +""" +import pytest +import asyncio +import json +import uuid +from unittest.mock import Mock, AsyncMock, patch +from fastapi.websockets import WebSocket +from fastapi.testclient import TestClient + +from detector_worker.app import create_app +from detector_worker.communication.websocket_handler import WebSocketHandler +from detector_worker.communication.message_processor import MessageProcessor, MessageType +from detector_worker.core.exceptions import MessageProcessingError + + +@pytest.fixture +def test_app(): + """Create test FastAPI application.""" + return create_app() + + +@pytest.fixture +def mock_websocket(): + """Create mock WebSocket for testing.""" + websocket = Mock(spec=WebSocket) + websocket.accept = AsyncMock() + websocket.send_json = AsyncMock() + websocket.send_text = AsyncMock() + websocket.receive_json = AsyncMock() + websocket.receive_text = AsyncMock() + websocket.close = AsyncMock() + websocket.ping = AsyncMock() + return websocket + + +class TestWebSocketProtocol: + """Test WebSocket protocol compliance.""" + + @pytest.mark.asyncio + async def test_subscription_message_protocol(self, mock_websocket): + """Test subscription message handling protocol.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Mock external dependencies + with patch('cv2.VideoCapture') as mock_video_cap, \ + patch('torch.load') as mock_torch_load, \ + patch('builtins.open') as mock_file_open: + + # Setup video capture mock + mock_cap_instance = Mock() + mock_video_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + + # Setup model loading mock + mock_torch_load.return_value = Mock() + + # Setup pipeline file mock + pipeline_config = { + "modelId": "test_detection_model", + "modelFile": "test_model.pt", + "expectedClasses": ["Car"], + "minConfidence": 0.8 + } + mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config) + + # Test message sequence + subscription_message = { + "type": "subscribe", + "payload": { + "subscriptionIdentifier": "display-001;cam-001", + "rtspUrl": "rtsp://example.com/stream", + "modelUrl": "http://example.com/model.mpta", + "modelId": 101, + "modelName": "Test Detection", + "cropX1": 0, + "cropY1": 0, + "cropX2": 640, + "cropY2": 480 + } + } + + request_state_message = { + "type": "requestState" + } + + unsubscribe_message = { + "type": "unsubscribe", + "payload": { + "subscriptionIdentifier": "display-001;cam-001" + } + } + + # Mock WebSocket message sequence + mock_websocket.receive_json.side_effect = [ + subscription_message, + request_state_message, + unsubscribe_message, + asyncio.CancelledError() # Simulate client disconnect + ] + + client_id = "test_client_001" + + try: + # Handle WebSocket connection + await websocket_handler.handle_websocket(mock_websocket, client_id) + except asyncio.CancelledError: + pass # Expected when client disconnects + + # Verify protocol compliance + mock_websocket.accept.assert_called_once() + + # Check sent messages + sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list] + + # Should receive subscription acknowledgment + subscribe_acks = [msg for msg in sent_messages if msg.get("type") == "subscribeAck"] + assert len(subscribe_acks) >= 1 + + subscribe_ack = subscribe_acks[0] + assert subscribe_ack["status"] in ["success", "error"] + if subscribe_ack["status"] == "success": + assert "subscriptionId" in subscribe_ack + + # Should receive state report + state_reports = [msg for msg in sent_messages if msg.get("type") == "stateReport"] + assert len(state_reports) >= 1 + + state_report = state_reports[0] + assert "payload" in state_report + assert "subscriptions" in state_report["payload"] + assert "system" in state_report["payload"] + + # Should receive unsubscribe acknowledgment + unsubscribe_acks = [msg for msg in sent_messages if msg.get("type") == "unsubscribeAck"] + assert len(unsubscribe_acks) >= 1 + + unsubscribe_ack = unsubscribe_acks[0] + assert unsubscribe_ack["status"] in ["success", "error"] + + @pytest.mark.asyncio + async def test_invalid_message_handling(self, mock_websocket): + """Test handling of invalid messages.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Test invalid message types + invalid_messages = [ + {"invalid": "message"}, # Missing type + {"type": "unknown_type", "payload": {}}, # Unknown type + {"type": "subscribe"}, # Missing payload + {"type": "subscribe", "payload": {}}, # Missing required fields + {"type": "subscribe", "payload": {"subscriptionIdentifier": "test"}}, # Missing URL + ] + + mock_websocket.receive_json.side_effect = invalid_messages + [asyncio.CancelledError()] + + client_id = "test_client_error" + + try: + await websocket_handler.handle_websocket(mock_websocket, client_id) + except asyncio.CancelledError: + pass + + # Verify error responses + sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list] + error_messages = [msg for msg in sent_messages if msg.get("type") == "error"] + + # Should receive error responses for invalid messages + assert len(error_messages) >= len(invalid_messages) + + for error_msg in error_messages: + assert "message" in error_msg + assert error_msg["message"] # Non-empty error message + + @pytest.mark.asyncio + async def test_session_management_protocol(self, mock_websocket): + """Test session management protocol.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + session_id = str(uuid.uuid4()) + + # Test session messages + set_session_message = { + "type": "setSessionId", + "payload": { + "sessionId": session_id, + "displayId": "display-001" + } + } + + patch_session_message = { + "type": "patchSession", + "payload": { + "sessionId": session_id, + "data": { + "car_brand": "Toyota", + "confidence": 0.92 + } + } + } + + mock_websocket.receive_json.side_effect = [ + set_session_message, + patch_session_message, + asyncio.CancelledError() + ] + + client_id = "test_client_session" + + try: + await websocket_handler.handle_websocket(mock_websocket, client_id) + except asyncio.CancelledError: + pass + + # Verify session responses + sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list] + + # Should receive acknowledgments for session operations + set_session_acks = [msg for msg in sent_messages + if msg.get("type") == "setSessionIdAck"] + assert len(set_session_acks) >= 1 + + patch_session_acks = [msg for msg in sent_messages + if msg.get("type") == "patchSessionAck"] + assert len(patch_session_acks) >= 1 + + @pytest.mark.asyncio + async def test_heartbeat_protocol(self, mock_websocket): + """Test heartbeat/ping-pong protocol.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor, {"heartbeat_interval": 0.1}) + + # Simulate long-running connection + mock_websocket.receive_json.side_effect = [ + asyncio.TimeoutError(), # No messages for a while + asyncio.TimeoutError(), + asyncio.CancelledError() # Then disconnect + ] + + client_id = "test_client_heartbeat" + + # Start heartbeat task + heartbeat_task = asyncio.create_task(websocket_handler._heartbeat_loop()) + + try: + # Let heartbeat run briefly + await asyncio.sleep(0.2) + + # Cancel heartbeat + heartbeat_task.cancel() + await heartbeat_task + + except asyncio.CancelledError: + pass + + # Verify ping was sent + assert mock_websocket.ping.called + + @pytest.mark.asyncio + async def test_detection_result_protocol(self, mock_websocket): + """Test detection result message protocol.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Mock a detection result being sent + detection_result = { + "type": "imageDetection", + "payload": { + "subscriptionId": "display-001;cam-001", + "detections": [ + { + "class": "Car", + "confidence": 0.95, + "bbox": [100, 200, 300, 400], + "trackId": 1001 + } + ], + "timestamp": 1640995200000, + "modelInfo": { + "modelId": 101, + "modelName": "Vehicle Detection" + } + } + } + + # Send detection result to client + await websocket_handler.send_to_client("test_client", detection_result) + + # Verify message was sent + mock_websocket.send_json.assert_called_with(detection_result) + + @pytest.mark.asyncio + async def test_error_recovery_protocol(self, mock_websocket): + """Test error recovery and graceful degradation.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Simulate WebSocket errors + mock_websocket.send_json.side_effect = [ + None, # First message succeeds + ConnectionError("Connection lost"), # Second fails + None # Third succeeds after recovery + ] + + # Try to send multiple messages + messages = [ + {"type": "test", "message": "1"}, + {"type": "test", "message": "2"}, + {"type": "test", "message": "3"} + ] + + results = [] + for msg in messages: + try: + await websocket_handler.send_to_client("test_client", msg) + results.append("success") + except Exception: + results.append("error") + + # Should handle errors gracefully + assert "error" in results + # But should still be able to send other messages + assert "success" in results + + @pytest.mark.asyncio + async def test_concurrent_client_protocol(self, mock_websocket): + """Test handling multiple concurrent clients.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Create multiple mock WebSocket connections + mock_websockets = [] + for i in range(3): + ws = Mock(spec=WebSocket) + ws.accept = AsyncMock() + ws.send_json = AsyncMock() + ws.receive_json = AsyncMock(side_effect=[asyncio.CancelledError()]) + mock_websockets.append(ws) + + # Handle multiple clients concurrently + client_tasks = [] + for i, ws in enumerate(mock_websockets): + task = asyncio.create_task( + websocket_handler.handle_websocket(ws, f"client_{i}") + ) + client_tasks.append(task) + + # Wait briefly then cancel all + await asyncio.sleep(0.1) + for task in client_tasks: + task.cancel() + + try: + await asyncio.gather(*client_tasks) + except asyncio.CancelledError: + pass + + # Verify all connections were accepted + for ws in mock_websockets: + ws.accept.assert_called_once() + + @pytest.mark.asyncio + async def test_subscription_sharing_protocol(self, mock_websocket): + """Test shared subscription protocol.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Mock multiple clients subscribing to same camera + with patch('cv2.VideoCapture') as mock_video_cap: + mock_cap_instance = Mock() + mock_video_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + + # First client subscribes + subscription_msg1 = { + "type": "subscribe", + "payload": { + "subscriptionIdentifier": "display-001;cam-001", + "rtspUrl": "rtsp://shared.example.com/stream", + "modelUrl": "http://example.com/model.mpta" + } + } + + # Second client subscribes to same camera + subscription_msg2 = { + "type": "subscribe", + "payload": { + "subscriptionIdentifier": "display-002;cam-001", + "rtspUrl": "rtsp://shared.example.com/stream", # Same URL + "modelUrl": "http://example.com/model.mpta" + } + } + + # Mock file operations + with patch('builtins.open') as mock_file_open: + pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]} + mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config) + + # Process both subscriptions + response1 = await message_processor.process_message(subscription_msg1, "client_1") + response2 = await message_processor.process_message(subscription_msg2, "client_2") + + # Both should succeed and reference same underlying stream + assert response1.get("status") == "success" + assert response2.get("status") == "success" + + # Should only create one video capture instance (shared stream) + assert mock_video_cap.call_count == 1 + + @pytest.mark.asyncio + async def test_message_ordering_protocol(self, mock_websocket): + """Test message ordering and sequencing.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Test sequence of related messages + messages = [ + {"type": "subscribe", "payload": {"subscriptionIdentifier": "test", "rtspUrl": "rtsp://test.com"}}, + {"type": "setSessionId", "payload": {"sessionId": "session_123", "displayId": "display_001"}}, + {"type": "requestState"}, + {"type": "patchSession", "payload": {"sessionId": "session_123", "data": {"test": "data"}}}, + {"type": "unsubscribe", "payload": {"subscriptionIdentifier": "test"}} + ] + + mock_websocket.receive_json.side_effect = messages + [asyncio.CancelledError()] + + with patch('cv2.VideoCapture') as mock_video_cap, \ + patch('builtins.open') as mock_file_open: + + mock_cap_instance = Mock() + mock_video_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + + pipeline_config = {"modelId": "test", "expectedClasses": ["Car"]} + mock_file_open.return_value.__enter__.return_value.read.return_value = json.dumps(pipeline_config) + + client_id = "test_client_ordering" + + try: + await websocket_handler.handle_websocket(mock_websocket, client_id) + except asyncio.CancelledError: + pass + + # Verify responses were sent in correct order + sent_messages = [call[0][0] for call in mock_websocket.send_json.call_args_list] + + # Should receive responses for each message type + response_types = [msg.get("type") for msg in sent_messages] + + expected_types = ["subscribeAck", "setSessionIdAck", "stateReport", "patchSessionAck", "unsubscribeAck"] + + # Check that we got appropriate responses (order may vary slightly) + for expected_type in expected_types: + assert any(expected_type in response_types), f"Missing response type: {expected_type}" + + +class TestWebSocketPerformance: + """Test WebSocket performance characteristics.""" + + @pytest.mark.asyncio + async def test_message_throughput(self, mock_websocket): + """Test message processing throughput.""" + + message_processor = MessageProcessor() + + # Prepare batch of simple messages + state_request = {"type": "requestState"} + + import time + start_time = time.time() + + # Process many messages quickly + for _ in range(100): + await message_processor.process_message(state_request, "test_client") + + end_time = time.time() + processing_time = end_time - start_time + + # Should process messages quickly (less than 1 second for 100 messages) + assert processing_time < 1.0 + + # Calculate throughput + throughput = 100 / processing_time + assert throughput > 100 # Should handle > 100 messages/second + + @pytest.mark.asyncio + async def test_concurrent_message_handling(self, mock_websocket): + """Test concurrent message handling.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Create multiple mock clients + num_clients = 10 + mock_websockets = [] + + for i in range(num_clients): + ws = Mock(spec=WebSocket) + ws.accept = AsyncMock() + ws.send_json = AsyncMock() + ws.receive_json = AsyncMock(side_effect=[ + {"type": "requestState"}, + asyncio.CancelledError() + ]) + mock_websockets.append(ws) + + # Handle all clients concurrently + client_tasks = [] + for i, ws in enumerate(mock_websockets): + task = asyncio.create_task( + websocket_handler.handle_websocket(ws, f"perf_client_{i}") + ) + client_tasks.append(task) + + start_time = time.time() + + # Wait for all to complete + try: + await asyncio.gather(*client_tasks) + except asyncio.CancelledError: + pass + + end_time = time.time() + total_time = end_time - start_time + + # Should handle all clients efficiently + assert total_time < 2.0 # Should complete in less than 2 seconds + + # All clients should have been accepted + for ws in mock_websockets: + ws.accept.assert_called_once() + + @pytest.mark.asyncio + async def test_memory_usage_stability(self, mock_websocket): + """Test memory usage remains stable.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Simulate many connection/disconnection cycles + for cycle in range(10): + # Create client + client_id = f"memory_test_client_{cycle}" + + mock_websocket.receive_json.side_effect = [ + {"type": "requestState"}, + asyncio.CancelledError() + ] + + try: + await websocket_handler.handle_websocket(mock_websocket, client_id) + except asyncio.CancelledError: + pass + + # Reset mock for next cycle + mock_websocket.reset_mock() + mock_websocket.accept = AsyncMock() + mock_websocket.send_json = AsyncMock() + mock_websocket.receive_json = AsyncMock() + + # Connection manager should not accumulate stale connections + stats = websocket_handler.get_connection_stats() + assert stats["total_connections"] == 0 # All should be cleaned up \ No newline at end of file diff --git a/tests/performance/__init__.py b/tests/performance/__init__.py new file mode 100644 index 0000000..2979f63 --- /dev/null +++ b/tests/performance/__init__.py @@ -0,0 +1,19 @@ +""" +Performance tests for the detector worker application. + +This package contains performance benchmarks and load tests to ensure +the application meets scalability and throughput requirements. +""" + +# Performance test modules +from . import ( + test_detection_performance, + test_websocket_performance, + test_storage_performance +) + +__all__ = [ + "test_detection_performance", + "test_websocket_performance", + "test_storage_performance" +] \ No newline at end of file diff --git a/tests/performance/test_detection_performance.py b/tests/performance/test_detection_performance.py new file mode 100644 index 0000000..82cef4d --- /dev/null +++ b/tests/performance/test_detection_performance.py @@ -0,0 +1,672 @@ +""" +Performance tests for detection pipeline components. + +These tests benchmark the performance of key detection pipeline +components to ensure they meet performance requirements. +""" +import pytest +import time +import asyncio +import statistics +from unittest.mock import Mock, patch +import numpy as np +import psutil +import gc + +from detector_worker.detection.yolo_detector import YOLODetector +from detector_worker.detection.tracking_manager import TrackingManager +from detector_worker.detection.stability_validator import StabilityValidator +from detector_worker.pipeline.pipeline_executor import PipelineExecutor +from detector_worker.models.model_manager import ModelManager +from detector_worker.streams.stream_manager import StreamManager + + +@pytest.fixture +def sample_frame(): + """Create a sample frame for performance testing.""" + return np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + + +@pytest.fixture +def large_frame(): + """Create a large frame for stress testing.""" + return np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8) + + +@pytest.fixture +def performance_config(): + """Configuration for performance tests.""" + return { + "target_fps": 30, + "max_detection_time_ms": 100, + "max_tracking_time_ms": 50, + "max_pipeline_time_ms": 500, + "memory_limit_mb": 1024 + } + + +class TestDetectionPerformance: + """Test detection performance benchmarks.""" + + def test_yolo_detection_speed(self, sample_frame, performance_config): + """Benchmark YOLO detection speed.""" + + detector = YOLODetector() + + with patch('torch.load') as mock_torch_load: + # Setup fast mock model + mock_model = Mock() + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.xyxy = Mock() + mock_result.boxes.conf = Mock() + mock_result.boxes.cls = Mock() + mock_result.names = {0: "car", 1: "person"} + + # Mock detection results + mock_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [100, 200, 300, 400], + [150, 250, 350, 450] + ]) + mock_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.9, 0.8]) + mock_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1]) + + mock_model.return_value = mock_result + mock_torch_load.return_value = mock_model + + # Warm up + for _ in range(5): + detector.detect(sample_frame, confidence_threshold=0.5) + + # Benchmark detection speed + detection_times = [] + num_iterations = 100 + + for _ in range(num_iterations): + start_time = time.perf_counter() + detections = detector.detect(sample_frame, confidence_threshold=0.5) + end_time = time.perf_counter() + + detection_time_ms = (end_time - start_time) * 1000 + detection_times.append(detection_time_ms) + + # Calculate statistics + avg_detection_time = statistics.mean(detection_times) + median_detection_time = statistics.median(detection_times) + max_detection_time = max(detection_times) + min_detection_time = min(detection_times) + + # Performance assertions + assert avg_detection_time < performance_config["max_detection_time_ms"] + assert median_detection_time < performance_config["max_detection_time_ms"] + + # Calculate theoretical FPS + theoretical_fps = 1000 / avg_detection_time + assert theoretical_fps >= performance_config["target_fps"] + + print(f"\nDetection Performance Metrics:") + print(f"Average detection time: {avg_detection_time:.2f} ms") + print(f"Median detection time: {median_detection_time:.2f} ms") + print(f"Min detection time: {min_detection_time:.2f} ms") + print(f"Max detection time: {max_detection_time:.2f} ms") + print(f"Theoretical FPS: {theoretical_fps:.1f}") + + def test_tracking_performance(self, sample_frame, performance_config): + """Benchmark object tracking performance.""" + + tracking_manager = TrackingManager() + + # Create mock detections + detections = [ + {"class": "car", "confidence": 0.9, "bbox": [100, 200, 300, 400]}, + {"class": "car", "confidence": 0.8, "bbox": [150, 250, 350, 450]}, + {"class": "person", "confidence": 0.7, "bbox": [200, 300, 250, 400]} + ] + + # Warm up tracking + for i in range(10): + tracking_manager.update_tracks(detections, frame_id=i) + + # Benchmark tracking speed + tracking_times = [] + num_iterations = 100 + + for i in range(num_iterations): + # Simulate moving detections + moving_detections = [] + for det in detections: + moved_det = det.copy() + # Add small random movement + bbox = moved_det["bbox"] + moved_det["bbox"] = [ + bbox[0] + np.random.randint(-5, 5), + bbox[1] + np.random.randint(-5, 5), + bbox[2] + np.random.randint(-5, 5), + bbox[3] + np.random.randint(-5, 5) + ] + moving_detections.append(moved_det) + + start_time = time.perf_counter() + tracks = tracking_manager.update_tracks(moving_detections, frame_id=i + 10) + end_time = time.perf_counter() + + tracking_time_ms = (end_time - start_time) * 1000 + tracking_times.append(tracking_time_ms) + + # Calculate statistics + avg_tracking_time = statistics.mean(tracking_times) + max_tracking_time = max(tracking_times) + + # Performance assertions + assert avg_tracking_time < performance_config["max_tracking_time_ms"] + assert max_tracking_time < performance_config["max_tracking_time_ms"] * 2 + + print(f"\nTracking Performance Metrics:") + print(f"Average tracking time: {avg_tracking_time:.2f} ms") + print(f"Max tracking time: {max_tracking_time:.2f} ms") + + def test_stability_validation_performance(self, performance_config): + """Benchmark stability validation performance.""" + + validator = StabilityValidator() + + # Create stable detections sequence + base_detection = { + "class": "car", + "confidence": 0.9, + "bbox": [100, 200, 300, 400], + "track_id": 1001 + } + + # Add sequence of stable detections + for i in range(20): + detection = base_detection.copy() + # Add small variations to simulate real detection noise + detection["confidence"] = 0.9 + np.random.normal(0, 0.02) + bbox = detection["bbox"] + detection["bbox"] = [ + bbox[0] + np.random.normal(0, 2), + bbox[1] + np.random.normal(0, 2), + bbox[2] + np.random.normal(0, 2), + bbox[3] + np.random.normal(0, 2) + ] + + validator.add_detection(detection, frame_id=i) + + # Benchmark validation performance + validation_times = [] + num_iterations = 1000 + + for i in range(num_iterations): + test_detection = base_detection.copy() + test_detection["confidence"] = 0.85 + np.random.normal(0, 0.05) + + start_time = time.perf_counter() + is_stable = validator.is_detection_stable( + test_detection, + stability_frames=10, + confidence_threshold=0.8 + ) + end_time = time.perf_counter() + + validation_time_ms = (end_time - start_time) * 1000 + validation_times.append(validation_time_ms) + + avg_validation_time = statistics.mean(validation_times) + max_validation_time = max(validation_times) + + # Should be very fast (< 1ms typically) + assert avg_validation_time < 1.0 + assert max_validation_time < 5.0 + + print(f"\nStability Validation Performance Metrics:") + print(f"Average validation time: {avg_validation_time:.3f} ms") + print(f"Max validation time: {max_validation_time:.3f} ms") + + @pytest.mark.asyncio + async def test_pipeline_executor_performance(self, sample_frame, performance_config): + """Benchmark complete pipeline execution performance.""" + + pipeline_executor = PipelineExecutor() + + # Simple pipeline configuration + pipeline_config = { + "modelId": "fast_detection_model", + "modelFile": "fast_model.pt", + "expectedClasses": ["car"], + "minConfidence": 0.5, + "actions": [], + "branches": [] + } + + detection_context = { + "camera_id": "perf_camera", + "display_id": "perf_display", + "frame": sample_frame, + "timestamp": int(time.time() * 1000), + "session_id": "perf_session" + } + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True): + + # Setup fast mock model + mock_model = Mock() + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.xyxy = Mock() + mock_result.boxes.conf = Mock() + mock_result.boxes.cls = Mock() + mock_result.names = {0: "car"} + + mock_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([[100, 200, 300, 400]]) + mock_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.9]) + mock_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0]) + + mock_model.return_value = mock_result + mock_torch_load.return_value = mock_model + + # Warm up + for _ in range(3): + await pipeline_executor.execute_pipeline(pipeline_config, detection_context) + + # Benchmark pipeline execution + pipeline_times = [] + num_iterations = 50 + + for _ in range(num_iterations): + start_time = time.perf_counter() + result = await pipeline_executor.execute_pipeline(pipeline_config, detection_context) + end_time = time.perf_counter() + + pipeline_time_ms = (end_time - start_time) * 1000 + pipeline_times.append(pipeline_time_ms) + + # Ensure result is valid + assert result is not None + + avg_pipeline_time = statistics.mean(pipeline_times) + max_pipeline_time = max(pipeline_times) + + # Performance assertions + assert avg_pipeline_time < performance_config["max_pipeline_time_ms"] + + print(f"\nPipeline Execution Performance Metrics:") + print(f"Average pipeline time: {avg_pipeline_time:.2f} ms") + print(f"Max pipeline time: {max_pipeline_time:.2f} ms") + + def test_memory_usage_detection(self, sample_frame, performance_config): + """Test memory usage during detection operations.""" + + detector = YOLODetector() + + with patch('torch.load') as mock_torch_load: + # Setup mock model + mock_model = Mock() + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.xyxy = Mock() + mock_result.boxes.conf = Mock() + mock_result.boxes.cls = Mock() + mock_result.names = {0: "car"} + + mock_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([[100, 200, 300, 400]]) + mock_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.9]) + mock_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0]) + + mock_model.return_value = mock_result + mock_torch_load.return_value = mock_model + + # Measure memory usage + gc.collect() # Clean up before measurement + initial_memory = psutil.Process().memory_info().rss / 1024 / 1024 # MB + + # Run detections and monitor memory + memory_measurements = [] + for i in range(100): + detections = detector.detect(sample_frame, confidence_threshold=0.5) + + if i % 10 == 0: # Measure every 10 iterations + current_memory = psutil.Process().memory_info().rss / 1024 / 1024 + memory_measurements.append(current_memory - initial_memory) + + # Final memory measurement + gc.collect() + final_memory = psutil.Process().memory_info().rss / 1024 / 1024 + memory_increase = final_memory - initial_memory + + # Memory should not grow significantly + assert memory_increase < 100 # Less than 100MB increase + + # Memory should be relatively stable (not constantly growing) + if len(memory_measurements) > 1: + memory_trend = memory_measurements[-1] - memory_measurements[0] + assert memory_trend < 50 # Less than 50MB trend growth + + print(f"\nMemory Usage Metrics:") + print(f"Initial memory: {initial_memory:.1f} MB") + print(f"Final memory: {final_memory:.1f} MB") + print(f"Memory increase: {memory_increase:.1f} MB") + + def test_concurrent_detection_performance(self, sample_frame): + """Test performance with concurrent detection operations.""" + + with patch('torch.load') as mock_torch_load: + # Setup mock model + mock_model = Mock() + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.xyxy = Mock() + mock_result.boxes.conf = Mock() + mock_result.boxes.cls = Mock() + mock_result.names = {0: "car"} + + mock_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([[100, 200, 300, 400]]) + mock_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.9]) + mock_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0]) + + mock_model.return_value = mock_result + mock_torch_load.return_value = mock_model + + # Create multiple detectors + detectors = [YOLODetector() for _ in range(4)] + + import threading + import concurrent.futures + + def run_detection(detector, frame, iterations=25): + """Run detection iterations.""" + times = [] + for _ in range(iterations): + start_time = time.perf_counter() + detections = detector.detect(frame, confidence_threshold=0.5) + end_time = time.perf_counter() + times.append((end_time - start_time) * 1000) + return times + + # Run concurrent detections + start_time = time.perf_counter() + + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [ + executor.submit(run_detection, detector, sample_frame) + for detector in detectors + ] + + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + end_time = time.perf_counter() + total_time = end_time - start_time + + # Analyze results + all_times = [time_ms for result in results for time_ms in result] + total_detections = len(all_times) + avg_detection_time = statistics.mean(all_times) + + # Calculate effective throughput + effective_fps = total_detections / total_time + + print(f"\nConcurrent Detection Performance:") + print(f"Total detections: {total_detections}") + print(f"Total time: {total_time:.2f} seconds") + print(f"Average detection time: {avg_detection_time:.2f} ms") + print(f"Effective throughput: {effective_fps:.1f} FPS") + + # Should maintain reasonable performance under load + assert avg_detection_time < 200 # Less than 200ms average + assert effective_fps > 20 # More than 20 effective FPS + + def test_large_frame_performance(self, large_frame): + """Test detection performance with large frames.""" + + detector = YOLODetector() + + with patch('torch.load') as mock_torch_load: + # Setup mock model + mock_model = Mock() + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.xyxy = Mock() + mock_result.boxes.conf = Mock() + mock_result.boxes.cls = Mock() + mock_result.names = {0: "car", 1: "person"} + + # Larger frame might have more detections + mock_result.boxes.xyxy.cpu.return_value.numpy.return_value = np.array([ + [100, 200, 300, 400], + [500, 600, 700, 800], + [1000, 200, 1200, 400] + ]) + mock_result.boxes.conf.cpu.return_value.numpy.return_value = np.array([0.9, 0.8, 0.7]) + mock_result.boxes.cls.cpu.return_value.numpy.return_value = np.array([0, 1, 0]) + + mock_model.return_value = mock_result + mock_torch_load.return_value = mock_model + + # Benchmark large frame detection + detection_times = [] + num_iterations = 20 # Fewer iterations for large frames + + for _ in range(num_iterations): + start_time = time.perf_counter() + detections = detector.detect(large_frame, confidence_threshold=0.5) + end_time = time.perf_counter() + + detection_time_ms = (end_time - start_time) * 1000 + detection_times.append(detection_time_ms) + + avg_detection_time = statistics.mean(detection_times) + max_detection_time = max(detection_times) + + print(f"\nLarge Frame Detection Performance:") + print(f"Frame size: {large_frame.shape}") + print(f"Average detection time: {avg_detection_time:.2f} ms") + print(f"Max detection time: {max_detection_time:.2f} ms") + + # Large frames should still be processed in reasonable time + assert avg_detection_time < 300 # Less than 300ms for large frames + assert max_detection_time < 500 # Less than 500ms max + + +class TestStreamPerformance: + """Test stream management performance.""" + + @pytest.mark.asyncio + async def test_stream_creation_performance(self): + """Test performance of stream creation and management.""" + + stream_manager = StreamManager() + + with patch('cv2.VideoCapture') as mock_video_cap: + # Setup fast mock + mock_cap_instance = Mock() + mock_video_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + mock_cap_instance.read.return_value = (True, np.ones((480, 640, 3), dtype=np.uint8)) + + # Benchmark stream creation + creation_times = [] + num_streams = 20 + + try: + for i in range(num_streams): + from detector_worker.streams.stream_manager import StreamConfig + config = StreamConfig( + stream_url=f"rtsp://test{i}.example.com/stream", + stream_type="rtsp" + ) + + start_time = time.perf_counter() + await stream_manager.create_stream(f"camera_{i}", config, f"sub_{i}") + end_time = time.perf_counter() + + creation_time_ms = (end_time - start_time) * 1000 + creation_times.append(creation_time_ms) + + avg_creation_time = statistics.mean(creation_times) + max_creation_time = max(creation_times) + + # Stream creation should be fast + assert avg_creation_time < 100 # Less than 100ms average + assert max_creation_time < 500 # Less than 500ms max + + print(f"\nStream Creation Performance:") + print(f"Streams created: {num_streams}") + print(f"Average creation time: {avg_creation_time:.2f} ms") + print(f"Max creation time: {max_creation_time:.2f} ms") + + finally: + await stream_manager.stop_all_streams() + + @pytest.mark.asyncio + async def test_frame_retrieval_performance(self, sample_frame): + """Test performance of frame retrieval operations.""" + + stream_manager = StreamManager() + + with patch('cv2.VideoCapture') as mock_video_cap: + mock_cap_instance = Mock() + mock_video_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + mock_cap_instance.read.return_value = (True, sample_frame) + + try: + # Create test stream + from detector_worker.streams.stream_manager import StreamConfig + config = StreamConfig( + stream_url="rtsp://perf.example.com/stream", + stream_type="rtsp" + ) + + await stream_manager.create_stream("perf_camera", config, "perf_sub") + + # Let stream capture some frames + await asyncio.sleep(0.1) + + # Benchmark frame retrieval + retrieval_times = [] + num_retrievals = 1000 + + for _ in range(num_retrievals): + start_time = time.perf_counter() + frame = stream_manager.get_latest_frame("perf_camera") + end_time = time.perf_counter() + + retrieval_time_ms = (end_time - start_time) * 1000 + retrieval_times.append(retrieval_time_ms) + + avg_retrieval_time = statistics.mean(retrieval_times) + max_retrieval_time = max(retrieval_times) + + # Frame retrieval should be very fast + assert avg_retrieval_time < 1.0 # Less than 1ms average + assert max_retrieval_time < 10.0 # Less than 10ms max + + print(f"\nFrame Retrieval Performance:") + print(f"Frame retrievals: {num_retrievals}") + print(f"Average retrieval time: {avg_retrieval_time:.3f} ms") + print(f"Max retrieval time: {max_retrieval_time:.3f} ms") + + finally: + await stream_manager.stop_all_streams() + + +class TestModelPerformance: + """Test model management performance.""" + + def test_model_loading_performance(self): + """Test performance of model loading operations.""" + + model_manager = ModelManager() + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True): + + # Setup mock model + def create_mock_model(): + model = Mock() + # Mock model parameters for memory estimation + param = Mock() + param.numel.return_value = 1000000 # 1M parameters + param.element_size.return_value = 4 # 4 bytes each + model.parameters.return_value = [param] + return model + + mock_torch_load.side_effect = lambda *args, **kwargs: create_mock_model() + + # Benchmark model loading + loading_times = [] + num_models = 10 + + for i in range(num_models): + from detector_worker.models.model_manager import ModelConfig + config = ModelConfig( + model_id=f"perf_model_{i}", + model_path=f"/fake/path/model_{i}.pt", + model_type="detection", + device="cpu" + ) + + start_time = time.perf_counter() + model = model_manager.load_model(config) + end_time = time.perf_counter() + + loading_time_ms = (end_time - start_time) * 1000 + loading_times.append(loading_time_ms) + + avg_loading_time = statistics.mean(loading_times) + max_loading_time = max(loading_times) + + print(f"\nModel Loading Performance:") + print(f"Models loaded: {num_models}") + print(f"Average loading time: {avg_loading_time:.2f} ms") + print(f"Max loading time: {max_loading_time:.2f} ms") + + # Model loading should be reasonable + assert avg_loading_time < 200 # Less than 200ms average + + def test_model_cache_performance(self): + """Test performance of model cache operations.""" + + model_manager = ModelManager() + + with patch('torch.load') as mock_torch_load, \ + patch('os.path.exists', return_value=True): + + mock_torch_load.return_value = Mock() + + # Load model first + from detector_worker.models.model_manager import ModelConfig + config = ModelConfig( + model_id="cache_perf_model", + model_path="/fake/path/model.pt", + model_type="detection", + device="cpu" + ) + + # Initial load + model_manager.load_model(config) + + # Benchmark cache retrieval + cache_times = [] + num_retrievals = 10000 + + for _ in range(num_retrievals): + start_time = time.perf_counter() + model = model_manager.get_model("cache_perf_model") + end_time = time.perf_counter() + + cache_time_ms = (end_time - start_time) * 1000 + cache_times.append(cache_time_ms) + + avg_cache_time = statistics.mean(cache_times) + max_cache_time = max(cache_times) + + print(f"\nModel Cache Performance:") + print(f"Cache retrievals: {num_retrievals}") + print(f"Average cache time: {avg_cache_time:.4f} ms") + print(f"Max cache time: {max_cache_time:.4f} ms") + + # Cache should be very fast + assert avg_cache_time < 0.1 # Less than 0.1ms average + assert max_cache_time < 1.0 # Less than 1ms max \ No newline at end of file diff --git a/tests/performance/test_storage_performance.py b/tests/performance/test_storage_performance.py new file mode 100644 index 0000000..16be992 --- /dev/null +++ b/tests/performance/test_storage_performance.py @@ -0,0 +1,828 @@ +""" +Performance tests for storage components (database, Redis, session cache). + +These tests benchmark storage operations to ensure they meet +performance requirements for high-throughput scenarios. +""" +import pytest +import asyncio +import time +import statistics +import uuid +from unittest.mock import Mock, patch, MagicMock +import psutil +import gc +import numpy as np + +from detector_worker.storage.database_manager import DatabaseManager +from detector_worker.storage.redis_client import RedisClient, RedisConfig +from detector_worker.storage.session_cache import SessionCacheManager, SessionCache, CacheConfig + + +@pytest.fixture +def performance_config(): + """Configuration for performance tests.""" + return { + "max_db_query_time_ms": 50, + "max_redis_operation_time_ms": 10, + "max_cache_operation_time_ms": 1, + "min_db_throughput_ops_per_sec": 1000, + "min_redis_throughput_ops_per_sec": 5000, + "min_cache_throughput_ops_per_sec": 10000 + } + + +class TestDatabasePerformance: + """Test database performance benchmarks.""" + + def test_database_connection_performance(self, performance_config): + """Test database connection establishment performance.""" + + with patch('psycopg2.connect') as mock_connect: + # Setup mock connection + mock_conn = Mock() + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + db_manager = DatabaseManager() + + # Benchmark connection times + connection_times = [] + num_connections = 100 + + for _ in range(num_connections): + start_time = time.perf_counter() + db_manager.connect() + end_time = time.perf_counter() + + connection_time_ms = (end_time - start_time) * 1000 + connection_times.append(connection_time_ms) + + # Disconnect for next test + db_manager.disconnect() + + avg_connection_time = statistics.mean(connection_times) + max_connection_time = max(connection_times) + + print(f"\nDatabase Connection Performance:") + print(f"Connections: {num_connections}") + print(f"Average connection time: {avg_connection_time:.2f} ms") + print(f"Max connection time: {max_connection_time:.2f} ms") + + # Connection should be fast + assert avg_connection_time < 10.0 # Less than 10ms average + assert max_connection_time < 50.0 # Less than 50ms max + + @pytest.mark.asyncio + async def test_database_insert_performance(self, performance_config): + """Test database insert performance.""" + + with patch('psycopg2.connect') as mock_connect: + # Setup mock database + mock_conn = Mock() + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + db_manager = DatabaseManager() + db_manager.connect() + + # Prepare test data + table_name = "car_frontal_info" + test_records = [ + { + "display_id": f"display_{i}", + "captured_timestamp": str(int(time.time() * 1000) + i), + "session_id": str(uuid.uuid4()), + "license_character": None, + "license_type": "No model available" + } + for i in range(1000) + ] + + # Benchmark single inserts + insert_times = [] + + for record in test_records[:100]: # Test first 100 for individual timing + start_time = time.perf_counter() + await db_manager.create_record(table_name, record) + end_time = time.perf_counter() + + insert_time_ms = (end_time - start_time) * 1000 + insert_times.append(insert_time_ms) + + # Benchmark batch insert + start_time = time.perf_counter() + for record in test_records[100:]: + await db_manager.create_record(table_name, record) + end_time = time.perf_counter() + + batch_time = end_time - start_time + batch_throughput = 900 / batch_time # 900 records in batch + + avg_insert_time = statistics.mean(insert_times) + max_insert_time = max(insert_times) + + print(f"\nDatabase Insert Performance:") + print(f"Average insert time: {avg_insert_time:.2f} ms") + print(f"Max insert time: {max_insert_time:.2f} ms") + print(f"Batch throughput: {batch_throughput:.0f} inserts/second") + + assert avg_insert_time < performance_config["max_db_query_time_ms"] + assert batch_throughput > performance_config["min_db_throughput_ops_per_sec"] + + @pytest.mark.asyncio + async def test_database_update_performance(self, performance_config): + """Test database update performance.""" + + with patch('psycopg2.connect') as mock_connect: + # Setup mock database + mock_conn = Mock() + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + db_manager = DatabaseManager() + db_manager.connect() + + table_name = "car_frontal_info" + session_ids = [str(uuid.uuid4()) for _ in range(1000)] + + # Benchmark updates + update_times = [] + + for session_id in session_ids[:100]: # Test first 100 for individual timing + update_data = { + "car_brand": "Toyota", + "car_body_type": "Sedan", + "updated_at": "NOW()" + } + + start_time = time.perf_counter() + await db_manager.update_record(table_name, session_id, update_data, key_field="session_id") + end_time = time.perf_counter() + + update_time_ms = (end_time - start_time) * 1000 + update_times.append(update_time_ms) + + # Benchmark batch updates + start_time = time.perf_counter() + for session_id in session_ids[100:]: + update_data = { + "car_brand": "Honda", + "car_body_type": "Hatchback" + } + await db_manager.update_record(table_name, session_id, update_data, key_field="session_id") + end_time = time.perf_counter() + + batch_time = end_time - start_time + batch_throughput = 900 / batch_time + + avg_update_time = statistics.mean(update_times) + max_update_time = max(update_times) + + print(f"\nDatabase Update Performance:") + print(f"Average update time: {avg_update_time:.2f} ms") + print(f"Max update time: {max_update_time:.2f} ms") + print(f"Batch throughput: {batch_throughput:.0f} updates/second") + + assert avg_update_time < performance_config["max_db_query_time_ms"] + assert batch_throughput > performance_config["min_db_throughput_ops_per_sec"] + + @pytest.mark.asyncio + async def test_database_query_performance(self, performance_config): + """Test database query performance.""" + + with patch('psycopg2.connect') as mock_connect: + # Setup mock database + mock_conn = Mock() + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + + # Mock query results + mock_cursor.fetchone.return_value = ("display_1", "1640995200", "session_123", None, "No model", "Toyota", "Sedan") + mock_cursor.fetchall.return_value = [ + ("display_1", "1640995200", "session_123", None, "No model", "Toyota", "Sedan"), + ("display_2", "1640995201", "session_124", None, "No model", "Honda", "Hatchback") + ] + + mock_connect.return_value = mock_conn + + db_manager = DatabaseManager() + db_manager.connect() + + table_name = "car_frontal_info" + + # Benchmark single record queries + query_times = [] + num_queries = 1000 + + for i in range(num_queries): + session_id = f"session_{i}" + + start_time = time.perf_counter() + result = await db_manager.get_record(table_name, session_id, key_field="session_id") + end_time = time.perf_counter() + + query_time_ms = (end_time - start_time) * 1000 + query_times.append(query_time_ms) + + avg_query_time = statistics.mean(query_times) + max_query_time = max(query_times) + query_throughput = num_queries / (sum(query_times) / 1000) + + print(f"\nDatabase Query Performance:") + print(f"Queries: {num_queries}") + print(f"Average query time: {avg_query_time:.2f} ms") + print(f"Max query time: {max_query_time:.2f} ms") + print(f"Query throughput: {query_throughput:.0f} queries/second") + + assert avg_query_time < performance_config["max_db_query_time_ms"] + assert query_throughput > performance_config["min_db_throughput_ops_per_sec"] + + +class TestRedisPerformance: + """Test Redis client performance benchmarks.""" + + @pytest.mark.asyncio + async def test_redis_connection_performance(self): + """Test Redis connection performance.""" + + with patch('redis.Redis') as mock_redis_class, \ + patch('redis.ConnectionPool') as mock_pool_class: + + mock_redis = Mock() + mock_redis.ping.return_value = True + mock_redis_class.return_value = mock_redis + + mock_pool = Mock() + mock_pool_class.return_value = mock_pool + + config = RedisConfig(host="localhost", port=6379) + + # Benchmark connection times + connection_times = [] + num_connections = 100 + + for _ in range(num_connections): + redis_client = RedisClient(config) + + start_time = time.perf_counter() + await redis_client.connect() + end_time = time.perf_counter() + + connection_time_ms = (end_time - start_time) * 1000 + connection_times.append(connection_time_ms) + + await redis_client.disconnect() + + avg_connection_time = statistics.mean(connection_times) + max_connection_time = max(connection_times) + + print(f"\nRedis Connection Performance:") + print(f"Connections: {num_connections}") + print(f"Average connection time: {avg_connection_time:.2f} ms") + print(f"Max connection time: {max_connection_time:.2f} ms") + + # Redis connections should be very fast + assert avg_connection_time < 5.0 # Less than 5ms average + assert max_connection_time < 20.0 # Less than 20ms max + + @pytest.mark.asyncio + async def test_redis_basic_operations_performance(self, performance_config): + """Test basic Redis operations performance.""" + + with patch('redis.Redis') as mock_redis_class: + mock_redis = Mock() + mock_redis.ping.return_value = True + mock_redis.set.return_value = True + mock_redis.get.return_value = "test_value" + mock_redis.delete.return_value = 1 + mock_redis.exists.return_value = 1 + mock_redis_class.return_value = mock_redis + + config = RedisConfig(host="localhost") + redis_client = RedisClient(config) + await redis_client.connect() + + # Benchmark SET operations + set_times = [] + num_operations = 10000 + + for i in range(num_operations): + start_time = time.perf_counter() + await redis_client.set(f"key_{i}", f"value_{i}", expire_seconds=300) + end_time = time.perf_counter() + + set_time_ms = (end_time - start_time) * 1000 + set_times.append(set_time_ms) + + # Benchmark GET operations + get_times = [] + for i in range(num_operations): + start_time = time.perf_counter() + value = await redis_client.get(f"key_{i}") + end_time = time.perf_counter() + + get_time_ms = (end_time - start_time) * 1000 + get_times.append(get_time_ms) + + # Benchmark DELETE operations + delete_times = [] + for i in range(num_operations): + start_time = time.perf_counter() + result = await redis_client.delete(f"key_{i}") + end_time = time.perf_counter() + + delete_time_ms = (end_time - start_time) * 1000 + delete_times.append(delete_time_ms) + + # Calculate statistics + avg_set_time = statistics.mean(set_times) + avg_get_time = statistics.mean(get_times) + avg_delete_time = statistics.mean(delete_times) + + set_throughput = num_operations / (sum(set_times) / 1000) + get_throughput = num_operations / (sum(get_times) / 1000) + delete_throughput = num_operations / (sum(delete_times) / 1000) + + print(f"\nRedis Basic Operations Performance:") + print(f"Operations per type: {num_operations}") + print(f"Average SET time: {avg_set_time:.3f} ms") + print(f"Average GET time: {avg_get_time:.3f} ms") + print(f"Average DELETE time: {avg_delete_time:.3f} ms") + print(f"SET throughput: {set_throughput:.0f} ops/second") + print(f"GET throughput: {get_throughput:.0f} ops/second") + print(f"DELETE throughput: {delete_throughput:.0f} ops/second") + + assert avg_set_time < performance_config["max_redis_operation_time_ms"] + assert avg_get_time < performance_config["max_redis_operation_time_ms"] + assert avg_delete_time < performance_config["max_redis_operation_time_ms"] + + assert set_throughput > performance_config["min_redis_throughput_ops_per_sec"] + assert get_throughput > performance_config["min_redis_throughput_ops_per_sec"] + + @pytest.mark.asyncio + async def test_redis_image_storage_performance(self): + """Test Redis image storage performance.""" + + with patch('redis.Redis') as mock_redis_class, \ + patch('cv2.imencode') as mock_imencode: + + mock_redis = Mock() + mock_redis.ping.return_value = True + mock_redis.set.return_value = True + mock_redis.expire.return_value = True + mock_redis_class.return_value = mock_redis + + # Mock image encoding + encoded_data = np.array([1, 2, 3, 4, 5], dtype=np.uint8) + mock_imencode.return_value = (True, encoded_data) + + config = RedisConfig(host="localhost") + redis_client = RedisClient(config) + await redis_client.connect() + + # Create test frames + small_frame = np.random.randint(0, 255, (240, 320, 3), dtype=np.uint8) + medium_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + large_frame = np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8) + + frames = [ + ("small", small_frame), + ("medium", medium_frame), + ("large", large_frame) + ] + + for frame_type, frame in frames: + storage_times = [] + num_images = 100 + + for i in range(num_images): + key = f"test_image_{frame_type}_{i}" + + start_time = time.perf_counter() + await redis_client.image_storage.store_image(key, frame, expire_seconds=300) + end_time = time.perf_counter() + + storage_time_ms = (end_time - start_time) * 1000 + storage_times.append(storage_time_ms) + + avg_storage_time = statistics.mean(storage_times) + max_storage_time = max(storage_times) + throughput = num_images / (sum(storage_times) / 1000) + + print(f"\n{frame_type.capitalize()} Frame Storage Performance:") + print(f"Frame size: {frame.shape}") + print(f"Images stored: {num_images}") + print(f"Average storage time: {avg_storage_time:.2f} ms") + print(f"Max storage time: {max_storage_time:.2f} ms") + print(f"Storage throughput: {throughput:.1f} images/second") + + # Performance should scale reasonably with image size + expected_max_time = {"small": 50, "medium": 100, "large": 200} + assert avg_storage_time < expected_max_time[frame_type] + + @pytest.mark.asyncio + async def test_redis_pipeline_performance(self): + """Test Redis pipeline performance.""" + + with patch('redis.Redis') as mock_redis_class: + mock_redis = Mock() + mock_redis.ping.return_value = True + mock_redis_class.return_value = mock_redis + + # Mock pipeline + mock_pipeline = Mock() + mock_pipeline.execute.return_value = [True] * 1000 + mock_redis.pipeline.return_value = mock_pipeline + + config = RedisConfig(host="localhost") + redis_client = RedisClient(config) + await redis_client.connect() + + # Benchmark pipeline operations + num_operations = 1000 + + start_time = time.perf_counter() + + async with redis_client.pipeline() as pipe: + for i in range(num_operations): + pipe.set(f"pipeline_key_{i}", f"pipeline_value_{i}") + results = await pipe.execute() + + end_time = time.perf_counter() + + total_time = end_time - start_time + throughput = num_operations / total_time + + print(f"\nRedis Pipeline Performance:") + print(f"Operations: {num_operations}") + print(f"Total time: {total_time:.3f} seconds") + print(f"Throughput: {throughput:.0f} ops/second") + + # Pipeline should be much faster than individual operations + assert throughput > 10000 # Should exceed 10k ops/second with pipeline + assert len(results) == num_operations + + +class TestSessionCachePerformance: + """Test session cache performance benchmarks.""" + + def test_cache_basic_operations_performance(self, performance_config): + """Test basic cache operations performance.""" + + cache_config = CacheConfig(max_size=10000, ttl_seconds=3600) + cache = SessionCache(cache_config) + + # Prepare test data + test_sessions = [] + for i in range(10000): + from detector_worker.storage.session_cache import SessionData + session_data = SessionData( + session_id=f"session_{i}", + camera_id=f"camera_{i % 100}", # 100 unique cameras + display_id=f"display_{i % 50}" # 50 unique displays + ) + session_data.add_detection_data("main", {"class": "car", "confidence": 0.9}) + test_sessions.append((f"session_{i}", session_data)) + + # Benchmark PUT operations + put_times = [] + for session_id, session_data in test_sessions: + start_time = time.perf_counter() + cache.put(session_id, session_data) + end_time = time.perf_counter() + + put_time_ms = (end_time - start_time) * 1000 + put_times.append(put_time_ms) + + # Benchmark GET operations + get_times = [] + for session_id, _ in test_sessions: + start_time = time.perf_counter() + retrieved_data = cache.get(session_id) + end_time = time.perf_counter() + + get_time_ms = (end_time - start_time) * 1000 + get_times.append(get_time_ms) + + # Calculate statistics + avg_put_time = statistics.mean(put_times) + avg_get_time = statistics.mean(get_times) + max_put_time = max(put_times) + max_get_time = max(get_times) + + put_throughput = len(test_sessions) / (sum(put_times) / 1000) + get_throughput = len(test_sessions) / (sum(get_times) / 1000) + + print(f"\nSession Cache Basic Operations Performance:") + print(f"Operations per type: {len(test_sessions)}") + print(f"Average PUT time: {avg_put_time:.3f} ms") + print(f"Average GET time: {avg_get_time:.3f} ms") + print(f"Max PUT time: {max_put_time:.3f} ms") + print(f"Max GET time: {max_get_time:.3f} ms") + print(f"PUT throughput: {put_throughput:.0f} ops/second") + print(f"GET throughput: {get_throughput:.0f} ops/second") + + assert avg_put_time < performance_config["max_cache_operation_time_ms"] + assert avg_get_time < performance_config["max_cache_operation_time_ms"] + assert put_throughput > performance_config["min_cache_throughput_ops_per_sec"] + assert get_throughput > performance_config["min_cache_throughput_ops_per_sec"] + + def test_cache_manager_performance(self, performance_config): + """Test session cache manager performance.""" + + cache_manager = SessionCacheManager() + cache_manager.clear_all() + + # Benchmark detection caching + detection_times = [] + num_operations = 5000 + + for i in range(num_operations): + camera_id = f"camera_{i % 50}" + detection_data = { + "class": "car", + "confidence": 0.9, + "bbox": [100, 200, 300, 400], + "track_id": i + } + + start_time = time.perf_counter() + cache_manager.cache_detection(camera_id, detection_data) + end_time = time.perf_counter() + + detection_time_ms = (end_time - start_time) * 1000 + detection_times.append(detection_time_ms) + + # Benchmark detection retrieval + retrieval_times = [] + for i in range(num_operations): + camera_id = f"camera_{i % 50}" + + start_time = time.perf_counter() + cached_detection = cache_manager.get_cached_detection(camera_id) + end_time = time.perf_counter() + + retrieval_time_ms = (end_time - start_time) * 1000 + retrieval_times.append(retrieval_time_ms) + + # Benchmark session operations + session_times = [] + for i in range(1000): # Fewer session operations as they're more complex + session_id = str(uuid.uuid4()) + camera_id = f"camera_{i % 20}" + + start_time = time.perf_counter() + cache_manager.create_session(session_id, camera_id, {"initial": "data"}) + cache_manager.update_session_detection(session_id, {"car_brand": "Toyota"}) + session_data = cache_manager.get_session_detection(session_id) + end_time = time.perf_counter() + + session_time_ms = (end_time - start_time) * 1000 + session_times.append(session_time_ms) + + # Calculate statistics + avg_detection_time = statistics.mean(detection_times) + avg_retrieval_time = statistics.mean(retrieval_times) + avg_session_time = statistics.mean(session_times) + + detection_throughput = num_operations / (sum(detection_times) / 1000) + retrieval_throughput = num_operations / (sum(retrieval_times) / 1000) + session_throughput = 1000 / (sum(session_times) / 1000) + + print(f"\nCache Manager Performance:") + print(f"Average detection cache time: {avg_detection_time:.3f} ms") + print(f"Average retrieval time: {avg_retrieval_time:.3f} ms") + print(f"Average session operation time: {avg_session_time:.3f} ms") + print(f"Detection throughput: {detection_throughput:.0f} ops/second") + print(f"Retrieval throughput: {retrieval_throughput:.0f} ops/second") + print(f"Session throughput: {session_throughput:.0f} ops/second") + + assert avg_detection_time < performance_config["max_cache_operation_time_ms"] * 2 + assert avg_retrieval_time < performance_config["max_cache_operation_time_ms"] + assert detection_throughput > performance_config["min_cache_throughput_ops_per_sec"] / 2 + + def test_cache_memory_performance(self): + """Test cache memory usage and performance.""" + + # Measure initial memory + gc.collect() + initial_memory = psutil.Process().memory_info().rss / 1024 / 1024 # MB + + cache_config = CacheConfig(max_size=10000, ttl_seconds=3600) + cache = SessionCache(cache_config) + + # Add many sessions to test memory usage + num_sessions = 5000 + memory_measurements = [] + + for i in range(num_sessions): + from detector_worker.storage.session_cache import SessionData + session_data = SessionData( + session_id=f"memory_session_{i}", + camera_id=f"camera_{i % 100}", + display_id=f"display_{i % 50}" + ) + + # Add some detection data + session_data.add_detection_data("detection", { + "class": "car", + "confidence": 0.9, + "bbox": [100, 200, 300, 400], + "features": [float(j) for j in range(50)] # Add some bulk + }) + + cache.put(f"memory_session_{i}", session_data) + + # Measure memory periodically + if i % 500 == 0 and i > 0: + current_memory = psutil.Process().memory_info().rss / 1024 / 1024 + memory_increase = current_memory - initial_memory + memory_measurements.append((i, memory_increase)) + + # Final memory measurement + gc.collect() + final_memory = psutil.Process().memory_info().rss / 1024 / 1024 + total_memory_increase = final_memory - initial_memory + + # Calculate memory per session + memory_per_session = total_memory_increase / num_sessions + + print(f"\nCache Memory Performance:") + print(f"Sessions cached: {num_sessions}") + print(f"Initial memory: {initial_memory:.1f} MB") + print(f"Final memory: {final_memory:.1f} MB") + print(f"Total memory increase: {total_memory_increase:.1f} MB") + print(f"Memory per session: {memory_per_session * 1024:.1f} KB") + + # Memory usage should be reasonable + assert memory_per_session < 0.1 # Less than 100KB per session + assert total_memory_increase < 500 # Total increase less than 500MB + + # Test access performance with full cache + access_times = [] + for i in range(1000): + session_id = f"memory_session_{i}" + + start_time = time.perf_counter() + session_data = cache.get(session_id) + end_time = time.perf_counter() + + access_time_ms = (end_time - start_time) * 1000 + access_times.append(access_time_ms) + + avg_access_time = statistics.mean(access_times) + max_access_time = max(access_times) + + print(f"Full cache access performance:") + print(f"Average access time: {avg_access_time:.3f} ms") + print(f"Max access time: {max_access_time:.3f} ms") + + # Access should remain fast even with full cache + assert avg_access_time < 1.0 # Less than 1ms average + assert max_access_time < 10.0 # Less than 10ms max + + def test_cache_eviction_performance(self): + """Test cache eviction performance.""" + + # Create cache with small size to force evictions + cache_config = CacheConfig(max_size=1000, eviction_policy="lru") + cache = SessionCache(cache_config) + + # Fill cache beyond capacity + num_sessions = 2000 # Double the capacity + eviction_times = [] + + for i in range(num_sessions): + from detector_worker.storage.session_cache import SessionData + session_data = SessionData( + session_id=f"eviction_session_{i}", + camera_id=f"camera_{i % 100}", + display_id=f"display_{i % 50}" + ) + + start_time = time.perf_counter() + cache.put(f"eviction_session_{i}", session_data) + end_time = time.perf_counter() + + operation_time_ms = (end_time - start_time) * 1000 + eviction_times.append(operation_time_ms) + + # Analyze eviction performance + avg_operation_time = statistics.mean(eviction_times) + max_operation_time = max(eviction_times) + + # Check that cache size is maintained + assert cache.size() == 1000 # Should not exceed max_size + + print(f"\nCache Eviction Performance:") + print(f"Sessions processed: {num_sessions}") + print(f"Final cache size: {cache.size()}") + print(f"Average operation time: {avg_operation_time:.3f} ms") + print(f"Max operation time: {max_operation_time:.3f} ms") + + # Eviction should not significantly slow down operations + assert avg_operation_time < 5.0 # Less than 5ms average with eviction + assert max_operation_time < 20.0 # Less than 20ms max + + +class TestStorageIntegrationPerformance: + """Test integrated storage performance scenarios.""" + + @pytest.mark.asyncio + async def test_full_storage_pipeline_performance(self): + """Test performance of complete storage pipeline.""" + + with patch('psycopg2.connect') as mock_db_connect, \ + patch('redis.Redis') as mock_redis_class: + + # Setup mocks + mock_db_conn = Mock() + mock_db_cursor = Mock() + mock_db_conn.cursor.return_value = mock_db_cursor + mock_db_connect.return_value = mock_db_conn + + mock_redis = Mock() + mock_redis.ping.return_value = True + mock_redis.set.return_value = True + mock_redis.expire.return_value = True + mock_redis_class.return_value = mock_redis + + # Initialize storage components + db_manager = DatabaseManager() + db_manager.connect() + + redis_config = RedisConfig(host="localhost") + redis_client = RedisClient(redis_config) + await redis_client.connect() + + cache_manager = SessionCacheManager() + cache_manager.clear_all() + + # Benchmark complete storage pipeline + pipeline_times = [] + num_iterations = 500 + + for i in range(num_iterations): + session_id = str(uuid.uuid4()) + camera_id = f"camera_{i % 20}" + + start_time = time.perf_counter() + + # 1. Cache detection + detection_data = { + "class": "car", + "confidence": 0.9, + "bbox": [100, 200, 300, 400], + "track_id": i + 1000 + } + cache_manager.cache_detection(camera_id, detection_data) + + # 2. Create session + cache_manager.create_session(session_id, camera_id, {"initial": "data"}) + + # 3. Database insert + await db_manager.create_record("car_frontal_info", { + "session_id": session_id, + "display_id": f"display_{i % 10}", + "captured_timestamp": str(int(time.time() * 1000)), + "license_type": "No model available" + }) + + # 4. Redis store + await redis_client.set(f"detection:{session_id}", "image_data", expire_seconds=600) + + # 5. Update session with results + cache_manager.update_session_detection(session_id, { + "car_brand": "Toyota", + "car_body_type": "Sedan" + }) + + # 6. Database update + await db_manager.update_record("car_frontal_info", session_id, { + "car_brand": "Toyota", + "car_body_type": "Sedan" + }, key_field="session_id") + + end_time = time.perf_counter() + + pipeline_time_ms = (end_time - start_time) * 1000 + pipeline_times.append(pipeline_time_ms) + + # Analyze pipeline performance + avg_pipeline_time = statistics.mean(pipeline_times) + max_pipeline_time = max(pipeline_times) + pipeline_throughput = num_iterations / (sum(pipeline_times) / 1000) + + print(f"\nFull Storage Pipeline Performance:") + print(f"Pipeline iterations: {num_iterations}") + print(f"Average pipeline time: {avg_pipeline_time:.2f} ms") + print(f"Max pipeline time: {max_pipeline_time:.2f} ms") + print(f"Pipeline throughput: {pipeline_throughput:.1f} pipelines/second") + + # Complete pipeline should be efficient + assert avg_pipeline_time < 100 # Less than 100ms per complete pipeline + assert pipeline_throughput > 50 # At least 50 pipelines/second \ No newline at end of file diff --git a/tests/performance/test_websocket_performance.py b/tests/performance/test_websocket_performance.py new file mode 100644 index 0000000..68d76dd --- /dev/null +++ b/tests/performance/test_websocket_performance.py @@ -0,0 +1,596 @@ +""" +Performance tests for WebSocket communication and message processing. + +These tests benchmark WebSocket throughput, latency, and concurrent +connection handling to ensure scalability requirements are met. +""" +import pytest +import asyncio +import time +import statistics +import json +from unittest.mock import Mock, AsyncMock +from concurrent.futures import ThreadPoolExecutor +import psutil + +from detector_worker.communication.websocket_handler import WebSocketHandler +from detector_worker.communication.message_processor import MessageProcessor +from detector_worker.communication.websocket_handler import ConnectionManager + + +@pytest.fixture +def performance_config(): + """Configuration for performance tests.""" + return { + "max_message_latency_ms": 10, + "min_throughput_msgs_per_sec": 1000, + "max_concurrent_connections": 100, + "max_memory_per_connection_kb": 100 + } + + +@pytest.fixture +def mock_websocket(): + """Create mock WebSocket for performance testing.""" + websocket = Mock() + websocket.accept = AsyncMock() + websocket.send_json = AsyncMock() + websocket.send_text = AsyncMock() + websocket.receive_json = AsyncMock() + websocket.receive_text = AsyncMock() + websocket.close = AsyncMock() + websocket.ping = AsyncMock() + return websocket + + +class TestWebSocketMessagePerformance: + """Test WebSocket message processing performance.""" + + @pytest.mark.asyncio + async def test_message_processing_throughput(self, performance_config): + """Test message processing throughput.""" + + message_processor = MessageProcessor() + + # Simple state request message + test_message = {"type": "requestState"} + client_id = "perf_client" + + # Warm up + for _ in range(10): + await message_processor.process_message(test_message, client_id) + + # Benchmark throughput + num_messages = 10000 + start_time = time.perf_counter() + + for _ in range(num_messages): + await message_processor.process_message(test_message, client_id) + + end_time = time.perf_counter() + total_time = end_time - start_time + throughput = num_messages / total_time + + print(f"\nMessage Processing Throughput:") + print(f"Messages processed: {num_messages}") + print(f"Total time: {total_time:.2f} seconds") + print(f"Throughput: {throughput:.0f} messages/second") + + assert throughput >= performance_config["min_throughput_msgs_per_sec"] + + @pytest.mark.asyncio + async def test_message_processing_latency(self, performance_config): + """Test individual message processing latency.""" + + message_processor = MessageProcessor() + + test_messages = [ + {"type": "requestState"}, + {"type": "setSessionId", "payload": {"sessionId": "test", "displayId": "display"}}, + {"type": "patchSession", "payload": {"sessionId": "test", "data": {"test": "value"}}} + ] + + client_id = "latency_client" + + # Benchmark individual message latency + all_latencies = [] + + for message_type, test_message in enumerate(test_messages): + latencies = [] + + for _ in range(1000): + start_time = time.perf_counter() + await message_processor.process_message(test_message, client_id) + end_time = time.perf_counter() + + latency_ms = (end_time - start_time) * 1000 + latencies.append(latency_ms) + + avg_latency = statistics.mean(latencies) + max_latency = max(latencies) + p95_latency = statistics.quantiles(latencies, n=20)[18] # 95th percentile + + all_latencies.extend(latencies) + + print(f"\nMessage Type: {test_message['type']}") + print(f"Average latency: {avg_latency:.3f} ms") + print(f"Max latency: {max_latency:.3f} ms") + print(f"95th percentile: {p95_latency:.3f} ms") + + assert avg_latency < performance_config["max_message_latency_ms"] + assert p95_latency < performance_config["max_message_latency_ms"] * 2 + + # Overall statistics + overall_avg = statistics.mean(all_latencies) + overall_p95 = statistics.quantiles(all_latencies, n=20)[18] + + print(f"\nOverall Message Latency:") + print(f"Average latency: {overall_avg:.3f} ms") + print(f"95th percentile: {overall_p95:.3f} ms") + + @pytest.mark.asyncio + async def test_concurrent_message_processing(self, performance_config): + """Test concurrent message processing performance.""" + + message_processor = MessageProcessor() + + async def process_messages_batch(client_id, num_messages): + """Process a batch of messages for one client.""" + test_message = {"type": "requestState"} + latencies = [] + + for _ in range(num_messages): + start_time = time.perf_counter() + await message_processor.process_message(test_message, client_id) + end_time = time.perf_counter() + + latency_ms = (end_time - start_time) * 1000 + latencies.append(latency_ms) + + return latencies + + # Run concurrent processing + num_clients = 50 + messages_per_client = 100 + + start_time = time.perf_counter() + + tasks = [ + process_messages_batch(f"client_{i}", messages_per_client) + for i in range(num_clients) + ] + + results = await asyncio.gather(*tasks) + + end_time = time.perf_counter() + total_time = end_time - start_time + + # Analyze results + all_latencies = [latency for client_latencies in results for latency in client_latencies] + total_messages = len(all_latencies) + avg_latency = statistics.mean(all_latencies) + throughput = total_messages / total_time + + print(f"\nConcurrent Message Processing:") + print(f"Clients: {num_clients}") + print(f"Total messages: {total_messages}") + print(f"Total time: {total_time:.2f} seconds") + print(f"Throughput: {throughput:.0f} messages/second") + print(f"Average latency: {avg_latency:.3f} ms") + + assert throughput >= performance_config["min_throughput_msgs_per_sec"] / 2 # Reduced due to concurrency overhead + assert avg_latency < performance_config["max_message_latency_ms"] * 2 + + @pytest.mark.asyncio + async def test_large_message_performance(self): + """Test performance with large messages.""" + + message_processor = MessageProcessor() + + # Create large message (simulating detection results) + large_payload = { + "detections": [ + { + "class": f"object_{i}", + "confidence": 0.9, + "bbox": [i*10, i*10, (i+1)*10, (i+1)*10], + "metadata": { + "feature_vector": [float(j) for j in range(100)], + "description": "x" * 500 # Large text field + } + } + for i in range(50) # 50 detections + ], + "camera_info": { + "resolution": [1920, 1080], + "settings": {"brightness": 50, "contrast": 75}, + "history": [{"timestamp": i, "event": f"event_{i}"} for i in range(100)] + } + } + + large_message = { + "type": "imageDetection", + "payload": large_payload + } + + client_id = "large_msg_client" + + # Measure message size + message_size_bytes = len(json.dumps(large_message)) + print(f"\nLarge Message Performance:") + print(f"Message size: {message_size_bytes / 1024:.1f} KB") + + # Benchmark large message processing + processing_times = [] + num_iterations = 100 + + for _ in range(num_iterations): + start_time = time.perf_counter() + await message_processor.process_message(large_message, client_id) + end_time = time.perf_counter() + + processing_time_ms = (end_time - start_time) * 1000 + processing_times.append(processing_time_ms) + + avg_processing_time = statistics.mean(processing_times) + max_processing_time = max(processing_times) + + print(f"Average processing time: {avg_processing_time:.2f} ms") + print(f"Max processing time: {max_processing_time:.2f} ms") + + # Large messages should still be processed reasonably quickly + assert avg_processing_time < 100 # Less than 100ms for large messages + assert max_processing_time < 500 # Less than 500ms max + + +class TestConnectionManagerPerformance: + """Test connection manager performance.""" + + def test_connection_creation_performance(self, performance_config, mock_websocket): + """Test connection creation and management performance.""" + + connection_manager = ConnectionManager() + + # Benchmark connection creation + creation_times = [] + num_connections = 1000 + + for i in range(num_connections): + start_time = time.perf_counter() + connection_manager._create_connection(mock_websocket, f"client_{i}") + end_time = time.perf_counter() + + creation_time_ms = (end_time - start_time) * 1000 + creation_times.append(creation_time_ms) + + avg_creation_time = statistics.mean(creation_times) + max_creation_time = max(creation_times) + + print(f"\nConnection Creation Performance:") + print(f"Connections created: {num_connections}") + print(f"Average creation time: {avg_creation_time:.3f} ms") + print(f"Max creation time: {max_creation_time:.3f} ms") + + # Connection creation should be very fast + assert avg_creation_time < 1.0 # Less than 1ms average + assert max_creation_time < 10.0 # Less than 10ms max + + @pytest.mark.asyncio + async def test_broadcast_performance(self, mock_websocket): + """Test broadcast message performance.""" + + connection_manager = ConnectionManager() + + # Create many mock connections + num_connections = 1000 + mock_websockets = [] + + for i in range(num_connections): + ws = Mock() + ws.send_json = AsyncMock() + ws.send_text = AsyncMock() + mock_websockets.append(ws) + + # Add to connection manager + connection = connection_manager._create_connection(ws, f"client_{i}") + connection.is_connected = True + connection_manager.connections[f"client_{i}"] = connection + + # Test broadcast performance + test_message = {"type": "broadcast", "data": "test message"} + + broadcast_times = [] + num_broadcasts = 100 + + for _ in range(num_broadcasts): + start_time = time.perf_counter() + await connection_manager.broadcast(test_message) + end_time = time.perf_counter() + + broadcast_time_ms = (end_time - start_time) * 1000 + broadcast_times.append(broadcast_time_ms) + + avg_broadcast_time = statistics.mean(broadcast_times) + max_broadcast_time = max(broadcast_times) + + print(f"\nBroadcast Performance:") + print(f"Connections: {num_connections}") + print(f"Broadcasts: {num_broadcasts}") + print(f"Average broadcast time: {avg_broadcast_time:.2f} ms") + print(f"Max broadcast time: {max_broadcast_time:.2f} ms") + + # Broadcast should scale reasonably + assert avg_broadcast_time < 50 # Less than 50ms for 1000 connections + + # Verify all connections received the message + for ws in mock_websockets: + assert ws.send_json.call_count == num_broadcasts + + def test_subscription_management_performance(self): + """Test subscription management performance.""" + + connection_manager = ConnectionManager() + + # Test subscription operations performance + num_operations = 10000 + + # Add subscriptions + add_times = [] + for i in range(num_operations): + client_id = f"client_{i % 100}" # 100 unique clients + subscription_id = f"camera_{i % 50}" # 50 unique cameras + + start_time = time.perf_counter() + connection_manager.add_subscription(client_id, subscription_id) + end_time = time.perf_counter() + + add_time_ms = (end_time - start_time) * 1000 + add_times.append(add_time_ms) + + # Query subscriptions + query_times = [] + for i in range(1000): + client_id = f"client_{i % 100}" + + start_time = time.perf_counter() + subscriptions = connection_manager.get_client_subscriptions(client_id) + end_time = time.perf_counter() + + query_time_ms = (end_time - start_time) * 1000 + query_times.append(query_time_ms) + + # Remove subscriptions + remove_times = [] + for i in range(num_operations): + client_id = f"client_{i % 100}" + subscription_id = f"camera_{i % 50}" + + start_time = time.perf_counter() + connection_manager.remove_subscription(client_id, subscription_id) + end_time = time.perf_counter() + + remove_time_ms = (end_time - start_time) * 1000 + remove_times.append(remove_time_ms) + + # Analyze results + avg_add_time = statistics.mean(add_times) + avg_query_time = statistics.mean(query_times) + avg_remove_time = statistics.mean(remove_times) + + print(f"\nSubscription Management Performance:") + print(f"Average add time: {avg_add_time:.4f} ms") + print(f"Average query time: {avg_query_time:.4f} ms") + print(f"Average remove time: {avg_remove_time:.4f} ms") + + # Should be very fast operations + assert avg_add_time < 0.1 + assert avg_query_time < 0.1 + assert avg_remove_time < 0.1 + + +class TestWebSocketHandlerPerformance: + """Test complete WebSocket handler performance.""" + + @pytest.mark.asyncio + async def test_concurrent_connections_performance(self, performance_config): + """Test performance with many concurrent connections.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + async def simulate_client_session(client_id, num_messages=50): + """Simulate a client WebSocket session.""" + mock_ws = Mock() + mock_ws.accept = AsyncMock() + mock_ws.send_json = AsyncMock() + mock_ws.receive_json = AsyncMock() + + # Simulate message sequence + messages = [ + {"type": "requestState"} for _ in range(num_messages) + ] + messages.append(asyncio.CancelledError()) # Disconnect + + mock_ws.receive_json.side_effect = messages + + processing_times = [] + try: + await websocket_handler.handle_websocket(mock_ws, client_id) + except asyncio.CancelledError: + pass # Expected disconnect + + return len(messages) - 1 # Exclude the disconnect + + # Test concurrent connections + num_concurrent_clients = 100 + messages_per_client = 25 + + start_time = time.perf_counter() + + tasks = [ + simulate_client_session(f"perf_client_{i}", messages_per_client) + for i in range(num_concurrent_clients) + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + end_time = time.perf_counter() + total_time = end_time - start_time + + # Analyze results + successful_clients = len([r for r in results if not isinstance(r, Exception)]) + total_messages = sum(r for r in results if isinstance(r, int)) + + print(f"\nConcurrent Connections Performance:") + print(f"Concurrent clients: {num_concurrent_clients}") + print(f"Successful clients: {successful_clients}") + print(f"Total messages: {total_messages}") + print(f"Total time: {total_time:.2f} seconds") + print(f"Messages per second: {total_messages / total_time:.0f}") + + assert successful_clients >= num_concurrent_clients * 0.95 # 95% success rate + assert total_messages / total_time >= 1000 # At least 1000 msg/sec throughput + + @pytest.mark.asyncio + async def test_memory_usage_under_load(self, performance_config): + """Test memory usage under high connection load.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Measure initial memory + initial_memory = psutil.Process().memory_info().rss / 1024 / 1024 # MB + + # Create many connections + num_connections = 500 + connections = [] + + for i in range(num_connections): + mock_ws = Mock() + mock_ws.accept = AsyncMock() + mock_ws.send_json = AsyncMock() + + connection = websocket_handler.connection_manager._create_connection( + mock_ws, f"mem_test_client_{i}" + ) + connection.is_connected = True + websocket_handler.connection_manager.connections[f"mem_test_client_{i}"] = connection + connections.append(connection) + + # Measure memory after connections + after_connections_memory = psutil.Process().memory_info().rss / 1024 / 1024 + memory_per_connection = (after_connections_memory - initial_memory) / num_connections * 1024 # KB + + # Simulate some activity + test_message = {"type": "broadcast", "data": "test"} + for _ in range(10): + await websocket_handler.connection_manager.broadcast(test_message) + + # Measure memory after activity + after_activity_memory = psutil.Process().memory_info().rss / 1024 / 1024 + + print(f"\nMemory Usage Under Load:") + print(f"Initial memory: {initial_memory:.1f} MB") + print(f"After {num_connections} connections: {after_connections_memory:.1f} MB") + print(f"After activity: {after_activity_memory:.1f} MB") + print(f"Memory per connection: {memory_per_connection:.1f} KB") + + # Memory usage should be reasonable + assert memory_per_connection < performance_config["max_memory_per_connection_kb"] + + # Clean up + websocket_handler.connection_manager.connections.clear() + + @pytest.mark.asyncio + async def test_heartbeat_performance(self): + """Test heartbeat mechanism performance.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor, {"heartbeat_interval": 0.1}) + + # Create connections with mock WebSockets + num_connections = 100 + mock_websockets = [] + + for i in range(num_connections): + mock_ws = Mock() + mock_ws.ping = AsyncMock() + mock_websockets.append(mock_ws) + + connection = websocket_handler.connection_manager._create_connection( + mock_ws, f"heartbeat_client_{i}" + ) + connection.is_connected = True + websocket_handler.connection_manager.connections[f"heartbeat_client_{i}"] = connection + + # Start heartbeat task + heartbeat_task = asyncio.create_task(websocket_handler._heartbeat_loop()) + + # Let it run for several heartbeat cycles + start_time = time.perf_counter() + await asyncio.sleep(0.5) # 5 heartbeat cycles + end_time = time.perf_counter() + + # Cancel heartbeat + heartbeat_task.cancel() + + try: + await heartbeat_task + except asyncio.CancelledError: + pass + + # Analyze heartbeat performance + elapsed_time = end_time - start_time + expected_pings = int(elapsed_time / 0.1) * num_connections + + actual_pings = sum(ws.ping.call_count for ws in mock_websockets) + ping_efficiency = actual_pings / expected_pings if expected_pings > 0 else 0 + + print(f"\nHeartbeat Performance:") + print(f"Connections: {num_connections}") + print(f"Elapsed time: {elapsed_time:.2f} seconds") + print(f"Expected pings: {expected_pings}") + print(f"Actual pings: {actual_pings}") + print(f"Ping efficiency: {ping_efficiency:.2%}") + + # Should achieve reasonable ping efficiency + assert ping_efficiency > 0.8 # At least 80% efficiency + + # Clean up + websocket_handler.connection_manager.connections.clear() + + @pytest.mark.asyncio + async def test_error_handling_performance(self): + """Test performance impact of error handling.""" + + message_processor = MessageProcessor() + websocket_handler = WebSocketHandler(message_processor) + + # Create messages that will cause errors + error_messages = [ + {"invalid": "message"}, # Missing type + {"type": "unknown_type"}, # Unknown type + {"type": "subscribe"}, # Missing payload + ] + + valid_message = {"type": "requestState"} + + # Mix error messages with valid ones + test_sequence = (error_messages + [valid_message]) * 250 # 1000 total messages + + start_time = time.perf_counter() + + for message in test_sequence: + await message_processor.process_message(message, "error_perf_client") + + end_time = time.perf_counter() + total_time = end_time - start_time + throughput = len(test_sequence) / total_time + + print(f"\nError Handling Performance:") + print(f"Total messages (with errors): {len(test_sequence)}") + print(f"Total time: {total_time:.2f} seconds") + print(f"Throughput: {throughput:.0f} messages/second") + + # Error handling shouldn't significantly impact performance + assert throughput > 500 # Should still process > 500 msg/sec with errors \ No newline at end of file diff --git a/tests/unit/communication/test_websocket_handler.py b/tests/unit/communication/test_websocket_handler.py new file mode 100644 index 0000000..c937780 --- /dev/null +++ b/tests/unit/communication/test_websocket_handler.py @@ -0,0 +1,856 @@ +""" +Unit tests for WebSocket handling functionality. +""" +import pytest +import asyncio +import json +from unittest.mock import Mock, AsyncMock, patch, MagicMock +from fastapi.websockets import WebSocket, WebSocketDisconnect +import uuid + +from detector_worker.communication.websocket_handler import ( + WebSocketHandler, + ConnectionManager, + WebSocketConnection, + MessageHandler, + WebSocketError, + ConnectionError as WSConnectionError +) +from detector_worker.communication.message_processor import MessageType +from detector_worker.core.exceptions import MessageProcessingError + + +class TestWebSocketConnection: + """Test WebSocket connection wrapper.""" + + def test_creation(self, mock_websocket): + """Test WebSocket connection creation.""" + connection = WebSocketConnection(mock_websocket, "client_001") + + assert connection.websocket == mock_websocket + assert connection.client_id == "client_001" + assert connection.is_connected is False + assert connection.connected_at is None + assert connection.last_ping is None + assert connection.subscription_id is None + + @pytest.mark.asyncio + async def test_accept_connection(self, mock_websocket): + """Test accepting WebSocket connection.""" + connection = WebSocketConnection(mock_websocket, "client_001") + + mock_websocket.accept = AsyncMock() + + await connection.accept() + + assert connection.is_connected is True + assert connection.connected_at is not None + mock_websocket.accept.assert_called_once() + + @pytest.mark.asyncio + async def test_send_message_json(self, mock_websocket): + """Test sending JSON message.""" + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + + mock_websocket.send_json = AsyncMock() + + message = {"type": "test", "data": "hello"} + await connection.send_message(message) + + mock_websocket.send_json.assert_called_once_with(message) + + @pytest.mark.asyncio + async def test_send_message_text(self, mock_websocket): + """Test sending text message.""" + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + + mock_websocket.send_text = AsyncMock() + + await connection.send_message("hello world") + + mock_websocket.send_text.assert_called_once_with("hello world") + + @pytest.mark.asyncio + async def test_send_message_not_connected(self, mock_websocket): + """Test sending message when not connected.""" + connection = WebSocketConnection(mock_websocket, "client_001") + # Don't set is_connected = True + + with pytest.raises(WebSocketError) as exc_info: + await connection.send_message({"type": "test"}) + + assert "not connected" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_receive_message_json(self, mock_websocket): + """Test receiving JSON message.""" + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + + mock_websocket.receive_json = AsyncMock(return_value={"type": "test", "data": "received"}) + + message = await connection.receive_message() + + assert message == {"type": "test", "data": "received"} + mock_websocket.receive_json.assert_called_once() + + @pytest.mark.asyncio + async def test_receive_message_text(self, mock_websocket): + """Test receiving text message.""" + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + + # Mock receive_json to fail, then receive_text to succeed + mock_websocket.receive_json = AsyncMock(side_effect=json.JSONDecodeError("Invalid JSON", "", 0)) + mock_websocket.receive_text = AsyncMock(return_value="plain text message") + + message = await connection.receive_message() + + assert message == "plain text message" + mock_websocket.receive_text.assert_called_once() + + @pytest.mark.asyncio + async def test_ping_pong(self, mock_websocket): + """Test ping/pong functionality.""" + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + + mock_websocket.ping = AsyncMock() + + await connection.ping() + + assert connection.last_ping is not None + mock_websocket.ping.assert_called_once() + + @pytest.mark.asyncio + async def test_close_connection(self, mock_websocket): + """Test closing connection.""" + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + + mock_websocket.close = AsyncMock() + + await connection.close(code=1000, reason="Normal closure") + + assert connection.is_connected is False + mock_websocket.close.assert_called_once_with(code=1000, reason="Normal closure") + + def test_connection_info(self, mock_websocket): + """Test getting connection information.""" + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + connection.subscription_id = "sub_123" + + info = connection.get_connection_info() + + assert info["client_id"] == "client_001" + assert info["is_connected"] is True + assert info["subscription_id"] == "sub_123" + assert "connected_at" in info + assert "last_ping" in info + + +class TestConnectionManager: + """Test WebSocket connection management.""" + + def test_initialization(self): + """Test connection manager initialization.""" + manager = ConnectionManager() + + assert len(manager.connections) == 0 + assert len(manager.subscriptions) == 0 + assert manager.max_connections == 100 + + @pytest.mark.asyncio + async def test_add_connection(self, mock_websocket): + """Test adding a connection.""" + manager = ConnectionManager() + + client_id = "client_001" + connection = await manager.add_connection(mock_websocket, client_id) + + assert connection.client_id == client_id + assert client_id in manager.connections + assert manager.get_connection_count() == 1 + + @pytest.mark.asyncio + async def test_remove_connection(self, mock_websocket): + """Test removing a connection.""" + manager = ConnectionManager() + + client_id = "client_001" + await manager.add_connection(mock_websocket, client_id) + + assert client_id in manager.connections + + removed_connection = await manager.remove_connection(client_id) + + assert removed_connection is not None + assert removed_connection.client_id == client_id + assert client_id not in manager.connections + assert manager.get_connection_count() == 0 + + def test_get_connection(self, mock_websocket): + """Test getting a connection.""" + manager = ConnectionManager() + + client_id = "client_001" + # Manually add connection for testing + connection = WebSocketConnection(mock_websocket, client_id) + manager.connections[client_id] = connection + + retrieved_connection = manager.get_connection(client_id) + + assert retrieved_connection == connection + assert retrieved_connection.client_id == client_id + + def test_get_nonexistent_connection(self): + """Test getting non-existent connection.""" + manager = ConnectionManager() + + connection = manager.get_connection("nonexistent_client") + + assert connection is None + + @pytest.mark.asyncio + async def test_broadcast_message(self, mock_websocket): + """Test broadcasting message to all connections.""" + manager = ConnectionManager() + + # Add multiple connections + connections = [] + for i in range(3): + client_id = f"client_{i}" + ws = Mock() + ws.send_json = AsyncMock() + connection = WebSocketConnection(ws, client_id) + connection.is_connected = True + manager.connections[client_id] = connection + connections.append(connection) + + message = {"type": "broadcast", "data": "hello all"} + + await manager.broadcast(message) + + # All connections should have received the message + for connection in connections: + connection.websocket.send_json.assert_called_once_with(message) + + @pytest.mark.asyncio + async def test_broadcast_to_subscription(self, mock_websocket): + """Test broadcasting to specific subscription.""" + manager = ConnectionManager() + + # Add connections with different subscriptions + subscription_id = "camera_001" + + # Connection with target subscription + ws1 = Mock() + ws1.send_json = AsyncMock() + connection1 = WebSocketConnection(ws1, "client_001") + connection1.is_connected = True + connection1.subscription_id = subscription_id + manager.connections["client_001"] = connection1 + manager.subscriptions[subscription_id] = {"client_001"} + + # Connection with different subscription + ws2 = Mock() + ws2.send_json = AsyncMock() + connection2 = WebSocketConnection(ws2, "client_002") + connection2.is_connected = True + connection2.subscription_id = "camera_002" + manager.connections["client_002"] = connection2 + manager.subscriptions["camera_002"] = {"client_002"} + + message = {"type": "detection", "data": "camera detection"} + + await manager.broadcast_to_subscription(subscription_id, message) + + # Only connection1 should have received the message + ws1.send_json.assert_called_once_with(message) + ws2.send_json.assert_not_called() + + def test_add_subscription(self): + """Test adding subscription mapping.""" + manager = ConnectionManager() + + client_id = "client_001" + subscription_id = "camera_001" + + manager.add_subscription(client_id, subscription_id) + + assert subscription_id in manager.subscriptions + assert client_id in manager.subscriptions[subscription_id] + + def test_remove_subscription(self): + """Test removing subscription mapping.""" + manager = ConnectionManager() + + client_id = "client_001" + subscription_id = "camera_001" + + # Add subscription first + manager.add_subscription(client_id, subscription_id) + assert client_id in manager.subscriptions[subscription_id] + + # Remove subscription + manager.remove_subscription(client_id, subscription_id) + + assert client_id not in manager.subscriptions.get(subscription_id, set()) + + def test_get_subscription_clients(self): + """Test getting clients for a subscription.""" + manager = ConnectionManager() + + subscription_id = "camera_001" + clients = ["client_001", "client_002", "client_003"] + + for client_id in clients: + manager.add_subscription(client_id, subscription_id) + + subscription_clients = manager.get_subscription_clients(subscription_id) + + assert subscription_clients == set(clients) + + def test_get_client_subscriptions(self): + """Test getting subscriptions for a client.""" + manager = ConnectionManager() + + client_id = "client_001" + subscriptions = ["camera_001", "camera_002", "camera_003"] + + for subscription_id in subscriptions: + manager.add_subscription(client_id, subscription_id) + + client_subscriptions = manager.get_client_subscriptions(client_id) + + assert client_subscriptions == set(subscriptions) + + @pytest.mark.asyncio + async def test_cleanup_disconnected_connections(self): + """Test cleanup of disconnected connections.""" + manager = ConnectionManager() + + # Add connected and disconnected connections + ws1 = Mock() + connection1 = WebSocketConnection(ws1, "client_001") + connection1.is_connected = True + manager.connections["client_001"] = connection1 + + ws2 = Mock() + connection2 = WebSocketConnection(ws2, "client_002") + connection2.is_connected = False # Disconnected + manager.connections["client_002"] = connection2 + + # Add subscriptions + manager.add_subscription("client_001", "camera_001") + manager.add_subscription("client_002", "camera_002") + + cleaned_count = await manager.cleanup_disconnected() + + assert cleaned_count == 1 + assert "client_001" in manager.connections # Still connected + assert "client_002" not in manager.connections # Cleaned up + + # Subscriptions should also be cleaned up + assert manager.get_client_subscriptions("client_002") == set() + + def test_get_connection_stats(self): + """Test getting connection statistics.""" + manager = ConnectionManager() + + # Add various connections and subscriptions + for i in range(3): + client_id = f"client_{i}" + ws = Mock() + connection = WebSocketConnection(ws, client_id) + connection.is_connected = i < 2 # First 2 connected, last one disconnected + manager.connections[client_id] = connection + + if i < 2: # Add subscriptions for connected clients + manager.add_subscription(client_id, f"camera_{i}") + + stats = manager.get_connection_stats() + + assert stats["total_connections"] == 3 + assert stats["active_connections"] == 2 + assert stats["total_subscriptions"] == 2 + assert "uptime" in stats + + +class TestMessageHandler: + """Test message handling functionality.""" + + def test_creation(self): + """Test message handler creation.""" + mock_processor = Mock() + handler = MessageHandler(mock_processor) + + assert handler.message_processor == mock_processor + assert handler.connection_manager is None + + def test_set_connection_manager(self): + """Test setting connection manager.""" + mock_processor = Mock() + mock_manager = Mock() + handler = MessageHandler(mock_processor) + + handler.set_connection_manager(mock_manager) + + assert handler.connection_manager == mock_manager + + @pytest.mark.asyncio + async def test_handle_message_success(self, mock_websocket): + """Test successful message handling.""" + mock_processor = Mock() + mock_processor.process_message = AsyncMock(return_value={"type": "response", "status": "success"}) + + handler = MessageHandler(mock_processor) + connection = WebSocketConnection(mock_websocket, "client_001") + + message = {"type": "subscribe", "payload": {"camera_id": "camera_001"}} + + response = await handler.handle_message(connection, message) + + assert response["status"] == "success" + mock_processor.process_message.assert_called_once_with(message, "client_001") + + @pytest.mark.asyncio + async def test_handle_message_processing_error(self, mock_websocket): + """Test message handling with processing error.""" + mock_processor = Mock() + mock_processor.process_message = AsyncMock(side_effect=MessageProcessingError("Invalid message")) + + handler = MessageHandler(mock_processor) + connection = WebSocketConnection(mock_websocket, "client_001") + + message = {"type": "invalid", "payload": {}} + + response = await handler.handle_message(connection, message) + + assert response["type"] == "error" + assert "Invalid message" in response["message"] + + @pytest.mark.asyncio + async def test_handle_message_unexpected_error(self, mock_websocket): + """Test message handling with unexpected error.""" + mock_processor = Mock() + mock_processor.process_message = AsyncMock(side_effect=Exception("Unexpected error")) + + handler = MessageHandler(mock_processor) + connection = WebSocketConnection(mock_websocket, "client_001") + + message = {"type": "test", "payload": {}} + + response = await handler.handle_message(connection, message) + + assert response["type"] == "error" + assert "internal error" in response["message"].lower() + + @pytest.mark.asyncio + async def test_send_response(self, mock_websocket): + """Test sending response to client.""" + mock_processor = Mock() + handler = MessageHandler(mock_processor) + + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + mock_websocket.send_json = AsyncMock() + + response = {"type": "response", "data": "test response"} + + await handler.send_response(connection, response) + + mock_websocket.send_json.assert_called_once_with(response) + + @pytest.mark.asyncio + async def test_send_error_response(self, mock_websocket): + """Test sending error response.""" + mock_processor = Mock() + handler = MessageHandler(mock_processor) + + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + mock_websocket.send_json = AsyncMock() + + await handler.send_error_response(connection, "Test error message", "TEST_ERROR") + + mock_websocket.send_json.assert_called_once() + call_args = mock_websocket.send_json.call_args[0][0] + + assert call_args["type"] == "error" + assert call_args["message"] == "Test error message" + assert call_args["error_code"] == "TEST_ERROR" + + +class TestWebSocketHandler: + """Test main WebSocket handler functionality.""" + + def test_initialization(self): + """Test WebSocket handler initialization.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor) + + assert isinstance(handler.connection_manager, ConnectionManager) + assert isinstance(handler.message_handler, MessageHandler) + assert handler.message_handler.connection_manager == handler.connection_manager + assert handler.heartbeat_interval == 30.0 + assert handler.max_connections == 100 + + def test_initialization_with_config(self): + """Test initialization with custom configuration.""" + mock_processor = Mock() + config = { + "heartbeat_interval": 60.0, + "max_connections": 200, + "connection_timeout": 300.0 + } + + handler = WebSocketHandler(mock_processor, config) + + assert handler.heartbeat_interval == 60.0 + assert handler.max_connections == 200 + assert handler.connection_timeout == 300.0 + + @pytest.mark.asyncio + async def test_handle_websocket_connection(self, mock_websocket): + """Test handling WebSocket connection.""" + mock_processor = Mock() + mock_processor.process_message = AsyncMock(return_value={"type": "ack", "status": "success"}) + + handler = WebSocketHandler(mock_processor) + + # Mock WebSocket behavior + mock_websocket.accept = AsyncMock() + mock_websocket.receive_json = AsyncMock(side_effect=[ + {"type": "subscribe", "payload": {"camera_id": "camera_001"}}, + WebSocketDisconnect() # Simulate disconnection + ]) + mock_websocket.send_json = AsyncMock() + + client_id = "test_client_001" + + # Handle connection (should not raise exception) + await handler.handle_websocket(mock_websocket, client_id) + + # Verify connection was accepted + mock_websocket.accept.assert_called_once() + + # Verify message was processed + mock_processor.process_message.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_websocket_max_connections(self, mock_websocket): + """Test handling max connections limit.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor, {"max_connections": 1}) + + # Add one connection to reach limit + client1_ws = Mock() + connection1 = WebSocketConnection(client1_ws, "client_001") + handler.connection_manager.connections["client_001"] = connection1 + + mock_websocket.close = AsyncMock() + + # Try to add second connection + await handler.handle_websocket(mock_websocket, "client_002") + + # Should close connection due to limit + mock_websocket.close.assert_called_once() + + @pytest.mark.asyncio + async def test_broadcast_message(self): + """Test broadcasting message to all connections.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor) + + # Mock connection manager + handler.connection_manager.broadcast = AsyncMock() + + message = {"type": "system", "data": "Server maintenance in 10 minutes"} + + await handler.broadcast_message(message) + + handler.connection_manager.broadcast.assert_called_once_with(message) + + @pytest.mark.asyncio + async def test_send_to_client(self): + """Test sending message to specific client.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor) + + # Create mock connection + mock_websocket = Mock() + mock_websocket.send_json = AsyncMock() + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + + handler.connection_manager.connections["client_001"] = connection + + message = {"type": "notification", "data": "Personal message"} + + result = await handler.send_to_client("client_001", message) + + assert result is True + mock_websocket.send_json.assert_called_once_with(message) + + @pytest.mark.asyncio + async def test_send_to_nonexistent_client(self): + """Test sending message to non-existent client.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor) + + message = {"type": "notification", "data": "Message"} + + result = await handler.send_to_client("nonexistent_client", message) + + assert result is False + + @pytest.mark.asyncio + async def test_send_to_subscription(self): + """Test sending message to subscription.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor) + + # Mock connection manager + handler.connection_manager.broadcast_to_subscription = AsyncMock() + + subscription_id = "camera_001" + message = {"type": "detection", "data": {"class": "car", "confidence": 0.95}} + + await handler.send_to_subscription(subscription_id, message) + + handler.connection_manager.broadcast_to_subscription.assert_called_once_with(subscription_id, message) + + @pytest.mark.asyncio + async def test_start_heartbeat_task(self): + """Test starting heartbeat task.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor, {"heartbeat_interval": 0.1}) + + # Mock connection with ping capability + mock_websocket = Mock() + mock_websocket.ping = AsyncMock() + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + handler.connection_manager.connections["client_001"] = connection + + # Start heartbeat task + heartbeat_task = asyncio.create_task(handler._heartbeat_loop()) + + # Let it run briefly + await asyncio.sleep(0.2) + + # Cancel task + heartbeat_task.cancel() + + try: + await heartbeat_task + except asyncio.CancelledError: + pass + + # Should have sent at least one ping + assert mock_websocket.ping.called + + def test_get_connection_stats(self): + """Test getting connection statistics.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor) + + # Add some mock connections + for i in range(3): + client_id = f"client_{i}" + ws = Mock() + connection = WebSocketConnection(ws, client_id) + connection.is_connected = True + handler.connection_manager.connections[client_id] = connection + + stats = handler.get_connection_stats() + + assert stats["total_connections"] == 3 + assert stats["active_connections"] == 3 + + def test_get_client_info(self): + """Test getting client information.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor) + + # Add mock connection + mock_websocket = Mock() + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + connection.subscription_id = "camera_001" + handler.connection_manager.connections["client_001"] = connection + + info = handler.get_client_info("client_001") + + assert info is not None + assert info["client_id"] == "client_001" + assert info["is_connected"] is True + assert info["subscription_id"] == "camera_001" + + @pytest.mark.asyncio + async def test_disconnect_client(self): + """Test disconnecting specific client.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor) + + # Add mock connection + mock_websocket = Mock() + mock_websocket.close = AsyncMock() + connection = WebSocketConnection(mock_websocket, "client_001") + connection.is_connected = True + handler.connection_manager.connections["client_001"] = connection + + result = await handler.disconnect_client("client_001", code=1000, reason="Admin disconnect") + + assert result is True + mock_websocket.close.assert_called_once_with(code=1000, reason="Admin disconnect") + + @pytest.mark.asyncio + async def test_cleanup_connections(self): + """Test cleanup of disconnected connections.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor) + + # Mock connection manager cleanup + handler.connection_manager.cleanup_disconnected = AsyncMock(return_value=2) + + cleaned_count = await handler.cleanup_connections() + + assert cleaned_count == 2 + handler.connection_manager.cleanup_disconnected.assert_called_once() + + +class TestWebSocketHandlerIntegration: + """Integration tests for WebSocket handler.""" + + @pytest.mark.asyncio + async def test_complete_subscription_workflow(self, mock_websocket): + """Test complete subscription workflow.""" + mock_processor = Mock() + + # Mock processor responses + mock_processor.process_message = AsyncMock(side_effect=[ + {"type": "subscribeAck", "status": "success", "subscription_id": "camera_001"}, + {"type": "unsubscribeAck", "status": "success"} + ]) + + handler = WebSocketHandler(mock_processor) + + # Mock WebSocket behavior + mock_websocket.accept = AsyncMock() + mock_websocket.send_json = AsyncMock() + mock_websocket.receive_json = AsyncMock(side_effect=[ + {"type": "subscribe", "payload": {"camera_id": "camera_001", "rtsp_url": "rtsp://example.com"}}, + {"type": "unsubscribe", "payload": {"subscription_id": "camera_001"}}, + WebSocketDisconnect() + ]) + + client_id = "test_client" + + # Handle complete workflow + await handler.handle_websocket(mock_websocket, client_id) + + # Verify both messages were processed + assert mock_processor.process_message.call_count == 2 + + # Verify responses were sent + assert mock_websocket.send_json.call_count == 2 + + @pytest.mark.asyncio + async def test_multiple_client_management(self): + """Test managing multiple concurrent clients.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor) + + clients = [] + for i in range(5): + client_id = f"client_{i}" + mock_ws = Mock() + mock_ws.send_json = AsyncMock() + + connection = WebSocketConnection(mock_ws, client_id) + connection.is_connected = True + handler.connection_manager.connections[client_id] = connection + clients.append(connection) + + # Test broadcasting to all clients + message = {"type": "broadcast", "data": "Hello all clients"} + await handler.broadcast_message(message) + + # All clients should receive the message + for connection in clients: + connection.websocket.send_json.assert_called_once_with(message) + + # Test subscription-specific messaging + subscription_id = "camera_001" + handler.connection_manager.add_subscription("client_0", subscription_id) + handler.connection_manager.add_subscription("client_2", subscription_id) + + subscription_message = {"type": "detection", "camera_id": "camera_001"} + await handler.send_to_subscription(subscription_id, subscription_message) + + # Only subscribed clients should receive the message + # Note: This would require additional mocking of broadcast_to_subscription + + @pytest.mark.asyncio + async def test_error_handling_and_recovery(self, mock_websocket): + """Test error handling and recovery scenarios.""" + mock_processor = Mock() + + # First message causes error, second succeeds + mock_processor.process_message = AsyncMock(side_effect=[ + MessageProcessingError("Invalid message format"), + {"type": "ack", "status": "success"} + ]) + + handler = WebSocketHandler(mock_processor) + + mock_websocket.accept = AsyncMock() + mock_websocket.send_json = AsyncMock() + mock_websocket.receive_json = AsyncMock(side_effect=[ + {"type": "invalid", "malformed": True}, + {"type": "valid", "payload": {"test": True}}, + WebSocketDisconnect() + ]) + + client_id = "error_test_client" + + # Should handle errors gracefully and continue processing + await handler.handle_websocket(mock_websocket, client_id) + + # Both messages should have been processed + assert mock_processor.process_message.call_count == 2 + + # Should have sent error response and success response + assert mock_websocket.send_json.call_count == 2 + + # First call should be error response + first_response = mock_websocket.send_json.call_args_list[0][0][0] + assert first_response["type"] == "error" + + @pytest.mark.asyncio + async def test_connection_timeout_handling(self): + """Test connection timeout handling.""" + mock_processor = Mock() + handler = WebSocketHandler(mock_processor, {"connection_timeout": 0.1}) + + # Add connection that hasn't been active + mock_websocket = Mock() + connection = WebSocketConnection(mock_websocket, "timeout_client") + connection.is_connected = True + # Don't update last_ping to simulate timeout + + handler.connection_manager.connections["timeout_client"] = connection + + # Wait longer than timeout + await asyncio.sleep(0.2) + + # Manual cleanup (in real implementation this would be automatic) + cleaned = await handler.cleanup_connections() + + # Connection should be identified for cleanup + # (Actual timeout logic would need to be implemented in the cleanup method) \ No newline at end of file diff --git a/tests/unit/core/test_config.py b/tests/unit/core/test_config.py new file mode 100644 index 0000000..dde6c87 --- /dev/null +++ b/tests/unit/core/test_config.py @@ -0,0 +1,429 @@ +""" +Unit tests for configuration management system. +""" +import pytest +import json +import os +import tempfile +from unittest.mock import Mock, patch, MagicMock + +from detector_worker.core.config import ( + ConfigurationManager, + JsonFileProvider, + EnvironmentProvider, + DatabaseConfig, + RedisConfig, + StreamConfig, + ModelConfig, + LoggingConfig, + get_config_manager, + validate_config +) +from detector_worker.core.exceptions import ConfigurationError + + +class TestJsonFileProvider: + """Test JSON file configuration provider.""" + + def test_get_config_from_valid_file(self, temp_dir): + """Test loading configuration from a valid JSON file.""" + config_data = {"test_key": "test_value", "number": 42} + config_file = os.path.join(temp_dir, "config.json") + + with open(config_file, 'w') as f: + json.dump(config_data, f) + + provider = JsonFileProvider(config_file) + result = provider.get_config() + + assert result == config_data + + def test_get_config_file_not_exists(self, temp_dir): + """Test handling of non-existent config file.""" + config_file = os.path.join(temp_dir, "nonexistent.json") + provider = JsonFileProvider(config_file) + + result = provider.get_config() + assert result == {} + + def test_get_config_invalid_json(self, temp_dir): + """Test handling of invalid JSON file.""" + config_file = os.path.join(temp_dir, "invalid.json") + + with open(config_file, 'w') as f: + f.write("invalid json content") + + provider = JsonFileProvider(config_file) + result = provider.get_config() + + assert result == {} + + def test_reload_updates_config(self, temp_dir): + """Test that reload updates configuration.""" + config_file = os.path.join(temp_dir, "config.json") + + # Initial config + initial_config = {"version": 1} + with open(config_file, 'w') as f: + json.dump(initial_config, f) + + provider = JsonFileProvider(config_file) + assert provider.get_config() == initial_config + + # Update config file + updated_config = {"version": 2} + with open(config_file, 'w') as f: + json.dump(updated_config, f) + + # Force reload + provider.reload() + assert provider.get_config() == updated_config + + +class TestEnvironmentProvider: + """Test environment variable configuration provider.""" + + def test_get_config_with_env_vars(self): + """Test loading configuration from environment variables.""" + env_vars = { + "DETECTOR_MAX_STREAMS": "10", + "DETECTOR_TARGET_FPS": "15", + "DETECTOR_CONFIG": '{"nested": "value"}', + "OTHER_VAR": "ignored" + } + + with patch.dict(os.environ, env_vars, clear=False): + provider = EnvironmentProvider("DETECTOR_") + config = provider.get_config() + + assert config["max_streams"] == "10" + assert config["target_fps"] == "15" + assert config["config"] == {"nested": "value"} + assert "other_var" not in config + + def test_get_config_no_env_vars(self): + """Test with no matching environment variables.""" + with patch.dict(os.environ, {}, clear=True): + provider = EnvironmentProvider("DETECTOR_") + config = provider.get_config() + + assert config == {} + + def test_custom_prefix(self): + """Test with custom prefix.""" + env_vars = {"CUSTOM_TEST": "value"} + + with patch.dict(os.environ, env_vars, clear=False): + provider = EnvironmentProvider("CUSTOM_") + config = provider.get_config() + + assert config["test"] == "value" + + +class TestConfigDataclasses: + """Test configuration dataclasses.""" + + def test_database_config_from_dict(self): + """Test DatabaseConfig creation from dictionary.""" + data = { + "enabled": True, + "host": "db.example.com", + "port": 5432, + "database": "testdb", + "username": "user", + "password": "pass", + "schema": "test_schema", + "unknown_field": "ignored" + } + + config = DatabaseConfig.from_dict(data) + + assert config.enabled is True + assert config.host == "db.example.com" + assert config.port == 5432 + assert config.database == "testdb" + assert config.username == "user" + assert config.password == "pass" + assert config.schema == "test_schema" + # Unknown fields should be ignored + assert not hasattr(config, 'unknown_field') + + def test_redis_config_from_dict(self): + """Test RedisConfig creation from dictionary.""" + data = { + "enabled": True, + "host": "redis.example.com", + "port": 6379, + "password": "secret", + "db": 1 + } + + config = RedisConfig.from_dict(data) + + assert config.enabled is True + assert config.host == "redis.example.com" + assert config.port == 6379 + assert config.password == "secret" + assert config.db == 1 + + def test_stream_config_from_dict(self): + """Test StreamConfig creation from dictionary.""" + data = { + "poll_interval_ms": 50, + "max_streams": 10, + "target_fps": 20, + "reconnect_interval_sec": 10, + "max_retries": 5 + } + + config = StreamConfig.from_dict(data) + + assert config.poll_interval_ms == 50 + assert config.max_streams == 10 + assert config.target_fps == 20 + assert config.reconnect_interval_sec == 10 + assert config.max_retries == 5 + + +class TestConfigurationManager: + """Test main configuration manager.""" + + def test_initialization_with_defaults(self): + """Test that manager initializes with default values.""" + manager = ConfigurationManager() + + # Should have default providers + assert len(manager._providers) >= 1 + + # Should have default configuration values + config = manager.get_all() + assert "poll_interval_ms" in config + assert "max_streams" in config + assert "target_fps" in config + + def test_add_provider(self): + """Test adding configuration providers.""" + manager = ConfigurationManager() + initial_count = len(manager._providers) + + mock_provider = Mock() + mock_provider.get_config.return_value = {"test": "value"} + + manager.add_provider(mock_provider) + + assert len(manager._providers) == initial_count + 1 + + def test_get_configuration_value(self): + """Test getting specific configuration values.""" + manager = ConfigurationManager() + + # Test existing key + value = manager.get("poll_interval_ms") + assert value is not None + + # Test non-existing key with default + value = manager.get("nonexistent", "default") + assert value == "default" + + # Test non-existing key without default + value = manager.get("nonexistent") + assert value is None + + def test_get_section(self): + """Test getting configuration sections.""" + manager = ConfigurationManager() + + # Test existing section + db_section = manager.get_section("database") + assert isinstance(db_section, dict) + + # Test non-existing section + empty_section = manager.get_section("nonexistent") + assert empty_section == {} + + def test_typed_config_access(self): + """Test typed configuration object access.""" + manager = ConfigurationManager() + + # Test database config + db_config = manager.get_database_config() + assert isinstance(db_config, DatabaseConfig) + + # Test Redis config + redis_config = manager.get_redis_config() + assert isinstance(redis_config, RedisConfig) + + # Test stream config + stream_config = manager.get_stream_config() + assert isinstance(stream_config, StreamConfig) + + # Test model config + model_config = manager.get_model_config() + assert isinstance(model_config, ModelConfig) + + # Test logging config + logging_config = manager.get_logging_config() + assert isinstance(logging_config, LoggingConfig) + + def test_set_configuration_value(self): + """Test setting configuration values at runtime.""" + manager = ConfigurationManager() + + manager.set("test_key", "test_value") + + assert manager.get("test_key") == "test_value" + + # Should also update typed configs + manager.set("poll_interval_ms", 200) + stream_config = manager.get_stream_config() + assert stream_config.poll_interval_ms == 200 + + def test_validation_success(self): + """Test configuration validation with valid config.""" + manager = ConfigurationManager() + + # Set valid configuration + manager.set("poll_interval_ms", 100) + manager.set("max_streams", 5) + manager.set("target_fps", 10) + + errors = manager.validate() + assert errors == [] + assert manager.is_valid() is True + + def test_validation_errors(self): + """Test configuration validation with invalid values.""" + manager = ConfigurationManager() + + # Set invalid configuration + manager.set("poll_interval_ms", 0) + manager.set("max_streams", -1) + manager.set("target_fps", 0) + + errors = manager.validate() + assert len(errors) > 0 + assert manager.is_valid() is False + + # Check specific errors + error_messages = " ".join(errors) + assert "poll_interval_ms must be positive" in error_messages + assert "max_streams must be positive" in error_messages + assert "target_fps must be positive" in error_messages + + def test_database_validation(self): + """Test database-specific validation.""" + manager = ConfigurationManager() + + # Enable database but don't provide required fields + db_config = { + "enabled": True, + "host": "", + "database": "" + } + manager.set("database", db_config) + + errors = manager.validate() + error_messages = " ".join(errors) + + assert "database host is required" in error_messages + assert "database name is required" in error_messages + + def test_redis_validation(self): + """Test Redis-specific validation.""" + manager = ConfigurationManager() + + # Enable Redis but don't provide required fields + redis_config = { + "enabled": True, + "host": "" + } + manager.set("redis", redis_config) + + errors = manager.validate() + error_messages = " ".join(errors) + + assert "redis host is required" in error_messages + + +class TestGlobalConfigurationFunctions: + """Test global configuration functions.""" + + def test_get_config_manager_singleton(self): + """Test that get_config_manager returns a singleton.""" + manager1 = get_config_manager() + manager2 = get_config_manager() + + assert manager1 is manager2 + + @patch('detector_worker.core.config.get_config_manager') + def test_validate_config_function(self, mock_get_manager): + """Test global validate_config function.""" + mock_manager = Mock() + mock_manager.validate.return_value = ["error1", "error2"] + mock_get_manager.return_value = mock_manager + + errors = validate_config() + + assert errors == ["error1", "error2"] + mock_manager.validate.assert_called_once() + + +class TestConfigurationIntegration: + """Integration tests for configuration system.""" + + def test_provider_priority(self, temp_dir): + """Test that later providers override earlier ones.""" + # Create JSON file with initial config + config_file = os.path.join(temp_dir, "config.json") + json_config = {"test_value": "from_json", "json_only": "json"} + + with open(config_file, 'w') as f: + json.dump(json_config, f) + + # Set environment variable that should override + env_vars = {"DETECTOR_TEST_VALUE": "from_env", "DETECTOR_ENV_ONLY": "env"} + + with patch.dict(os.environ, env_vars, clear=False): + manager = ConfigurationManager() + manager._providers.clear() # Start fresh + + # Add providers in order + manager.add_provider(JsonFileProvider(config_file)) + manager.add_provider(EnvironmentProvider("DETECTOR_")) + + config = manager.get_all() + + # Environment should override JSON + assert config["test_value"] == "from_env" + + # Both sources should be present + assert config["json_only"] == "json" + assert config["env_only"] == "env" + + def test_hot_reload(self, temp_dir): + """Test configuration hot reload functionality.""" + config_file = os.path.join(temp_dir, "config.json") + + # Initial config + initial_config = {"version": 1, "feature_enabled": False} + with open(config_file, 'w') as f: + json.dump(initial_config, f) + + manager = ConfigurationManager() + manager._providers.clear() + manager.add_provider(JsonFileProvider(config_file)) + + assert manager.get("version") == 1 + assert manager.get("feature_enabled") is False + + # Update config file + updated_config = {"version": 2, "feature_enabled": True} + with open(config_file, 'w') as f: + json.dump(updated_config, f) + + # Reload configuration + success = manager.reload() + assert success is True + + assert manager.get("version") == 2 + assert manager.get("feature_enabled") is True \ No newline at end of file diff --git a/tests/unit/core/test_dependency_injection.py b/tests/unit/core/test_dependency_injection.py new file mode 100644 index 0000000..2147950 --- /dev/null +++ b/tests/unit/core/test_dependency_injection.py @@ -0,0 +1,566 @@ +""" +Unit tests for dependency injection system. +""" +import pytest +import threading +from unittest.mock import Mock, MagicMock + +from detector_worker.core.dependency_injection import ( + ServiceContainer, + ServiceLifetime, + ServiceDescriptor, + ServiceScope, + DetectorWorkerContainer, + get_container, + resolve_service, + create_service_scope +) +from detector_worker.core.exceptions import DependencyInjectionError + + +class TestServiceContainer: + """Test core service container functionality.""" + + def test_register_singleton(self): + """Test singleton service registration.""" + container = ServiceContainer() + + class TestService: + def __init__(self): + self.value = 42 + + # Register singleton + container.register_singleton(TestService) + + # Resolve twice - should get same instance + instance1 = container.resolve(TestService) + instance2 = container.resolve(TestService) + + assert instance1 is instance2 + assert instance1.value == 42 + + def test_register_singleton_with_instance(self): + """Test singleton registration with pre-created instance.""" + container = ServiceContainer() + + class TestService: + def __init__(self, value): + self.value = value + + # Create instance and register + instance = TestService(99) + container.register_singleton(TestService, instance=instance) + + # Resolve should return the pre-created instance + resolved = container.resolve(TestService) + assert resolved is instance + assert resolved.value == 99 + + def test_register_transient(self): + """Test transient service registration.""" + container = ServiceContainer() + + class TestService: + def __init__(self): + self.value = 42 + + # Register transient + container.register_transient(TestService) + + # Resolve twice - should get different instances + instance1 = container.resolve(TestService) + instance2 = container.resolve(TestService) + + assert instance1 is not instance2 + assert instance1.value == instance2.value == 42 + + def test_register_scoped(self): + """Test scoped service registration.""" + container = ServiceContainer() + + class TestService: + def __init__(self): + self.value = 42 + + # Register scoped + container.register_scoped(TestService) + + # Resolve in same scope - should get same instance + instance1 = container.resolve(TestService, scope_id="scope1") + instance2 = container.resolve(TestService, scope_id="scope1") + + assert instance1 is instance2 + + # Resolve in different scope - should get different instance + instance3 = container.resolve(TestService, scope_id="scope2") + assert instance3 is not instance1 + + def test_register_with_factory(self): + """Test service registration with factory function.""" + container = ServiceContainer() + + class TestService: + def __init__(self, value): + self.value = value + + # Register with factory + def factory(): + return TestService(100) + + container.register_singleton(TestService, factory=factory) + + instance = container.resolve(TestService) + assert instance.value == 100 + + def test_register_with_implementation_type(self): + """Test service registration with implementation type.""" + container = ServiceContainer() + + class ITestService: + pass + + class TestService(ITestService): + def __init__(self): + self.value = 42 + + # Register interface with implementation + container.register_singleton(ITestService, implementation_type=TestService) + + instance = container.resolve(ITestService) + assert isinstance(instance, TestService) + assert instance.value == 42 + + def test_dependency_injection(self): + """Test automatic dependency injection.""" + container = ServiceContainer() + + class DatabaseService: + def __init__(self): + self.connected = True + + class UserService: + def __init__(self, database: DatabaseService): + self.database = database + + # Register services + container.register_singleton(DatabaseService) + container.register_transient(UserService) + + # Resolve should inject dependencies + user_service = container.resolve(UserService) + assert isinstance(user_service.database, DatabaseService) + assert user_service.database.connected is True + + def test_circular_dependency_detection(self): + """Test circular dependency detection.""" + container = ServiceContainer() + + class ServiceA: + def __init__(self, service_b: 'ServiceB'): + self.service_b = service_b + + class ServiceB: + def __init__(self, service_a: ServiceA): + self.service_a = service_a + + # Register circular dependencies + container.register_singleton(ServiceA) + container.register_singleton(ServiceB) + + # Should raise circular dependency error + with pytest.raises(DependencyInjectionError) as exc_info: + container.resolve(ServiceA) + + assert "Circular dependency detected" in str(exc_info.value) + + def test_unregistered_service_error(self): + """Test error when resolving unregistered service.""" + container = ServiceContainer() + + class UnregisteredService: + pass + + with pytest.raises(DependencyInjectionError) as exc_info: + container.resolve(UnregisteredService) + + assert "is not registered" in str(exc_info.value) + + def test_scoped_service_without_scope_id(self): + """Test error when resolving scoped service without scope ID.""" + container = ServiceContainer() + + class TestService: + pass + + container.register_scoped(TestService) + + with pytest.raises(DependencyInjectionError) as exc_info: + container.resolve(TestService) + + assert "Scope ID required" in str(exc_info.value) + + def test_factory_error_handling(self): + """Test factory error handling.""" + container = ServiceContainer() + + class TestService: + pass + + def failing_factory(): + raise ValueError("Factory failed") + + container.register_singleton(TestService, factory=failing_factory) + + with pytest.raises(DependencyInjectionError) as exc_info: + container.resolve(TestService) + + assert "Failed to create service using factory" in str(exc_info.value) + + def test_constructor_dependency_with_default(self): + """Test dependency with default value.""" + container = ServiceContainer() + + class TestService: + def __init__(self, value: int = 42): + self.value = value + + container.register_singleton(TestService) + + instance = container.resolve(TestService) + assert instance.value == 42 + + def test_unresolvable_dependency_with_default(self): + """Test unresolvable dependency that has a default value.""" + container = ServiceContainer() + + class UnregisteredService: + pass + + class TestService: + def __init__(self, dep: UnregisteredService = None): + self.dep = dep + + container.register_singleton(TestService) + + instance = container.resolve(TestService) + assert instance.dep is None + + def test_unresolvable_dependency_without_default(self): + """Test unresolvable dependency without default value.""" + container = ServiceContainer() + + class UnregisteredService: + pass + + class TestService: + def __init__(self, dep: UnregisteredService): + self.dep = dep + + container.register_singleton(TestService) + + with pytest.raises(DependencyInjectionError) as exc_info: + container.resolve(TestService) + + assert "Cannot resolve dependency" in str(exc_info.value) + + +class TestServiceScope: + """Test service scope functionality.""" + + def test_create_scope(self): + """Test scope creation.""" + container = ServiceContainer() + + class TestService: + def __init__(self): + self.value = 42 + + container.register_scoped(TestService) + + scope = container.create_scope("test_scope") + assert isinstance(scope, ServiceScope) + assert scope.scope_id == "test_scope" + + def test_scope_context_manager(self): + """Test scope as context manager.""" + container = ServiceContainer() + + class TestService: + def __init__(self): + self.disposed = False + + def dispose(self): + self.disposed = True + + container.register_scoped(TestService) + + instance = None + with container.create_scope("test_scope") as scope: + instance = scope.resolve(TestService) + assert not instance.disposed + + # Instance should be disposed after scope exit + assert instance.disposed + + def test_dispose_scope(self): + """Test manual scope disposal.""" + container = ServiceContainer() + + class TestService: + def __init__(self): + self.disposed = False + + def dispose(self): + self.disposed = True + + container.register_scoped(TestService) + + instance = container.resolve(TestService, scope_id="test_scope") + assert not instance.disposed + + container.dispose_scope("test_scope") + assert instance.disposed + + def test_dispose_error_handling(self): + """Test error handling during scope disposal.""" + container = ServiceContainer() + + class TestService: + def dispose(self): + raise ValueError("Dispose failed") + + container.register_scoped(TestService) + + container.resolve(TestService, scope_id="test_scope") + + # Should not raise error, just log it + container.dispose_scope("test_scope") + + +class TestContainerIntrospection: + """Test container introspection capabilities.""" + + def test_is_registered(self): + """Test checking if service is registered.""" + container = ServiceContainer() + + class RegisteredService: + pass + + class UnregisteredService: + pass + + container.register_singleton(RegisteredService) + + assert container.is_registered(RegisteredService) is True + assert container.is_registered(UnregisteredService) is False + + def test_get_registration_info(self): + """Test getting service registration information.""" + container = ServiceContainer() + + class TestService: + pass + + container.register_singleton(TestService) + + info = container.get_registration_info(TestService) + assert isinstance(info, ServiceDescriptor) + assert info.service_type == TestService + assert info.lifetime == ServiceLifetime.SINGLETON + + def test_get_registered_services(self): + """Test getting all registered services.""" + container = ServiceContainer() + + class Service1: + pass + + class Service2: + pass + + container.register_singleton(Service1) + container.register_transient(Service2) + + services = container.get_registered_services() + assert len(services) == 2 + assert Service1 in services + assert Service2 in services + + def test_clear_singletons(self): + """Test clearing singleton instances.""" + container = ServiceContainer() + + class TestService: + pass + + container.register_singleton(TestService) + + # Create singleton instance + instance1 = container.resolve(TestService) + + # Clear singletons + container.clear_singletons() + + # Next resolve should create new instance + instance2 = container.resolve(TestService) + assert instance2 is not instance1 + + def test_get_stats(self): + """Test getting container statistics.""" + container = ServiceContainer() + + class Service1: + pass + + class Service2: + pass + + class Service3: + pass + + container.register_singleton(Service1) + container.register_transient(Service2) + container.register_scoped(Service3) + + # Create some instances + container.resolve(Service1) + container.resolve(Service3, scope_id="scope1") + + stats = container.get_stats() + + assert stats["registered_services"] == 3 + assert stats["active_singletons"] == 1 + assert stats["active_scopes"] == 1 + assert stats["lifetime_breakdown"]["singleton"] == 1 + assert stats["lifetime_breakdown"]["transient"] == 1 + assert stats["lifetime_breakdown"]["scoped"] == 1 + + +class TestDetectorWorkerContainer: + """Test pre-configured detector worker container.""" + + def test_initialization(self): + """Test detector worker container initialization.""" + container = DetectorWorkerContainer() + + assert isinstance(container.container, ServiceContainer) + + # Should have core services registered + stats = container.container.get_stats() + assert stats["registered_services"] > 0 + + def test_resolve_convenience_method(self): + """Test resolve convenience method.""" + container = DetectorWorkerContainer() + + # Should be able to resolve through convenience method + from detector_worker.core.singleton_managers import ModelStateManager + + manager = container.resolve(ModelStateManager) + assert isinstance(manager, ModelStateManager) + + def test_create_scope_convenience_method(self): + """Test create scope convenience method.""" + container = DetectorWorkerContainer() + + scope = container.create_scope("test_scope") + assert isinstance(scope, ServiceScope) + assert scope.scope_id == "test_scope" + + +class TestGlobalContainerFunctions: + """Test global container functions.""" + + def test_get_container_singleton(self): + """Test that get_container returns a singleton.""" + container1 = get_container() + container2 = get_container() + + assert container1 is container2 + assert isinstance(container1, DetectorWorkerContainer) + + def test_resolve_service_convenience(self): + """Test resolve_service convenience function.""" + from detector_worker.core.singleton_managers import ModelStateManager + + manager = resolve_service(ModelStateManager) + assert isinstance(manager, ModelStateManager) + + def test_create_service_scope_convenience(self): + """Test create_service_scope convenience function.""" + scope = create_service_scope("test_scope") + assert isinstance(scope, ServiceScope) + assert scope.scope_id == "test_scope" + + +class TestThreadSafety: + """Test thread safety of dependency injection system.""" + + def test_container_thread_safety(self): + """Test that container is thread-safe.""" + container = ServiceContainer() + + class TestService: + def __init__(self): + import threading + self.thread_id = threading.current_thread().ident + + container.register_singleton(TestService) + + instances = {} + + def resolve_service(thread_id): + instances[thread_id] = container.resolve(TestService) + + # Create multiple threads + threads = [] + for i in range(10): + thread = threading.Thread(target=resolve_service, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads + for thread in threads: + thread.join() + + # All should get the same singleton instance + first_instance = list(instances.values())[0] + for instance in instances.values(): + assert instance is first_instance + + def test_scope_thread_safety(self): + """Test that scoped services are thread-safe.""" + container = ServiceContainer() + + class TestService: + def __init__(self): + import threading + self.thread_id = threading.current_thread().ident + + container.register_scoped(TestService) + + results = {} + + def resolve_in_scope(thread_id): + # Each thread uses its own scope + instance1 = container.resolve(TestService, scope_id=f"scope_{thread_id}") + instance2 = container.resolve(TestService, scope_id=f"scope_{thread_id}") + + results[thread_id] = { + "same_instance": instance1 is instance2, + "thread_id": instance1.thread_id + } + + threads = [] + for i in range(5): + thread = threading.Thread(target=resolve_in_scope, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Each thread should get same instance within its scope + for thread_id, result in results.items(): + assert result["same_instance"] is True \ No newline at end of file diff --git a/tests/unit/core/test_singleton_managers.py b/tests/unit/core/test_singleton_managers.py new file mode 100644 index 0000000..4c0ad98 --- /dev/null +++ b/tests/unit/core/test_singleton_managers.py @@ -0,0 +1,560 @@ +""" +Unit tests for singleton state managers. +""" +import pytest +import time +import threading +from unittest.mock import Mock, patch, MagicMock + +from detector_worker.core.singleton_managers import ( + SingletonMeta, + ModelStateManager, + StreamStateManager, + SessionStateManager, + CacheStateManager, + CameraStateManager, + PipelineStateManager, + ModelInfo, + StreamInfo, + SessionInfo +) + + +class TestSingletonMeta: + """Test singleton metaclass.""" + + def test_singleton_behavior(self): + """Test that singleton metaclass creates only one instance.""" + class TestSingleton(metaclass=SingletonMeta): + def __init__(self): + self.value = 42 + + instance1 = TestSingleton() + instance2 = TestSingleton() + + assert instance1 is instance2 + assert instance1.value == instance2.value + + def test_singleton_thread_safety(self): + """Test that singleton is thread-safe.""" + class TestSingleton(metaclass=SingletonMeta): + def __init__(self): + self.created_by = threading.current_thread().name + + instances = {} + + def create_instance(thread_id): + instances[thread_id] = TestSingleton() + + threads = [] + for i in range(10): + thread = threading.Thread(target=create_instance, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # All instances should be the same object + first_instance = instances[0] + for instance in instances.values(): + assert instance is first_instance + + +class TestModelStateManager: + """Test model state management.""" + + def test_singleton_behavior(self): + """Test that ModelStateManager is a singleton.""" + manager1 = ModelStateManager() + manager2 = ModelStateManager() + + assert manager1 is manager2 + + def test_load_model(self): + """Test loading a model.""" + manager = ModelStateManager() + manager.clear_all() # Start fresh + + mock_model = Mock() + manager.load_model("camera1", "model1", mock_model) + + retrieved_model = manager.get_model("camera1", "model1") + assert retrieved_model is mock_model + + def test_load_same_model_increments_reference_count(self): + """Test that loading the same model increments reference count.""" + manager = ModelStateManager() + manager.clear_all() + + mock_model = Mock() + + # Load same model twice + manager.load_model("camera1", "model1", mock_model) + manager.load_model("camera1", "model1", mock_model) + + # Should still be accessible + assert manager.get_model("camera1", "model1") is mock_model + + def test_get_camera_models(self): + """Test getting all models for a camera.""" + manager = ModelStateManager() + manager.clear_all() + + mock_model1 = Mock() + mock_model2 = Mock() + + manager.load_model("camera1", "model1", mock_model1) + manager.load_model("camera1", "model2", mock_model2) + + models = manager.get_camera_models("camera1") + + assert len(models) == 2 + assert models["model1"] is mock_model1 + assert models["model2"] is mock_model2 + + def test_unload_model_with_multiple_references(self): + """Test unloading model with multiple references.""" + manager = ModelStateManager() + manager.clear_all() + + mock_model = Mock() + + # Load model twice (reference count = 2) + manager.load_model("camera1", "model1", mock_model) + manager.load_model("camera1", "model1", mock_model) + + # First unload should not remove model + result = manager.unload_model("camera1", "model1") + assert result is False # Still referenced + assert manager.get_model("camera1", "model1") is mock_model + + # Second unload should remove model + result = manager.unload_model("camera1", "model1") + assert result is True # Completely removed + assert manager.get_model("camera1", "model1") is None + + def test_unload_camera_models(self): + """Test unloading all models for a camera.""" + manager = ModelStateManager() + manager.clear_all() + + mock_model1 = Mock() + mock_model2 = Mock() + + manager.load_model("camera1", "model1", mock_model1) + manager.load_model("camera1", "model2", mock_model2) + + manager.unload_camera_models("camera1") + + assert manager.get_model("camera1", "model1") is None + assert manager.get_model("camera1", "model2") is None + + def test_get_stats(self): + """Test getting model statistics.""" + manager = ModelStateManager() + manager.clear_all() + + mock_model = Mock() + manager.load_model("camera1", "model1", mock_model) + manager.load_model("camera2", "model2", mock_model) + + stats = manager.get_stats() + + assert stats["total_models"] == 2 + assert stats["total_cameras"] == 2 + assert "camera1" in stats["cameras"] + assert "camera2" in stats["cameras"] + + +class TestStreamStateManager: + """Test stream state management.""" + + def test_add_stream(self): + """Test adding a stream.""" + manager = StreamStateManager() + manager.clear_all() + + config = {"rtsp_url": "rtsp://example.com", "model_id": "test"} + manager.add_stream("camera1", "sub1", config) + + stream = manager.get_stream("camera1") + assert stream is not None + assert stream.camera_id == "camera1" + assert stream.subscription_id == "sub1" + assert stream.config == config + + def test_subscription_mapping(self): + """Test subscription to camera mapping.""" + manager = StreamStateManager() + manager.clear_all() + + config = {"rtsp_url": "rtsp://example.com"} + manager.add_stream("camera1", "sub1", config) + + camera_id = manager.get_camera_by_subscription("sub1") + assert camera_id == "camera1" + + def test_remove_stream(self): + """Test removing a stream.""" + manager = StreamStateManager() + manager.clear_all() + + config = {"rtsp_url": "rtsp://example.com"} + manager.add_stream("camera1", "sub1", config) + + removed_stream = manager.remove_stream("camera1") + + assert removed_stream is not None + assert removed_stream.camera_id == "camera1" + assert manager.get_stream("camera1") is None + assert manager.get_camera_by_subscription("sub1") is None + + def test_shared_stream_management(self): + """Test shared stream management.""" + manager = StreamStateManager() + manager.clear_all() + + stream_data = {"reader": Mock(), "reference_count": 1} + manager.add_shared_stream("rtsp://example.com", stream_data) + + retrieved_data = manager.get_shared_stream("rtsp://example.com") + assert retrieved_data == stream_data + + removed_data = manager.remove_shared_stream("rtsp://example.com") + assert removed_data == stream_data + assert manager.get_shared_stream("rtsp://example.com") is None + + +class TestSessionStateManager: + """Test session state management.""" + + def test_session_id_management(self): + """Test session ID assignment.""" + manager = SessionStateManager() + manager.clear_all() + + manager.set_session_id("display1", "session123") + + session_id = manager.get_session_id("display1") + assert session_id == "session123" + + def test_create_session(self): + """Test session creation with detection data.""" + manager = SessionStateManager() + manager.clear_all() + + detection_data = {"class": "car", "confidence": 0.85} + manager.create_session("session123", "camera1", detection_data) + + retrieved_data = manager.get_session_detection("session123") + assert retrieved_data == detection_data + + camera_id = manager.get_camera_by_session("session123") + assert camera_id == "camera1" + + def test_update_session_detection(self): + """Test updating session detection data.""" + manager = SessionStateManager() + manager.clear_all() + + initial_data = {"class": "car", "confidence": 0.85} + manager.create_session("session123", "camera1", initial_data) + + update_data = {"brand": "Toyota"} + manager.update_session_detection("session123", update_data) + + final_data = manager.get_session_detection("session123") + assert final_data["class"] == "car" + assert final_data["brand"] == "Toyota" + + def test_session_expiration(self): + """Test session expiration based on TTL.""" + # Use a very short TTL for testing + manager = SessionStateManager(session_ttl=0.1) + manager.clear_all() + + detection_data = {"class": "car", "confidence": 0.85} + manager.create_session("session123", "camera1", detection_data) + + # Session should exist initially + assert manager.get_session_detection("session123") is not None + + # Wait for expiration + time.sleep(0.2) + + # Clean up expired sessions + expired_count = manager.cleanup_expired_sessions() + + assert expired_count == 1 + assert manager.get_session_detection("session123") is None + + def test_remove_session(self): + """Test manual session removal.""" + manager = SessionStateManager() + manager.clear_all() + + detection_data = {"class": "car", "confidence": 0.85} + manager.create_session("session123", "camera1", detection_data) + + result = manager.remove_session("session123") + assert result is True + + assert manager.get_session_detection("session123") is None + assert manager.get_camera_by_session("session123") is None + + +class TestCacheStateManager: + """Test cache state management.""" + + def test_cache_detection(self): + """Test caching detection results.""" + manager = CacheStateManager() + manager.clear_all() + + detection_data = {"class": "car", "confidence": 0.85, "bbox": [100, 200, 300, 400]} + manager.cache_detection("camera1", detection_data) + + cached_data = manager.get_cached_detection("camera1") + assert cached_data == detection_data + + def test_cache_pipeline_result(self): + """Test caching pipeline results.""" + manager = CacheStateManager() + manager.clear_all() + + pipeline_result = {"status": "success", "detections": []} + manager.cache_pipeline_result("camera1", pipeline_result) + + cached_result = manager.get_cached_pipeline_result("camera1") + assert cached_result == pipeline_result + + def test_latest_frame_management(self): + """Test latest frame storage.""" + manager = CacheStateManager() + manager.clear_all() + + frame_data = b"fake_frame_data" + manager.set_latest_frame("camera1", frame_data) + + retrieved_frame = manager.get_latest_frame("camera1") + assert retrieved_frame == frame_data + + def test_frame_skip_flag(self): + """Test frame skip flag management.""" + manager = CacheStateManager() + manager.clear_all() + + # Initially should be False + assert manager.get_frame_skip_flag("camera1") is False + + manager.set_frame_skip_flag("camera1", True) + assert manager.get_frame_skip_flag("camera1") is True + + manager.set_frame_skip_flag("camera1", False) + assert manager.get_frame_skip_flag("camera1") is False + + def test_clear_camera_cache(self): + """Test clearing all cache data for a camera.""" + manager = CacheStateManager() + manager.clear_all() + + # Set up cache data + detection_data = {"class": "car"} + pipeline_result = {"status": "success"} + frame_data = b"frame" + + manager.cache_detection("camera1", detection_data) + manager.cache_pipeline_result("camera1", pipeline_result) + manager.set_latest_frame("camera1", frame_data) + manager.set_frame_skip_flag("camera1", True) + + # Clear cache + manager.clear_camera_cache("camera1") + + # All data should be gone + assert manager.get_cached_detection("camera1") is None + assert manager.get_cached_pipeline_result("camera1") is None + assert manager.get_latest_frame("camera1") is None + assert manager.get_frame_skip_flag("camera1") is False + + +class TestCameraStateManager: + """Test camera state management.""" + + def test_camera_connection_state(self): + """Test camera connection state management.""" + manager = CameraStateManager() + manager.clear_all() + + # Initially connected (default) + assert manager.is_camera_connected("camera1") is True + + # Set disconnected + manager.set_camera_connected("camera1", False) + assert manager.is_camera_connected("camera1") is False + + # Set connected again + manager.set_camera_connected("camera1", True) + assert manager.is_camera_connected("camera1") is True + + def test_notification_flags(self): + """Test disconnection/reconnection notification flags.""" + manager = CameraStateManager() + manager.clear_all() + + # Set disconnected + manager.set_camera_connected("camera1", False) + + # Should notify disconnection once + assert manager.should_notify_disconnection("camera1") is True + manager.mark_disconnection_notified("camera1") + assert manager.should_notify_disconnection("camera1") is False + + # Reconnect + manager.set_camera_connected("camera1", True) + + # Should notify reconnection + assert manager.should_notify_reconnection("camera1") is True + manager.mark_reconnection_notified("camera1") + assert manager.should_notify_reconnection("camera1") is False + + def test_get_camera_state(self): + """Test getting full camera state.""" + manager = CameraStateManager() + manager.clear_all() + + manager.set_camera_connected("camera1", False) + + state = manager.get_camera_state("camera1") + + assert state["connected"] is False + assert "last_update" in state + assert "disconnection_notified" in state + assert "reconnection_notified" in state + + def test_get_stats(self): + """Test getting camera state statistics.""" + manager = CameraStateManager() + manager.clear_all() + + manager.set_camera_connected("camera1", True) + manager.set_camera_connected("camera2", False) + + stats = manager.get_stats() + + assert stats["total_cameras"] == 2 + assert stats["connected_cameras"] == 1 + assert stats["disconnected_cameras"] == 1 + + +class TestPipelineStateManager: + """Test pipeline state management.""" + + def test_get_or_init_state(self): + """Test getting or initializing pipeline state.""" + manager = PipelineStateManager() + manager.clear_all() + + state = manager.get_or_init_state("camera1") + + assert state["mode"] == "validation_detecting" + assert state["backend_session_id"] is None + assert state["yolo_inference_enabled"] is True + assert "created_at" in state + + def test_update_mode(self): + """Test updating pipeline mode.""" + manager = PipelineStateManager() + manager.clear_all() + + manager.update_mode("camera1", "classification", "session123") + + state = manager.get_state("camera1") + assert state["mode"] == "classification" + assert state["backend_session_id"] == "session123" + + def test_set_yolo_inference_enabled(self): + """Test setting YOLO inference state.""" + manager = PipelineStateManager() + manager.clear_all() + + manager.set_yolo_inference_enabled("camera1", False) + + state = manager.get_state("camera1") + assert state["yolo_inference_enabled"] is False + + def test_set_progression_stage(self): + """Test setting progression stage.""" + manager = PipelineStateManager() + manager.clear_all() + + manager.set_progression_stage("camera1", "brand_classification") + + state = manager.get_state("camera1") + assert state["progression_stage"] == "brand_classification" + + def test_set_validated_detection(self): + """Test setting validated detection.""" + manager = PipelineStateManager() + manager.clear_all() + + detection = {"class": "car", "confidence": 0.85} + manager.set_validated_detection("camera1", detection) + + state = manager.get_state("camera1") + assert state["validated_detection"] == detection + + def test_get_stats(self): + """Test getting pipeline state statistics.""" + manager = PipelineStateManager() + manager.clear_all() + + manager.update_mode("camera1", "validation_detecting") + manager.update_mode("camera2", "classification") + manager.update_mode("camera3", "classification") + + stats = manager.get_stats() + + assert stats["total_pipeline_states"] == 3 + assert stats["mode_breakdown"]["validation_detecting"] == 1 + assert stats["mode_breakdown"]["classification"] == 2 + + +class TestThreadSafety: + """Test thread safety of singleton managers.""" + + def test_model_manager_thread_safety(self): + """Test ModelStateManager thread safety.""" + manager = ModelStateManager() + manager.clear_all() + + results = {} + + def load_models(thread_id): + for i in range(10): + model = Mock() + model.thread_id = thread_id + model.model_id = i + manager.load_model(f"camera{thread_id}", f"model{i}", model) + + # Verify models + models = manager.get_camera_models(f"camera{thread_id}") + results[thread_id] = len(models) + + threads = [] + for i in range(5): + thread = threading.Thread(target=load_models, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Each thread should have loaded 10 models + for thread_id, model_count in results.items(): + assert model_count == 10 + + # Total should be 50 models + stats = manager.get_stats() + assert stats["total_models"] == 50 \ No newline at end of file diff --git a/tests/unit/detection/test_detection_result.py b/tests/unit/detection/test_detection_result.py new file mode 100644 index 0000000..c363a41 --- /dev/null +++ b/tests/unit/detection/test_detection_result.py @@ -0,0 +1,479 @@ +""" +Unit tests for detection result data structures. +""" +import pytest +from dataclasses import asdict +import numpy as np + +from detector_worker.detection.detection_result import ( + BoundingBox, + DetectionResult, + LightweightDetectionResult, + DetectionSession, + TrackValidationResult +) + + +class TestBoundingBox: + """Test BoundingBox data structure.""" + + def test_creation_from_coordinates(self): + """Test creating bounding box from coordinates.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + + assert bbox.x1 == 100 + assert bbox.y1 == 200 + assert bbox.x2 == 300 + assert bbox.y2 == 400 + + def test_creation_from_list(self): + """Test creating bounding box from list.""" + coords = [100, 200, 300, 400] + bbox = BoundingBox.from_list(coords) + + assert bbox.x1 == 100 + assert bbox.y1 == 200 + assert bbox.x2 == 300 + assert bbox.y2 == 400 + + def test_creation_from_invalid_list(self): + """Test error handling for invalid list.""" + with pytest.raises(ValueError): + BoundingBox.from_list([100, 200, 300]) # Too few elements + + def test_to_list(self): + """Test converting bounding box to list.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + coords = bbox.to_list() + + assert coords == [100, 200, 300, 400] + + def test_area_calculation(self): + """Test area calculation.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + area = bbox.area() + + expected_area = (300 - 100) * (400 - 200) # 200 * 200 = 40000 + assert area == expected_area + + def test_area_zero_for_invalid_bbox(self): + """Test area is zero for invalid bounding box.""" + # x2 <= x1 + bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400) + assert bbox.area() == 0 + + # y2 <= y1 + bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200) + assert bbox.area() == 0 + + def test_width_height(self): + """Test width and height properties.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + + assert bbox.width() == 200 + assert bbox.height() == 200 + + def test_center_point(self): + """Test center point calculation.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + center = bbox.center() + + assert center == (200, 300) # (x1+x2)/2, (y1+y2)/2 + + def test_is_valid(self): + """Test bounding box validation.""" + # Valid bbox + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + assert bbox.is_valid() is True + + # Invalid bbox (x2 <= x1) + bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400) + assert bbox.is_valid() is False + + # Invalid bbox (y2 <= y1) + bbox = BoundingBox(x1=100, y1=400, x2=300, y2=200) + assert bbox.is_valid() is False + + def test_intersection(self): + """Test bounding box intersection.""" + bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300) + bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400) + + intersection = bbox1.intersection(bbox2) + + assert intersection.x1 == 200 + assert intersection.y1 == 200 + assert intersection.x2 == 300 + assert intersection.y2 == 300 + + def test_no_intersection(self): + """Test no intersection between bounding boxes.""" + bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200) + bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400) + + intersection = bbox1.intersection(bbox2) + + assert intersection.is_valid() is False + + def test_union(self): + """Test bounding box union.""" + bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300) + bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400) + + union = bbox1.union(bbox2) + + assert union.x1 == 100 + assert union.y1 == 100 + assert union.x2 == 400 + assert union.y2 == 400 + + def test_iou_calculation(self): + """Test IoU (Intersection over Union) calculation.""" + # Perfect overlap + bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300) + bbox2 = BoundingBox(x1=100, y1=100, x2=300, y2=300) + assert bbox1.iou(bbox2) == 1.0 + + # No overlap + bbox1 = BoundingBox(x1=100, y1=100, x2=200, y2=200) + bbox2 = BoundingBox(x1=300, y1=300, x2=400, y2=400) + assert bbox1.iou(bbox2) == 0.0 + + # Partial overlap + bbox1 = BoundingBox(x1=100, y1=100, x2=300, y2=300) + bbox2 = BoundingBox(x1=200, y1=200, x2=400, y2=400) + + # Intersection area: 100x100 = 10000 + # Union area: 200x200 + 200x200 - 10000 = 30000 + # IoU = 10000/30000 = 1/3 + expected_iou = 1.0 / 3.0 + assert abs(bbox1.iou(bbox2) - expected_iou) < 1e-6 + + +class TestDetectionResult: + """Test DetectionResult data structure.""" + + def test_creation_with_required_fields(self): + """Test creating detection result with required fields.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=12345 + ) + + assert detection.class_name == "car" + assert detection.confidence == 0.85 + assert detection.bbox == bbox + assert detection.track_id == 12345 + + def test_creation_with_all_fields(self): + """Test creating detection result with all fields.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=12345, + model_id="yolo_v8", + timestamp=1640995200000, + branch_results={"brand": "Toyota"} + ) + + assert detection.model_id == "yolo_v8" + assert detection.timestamp == 1640995200000 + assert detection.branch_results == {"brand": "Toyota"} + + def test_creation_from_dict(self): + """Test creating detection result from dictionary.""" + data = { + "class": "car", + "confidence": 0.85, + "bbox": [100, 200, 300, 400], + "id": 12345, + "model_id": "yolo_v8", + "timestamp": 1640995200000 + } + + detection = DetectionResult.from_dict(data) + + assert detection.class_name == "car" + assert detection.confidence == 0.85 + assert detection.bbox.to_list() == [100, 200, 300, 400] + assert detection.track_id == 12345 + + def test_to_dict(self): + """Test converting detection result to dictionary.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=12345 + ) + + data = detection.to_dict() + + assert data["class"] == "car" + assert data["confidence"] == 0.85 + assert data["bbox"] == [100, 200, 300, 400] + assert data["id"] == 12345 + + def test_is_valid_detection(self): + """Test detection validation.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + + # Valid detection + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=12345 + ) + assert detection.is_valid() is True + + # Invalid confidence (too low) + detection = DetectionResult( + class_name="car", + confidence=-0.1, + bbox=bbox, + track_id=12345 + ) + assert detection.is_valid() is False + + # Invalid confidence (too high) + detection = DetectionResult( + class_name="car", + confidence=1.5, + bbox=bbox, + track_id=12345 + ) + assert detection.is_valid() is False + + # Invalid bounding box + invalid_bbox = BoundingBox(x1=300, y1=200, x2=100, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=invalid_bbox, + track_id=12345 + ) + assert detection.is_valid() is False + + +class TestLightweightDetectionResult: + """Test LightweightDetectionResult data structure.""" + + def test_creation(self): + """Test creating lightweight detection result.""" + detection = LightweightDetectionResult( + class_name="car", + confidence=0.85, + bbox_area=40000, + frame_width=1920, + frame_height=1080 + ) + + assert detection.class_name == "car" + assert detection.confidence == 0.85 + assert detection.bbox_area == 40000 + assert detection.frame_width == 1920 + assert detection.frame_height == 1080 + + def test_area_ratio_calculation(self): + """Test bounding box area ratio calculation.""" + detection = LightweightDetectionResult( + class_name="car", + confidence=0.85, + bbox_area=40000, + frame_width=1920, + frame_height=1080 + ) + + expected_ratio = 40000 / (1920 * 1080) + assert abs(detection.area_ratio() - expected_ratio) < 1e-6 + + def test_meets_threshold(self): + """Test threshold checking.""" + detection = LightweightDetectionResult( + class_name="car", + confidence=0.85, + bbox_area=40000, + frame_width=1920, + frame_height=1080 + ) + + assert detection.meets_threshold(confidence=0.8, area_ratio=0.01) is True + assert detection.meets_threshold(confidence=0.9, area_ratio=0.01) is False + assert detection.meets_threshold(confidence=0.8, area_ratio=0.1) is False + + +class TestDetectionSession: + """Test DetectionSession data structure.""" + + def test_creation(self): + """Test creating detection session.""" + session = DetectionSession( + session_id="session_123", + camera_id="camera_001", + display_id="display_001" + ) + + assert session.session_id == "session_123" + assert session.camera_id == "camera_001" + assert session.display_id == "display_001" + assert session.detections == [] + assert session.metadata == {} + + def test_add_detection(self): + """Test adding detection to session.""" + session = DetectionSession( + session_id="session_123", + camera_id="camera_001", + display_id="display_001" + ) + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=12345 + ) + + session.add_detection(detection) + + assert len(session.detections) == 1 + assert session.detections[0] == detection + + def test_get_latest_detection(self): + """Test getting latest detection.""" + session = DetectionSession( + session_id="session_123", + camera_id="camera_001", + display_id="display_001" + ) + + # Add multiple detections + bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection1 = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox1, + track_id=12345, + timestamp=1640995200000 + ) + + bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450) + detection2 = DetectionResult( + class_name="car", + confidence=0.90, + bbox=bbox2, + track_id=12345, + timestamp=1640995300000 + ) + + session.add_detection(detection1) + session.add_detection(detection2) + + latest = session.get_latest_detection() + assert latest == detection2 # Should be the one with later timestamp + + def test_get_detections_by_class(self): + """Test filtering detections by class.""" + session = DetectionSession( + session_id="session_123", + camera_id="camera_001", + display_id="display_001" + ) + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + + car_detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=12345 + ) + + truck_detection = DetectionResult( + class_name="truck", + confidence=0.80, + bbox=bbox, + track_id=54321 + ) + + session.add_detection(car_detection) + session.add_detection(truck_detection) + + car_detections = session.get_detections_by_class("car") + assert len(car_detections) == 1 + assert car_detections[0] == car_detection + + truck_detections = session.get_detections_by_class("truck") + assert len(truck_detections) == 1 + assert truck_detections[0] == truck_detection + + +class TestTrackValidationResult: + """Test TrackValidationResult data structure.""" + + def test_creation(self): + """Test creating track validation result.""" + result = TrackValidationResult( + stable_tracks=[101, 102, 103], + current_tracks=[101, 102, 104, 105], + newly_stable=[103], + lost_tracks=[106] + ) + + assert result.stable_tracks == [101, 102, 103] + assert result.current_tracks == [101, 102, 104, 105] + assert result.newly_stable == [103] + assert result.lost_tracks == [106] + + def test_has_stable_tracks(self): + """Test checking for stable tracks.""" + result = TrackValidationResult( + stable_tracks=[101, 102], + current_tracks=[101, 102, 103] + ) + + assert result.has_stable_tracks() is True + + result_empty = TrackValidationResult( + stable_tracks=[], + current_tracks=[101, 102, 103] + ) + + assert result_empty.has_stable_tracks() is False + + def test_get_stats(self): + """Test getting validation statistics.""" + result = TrackValidationResult( + stable_tracks=[101, 102, 103], + current_tracks=[101, 102, 104, 105], + newly_stable=[103], + lost_tracks=[106] + ) + + stats = result.get_stats() + + assert stats["stable_count"] == 3 + assert stats["current_count"] == 4 + assert stats["newly_stable_count"] == 1 + assert stats["lost_count"] == 1 + assert stats["stability_ratio"] == 3/4 # stable/current + + def test_is_track_stable(self): + """Test checking if specific track is stable.""" + result = TrackValidationResult( + stable_tracks=[101, 102, 103], + current_tracks=[101, 102, 104, 105] + ) + + assert result.is_track_stable(101) is True + assert result.is_track_stable(102) is True + assert result.is_track_stable(104) is False + assert result.is_track_stable(999) is False \ No newline at end of file diff --git a/tests/unit/detection/test_stability_validator.py b/tests/unit/detection/test_stability_validator.py new file mode 100644 index 0000000..bf748c2 --- /dev/null +++ b/tests/unit/detection/test_stability_validator.py @@ -0,0 +1,701 @@ +""" +Unit tests for track stability validation. +""" +import pytest +import time +from unittest.mock import Mock, patch +from collections import defaultdict + +from detector_worker.detection.stability_validator import ( + StabilityValidator, + StabilityConfig, + ValidationResult, + TrackStabilityMetrics +) +from detector_worker.detection.detection_result import DetectionResult, BoundingBox, TrackValidationResult +from detector_worker.core.exceptions import ValidationError + + +class TestStabilityConfig: + """Test stability configuration data structure.""" + + def test_default_config(self): + """Test default stability configuration.""" + config = StabilityConfig() + + assert config.min_detection_frames == 10 + assert config.max_absence_frames == 30 + assert config.confidence_threshold == 0.5 + assert config.stability_window == 60.0 + assert config.iou_threshold == 0.3 + assert config.movement_threshold == 50.0 + + def test_custom_config(self): + """Test custom stability configuration.""" + config = StabilityConfig( + min_detection_frames=5, + max_absence_frames=15, + confidence_threshold=0.8, + stability_window=30.0, + iou_threshold=0.5, + movement_threshold=25.0 + ) + + assert config.min_detection_frames == 5 + assert config.max_absence_frames == 15 + assert config.confidence_threshold == 0.8 + assert config.stability_window == 30.0 + assert config.iou_threshold == 0.5 + assert config.movement_threshold == 25.0 + + def test_from_dict(self): + """Test creating config from dictionary.""" + config_dict = { + "min_detection_frames": 8, + "max_absence_frames": 25, + "confidence_threshold": 0.75, + "unknown_field": "ignored" + } + + config = StabilityConfig.from_dict(config_dict) + + assert config.min_detection_frames == 8 + assert config.max_absence_frames == 25 + assert config.confidence_threshold == 0.75 + # Unknown fields should use defaults + assert config.stability_window == 60.0 + + +class TestTrackStabilityMetrics: + """Test track stability metrics.""" + + def test_initialization(self): + """Test metrics initialization.""" + metrics = TrackStabilityMetrics(track_id=1001) + + assert metrics.track_id == 1001 + assert metrics.detection_count == 0 + assert metrics.absence_count == 0 + assert metrics.total_confidence == 0.0 + assert metrics.first_detection_time is None + assert metrics.last_detection_time is None + assert metrics.bounding_boxes == [] + assert metrics.confidence_scores == [] + + def test_add_detection(self): + """Test adding detection to metrics.""" + metrics = TrackStabilityMetrics(track_id=1001) + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + metrics.add_detection(detection, current_time=1640995200.0) + + assert metrics.detection_count == 1 + assert metrics.absence_count == 0 + assert metrics.total_confidence == 0.85 + assert metrics.first_detection_time == 1640995200.0 + assert metrics.last_detection_time == 1640995200.0 + assert len(metrics.bounding_boxes) == 1 + assert len(metrics.confidence_scores) == 1 + + def test_increment_absence(self): + """Test incrementing absence count.""" + metrics = TrackStabilityMetrics(track_id=1001) + + metrics.increment_absence() + assert metrics.absence_count == 1 + + metrics.increment_absence() + assert metrics.absence_count == 2 + + def test_reset_absence(self): + """Test resetting absence count.""" + metrics = TrackStabilityMetrics(track_id=1001) + + metrics.increment_absence() + metrics.increment_absence() + assert metrics.absence_count == 2 + + metrics.reset_absence() + assert metrics.absence_count == 0 + + def test_average_confidence(self): + """Test average confidence calculation.""" + metrics = TrackStabilityMetrics(track_id=1001) + + # No detections + assert metrics.average_confidence() == 0.0 + + # Add detections + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + + detection1 = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + detection2 = DetectionResult( + class_name="car", + confidence=0.9, + bbox=bbox, + track_id=1001, + timestamp=1640995300000 + ) + + metrics.add_detection(detection1, current_time=1640995200.0) + metrics.add_detection(detection2, current_time=1640995300.0) + + assert metrics.average_confidence() == 0.85 # (0.8 + 0.9) / 2 + + def test_tracking_duration(self): + """Test tracking duration calculation.""" + metrics = TrackStabilityMetrics(track_id=1001) + + # No detections + assert metrics.tracking_duration() == 0.0 + + # Add detections + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + + detection1 = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + detection2 = DetectionResult( + class_name="car", + confidence=0.9, + bbox=bbox, + track_id=1001, + timestamp=1640995300000 + ) + + metrics.add_detection(detection1, current_time=1640995200.0) + metrics.add_detection(detection2, current_time=1640995300.0) + + assert metrics.tracking_duration() == 100.0 # 1640995300 - 1640995200 + + def test_movement_distance(self): + """Test movement distance calculation.""" + metrics = TrackStabilityMetrics(track_id=1001) + + # No movement with single detection + bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection1 = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox1, + track_id=1001, + timestamp=1640995200000 + ) + + metrics.add_detection(detection1, current_time=1640995200.0) + assert metrics.total_movement_distance() == 0.0 + + # Add second detection with movement + bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410) + detection2 = DetectionResult( + class_name="car", + confidence=0.9, + bbox=bbox2, + track_id=1001, + timestamp=1640995300000 + ) + + metrics.add_detection(detection2, current_time=1640995300.0) + + # Distance between centers: (200,300) to (210,310) = sqrt(100+100) โ‰ˆ 14.14 + movement = metrics.total_movement_distance() + assert movement == pytest.approx(14.14, rel=1e-2) + + +class TestValidationResult: + """Test validation result data structure.""" + + def test_initialization(self): + """Test validation result initialization.""" + result = ValidationResult( + track_id=1001, + is_stable=True, + detection_count=15, + absence_count=2, + average_confidence=0.85, + tracking_duration=120.0 + ) + + assert result.track_id == 1001 + assert result.is_stable is True + assert result.detection_count == 15 + assert result.absence_count == 2 + assert result.average_confidence == 0.85 + assert result.tracking_duration == 120.0 + assert result.reasons == [] + + def test_with_reasons(self): + """Test validation result with failure reasons.""" + result = ValidationResult( + track_id=1001, + is_stable=False, + detection_count=5, + absence_count=35, + average_confidence=0.4, + tracking_duration=30.0, + reasons=["Insufficient detection frames", "Too many absences", "Low confidence"] + ) + + assert result.is_stable is False + assert len(result.reasons) == 3 + assert "Insufficient detection frames" in result.reasons + + +class TestStabilityValidator: + """Test stability validation functionality.""" + + def test_initialization_default(self): + """Test validator initialization with default config.""" + validator = StabilityValidator() + + assert isinstance(validator.config, StabilityConfig) + assert validator.config.min_detection_frames == 10 + assert len(validator.track_metrics) == 0 + + def test_initialization_custom_config(self): + """Test validator initialization with custom config.""" + config = StabilityConfig(min_detection_frames=5, confidence_threshold=0.8) + validator = StabilityValidator(config) + + assert validator.config.min_detection_frames == 5 + assert validator.config.confidence_threshold == 0.8 + + def test_update_detections_new_track(self): + """Test updating with new track.""" + validator = StabilityValidator() + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + validator.update_detections([detection], current_time=1640995200.0) + + assert 1001 in validator.track_metrics + metrics = validator.track_metrics[1001] + assert metrics.detection_count == 1 + assert metrics.absence_count == 0 + + def test_update_detections_existing_track(self): + """Test updating existing track.""" + validator = StabilityValidator() + + # First detection + bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection1 = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox1, + track_id=1001, + timestamp=1640995200000 + ) + + validator.update_detections([detection1], current_time=1640995200.0) + + # Second detection + bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410) + detection2 = DetectionResult( + class_name="car", + confidence=0.9, + bbox=bbox2, + track_id=1001, + timestamp=1640995300000 + ) + + validator.update_detections([detection2], current_time=1640995300.0) + + metrics = validator.track_metrics[1001] + assert metrics.detection_count == 2 + assert metrics.absence_count == 0 + assert metrics.average_confidence() == 0.85 + + def test_update_detections_missing_track(self): + """Test updating when track is missing (increment absence).""" + validator = StabilityValidator() + + # Add track + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + validator.update_detections([detection], current_time=1640995200.0) + + # Update with empty detections + validator.update_detections([], current_time=1640995300.0) + + metrics = validator.track_metrics[1001] + assert metrics.detection_count == 1 + assert metrics.absence_count == 1 + + def test_validate_track_stable(self): + """Test validating a stable track.""" + config = StabilityConfig(min_detection_frames=3, max_absence_frames=5) + validator = StabilityValidator(config) + + # Create track with sufficient detections + track_id = 1001 + validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) + metrics = validator.track_metrics[track_id] + + # Add sufficient detections + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + for i in range(5): + detection = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox, + track_id=track_id, + timestamp=1640995200000 + i * 1000 + ) + metrics.add_detection(detection, current_time=1640995200.0 + i) + + result = validator.validate_track(track_id) + + assert result.is_stable is True + assert result.detection_count == 5 + assert result.absence_count == 0 + assert len(result.reasons) == 0 + + def test_validate_track_insufficient_detections(self): + """Test validating track with insufficient detections.""" + config = StabilityConfig(min_detection_frames=10, max_absence_frames=5) + validator = StabilityValidator(config) + + # Create track with insufficient detections + track_id = 1001 + validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) + metrics = validator.track_metrics[track_id] + + # Add only few detections + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + for i in range(3): + detection = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox, + track_id=track_id, + timestamp=1640995200000 + i * 1000 + ) + metrics.add_detection(detection, current_time=1640995200.0 + i) + + result = validator.validate_track(track_id) + + assert result.is_stable is False + assert "Insufficient detection frames" in result.reasons + + def test_validate_track_too_many_absences(self): + """Test validating track with too many absences.""" + config = StabilityConfig(min_detection_frames=3, max_absence_frames=2) + validator = StabilityValidator(config) + + # Create track with too many absences + track_id = 1001 + validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) + metrics = validator.track_metrics[track_id] + + # Add detections and absences + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + for i in range(5): + detection = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox, + track_id=track_id, + timestamp=1640995200000 + i * 1000 + ) + metrics.add_detection(detection, current_time=1640995200.0 + i) + + # Add too many absences + for _ in range(5): + metrics.increment_absence() + + result = validator.validate_track(track_id) + + assert result.is_stable is False + assert "Too many absence frames" in result.reasons + + def test_validate_track_low_confidence(self): + """Test validating track with low confidence.""" + config = StabilityConfig( + min_detection_frames=3, + max_absence_frames=5, + confidence_threshold=0.8 + ) + validator = StabilityValidator(config) + + # Create track with low confidence + track_id = 1001 + validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) + metrics = validator.track_metrics[track_id] + + # Add detections with low confidence + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + for i in range(5): + detection = DetectionResult( + class_name="car", + confidence=0.5, # Below threshold + bbox=bbox, + track_id=track_id, + timestamp=1640995200000 + i * 1000 + ) + metrics.add_detection(detection, current_time=1640995200.0 + i) + + result = validator.validate_track(track_id) + + assert result.is_stable is False + assert "Low average confidence" in result.reasons + + def test_validate_all_tracks(self): + """Test validating all tracks.""" + config = StabilityConfig(min_detection_frames=3) + validator = StabilityValidator(config) + + # Add multiple tracks + for track_id in [1001, 1002, 1003]: + validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) + metrics = validator.track_metrics[track_id] + + # Make some tracks stable, others not + detection_count = 5 if track_id == 1001 else 2 + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + + for i in range(detection_count): + detection = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox, + track_id=track_id, + timestamp=1640995200000 + i * 1000 + ) + metrics.add_detection(detection, current_time=1640995200.0 + i) + + results = validator.validate_all_tracks() + + assert len(results) == 3 + assert results[1001].is_stable is True # 5 detections + assert results[1002].is_stable is False # 2 detections + assert results[1003].is_stable is False # 2 detections + + def test_get_stable_tracks(self): + """Test getting stable track IDs.""" + config = StabilityConfig(min_detection_frames=3) + validator = StabilityValidator(config) + + # Add tracks with different stability + for track_id, detection_count in [(1001, 5), (1002, 2), (1003, 4)]: + validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) + metrics = validator.track_metrics[track_id] + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + for i in range(detection_count): + detection = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox, + track_id=track_id, + timestamp=1640995200000 + i * 1000 + ) + metrics.add_detection(detection, current_time=1640995200.0 + i) + + stable_tracks = validator.get_stable_tracks() + + assert stable_tracks == [1001, 1003] # 5 and 4 detections respectively + + def test_cleanup_expired_tracks(self): + """Test cleanup of expired tracks.""" + config = StabilityConfig(stability_window=10.0) + validator = StabilityValidator(config) + + # Add tracks with different last detection times + current_time = 1640995300.0 + + for track_id, last_detection_time in [(1001, current_time - 5), (1002, current_time - 15)]: + validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) + metrics = validator.track_metrics[track_id] + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox, + track_id=track_id, + timestamp=int(last_detection_time * 1000) + ) + metrics.add_detection(detection, current_time=last_detection_time) + + removed_count = validator.cleanup_expired_tracks(current_time) + + assert removed_count == 1 # 1002 should be removed (15 > 10 seconds) + assert 1001 in validator.track_metrics + assert 1002 not in validator.track_metrics + + def test_clear_all_tracks(self): + """Test clearing all track metrics.""" + validator = StabilityValidator() + + # Add some tracks + for track_id in [1001, 1002]: + validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) + + assert len(validator.track_metrics) == 2 + + validator.clear_all_tracks() + + assert len(validator.track_metrics) == 0 + + def test_get_validation_summary(self): + """Test getting validation summary statistics.""" + config = StabilityConfig(min_detection_frames=3) + validator = StabilityValidator(config) + + # Add tracks with different characteristics + track_data = [ + (1001, 5, True), # Stable + (1002, 2, False), # Unstable + (1003, 4, True), # Stable + (1004, 1, False) # Unstable + ] + + for track_id, detection_count, _ in track_data: + validator.track_metrics[track_id] = TrackStabilityMetrics(track_id) + metrics = validator.track_metrics[track_id] + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + for i in range(detection_count): + detection = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox, + track_id=track_id, + timestamp=1640995200000 + i * 1000 + ) + metrics.add_detection(detection, current_time=1640995200.0 + i) + + summary = validator.get_validation_summary() + + assert summary["total_tracks"] == 4 + assert summary["stable_tracks"] == 2 + assert summary["unstable_tracks"] == 2 + assert summary["stability_rate"] == 0.5 + + +class TestStabilityValidatorIntegration: + """Integration tests for stability validator.""" + + def test_full_tracking_lifecycle(self): + """Test complete tracking lifecycle with stability validation.""" + config = StabilityConfig( + min_detection_frames=3, + max_absence_frames=2, + confidence_threshold=0.7 + ) + validator = StabilityValidator(config) + + track_id = 1001 + + # Phase 1: Initial detections (building up) + for i in range(5): + bbox = BoundingBox(x1=100+i*2, y1=200+i*2, x2=300+i*2, y2=400+i*2) + detection = DetectionResult( + class_name="car", + confidence=0.8, + bbox=bbox, + track_id=track_id, + timestamp=1640995200000 + i * 1000 + ) + validator.update_detections([detection], current_time=1640995200.0 + i) + + # Should be stable now + result = validator.validate_track(track_id) + assert result.is_stable is True + + # Phase 2: Some absences + for i in range(2): + validator.update_detections([], current_time=1640995205.0 + i) + + # Still stable (within absence threshold) + result = validator.validate_track(track_id) + assert result.is_stable is True + + # Phase 3: Track reappears + bbox = BoundingBox(x1=120, y1=220, x2=320, y2=420) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=track_id, + timestamp=1640995207000 + ) + validator.update_detections([detection], current_time=1640995207.0) + + # Should reset absence count and remain stable + result = validator.validate_track(track_id) + assert result.is_stable is True + assert validator.track_metrics[track_id].absence_count == 0 + + def test_multi_track_validation(self): + """Test validation with multiple tracks.""" + validator = StabilityValidator() + + # Simulate multi-track scenario + frame_detections = [ + # Frame 1 + [ + DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000), + DetectionResult("truck", 0.8, BoundingBox(400, 200, 600, 400), 1002, 1640995200000) + ], + # Frame 2 + [ + DetectionResult("car", 0.85, BoundingBox(105, 205, 305, 405), 1001, 1640995201000), + DetectionResult("truck", 0.82, BoundingBox(405, 205, 605, 405), 1002, 1640995201000), + DetectionResult("car", 0.75, BoundingBox(200, 300, 400, 500), 1003, 1640995201000) + ], + # Frame 3 - track 1002 disappears + [ + DetectionResult("car", 0.88, BoundingBox(110, 210, 310, 410), 1001, 1640995202000), + DetectionResult("car", 0.78, BoundingBox(205, 305, 405, 505), 1003, 1640995202000) + ] + ] + + # Process frames + for i, detections in enumerate(frame_detections): + validator.update_detections(detections, current_time=1640995200.0 + i) + + # Get validation results + validation_results = validator.validate_all_tracks() + + assert len(validation_results) == 3 + + # All tracks should be unstable (insufficient frames) + for result in validation_results.values(): + assert result.is_stable is False + assert "Insufficient detection frames" in result.reasons \ No newline at end of file diff --git a/tests/unit/detection/test_tracking_manager.py b/tests/unit/detection/test_tracking_manager.py new file mode 100644 index 0000000..3d63535 --- /dev/null +++ b/tests/unit/detection/test_tracking_manager.py @@ -0,0 +1,606 @@ +""" +Unit tests for BoT-SORT tracking management. +""" +import pytest +import numpy as np +from unittest.mock import Mock, MagicMock, patch +from collections import defaultdict + +from detector_worker.detection.tracking_manager import TrackingManager, TrackInfo +from detector_worker.detection.detection_result import DetectionResult, BoundingBox +from detector_worker.core.exceptions import TrackingError + + +class TestTrackInfo: + """Test TrackInfo data structure.""" + + def test_creation(self): + """Test TrackInfo creation.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + track = TrackInfo( + track_id=1001, + bbox=bbox, + confidence=0.85, + class_name="car", + first_seen=1640995200.0, + last_seen=1640995300.0 + ) + + assert track.track_id == 1001 + assert track.bbox == bbox + assert track.confidence == 0.85 + assert track.class_name == "car" + assert track.first_seen == 1640995200.0 + assert track.last_seen == 1640995300.0 + assert track.frame_count == 1 + assert track.absence_count == 0 + + def test_update_track(self): + """Test updating track information.""" + bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) + track = TrackInfo( + track_id=1001, + bbox=bbox1, + confidence=0.85, + class_name="car", + first_seen=1640995200.0, + last_seen=1640995200.0 + ) + + bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410) + track.update(bbox2, 0.90, 1640995300.0) + + assert track.bbox == bbox2 + assert track.confidence == 0.90 + assert track.last_seen == 1640995300.0 + assert track.frame_count == 2 + assert track.absence_count == 0 + + def test_increment_absence(self): + """Test incrementing absence count.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + track = TrackInfo( + track_id=1001, + bbox=bbox, + confidence=0.85, + class_name="car", + first_seen=1640995200.0, + last_seen=1640995200.0 + ) + + track.increment_absence() + assert track.absence_count == 1 + + track.increment_absence() + assert track.absence_count == 2 + + def test_age_calculation(self): + """Test track age calculation.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + track = TrackInfo( + track_id=1001, + bbox=bbox, + confidence=0.85, + class_name="car", + first_seen=1640995200.0, + last_seen=1640995300.0 + ) + + age = track.age(current_time=1640995400.0) + assert age == 200.0 # 1640995400 - 1640995200 + + def test_time_since_last_seen(self): + """Test time since last seen calculation.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + track = TrackInfo( + track_id=1001, + bbox=bbox, + confidence=0.85, + class_name="car", + first_seen=1640995200.0, + last_seen=1640995300.0 + ) + + time_since = track.time_since_last_seen(current_time=1640995450.0) + assert time_since == 150.0 # 1640995450 - 1640995300 + + def test_is_stable(self): + """Test track stability checking.""" + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + track = TrackInfo( + track_id=1001, + bbox=bbox, + confidence=0.85, + class_name="car", + first_seen=1640995200.0, + last_seen=1640995300.0 + ) + + # Not stable initially + assert track.is_stable(min_frames=5, max_absence=3) is False + + # Make it stable + track.frame_count = 10 + track.absence_count = 1 + assert track.is_stable(min_frames=5, max_absence=3) is True + + # Too many absences + track.absence_count = 5 + assert track.is_stable(min_frames=5, max_absence=3) is False + + +class TestTrackingManager: + """Test tracking management functionality.""" + + def test_initialization(self): + """Test tracking manager initialization.""" + manager = TrackingManager() + + assert manager.max_absence_frames == 30 + assert manager.min_stable_frames == 10 + assert manager.track_timeout == 60.0 + assert len(manager.active_tracks) == 0 + assert len(manager.stable_tracks) == 0 + + def test_initialization_with_config(self): + """Test initialization with custom configuration.""" + config = { + "max_absence_frames": 20, + "min_stable_frames": 5, + "track_timeout": 30.0 + } + manager = TrackingManager(config) + + assert manager.max_absence_frames == 20 + assert manager.min_stable_frames == 5 + assert manager.track_timeout == 30.0 + + def test_update_tracks_new_detections(self): + """Test updating with new detections.""" + manager = TrackingManager() + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + manager.update_tracks([detection], current_time=1640995200.0) + + assert len(manager.active_tracks) == 1 + assert 1001 in manager.active_tracks + + track = manager.active_tracks[1001] + assert track.track_id == 1001 + assert track.class_name == "car" + assert track.confidence == 0.85 + assert track.frame_count == 1 + + def test_update_tracks_existing_detection(self): + """Test updating existing track.""" + manager = TrackingManager() + + # First detection + bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection1 = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox1, + track_id=1001, + timestamp=1640995200000 + ) + + manager.update_tracks([detection1], current_time=1640995200.0) + + # Second detection (same track, different position) + bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410) + detection2 = DetectionResult( + class_name="car", + confidence=0.90, + bbox=bbox2, + track_id=1001, + timestamp=1640995300000 + ) + + manager.update_tracks([detection2], current_time=1640995300.0) + + assert len(manager.active_tracks) == 1 + track = manager.active_tracks[1001] + assert track.frame_count == 2 + assert track.confidence == 0.90 + assert track.bbox == bbox2 + assert track.absence_count == 0 + + def test_update_tracks_no_detections(self): + """Test updating with no detections (increment absence).""" + manager = TrackingManager() + + # Add initial track + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + manager.update_tracks([detection], current_time=1640995200.0) + + # Update with no detections + manager.update_tracks([], current_time=1640995300.0) + + track = manager.active_tracks[1001] + assert track.absence_count == 1 + + def test_cleanup_expired_tracks(self): + """Test cleanup of expired tracks.""" + manager = TrackingManager({"track_timeout": 10.0}) + + # Add track + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + manager.update_tracks([detection], current_time=1640995200.0) + assert len(manager.active_tracks) == 1 + + # Cleanup after timeout + removed_count = manager.cleanup_expired_tracks(current_time=1640995220.0) # 20 seconds later + + assert removed_count == 1 + assert len(manager.active_tracks) == 0 + + def test_cleanup_absent_tracks(self): + """Test cleanup of tracks with too many absences.""" + manager = TrackingManager({"max_absence_frames": 3}) + + # Add track + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + manager.update_tracks([detection], current_time=1640995200.0) + + # Increment absence count beyond threshold + for i in range(5): + manager.update_tracks([], current_time=1640995200.0 + i) + + track = manager.active_tracks[1001] + assert track.absence_count == 5 + + # Cleanup absent tracks + removed_count = manager.cleanup_absent_tracks() + + assert removed_count == 1 + assert len(manager.active_tracks) == 0 + + def test_get_stable_tracks(self): + """Test getting stable tracks.""" + manager = TrackingManager({"min_stable_frames": 3}) + + # Add track and make it stable + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + track_info = TrackInfo( + track_id=1001, + bbox=bbox, + confidence=0.85, + class_name="car", + first_seen=1640995200.0, + last_seen=1640995300.0 + ) + track_info.frame_count = 5 # Make it stable + + manager.active_tracks[1001] = track_info + + stable_tracks = manager.get_stable_tracks() + + assert len(stable_tracks) == 1 + assert 1001 in stable_tracks + assert 1001 in manager.stable_tracks # Should be cached + + def test_get_track_by_id(self): + """Test getting track by ID.""" + manager = TrackingManager() + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + manager.update_tracks([detection], current_time=1640995200.0) + + track = manager.get_track_by_id(1001) + assert track is not None + assert track.track_id == 1001 + + non_existent = manager.get_track_by_id(9999) + assert non_existent is None + + def test_get_tracks_by_class(self): + """Test getting tracks by class name.""" + manager = TrackingManager() + + # Add different classes + bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection1 = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox1, + track_id=1001, + timestamp=1640995200000 + ) + + bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450) + detection2 = DetectionResult( + class_name="truck", + confidence=0.80, + bbox=bbox2, + track_id=1002, + timestamp=1640995200000 + ) + + bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500) + detection3 = DetectionResult( + class_name="car", + confidence=0.90, + bbox=bbox3, + track_id=1003, + timestamp=1640995200000 + ) + + manager.update_tracks([detection1, detection2, detection3], current_time=1640995200.0) + + car_tracks = manager.get_tracks_by_class("car") + assert len(car_tracks) == 2 + assert 1001 in car_tracks + assert 1003 in car_tracks + + truck_tracks = manager.get_tracks_by_class("truck") + assert len(truck_tracks) == 1 + assert 1002 in truck_tracks + + def test_get_track_count(self): + """Test getting track counts.""" + manager = TrackingManager() + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + manager.update_tracks([detection], current_time=1640995200.0) + + assert manager.get_active_track_count() == 1 + assert manager.get_track_count_by_class("car") == 1 + assert manager.get_track_count_by_class("truck") == 0 + + def test_clear_all_tracks(self): + """Test clearing all tracks.""" + manager = TrackingManager() + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001, + timestamp=1640995200000 + ) + + manager.update_tracks([detection], current_time=1640995200.0) + assert len(manager.active_tracks) == 1 + + manager.clear_all_tracks() + + assert len(manager.active_tracks) == 0 + assert len(manager.stable_tracks) == 0 + + def test_get_track_statistics(self): + """Test getting track statistics.""" + manager = TrackingManager({"min_stable_frames": 2}) + + # Add multiple tracks + detections = [] + for i in range(3): + bbox = BoundingBox(x1=100+i*50, y1=200, x2=300+i*50, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=1001+i, + timestamp=1640995200000 + ) + detections.append(detection) + + manager.update_tracks(detections, current_time=1640995200.0) + + # Make some tracks stable + manager.active_tracks[1001].frame_count = 5 + manager.active_tracks[1002].frame_count = 3 + # 1003 remains unstable with frame_count=1 + + stats = manager.get_track_statistics() + + assert stats["active_tracks"] == 3 + assert stats["stable_tracks"] == 2 + assert stats["unstable_tracks"] == 1 + assert "average_track_age" in stats + assert "average_confidence" in stats + + def test_validate_tracks(self): + """Test track validation.""" + manager = TrackingManager({"min_stable_frames": 3, "max_absence_frames": 2}) + + # Add tracks with different stability + bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) + track1 = TrackInfo( + track_id=1001, + bbox=bbox1, + confidence=0.85, + class_name="car", + first_seen=1640995200.0, + last_seen=1640995300.0 + ) + track1.frame_count = 5 # Stable + track1.absence_count = 1 # Present + + bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450) + track2 = TrackInfo( + track_id=1002, + bbox=bbox2, + confidence=0.80, + class_name="car", + first_seen=1640995200.0, + last_seen=1640995250.0 + ) + track2.frame_count = 2 # Not stable + track2.absence_count = 1 + + bbox3 = BoundingBox(x1=200, y1=300, x2=400, y2=500) + track3 = TrackInfo( + track_id=1003, + bbox=bbox3, + confidence=0.90, + class_name="car", + first_seen=1640995100.0, + last_seen=1640995150.0 + ) + track3.frame_count = 8 # Was stable but now absent + track3.absence_count = 5 # Too many absences + + manager.active_tracks = {1001: track1, 1002: track2, 1003: track3} + manager.stable_tracks = {1001, 1003} # 1003 was previously stable + + validation_result = manager.validate_tracks() + + assert validation_result.stable_tracks == [1001] + assert validation_result.current_tracks == [1001, 1002, 1003] + assert validation_result.newly_stable == [] + assert validation_result.lost_tracks == [1003] + + def test_track_persistence_across_frames(self): + """Test track persistence across multiple frames.""" + manager = TrackingManager() + + # Frame 1 + bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection1 = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox1, + track_id=1001, + timestamp=1640995200000 + ) + + manager.update_tracks([detection1], current_time=1640995200.0) + + # Frame 2 - track moves + bbox2 = BoundingBox(x1=110, y1=210, x2=310, y2=410) + detection2 = DetectionResult( + class_name="car", + confidence=0.88, + bbox=bbox2, + track_id=1001, + timestamp=1640995300000 + ) + + manager.update_tracks([detection2], current_time=1640995300.0) + + # Frame 3 - track disappears + manager.update_tracks([], current_time=1640995400.0) + + # Frame 4 - track reappears + bbox4 = BoundingBox(x1=120, y1=220, x2=320, y2=420) + detection4 = DetectionResult( + class_name="car", + confidence=0.82, + bbox=bbox4, + track_id=1001, + timestamp=1640995500000 + ) + + manager.update_tracks([detection4], current_time=1640995500.0) + + track = manager.active_tracks[1001] + assert track.frame_count == 3 # Seen in 3 frames + assert track.absence_count == 0 # Reset when reappeared + assert track.bbox == bbox4 # Latest position + + +class TestTrackingManagerErrorHandling: + """Test error handling in tracking manager.""" + + def test_invalid_detection_input(self): + """Test handling of invalid detection input.""" + manager = TrackingManager() + + # None detection should be handled gracefully + with pytest.raises(TrackingError): + manager.update_tracks([None], current_time=1640995200.0) + + def test_negative_track_id(self): + """Test handling of negative track ID.""" + manager = TrackingManager() + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox, + track_id=-1, # Invalid track ID + timestamp=1640995200000 + ) + + with pytest.raises(TrackingError): + manager.update_tracks([detection], current_time=1640995200.0) + + def test_duplicate_track_ids_different_classes(self): + """Test handling of duplicate track IDs with different classes.""" + manager = TrackingManager() + + bbox1 = BoundingBox(x1=100, y1=200, x2=300, y2=400) + detection1 = DetectionResult( + class_name="car", + confidence=0.85, + bbox=bbox1, + track_id=1001, + timestamp=1640995200000 + ) + + bbox2 = BoundingBox(x1=150, y1=250, x2=350, y2=450) + detection2 = DetectionResult( + class_name="truck", # Different class, same ID + confidence=0.80, + bbox=bbox2, + track_id=1001, + timestamp=1640995200000 + ) + + # Should log warning but handle gracefully + manager.update_tracks([detection1, detection2], current_time=1640995200.0) + + # The later detection should update the track + track = manager.active_tracks[1001] + assert track.class_name == "truck" # Last update wins \ No newline at end of file diff --git a/tests/unit/detection/test_yolo_detector.py b/tests/unit/detection/test_yolo_detector.py new file mode 100644 index 0000000..c5dd0bb --- /dev/null +++ b/tests/unit/detection/test_yolo_detector.py @@ -0,0 +1,386 @@ +""" +Unit tests for YOLO detector with tracking functionality. +""" +import pytest +import numpy as np +from unittest.mock import Mock, MagicMock, patch +import torch + +from detector_worker.detection.yolo_detector import YOLODetector +from detector_worker.detection.detection_result import DetectionResult, BoundingBox +from detector_worker.core.exceptions import DetectionError + + +class TestYOLODetector: + """Test YOLO detection and tracking functionality.""" + + def test_initialization_with_valid_model(self, mock_yolo_model): + """Test detector initialization with valid model.""" + detector = YOLODetector(mock_yolo_model) + + assert detector.model is mock_yolo_model + assert detector.class_names == {} + assert detector.is_tracking_enabled is True + + def test_initialization_with_class_names(self, mock_yolo_model): + """Test detector initialization with class names.""" + class_names = {0: "car", 1: "truck", 2: "bus"} + detector = YOLODetector(mock_yolo_model, class_names=class_names) + + assert detector.class_names == class_names + + def test_initialization_tracking_disabled(self, mock_yolo_model): + """Test detector initialization with tracking disabled.""" + detector = YOLODetector(mock_yolo_model, enable_tracking=False) + + assert detector.is_tracking_enabled is False + + def test_detect_with_tracking(self, mock_yolo_model, mock_frame): + """Test detection with tracking enabled.""" + # Mock detection result + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0], # x1, y1, x2, y2, conf, class + [150, 250, 350, 450, 0.85, 1] + ]) + mock_result.boxes.id = torch.tensor([1001, 1002]) + + mock_yolo_model.track.return_value = [mock_result] + + detector = YOLODetector(mock_yolo_model) + detections = detector.detect(mock_frame) + + assert len(detections) == 2 + assert detections[0].confidence == 0.9 + assert detections[0].track_id == 1001 + assert detections[0].bbox.x1 == 100 + + mock_yolo_model.track.assert_called_once_with(mock_frame, persist=True, verbose=False) + + def test_detect_without_tracking(self, mock_yolo_model, mock_frame): + """Test detection with tracking disabled.""" + # Mock detection result + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0] + ]) + mock_result.boxes.id = None # No tracking IDs + + mock_yolo_model.predict.return_value = [mock_result] + + detector = YOLODetector(mock_yolo_model, enable_tracking=False) + detections = detector.detect(mock_frame) + + assert len(detections) == 1 + assert detections[0].track_id is None # No tracking ID + + mock_yolo_model.predict.assert_called_once_with(mock_frame, verbose=False) + + def test_detect_with_class_names(self, mock_yolo_model, mock_frame): + """Test detection with class name mapping.""" + class_names = {0: "car", 1: "truck"} + + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0], # car + [150, 250, 350, 450, 0.85, 1] # truck + ]) + mock_result.boxes.id = torch.tensor([1001, 1002]) + + mock_yolo_model.track.return_value = [mock_result] + + detector = YOLODetector(mock_yolo_model, class_names=class_names) + detections = detector.detect(mock_frame) + + assert detections[0].class_name == "car" + assert detections[1].class_name == "truck" + + def test_detect_no_boxes(self, mock_yolo_model, mock_frame): + """Test detection when no objects are detected.""" + mock_result = Mock() + mock_result.boxes = None + + mock_yolo_model.track.return_value = [mock_result] + + detector = YOLODetector(mock_yolo_model) + detections = detector.detect(mock_frame) + + assert detections == [] + + def test_detect_empty_boxes(self, mock_yolo_model, mock_frame): + """Test detection with empty boxes tensor.""" + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([]).reshape(0, 6) + mock_result.boxes.id = None + + mock_yolo_model.track.return_value = [mock_result] + + detector = YOLODetector(mock_yolo_model) + detections = detector.detect(mock_frame) + + assert detections == [] + + def test_detect_with_confidence_threshold(self, mock_yolo_model, mock_frame): + """Test detection with confidence threshold filtering.""" + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0], # Above threshold + [150, 250, 350, 450, 0.3, 1] # Below threshold + ]) + mock_result.boxes.id = torch.tensor([1001, 1002]) + + mock_yolo_model.track.return_value = [mock_result] + + detector = YOLODetector(mock_yolo_model) + detections = detector.detect(mock_frame, confidence_threshold=0.5) + + assert len(detections) == 1 # Only one above threshold + assert detections[0].confidence == 0.9 + + def test_detect_model_error_handling(self, mock_yolo_model, mock_frame): + """Test error handling when model fails.""" + mock_yolo_model.track.side_effect = Exception("Model inference failed") + + detector = YOLODetector(mock_yolo_model) + + with pytest.raises(DetectionError) as exc_info: + detector.detect(mock_frame) + + assert "Model inference failed" in str(exc_info.value) + + def test_detect_invalid_frame(self, mock_yolo_model): + """Test detection with invalid frame input.""" + detector = YOLODetector(mock_yolo_model) + + with pytest.raises(DetectionError) as exc_info: + detector.detect(None) + + assert "Invalid frame" in str(exc_info.value) + + def test_detect_result_validation(self, mock_yolo_model, mock_frame): + """Test detection result validation.""" + # Mock result with invalid bounding box (x2 <= x1) + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [300, 200, 100, 400, 0.9, 0] # Invalid: x2 < x1 + ]) + mock_result.boxes.id = torch.tensor([1001]) + + mock_yolo_model.track.return_value = [mock_result] + + detector = YOLODetector(mock_yolo_model) + detections = detector.detect(mock_frame) + + # Invalid detections should be filtered out + assert detections == [] + + def test_get_model_info(self, mock_yolo_model): + """Test getting model information.""" + mock_yolo_model.device = "cuda:0" + mock_yolo_model.names = {0: "car", 1: "truck"} + + detector = YOLODetector(mock_yolo_model) + info = detector.get_model_info() + + assert info["device"] == "cuda:0" + assert info["class_names"] == {0: "car", 1: "truck"} + assert info["tracking_enabled"] is True + + def test_set_tracking_enabled(self, mock_yolo_model): + """Test enabling/disabling tracking at runtime.""" + detector = YOLODetector(mock_yolo_model, enable_tracking=False) + assert detector.is_tracking_enabled is False + + detector.set_tracking_enabled(True) + assert detector.is_tracking_enabled is True + + detector.set_tracking_enabled(False) + assert detector.is_tracking_enabled is False + + def test_update_class_names(self, mock_yolo_model): + """Test updating class names at runtime.""" + detector = YOLODetector(mock_yolo_model) + + new_class_names = {0: "vehicle", 1: "person"} + detector.update_class_names(new_class_names) + + assert detector.class_names == new_class_names + + def test_reset_tracker(self, mock_yolo_model): + """Test resetting the tracking state.""" + detector = YOLODetector(mock_yolo_model) + + # This should not raise an error + detector.reset_tracker() + + def test_detect_with_crop_region(self, mock_yolo_model, mock_frame): + """Test detection with crop region specified.""" + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [50, 75, 150, 175, 0.9, 0] # Relative to cropped region + ]) + mock_result.boxes.id = torch.tensor([1001]) + + mock_yolo_model.track.return_value = [mock_result] + + detector = YOLODetector(mock_yolo_model) + crop_region = (100, 200, 300, 400) # x1, y1, x2, y2 + detections = detector.detect(mock_frame, crop_region=crop_region) + + # Bounding box should be adjusted to global coordinates + assert detections[0].bbox.x1 == 150 # 100 + 50 + assert detections[0].bbox.y1 == 275 # 200 + 75 + assert detections[0].bbox.x2 == 250 # 100 + 150 + assert detections[0].bbox.y2 == 375 # 200 + 175 + + def test_detect_batch_processing(self, mock_yolo_model): + """Test batch detection processing.""" + frames = [ + np.zeros((480, 640, 3), dtype=np.uint8), + np.ones((480, 640, 3), dtype=np.uint8) * 255 + ] + + mock_results = [] + for i in range(2): + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [100 + i*10, 200, 300, 400, 0.9, 0] + ]) + mock_result.boxes.id = torch.tensor([1001 + i]) + mock_results.append(mock_result) + + mock_yolo_model.track.side_effect = [[result] for result in mock_results] + + detector = YOLODetector(mock_yolo_model) + batch_detections = detector.detect_batch(frames) + + assert len(batch_detections) == 2 + assert len(batch_detections[0]) == 1 + assert len(batch_detections[1]) == 1 + assert batch_detections[0][0].bbox.x1 == 100 + assert batch_detections[1][0].bbox.x1 == 110 + + def test_detect_batch_empty_frames(self, mock_yolo_model): + """Test batch detection with empty frame list.""" + detector = YOLODetector(mock_yolo_model) + batch_detections = detector.detect_batch([]) + + assert batch_detections == [] + + def test_detect_performance_metrics(self, mock_yolo_model, mock_frame): + """Test detection performance metrics collection.""" + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0] + ]) + mock_result.boxes.id = torch.tensor([1001]) + mock_result.speed = {"preprocess": 2.1, "inference": 15.3, "postprocess": 1.2} + + mock_yolo_model.track.return_value = [mock_result] + + detector = YOLODetector(mock_yolo_model) + detections = detector.detect(mock_frame, return_metrics=True) + + # Check if performance metrics are available + assert hasattr(detector, '_last_inference_time') + + @pytest.mark.parametrize("device", ["cpu", "cuda:0", "mps"]) + def test_detect_different_devices(self, device, mock_frame): + """Test detection on different devices.""" + mock_model = Mock() + mock_model.device = device + + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0] + ]) + mock_result.boxes.id = torch.tensor([1001]) + + mock_model.track.return_value = [mock_result] + + detector = YOLODetector(mock_model) + detections = detector.detect(mock_frame) + + assert len(detections) == 1 + assert detections[0].confidence == 0.9 + + +class TestYOLODetectorIntegration: + """Integration tests for YOLO detector.""" + + def test_detect_with_real_tensor_operations(self, mock_yolo_model, mock_frame): + """Test detection with realistic tensor operations.""" + # Create more realistic box data + boxes_data = torch.tensor([ + [100.5, 200.3, 299.7, 399.8, 0.95, 0], + [150.2, 250.1, 349.9, 449.6, 0.87, 1], + [200.0, 300.0, 400.0, 500.0, 0.45, 0] # Low confidence + ]) + + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = boxes_data + mock_result.boxes.id = torch.tensor([2001, 2002, 2003]) + + mock_yolo_model.track.return_value = [mock_result] + + class_names = {0: "car", 1: "truck"} + detector = YOLODetector(mock_yolo_model, class_names=class_names) + + detections = detector.detect(mock_frame, confidence_threshold=0.5) + + # Should filter out low confidence detection + assert len(detections) == 2 + + # Check first detection + det1 = detections[0] + assert det1.class_name == "car" + assert det1.confidence == pytest.approx(0.95) + assert det1.track_id == 2001 + assert det1.bbox.x1 == pytest.approx(100.5) + assert det1.bbox.y1 == pytest.approx(200.3) + + # Check second detection + det2 = detections[1] + assert det2.class_name == "truck" + assert det2.confidence == pytest.approx(0.87) + assert det2.track_id == 2002 + + def test_multi_frame_tracking_consistency(self, mock_yolo_model, mock_frame): + """Test that tracking IDs remain consistent across frames.""" + detector = YOLODetector(mock_yolo_model) + + # Frame 1 + mock_result1 = Mock() + mock_result1.boxes = Mock() + mock_result1.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0] + ]) + mock_result1.boxes.id = torch.tensor([5001]) + + mock_yolo_model.track.return_value = [mock_result1] + detections1 = detector.detect(mock_frame) + + # Frame 2 - same object, slightly moved + mock_result2 = Mock() + mock_result2.boxes = Mock() + mock_result2.boxes.data = torch.tensor([ + [105, 205, 305, 405, 0.88, 0] + ]) + mock_result2.boxes.id = torch.tensor([5001]) # Same ID + + mock_yolo_model.track.return_value = [mock_result2] + detections2 = detector.detect(mock_frame) + + # Should maintain same track ID + assert detections1[0].track_id == detections2[0].track_id == 5001 \ No newline at end of file diff --git a/tests/unit/models/test_model_manager.py b/tests/unit/models/test_model_manager.py new file mode 100644 index 0000000..7bc6726 --- /dev/null +++ b/tests/unit/models/test_model_manager.py @@ -0,0 +1,882 @@ +""" +Unit tests for model management functionality. +""" +import pytest +import os +import tempfile +import threading +import time +from unittest.mock import Mock, patch, MagicMock +import torch +import numpy as np + +from detector_worker.models.model_manager import ( + ModelManager, + ModelInfo, + ModelConfig, + ModelCache, + ModelLoader, + ModelError, + ModelLoadError, + ModelCacheError +) +from detector_worker.core.exceptions import ConfigurationError + + +class TestModelConfig: + """Test model configuration.""" + + def test_creation(self): + """Test model config creation.""" + config = ModelConfig( + model_id="yolo_v8_car", + model_path="/models/yolo_v8_car.pt", + model_type="detection", + device="cuda:0" + ) + + assert config.model_id == "yolo_v8_car" + assert config.model_path == "/models/yolo_v8_car.pt" + assert config.model_type == "detection" + assert config.device == "cuda:0" + assert config.confidence_threshold == 0.5 + assert config.max_memory_mb == 1024 + + def test_creation_with_optional_params(self): + """Test config creation with optional parameters.""" + config = ModelConfig( + model_id="classifier_v1", + model_path="/models/classifier.pt", + model_type="classification", + device="cpu", + confidence_threshold=0.8, + max_memory_mb=512, + class_names={0: "car", 1: "truck", 2: "bus"}, + preprocessing_config={"resize": (224, 224), "normalize": True} + ) + + assert config.confidence_threshold == 0.8 + assert config.max_memory_mb == 512 + assert config.class_names[0] == "car" + assert config.preprocessing_config["resize"] == (224, 224) + + def test_from_dict(self): + """Test creating config from dictionary.""" + config_dict = { + "model_id": "detection_model", + "model_path": "/path/to/model.pt", + "model_type": "detection", + "device": "cuda:0", + "confidence_threshold": 0.75, + "class_names": {0: "person", 1: "vehicle"}, + "unknown_field": "ignored" + } + + config = ModelConfig.from_dict(config_dict) + + assert config.model_id == "detection_model" + assert config.confidence_threshold == 0.75 + assert config.class_names[1] == "vehicle" + + def test_validation(self): + """Test config validation.""" + # Valid config + valid_config = ModelConfig( + model_id="test_model", + model_path="/valid/path/model.pt", + model_type="detection", + device="cpu" + ) + assert valid_config.is_valid() is True + + # Invalid config (empty model_id) + invalid_config = ModelConfig( + model_id="", + model_path="/path/model.pt", + model_type="detection", + device="cpu" + ) + assert invalid_config.is_valid() is False + + def test_get_memory_limit_bytes(self): + """Test getting memory limit in bytes.""" + config = ModelConfig( + model_id="test", + model_path="/path", + model_type="detection", + device="cpu", + max_memory_mb=256 + ) + + assert config.get_memory_limit_bytes() == 256 * 1024 * 1024 + + +class TestModelInfo: + """Test model information.""" + + def test_creation(self): + """Test model info creation.""" + config = ModelConfig( + model_id="test_model", + model_path="/path/model.pt", + model_type="detection", + device="cuda:0" + ) + + mock_model = Mock() + + info = ModelInfo( + config=config, + model_instance=mock_model, + load_time=1.5 + ) + + assert info.config == config + assert info.model_instance == mock_model + assert info.load_time == 1.5 + assert info.reference_count == 0 + assert info.last_used <= time.time() + assert info.memory_usage == 0 + + def test_increment_reference(self): + """Test incrementing reference count.""" + config = ModelConfig("test", "/path", "detection", "cpu") + info = ModelInfo(config, Mock(), 1.0) + + assert info.reference_count == 0 + + info.increment_reference() + assert info.reference_count == 1 + + info.increment_reference() + assert info.reference_count == 2 + + def test_decrement_reference(self): + """Test decrementing reference count.""" + config = ModelConfig("test", "/path", "detection", "cpu") + info = ModelInfo(config, Mock(), 1.0) + info.reference_count = 3 + + assert info.decrement_reference() == 2 + assert info.reference_count == 2 + + assert info.decrement_reference() == 1 + assert info.decrement_reference() == 0 + + # Should not go below 0 + assert info.decrement_reference() == 0 + + def test_update_usage(self): + """Test updating usage statistics.""" + config = ModelConfig("test", "/path", "detection", "cpu") + info = ModelInfo(config, Mock(), 1.0) + + original_time = info.last_used + original_count = info.usage_count + + time.sleep(0.01) # Small delay + info.update_usage(memory_usage=512*1024*1024) # 512MB + + assert info.last_used > original_time + assert info.usage_count == original_count + 1 + assert info.memory_usage == 512*1024*1024 + + def test_age_calculation(self): + """Test age calculation.""" + config = ModelConfig("test", "/path", "detection", "cpu") + info = ModelInfo(config, Mock(), 1.0) + + time.sleep(0.01) + age = info.age() + + assert age > 0 + assert age < 1 # Should be less than 1 second + + def test_get_stats(self): + """Test getting model statistics.""" + config = ModelConfig("test_model", "/path", "detection", "cuda:0") + info = ModelInfo(config, Mock(), 2.5) + + info.reference_count = 3 + info.usage_count = 100 + info.memory_usage = 1024*1024*1024 # 1GB + + stats = info.get_stats() + + assert stats["model_id"] == "test_model" + assert stats["device"] == "cuda:0" + assert stats["load_time"] == 2.5 + assert stats["reference_count"] == 3 + assert stats["usage_count"] == 100 + assert stats["memory_usage_mb"] == 1024 + assert "age_seconds" in stats + + +class TestModelLoader: + """Test model loading functionality.""" + + def test_creation(self): + """Test model loader creation.""" + loader = ModelLoader() + + assert loader.supported_formats == [".pt", ".pth", ".onnx", ".trt"] + assert loader.default_device == "cpu" + + def test_detect_device_cuda_available(self): + """Test device detection when CUDA is available.""" + loader = ModelLoader() + + with patch('torch.cuda.is_available', return_value=True): + with patch('torch.cuda.device_count', return_value=2): + device = loader.detect_optimal_device() + + assert device == "cuda:0" + + def test_detect_device_cuda_unavailable(self): + """Test device detection when CUDA is not available.""" + loader = ModelLoader() + + with patch('torch.cuda.is_available', return_value=False): + device = loader.detect_optimal_device() + + assert device == "cpu" + + def test_load_pytorch_model(self): + """Test loading PyTorch model.""" + loader = ModelLoader() + + with patch('torch.load') as mock_torch_load: + with patch('os.path.exists', return_value=True): + mock_model = Mock() + mock_torch_load.return_value = mock_model + + config = ModelConfig( + model_id="test_model", + model_path="/path/to/model.pt", + model_type="detection", + device="cpu" + ) + + loaded_model = loader.load_model(config) + + assert loaded_model == mock_model + mock_torch_load.assert_called_once_with("/path/to/model.pt", map_location="cpu") + + def test_load_model_file_not_exists(self): + """Test loading model when file doesn't exist.""" + loader = ModelLoader() + + with patch('os.path.exists', return_value=False): + config = ModelConfig( + model_id="missing_model", + model_path="/nonexistent/model.pt", + model_type="detection", + device="cpu" + ) + + with pytest.raises(ModelLoadError) as exc_info: + loader.load_model(config) + + assert "does not exist" in str(exc_info.value) + + def test_load_model_invalid_format(self): + """Test loading model with invalid format.""" + loader = ModelLoader() + + with patch('os.path.exists', return_value=True): + config = ModelConfig( + model_id="invalid_model", + model_path="/path/to/model.invalid", + model_type="detection", + device="cpu" + ) + + with pytest.raises(ModelLoadError) as exc_info: + loader.load_model(config) + + assert "unsupported format" in str(exc_info.value).lower() + + def test_load_model_torch_error(self): + """Test loading model with torch loading error.""" + loader = ModelLoader() + + with patch('os.path.exists', return_value=True): + with patch('torch.load', side_effect=RuntimeError("CUDA out of memory")): + config = ModelConfig( + model_id="error_model", + model_path="/path/to/model.pt", + model_type="detection", + device="cuda:0" + ) + + with pytest.raises(ModelLoadError) as exc_info: + loader.load_model(config) + + assert "CUDA out of memory" in str(exc_info.value) + + def test_validate_model_pytorch(self): + """Test validating PyTorch model.""" + loader = ModelLoader() + + mock_model = Mock() + mock_model.__class__.__module__ = "torch.nn" + + config = ModelConfig("test", "/path", "detection", "cpu") + + is_valid = loader.validate_model(mock_model, config) + + assert is_valid is True + + def test_validate_model_invalid(self): + """Test validating invalid model.""" + loader = ModelLoader() + + invalid_model = "not_a_model" + config = ModelConfig("test", "/path", "detection", "cpu") + + is_valid = loader.validate_model(invalid_model, config) + + assert is_valid is False + + def test_estimate_model_memory(self): + """Test estimating model memory usage.""" + loader = ModelLoader() + + mock_model = Mock() + mock_param1 = Mock() + mock_param1.numel.return_value = 1000000 # 1M parameters + mock_param1.element_size.return_value = 4 # 4 bytes per parameter + + mock_param2 = Mock() + mock_param2.numel.return_value = 500000 # 0.5M parameters + mock_param2.element_size.return_value = 4 + + mock_model.parameters.return_value = [mock_param1, mock_param2] + + memory_bytes = loader.estimate_memory_usage(mock_model) + + expected_bytes = (1000000 + 500000) * 4 # 6MB + assert memory_bytes == expected_bytes + + +class TestModelCache: + """Test model caching functionality.""" + + def test_creation(self): + """Test model cache creation.""" + cache = ModelCache(max_size=5, max_memory_mb=2048) + + assert cache.max_size == 5 + assert cache.max_memory_mb == 2048 + assert len(cache.models) == 0 + assert len(cache.access_order) == 0 + + def test_put_and_get_model(self): + """Test putting and getting model from cache.""" + cache = ModelCache(max_size=3) + + config = ModelConfig("test_model", "/path", "detection", "cpu") + mock_model = Mock() + model_info = ModelInfo(config, mock_model, 1.5) + + cache.put("test_model", model_info) + + retrieved_info = cache.get("test_model") + + assert retrieved_info == model_info + assert retrieved_info.reference_count == 1 # Should be incremented on get + + def test_get_nonexistent_model(self): + """Test getting non-existent model.""" + cache = ModelCache(max_size=3) + + result = cache.get("nonexistent_model") + + assert result is None + + def test_contains_check(self): + """Test checking if model exists in cache.""" + cache = ModelCache(max_size=3) + + config = ModelConfig("test_model", "/path", "detection", "cpu") + model_info = ModelInfo(config, Mock(), 1.0) + cache.put("test_model", model_info) + + assert cache.contains("test_model") is True + assert cache.contains("nonexistent_model") is False + + def test_remove_model(self): + """Test removing model from cache.""" + cache = ModelCache(max_size=3) + + config = ModelConfig("test_model", "/path", "detection", "cpu") + model_info = ModelInfo(config, Mock(), 1.0) + cache.put("test_model", model_info) + + assert cache.contains("test_model") is True + + removed_info = cache.remove("test_model") + + assert removed_info == model_info + assert cache.contains("test_model") is False + + def test_lru_eviction(self): + """Test LRU eviction policy.""" + cache = ModelCache(max_size=2) + + # Add models to fill cache + for i in range(2): + config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu") + model_info = ModelInfo(config, Mock(), 1.0) + cache.put(f"model_{i}", model_info) + + # Access model_0 to make it recently used + cache.get("model_0") + + # Add another model (should evict model_1, the least recently used) + config = ModelConfig("model_2", "/path_2", "detection", "cpu") + model_info = ModelInfo(config, Mock(), 1.0) + cache.put("model_2", model_info) + + assert cache.size() == 2 + assert cache.contains("model_0") is True # Recently accessed + assert cache.contains("model_1") is False # Evicted + assert cache.contains("model_2") is True # Newly added + + def test_memory_based_eviction(self): + """Test memory-based eviction.""" + cache = ModelCache(max_size=10, max_memory_mb=1) # 1MB limit + + # Add model that uses 0.8MB + config1 = ModelConfig("model_1", "/path_1", "detection", "cpu") + model1 = Mock() + info1 = ModelInfo(config1, model1, 1.0) + info1.memory_usage = 0.8 * 1024 * 1024 # 0.8MB + cache.put("model_1", info1) + + # Add model that would exceed memory limit + config2 = ModelConfig("model_2", "/path_2", "detection", "cpu") + model2 = Mock() + info2 = ModelInfo(config2, model2, 1.0) + info2.memory_usage = 0.5 * 1024 * 1024 # 0.5MB + cache.put("model_2", info2) + + # First model should be evicted due to memory constraint + assert cache.contains("model_1") is False + assert cache.contains("model_2") is True + + def test_get_stats(self): + """Test getting cache statistics.""" + cache = ModelCache(max_size=5) + + # Add some models + for i in range(3): + config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu") + model_info = ModelInfo(config, Mock(), 1.0) + model_info.memory_usage = 100 * 1024 * 1024 # 100MB each + cache.put(f"model_{i}", model_info) + + # Access some models + cache.get("model_0") + cache.get("model_1") + cache.get("nonexistent") # Miss + + stats = cache.get_stats() + + assert stats["size"] == 3 + assert stats["max_size"] == 5 + assert stats["hits"] == 2 + assert stats["misses"] == 1 + assert stats["hit_rate"] == 2/3 + assert stats["memory_usage_mb"] == 300 + + def test_clear_cache(self): + """Test clearing entire cache.""" + cache = ModelCache(max_size=5) + + # Add models + for i in range(3): + config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu") + model_info = ModelInfo(config, Mock(), 1.0) + cache.put(f"model_{i}", model_info) + + assert cache.size() == 3 + + cache.clear() + + assert cache.size() == 0 + assert len(cache.models) == 0 + assert len(cache.access_order) == 0 + + +class TestModelManager: + """Test main model manager functionality.""" + + def test_initialization(self): + """Test model manager initialization.""" + manager = ModelManager() + + assert isinstance(manager.cache, ModelCache) + assert isinstance(manager.loader, ModelLoader) + assert manager.models_directory == "models" + assert manager.default_device == "cpu" + + def test_initialization_with_config(self): + """Test initialization with custom configuration.""" + config = { + "models_directory": "/custom/models", + "default_device": "cuda:0", + "cache_max_size": 20, + "cache_max_memory_mb": 4096 + } + + manager = ModelManager(config) + + assert manager.models_directory == "/custom/models" + assert manager.default_device == "cuda:0" + assert manager.cache.max_size == 20 + assert manager.cache.max_memory_mb == 4096 + + def test_load_model_new(self): + """Test loading new model.""" + manager = ModelManager() + + config = ModelConfig( + model_id="test_model", + model_path="/path/to/model.pt", + model_type="detection", + device="cpu" + ) + + with patch.object(manager.loader, 'load_model') as mock_load: + with patch.object(manager.loader, 'estimate_memory_usage', return_value=512*1024*1024): + mock_model = Mock() + mock_load.return_value = mock_model + + loaded_model = manager.load_model(config) + + assert loaded_model == mock_model + assert manager.cache.contains("test_model") is True + mock_load.assert_called_once_with(config) + + def test_load_model_from_cache(self): + """Test loading model from cache.""" + manager = ModelManager() + + # Pre-populate cache + config = ModelConfig("cached_model", "/path", "detection", "cpu") + mock_model = Mock() + model_info = ModelInfo(config, mock_model, 1.0) + manager.cache.put("cached_model", model_info) + + with patch.object(manager.loader, 'load_model') as mock_load: + loaded_model = manager.load_model(config) + + assert loaded_model == mock_model + mock_load.assert_not_called() # Should not load from disk + + def test_get_model_by_id(self): + """Test getting model by ID.""" + manager = ModelManager() + + config = ModelConfig("test_model", "/path", "detection", "cpu") + mock_model = Mock() + model_info = ModelInfo(config, mock_model, 1.0) + manager.cache.put("test_model", model_info) + + retrieved_model = manager.get_model("test_model") + + assert retrieved_model == mock_model + + def test_get_nonexistent_model(self): + """Test getting non-existent model.""" + manager = ModelManager() + + model = manager.get_model("nonexistent_model") + + assert model is None + + def test_unload_model_with_references(self): + """Test unloading model with active references.""" + manager = ModelManager() + + config = ModelConfig("ref_model", "/path", "detection", "cpu") + mock_model = Mock() + model_info = ModelInfo(config, mock_model, 1.0) + model_info.reference_count = 2 # Active references + manager.cache.put("ref_model", model_info) + + result = manager.unload_model("ref_model") + + assert result is False # Should not unload with active references + assert manager.cache.contains("ref_model") is True + + def test_unload_model_no_references(self): + """Test unloading model without references.""" + manager = ModelManager() + + config = ModelConfig("no_ref_model", "/path", "detection", "cpu") + mock_model = Mock() + model_info = ModelInfo(config, mock_model, 1.0) + model_info.reference_count = 0 # No references + manager.cache.put("no_ref_model", model_info) + + result = manager.unload_model("no_ref_model") + + assert result is True + assert manager.cache.contains("no_ref_model") is False + + def test_list_loaded_models(self): + """Test listing loaded models.""" + manager = ModelManager() + + # Add models to cache + for i in range(3): + config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu") + model_info = ModelInfo(config, Mock(), 1.0) + manager.cache.put(f"model_{i}", model_info) + + loaded_models = manager.list_loaded_models() + + assert len(loaded_models) == 3 + assert all(info["model_id"].startswith("model_") for info in loaded_models) + + def test_get_model_info(self): + """Test getting model information.""" + manager = ModelManager() + + config = ModelConfig("info_model", "/path", "detection", "cuda:0") + mock_model = Mock() + model_info = ModelInfo(config, mock_model, 2.5) + model_info.usage_count = 10 + manager.cache.put("info_model", model_info) + + info = manager.get_model_info("info_model") + + assert info is not None + assert info["model_id"] == "info_model" + assert info["device"] == "cuda:0" + assert info["load_time"] == 2.5 + assert info["usage_count"] == 10 + + def test_cleanup_unused_models(self): + """Test cleaning up unused models.""" + manager = ModelManager() + + # Add models with different reference counts + models_data = [ + ("used_model", 2), # Has references + ("unused_model_1", 0), # No references + ("unused_model_2", 0) # No references + ] + + for model_id, ref_count in models_data: + config = ModelConfig(model_id, f"/path/{model_id}", "detection", "cpu") + model_info = ModelInfo(config, Mock(), 1.0) + model_info.reference_count = ref_count + manager.cache.put(model_id, model_info) + + cleaned_count = manager.cleanup_unused_models() + + assert cleaned_count == 2 # Two unused models cleaned + assert manager.cache.contains("used_model") is True + assert manager.cache.contains("unused_model_1") is False + assert manager.cache.contains("unused_model_2") is False + + def test_get_memory_usage(self): + """Test getting total memory usage.""" + manager = ModelManager() + + # Add models with different memory usage + memory_sizes = [256, 512, 1024] # MB + + for i, memory_mb in enumerate(memory_sizes): + config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu") + model_info = ModelInfo(config, Mock(), 1.0) + model_info.memory_usage = memory_mb * 1024 * 1024 # Convert to bytes + manager.cache.put(f"model_{i}", model_info) + + total_usage = manager.get_memory_usage() + + expected_bytes = sum(memory_sizes) * 1024 * 1024 + assert total_usage == expected_bytes + + def test_health_check(self): + """Test model manager health check.""" + manager = ModelManager() + + # Add models + for i in range(3): + config = ModelConfig(f"model_{i}", f"/path_{i}", "detection", "cpu") + model_info = ModelInfo(config, Mock(), 1.0) + model_info.memory_usage = 100 * 1024 * 1024 # 100MB each + manager.cache.put(f"model_{i}", model_info) + + health_report = manager.health_check() + + assert health_report["status"] == "healthy" + assert health_report["loaded_models"] == 3 + assert health_report["total_memory_mb"] == 300 + assert health_report["cache_hit_rate"] >= 0 + + +class TestModelManagerIntegration: + """Integration tests for model manager.""" + + def test_concurrent_model_loading(self): + """Test concurrent model loading.""" + manager = ModelManager() + + # Mock loader to simulate loading time + def slow_load(config): + time.sleep(0.1) # Simulate loading time + mock_model = Mock() + mock_model.model_id = config.model_id + return mock_model + + with patch.object(manager.loader, 'load_model', side_effect=slow_load): + with patch.object(manager.loader, 'estimate_memory_usage', return_value=100*1024*1024): + + # Create multiple threads loading different models + results = {} + errors = [] + + def load_model_thread(model_id): + try: + config = ModelConfig( + model_id=model_id, + model_path=f"/path/{model_id}.pt", + model_type="detection", + device="cpu" + ) + model = manager.load_model(config) + results[model_id] = model + except Exception as e: + errors.append((model_id, str(e))) + + threads = [] + for i in range(5): + thread = threading.Thread(target=load_model_thread, args=(f"model_{i}",)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # All models should be loaded successfully + assert len(errors) == 0 + assert len(results) == 5 + assert len(manager.cache.models) == 5 + + def test_memory_pressure_handling(self): + """Test handling memory pressure.""" + # Create manager with small memory limit + manager = ModelManager({ + "cache_max_memory_mb": 200 # 200MB limit + }) + + with patch.object(manager.loader, 'load_model') as mock_load: + with patch.object(manager.loader, 'estimate_memory_usage', return_value=100*1024*1024): # 100MB per model + + def create_mock_model(config): + mock_model = Mock() + mock_model.model_id = config.model_id + return mock_model + + mock_load.side_effect = create_mock_model + + # Load models that exceed memory limit + for i in range(4): # 4 * 100MB = 400MB > 200MB limit + config = ModelConfig( + model_id=f"large_model_{i}", + model_path=f"/path/large_model_{i}.pt", + model_type="detection", + device="cpu" + ) + manager.load_model(config) + + # Should not exceed memory limit due to eviction + total_memory = manager.get_memory_usage() + memory_limit = 200 * 1024 * 1024 + assert total_memory <= memory_limit + + def test_model_lifecycle_management(self): + """Test complete model lifecycle.""" + manager = ModelManager() + + with patch.object(manager.loader, 'load_model') as mock_load: + with patch.object(manager.loader, 'estimate_memory_usage', return_value=50*1024*1024): + + mock_model = Mock() + mock_load.return_value = mock_model + + config = ModelConfig( + model_id="lifecycle_model", + model_path="/path/lifecycle_model.pt", + model_type="detection", + device="cpu" + ) + + # 1. Load model + loaded_model = manager.load_model(config) + assert loaded_model == mock_model + assert manager.cache.contains("lifecycle_model") is True + + # 2. Get model multiple times (increase usage) + for _ in range(5): + model = manager.get_model("lifecycle_model") + assert model == mock_model + + # 3. Check model info + info = manager.get_model_info("lifecycle_model") + assert info["usage_count"] >= 5 + + # 4. Simulate model still in use + model_info = manager.cache.get("lifecycle_model") + model_info.reference_count = 1 + + # Should not unload while in use + unloaded = manager.unload_model("lifecycle_model") + assert unloaded is False + assert manager.cache.contains("lifecycle_model") is True + + # 5. Release reference and unload + model_info.reference_count = 0 + unloaded = manager.unload_model("lifecycle_model") + assert unloaded is True + assert manager.cache.contains("lifecycle_model") is False + + def test_error_recovery(self): + """Test error recovery scenarios.""" + manager = ModelManager() + + # Test loading model that fails initially then succeeds + call_count = 0 + def failing_then_success_load(config): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ModelLoadError("First attempt failed") + return Mock() + + with patch.object(manager.loader, 'load_model', side_effect=failing_then_success_load): + with patch.object(manager.loader, 'estimate_memory_usage', return_value=50*1024*1024): + + config = ModelConfig( + model_id="retry_model", + model_path="/path/retry_model.pt", + model_type="detection", + device="cpu" + ) + + # First attempt should fail + with pytest.raises(ModelLoadError): + manager.load_model(config) + + # Model should not be in cache + assert manager.cache.contains("retry_model") is False + + # Second attempt should succeed + model = manager.load_model(config) + assert model is not None + assert manager.cache.contains("retry_model") is True \ No newline at end of file diff --git a/tests/unit/pipeline/test_action_executor.py b/tests/unit/pipeline/test_action_executor.py new file mode 100644 index 0000000..08ae60c --- /dev/null +++ b/tests/unit/pipeline/test_action_executor.py @@ -0,0 +1,959 @@ +""" +Unit tests for action execution functionality. +""" +import pytest +import asyncio +import json +import base64 +import numpy as np +from unittest.mock import Mock, MagicMock, patch, AsyncMock +from datetime import datetime, timedelta + +from detector_worker.pipeline.action_executor import ( + ActionExecutor, + ActionResult, + ActionType, + RedisAction, + PostgreSQLAction, + FileAction +) +from detector_worker.detection.detection_result import DetectionResult, BoundingBox +from detector_worker.core.exceptions import ActionError, RedisError, DatabaseError + + +class TestActionResult: + """Test action execution result.""" + + def test_creation_success(self): + """Test successful action result creation.""" + result = ActionResult( + action_type=ActionType.REDIS_SAVE, + success=True, + execution_time=0.05, + metadata={"key": "saved_image_key", "expiry": 600} + ) + + assert result.action_type == ActionType.REDIS_SAVE + assert result.success is True + assert result.execution_time == 0.05 + assert result.metadata["key"] == "saved_image_key" + assert result.error is None + + def test_creation_failure(self): + """Test failed action result creation.""" + result = ActionResult( + action_type=ActionType.POSTGRESQL_INSERT, + success=False, + error="Database connection failed", + execution_time=0.02 + ) + + assert result.action_type == ActionType.POSTGRESQL_INSERT + assert result.success is False + assert result.error == "Database connection failed" + assert result.metadata == {} + + +class TestRedisAction: + """Test Redis action implementations.""" + + def test_creation(self): + """Test Redis action creation.""" + action_config = { + "type": "redis_save_image", + "region": "car", + "key": "inference:{display_id}:{timestamp}:{session_id}", + "expire_seconds": 600 + } + + action = RedisAction(action_config) + + assert action.action_type == ActionType.REDIS_SAVE + assert action.region == "car" + assert action.key_template == "inference:{display_id}:{timestamp}:{session_id}" + assert action.expire_seconds == 600 + + def test_resolve_key_template(self): + """Test key template resolution.""" + action_config = { + "type": "redis_save_image", + "region": "car", + "key": "inference:{display_id}:{timestamp}:{session_id}:{filename}", + "expire_seconds": 600 + } + + action = RedisAction(action_config) + + context = { + "display_id": "display_001", + "timestamp": "1640995200000", + "session_id": "session_123", + "filename": "detection.jpg" + } + + resolved_key = action.resolve_key(context) + expected_key = "inference:display_001:1640995200000:session_123:detection.jpg" + + assert resolved_key == expected_key + + def test_resolve_key_missing_variable(self): + """Test key resolution with missing variable.""" + action_config = { + "type": "redis_save_image", + "region": "car", + "key": "inference:{display_id}:{missing_var}", + "expire_seconds": 600 + } + + action = RedisAction(action_config) + + context = {"display_id": "display_001"} + + with pytest.raises(ActionError): + action.resolve_key(context) + + +class TestPostgreSQLAction: + """Test PostgreSQL action implementations.""" + + def test_creation_insert(self): + """Test PostgreSQL insert action creation.""" + action_config = { + "type": "postgresql_insert", + "table": "detections", + "fields": { + "camera_id": "{camera_id}", + "session_id": "{session_id}", + "detection_class": "{class}", + "confidence": "{confidence}", + "bbox_x1": "{bbox.x1}", + "created_at": "NOW()" + } + } + + action = PostgreSQLAction(action_config) + + assert action.action_type == ActionType.POSTGRESQL_INSERT + assert action.table == "detections" + assert len(action.fields) == 6 + assert action.key_field is None + + def test_creation_update(self): + """Test PostgreSQL update action creation.""" + action_config = { + "type": "postgresql_update_combined", + "table": "car_info", + "key_field": "session_id", + "fields": { + "car_brand": "{car_brand_cls.brand}", + "car_body_type": "{car_bodytype_cls.body_type}", + "updated_at": "NOW()" + }, + "waitForBranches": ["car_brand_cls", "car_bodytype_cls"] + } + + action = PostgreSQLAction(action_config) + + assert action.action_type == ActionType.POSTGRESQL_UPDATE + assert action.table == "car_info" + assert action.key_field == "session_id" + assert action.wait_for_branches == ["car_brand_cls", "car_bodytype_cls"] + + def test_resolve_field_values(self): + """Test field value resolution.""" + action_config = { + "type": "postgresql_insert", + "table": "detections", + "fields": { + "camera_id": "{camera_id}", + "detection_class": "{class}", + "confidence": "{confidence}", + "brand": "{car_brand_cls.brand}" + } + } + + action = PostgreSQLAction(action_config) + + context = { + "camera_id": "camera_001", + "class": "car", + "confidence": 0.85 + } + + branch_results = { + "car_brand_cls": {"brand": "Toyota", "confidence": 0.78} + } + + resolved_fields = action.resolve_field_values(context, branch_results) + + assert resolved_fields["camera_id"] == "camera_001" + assert resolved_fields["detection_class"] == "car" + assert resolved_fields["confidence"] == 0.85 + assert resolved_fields["brand"] == "Toyota" + + +class TestFileAction: + """Test file action implementations.""" + + def test_creation(self): + """Test file action creation.""" + action_config = { + "type": "save_image", + "path": "/tmp/detections/{camera_id}_{timestamp}.jpg", + "region": "car", + "format": "jpeg", + "quality": 85 + } + + action = FileAction(action_config) + + assert action.action_type == ActionType.SAVE_IMAGE + assert action.path_template == "/tmp/detections/{camera_id}_{timestamp}.jpg" + assert action.region == "car" + assert action.format == "jpeg" + assert action.quality == 85 + + def test_resolve_path_template(self): + """Test path template resolution.""" + action_config = { + "type": "save_image", + "path": "/tmp/detections/{camera_id}/{date}/{timestamp}.jpg" + } + + action = FileAction(action_config) + + context = { + "camera_id": "camera_001", + "timestamp": "1640995200000", + "date": "2022-01-01" + } + + resolved_path = action.resolve_path(context) + expected_path = "/tmp/detections/camera_001/2022-01-01/1640995200000.jpg" + + assert resolved_path == expected_path + + +class TestActionExecutor: + """Test action execution functionality.""" + + def test_initialization(self): + """Test action executor initialization.""" + executor = ActionExecutor() + + assert executor.redis_client is None + assert executor.db_manager is None + assert executor.max_concurrent_actions == 10 + assert executor.action_timeout == 30.0 + + def test_initialization_with_clients(self, mock_redis_client, mock_database_connection): + """Test initialization with client instances.""" + executor = ActionExecutor( + redis_client=mock_redis_client, + db_manager=mock_database_connection + ) + + assert executor.redis_client is mock_redis_client + assert executor.db_manager is mock_database_connection + + @pytest.mark.asyncio + async def test_execute_actions_empty_list(self): + """Test executing empty action list.""" + executor = ActionExecutor() + + context = { + "camera_id": "camera_001", + "session_id": "session_123" + } + + results = await executor.execute_actions([], {}, context) + + assert results == [] + + @pytest.mark.asyncio + async def test_execute_redis_save_action(self, mock_redis_client, mock_frame): + """Test executing Redis save image action.""" + executor = ActionExecutor(redis_client=mock_redis_client) + + actions = [ + { + "type": "redis_save_image", + "region": "car", + "key": "inference:{camera_id}:{session_id}", + "expire_seconds": 600 + } + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = { + "camera_id": "camera_001", + "session_id": "session_123", + "frame_data": mock_frame + } + + # Mock successful Redis operations + mock_redis_client.set.return_value = True + mock_redis_client.expire.return_value = True + + results = await executor.execute_actions(actions, regions, context) + + assert len(results) == 1 + assert results[0].success is True + assert results[0].action_type == ActionType.REDIS_SAVE + + # Verify Redis calls + mock_redis_client.set.assert_called_once() + mock_redis_client.expire.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_postgresql_insert_action(self, mock_database_connection): + """Test executing PostgreSQL insert action.""" + # Mock database manager + mock_db_manager = Mock() + mock_db_manager.execute_query = AsyncMock(return_value=True) + + executor = ActionExecutor(db_manager=mock_db_manager) + + actions = [ + { + "type": "postgresql_insert", + "table": "detections", + "fields": { + "camera_id": "{camera_id}", + "session_id": "{session_id}", + "detection_class": "{class}", + "confidence": "{confidence}", + "created_at": "NOW()" + } + } + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = { + "camera_id": "camera_001", + "session_id": "session_123", + "class": "car", + "confidence": 0.9 + } + + results = await executor.execute_actions(actions, regions, context) + + assert len(results) == 1 + assert results[0].success is True + assert results[0].action_type == ActionType.POSTGRESQL_INSERT + + # Verify database call + mock_db_manager.execute_query.assert_called_once() + call_args = mock_db_manager.execute_query.call_args[0] + assert "INSERT INTO detections" in call_args[0] + + @pytest.mark.asyncio + async def test_execute_postgresql_update_action(self, mock_database_connection): + """Test executing PostgreSQL update action.""" + mock_db_manager = Mock() + mock_db_manager.execute_query = AsyncMock(return_value=True) + + executor = ActionExecutor(db_manager=mock_db_manager) + + actions = [ + { + "type": "postgresql_update_combined", + "table": "car_info", + "key_field": "session_id", + "fields": { + "car_brand": "{car_brand_cls.brand}", + "car_body_type": "{car_bodytype_cls.body_type}", + "updated_at": "NOW()" + }, + "waitForBranches": ["car_brand_cls", "car_bodytype_cls"] + } + ] + + regions = {} + + context = { + "session_id": "session_123" + } + + branch_results = { + "car_brand_cls": {"brand": "Toyota"}, + "car_bodytype_cls": {"body_type": "Sedan"} + } + + results = await executor.execute_actions(actions, regions, context, branch_results) + + assert len(results) == 1 + assert results[0].success is True + assert results[0].action_type == ActionType.POSTGRESQL_UPDATE + + # Verify database call + mock_db_manager.execute_query.assert_called_once() + call_args = mock_db_manager.execute_query.call_args[0] + assert "UPDATE car_info SET" in call_args[0] + assert "WHERE session_id" in call_args[0] + + @pytest.mark.asyncio + async def test_execute_file_save_action(self, mock_frame): + """Test executing file save action.""" + executor = ActionExecutor() + + actions = [ + { + "type": "save_image", + "path": "/tmp/test_{camera_id}_{timestamp}.jpg", + "region": "car", + "format": "jpeg", + "quality": 85 + } + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = { + "camera_id": "camera_001", + "timestamp": "1640995200000", + "frame_data": mock_frame + } + + with patch('cv2.imwrite') as mock_imwrite: + mock_imwrite.return_value = True + + results = await executor.execute_actions(actions, regions, context) + + assert len(results) == 1 + assert results[0].success is True + assert results[0].action_type == ActionType.SAVE_IMAGE + + # Verify file save call + mock_imwrite.assert_called_once() + call_args = mock_imwrite.call_args + assert "/tmp/test_camera_001_1640995200000.jpg" in call_args[0][0] + + @pytest.mark.asyncio + async def test_execute_actions_parallel(self, mock_redis_client): + """Test parallel execution of multiple actions.""" + executor = ActionExecutor(redis_client=mock_redis_client) + + # Multiple Redis actions + actions = [ + { + "type": "redis_save_image", + "region": "car", + "key": "inference:car:{session_id}", + "expire_seconds": 600 + }, + { + "type": "redis_publish", + "channel": "detections", + "message": "{camera_id}:car_detected" + } + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = { + "camera_id": "camera_001", + "session_id": "session_123", + "frame_data": np.zeros((480, 640, 3), dtype=np.uint8) + } + + # Mock Redis operations + mock_redis_client.set.return_value = True + mock_redis_client.expire.return_value = True + mock_redis_client.publish.return_value = 1 + + import time + start_time = time.time() + + results = await executor.execute_actions(actions, regions, context) + + execution_time = time.time() - start_time + + assert len(results) == 2 + assert all(result.success for result in results) + + # Should execute in parallel (faster than sequential) + assert execution_time < 0.1 # Allow some overhead + + @pytest.mark.asyncio + async def test_execute_actions_error_handling(self, mock_redis_client): + """Test error handling in action execution.""" + executor = ActionExecutor(redis_client=mock_redis_client) + + actions = [ + { + "type": "redis_save_image", + "region": "car", + "key": "inference:{session_id}", + "expire_seconds": 600 + }, + { + "type": "redis_save_image", # This one will fail + "region": "truck", # Region not detected + "key": "inference:truck:{session_id}", + "expire_seconds": 600 + } + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + } + # No truck region + } + + context = { + "session_id": "session_123", + "frame_data": np.zeros((480, 640, 3), dtype=np.uint8) + } + + # Mock Redis operations + mock_redis_client.set.return_value = True + mock_redis_client.expire.return_value = True + + results = await executor.execute_actions(actions, regions, context) + + assert len(results) == 2 + assert results[0].success is True # Car action succeeds + assert results[1].success is False # Truck action fails + assert "Region 'truck' not found" in results[1].error + + @pytest.mark.asyncio + async def test_execute_actions_timeout(self, mock_redis_client): + """Test action execution timeout.""" + config = {"action_timeout": 0.001} # Very short timeout + executor = ActionExecutor(redis_client=mock_redis_client, config=config) + + def slow_redis_operation(*args, **kwargs): + import time + time.sleep(1) # Longer than timeout + return True + + mock_redis_client.set.side_effect = slow_redis_operation + + actions = [ + { + "type": "redis_save_image", + "region": "car", + "key": "inference:{session_id}", + "expire_seconds": 600 + } + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = { + "session_id": "session_123", + "frame_data": np.zeros((480, 640, 3), dtype=np.uint8) + } + + results = await executor.execute_actions(actions, regions, context) + + assert len(results) == 1 + assert results[0].success is False + assert "timeout" in results[0].error.lower() + + @pytest.mark.asyncio + async def test_execute_redis_publish_action(self, mock_redis_client): + """Test executing Redis publish action.""" + executor = ActionExecutor(redis_client=mock_redis_client) + + actions = [ + { + "type": "redis_publish", + "channel": "detections:{camera_id}", + "message": { + "camera_id": "{camera_id}", + "detection_class": "{class}", + "confidence": "{confidence}", + "timestamp": "{timestamp}" + } + } + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = { + "camera_id": "camera_001", + "class": "car", + "confidence": 0.9, + "timestamp": "1640995200000" + } + + mock_redis_client.publish.return_value = 1 + + results = await executor.execute_actions(actions, regions, context) + + assert len(results) == 1 + assert results[0].success is True + assert results[0].action_type == ActionType.REDIS_PUBLISH + + # Verify publish call + mock_redis_client.publish.assert_called_once() + call_args = mock_redis_client.publish.call_args + assert call_args[0][0] == "detections:camera_001" # Channel + + # Message should be JSON + message = call_args[0][1] + parsed_message = json.loads(message) + assert parsed_message["camera_id"] == "camera_001" + assert parsed_message["detection_class"] == "car" + + @pytest.mark.asyncio + async def test_execute_conditional_action(self): + """Test executing conditional actions.""" + executor = ActionExecutor() + + actions = [ + { + "type": "conditional", + "condition": "{confidence} > 0.8", + "actions": [ + { + "type": "log", + "message": "High confidence detection: {class} ({confidence})" + } + ] + } + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.95, # High confidence + "detection": DetectionResult("car", 0.95, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = { + "class": "car", + "confidence": 0.95 + } + + with patch('logging.info') as mock_log: + results = await executor.execute_actions(actions, regions, context) + + assert len(results) == 1 + assert results[0].success is True + + # Should have logged the message + mock_log.assert_called_once() + log_message = mock_log.call_args[0][0] + assert "High confidence detection: car (0.95)" in log_message + + def test_crop_region_from_frame(self, mock_frame): + """Test cropping region from frame.""" + executor = ActionExecutor() + + detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + + cropped = executor._crop_region_from_frame(mock_frame, detection.bbox) + + assert cropped.shape == (200, 200, 3) # 400-200, 300-100 + + def test_encode_image_base64(self, mock_frame): + """Test encoding image to base64.""" + executor = ActionExecutor() + + # Crop a small region + cropped_frame = mock_frame[200:400, 100:300] # 200x200 region + + with patch('cv2.imencode') as mock_imencode: + # Mock successful encoding + mock_imencode.return_value = (True, np.array([1, 2, 3, 4], dtype=np.uint8)) + + encoded = executor._encode_image_base64(cropped_frame, format="jpeg") + + # Should return base64 string + assert isinstance(encoded, str) + assert len(encoded) > 0 + + # Verify encoding call + mock_imencode.assert_called_once() + assert mock_imencode.call_args[0][0] == '.jpg' + + def test_build_insert_query(self): + """Test building INSERT SQL query.""" + executor = ActionExecutor() + + table = "detections" + fields = { + "camera_id": "camera_001", + "detection_class": "car", + "confidence": 0.9, + "created_at": "NOW()" + } + + query, values = executor._build_insert_query(table, fields) + + assert "INSERT INTO detections" in query + assert "camera_id, detection_class, confidence, created_at" in query + assert "VALUES (%s, %s, %s, NOW())" in query + assert values == ["camera_001", "car", 0.9] + + def test_build_update_query(self): + """Test building UPDATE SQL query.""" + executor = ActionExecutor() + + table = "car_info" + fields = { + "car_brand": "Toyota", + "car_body_type": "Sedan", + "updated_at": "NOW()" + } + key_field = "session_id" + key_value = "session_123" + + query, values = executor._build_update_query(table, fields, key_field, key_value) + + assert "UPDATE car_info SET" in query + assert "car_brand = %s" in query + assert "car_body_type = %s" in query + assert "updated_at = NOW()" in query + assert "WHERE session_id = %s" in query + assert values == ["Toyota", "Sedan", "session_123"] + + def test_evaluate_condition(self): + """Test evaluating conditional expressions.""" + executor = ActionExecutor() + + context = { + "confidence": 0.85, + "class": "car", + "area": 40000 + } + + # Simple comparisons + assert executor._evaluate_condition("{confidence} > 0.8", context) is True + assert executor._evaluate_condition("{confidence} < 0.8", context) is False + assert executor._evaluate_condition("{confidence} >= 0.85", context) is True + assert executor._evaluate_condition("{confidence} == 0.85", context) is True + + # String comparisons + assert executor._evaluate_condition("{class} == 'car'", context) is True + assert executor._evaluate_condition("{class} != 'truck'", context) is True + + # Complex conditions + assert executor._evaluate_condition("{confidence} > 0.8 and {area} > 30000", context) is True + assert executor._evaluate_condition("{confidence} > 0.9 or {area} > 30000", context) is True + assert executor._evaluate_condition("{confidence} > 0.9 and {area} < 30000", context) is False + + def test_validate_action_config(self): + """Test action configuration validation.""" + executor = ActionExecutor() + + # Valid Redis action + valid_redis = { + "type": "redis_save_image", + "region": "car", + "key": "inference:{session_id}", + "expire_seconds": 600 + } + assert executor._validate_action_config(valid_redis) is True + + # Invalid action (missing required fields) + invalid_action = { + "type": "redis_save_image" + # Missing region and key + } + with pytest.raises(ActionError): + executor._validate_action_config(invalid_action) + + # Unknown action type + unknown_action = { + "type": "unknown_action_type", + "some_field": "value" + } + with pytest.raises(ActionError): + executor._validate_action_config(unknown_action) + + +class TestActionExecutorIntegration: + """Integration tests for action execution.""" + + @pytest.mark.asyncio + async def test_complete_detection_workflow(self, mock_redis_client, mock_frame): + """Test complete detection workflow with multiple actions.""" + # Mock database manager + mock_db_manager = Mock() + mock_db_manager.execute_query = AsyncMock(return_value=True) + + executor = ActionExecutor( + redis_client=mock_redis_client, + db_manager=mock_db_manager + ) + + # Complete action workflow + actions = [ + # Save cropped image to Redis + { + "type": "redis_save_image", + "region": "car", + "key": "inference:{camera_id}:{timestamp}:{session_id}:car", + "expire_seconds": 600 + }, + # Insert initial detection record + { + "type": "postgresql_insert", + "table": "car_detections", + "fields": { + "camera_id": "{camera_id}", + "session_id": "{session_id}", + "detection_class": "{class}", + "confidence": "{confidence}", + "bbox_x1": "{bbox.x1}", + "bbox_y1": "{bbox.y1}", + "bbox_x2": "{bbox.x2}", + "bbox_y2": "{bbox.y2}", + "created_at": "NOW()" + } + }, + # Publish detection event + { + "type": "redis_publish", + "channel": "detections:{camera_id}", + "message": { + "event": "car_detected", + "camera_id": "{camera_id}", + "session_id": "{session_id}", + "timestamp": "{timestamp}" + } + } + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.92, + "detection": DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = { + "camera_id": "camera_001", + "session_id": "session_123", + "timestamp": "1640995200000", + "class": "car", + "confidence": 0.92, + "bbox": {"x1": 100, "y1": 200, "x2": 300, "y2": 400}, + "frame_data": mock_frame + } + + # Mock all Redis operations + mock_redis_client.set.return_value = True + mock_redis_client.expire.return_value = True + mock_redis_client.publish.return_value = 1 + + results = await executor.execute_actions(actions, regions, context) + + # All actions should succeed + assert len(results) == 3 + assert all(result.success for result in results) + + # Verify all operations were called + mock_redis_client.set.assert_called_once() # Image save + mock_redis_client.expire.assert_called_once() # Set expiry + mock_redis_client.publish.assert_called_once() # Publish event + mock_db_manager.execute_query.assert_called_once() # Database insert + + @pytest.mark.asyncio + async def test_branch_dependent_actions(self, mock_database_connection): + """Test actions that depend on branch results.""" + mock_db_manager = Mock() + mock_db_manager.execute_query = AsyncMock(return_value=True) + + executor = ActionExecutor(db_manager=mock_db_manager) + + # Action that waits for classification branches + actions = [ + { + "type": "postgresql_update_combined", + "table": "car_info", + "key_field": "session_id", + "fields": { + "car_brand": "{car_brand_cls.brand}", + "car_body_type": "{car_bodytype_cls.body_type}", + "car_color": "{car_color_cls.color}", + "confidence_brand": "{car_brand_cls.confidence}", + "confidence_bodytype": "{car_bodytype_cls.confidence}", + "updated_at": "NOW()" + }, + "waitForBranches": ["car_brand_cls", "car_bodytype_cls", "car_color_cls"] + } + ] + + regions = {} + + context = { + "session_id": "session_123" + } + + # Simulated branch results + branch_results = { + "car_brand_cls": {"brand": "Toyota", "confidence": 0.87}, + "car_bodytype_cls": {"body_type": "Sedan", "confidence": 0.82}, + "car_color_cls": {"color": "Red", "confidence": 0.79} + } + + results = await executor.execute_actions(actions, regions, context, branch_results) + + assert len(results) == 1 + assert results[0].success is True + assert results[0].action_type == ActionType.POSTGRESQL_UPDATE + + # Verify database call with all branch data + mock_db_manager.execute_query.assert_called_once() + call_args = mock_db_manager.execute_query.call_args + query = call_args[0][0] + values = call_args[0][1] + + assert "UPDATE car_info SET" in query + assert "car_brand = %s" in query + assert "car_body_type = %s" in query + assert "car_color = %s" in query + assert "WHERE session_id = %s" in query + + assert "Toyota" in values + assert "Sedan" in values + assert "Red" in values + assert "session_123" in values \ No newline at end of file diff --git a/tests/unit/pipeline/test_field_mapper.py b/tests/unit/pipeline/test_field_mapper.py new file mode 100644 index 0000000..4b61a89 --- /dev/null +++ b/tests/unit/pipeline/test_field_mapper.py @@ -0,0 +1,786 @@ +""" +Unit tests for field mapping and template resolution. +""" +import pytest +from unittest.mock import Mock, patch +from datetime import datetime +import json + +from detector_worker.pipeline.field_mapper import ( + FieldMapper, + MappingContext, + TemplateResolver, + FieldMappingError, + NestedFieldAccessor +) +from detector_worker.detection.detection_result import DetectionResult, BoundingBox + + +class TestNestedFieldAccessor: + """Test nested field access functionality.""" + + def test_get_nested_value_simple(self): + """Test getting simple nested values.""" + data = { + "user": { + "name": "John", + "age": 30, + "address": { + "city": "New York", + "zip": "10001" + } + } + } + + accessor = NestedFieldAccessor() + + assert accessor.get_nested_value(data, "user.name") == "John" + assert accessor.get_nested_value(data, "user.age") == 30 + assert accessor.get_nested_value(data, "user.address.city") == "New York" + assert accessor.get_nested_value(data, "user.address.zip") == "10001" + + def test_get_nested_value_array_access(self): + """Test accessing array elements.""" + data = { + "results": [ + {"score": 0.9, "label": "car"}, + {"score": 0.8, "label": "truck"} + ], + "bbox": [100, 200, 300, 400] + } + + accessor = NestedFieldAccessor() + + assert accessor.get_nested_value(data, "results[0].score") == 0.9 + assert accessor.get_nested_value(data, "results[0].label") == "car" + assert accessor.get_nested_value(data, "results[1].score") == 0.8 + assert accessor.get_nested_value(data, "bbox[0]") == 100 + assert accessor.get_nested_value(data, "bbox[3]") == 400 + + def test_get_nested_value_nonexistent_path(self): + """Test accessing non-existent paths.""" + data = {"user": {"name": "John"}} + accessor = NestedFieldAccessor() + + assert accessor.get_nested_value(data, "user.nonexistent") is None + assert accessor.get_nested_value(data, "nonexistent.field") is None + assert accessor.get_nested_value(data, "user.address.city") is None + + def test_get_nested_value_with_default(self): + """Test getting nested values with default fallback.""" + data = {"user": {"name": "John"}} + accessor = NestedFieldAccessor() + + assert accessor.get_nested_value(data, "user.age", default=25) == 25 + assert accessor.get_nested_value(data, "user.name", default="Unknown") == "John" + + def test_set_nested_value(self): + """Test setting nested values.""" + data = {} + accessor = NestedFieldAccessor() + + accessor.set_nested_value(data, "user.name", "John") + assert data["user"]["name"] == "John" + + accessor.set_nested_value(data, "user.address.city", "New York") + assert data["user"]["address"]["city"] == "New York" + + accessor.set_nested_value(data, "scores[0]", 0.95) + assert data["scores"][0] == 0.95 + + def test_set_nested_value_overwrite(self): + """Test overwriting existing nested values.""" + data = {"user": {"name": "John", "age": 30}} + accessor = NestedFieldAccessor() + + accessor.set_nested_value(data, "user.name", "Jane") + assert data["user"]["name"] == "Jane" + assert data["user"]["age"] == 30 # Should not affect other fields + + +class TestTemplateResolver: + """Test template string resolution.""" + + def test_resolve_simple_template(self): + """Test resolving simple template variables.""" + resolver = TemplateResolver() + + template = "Hello {name}, you are {age} years old" + context = {"name": "John", "age": 30} + + result = resolver.resolve(template, context) + assert result == "Hello John, you are 30 years old" + + def test_resolve_nested_template(self): + """Test resolving nested field templates.""" + resolver = TemplateResolver() + + template = "User: {user.name} from {user.address.city}" + context = { + "user": { + "name": "John", + "address": {"city": "New York", "zip": "10001"} + } + } + + result = resolver.resolve(template, context) + assert result == "User: John from New York" + + def test_resolve_array_template(self): + """Test resolving array element templates.""" + resolver = TemplateResolver() + + template = "First result: {results[0].label} ({results[0].score})" + context = { + "results": [ + {"label": "car", "score": 0.95}, + {"label": "truck", "score": 0.87} + ] + } + + result = resolver.resolve(template, context) + assert result == "First result: car (0.95)" + + def test_resolve_missing_variables(self): + """Test resolving templates with missing variables.""" + resolver = TemplateResolver() + + template = "Hello {name}, you are {age} years old" + context = {"name": "John"} # Missing age + + with pytest.raises(FieldMappingError) as exc_info: + resolver.resolve(template, context) + + assert "Variable 'age' not found" in str(exc_info.value) + + def test_resolve_with_defaults(self): + """Test resolving templates with default values.""" + resolver = TemplateResolver(allow_missing=True) + + template = "Hello {name}, you are {age|25} years old" + context = {"name": "John"} # Missing age, should use default + + result = resolver.resolve(template, context) + assert result == "Hello John, you are 25 years old" + + def test_resolve_complex_template(self): + """Test resolving complex templates with multiple variable types.""" + resolver = TemplateResolver() + + template = "{camera_id}:{timestamp}:{session_id}:{results[0].class}_{bbox[0]}_{bbox[1]}" + context = { + "camera_id": "cam001", + "timestamp": 1640995200000, + "session_id": "sess123", + "results": [{"class": "car", "confidence": 0.95}], + "bbox": [100, 200, 300, 400] + } + + result = resolver.resolve(template, context) + assert result == "cam001:1640995200000:sess123:car_100_200" + + def test_resolve_conditional_template(self): + """Test resolving conditional templates.""" + resolver = TemplateResolver() + + # Simple conditional + template = "{name} is {age > 18 ? 'adult' : 'minor'}" + + context_adult = {"name": "John", "age": 25} + result_adult = resolver.resolve(template, context_adult) + assert result_adult == "John is adult" + + context_minor = {"name": "Jane", "age": 16} + result_minor = resolver.resolve(template, context_minor) + assert result_minor == "Jane is minor" + + def test_escape_braces(self): + """Test escaping braces in templates.""" + resolver = TemplateResolver() + + template = "Literal {{braces}} and variable {name}" + context = {"name": "John"} + + result = resolver.resolve(template, context) + assert result == "Literal {braces} and variable John" + + +class TestMappingContext: + """Test mapping context data structure.""" + + def test_creation(self): + """Test mapping context creation.""" + detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000) + + context = MappingContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + detection=detection, + timestamp=1640995200000 + ) + + assert context.camera_id == "camera_001" + assert context.display_id == "display_001" + assert context.session_id == "session_123" + assert context.detection == detection + assert context.timestamp == 1640995200000 + assert context.branch_results == {} + assert context.metadata == {} + + def test_add_branch_result(self): + """Test adding branch results to context.""" + context = MappingContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123" + ) + + context.add_branch_result("car_brand_cls", {"brand": "Toyota", "confidence": 0.87}) + context.add_branch_result("car_bodytype_cls", {"body_type": "Sedan", "confidence": 0.82}) + + assert len(context.branch_results) == 2 + assert context.branch_results["car_brand_cls"]["brand"] == "Toyota" + assert context.branch_results["car_bodytype_cls"]["body_type"] == "Sedan" + + def test_to_dict(self): + """Test converting context to dictionary.""" + detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000) + + context = MappingContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + detection=detection, + timestamp=1640995200000 + ) + + context.add_branch_result("car_brand_cls", {"brand": "Toyota"}) + context.add_metadata("model_id", "yolo_v8") + + context_dict = context.to_dict() + + assert context_dict["camera_id"] == "camera_001" + assert context_dict["display_id"] == "display_001" + assert context_dict["session_id"] == "session_123" + assert context_dict["timestamp"] == 1640995200000 + assert context_dict["class"] == "car" + assert context_dict["confidence"] == 0.9 + assert context_dict["track_id"] == 1001 + assert context_dict["bbox"]["x1"] == 100 + assert context_dict["car_brand_cls"]["brand"] == "Toyota" + assert context_dict["model_id"] == "yolo_v8" + + def test_add_metadata(self): + """Test adding metadata to context.""" + context = MappingContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123" + ) + + context.add_metadata("model_version", "v2.1") + context.add_metadata("inference_time", 0.15) + + assert context.metadata["model_version"] == "v2.1" + assert context.metadata["inference_time"] == 0.15 + + +class TestFieldMapper: + """Test field mapping functionality.""" + + def test_initialization(self): + """Test field mapper initialization.""" + mapper = FieldMapper() + + assert isinstance(mapper.template_resolver, TemplateResolver) + assert isinstance(mapper.field_accessor, NestedFieldAccessor) + + def test_map_fields_simple(self): + """Test simple field mapping.""" + mapper = FieldMapper() + + field_mappings = { + "camera_id": "{camera_id}", + "detection_class": "{class}", + "confidence_score": "{confidence}", + "track_identifier": "{track_id}" + } + + detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000) + context = MappingContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + detection=detection, + timestamp=1640995200000 + ) + + mapped_fields = mapper.map_fields(field_mappings, context) + + assert mapped_fields["camera_id"] == "camera_001" + assert mapped_fields["detection_class"] == "car" + assert mapped_fields["confidence_score"] == 0.92 + assert mapped_fields["track_identifier"] == 1001 + + def test_map_fields_with_branch_results(self): + """Test field mapping with branch results.""" + mapper = FieldMapper() + + field_mappings = { + "car_brand": "{car_brand_cls.brand}", + "car_model": "{car_brand_cls.model}", + "body_type": "{car_bodytype_cls.body_type}", + "brand_confidence": "{car_brand_cls.confidence}", + "combined_info": "{car_brand_cls.brand} {car_bodytype_cls.body_type}" + } + + context = MappingContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123" + ) + + context.add_branch_result("car_brand_cls", { + "brand": "Toyota", + "model": "Camry", + "confidence": 0.87 + }) + context.add_branch_result("car_bodytype_cls", { + "body_type": "Sedan", + "confidence": 0.82 + }) + + mapped_fields = mapper.map_fields(field_mappings, context) + + assert mapped_fields["car_brand"] == "Toyota" + assert mapped_fields["car_model"] == "Camry" + assert mapped_fields["body_type"] == "Sedan" + assert mapped_fields["brand_confidence"] == 0.87 + assert mapped_fields["combined_info"] == "Toyota Sedan" + + def test_map_fields_bbox_access(self): + """Test field mapping with bounding box access.""" + mapper = FieldMapper() + + field_mappings = { + "bbox_x1": "{bbox.x1}", + "bbox_y1": "{bbox.y1}", + "bbox_x2": "{bbox.x2}", + "bbox_y2": "{bbox.y2}", + "bbox_width": "{bbox.width}", + "bbox_height": "{bbox.height}", + "bbox_area": "{bbox.area}", + "bbox_center_x": "{bbox.center_x}", + "bbox_center_y": "{bbox.center_y}" + } + + detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + context = MappingContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + detection=detection + ) + + mapped_fields = mapper.map_fields(field_mappings, context) + + assert mapped_fields["bbox_x1"] == 100 + assert mapped_fields["bbox_y1"] == 200 + assert mapped_fields["bbox_x2"] == 300 + assert mapped_fields["bbox_y2"] == 400 + assert mapped_fields["bbox_width"] == 200 # 300 - 100 + assert mapped_fields["bbox_height"] == 200 # 400 - 200 + assert mapped_fields["bbox_area"] == 40000 # 200 * 200 + assert mapped_fields["bbox_center_x"] == 200 # (100 + 300) / 2 + assert mapped_fields["bbox_center_y"] == 300 # (200 + 400) / 2 + + def test_map_fields_with_sql_functions(self): + """Test field mapping with SQL function templates.""" + mapper = FieldMapper() + + field_mappings = { + "created_at": "NOW()", + "updated_at": "CURRENT_TIMESTAMP", + "uuid_field": "UUID()", + "json_data": "JSON_OBJECT('class', '{class}', 'confidence', {confidence})" + } + + detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + context = MappingContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + detection=detection + ) + + mapped_fields = mapper.map_fields(field_mappings, context) + + # SQL functions should pass through unchanged + assert mapped_fields["created_at"] == "NOW()" + assert mapped_fields["updated_at"] == "CURRENT_TIMESTAMP" + assert mapped_fields["uuid_field"] == "UUID()" + assert mapped_fields["json_data"] == "JSON_OBJECT('class', 'car', 'confidence', 0.9)" + + def test_map_fields_missing_branch_data(self): + """Test field mapping with missing branch data.""" + mapper = FieldMapper() + + field_mappings = { + "car_brand": "{car_brand_cls.brand}", + "car_model": "{nonexistent_branch.model}" + } + + context = MappingContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123" + ) + + # Only add one branch result + context.add_branch_result("car_brand_cls", {"brand": "Toyota"}) + + with pytest.raises(FieldMappingError) as exc_info: + mapper.map_fields(field_mappings, context) + + assert "nonexistent_branch.model" in str(exc_info.value) + + def test_map_fields_with_defaults(self): + """Test field mapping with default values.""" + mapper = FieldMapper(allow_missing=True) + + field_mappings = { + "car_brand": "{car_brand_cls.brand|Unknown}", + "car_model": "{car_brand_cls.model|N/A}", + "confidence": "{confidence|0.0}" + } + + context = MappingContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123" + ) + + # Don't add any branch results + mapped_fields = mapper.map_fields(field_mappings, context) + + assert mapped_fields["car_brand"] == "Unknown" + assert mapped_fields["car_model"] == "N/A" + assert mapped_fields["confidence"] == "0.0" + + def test_map_database_fields(self): + """Test mapping fields for database operations.""" + mapper = FieldMapper() + + # Database field mapping + db_field_mappings = { + "camera_id": "{camera_id}", + "session_id": "{session_id}", + "detection_timestamp": "{timestamp}", + "object_class": "{class}", + "detection_confidence": "{confidence}", + "track_id": "{track_id}", + "bbox_json": "JSON_OBJECT('x1', {bbox.x1}, 'y1', {bbox.y1}, 'x2', {bbox.x2}, 'y2', {bbox.y2})", + "car_brand": "{car_brand_cls.brand}", + "car_body_type": "{car_bodytype_cls.body_type}", + "license_plate": "{license_ocr.text}", + "created_at": "NOW()", + "updated_at": "NOW()" + } + + detection = DetectionResult("car", 0.93, BoundingBox(150, 250, 350, 450), 2001, 1640995300000) + context = MappingContext( + camera_id="camera_002", + display_id="display_002", + session_id="session_456", + detection=detection, + timestamp=1640995300000 + ) + + # Add branch results + context.add_branch_result("car_brand_cls", {"brand": "Honda", "confidence": 0.89}) + context.add_branch_result("car_bodytype_cls", {"body_type": "SUV", "confidence": 0.85}) + context.add_branch_result("license_ocr", {"text": "ABC-123", "confidence": 0.76}) + + mapped_fields = mapper.map_fields(db_field_mappings, context) + + assert mapped_fields["camera_id"] == "camera_002" + assert mapped_fields["session_id"] == "session_456" + assert mapped_fields["detection_timestamp"] == 1640995300000 + assert mapped_fields["object_class"] == "car" + assert mapped_fields["detection_confidence"] == 0.93 + assert mapped_fields["track_id"] == 2001 + assert mapped_fields["bbox_json"] == "JSON_OBJECT('x1', 150, 'y1', 250, 'x2', 350, 'y2', 450)" + assert mapped_fields["car_brand"] == "Honda" + assert mapped_fields["car_body_type"] == "SUV" + assert mapped_fields["license_plate"] == "ABC-123" + assert mapped_fields["created_at"] == "NOW()" + assert mapped_fields["updated_at"] == "NOW()" + + def test_map_redis_keys(self): + """Test mapping Redis key templates.""" + mapper = FieldMapper() + + key_templates = [ + "inference:{camera_id}:{timestamp}:{session_id}:car", + "detection:{display_id}:{track_id}", + "cropped_image:{camera_id}:{session_id}:{class}", + "metadata:{session_id}:brands:{car_brand_cls.brand}", + "tracking:{camera_id}:active_tracks" + ] + + detection = DetectionResult("car", 0.88, BoundingBox(200, 300, 400, 500), 3001, 1640995400000) + context = MappingContext( + camera_id="camera_003", + display_id="display_003", + session_id="session_789", + detection=detection, + timestamp=1640995400000 + ) + + context.add_branch_result("car_brand_cls", {"brand": "Ford"}) + + mapped_keys = [mapper.map_template(template, context) for template in key_templates] + + expected_keys = [ + "inference:camera_003:1640995400000:session_789:car", + "detection:display_003:3001", + "cropped_image:camera_003:session_789:car", + "metadata:session_789:brands:Ford", + "tracking:camera_003:active_tracks" + ] + + assert mapped_keys == expected_keys + + def test_map_template(self): + """Test single template mapping.""" + mapper = FieldMapper() + + template = "Camera {camera_id} detected {class} with {confidence:.2f} confidence at {timestamp}" + + detection = DetectionResult("truck", 0.876, BoundingBox(100, 150, 300, 350), 4001, 1640995500000) + context = MappingContext( + camera_id="camera_004", + display_id="display_004", + session_id="session_101", + detection=detection, + timestamp=1640995500000 + ) + + result = mapper.map_template(template, context) + expected = "Camera camera_004 detected truck with 0.88 confidence at 1640995500000" + + assert result == expected + + def test_validate_field_mappings(self): + """Test field mapping validation.""" + mapper = FieldMapper() + + # Valid mappings + valid_mappings = { + "camera_id": "{camera_id}", + "class": "{class}", + "confidence": "{confidence}", + "created_at": "NOW()" + } + + assert mapper.validate_field_mappings(valid_mappings) is True + + # Invalid mappings (malformed templates) + invalid_mappings = { + "camera_id": "{camera_id", # Missing closing brace + "class": "class}", # Missing opening brace + "confidence": "{nonexistent_field}" # This might be valid depending on context + } + + with pytest.raises(FieldMappingError): + mapper.validate_field_mappings(invalid_mappings) + + def test_create_context_from_detection(self): + """Test creating mapping context from detection result.""" + mapper = FieldMapper() + + detection = DetectionResult("car", 0.95, BoundingBox(50, 100, 250, 300), 5001, 1640995600000) + + context = mapper.create_context_from_detection( + detection, + camera_id="camera_005", + display_id="display_005", + session_id="session_202" + ) + + assert context.camera_id == "camera_005" + assert context.display_id == "display_005" + assert context.session_id == "session_202" + assert context.detection == detection + assert context.timestamp == 1640995600000 + + def test_format_sql_value(self): + """Test SQL value formatting.""" + mapper = FieldMapper() + + # String values should be quoted + assert mapper.format_sql_value("test_string") == "'test_string'" + assert mapper.format_sql_value("John's car") == "'John''s car'" # Escape quotes + + # Numeric values should not be quoted + assert mapper.format_sql_value(42) == "42" + assert mapper.format_sql_value(3.14) == "3.14" + assert mapper.format_sql_value(0.95) == "0.95" + + # Boolean values + assert mapper.format_sql_value(True) == "TRUE" + assert mapper.format_sql_value(False) == "FALSE" + + # None/NULL values + assert mapper.format_sql_value(None) == "NULL" + + # SQL functions should pass through + assert mapper.format_sql_value("NOW()") == "NOW()" + assert mapper.format_sql_value("CURRENT_TIMESTAMP") == "CURRENT_TIMESTAMP" + + +class TestFieldMapperIntegration: + """Integration tests for field mapping.""" + + def test_complete_mapping_workflow(self): + """Test complete field mapping workflow.""" + mapper = FieldMapper() + + # Simulate complete detection workflow + detection = DetectionResult("car", 0.91, BoundingBox(120, 180, 320, 380), 6001, 1640995700000) + context = MappingContext( + camera_id="camera_006", + display_id="display_006", + session_id="session_303", + detection=detection, + timestamp=1640995700000 + ) + + # Add comprehensive branch results + context.add_branch_result("car_brand_cls", { + "brand": "BMW", + "model": "X5", + "confidence": 0.84, + "top3_brands": ["BMW", "Audi", "Mercedes"] + }) + + context.add_branch_result("car_bodytype_cls", { + "body_type": "SUV", + "confidence": 0.79, + "features": ["tall", "4_doors", "roof_rails"] + }) + + context.add_branch_result("car_color_cls", { + "color": "Black", + "confidence": 0.73, + "rgb_values": [20, 25, 30] + }) + + context.add_branch_result("license_ocr", { + "text": "XYZ-789", + "confidence": 0.68, + "region_bbox": [150, 320, 290, 360] + }) + + # Database field mapping + db_mappings = { + "camera_id": "{camera_id}", + "display_id": "{display_id}", + "session_id": "{session_id}", + "detection_timestamp": "{timestamp}", + "object_class": "{class}", + "detection_confidence": "{confidence}", + "track_id": "{track_id}", + "bbox_x1": "{bbox.x1}", + "bbox_y1": "{bbox.y1}", + "bbox_x2": "{bbox.x2}", + "bbox_y2": "{bbox.y2}", + "bbox_area": "{bbox.area}", + "car_brand": "{car_brand_cls.brand}", + "car_model": "{car_brand_cls.model}", + "car_body_type": "{car_bodytype_cls.body_type}", + "car_color": "{car_color_cls.color}", + "license_plate": "{license_ocr.text}", + "brand_confidence": "{car_brand_cls.confidence}", + "bodytype_confidence": "{car_bodytype_cls.confidence}", + "color_confidence": "{car_color_cls.confidence}", + "license_confidence": "{license_ocr.confidence}", + "detection_summary": "{car_brand_cls.brand} {car_bodytype_cls.body_type} ({car_color_cls.color})", + "created_at": "NOW()", + "updated_at": "NOW()" + } + + mapped_db_fields = mapper.map_fields(db_mappings, context) + + # Verify all mappings + assert mapped_db_fields["camera_id"] == "camera_006" + assert mapped_db_fields["session_id"] == "session_303" + assert mapped_db_fields["object_class"] == "car" + assert mapped_db_fields["detection_confidence"] == 0.91 + assert mapped_db_fields["track_id"] == 6001 + assert mapped_db_fields["bbox_area"] == 40000 # 200 * 200 + assert mapped_db_fields["car_brand"] == "BMW" + assert mapped_db_fields["car_model"] == "X5" + assert mapped_db_fields["car_body_type"] == "SUV" + assert mapped_db_fields["car_color"] == "Black" + assert mapped_db_fields["license_plate"] == "XYZ-789" + assert mapped_db_fields["detection_summary"] == "BMW SUV (Black)" + + # Redis key mapping + redis_key_templates = [ + "detection:{camera_id}:{session_id}:main", + "cropped:{camera_id}:{session_id}:car_image", + "metadata:{session_id}:brand:{car_brand_cls.brand}", + "tracking:{camera_id}:track_{track_id}", + "classification:{session_id}:results" + ] + + mapped_redis_keys = [ + mapper.map_template(template, context) + for template in redis_key_templates + ] + + expected_redis_keys = [ + "detection:camera_006:session_303:main", + "cropped:camera_006:session_303:car_image", + "metadata:session_303:brand:BMW", + "tracking:camera_006:track_6001", + "classification:session_303:results" + ] + + assert mapped_redis_keys == expected_redis_keys + + def test_error_handling_and_recovery(self): + """Test error handling and recovery in field mapping.""" + mapper = FieldMapper(allow_missing=True) + + # Context with missing detection + context = MappingContext( + camera_id="camera_007", + display_id="display_007", + session_id="session_404" + ) + + # Partial branch results + context.add_branch_result("car_brand_cls", {"brand": "Unknown"}) + # Missing car_bodytype_cls branch + + # Field mappings with some missing data + mappings = { + "camera_id": "{camera_id}", + "detection_class": "{class|Unknown}", + "confidence": "{confidence|0.0}", + "car_brand": "{car_brand_cls.brand|N/A}", + "car_body_type": "{car_bodytype_cls.body_type|Unknown}", + "car_model": "{car_brand_cls.model|N/A}" + } + + mapped_fields = mapper.map_fields(mappings, context) + + assert mapped_fields["camera_id"] == "camera_007" + assert mapped_fields["detection_class"] == "Unknown" + assert mapped_fields["confidence"] == "0.0" + assert mapped_fields["car_brand"] == "Unknown" + assert mapped_fields["car_body_type"] == "Unknown" + assert mapped_fields["car_model"] == "N/A" \ No newline at end of file diff --git a/tests/unit/pipeline/test_pipeline_executor.py b/tests/unit/pipeline/test_pipeline_executor.py new file mode 100644 index 0000000..69b9a67 --- /dev/null +++ b/tests/unit/pipeline/test_pipeline_executor.py @@ -0,0 +1,921 @@ +""" +Unit tests for pipeline execution functionality. +""" +import pytest +import asyncio +import numpy as np +from unittest.mock import Mock, MagicMock, patch, AsyncMock +from concurrent.futures import ThreadPoolExecutor +import json + +from detector_worker.pipeline.pipeline_executor import ( + PipelineExecutor, + PipelineContext, + PipelineResult, + BranchResult, + ExecutionMode +) +from detector_worker.detection.detection_result import DetectionResult, BoundingBox +from detector_worker.core.exceptions import PipelineError, ModelError, ActionError + + +class TestPipelineContext: + """Test pipeline context data structure.""" + + def test_creation(self): + """Test pipeline context creation.""" + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=np.zeros((480, 640, 3), dtype=np.uint8) + ) + + assert context.camera_id == "camera_001" + assert context.display_id == "display_001" + assert context.session_id == "session_123" + assert context.timestamp == 1640995200000 + assert context.frame_data.shape == (480, 640, 3) + assert context.metadata == {} + assert context.crop_region is None + + def test_creation_with_crop_region(self): + """Test context creation with crop region.""" + crop_region = (100, 200, 300, 400) + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=np.zeros((480, 640, 3), dtype=np.uint8), + crop_region=crop_region + ) + + assert context.crop_region == crop_region + + def test_add_metadata(self): + """Test adding metadata to context.""" + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=np.zeros((480, 640, 3), dtype=np.uint8) + ) + + context.add_metadata("model_id", "yolo_v8") + context.add_metadata("confidence_threshold", 0.8) + + assert context.metadata["model_id"] == "yolo_v8" + assert context.metadata["confidence_threshold"] == 0.8 + + def test_get_cropped_frame(self): + """Test getting cropped frame.""" + frame = np.ones((480, 640, 3), dtype=np.uint8) * 255 + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=frame, + crop_region=(100, 200, 300, 400) + ) + + cropped = context.get_cropped_frame() + + assert cropped.shape == (200, 200, 3) # 400-200, 300-100 + assert np.all(cropped == 255) + + def test_get_cropped_frame_no_crop(self): + """Test getting frame when no crop region specified.""" + frame = np.ones((480, 640, 3), dtype=np.uint8) * 255 + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=frame + ) + + cropped = context.get_cropped_frame() + + assert np.array_equal(cropped, frame) + + +class TestBranchResult: + """Test branch execution result.""" + + def test_creation_success(self): + """Test successful branch result creation.""" + detections = [ + DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000) + ] + + result = BranchResult( + branch_id="car_brand_cls", + success=True, + detections=detections, + metadata={"brand": "Toyota"}, + execution_time=0.15 + ) + + assert result.branch_id == "car_brand_cls" + assert result.success is True + assert len(result.detections) == 1 + assert result.metadata["brand"] == "Toyota" + assert result.execution_time == 0.15 + assert result.error is None + + def test_creation_failure(self): + """Test failed branch result creation.""" + result = BranchResult( + branch_id="car_brand_cls", + success=False, + error="Model inference failed", + execution_time=0.05 + ) + + assert result.branch_id == "car_brand_cls" + assert result.success is False + assert result.detections == [] + assert result.metadata == {} + assert result.error == "Model inference failed" + + +class TestPipelineResult: + """Test pipeline execution result.""" + + def test_creation_success(self): + """Test successful pipeline result creation.""" + main_detections = [ + DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000) + ] + + branch_results = { + "car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1), + "car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12) + } + + result = PipelineResult( + success=True, + detections=main_detections, + branch_results=branch_results, + total_execution_time=0.5 + ) + + assert result.success is True + assert len(result.detections) == 1 + assert len(result.branch_results) == 2 + assert result.total_execution_time == 0.5 + assert result.error is None + + def test_creation_failure(self): + """Test failed pipeline result creation.""" + result = PipelineResult( + success=False, + error="Pipeline execution failed", + total_execution_time=0.1 + ) + + assert result.success is False + assert result.detections == [] + assert result.branch_results == {} + assert result.error == "Pipeline execution failed" + + def test_get_combined_results(self): + """Test getting combined results from all branches.""" + main_detections = [ + DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001, 1640995200000) + ] + + branch_results = { + "car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1), + "car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12) + } + + result = PipelineResult( + success=True, + detections=main_detections, + branch_results=branch_results, + total_execution_time=0.5 + ) + + combined = result.get_combined_results() + + assert "brand" in combined + assert "body_type" in combined + assert combined["brand"] == "Toyota" + assert combined["body_type"] == "Sedan" + + +class TestPipelineExecutor: + """Test pipeline execution functionality.""" + + def test_initialization(self): + """Test pipeline executor initialization.""" + executor = PipelineExecutor() + + assert isinstance(executor.thread_pool, ThreadPoolExecutor) + assert executor.max_workers == 4 + assert executor.execution_mode == ExecutionMode.PARALLEL + assert executor.timeout == 30.0 + + def test_initialization_custom_config(self): + """Test initialization with custom configuration.""" + config = { + "max_workers": 8, + "execution_mode": "sequential", + "timeout": 60.0 + } + + executor = PipelineExecutor(config) + + assert executor.max_workers == 8 + assert executor.execution_mode == ExecutionMode.SEQUENTIAL + assert executor.timeout == 60.0 + + @pytest.mark.asyncio + async def test_execute_pipeline_simple(self, mock_yolo_model, mock_frame): + """Test simple pipeline execution.""" + # Mock pipeline configuration + pipeline_config = { + "modelId": "car_detection_v1", + "modelFile": "car_detection.pt", + "expectedClasses": ["car"], + "triggerClasses": ["car"], + "minConfidence": 0.8, + "branches": [], + "actions": [] + } + + # Mock detection result + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0] + ]) + mock_result.boxes.id = torch.tensor([1001]) + + mock_yolo_model.track.return_value = [mock_result] + + executor = PipelineExecutor() + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=mock_frame + ) + + with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager: + mock_model_manager.return_value.get_model.return_value = mock_yolo_model + + result = await executor.execute_pipeline(pipeline_config, context) + + assert result.success is True + assert len(result.detections) == 1 + assert result.detections[0].class_name == "0" # Default class name + assert result.detections[0].confidence == 0.9 + + @pytest.mark.asyncio + async def test_execute_pipeline_with_branches(self, mock_yolo_model, mock_frame): + """Test pipeline execution with classification branches.""" + import torch + + # Mock main detection + mock_detection_result = Mock() + mock_detection_result.boxes = Mock() + mock_detection_result.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0] # car detection + ]) + mock_detection_result.boxes.id = torch.tensor([1001]) + + # Mock classification results + mock_brand_result = Mock() + mock_brand_result.probs = Mock() + mock_brand_result.probs.top1 = 2 # Toyota + mock_brand_result.probs.top1conf = 0.85 + + mock_bodytype_result = Mock() + mock_bodytype_result.probs = Mock() + mock_bodytype_result.probs.top1 = 1 # Sedan + mock_bodytype_result.probs.top1conf = 0.78 + + mock_yolo_model.track.return_value = [mock_detection_result] + mock_yolo_model.predict.return_value = [mock_brand_result] + + mock_brand_model = Mock() + mock_brand_model.predict.return_value = [mock_brand_result] + mock_brand_model.names = {0: "Honda", 1: "Ford", 2: "Toyota"} + + mock_bodytype_model = Mock() + mock_bodytype_model.predict.return_value = [mock_bodytype_result] + mock_bodytype_model.names = {0: "SUV", 1: "Sedan", 2: "Hatchback"} + + # Pipeline configuration with branches + pipeline_config = { + "modelId": "car_detection_v1", + "modelFile": "car_detection.pt", + "expectedClasses": ["car"], + "triggerClasses": ["car"], + "minConfidence": 0.8, + "branches": [ + { + "modelId": "car_brand_cls", + "modelFile": "car_brand.pt", + "triggerClasses": ["car"], + "minConfidence": 0.7, + "parallel": True, + "crop": True, + "cropClass": "car" + }, + { + "modelId": "car_bodytype_cls", + "modelFile": "car_bodytype.pt", + "triggerClasses": ["car"], + "minConfidence": 0.7, + "parallel": True, + "crop": True, + "cropClass": "car" + } + ], + "actions": [] + } + + executor = PipelineExecutor() + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=mock_frame + ) + + with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager: + def get_model_side_effect(model_id, camera_id): + if model_id == "car_detection_v1": + return mock_yolo_model + elif model_id == "car_brand_cls": + return mock_brand_model + elif model_id == "car_bodytype_cls": + return mock_bodytype_model + return None + + mock_model_manager.return_value.get_model.side_effect = get_model_side_effect + + result = await executor.execute_pipeline(pipeline_config, context) + + assert result.success is True + assert len(result.detections) == 1 + assert len(result.branch_results) == 2 + + # Check branch results + assert "car_brand_cls" in result.branch_results + assert "car_bodytype_cls" in result.branch_results + + brand_result = result.branch_results["car_brand_cls"] + assert brand_result.success is True + assert brand_result.metadata.get("brand") == "Toyota" + + bodytype_result = result.branch_results["car_bodytype_cls"] + assert bodytype_result.success is True + assert bodytype_result.metadata.get("body_type") == "Sedan" + + @pytest.mark.asyncio + async def test_execute_pipeline_sequential_mode(self, mock_yolo_model, mock_frame): + """Test pipeline execution in sequential mode.""" + import torch + + config = {"execution_mode": "sequential"} + executor = PipelineExecutor(config) + + # Mock detection result + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0] + ]) + mock_result.boxes.id = torch.tensor([1001]) + + mock_yolo_model.track.return_value = [mock_result] + + pipeline_config = { + "modelId": "car_detection_v1", + "modelFile": "car_detection.pt", + "expectedClasses": ["car"], + "triggerClasses": ["car"], + "minConfidence": 0.8, + "branches": [ + { + "modelId": "car_brand_cls", + "modelFile": "car_brand.pt", + "triggerClasses": ["car"], + "minConfidence": 0.7, + "parallel": False # Sequential execution + } + ], + "actions": [] + } + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=mock_frame + ) + + with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager: + mock_model_manager.return_value.get_model.return_value = mock_yolo_model + + result = await executor.execute_pipeline(pipeline_config, context) + + assert result.success is True + assert executor.execution_mode == ExecutionMode.SEQUENTIAL + + @pytest.mark.asyncio + async def test_execute_pipeline_with_actions(self, mock_yolo_model, mock_frame): + """Test pipeline execution with actions.""" + import torch + + # Mock detection result + mock_result = Mock() + mock_result.boxes = Mock() + mock_result.boxes.data = torch.tensor([ + [100, 200, 300, 400, 0.9, 0] + ]) + mock_result.boxes.id = torch.tensor([1001]) + + mock_yolo_model.track.return_value = [mock_result] + + # Pipeline configuration with actions + pipeline_config = { + "modelId": "car_detection_v1", + "modelFile": "car_detection.pt", + "expectedClasses": ["car"], + "triggerClasses": ["car"], + "minConfidence": 0.8, + "branches": [], + "actions": [ + { + "type": "redis_save_image", + "region": "car", + "key": "inference:{display_id}:{timestamp}:{session_id}", + "expire_seconds": 600 + }, + { + "type": "postgresql_insert", + "table": "detections", + "fields": { + "camera_id": "{camera_id}", + "detection_class": "{class}", + "confidence": "{confidence}" + } + } + ] + } + + executor = PipelineExecutor() + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=mock_frame + ) + + with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager, \ + patch('detector_worker.pipeline.action_executor.ActionExecutor') as mock_action_executor: + + mock_model_manager.return_value.get_model.return_value = mock_yolo_model + mock_action_executor.return_value.execute_actions = AsyncMock(return_value=True) + + result = await executor.execute_pipeline(pipeline_config, context) + + assert result.success is True + # Actions should be executed + mock_action_executor.return_value.execute_actions.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_pipeline_model_error(self, mock_frame): + """Test pipeline execution with model error.""" + pipeline_config = { + "modelId": "car_detection_v1", + "modelFile": "car_detection.pt", + "expectedClasses": ["car"], + "triggerClasses": ["car"], + "minConfidence": 0.8, + "branches": [], + "actions": [] + } + + executor = PipelineExecutor() + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=mock_frame + ) + + with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager: + # Model manager raises error + mock_model_manager.return_value.get_model.side_effect = ModelError("Model not found") + + result = await executor.execute_pipeline(pipeline_config, context) + + assert result.success is False + assert "Model not found" in result.error + + @pytest.mark.asyncio + async def test_execute_pipeline_timeout(self, mock_yolo_model, mock_frame): + """Test pipeline execution timeout.""" + import torch + + # Configure short timeout + config = {"timeout": 0.001} # Very short timeout + executor = PipelineExecutor(config) + + # Mock slow model inference + def slow_inference(*args, **kwargs): + import time + time.sleep(1) # Longer than timeout + mock_result = Mock() + mock_result.boxes = None + return [mock_result] + + mock_yolo_model.track.side_effect = slow_inference + + pipeline_config = { + "modelId": "car_detection_v1", + "modelFile": "car_detection.pt", + "expectedClasses": ["car"], + "triggerClasses": ["car"], + "minConfidence": 0.8, + "branches": [], + "actions": [] + } + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=mock_frame + ) + + with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager: + mock_model_manager.return_value.get_model.return_value = mock_yolo_model + + result = await executor.execute_pipeline(pipeline_config, context) + + assert result.success is False + assert "timeout" in result.error.lower() + + @pytest.mark.asyncio + async def test_execute_branch_parallel(self, mock_frame): + """Test parallel branch execution.""" + import torch + + # Mock classification model + mock_brand_model = Mock() + mock_result = Mock() + mock_result.probs = Mock() + mock_result.probs.top1 = 1 + mock_result.probs.top1conf = 0.85 + mock_brand_model.predict.return_value = [mock_result] + mock_brand_model.names = {0: "Honda", 1: "Toyota", 2: "Ford"} + + executor = PipelineExecutor() + + # Branch configuration + branch_config = { + "modelId": "car_brand_cls", + "modelFile": "car_brand.pt", + "triggerClasses": ["car"], + "minConfidence": 0.7, + "parallel": True, + "crop": True, + "cropClass": "car" + } + + # Mock detected regions + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=mock_frame + ) + + with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager: + mock_model_manager.return_value.get_model.return_value = mock_brand_model + + result = await executor._execute_branch(branch_config, regions, context) + + assert result.success is True + assert result.branch_id == "car_brand_cls" + assert result.metadata.get("brand") == "Toyota" + assert result.execution_time > 0 + + @pytest.mark.asyncio + async def test_execute_branch_no_trigger_class(self, mock_frame): + """Test branch execution when trigger class not detected.""" + executor = PipelineExecutor() + + branch_config = { + "modelId": "car_brand_cls", + "modelFile": "car_brand.pt", + "triggerClasses": ["car"], + "minConfidence": 0.7 + } + + # No car detected + regions = { + "truck": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("truck", 0.9, BoundingBox(100, 200, 300, 400), 1002) + } + } + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=mock_frame + ) + + result = await executor._execute_branch(branch_config, regions, context) + + assert result.success is False + assert "trigger class not detected" in result.error.lower() + + def test_wait_for_branches(self): + """Test waiting for specific branches to complete.""" + executor = PipelineExecutor() + + # Mock completed branch results + branch_results = { + "car_brand_cls": BranchResult("car_brand_cls", True, [], {"brand": "Toyota"}, 0.1), + "car_bodytype_cls": BranchResult("car_bodytype_cls", True, [], {"body_type": "Sedan"}, 0.12), + "license_ocr": BranchResult("license_ocr", True, [], {"license": "ABC123"}, 0.2) + } + + # Wait for specific branches + wait_for = ["car_brand_cls", "car_bodytype_cls"] + completed = executor._wait_for_branches(branch_results, wait_for, timeout=1.0) + + assert completed is True + + # Wait for non-existent branch (should timeout) + wait_for_missing = ["car_brand_cls", "nonexistent_branch"] + completed = executor._wait_for_branches(branch_results, wait_for_missing, timeout=0.1) + + assert completed is False + + def test_validate_pipeline_config(self): + """Test pipeline configuration validation.""" + executor = PipelineExecutor() + + # Valid configuration + valid_config = { + "modelId": "car_detection_v1", + "modelFile": "car_detection.pt", + "expectedClasses": ["car"], + "triggerClasses": ["car"], + "minConfidence": 0.8 + } + + assert executor._validate_pipeline_config(valid_config) is True + + # Invalid configuration (missing required fields) + invalid_config = { + "modelFile": "car_detection.pt" + # Missing modelId + } + + with pytest.raises(PipelineError): + executor._validate_pipeline_config(invalid_config) + + def test_crop_frame_for_detection(self, mock_frame): + """Test frame cropping for detection.""" + executor = PipelineExecutor() + + detection = DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + + cropped = executor._crop_frame_for_detection(mock_frame, detection) + + assert cropped.shape == (200, 200, 3) # 400-200, 300-100 + + def test_crop_frame_invalid_bounds(self, mock_frame): + """Test frame cropping with invalid bounds.""" + executor = PipelineExecutor() + + # Detection outside frame bounds + detection = DetectionResult("car", 0.9, BoundingBox(-100, -200, 50, 100), 1001) + + cropped = executor._crop_frame_for_detection(mock_frame, detection) + + # Should handle bounds gracefully + assert cropped.shape[0] > 0 + assert cropped.shape[1] > 0 + + +class TestPipelineExecutorPerformance: + """Test pipeline executor performance and optimization.""" + + @pytest.mark.asyncio + async def test_parallel_branch_execution_performance(self, mock_frame): + """Test that parallel execution is faster than sequential.""" + import time + import torch + + def slow_inference(*args, **kwargs): + time.sleep(0.1) # Simulate slow inference + mock_result = Mock() + mock_result.probs = Mock() + mock_result.probs.top1 = 1 + mock_result.probs.top1conf = 0.85 + return [mock_result] + + mock_model = Mock() + mock_model.predict.side_effect = slow_inference + mock_model.names = {0: "Class0", 1: "Class1"} + + # Test parallel execution + parallel_executor = PipelineExecutor({"execution_mode": "parallel", "max_workers": 2}) + + branch_configs = [ + { + "modelId": f"branch_{i}", + "modelFile": f"branch_{i}.pt", + "triggerClasses": ["car"], + "minConfidence": 0.7, + "parallel": True + } + for i in range(3) # 3 branches + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=mock_frame + ) + + with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager: + mock_model_manager.return_value.get_model.return_value = mock_model + + start_time = time.time() + + # Execute branches in parallel + tasks = [ + parallel_executor._execute_branch(config, regions, context) + for config in branch_configs + ] + + results = await asyncio.gather(*tasks) + + parallel_time = time.time() - start_time + + # Parallel execution should be faster than 3 * 0.1 seconds + assert parallel_time < 0.25 # Allow some overhead + assert len(results) == 3 + assert all(result.success for result in results) + + def test_thread_pool_management(self): + """Test thread pool creation and management.""" + # Test different worker counts + for workers in [1, 2, 4, 8]: + executor = PipelineExecutor({"max_workers": workers}) + assert executor.max_workers == workers + assert executor.thread_pool._max_workers == workers + + def test_memory_management_large_frames(self): + """Test memory management with large frames.""" + executor = PipelineExecutor() + + # Create large frame + large_frame = np.ones((1080, 1920, 3), dtype=np.uint8) * 128 + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=large_frame, + crop_region=(500, 400, 1000, 800) + ) + + # Get cropped frame + cropped = context.get_cropped_frame() + + # Should reduce memory usage + assert cropped.shape == (400, 500, 3) # Much smaller than original + assert cropped.nbytes < large_frame.nbytes + + +class TestPipelineExecutorErrorHandling: + """Test comprehensive error handling.""" + + @pytest.mark.asyncio + async def test_branch_execution_error_isolation(self, mock_frame): + """Test that errors in one branch don't affect others.""" + executor = PipelineExecutor() + + # Mock models - one fails, one succeeds + failing_model = Mock() + failing_model.predict.side_effect = Exception("Model crashed") + + success_model = Mock() + mock_result = Mock() + mock_result.probs = Mock() + mock_result.probs.top1 = 1 + mock_result.probs.top1conf = 0.85 + success_model.predict.return_value = [mock_result] + success_model.names = {0: "Class0", 1: "Class1"} + + branch_configs = [ + { + "modelId": "failing_branch", + "modelFile": "failing.pt", + "triggerClasses": ["car"], + "minConfidence": 0.7, + "parallel": True + }, + { + "modelId": "success_branch", + "modelFile": "success.pt", + "triggerClasses": ["car"], + "minConfidence": 0.7, + "parallel": True + } + ] + + regions = { + "car": { + "bbox": [100, 200, 300, 400], + "confidence": 0.9, + "detection": DetectionResult("car", 0.9, BoundingBox(100, 200, 300, 400), 1001) + } + } + + context = PipelineContext( + camera_id="camera_001", + display_id="display_001", + session_id="session_123", + timestamp=1640995200000, + frame_data=mock_frame + ) + + def get_model_side_effect(model_id, camera_id): + if model_id == "failing_branch": + return failing_model + elif model_id == "success_branch": + return success_model + return None + + with patch('detector_worker.models.model_manager.ModelManager') as mock_model_manager: + mock_model_manager.return_value.get_model.side_effect = get_model_side_effect + + # Execute branches + tasks = [ + executor._execute_branch(config, regions, context) + for config in branch_configs + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # One should fail, one should succeed + failing_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "failing_branch") + success_result = next(r for r in results if isinstance(r, BranchResult) and r.branch_id == "success_branch") + + assert failing_result.success is False + assert "Model crashed" in failing_result.error + + assert success_result.success is True + assert success_result.error is None \ No newline at end of file diff --git a/tests/unit/storage/test_database_manager.py b/tests/unit/storage/test_database_manager.py new file mode 100644 index 0000000..a26e558 --- /dev/null +++ b/tests/unit/storage/test_database_manager.py @@ -0,0 +1,976 @@ +""" +Unit tests for database management functionality. +""" +import pytest +import asyncio +from unittest.mock import Mock, MagicMock, patch, AsyncMock +from datetime import datetime, timedelta +import psycopg2 +import uuid + +from detector_worker.storage.database_manager import ( + DatabaseManager, + DatabaseConfig, + DatabaseConnection, + QueryBuilder, + TransactionManager, + DatabaseError, + ConnectionPoolError +) +from detector_worker.core.exceptions import ConfigurationError + + +class TestDatabaseConfig: + """Test database configuration.""" + + def test_creation_minimal(self): + """Test creating database config with minimal parameters.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + assert config.host == "localhost" + assert config.port == 5432 # Default port + assert config.database == "test_db" + assert config.username == "test_user" + assert config.password == "test_pass" + assert config.schema == "public" # Default schema + assert config.enabled is True + + def test_creation_full(self): + """Test creating database config with all parameters.""" + config = DatabaseConfig( + host="db.example.com", + port=5433, + database="production_db", + username="prod_user", + password="secure_pass", + schema="gas_station_1", + enabled=True, + pool_min_conn=2, + pool_max_conn=20, + pool_timeout=30.0, + connection_timeout=10.0, + ssl_mode="require" + ) + + assert config.host == "db.example.com" + assert config.port == 5433 + assert config.database == "production_db" + assert config.schema == "gas_station_1" + assert config.pool_min_conn == 2 + assert config.pool_max_conn == 20 + assert config.ssl_mode == "require" + + def test_get_connection_string(self): + """Test generating connection string.""" + config = DatabaseConfig( + host="localhost", + port=5432, + database="test_db", + username="test_user", + password="test_pass" + ) + + conn_string = config.get_connection_string() + + expected = "host=localhost port=5432 database=test_db user=test_user password=test_pass" + assert conn_string == expected + + def test_get_connection_string_with_ssl(self): + """Test generating connection string with SSL.""" + config = DatabaseConfig( + host="db.example.com", + database="secure_db", + username="user", + password="pass", + ssl_mode="require" + ) + + conn_string = config.get_connection_string() + + assert "sslmode=require" in conn_string + + def test_from_dict(self): + """Test creating config from dictionary.""" + config_dict = { + "host": "test-host", + "port": 5433, + "database": "test-db", + "username": "test-user", + "password": "test-pass", + "schema": "test_schema", + "pool_max_conn": 15, + "unknown_field": "ignored" + } + + config = DatabaseConfig.from_dict(config_dict) + + assert config.host == "test-host" + assert config.port == 5433 + assert config.database == "test-db" + assert config.schema == "test_schema" + assert config.pool_max_conn == 15 + + +class TestQueryBuilder: + """Test SQL query building functionality.""" + + def test_build_select_query(self): + """Test building SELECT queries.""" + builder = QueryBuilder("test_schema") + + query, params = builder.build_select_query( + table="users", + columns=["id", "name", "email"], + where={"status": "active", "age": 25}, + order_by="created_at DESC", + limit=10 + ) + + expected_query = ( + "SELECT id, name, email FROM test_schema.users " + "WHERE status = %s AND age = %s " + "ORDER BY created_at DESC LIMIT 10" + ) + + assert query == expected_query + assert params == ["active", 25] + + def test_build_select_all_columns(self): + """Test building SELECT * query.""" + builder = QueryBuilder("public") + + query, params = builder.build_select_query("products") + + expected_query = "SELECT * FROM public.products" + assert query == expected_query + assert params == [] + + def test_build_insert_query(self): + """Test building INSERT queries.""" + builder = QueryBuilder("inventory") + + data = { + "product_name": "Widget", + "price": 19.99, + "quantity": 100, + "created_at": "NOW()" + } + + query, params = builder.build_insert_query("products", data) + + expected_query = ( + "INSERT INTO inventory.products (product_name, price, quantity, created_at) " + "VALUES (%s, %s, %s, NOW()) RETURNING id" + ) + + assert query == expected_query + assert params == ["Widget", 19.99, 100] + + def test_build_update_query(self): + """Test building UPDATE queries.""" + builder = QueryBuilder("sales") + + data = { + "status": "shipped", + "shipped_date": "NOW()", + "tracking_number": "ABC123" + } + + where_conditions = {"order_id": 12345} + + query, params = builder.build_update_query("orders", data, where_conditions) + + expected_query = ( + "UPDATE sales.orders SET status = %s, shipped_date = NOW(), tracking_number = %s " + "WHERE order_id = %s" + ) + + assert query == expected_query + assert params == ["shipped", "ABC123", 12345] + + def test_build_delete_query(self): + """Test building DELETE queries.""" + builder = QueryBuilder("logs") + + where_conditions = { + "level": "DEBUG", + "created_at": "< NOW() - INTERVAL '7 days'" + } + + query, params = builder.build_delete_query("application_logs", where_conditions) + + expected_query = ( + "DELETE FROM logs.application_logs " + "WHERE level = %s AND created_at < NOW() - INTERVAL '7 days'" + ) + + assert query == expected_query + assert params == ["DEBUG"] + + def test_build_create_table_query(self): + """Test building CREATE TABLE queries.""" + builder = QueryBuilder("gas_station_1") + + columns = { + "id": "SERIAL PRIMARY KEY", + "session_id": "VARCHAR(255) UNIQUE NOT NULL", + "camera_id": "VARCHAR(255) NOT NULL", + "detection_class": "VARCHAR(100)", + "confidence": "DECIMAL(4,3)", + "bbox_data": "JSON", + "created_at": "TIMESTAMP DEFAULT NOW()", + "updated_at": "TIMESTAMP DEFAULT NOW()" + } + + query = builder.build_create_table_query("detections", columns) + + expected_parts = [ + "CREATE TABLE IF NOT EXISTS gas_station_1.detections", + "id SERIAL PRIMARY KEY", + "session_id VARCHAR(255) UNIQUE NOT NULL", + "camera_id VARCHAR(255) NOT NULL", + "bbox_data JSON", + "created_at TIMESTAMP DEFAULT NOW()" + ] + + for part in expected_parts: + assert part in query + + def test_escape_identifier(self): + """Test SQL identifier escaping.""" + builder = QueryBuilder("test") + + assert builder.escape_identifier("table") == '"table"' + assert builder.escape_identifier("column_name") == '"column_name"' + assert builder.escape_identifier("user-table") == '"user-table"' + + def test_format_value_for_sql(self): + """Test SQL value formatting.""" + builder = QueryBuilder("test") + + # Regular values should use placeholder + assert builder.format_value_for_sql("string") == ("%s", "string") + assert builder.format_value_for_sql(42) == ("%s", 42) + assert builder.format_value_for_sql(3.14) == ("%s", 3.14) + + # SQL functions should be literal + assert builder.format_value_for_sql("NOW()") == ("NOW()", None) + assert builder.format_value_for_sql("CURRENT_TIMESTAMP") == ("CURRENT_TIMESTAMP", None) + assert builder.format_value_for_sql("UUID()") == ("UUID()", None) + + +class TestDatabaseConnection: + """Test database connection management.""" + + def test_creation(self, mock_database_connection): + """Test connection creation.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + conn = DatabaseConnection(config, mock_database_connection) + + assert conn.config == config + assert conn.connection == mock_database_connection + assert conn.is_connected is True + + def test_execute_query(self, mock_database_connection): + """Test query execution.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + # Mock cursor behavior + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.fetchall.return_value = [ + (1, "John", "john@example.com"), + (2, "Jane", "jane@example.com") + ] + mock_cursor.rowcount = 2 + + conn = DatabaseConnection(config, mock_database_connection) + + query = "SELECT id, name, email FROM users WHERE status = %s" + params = ["active"] + + result = conn.execute_query(query, params) + + assert result == [ + (1, "John", "john@example.com"), + (2, "Jane", "jane@example.com") + ] + + mock_cursor.execute.assert_called_once_with(query, params) + mock_cursor.fetchall.assert_called_once() + + def test_execute_query_single_result(self, mock_database_connection): + """Test query execution with single result.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.fetchone.return_value = (1, "John", "john@example.com") + + conn = DatabaseConnection(config, mock_database_connection) + + result = conn.execute_query("SELECT * FROM users WHERE id = %s", [1], fetch_one=True) + + assert result == (1, "John", "john@example.com") + mock_cursor.fetchone.assert_called_once() + + def test_execute_query_no_fetch(self, mock_database_connection): + """Test query execution without fetching results.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.rowcount = 1 + + conn = DatabaseConnection(config, mock_database_connection) + + result = conn.execute_query( + "INSERT INTO users (name) VALUES (%s)", + ["John"], + fetch_results=False + ) + + assert result == 1 # Row count + mock_cursor.execute.assert_called_once() + mock_cursor.fetchall.assert_not_called() + mock_cursor.fetchone.assert_not_called() + + def test_execute_query_error(self, mock_database_connection): + """Test query execution error handling.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.execute.side_effect = psycopg2.Error("Database error") + + conn = DatabaseConnection(config, mock_database_connection) + + with pytest.raises(DatabaseError) as exc_info: + conn.execute_query("SELECT * FROM invalid_table") + + assert "Database error" in str(exc_info.value) + + def test_commit_transaction(self, mock_database_connection): + """Test transaction commit.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + conn = DatabaseConnection(config, mock_database_connection) + conn.commit() + + mock_database_connection.commit.assert_called_once() + + def test_rollback_transaction(self, mock_database_connection): + """Test transaction rollback.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + conn = DatabaseConnection(config, mock_database_connection) + conn.rollback() + + mock_database_connection.rollback.assert_called_once() + + def test_close_connection(self, mock_database_connection): + """Test connection closing.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + conn = DatabaseConnection(config, mock_database_connection) + conn.close() + + assert conn.is_connected is False + mock_database_connection.close.assert_called_once() + + +class TestTransactionManager: + """Test transaction management.""" + + def test_transaction_context_success(self, mock_database_connection): + """Test successful transaction context.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + conn = DatabaseConnection(config, mock_database_connection) + tx_manager = TransactionManager(conn) + + with tx_manager: + # Simulate some database operations + conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"]) + conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"]) + + # Should commit on successful exit + mock_database_connection.commit.assert_called_once() + mock_database_connection.rollback.assert_not_called() + + def test_transaction_context_error(self, mock_database_connection): + """Test transaction context with error.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + conn = DatabaseConnection(config, mock_database_connection) + tx_manager = TransactionManager(conn) + + with pytest.raises(DatabaseError): + with tx_manager: + conn.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"]) + # Simulate an error + raise DatabaseError("Something went wrong") + + # Should rollback on error + mock_database_connection.rollback.assert_called_once() + mock_database_connection.commit.assert_not_called() + + +class TestDatabaseManager: + """Test main database manager functionality.""" + + def test_initialization(self): + """Test database manager initialization.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass", + schema="gas_station_1" + ) + + manager = DatabaseManager(config) + + assert manager.config == config + assert isinstance(manager.query_builder, QueryBuilder) + assert manager.query_builder.schema == "gas_station_1" + assert manager.connection is None + + @pytest.mark.asyncio + async def test_connect_success(self): + """Test successful database connection.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + + with patch('psycopg2.connect') as mock_connect: + mock_connection = Mock() + mock_connect.return_value = mock_connection + + await manager.connect() + + assert manager.connection is not None + assert manager.is_connected is True + mock_connect.assert_called_once() + + @pytest.mark.asyncio + async def test_connect_failure(self): + """Test database connection failure.""" + config = DatabaseConfig( + host="nonexistent-host", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + + with patch('psycopg2.connect') as mock_connect: + mock_connect.side_effect = psycopg2.Error("Connection failed") + + with pytest.raises(DatabaseError) as exc_info: + await manager.connect() + + assert "Connection failed" in str(exc_info.value) + assert manager.is_connected is False + + @pytest.mark.asyncio + async def test_disconnect(self): + """Test database disconnection.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + + # Mock connection + mock_connection = Mock() + manager.connection = DatabaseConnection(config, mock_connection) + + await manager.disconnect() + + assert manager.connection is None + mock_connection.close.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_query(self, mock_database_connection): + """Test query execution through manager.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + # Mock cursor behavior + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.fetchall.return_value = [(1, "Test"), (2, "Data")] + + result = await manager.execute_query("SELECT * FROM test_table") + + assert result == [(1, "Test"), (2, "Data")] + mock_cursor.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_execute_query_not_connected(self): + """Test query execution when not connected.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + + with pytest.raises(DatabaseError) as exc_info: + await manager.execute_query("SELECT * FROM test_table") + + assert "not connected" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_insert_record(self, mock_database_connection): + """Test inserting a record.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass", + schema="gas_station_1" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + # Mock cursor behavior + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.fetchone.return_value = (123,) # Returned ID + + data = { + "session_id": "session_123", + "camera_id": "camera_001", + "detection_class": "car", + "confidence": 0.95, + "created_at": "NOW()" + } + + record_id = await manager.insert_record("car_detections", data) + + assert record_id == 123 + mock_cursor.execute.assert_called_once() + mock_database_connection.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_update_record(self, mock_database_connection): + """Test updating a record.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass", + schema="gas_station_1" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + # Mock cursor behavior + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.rowcount = 1 + + data = { + "car_brand": "Toyota", + "car_body_type": "Sedan", + "updated_at": "NOW()" + } + + where_conditions = {"session_id": "session_123"} + + rows_affected = await manager.update_record("car_info", data, where_conditions) + + assert rows_affected == 1 + mock_cursor.execute.assert_called_once() + mock_database_connection.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_records(self, mock_database_connection): + """Test deleting records.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + # Mock cursor behavior + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.rowcount = 3 + + where_conditions = { + "created_at": "< NOW() - INTERVAL '30 days'", + "processed": True + } + + rows_deleted = await manager.delete_records("old_detections", where_conditions) + + assert rows_deleted == 3 + mock_cursor.execute.assert_called_once() + mock_database_connection.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_create_table(self, mock_database_connection): + """Test creating a table.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass", + schema="gas_station_1" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + columns = { + "id": "SERIAL PRIMARY KEY", + "session_id": "VARCHAR(255) UNIQUE NOT NULL", + "camera_id": "VARCHAR(255) NOT NULL", + "detection_data": "JSON", + "created_at": "TIMESTAMP DEFAULT NOW()" + } + + await manager.create_table("test_detections", columns) + + mock_database_connection.cursor.return_value.execute.assert_called_once() + mock_database_connection.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_table_exists(self, mock_database_connection): + """Test checking if table exists.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass", + schema="gas_station_1" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + # Mock cursor behavior - table exists + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.fetchone.return_value = (1,) + + exists = await manager.table_exists("car_detections") + + assert exists is True + mock_cursor.execute.assert_called_once() + + # Mock cursor behavior - table doesn't exist + mock_cursor.fetchone.return_value = None + + exists = await manager.table_exists("nonexistent_table") + + assert exists is False + + @pytest.mark.asyncio + async def test_transaction_context(self, mock_database_connection): + """Test transaction context manager.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + async with manager.transaction(): + await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"]) + await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"]) + + # Should commit on successful completion + mock_database_connection.commit.assert_called() + + @pytest.mark.asyncio + async def test_get_table_schema(self, mock_database_connection): + """Test getting table schema information.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass", + schema="gas_station_1" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + # Mock cursor behavior + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.fetchall.return_value = [ + ("id", "integer", "NOT NULL"), + ("session_id", "character varying", "NOT NULL"), + ("created_at", "timestamp without time zone", "DEFAULT now()") + ] + + schema = await manager.get_table_schema("car_detections") + + assert len(schema) == 3 + assert schema[0] == ("id", "integer", "NOT NULL") + assert schema[1] == ("session_id", "character varying", "NOT NULL") + + @pytest.mark.asyncio + async def test_bulk_insert(self, mock_database_connection): + """Test bulk insert operation.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + records = [ + {"name": "John", "email": "john@example.com"}, + {"name": "Jane", "email": "jane@example.com"}, + {"name": "Bob", "email": "bob@example.com"} + ] + + mock_cursor = mock_database_connection.cursor.return_value + mock_cursor.rowcount = 3 + + rows_inserted = await manager.bulk_insert("users", records) + + assert rows_inserted == 3 + mock_cursor.executemany.assert_called_once() + mock_database_connection.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_get_connection_stats(self, mock_database_connection): + """Test getting connection statistics.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + stats = manager.get_connection_stats() + + assert "connected" in stats + assert "host" in stats + assert "database" in stats + assert "schema" in stats + assert stats["connected"] is True + assert stats["host"] == "localhost" + assert stats["database"] == "test_db" + + +class TestDatabaseManagerIntegration: + """Integration tests for database manager.""" + + @pytest.mark.asyncio + async def test_complete_car_detection_workflow(self, mock_database_connection): + """Test complete car detection database workflow.""" + config = DatabaseConfig( + host="localhost", + database="gas_station_db", + username="detector_user", + password="detector_pass", + schema="gas_station_1" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + # Mock cursor behaviors for different operations + mock_cursor = mock_database_connection.cursor.return_value + + # 1. Create initial detection record + mock_cursor.fetchone.return_value = (456,) # Returned ID + + detection_data = { + "session_id": str(uuid.uuid4()), + "camera_id": "camera_001", + "display_id": "display_001", + "detection_class": "car", + "confidence": 0.92, + "bbox_x1": 100, + "bbox_y1": 200, + "bbox_x2": 300, + "bbox_y2": 400, + "track_id": 1001, + "created_at": "NOW()" + } + + detection_id = await manager.insert_record("car_detections", detection_data) + assert detection_id == 456 + + # 2. Update with classification results + mock_cursor.rowcount = 1 + + classification_data = { + "car_brand": "Toyota", + "car_model": "Camry", + "car_body_type": "Sedan", + "car_color": "Blue", + "brand_confidence": 0.87, + "bodytype_confidence": 0.82, + "color_confidence": 0.79, + "updated_at": "NOW()" + } + + where_conditions = {"session_id": detection_data["session_id"]} + + rows_updated = await manager.update_record("car_detections", classification_data, where_conditions) + assert rows_updated == 1 + + # 3. Query final results + mock_cursor.fetchall.return_value = [ + (456, detection_data["session_id"], "camera_001", "car", 0.92, "Toyota", "Sedan") + ] + + results = await manager.execute_query( + "SELECT id, session_id, camera_id, detection_class, confidence, car_brand, car_body_type " + "FROM gas_station_1.car_detections WHERE session_id = %s", + [detection_data["session_id"]] + ) + + assert len(results) == 1 + assert results[0][0] == 456 # ID + assert results[0][3] == "car" # detection_class + assert results[0][5] == "Toyota" # car_brand + + # Verify all database operations were called + assert mock_cursor.execute.call_count == 3 + assert mock_database_connection.commit.call_count == 2 + + @pytest.mark.asyncio + async def test_error_handling_and_recovery(self, mock_database_connection): + """Test error handling and recovery scenarios.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + manager.connection = DatabaseConnection(config, mock_database_connection) + + # Test transaction rollback on error + mock_cursor = mock_database_connection.cursor.return_value + + with pytest.raises(DatabaseError): + async with manager.transaction(): + # First operation succeeds + await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["John"]) + + # Second operation fails + mock_cursor.execute.side_effect = psycopg2.Error("Constraint violation") + await manager.execute_query("INSERT INTO users (name) VALUES (%s)", ["Jane"]) + + # Should have rolled back + mock_database_connection.rollback.assert_called_once() + mock_database_connection.commit.assert_not_called() + + @pytest.mark.asyncio + async def test_connection_recovery(self): + """Test automatic connection recovery.""" + config = DatabaseConfig( + host="localhost", + database="test_db", + username="test_user", + password="test_pass" + ) + + manager = DatabaseManager(config) + + with patch('psycopg2.connect') as mock_connect: + # First connection attempt fails + mock_connect.side_effect = [ + psycopg2.Error("Connection refused"), + Mock() # Second attempt succeeds + ] + + # First attempt should fail + with pytest.raises(DatabaseError): + await manager.connect() + + # Second attempt should succeed + await manager.connect() + assert manager.is_connected is True \ No newline at end of file diff --git a/tests/unit/storage/test_redis_client.py b/tests/unit/storage/test_redis_client.py new file mode 100644 index 0000000..b4fd1f5 --- /dev/null +++ b/tests/unit/storage/test_redis_client.py @@ -0,0 +1,964 @@ +""" +Unit tests for Redis client functionality. +""" +import pytest +import asyncio +import json +import base64 +import time +from unittest.mock import Mock, MagicMock, patch, AsyncMock +from datetime import datetime, timedelta +import redis +import numpy as np + +from detector_worker.storage.redis_client import ( + RedisClient, + RedisConfig, + RedisConnectionPool, + RedisPublisher, + RedisSubscriber, + RedisImageStorage, + RedisError, + ConnectionPoolError +) +from detector_worker.detection.detection_result import DetectionResult, BoundingBox +from detector_worker.core.exceptions import ConfigurationError + + +class TestRedisConfig: + """Test Redis configuration.""" + + def test_creation_minimal(self): + """Test creating Redis config with minimal parameters.""" + config = RedisConfig( + host="localhost" + ) + + assert config.host == "localhost" + assert config.port == 6379 # Default port + assert config.password is None + assert config.db == 0 # Default database + assert config.enabled is True + + def test_creation_full(self): + """Test creating Redis config with all parameters.""" + config = RedisConfig( + host="redis.example.com", + port=6380, + password="secure_pass", + db=2, + enabled=True, + connection_timeout=5.0, + socket_timeout=3.0, + socket_connect_timeout=2.0, + max_connections=50, + retry_on_timeout=True, + health_check_interval=30 + ) + + assert config.host == "redis.example.com" + assert config.port == 6380 + assert config.password == "secure_pass" + assert config.db == 2 + assert config.connection_timeout == 5.0 + assert config.max_connections == 50 + assert config.retry_on_timeout is True + + def test_get_connection_params(self): + """Test getting Redis connection parameters.""" + config = RedisConfig( + host="localhost", + port=6379, + password="test_pass", + db=1, + connection_timeout=10.0 + ) + + params = config.get_connection_params() + + assert params["host"] == "localhost" + assert params["port"] == 6379 + assert params["password"] == "test_pass" + assert params["db"] == 1 + assert params["socket_timeout"] == 10.0 + + def test_from_dict(self): + """Test creating config from dictionary.""" + config_dict = { + "host": "redis-server", + "port": 6380, + "password": "secret", + "db": 3, + "max_connections": 100, + "unknown_field": "ignored" + } + + config = RedisConfig.from_dict(config_dict) + + assert config.host == "redis-server" + assert config.port == 6380 + assert config.password == "secret" + assert config.db == 3 + assert config.max_connections == 100 + + +class TestRedisConnectionPool: + """Test Redis connection pool management.""" + + def test_creation(self): + """Test connection pool creation.""" + config = RedisConfig( + host="localhost", + max_connections=20 + ) + + pool = RedisConnectionPool(config) + + assert pool.config == config + assert pool.pool is None + assert pool.is_connected is False + + @pytest.mark.asyncio + async def test_connect_success(self): + """Test successful connection to Redis.""" + config = RedisConfig(host="localhost") + pool = RedisConnectionPool(config) + + with patch('redis.ConnectionPool') as mock_pool_class: + mock_pool = Mock() + mock_pool_class.return_value = mock_pool + + with patch('redis.Redis') as mock_redis_class: + mock_redis = Mock() + mock_redis.ping.return_value = True + mock_redis_class.return_value = mock_redis + + await pool.connect() + + assert pool.is_connected is True + assert pool.pool is not None + mock_pool_class.assert_called_once() + + @pytest.mark.asyncio + async def test_connect_failure(self): + """Test Redis connection failure.""" + config = RedisConfig(host="nonexistent-redis") + pool = RedisConnectionPool(config) + + with patch('redis.ConnectionPool') as mock_pool_class: + mock_pool_class.side_effect = redis.ConnectionError("Connection failed") + + with pytest.raises(RedisError) as exc_info: + await pool.connect() + + assert "Connection failed" in str(exc_info.value) + assert pool.is_connected is False + + @pytest.mark.asyncio + async def test_disconnect(self): + """Test Redis disconnection.""" + config = RedisConfig(host="localhost") + pool = RedisConnectionPool(config) + + # Mock connected state + mock_pool = Mock() + mock_redis = Mock() + pool.pool = mock_pool + pool._redis_client = mock_redis + pool.is_connected = True + + await pool.disconnect() + + assert pool.is_connected is False + assert pool.pool is None + mock_pool.disconnect.assert_called_once() + + def test_get_client_connected(self): + """Test getting Redis client when connected.""" + config = RedisConfig(host="localhost") + pool = RedisConnectionPool(config) + + mock_pool = Mock() + mock_redis = Mock() + pool.pool = mock_pool + pool._redis_client = mock_redis + pool.is_connected = True + + client = pool.get_client() + assert client == mock_redis + + def test_get_client_not_connected(self): + """Test getting Redis client when not connected.""" + config = RedisConfig(host="localhost") + pool = RedisConnectionPool(config) + + with pytest.raises(RedisError) as exc_info: + pool.get_client() + + assert "not connected" in str(exc_info.value).lower() + + def test_health_check(self): + """Test Redis health check.""" + config = RedisConfig(host="localhost") + pool = RedisConnectionPool(config) + + mock_redis = Mock() + mock_redis.ping.return_value = True + pool._redis_client = mock_redis + pool.is_connected = True + + is_healthy = pool.health_check() + + assert is_healthy is True + mock_redis.ping.assert_called_once() + + def test_health_check_failure(self): + """Test Redis health check failure.""" + config = RedisConfig(host="localhost") + pool = RedisConnectionPool(config) + + mock_redis = Mock() + mock_redis.ping.side_effect = redis.ConnectionError("Connection lost") + pool._redis_client = mock_redis + pool.is_connected = True + + is_healthy = pool.health_check() + + assert is_healthy is False + + +class TestRedisImageStorage: + """Test Redis image storage functionality.""" + + def test_creation(self, mock_redis_client): + """Test Redis image storage creation.""" + storage = RedisImageStorage(mock_redis_client) + + assert storage.redis_client == mock_redis_client + assert storage.default_expiry == 3600 # 1 hour + assert storage.compression_enabled is True + + @pytest.mark.asyncio + async def test_store_image_success(self, mock_redis_client, mock_frame): + """Test successful image storage.""" + storage = RedisImageStorage(mock_redis_client) + + mock_redis_client.set.return_value = True + mock_redis_client.expire.return_value = True + + with patch('cv2.imencode') as mock_imencode: + # Mock successful encoding + encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8) + mock_imencode.return_value = (True, encoded_data) + + result = await storage.store_image("test_key", mock_frame, expire_seconds=600) + + assert result is True + mock_redis_client.set.assert_called_once() + mock_redis_client.expire.assert_called_once_with("test_key", 600) + mock_imencode.assert_called_once() + + @pytest.mark.asyncio + async def test_store_image_cropped(self, mock_redis_client, mock_frame): + """Test storing cropped image.""" + storage = RedisImageStorage(mock_redis_client) + + mock_redis_client.set.return_value = True + mock_redis_client.expire.return_value = True + + bbox = BoundingBox(x1=100, y1=200, x2=300, y2=400) + + with patch('cv2.imencode') as mock_imencode: + encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8) + mock_imencode.return_value = (True, encoded_data) + + result = await storage.store_image("cropped_key", mock_frame, crop_bbox=bbox) + + assert result is True + mock_redis_client.set.assert_called_once() + + @pytest.mark.asyncio + async def test_store_image_encoding_failure(self, mock_redis_client, mock_frame): + """Test image storage with encoding failure.""" + storage = RedisImageStorage(mock_redis_client) + + with patch('cv2.imencode') as mock_imencode: + # Mock encoding failure + mock_imencode.return_value = (False, None) + + with pytest.raises(RedisError) as exc_info: + await storage.store_image("test_key", mock_frame) + + assert "Failed to encode image" in str(exc_info.value) + mock_redis_client.set.assert_not_called() + + @pytest.mark.asyncio + async def test_store_image_redis_failure(self, mock_redis_client, mock_frame): + """Test image storage with Redis failure.""" + storage = RedisImageStorage(mock_redis_client) + + mock_redis_client.set.side_effect = redis.RedisError("Redis error") + + with patch('cv2.imencode') as mock_imencode: + encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8) + mock_imencode.return_value = (True, encoded_data) + + with pytest.raises(RedisError) as exc_info: + await storage.store_image("test_key", mock_frame) + + assert "Redis error" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_retrieve_image_success(self, mock_redis_client): + """Test successful image retrieval.""" + storage = RedisImageStorage(mock_redis_client) + + # Mock encoded image data + original_image = np.ones((100, 100, 3), dtype=np.uint8) * 128 + + with patch('cv2.imencode') as mock_imencode: + encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8) + mock_imencode.return_value = (True, encoded_data) + + # Mock Redis returning base64 encoded data + base64_data = base64.b64encode(encoded_data.tobytes()).decode('utf-8') + mock_redis_client.get.return_value = base64_data + + with patch('cv2.imdecode') as mock_imdecode: + mock_imdecode.return_value = original_image + + retrieved_image = await storage.retrieve_image("test_key") + + assert retrieved_image is not None + assert retrieved_image.shape == (100, 100, 3) + mock_redis_client.get.assert_called_once_with("test_key") + + @pytest.mark.asyncio + async def test_retrieve_image_not_found(self, mock_redis_client): + """Test image retrieval when key not found.""" + storage = RedisImageStorage(mock_redis_client) + + mock_redis_client.get.return_value = None + + retrieved_image = await storage.retrieve_image("nonexistent_key") + + assert retrieved_image is None + mock_redis_client.get.assert_called_once_with("nonexistent_key") + + @pytest.mark.asyncio + async def test_delete_image(self, mock_redis_client): + """Test image deletion.""" + storage = RedisImageStorage(mock_redis_client) + + mock_redis_client.delete.return_value = 1 + + result = await storage.delete_image("test_key") + + assert result is True + mock_redis_client.delete.assert_called_once_with("test_key") + + @pytest.mark.asyncio + async def test_delete_image_not_found(self, mock_redis_client): + """Test deleting non-existent image.""" + storage = RedisImageStorage(mock_redis_client) + + mock_redis_client.delete.return_value = 0 + + result = await storage.delete_image("nonexistent_key") + + assert result is False + mock_redis_client.delete.assert_called_once_with("nonexistent_key") + + @pytest.mark.asyncio + async def test_bulk_delete_images(self, mock_redis_client): + """Test bulk image deletion.""" + storage = RedisImageStorage(mock_redis_client) + + keys = ["key1", "key2", "key3"] + mock_redis_client.delete.return_value = 3 + + deleted_count = await storage.bulk_delete_images(keys) + + assert deleted_count == 3 + mock_redis_client.delete.assert_called_once_with(*keys) + + @pytest.mark.asyncio + async def test_cleanup_expired_images(self, mock_redis_client): + """Test cleanup of expired images.""" + storage = RedisImageStorage(mock_redis_client) + + # Mock scan to return image keys + mock_redis_client.scan_iter.return_value = [ + b"inference:camera1:image1", + b"inference:camera2:image2", + b"inference:camera1:image3" + ] + + # Mock ttl to return different expiry times + mock_redis_client.ttl.side_effect = [-1, 100, -2] # No expiry, valid, expired + mock_redis_client.delete.return_value = 1 + + deleted_count = await storage.cleanup_expired_images("inference:*") + + assert deleted_count == 1 # Only expired images deleted + mock_redis_client.delete.assert_called_once() + + def test_get_image_info(self, mock_redis_client): + """Test getting image metadata.""" + storage = RedisImageStorage(mock_redis_client) + + mock_redis_client.exists.return_value = 1 + mock_redis_client.ttl.return_value = 1800 # 30 minutes + mock_redis_client.memory_usage.return_value = 4096 # 4KB + + info = storage.get_image_info("test_key") + + assert info["exists"] is True + assert info["ttl"] == 1800 + assert info["size_bytes"] == 4096 + + mock_redis_client.exists.assert_called_once_with("test_key") + mock_redis_client.ttl.assert_called_once_with("test_key") + + +class TestRedisPublisher: + """Test Redis publisher functionality.""" + + def test_creation(self, mock_redis_client): + """Test Redis publisher creation.""" + publisher = RedisPublisher(mock_redis_client) + + assert publisher.redis_client == mock_redis_client + + @pytest.mark.asyncio + async def test_publish_message_string(self, mock_redis_client): + """Test publishing string message.""" + publisher = RedisPublisher(mock_redis_client) + + mock_redis_client.publish.return_value = 2 # 2 subscribers + + result = await publisher.publish("test_channel", "Hello, Redis!") + + assert result == 2 + mock_redis_client.publish.assert_called_once_with("test_channel", "Hello, Redis!") + + @pytest.mark.asyncio + async def test_publish_message_json(self, mock_redis_client): + """Test publishing JSON message.""" + publisher = RedisPublisher(mock_redis_client) + + mock_redis_client.publish.return_value = 1 + + message_data = { + "camera_id": "camera_001", + "detection_class": "car", + "confidence": 0.95, + "timestamp": 1640995200000 + } + + result = await publisher.publish("detections", message_data) + + assert result == 1 + + # Should have been JSON serialized + expected_json = json.dumps(message_data) + mock_redis_client.publish.assert_called_once_with("detections", expected_json) + + @pytest.mark.asyncio + async def test_publish_detection_event(self, mock_redis_client): + """Test publishing detection event.""" + publisher = RedisPublisher(mock_redis_client) + + mock_redis_client.publish.return_value = 3 + + detection = DetectionResult("car", 0.92, BoundingBox(100, 200, 300, 400), 1001, 1640995200000) + + result = await publisher.publish_detection_event( + "camera_detections", + detection, + camera_id="camera_001", + session_id="session_123" + ) + + assert result == 3 + + # Verify the published message structure + call_args = mock_redis_client.publish.call_args + channel = call_args[0][0] + message_str = call_args[0][1] + message_data = json.loads(message_str) + + assert channel == "camera_detections" + assert message_data["event_type"] == "detection" + assert message_data["camera_id"] == "camera_001" + assert message_data["session_id"] == "session_123" + assert message_data["detection"]["class"] == "car" + assert message_data["detection"]["confidence"] == 0.92 + + @pytest.mark.asyncio + async def test_publish_batch_messages(self, mock_redis_client): + """Test publishing multiple messages in batch.""" + publisher = RedisPublisher(mock_redis_client) + + mock_pipeline = Mock() + mock_redis_client.pipeline.return_value = mock_pipeline + mock_pipeline.execute.return_value = [1, 2, 1] # Subscriber counts + + messages = [ + ("channel1", "message1"), + ("channel2", {"data": "message2"}), + ("channel1", "message3") + ] + + results = await publisher.publish_batch(messages) + + assert results == [1, 2, 1] + mock_redis_client.pipeline.assert_called_once() + assert mock_pipeline.publish.call_count == 3 + mock_pipeline.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_publish_error_handling(self, mock_redis_client): + """Test error handling in publishing.""" + publisher = RedisPublisher(mock_redis_client) + + mock_redis_client.publish.side_effect = redis.RedisError("Publish failed") + + with pytest.raises(RedisError) as exc_info: + await publisher.publish("test_channel", "test_message") + + assert "Publish failed" in str(exc_info.value) + + +class TestRedisSubscriber: + """Test Redis subscriber functionality.""" + + def test_creation(self, mock_redis_client): + """Test Redis subscriber creation.""" + subscriber = RedisSubscriber(mock_redis_client) + + assert subscriber.redis_client == mock_redis_client + assert subscriber.pubsub is None + assert subscriber.subscriptions == set() + + @pytest.mark.asyncio + async def test_subscribe_to_channel(self, mock_redis_client): + """Test subscribing to a channel.""" + subscriber = RedisSubscriber(mock_redis_client) + + mock_pubsub = Mock() + mock_redis_client.pubsub.return_value = mock_pubsub + + await subscriber.subscribe("test_channel") + + assert "test_channel" in subscriber.subscriptions + mock_pubsub.subscribe.assert_called_once_with("test_channel") + + @pytest.mark.asyncio + async def test_subscribe_to_pattern(self, mock_redis_client): + """Test subscribing to a pattern.""" + subscriber = RedisSubscriber(mock_redis_client) + + mock_pubsub = Mock() + mock_redis_client.pubsub.return_value = mock_pubsub + + await subscriber.subscribe_pattern("detection:*") + + assert "detection:*" in subscriber.subscriptions + mock_pubsub.psubscribe.assert_called_once_with("detection:*") + + @pytest.mark.asyncio + async def test_unsubscribe_from_channel(self, mock_redis_client): + """Test unsubscribing from a channel.""" + subscriber = RedisSubscriber(mock_redis_client) + + mock_pubsub = Mock() + mock_redis_client.pubsub.return_value = mock_pubsub + subscriber.pubsub = mock_pubsub + subscriber.subscriptions.add("test_channel") + + await subscriber.unsubscribe("test_channel") + + assert "test_channel" not in subscriber.subscriptions + mock_pubsub.unsubscribe.assert_called_once_with("test_channel") + + @pytest.mark.asyncio + async def test_listen_for_messages(self, mock_redis_client): + """Test listening for messages.""" + subscriber = RedisSubscriber(mock_redis_client) + + mock_pubsub = Mock() + mock_redis_client.pubsub.return_value = mock_pubsub + + # Mock message stream + messages = [ + {"type": "subscribe", "channel": "test", "data": 1}, + {"type": "message", "channel": "test", "data": "Hello"}, + {"type": "message", "channel": "test", "data": '{"key": "value"}'} + ] + + mock_pubsub.listen.return_value = iter(messages) + + received_messages = [] + message_count = 0 + + async for message in subscriber.listen(): + received_messages.append(message) + message_count += 1 + if message_count >= 2: # Only process actual messages + break + + # Should receive 2 actual messages (excluding subscribe confirmation) + assert len(received_messages) == 2 + assert received_messages[0]["data"] == "Hello" + assert received_messages[1]["data"] == {"key": "value"} # Should be parsed as JSON + + @pytest.mark.asyncio + async def test_close_subscription(self, mock_redis_client): + """Test closing subscription.""" + subscriber = RedisSubscriber(mock_redis_client) + + mock_pubsub = Mock() + subscriber.pubsub = mock_pubsub + subscriber.subscriptions = {"channel1", "pattern:*"} + + await subscriber.close() + + assert len(subscriber.subscriptions) == 0 + mock_pubsub.close.assert_called_once() + assert subscriber.pubsub is None + + +class TestRedisClient: + """Test main Redis client functionality.""" + + def test_initialization(self): + """Test Redis client initialization.""" + config = RedisConfig(host="localhost", port=6379) + client = RedisClient(config) + + assert client.config == config + assert isinstance(client.connection_pool, RedisConnectionPool) + assert client.image_storage is None + assert client.publisher is None + assert client.subscriber is None + + @pytest.mark.asyncio + async def test_connect_and_initialize_components(self): + """Test connecting and initializing all components.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect: + mock_redis_client = Mock() + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + + await client.connect() + + assert client.image_storage is not None + assert client.publisher is not None + assert client.subscriber is not None + assert isinstance(client.image_storage, RedisImageStorage) + assert isinstance(client.publisher, RedisPublisher) + assert isinstance(client.subscriber, RedisSubscriber) + + mock_connect.assert_called_once() + + @pytest.mark.asyncio + async def test_disconnect(self): + """Test disconnection.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state + client.connection_pool.is_connected = True + client.subscriber = Mock() + client.subscriber.close = AsyncMock() + + with patch.object(client.connection_pool, 'disconnect', new_callable=AsyncMock) as mock_disconnect: + await client.disconnect() + + client.subscriber.close.assert_called_once() + mock_disconnect.assert_called_once() + + assert client.image_storage is None + assert client.publisher is None + assert client.subscriber is None + + @pytest.mark.asyncio + async def test_store_and_retrieve_data(self, mock_redis_client): + """Test storing and retrieving data.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + + # Test storing data + mock_redis_client.set.return_value = True + result = await client.set("test_key", "test_value", expire_seconds=300) + assert result is True + mock_redis_client.set.assert_called_once_with("test_key", "test_value") + mock_redis_client.expire.assert_called_once_with("test_key", 300) + + # Test retrieving data + mock_redis_client.get.return_value = "test_value" + value = await client.get("test_key") + assert value == "test_value" + mock_redis_client.get.assert_called_once_with("test_key") + + @pytest.mark.asyncio + async def test_delete_keys(self, mock_redis_client): + """Test deleting keys.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + + mock_redis_client.delete.return_value = 2 + + result = await client.delete("key1", "key2") + + assert result == 2 + mock_redis_client.delete.assert_called_once_with("key1", "key2") + + @pytest.mark.asyncio + async def test_exists_check(self, mock_redis_client): + """Test checking key existence.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + + mock_redis_client.exists.return_value = 1 + + exists = await client.exists("test_key") + + assert exists is True + mock_redis_client.exists.assert_called_once_with("test_key") + + @pytest.mark.asyncio + async def test_expire_key(self, mock_redis_client): + """Test setting key expiration.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + + mock_redis_client.expire.return_value = True + + result = await client.expire("test_key", 600) + + assert result is True + mock_redis_client.expire.assert_called_once_with("test_key", 600) + + @pytest.mark.asyncio + async def test_get_ttl(self, mock_redis_client): + """Test getting key TTL.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + + mock_redis_client.ttl.return_value = 300 + + ttl = await client.ttl("test_key") + + assert ttl == 300 + mock_redis_client.ttl.assert_called_once_with("test_key") + + @pytest.mark.asyncio + async def test_scan_keys(self, mock_redis_client): + """Test scanning for keys.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + + mock_redis_client.scan_iter.return_value = [b"key1", b"key2", b"key3"] + + keys = await client.scan_keys("test:*") + + assert keys == ["key1", "key2", "key3"] + mock_redis_client.scan_iter.assert_called_once_with(match="test:*") + + @pytest.mark.asyncio + async def test_flush_database(self, mock_redis_client): + """Test flushing database.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + + mock_redis_client.flushdb.return_value = True + + result = await client.flush_db() + + assert result is True + mock_redis_client.flushdb.assert_called_once() + + def test_get_connection_info(self): + """Test getting connection information.""" + config = RedisConfig( + host="redis.example.com", + port=6380, + db=2 + ) + client = RedisClient(config) + client.connection_pool.is_connected = True + + info = client.get_connection_info() + + assert info["connected"] is True + assert info["host"] == "redis.example.com" + assert info["port"] == 6380 + assert info["database"] == 2 + + @pytest.mark.asyncio + async def test_pipeline_operations(self, mock_redis_client): + """Test Redis pipeline operations.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + + mock_pipeline = Mock() + mock_redis_client.pipeline.return_value = mock_pipeline + mock_pipeline.execute.return_value = [True, True, 1] + + async with client.pipeline() as pipe: + pipe.set("key1", "value1") + pipe.set("key2", "value2") + pipe.delete("key3") + results = await pipe.execute() + + assert results == [True, True, 1] + mock_redis_client.pipeline.assert_called_once() + mock_pipeline.execute.assert_called_once() + + +class TestRedisClientIntegration: + """Integration tests for Redis client.""" + + @pytest.mark.asyncio + async def test_complete_image_workflow(self, mock_redis_client, mock_frame): + """Test complete image storage workflow.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state and components + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + client.image_storage = RedisImageStorage(mock_redis_client) + client.publisher = RedisPublisher(mock_redis_client) + + # Mock Redis operations + mock_redis_client.set.return_value = True + mock_redis_client.expire.return_value = True + mock_redis_client.publish.return_value = 2 + + with patch('cv2.imencode') as mock_imencode: + encoded_data = np.array([1, 2, 3, 4], dtype=np.uint8) + mock_imencode.return_value = (True, encoded_data) + + # Store image + store_result = await client.image_storage.store_image( + "detection:camera001:1640995200:session123", + mock_frame, + expire_seconds=600 + ) + + # Publish detection event + detection_event = { + "camera_id": "camera001", + "session_id": "session123", + "detection_class": "car", + "confidence": 0.95, + "timestamp": 1640995200000 + } + + publish_result = await client.publisher.publish("detections:camera001", detection_event) + + assert store_result is True + assert publish_result == 2 + + # Verify Redis operations + mock_redis_client.set.assert_called_once() + mock_redis_client.expire.assert_called_once() + mock_redis_client.publish.assert_called_once() + + @pytest.mark.asyncio + async def test_error_recovery_and_reconnection(self): + """Test error recovery and reconnection.""" + config = RedisConfig(host="localhost", retry_on_timeout=True) + client = RedisClient(config) + + with patch.object(client.connection_pool, 'connect', new_callable=AsyncMock) as mock_connect: + with patch.object(client.connection_pool, 'health_check') as mock_health_check: + # First health check fails, second succeeds + mock_health_check.side_effect = [False, True] + + # First connection attempt fails, second succeeds + mock_connect.side_effect = [RedisError("Connection failed"), None] + + # Simulate connection recovery + try: + await client.connect() + except RedisError: + # Retry connection + await client.connect() + + assert mock_connect.call_count == 2 + + @pytest.mark.asyncio + async def test_bulk_operations_performance(self, mock_redis_client): + """Test bulk operations for performance.""" + config = RedisConfig(host="localhost") + client = RedisClient(config) + + # Mock connected state + client.connection_pool.get_client = Mock(return_value=mock_redis_client) + client.connection_pool.is_connected = True + client.publisher = RedisPublisher(mock_redis_client) + + # Mock pipeline operations + mock_pipeline = Mock() + mock_redis_client.pipeline.return_value = mock_pipeline + mock_pipeline.execute.return_value = [1] * 100 # 100 successful operations + + # Prepare bulk messages + messages = [ + (f"channel_{i}", f"message_{i}") + for i in range(100) + ] + + start_time = time.time() + results = await client.publisher.publish_batch(messages) + execution_time = time.time() - start_time + + assert len(results) == 100 + assert all(result == 1 for result in results) + + # Should be faster than individual operations + assert execution_time < 1.0 # Should complete in less than 1 second + + # Pipeline should be used for efficiency + mock_redis_client.pipeline.assert_called_once() + assert mock_pipeline.publish.call_count == 100 + mock_pipeline.execute.assert_called_once() \ No newline at end of file diff --git a/tests/unit/storage/test_session_cache.py b/tests/unit/storage/test_session_cache.py new file mode 100644 index 0000000..1606f3b --- /dev/null +++ b/tests/unit/storage/test_session_cache.py @@ -0,0 +1,883 @@ +""" +Unit tests for session cache management. +""" +import pytest +import time +import uuid +from unittest.mock import Mock, patch +from datetime import datetime, timedelta +from collections import defaultdict + +from detector_worker.storage.session_cache import ( + SessionCache, + SessionCacheManager, + SessionData, + CacheConfig, + CacheEntry, + CacheStats, + SessionError, + CacheError +) +from detector_worker.detection.detection_result import DetectionResult, BoundingBox + + +class TestCacheConfig: + """Test cache configuration.""" + + def test_creation_default(self): + """Test creating cache config with default values.""" + config = CacheConfig() + + assert config.max_size == 1000 + assert config.ttl_seconds == 3600 # 1 hour + assert config.cleanup_interval == 300 # 5 minutes + assert config.eviction_policy == "lru" + assert config.enable_persistence is False + + def test_creation_custom(self): + """Test creating cache config with custom values.""" + config = CacheConfig( + max_size=5000, + ttl_seconds=7200, + cleanup_interval=600, + eviction_policy="lfu", + enable_persistence=True, + persistence_path="/tmp/cache" + ) + + assert config.max_size == 5000 + assert config.ttl_seconds == 7200 + assert config.cleanup_interval == 600 + assert config.eviction_policy == "lfu" + assert config.enable_persistence is True + assert config.persistence_path == "/tmp/cache" + + def test_from_dict(self): + """Test creating config from dictionary.""" + config_dict = { + "max_size": 2000, + "ttl_seconds": 1800, + "eviction_policy": "fifo", + "enable_persistence": True, + "unknown_field": "ignored" + } + + config = CacheConfig.from_dict(config_dict) + + assert config.max_size == 2000 + assert config.ttl_seconds == 1800 + assert config.eviction_policy == "fifo" + assert config.enable_persistence is True + + +class TestCacheEntry: + """Test cache entry data structure.""" + + def test_creation(self): + """Test cache entry creation.""" + data = {"key": "value", "number": 42} + entry = CacheEntry(data, ttl_seconds=600) + + assert entry.data == data + assert entry.ttl_seconds == 600 + assert entry.created_at <= time.time() + assert entry.last_accessed <= time.time() + assert entry.access_count == 1 + assert entry.size > 0 + + def test_is_expired(self): + """Test expiration checking.""" + # Non-expired entry + entry = CacheEntry({"data": "test"}, ttl_seconds=600) + assert entry.is_expired() is False + + # Expired entry (simulate by setting old creation time) + entry.created_at = time.time() - 700 # Created 700 seconds ago + assert entry.is_expired() is True + + # Entry without expiration + entry_no_ttl = CacheEntry({"data": "test"}) + assert entry_no_ttl.is_expired() is False + + def test_touch(self): + """Test updating access time and count.""" + entry = CacheEntry({"data": "test"}) + + original_access_time = entry.last_accessed + original_access_count = entry.access_count + + time.sleep(0.01) # Small delay + entry.touch() + + assert entry.last_accessed > original_access_time + assert entry.access_count == original_access_count + 1 + + def test_age(self): + """Test age calculation.""" + entry = CacheEntry({"data": "test"}) + + time.sleep(0.01) # Small delay + age = entry.age() + + assert age > 0 + assert age < 1 # Should be less than 1 second + + def test_size_estimation(self): + """Test size estimation.""" + small_entry = CacheEntry({"key": "value"}) + large_entry = CacheEntry({"key": "x" * 1000, "data": list(range(100))}) + + assert large_entry.size > small_entry.size + + +class TestSessionData: + """Test session data structure.""" + + def test_creation(self): + """Test session data creation.""" + session_data = SessionData( + session_id="session_123", + camera_id="camera_001", + display_id="display_001" + ) + + assert session_data.session_id == "session_123" + assert session_data.camera_id == "camera_001" + assert session_data.display_id == "display_001" + assert session_data.created_at <= time.time() + assert session_data.last_activity <= time.time() + assert session_data.detection_data == {} + assert session_data.metadata == {} + + def test_update_activity(self): + """Test updating last activity.""" + session_data = SessionData("session_123", "camera_001", "display_001") + + original_activity = session_data.last_activity + time.sleep(0.01) + session_data.update_activity() + + assert session_data.last_activity > original_activity + + def test_add_detection_data(self): + """Test adding detection data.""" + session_data = SessionData("session_123", "camera_001", "display_001") + + detection_data = { + "class": "car", + "confidence": 0.95, + "bbox": [100, 200, 300, 400] + } + + session_data.add_detection_data("main_detection", detection_data) + + assert "main_detection" in session_data.detection_data + assert session_data.detection_data["main_detection"] == detection_data + + def test_add_metadata(self): + """Test adding metadata.""" + session_data = SessionData("session_123", "camera_001", "display_001") + + session_data.add_metadata("model_version", "v2.1") + session_data.add_metadata("inference_time", 0.15) + + assert session_data.metadata["model_version"] == "v2.1" + assert session_data.metadata["inference_time"] == 0.15 + + def test_is_expired(self): + """Test session expiration.""" + session_data = SessionData("session_123", "camera_001", "display_001") + + # Not expired with default timeout + assert session_data.is_expired() is False + + # Expired with short timeout + assert session_data.is_expired(timeout_seconds=0.001) is True + + def test_to_dict(self): + """Test converting session to dictionary.""" + session_data = SessionData("session_123", "camera_001", "display_001") + session_data.add_detection_data("detection", {"class": "car", "confidence": 0.9}) + session_data.add_metadata("model_id", "yolo_v8") + + data_dict = session_data.to_dict() + + assert data_dict["session_id"] == "session_123" + assert data_dict["camera_id"] == "camera_001" + assert data_dict["detection_data"]["detection"]["class"] == "car" + assert data_dict["metadata"]["model_id"] == "yolo_v8" + assert "created_at" in data_dict + assert "last_activity" in data_dict + + +class TestCacheStats: + """Test cache statistics.""" + + def test_creation(self): + """Test cache stats creation.""" + stats = CacheStats() + + assert stats.hits == 0 + assert stats.misses == 0 + assert stats.evictions == 0 + assert stats.size == 0 + assert stats.memory_usage == 0 + + def test_hit_rate_calculation(self): + """Test hit rate calculation.""" + stats = CacheStats() + + # No requests yet + assert stats.hit_rate() == 0.0 + + # Some hits and misses + stats.hits = 8 + stats.misses = 2 + + assert stats.hit_rate() == 0.8 # 8 / (8 + 2) + + def test_total_requests(self): + """Test total requests calculation.""" + stats = CacheStats() + + stats.hits = 15 + stats.misses = 5 + + assert stats.total_requests() == 20 + + +class TestSessionCache: + """Test session cache functionality.""" + + def test_creation(self): + """Test session cache creation.""" + config = CacheConfig(max_size=100, ttl_seconds=300) + cache = SessionCache(config) + + assert cache.config == config + assert cache.max_size == 100 + assert cache.ttl_seconds == 300 + assert len(cache._cache) == 0 + assert len(cache._access_order) == 0 + + def test_put_and_get_session(self): + """Test putting and getting session data.""" + cache = SessionCache(CacheConfig(max_size=10)) + + session_data = SessionData("session_123", "camera_001", "display_001") + session_data.add_detection_data("main", {"class": "car", "confidence": 0.9}) + + # Put session + cache.put("session_123", session_data) + + # Get session + retrieved_data = cache.get("session_123") + + assert retrieved_data is not None + assert retrieved_data.session_id == "session_123" + assert retrieved_data.camera_id == "camera_001" + assert "main" in retrieved_data.detection_data + + def test_get_nonexistent_session(self): + """Test getting non-existent session.""" + cache = SessionCache(CacheConfig(max_size=10)) + + result = cache.get("nonexistent_session") + + assert result is None + + def test_contains_check(self): + """Test checking if session exists.""" + cache = SessionCache(CacheConfig(max_size=10)) + + session_data = SessionData("session_123", "camera_001", "display_001") + cache.put("session_123", session_data) + + assert cache.contains("session_123") is True + assert cache.contains("nonexistent_session") is False + + def test_remove_session(self): + """Test removing session.""" + cache = SessionCache(CacheConfig(max_size=10)) + + session_data = SessionData("session_123", "camera_001", "display_001") + cache.put("session_123", session_data) + + assert cache.contains("session_123") is True + + removed_data = cache.remove("session_123") + + assert removed_data is not None + assert removed_data.session_id == "session_123" + assert cache.contains("session_123") is False + + def test_size_tracking(self): + """Test cache size tracking.""" + cache = SessionCache(CacheConfig(max_size=10)) + + assert cache.size() == 0 + assert cache.is_empty() is True + + # Add sessions + for i in range(3): + session_data = SessionData(f"session_{i}", "camera_001", "display_001") + cache.put(f"session_{i}", session_data) + + assert cache.size() == 3 + assert cache.is_empty() is False + + def test_lru_eviction(self): + """Test LRU eviction policy.""" + cache = SessionCache(CacheConfig(max_size=3, eviction_policy="lru")) + + # Fill cache to capacity + for i in range(3): + session_data = SessionData(f"session_{i}", "camera_001", "display_001") + cache.put(f"session_{i}", session_data) + + # Access session_1 to make it recently used + cache.get("session_1") + + # Add another session (should evict session_0, the least recently used) + new_session = SessionData("session_3", "camera_001", "display_001") + cache.put("session_3", new_session) + + assert cache.size() == 3 + assert cache.contains("session_0") is False # Evicted + assert cache.contains("session_1") is True # Recently accessed + assert cache.contains("session_2") is True + assert cache.contains("session_3") is True # Newly added + + def test_ttl_expiration(self): + """Test TTL-based expiration.""" + cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1)) # 100ms TTL + + session_data = SessionData("session_123", "camera_001", "display_001") + cache.put("session_123", session_data) + + # Should exist immediately + assert cache.contains("session_123") is True + + # Wait for expiration + time.sleep(0.2) + + # Should be expired (but might still be in cache until cleanup) + entry = cache._cache.get("session_123") + if entry: + assert entry.is_expired() is True + + # Getting expired entry should return None and clean it up + retrieved = cache.get("session_123") + assert retrieved is None + assert cache.contains("session_123") is False + + def test_cleanup_expired_entries(self): + """Test cleanup of expired entries.""" + cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1)) + + # Add multiple sessions + for i in range(3): + session_data = SessionData(f"session_{i}", "camera_001", "display_001") + cache.put(f"session_{i}", session_data) + + assert cache.size() == 3 + + # Wait for expiration + time.sleep(0.2) + + # Cleanup expired entries + cleaned_count = cache.cleanup_expired() + + assert cleaned_count == 3 + assert cache.size() == 0 + + def test_clear_cache(self): + """Test clearing entire cache.""" + cache = SessionCache(CacheConfig(max_size=10)) + + # Add sessions + for i in range(5): + session_data = SessionData(f"session_{i}", "camera_001", "display_001") + cache.put(f"session_{i}", session_data) + + assert cache.size() == 5 + + cache.clear() + + assert cache.size() == 0 + assert cache.is_empty() is True + + def test_get_all_sessions(self): + """Test getting all sessions.""" + cache = SessionCache(CacheConfig(max_size=10)) + + sessions = [] + for i in range(3): + session_data = SessionData(f"session_{i}", f"camera_{i}", "display_001") + cache.put(f"session_{i}", session_data) + sessions.append(session_data) + + all_sessions = cache.get_all() + + assert len(all_sessions) == 3 + for session_id, session_data in all_sessions.items(): + assert session_id.startswith("session_") + assert session_data.session_id == session_id + + def test_get_sessions_by_camera(self): + """Test getting sessions by camera ID.""" + cache = SessionCache(CacheConfig(max_size=10)) + + # Add sessions for different cameras + for i in range(2): + session_data1 = SessionData(f"session_cam1_{i}", "camera_001", "display_001") + session_data2 = SessionData(f"session_cam2_{i}", "camera_002", "display_001") + cache.put(f"session_cam1_{i}", session_data1) + cache.put(f"session_cam2_{i}", session_data2) + + camera1_sessions = cache.get_by_camera("camera_001") + camera2_sessions = cache.get_by_camera("camera_002") + + assert len(camera1_sessions) == 2 + assert len(camera2_sessions) == 2 + + for session_data in camera1_sessions: + assert session_data.camera_id == "camera_001" + + for session_data in camera2_sessions: + assert session_data.camera_id == "camera_002" + + def test_statistics_tracking(self): + """Test cache statistics tracking.""" + cache = SessionCache(CacheConfig(max_size=10)) + + session_data = SessionData("session_123", "camera_001", "display_001") + cache.put("session_123", session_data) + + # Cache miss + cache.get("nonexistent_session") + + # Cache hit + cache.get("session_123") + cache.get("session_123") # Another hit + + stats = cache.get_stats() + + assert stats.hits == 2 + assert stats.misses == 1 + assert stats.size == 1 + assert stats.hit_rate() == 2/3 # 2 hits out of 3 total requests + + def test_memory_usage_estimation(self): + """Test memory usage estimation.""" + cache = SessionCache(CacheConfig(max_size=10)) + + initial_memory = cache.get_memory_usage() + + # Add large session + session_data = SessionData("session_123", "camera_001", "display_001") + session_data.add_detection_data("large_data", {"data": "x" * 1000}) + cache.put("session_123", session_data) + + after_memory = cache.get_memory_usage() + + assert after_memory > initial_memory + + +class TestSessionCacheManager: + """Test session cache manager.""" + + def test_singleton_behavior(self): + """Test that SessionCacheManager is a singleton.""" + manager1 = SessionCacheManager() + manager2 = SessionCacheManager() + + assert manager1 is manager2 + + def test_initialization(self): + """Test session cache manager initialization.""" + manager = SessionCacheManager() + + assert manager.detection_cache is not None + assert manager.pipeline_cache is not None + assert manager.session_cache is not None + assert isinstance(manager.detection_cache, SessionCache) + assert isinstance(manager.pipeline_cache, SessionCache) + assert isinstance(manager.session_cache, SessionCache) + + def test_cache_detection_result(self): + """Test caching detection results.""" + manager = SessionCacheManager() + manager.clear_all() # Start fresh + + detection_data = { + "class": "car", + "confidence": 0.95, + "bbox": [100, 200, 300, 400], + "track_id": 1001 + } + + manager.cache_detection("camera_001", detection_data) + + cached_detection = manager.get_cached_detection("camera_001") + + assert cached_detection is not None + assert cached_detection["class"] == "car" + assert cached_detection["confidence"] == 0.95 + assert cached_detection["track_id"] == 1001 + + def test_cache_pipeline_result(self): + """Test caching pipeline results.""" + manager = SessionCacheManager() + manager.clear_all() + + pipeline_result = { + "status": "success", + "detections": [{"class": "car", "confidence": 0.9}], + "execution_time": 0.15, + "model_id": "yolo_v8" + } + + manager.cache_pipeline_result("camera_001", pipeline_result) + + cached_result = manager.get_cached_pipeline_result("camera_001") + + assert cached_result is not None + assert cached_result["status"] == "success" + assert cached_result["execution_time"] == 0.15 + assert len(cached_result["detections"]) == 1 + + def test_manage_session_data(self): + """Test session data management.""" + manager = SessionCacheManager() + manager.clear_all() + + session_id = str(uuid.uuid4()) + + # Create session + manager.create_session(session_id, "camera_001", {"initial": "data"}) + + # Update session + manager.update_session_detection(session_id, {"car_brand": "Toyota"}) + + # Get session + session_data = manager.get_session_detection(session_id) + + assert session_data is not None + assert "initial" in session_data + assert session_data["car_brand"] == "Toyota" + + def test_set_latest_frame(self): + """Test setting and getting latest frame.""" + manager = SessionCacheManager() + manager.clear_all() + + frame_data = b"fake_frame_data" + + manager.set_latest_frame("camera_001", frame_data) + + retrieved_frame = manager.get_latest_frame("camera_001") + + assert retrieved_frame == frame_data + + def test_frame_skip_flag_management(self): + """Test frame skip flag management.""" + manager = SessionCacheManager() + manager.clear_all() + + # Initially should be False + assert manager.get_frame_skip_flag("camera_001") is False + + # Set to True + manager.set_frame_skip_flag("camera_001", True) + assert manager.get_frame_skip_flag("camera_001") is True + + # Set back to False + manager.set_frame_skip_flag("camera_001", False) + assert manager.get_frame_skip_flag("camera_001") is False + + def test_cleanup_expired_sessions(self): + """Test cleanup of expired sessions.""" + manager = SessionCacheManager() + manager.clear_all() + + # Create sessions with short TTL + manager.session_cache = SessionCache(CacheConfig(max_size=10, ttl_seconds=0.1)) + + # Add sessions + for i in range(3): + session_id = f"session_{i}" + manager.create_session(session_id, "camera_001", {"test": "data"}) + + assert manager.session_cache.size() == 3 + + # Wait for expiration + time.sleep(0.2) + + # Cleanup + expired_count = manager.cleanup_expired_sessions() + + assert expired_count == 3 + assert manager.session_cache.size() == 0 + + def test_clear_camera_cache(self): + """Test clearing cache for specific camera.""" + manager = SessionCacheManager() + manager.clear_all() + + # Add data for multiple cameras + manager.cache_detection("camera_001", {"class": "car"}) + manager.cache_detection("camera_002", {"class": "truck"}) + manager.cache_pipeline_result("camera_001", {"status": "success"}) + manager.set_latest_frame("camera_001", b"frame1") + manager.set_latest_frame("camera_002", b"frame2") + + # Clear camera_001 cache + manager.clear_camera_cache("camera_001") + + # camera_001 data should be gone + assert manager.get_cached_detection("camera_001") is None + assert manager.get_cached_pipeline_result("camera_001") is None + assert manager.get_latest_frame("camera_001") is None + + # camera_002 data should remain + assert manager.get_cached_detection("camera_002") is not None + assert manager.get_latest_frame("camera_002") is not None + + def test_get_cache_statistics(self): + """Test getting cache statistics.""" + manager = SessionCacheManager() + manager.clear_all() + + # Add some data to generate statistics + manager.cache_detection("camera_001", {"class": "car"}) + manager.cache_pipeline_result("camera_001", {"status": "success"}) + manager.create_session("session_123", "camera_001", {"initial": "data"}) + + # Access data to generate hits/misses + manager.get_cached_detection("camera_001") # Hit + manager.get_cached_detection("camera_002") # Miss + + stats = manager.get_cache_statistics() + + assert "detection_cache" in stats + assert "pipeline_cache" in stats + assert "session_cache" in stats + assert "total_memory_usage" in stats + + detection_stats = stats["detection_cache"] + assert detection_stats["size"] >= 1 + assert detection_stats["hits"] >= 1 + assert detection_stats["misses"] >= 1 + + def test_memory_pressure_handling(self): + """Test handling memory pressure.""" + # Create manager with small cache sizes + config = CacheConfig(max_size=3) + manager = SessionCacheManager() + manager.detection_cache = SessionCache(config) + manager.pipeline_cache = SessionCache(config) + manager.session_cache = SessionCache(config) + + # Fill caches beyond capacity + for i in range(5): + manager.cache_detection(f"camera_{i}", {"class": "car", "data": "x" * 100}) + manager.cache_pipeline_result(f"camera_{i}", {"status": "success", "data": "y" * 100}) + manager.create_session(f"session_{i}", f"camera_{i}", {"data": "z" * 100}) + + # Caches should not exceed max size due to eviction + assert manager.detection_cache.size() <= 3 + assert manager.pipeline_cache.size() <= 3 + assert manager.session_cache.size() <= 3 + + def test_concurrent_access_thread_safety(self): + """Test thread safety of concurrent cache access.""" + import threading + import concurrent.futures + + manager = SessionCacheManager() + manager.clear_all() + + results = [] + errors = [] + + def cache_operation(thread_id): + try: + # Each thread performs multiple cache operations + for i in range(10): + session_id = f"thread_{thread_id}_session_{i}" + + # Create session + manager.create_session(session_id, f"camera_{thread_id}", {"thread": thread_id, "index": i}) + + # Update session + manager.update_session_detection(session_id, {"updated": True}) + + # Read session + data = manager.get_session_detection(session_id) + if data and data.get("thread") == thread_id: + results.append((thread_id, i)) + + except Exception as e: + errors.append((thread_id, str(e))) + + # Run operations concurrently + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(cache_operation, i) for i in range(5)] + concurrent.futures.wait(futures) + + # Should have no errors and successful operations + assert len(errors) == 0 + assert len(results) >= 25 # At least some operations should succeed + + +class TestSessionCacheIntegration: + """Integration tests for session cache.""" + + def test_complete_detection_workflow(self): + """Test complete detection workflow with caching.""" + manager = SessionCacheManager() + manager.clear_all() + + camera_id = "camera_001" + session_id = str(uuid.uuid4()) + + # 1. Cache initial detection + detection_data = { + "class": "car", + "confidence": 0.92, + "bbox": [100, 200, 300, 400], + "track_id": 1001, + "timestamp": int(time.time() * 1000) + } + + manager.cache_detection(camera_id, detection_data) + + # 2. Create session for tracking + initial_session_data = { + "detection_class": detection_data["class"], + "confidence": detection_data["confidence"], + "track_id": detection_data["track_id"] + } + + manager.create_session(session_id, camera_id, initial_session_data) + + # 3. Cache pipeline processing result + pipeline_result = { + "status": "processing", + "stage": "classification", + "detections": [detection_data], + "branches_completed": [], + "branches_pending": ["car_brand_cls", "car_bodytype_cls"] + } + + manager.cache_pipeline_result(camera_id, pipeline_result) + + # 4. Update session with classification results + classification_updates = [ + {"car_brand": "Toyota", "brand_confidence": 0.87}, + {"car_body_type": "Sedan", "bodytype_confidence": 0.82} + ] + + for update in classification_updates: + manager.update_session_detection(session_id, update) + + # 5. Update pipeline result to completed + final_pipeline_result = { + "status": "completed", + "stage": "finished", + "detections": [detection_data], + "branches_completed": ["car_brand_cls", "car_bodytype_cls"], + "branches_pending": [], + "execution_time": 0.25 + } + + manager.cache_pipeline_result(camera_id, final_pipeline_result) + + # 6. Verify all cached data + cached_detection = manager.get_cached_detection(camera_id) + cached_pipeline = manager.get_cached_pipeline_result(camera_id) + cached_session = manager.get_session_detection(session_id) + + # Assertions + assert cached_detection["class"] == "car" + assert cached_detection["track_id"] == 1001 + + assert cached_pipeline["status"] == "completed" + assert len(cached_pipeline["branches_completed"]) == 2 + + assert cached_session["detection_class"] == "car" + assert cached_session["car_brand"] == "Toyota" + assert cached_session["car_body_type"] == "Sedan" + assert cached_session["brand_confidence"] == 0.87 + + def test_cache_performance_under_load(self): + """Test cache performance under load.""" + manager = SessionCacheManager() + manager.clear_all() + + import time + + # Measure performance of cache operations + start_time = time.time() + + # Perform many cache operations + for i in range(1000): + camera_id = f"camera_{i % 10}" # 10 different cameras + session_id = f"session_{i}" + + # Cache detection + detection_data = { + "class": "car", + "confidence": 0.9 + (i % 10) * 0.01, + "track_id": i, + "bbox": [i % 100, i % 100, (i % 100) + 200, (i % 100) + 200] + } + manager.cache_detection(camera_id, detection_data) + + # Create session + manager.create_session(session_id, camera_id, {"index": i}) + + # Read back (every 10th operation) + if i % 10 == 0: + manager.get_cached_detection(camera_id) + manager.get_session_detection(session_id) + + end_time = time.time() + total_time = end_time - start_time + + # Should complete in reasonable time (less than 1 second) + assert total_time < 1.0 + + # Verify cache statistics + stats = manager.get_cache_statistics() + assert stats["detection_cache"]["size"] > 0 + assert stats["session_cache"]["size"] > 0 + assert stats["detection_cache"]["hits"] > 0 + + def test_cache_persistence_and_recovery(self): + """Test cache persistence and recovery (if enabled).""" + # This test would be more meaningful with actual persistence + # For now, test the configuration and structure + + persistence_config = CacheConfig( + max_size=100, + enable_persistence=True, + persistence_path="/tmp/detector_cache_test" + ) + + cache = SessionCache(persistence_config) + + # Add some data + session_data = SessionData("session_123", "camera_001", "display_001") + session_data.add_detection_data("main", {"class": "car", "confidence": 0.95}) + + cache.put("session_123", session_data) + + # Verify data exists + assert cache.contains("session_123") is True + + # In a real implementation, this would test: + # 1. Saving cache to disk + # 2. Loading cache from disk + # 3. Verifying data integrity after reload \ No newline at end of file diff --git a/tests/unit/streams/test_stream_manager.py b/tests/unit/streams/test_stream_manager.py new file mode 100644 index 0000000..2ec9433 --- /dev/null +++ b/tests/unit/streams/test_stream_manager.py @@ -0,0 +1,818 @@ +""" +Unit tests for stream management functionality. +""" +import pytest +import asyncio +import threading +import time +from unittest.mock import Mock, AsyncMock, patch, MagicMock +import numpy as np +import cv2 + +from detector_worker.streams.stream_manager import ( + StreamManager, + StreamInfo, + StreamConfig, + StreamReader, + StreamError, + ConnectionError as StreamConnectionError +) +from detector_worker.streams.frame_reader import FrameReader +from detector_worker.core.exceptions import ConfigurationError + + +class TestStreamConfig: + """Test stream configuration.""" + + def test_creation_rtsp(self): + """Test creating RTSP stream config.""" + config = StreamConfig( + stream_url="rtsp://example.com/stream1", + stream_type="rtsp", + target_fps=15, + reconnect_interval=5.0, + max_retries=3 + ) + + assert config.stream_url == "rtsp://example.com/stream1" + assert config.stream_type == "rtsp" + assert config.target_fps == 15 + assert config.reconnect_interval == 5.0 + assert config.max_retries == 3 + + def test_creation_http_snapshot(self): + """Test creating HTTP snapshot config.""" + config = StreamConfig( + stream_url="http://example.com/snapshot.jpg", + stream_type="http_snapshot", + snapshot_interval=1.0, + timeout=10.0 + ) + + assert config.stream_url == "http://example.com/snapshot.jpg" + assert config.stream_type == "http_snapshot" + assert config.snapshot_interval == 1.0 + assert config.timeout == 10.0 + + def test_from_dict(self): + """Test creating config from dictionary.""" + config_dict = { + "stream_url": "rtsp://camera.example.com/live", + "stream_type": "rtsp", + "target_fps": 20, + "reconnect_interval": 3.0, + "max_retries": 5, + "crop_region": [100, 200, 300, 400], + "unknown_field": "ignored" + } + + config = StreamConfig.from_dict(config_dict) + + assert config.stream_url == "rtsp://camera.example.com/live" + assert config.target_fps == 20 + assert config.crop_region == [100, 200, 300, 400] + + def test_validation(self): + """Test config validation.""" + # Valid config + valid_config = StreamConfig( + stream_url="rtsp://example.com/stream", + stream_type="rtsp" + ) + assert valid_config.is_valid() is True + + # Invalid config (empty URL) + invalid_config = StreamConfig( + stream_url="", + stream_type="rtsp" + ) + assert invalid_config.is_valid() is False + + +class TestStreamInfo: + """Test stream information.""" + + def test_creation(self): + """Test stream info creation.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + info = StreamInfo( + stream_id="stream_001", + config=config, + camera_id="camera_001" + ) + + assert info.stream_id == "stream_001" + assert info.config == config + assert info.camera_id == "camera_001" + assert info.status == "inactive" + assert info.reference_count == 0 + assert info.created_at <= time.time() + + def test_increment_reference(self): + """Test incrementing reference count.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + info = StreamInfo("stream_001", config, "camera_001") + + assert info.reference_count == 0 + + info.increment_reference() + assert info.reference_count == 1 + + info.increment_reference() + assert info.reference_count == 2 + + def test_decrement_reference(self): + """Test decrementing reference count.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + info = StreamInfo("stream_001", config, "camera_001") + + info.reference_count = 3 + + assert info.decrement_reference() == 2 + assert info.reference_count == 2 + + assert info.decrement_reference() == 1 + assert info.decrement_reference() == 0 + + # Should not go below 0 + assert info.decrement_reference() == 0 + + def test_update_status(self): + """Test updating stream status.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + info = StreamInfo("stream_001", config, "camera_001") + + info.update_status("connecting") + assert info.status == "connecting" + assert info.last_update <= time.time() + + info.update_status("active", frame_count=100) + assert info.status == "active" + assert info.frame_count == 100 + + def test_get_stats(self): + """Test getting stream statistics.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + info = StreamInfo("stream_001", config, "camera_001") + + info.frame_count = 1000 + info.error_count = 5 + info.reference_count = 2 + + stats = info.get_stats() + + assert stats["stream_id"] == "stream_001" + assert stats["status"] == "inactive" + assert stats["frame_count"] == 1000 + assert stats["error_count"] == 5 + assert stats["reference_count"] == 2 + assert "uptime" in stats + + +class TestStreamReader: + """Test stream reader functionality.""" + + def test_creation(self): + """Test stream reader creation.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + reader = StreamReader("stream_001", config) + + assert reader.stream_id == "stream_001" + assert reader.config == config + assert reader.is_running is False + assert reader.latest_frame is None + assert reader.frame_queue.qsize() == 0 + + @pytest.mark.asyncio + async def test_start_rtsp_stream(self): + """Test starting RTSP stream.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp", target_fps=10) + reader = StreamReader("stream_001", config) + + # Mock cv2.VideoCapture + with patch('cv2.VideoCapture') as mock_cap: + mock_cap_instance = Mock() + mock_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + mock_cap_instance.read.return_value = (True, np.zeros((480, 640, 3), dtype=np.uint8)) + + await reader.start() + + assert reader.is_running is True + assert reader.capture is not None + mock_cap.assert_called_once_with("rtsp://example.com/stream") + + @pytest.mark.asyncio + async def test_start_rtsp_connection_failure(self): + """Test RTSP connection failure.""" + config = StreamConfig("rtsp://invalid.com/stream", "rtsp") + reader = StreamReader("stream_001", config) + + with patch('cv2.VideoCapture') as mock_cap: + mock_cap_instance = Mock() + mock_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = False + + with pytest.raises(StreamConnectionError): + await reader.start() + + @pytest.mark.asyncio + async def test_start_http_snapshot(self): + """Test starting HTTP snapshot stream.""" + config = StreamConfig("http://example.com/snapshot.jpg", "http_snapshot", snapshot_interval=1.0) + reader = StreamReader("stream_001", config) + + with patch('requests.get') as mock_get: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b"fake_image_data" + mock_get.return_value = mock_response + + with patch('cv2.imdecode') as mock_decode: + mock_decode.return_value = np.zeros((480, 640, 3), dtype=np.uint8) + + await reader.start() + + assert reader.is_running is True + mock_get.assert_called_once() + + @pytest.mark.asyncio + async def test_stop_stream(self): + """Test stopping stream.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + reader = StreamReader("stream_001", config) + + # Simulate running state + reader.is_running = True + reader.capture = Mock() + reader.capture.release = Mock() + reader._reader_task = Mock() + reader._reader_task.cancel = Mock() + + await reader.stop() + + assert reader.is_running is False + reader.capture.release.assert_called_once() + reader._reader_task.cancel.assert_called_once() + + def test_get_latest_frame(self): + """Test getting latest frame.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + reader = StreamReader("stream_001", config) + + test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128 + reader.latest_frame = test_frame + + frame = reader.get_latest_frame() + + assert np.array_equal(frame, test_frame) + + def test_get_frame_from_queue(self): + """Test getting frame from queue.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + reader = StreamReader("stream_001", config) + + test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128 + reader.frame_queue.put(test_frame) + + frame = reader.get_frame(timeout=0.1) + + assert np.array_equal(frame, test_frame) + + def test_get_frame_timeout(self): + """Test getting frame with timeout.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + reader = StreamReader("stream_001", config) + + # Queue is empty, should timeout + frame = reader.get_frame(timeout=0.1) + + assert frame is None + + def test_get_stats(self): + """Test getting reader statistics.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + reader = StreamReader("stream_001", config) + + reader.frame_count = 500 + reader.error_count = 2 + + stats = reader.get_stats() + + assert stats["stream_id"] == "stream_001" + assert stats["frame_count"] == 500 + assert stats["error_count"] == 2 + assert stats["is_running"] is False + + +class TestStreamManager: + """Test stream manager functionality.""" + + def test_initialization(self): + """Test stream manager initialization.""" + manager = StreamManager() + + assert len(manager.streams) == 0 + assert len(manager.readers) == 0 + assert manager.max_streams == 10 + assert manager.default_timeout == 30.0 + + def test_initialization_with_config(self): + """Test initialization with custom configuration.""" + config = { + "max_streams": 20, + "default_timeout": 60.0, + "frame_buffer_size": 5 + } + + manager = StreamManager(config) + + assert manager.max_streams == 20 + assert manager.default_timeout == 60.0 + assert manager.frame_buffer_size == 5 + + @pytest.mark.asyncio + async def test_create_stream_new(self): + """Test creating new stream.""" + manager = StreamManager() + + config = StreamConfig("rtsp://example.com/stream", "rtsp") + + with patch.object(StreamReader, 'start', new_callable=AsyncMock): + stream_info = await manager.create_stream("camera_001", config, "sub_001") + + assert "camera_001" in manager.streams + assert manager.streams["camera_001"].reference_count == 1 + assert manager.streams["camera_001"].camera_id == "camera_001" + + @pytest.mark.asyncio + async def test_create_stream_shared(self): + """Test creating shared stream (same URL).""" + manager = StreamManager() + + config = StreamConfig("rtsp://example.com/stream", "rtsp") + + with patch.object(StreamReader, 'start', new_callable=AsyncMock): + # Create first stream + stream_info1 = await manager.create_stream("camera_001", config, "sub_001") + + # Create second stream with same URL + stream_info2 = await manager.create_stream("camera_001", config, "sub_002") + + assert stream_info1 == stream_info2 # Should be same stream + assert manager.streams["camera_001"].reference_count == 2 + + @pytest.mark.asyncio + async def test_create_stream_max_limit(self): + """Test creating stream when at max limit.""" + manager = StreamManager({"max_streams": 1}) + + config1 = StreamConfig("rtsp://example.com/stream1", "rtsp") + config2 = StreamConfig("rtsp://example.com/stream2", "rtsp") + + with patch.object(StreamReader, 'start', new_callable=AsyncMock): + # Create first stream (should succeed) + await manager.create_stream("camera_001", config1, "sub_001") + + # Try to create second stream (should fail) + with pytest.raises(StreamError) as exc_info: + await manager.create_stream("camera_002", config2, "sub_002") + + assert "maximum number of streams" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_remove_stream_single_reference(self): + """Test removing stream with single reference.""" + manager = StreamManager() + + config = StreamConfig("rtsp://example.com/stream", "rtsp") + + with patch.object(StreamReader, 'start', new_callable=AsyncMock): + with patch.object(StreamReader, 'stop', new_callable=AsyncMock): + # Create stream + await manager.create_stream("camera_001", config, "sub_001") + + # Remove stream + removed = await manager.remove_stream("camera_001", "sub_001") + + assert removed is True + assert "camera_001" not in manager.streams + + @pytest.mark.asyncio + async def test_remove_stream_multiple_references(self): + """Test removing stream with multiple references.""" + manager = StreamManager() + + config = StreamConfig("rtsp://example.com/stream", "rtsp") + + with patch.object(StreamReader, 'start', new_callable=AsyncMock): + # Create shared stream + await manager.create_stream("camera_001", config, "sub_001") + await manager.create_stream("camera_001", config, "sub_002") + + assert manager.streams["camera_001"].reference_count == 2 + + # Remove one reference + removed = await manager.remove_stream("camera_001", "sub_001") + + assert removed is True + assert "camera_001" in manager.streams # Still exists + assert manager.streams["camera_001"].reference_count == 1 + + def test_get_stream_info(self): + """Test getting stream information.""" + manager = StreamManager() + + config = StreamConfig("rtsp://example.com/stream", "rtsp") + stream_info = StreamInfo("camera_001", config, "camera_001") + manager.streams["camera_001"] = stream_info + + retrieved_info = manager.get_stream_info("camera_001") + + assert retrieved_info == stream_info + + def test_get_nonexistent_stream_info(self): + """Test getting info for non-existent stream.""" + manager = StreamManager() + + info = manager.get_stream_info("nonexistent_camera") + + assert info is None + + def test_get_latest_frame(self): + """Test getting latest frame from stream.""" + manager = StreamManager() + + # Create mock reader + mock_reader = Mock() + test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128 + mock_reader.get_latest_frame.return_value = test_frame + + manager.readers["camera_001"] = mock_reader + + frame = manager.get_latest_frame("camera_001") + + assert np.array_equal(frame, test_frame) + mock_reader.get_latest_frame.assert_called_once() + + def test_get_frame_from_nonexistent_stream(self): + """Test getting frame from non-existent stream.""" + manager = StreamManager() + + frame = manager.get_latest_frame("nonexistent_camera") + + assert frame is None + + def test_list_active_streams(self): + """Test listing active streams.""" + manager = StreamManager() + + # Add streams + config1 = StreamConfig("rtsp://example.com/stream1", "rtsp") + config2 = StreamConfig("rtsp://example.com/stream2", "rtsp") + + stream1 = StreamInfo("camera_001", config1, "camera_001") + stream1.update_status("active") + + stream2 = StreamInfo("camera_002", config2, "camera_002") + stream2.update_status("inactive") + + manager.streams["camera_001"] = stream1 + manager.streams["camera_002"] = stream2 + + active_streams = manager.list_active_streams() + + assert len(active_streams) == 1 + assert active_streams[0]["camera_id"] == "camera_001" + assert active_streams[0]["status"] == "active" + + @pytest.mark.asyncio + async def test_stop_all_streams(self): + """Test stopping all streams.""" + manager = StreamManager() + + # Add mock streams + mock_reader1 = Mock() + mock_reader1.stop = AsyncMock() + mock_reader2 = Mock() + mock_reader2.stop = AsyncMock() + + manager.readers["camera_001"] = mock_reader1 + manager.readers["camera_002"] = mock_reader2 + + stopped_count = await manager.stop_all_streams() + + assert stopped_count == 2 + mock_reader1.stop.assert_called_once() + mock_reader2.stop.assert_called_once() + assert len(manager.readers) == 0 + assert len(manager.streams) == 0 + + def test_get_stream_statistics(self): + """Test getting stream statistics.""" + manager = StreamManager() + + # Add streams + config = StreamConfig("rtsp://example.com/stream", "rtsp") + + stream1 = StreamInfo("camera_001", config, "camera_001") + stream1.update_status("active") + stream1.frame_count = 1000 + stream1.reference_count = 2 + + stream2 = StreamInfo("camera_002", config, "camera_002") + stream2.update_status("error") + stream2.error_count = 5 + + manager.streams["camera_001"] = stream1 + manager.streams["camera_002"] = stream2 + + stats = manager.get_stream_statistics() + + assert stats["total_streams"] == 2 + assert stats["active_streams"] == 1 + assert stats["error_streams"] == 1 + assert stats["total_references"] == 2 + assert "status_breakdown" in stats + + @pytest.mark.asyncio + async def test_reconnect_stream(self): + """Test reconnecting failed stream.""" + manager = StreamManager() + + config = StreamConfig("rtsp://example.com/stream", "rtsp") + stream_info = StreamInfo("camera_001", config, "camera_001") + stream_info.update_status("error") + manager.streams["camera_001"] = stream_info + + # Mock reader + mock_reader = Mock() + mock_reader.start = AsyncMock() + mock_reader.stop = AsyncMock() + manager.readers["camera_001"] = mock_reader + + result = await manager.reconnect_stream("camera_001") + + assert result is True + mock_reader.stop.assert_called_once() + mock_reader.start.assert_called_once() + assert stream_info.status != "error" + + @pytest.mark.asyncio + async def test_health_check_streams(self): + """Test health check of all streams.""" + manager = StreamManager() + + # Add streams with different states + config = StreamConfig("rtsp://example.com/stream", "rtsp") + + stream1 = StreamInfo("camera_001", config, "camera_001") + stream1.update_status("active") + + stream2 = StreamInfo("camera_002", config, "camera_002") + stream2.update_status("error") + + manager.streams["camera_001"] = stream1 + manager.streams["camera_002"] = stream2 + + # Mock readers + mock_reader1 = Mock() + mock_reader1.is_running = True + mock_reader2 = Mock() + mock_reader2.is_running = False + + manager.readers["camera_001"] = mock_reader1 + manager.readers["camera_002"] = mock_reader2 + + health_report = await manager.health_check() + + assert health_report["total_streams"] == 2 + assert health_report["healthy_streams"] == 1 + assert health_report["unhealthy_streams"] == 1 + assert len(health_report["unhealthy_stream_ids"]) == 1 + + +class TestStreamManagerIntegration: + """Integration tests for stream manager.""" + + @pytest.mark.asyncio + async def test_multiple_subscribers_same_stream(self): + """Test multiple subscribers to same stream.""" + manager = StreamManager() + + config = StreamConfig("rtsp://example.com/shared_stream", "rtsp") + + with patch.object(StreamReader, 'start', new_callable=AsyncMock): + # Multiple subscribers to same stream + stream1 = await manager.create_stream("camera_001", config, "sub_001") + stream2 = await manager.create_stream("camera_001", config, "sub_002") + stream3 = await manager.create_stream("camera_001", config, "sub_003") + + # All should reference same stream + assert stream1 == stream2 == stream3 + assert manager.streams["camera_001"].reference_count == 3 + assert len(manager.readers) == 1 # Only one actual reader + + # Remove subscribers one by one + with patch.object(StreamReader, 'stop', new_callable=AsyncMock) as mock_stop: + await manager.remove_stream("camera_001", "sub_001") # ref_count = 2 + await manager.remove_stream("camera_001", "sub_002") # ref_count = 1 + + # Stream should still exist + assert "camera_001" in manager.streams + mock_stop.assert_not_called() + + await manager.remove_stream("camera_001", "sub_003") # ref_count = 0 + + # Now stream should be stopped and removed + assert "camera_001" not in manager.streams + mock_stop.assert_called_once() + + @pytest.mark.asyncio + async def test_stream_failure_and_recovery(self): + """Test stream failure and recovery workflow.""" + manager = StreamManager() + + config = StreamConfig("rtsp://unreliable.com/stream", "rtsp", max_retries=2) + + # Mock reader that fails initially then succeeds + with patch.object(StreamReader, 'start', new_callable=AsyncMock) as mock_start: + mock_start.side_effect = [ + StreamConnectionError("Connection failed"), # First attempt fails + None # Second attempt succeeds + ] + + # First attempt should fail + with pytest.raises(StreamConnectionError): + await manager.create_stream("camera_001", config, "sub_001") + + # Retry should succeed + stream_info = await manager.create_stream("camera_001", config, "sub_001") + + assert stream_info is not None + assert mock_start.call_count == 2 + + @pytest.mark.asyncio + async def test_concurrent_stream_operations(self): + """Test concurrent stream operations.""" + manager = StreamManager() + + configs = [ + StreamConfig(f"rtsp://example.com/stream{i}", "rtsp") + for i in range(5) + ] + + with patch.object(StreamReader, 'start', new_callable=AsyncMock): + with patch.object(StreamReader, 'stop', new_callable=AsyncMock): + # Create streams concurrently + create_tasks = [ + manager.create_stream(f"camera_{i}", configs[i], f"sub_{i}") + for i in range(5) + ] + + results = await asyncio.gather(*create_tasks) + + assert len(results) == 5 + assert len(manager.streams) == 5 + + # Remove streams concurrently + remove_tasks = [ + manager.remove_stream(f"camera_{i}", f"sub_{i}") + for i in range(5) + ] + + remove_results = await asyncio.gather(*remove_tasks) + + assert all(remove_results) + assert len(manager.streams) == 0 + + @pytest.mark.asyncio + async def test_memory_management_large_scale(self): + """Test memory management with many streams.""" + manager = StreamManager({"max_streams": 50}) + + # Create many streams + with patch.object(StreamReader, 'start', new_callable=AsyncMock): + for i in range(30): + config = StreamConfig(f"rtsp://example.com/stream{i}", "rtsp") + await manager.create_stream(f"camera_{i}", config, f"sub_{i}") + + # Verify memory usage is reasonable + stats = manager.get_stream_statistics() + assert stats["total_streams"] == 30 + assert stats["active_streams"] <= 30 + + # Test bulk cleanup + with patch.object(StreamReader, 'stop', new_callable=AsyncMock): + stopped_count = await manager.stop_all_streams() + + assert stopped_count == 30 + assert len(manager.streams) == 0 + assert len(manager.readers) == 0 + + +class TestFrameReaderIntegration: + """Integration tests for frame reader.""" + + @pytest.mark.asyncio + async def test_rtsp_frame_processing(self): + """Test RTSP frame processing pipeline.""" + config = StreamConfig( + stream_url="rtsp://example.com/stream", + stream_type="rtsp", + target_fps=10, + crop_region=[100, 100, 400, 300] + ) + + reader = StreamReader("test_stream", config) + + # Mock cv2.VideoCapture + with patch('cv2.VideoCapture') as mock_cap: + mock_cap_instance = Mock() + mock_cap.return_value = mock_cap_instance + mock_cap_instance.isOpened.return_value = True + + # Mock frame sequence + test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 128 + mock_cap_instance.read.side_effect = [ + (True, test_frame), # First frame + (True, test_frame * 0.8), # Second frame + (False, None), # Connection lost + (True, test_frame * 1.2), # Reconnected + ] + + await reader.start() + + # Let reader process some frames + await asyncio.sleep(0.1) + + # Verify frame processing + latest_frame = reader.get_latest_frame() + assert latest_frame is not None + assert latest_frame.shape == (480, 640, 3) + + await reader.stop() + + @pytest.mark.asyncio + async def test_http_snapshot_processing(self): + """Test HTTP snapshot processing.""" + config = StreamConfig( + stream_url="http://camera.example.com/snapshot.jpg", + stream_type="http_snapshot", + snapshot_interval=0.5, + timeout=5.0 + ) + + reader = StreamReader("snapshot_stream", config) + + with patch('requests.get') as mock_get: + # Mock HTTP responses + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = b"fake_jpeg_data" + mock_get.return_value = mock_response + + with patch('cv2.imdecode') as mock_decode: + test_frame = np.ones((480, 640, 3), dtype=np.uint8) * 200 + mock_decode.return_value = test_frame + + await reader.start() + + # Wait for snapshot capture + await asyncio.sleep(0.6) + + # Verify snapshot processing + latest_frame = reader.get_latest_frame() + assert latest_frame is not None + assert np.array_equal(latest_frame, test_frame) + + await reader.stop() + + def test_frame_queue_management(self): + """Test frame queue management and buffering.""" + config = StreamConfig("rtsp://example.com/stream", "rtsp") + reader = StreamReader("queue_test", config, frame_buffer_size=3) + + # Add frames to queue + frames = [ + np.ones((100, 100, 3), dtype=np.uint8) * i + for i in range(50, 250, 50) # 4 different frames + ] + + for frame in frames[:3]: # Fill buffer + reader._add_frame_to_queue(frame) + + assert reader.frame_queue.qsize() == 3 + + # Add one more (should drop oldest) + reader._add_frame_to_queue(frames[3]) + assert reader.frame_queue.qsize() == 3 + + # Verify frame order (oldest should be dropped) + retrieved_frames = [] + while not reader.frame_queue.empty(): + retrieved_frames.append(reader.get_frame(timeout=0.1)) + + assert len(retrieved_frames) == 3 + # First frame should have been dropped, so we should have frames 1,2,3 + assert not np.array_equal(retrieved_frames[0], frames[0]) \ No newline at end of file diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..7854dfb --- /dev/null +++ b/tox.ini @@ -0,0 +1,78 @@ +[tox] +envlist = py39,py310,py311,py312,py313,flake8,mypy,coverage +skipsdist = True +isolated_build = True + +[testenv] +deps = -r{toxinidir}/requirements.txt + -r{toxinidir}/requirements-dev.txt +commands = python scripts/run_tests.py --unit --no-coverage +setenv = + PYTHONPATH = {toxinidir} + DETECTOR_WORKER_ENV = test + +[testenv:integration] +commands = python scripts/run_tests.py --integration + +[testenv:performance] +commands = python scripts/run_tests.py --performance + +[testenv:flake8] +deps = flake8 + flake8-docstrings + flake8-import-order +commands = flake8 detector_worker tests scripts + +[testenv:mypy] +deps = mypy + types-redis + types-requests + types-PyYAML +commands = mypy detector_worker --ignore-missing-imports + +[testenv:coverage] +commands = python scripts/run_tests.py --coverage + +[testenv:format] +deps = black + isort +commands = + black detector_worker tests scripts + isort detector_worker tests scripts + +[testenv:security] +deps = bandit + safety +commands = + bandit -r detector_worker + safety check + +[testenv:docs] +deps = sphinx + sphinx-rtd-theme +commands = sphinx-build -b html docs docs/_build + +[flake8] +max-line-length = 120 +extend-ignore = E203,W503,E501 +exclude = .git,__pycache__,build,dist,*.egg-info,.tox +per-file-ignores = + __init__.py:F401 + tests/*:S101,S105,S106 + +[coverage:run] +source = detector_worker +omit = + */tests/* + */test_* + */__pycache__/* + +[coverage:report] +exclude_lines = + pragma: no cover + def __repr__ + raise AssertionError + raise NotImplementedError + if 0: + if __name__ == .__main__.: + @abstractmethod \ No newline at end of file