Refactor: PHASE 8: Testing & Integration
This commit is contained in:
parent
af34f4fd08
commit
9e8c6804a7
32 changed files with 17128 additions and 0 deletions
429
tests/unit/core/test_config.py
Normal file
429
tests/unit/core/test_config.py
Normal file
|
@ -0,0 +1,429 @@
|
|||
"""
|
||||
Unit tests for configuration management system.
|
||||
"""
|
||||
import pytest
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from detector_worker.core.config import (
|
||||
ConfigurationManager,
|
||||
JsonFileProvider,
|
||||
EnvironmentProvider,
|
||||
DatabaseConfig,
|
||||
RedisConfig,
|
||||
StreamConfig,
|
||||
ModelConfig,
|
||||
LoggingConfig,
|
||||
get_config_manager,
|
||||
validate_config
|
||||
)
|
||||
from detector_worker.core.exceptions import ConfigurationError
|
||||
|
||||
|
||||
class TestJsonFileProvider:
|
||||
"""Test JSON file configuration provider."""
|
||||
|
||||
def test_get_config_from_valid_file(self, temp_dir):
|
||||
"""Test loading configuration from a valid JSON file."""
|
||||
config_data = {"test_key": "test_value", "number": 42}
|
||||
config_file = os.path.join(temp_dir, "config.json")
|
||||
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
provider = JsonFileProvider(config_file)
|
||||
result = provider.get_config()
|
||||
|
||||
assert result == config_data
|
||||
|
||||
def test_get_config_file_not_exists(self, temp_dir):
|
||||
"""Test handling of non-existent config file."""
|
||||
config_file = os.path.join(temp_dir, "nonexistent.json")
|
||||
provider = JsonFileProvider(config_file)
|
||||
|
||||
result = provider.get_config()
|
||||
assert result == {}
|
||||
|
||||
def test_get_config_invalid_json(self, temp_dir):
|
||||
"""Test handling of invalid JSON file."""
|
||||
config_file = os.path.join(temp_dir, "invalid.json")
|
||||
|
||||
with open(config_file, 'w') as f:
|
||||
f.write("invalid json content")
|
||||
|
||||
provider = JsonFileProvider(config_file)
|
||||
result = provider.get_config()
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_reload_updates_config(self, temp_dir):
|
||||
"""Test that reload updates configuration."""
|
||||
config_file = os.path.join(temp_dir, "config.json")
|
||||
|
||||
# Initial config
|
||||
initial_config = {"version": 1}
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(initial_config, f)
|
||||
|
||||
provider = JsonFileProvider(config_file)
|
||||
assert provider.get_config() == initial_config
|
||||
|
||||
# Update config file
|
||||
updated_config = {"version": 2}
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(updated_config, f)
|
||||
|
||||
# Force reload
|
||||
provider.reload()
|
||||
assert provider.get_config() == updated_config
|
||||
|
||||
|
||||
class TestEnvironmentProvider:
|
||||
"""Test environment variable configuration provider."""
|
||||
|
||||
def test_get_config_with_env_vars(self):
|
||||
"""Test loading configuration from environment variables."""
|
||||
env_vars = {
|
||||
"DETECTOR_MAX_STREAMS": "10",
|
||||
"DETECTOR_TARGET_FPS": "15",
|
||||
"DETECTOR_CONFIG": '{"nested": "value"}',
|
||||
"OTHER_VAR": "ignored"
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
provider = EnvironmentProvider("DETECTOR_")
|
||||
config = provider.get_config()
|
||||
|
||||
assert config["max_streams"] == "10"
|
||||
assert config["target_fps"] == "15"
|
||||
assert config["config"] == {"nested": "value"}
|
||||
assert "other_var" not in config
|
||||
|
||||
def test_get_config_no_env_vars(self):
|
||||
"""Test with no matching environment variables."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
provider = EnvironmentProvider("DETECTOR_")
|
||||
config = provider.get_config()
|
||||
|
||||
assert config == {}
|
||||
|
||||
def test_custom_prefix(self):
|
||||
"""Test with custom prefix."""
|
||||
env_vars = {"CUSTOM_TEST": "value"}
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
provider = EnvironmentProvider("CUSTOM_")
|
||||
config = provider.get_config()
|
||||
|
||||
assert config["test"] == "value"
|
||||
|
||||
|
||||
class TestConfigDataclasses:
|
||||
"""Test configuration dataclasses."""
|
||||
|
||||
def test_database_config_from_dict(self):
|
||||
"""Test DatabaseConfig creation from dictionary."""
|
||||
data = {
|
||||
"enabled": True,
|
||||
"host": "db.example.com",
|
||||
"port": 5432,
|
||||
"database": "testdb",
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
"schema": "test_schema",
|
||||
"unknown_field": "ignored"
|
||||
}
|
||||
|
||||
config = DatabaseConfig.from_dict(data)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.host == "db.example.com"
|
||||
assert config.port == 5432
|
||||
assert config.database == "testdb"
|
||||
assert config.username == "user"
|
||||
assert config.password == "pass"
|
||||
assert config.schema == "test_schema"
|
||||
# Unknown fields should be ignored
|
||||
assert not hasattr(config, 'unknown_field')
|
||||
|
||||
def test_redis_config_from_dict(self):
|
||||
"""Test RedisConfig creation from dictionary."""
|
||||
data = {
|
||||
"enabled": True,
|
||||
"host": "redis.example.com",
|
||||
"port": 6379,
|
||||
"password": "secret",
|
||||
"db": 1
|
||||
}
|
||||
|
||||
config = RedisConfig.from_dict(data)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.host == "redis.example.com"
|
||||
assert config.port == 6379
|
||||
assert config.password == "secret"
|
||||
assert config.db == 1
|
||||
|
||||
def test_stream_config_from_dict(self):
|
||||
"""Test StreamConfig creation from dictionary."""
|
||||
data = {
|
||||
"poll_interval_ms": 50,
|
||||
"max_streams": 10,
|
||||
"target_fps": 20,
|
||||
"reconnect_interval_sec": 10,
|
||||
"max_retries": 5
|
||||
}
|
||||
|
||||
config = StreamConfig.from_dict(data)
|
||||
|
||||
assert config.poll_interval_ms == 50
|
||||
assert config.max_streams == 10
|
||||
assert config.target_fps == 20
|
||||
assert config.reconnect_interval_sec == 10
|
||||
assert config.max_retries == 5
|
||||
|
||||
|
||||
class TestConfigurationManager:
|
||||
"""Test main configuration manager."""
|
||||
|
||||
def test_initialization_with_defaults(self):
|
||||
"""Test that manager initializes with default values."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Should have default providers
|
||||
assert len(manager._providers) >= 1
|
||||
|
||||
# Should have default configuration values
|
||||
config = manager.get_all()
|
||||
assert "poll_interval_ms" in config
|
||||
assert "max_streams" in config
|
||||
assert "target_fps" in config
|
||||
|
||||
def test_add_provider(self):
|
||||
"""Test adding configuration providers."""
|
||||
manager = ConfigurationManager()
|
||||
initial_count = len(manager._providers)
|
||||
|
||||
mock_provider = Mock()
|
||||
mock_provider.get_config.return_value = {"test": "value"}
|
||||
|
||||
manager.add_provider(mock_provider)
|
||||
|
||||
assert len(manager._providers) == initial_count + 1
|
||||
|
||||
def test_get_configuration_value(self):
|
||||
"""Test getting specific configuration values."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Test existing key
|
||||
value = manager.get("poll_interval_ms")
|
||||
assert value is not None
|
||||
|
||||
# Test non-existing key with default
|
||||
value = manager.get("nonexistent", "default")
|
||||
assert value == "default"
|
||||
|
||||
# Test non-existing key without default
|
||||
value = manager.get("nonexistent")
|
||||
assert value is None
|
||||
|
||||
def test_get_section(self):
|
||||
"""Test getting configuration sections."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Test existing section
|
||||
db_section = manager.get_section("database")
|
||||
assert isinstance(db_section, dict)
|
||||
|
||||
# Test non-existing section
|
||||
empty_section = manager.get_section("nonexistent")
|
||||
assert empty_section == {}
|
||||
|
||||
def test_typed_config_access(self):
|
||||
"""Test typed configuration object access."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Test database config
|
||||
db_config = manager.get_database_config()
|
||||
assert isinstance(db_config, DatabaseConfig)
|
||||
|
||||
# Test Redis config
|
||||
redis_config = manager.get_redis_config()
|
||||
assert isinstance(redis_config, RedisConfig)
|
||||
|
||||
# Test stream config
|
||||
stream_config = manager.get_stream_config()
|
||||
assert isinstance(stream_config, StreamConfig)
|
||||
|
||||
# Test model config
|
||||
model_config = manager.get_model_config()
|
||||
assert isinstance(model_config, ModelConfig)
|
||||
|
||||
# Test logging config
|
||||
logging_config = manager.get_logging_config()
|
||||
assert isinstance(logging_config, LoggingConfig)
|
||||
|
||||
def test_set_configuration_value(self):
|
||||
"""Test setting configuration values at runtime."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
manager.set("test_key", "test_value")
|
||||
|
||||
assert manager.get("test_key") == "test_value"
|
||||
|
||||
# Should also update typed configs
|
||||
manager.set("poll_interval_ms", 200)
|
||||
stream_config = manager.get_stream_config()
|
||||
assert stream_config.poll_interval_ms == 200
|
||||
|
||||
def test_validation_success(self):
|
||||
"""Test configuration validation with valid config."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Set valid configuration
|
||||
manager.set("poll_interval_ms", 100)
|
||||
manager.set("max_streams", 5)
|
||||
manager.set("target_fps", 10)
|
||||
|
||||
errors = manager.validate()
|
||||
assert errors == []
|
||||
assert manager.is_valid() is True
|
||||
|
||||
def test_validation_errors(self):
|
||||
"""Test configuration validation with invalid values."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Set invalid configuration
|
||||
manager.set("poll_interval_ms", 0)
|
||||
manager.set("max_streams", -1)
|
||||
manager.set("target_fps", 0)
|
||||
|
||||
errors = manager.validate()
|
||||
assert len(errors) > 0
|
||||
assert manager.is_valid() is False
|
||||
|
||||
# Check specific errors
|
||||
error_messages = " ".join(errors)
|
||||
assert "poll_interval_ms must be positive" in error_messages
|
||||
assert "max_streams must be positive" in error_messages
|
||||
assert "target_fps must be positive" in error_messages
|
||||
|
||||
def test_database_validation(self):
|
||||
"""Test database-specific validation."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Enable database but don't provide required fields
|
||||
db_config = {
|
||||
"enabled": True,
|
||||
"host": "",
|
||||
"database": ""
|
||||
}
|
||||
manager.set("database", db_config)
|
||||
|
||||
errors = manager.validate()
|
||||
error_messages = " ".join(errors)
|
||||
|
||||
assert "database host is required" in error_messages
|
||||
assert "database name is required" in error_messages
|
||||
|
||||
def test_redis_validation(self):
|
||||
"""Test Redis-specific validation."""
|
||||
manager = ConfigurationManager()
|
||||
|
||||
# Enable Redis but don't provide required fields
|
||||
redis_config = {
|
||||
"enabled": True,
|
||||
"host": ""
|
||||
}
|
||||
manager.set("redis", redis_config)
|
||||
|
||||
errors = manager.validate()
|
||||
error_messages = " ".join(errors)
|
||||
|
||||
assert "redis host is required" in error_messages
|
||||
|
||||
|
||||
class TestGlobalConfigurationFunctions:
|
||||
"""Test global configuration functions."""
|
||||
|
||||
def test_get_config_manager_singleton(self):
|
||||
"""Test that get_config_manager returns a singleton."""
|
||||
manager1 = get_config_manager()
|
||||
manager2 = get_config_manager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
@patch('detector_worker.core.config.get_config_manager')
|
||||
def test_validate_config_function(self, mock_get_manager):
|
||||
"""Test global validate_config function."""
|
||||
mock_manager = Mock()
|
||||
mock_manager.validate.return_value = ["error1", "error2"]
|
||||
mock_get_manager.return_value = mock_manager
|
||||
|
||||
errors = validate_config()
|
||||
|
||||
assert errors == ["error1", "error2"]
|
||||
mock_manager.validate.assert_called_once()
|
||||
|
||||
|
||||
class TestConfigurationIntegration:
|
||||
"""Integration tests for configuration system."""
|
||||
|
||||
def test_provider_priority(self, temp_dir):
|
||||
"""Test that later providers override earlier ones."""
|
||||
# Create JSON file with initial config
|
||||
config_file = os.path.join(temp_dir, "config.json")
|
||||
json_config = {"test_value": "from_json", "json_only": "json"}
|
||||
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(json_config, f)
|
||||
|
||||
# Set environment variable that should override
|
||||
env_vars = {"DETECTOR_TEST_VALUE": "from_env", "DETECTOR_ENV_ONLY": "env"}
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=False):
|
||||
manager = ConfigurationManager()
|
||||
manager._providers.clear() # Start fresh
|
||||
|
||||
# Add providers in order
|
||||
manager.add_provider(JsonFileProvider(config_file))
|
||||
manager.add_provider(EnvironmentProvider("DETECTOR_"))
|
||||
|
||||
config = manager.get_all()
|
||||
|
||||
# Environment should override JSON
|
||||
assert config["test_value"] == "from_env"
|
||||
|
||||
# Both sources should be present
|
||||
assert config["json_only"] == "json"
|
||||
assert config["env_only"] == "env"
|
||||
|
||||
def test_hot_reload(self, temp_dir):
|
||||
"""Test configuration hot reload functionality."""
|
||||
config_file = os.path.join(temp_dir, "config.json")
|
||||
|
||||
# Initial config
|
||||
initial_config = {"version": 1, "feature_enabled": False}
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(initial_config, f)
|
||||
|
||||
manager = ConfigurationManager()
|
||||
manager._providers.clear()
|
||||
manager.add_provider(JsonFileProvider(config_file))
|
||||
|
||||
assert manager.get("version") == 1
|
||||
assert manager.get("feature_enabled") is False
|
||||
|
||||
# Update config file
|
||||
updated_config = {"version": 2, "feature_enabled": True}
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(updated_config, f)
|
||||
|
||||
# Reload configuration
|
||||
success = manager.reload()
|
||||
assert success is True
|
||||
|
||||
assert manager.get("version") == 2
|
||||
assert manager.get("feature_enabled") is True
|
566
tests/unit/core/test_dependency_injection.py
Normal file
566
tests/unit/core/test_dependency_injection.py
Normal file
|
@ -0,0 +1,566 @@
|
|||
"""
|
||||
Unit tests for dependency injection system.
|
||||
"""
|
||||
import pytest
|
||||
import threading
|
||||
from unittest.mock import Mock, MagicMock
|
||||
|
||||
from detector_worker.core.dependency_injection import (
|
||||
ServiceContainer,
|
||||
ServiceLifetime,
|
||||
ServiceDescriptor,
|
||||
ServiceScope,
|
||||
DetectorWorkerContainer,
|
||||
get_container,
|
||||
resolve_service,
|
||||
create_service_scope
|
||||
)
|
||||
from detector_worker.core.exceptions import DependencyInjectionError
|
||||
|
||||
|
||||
class TestServiceContainer:
|
||||
"""Test core service container functionality."""
|
||||
|
||||
def test_register_singleton(self):
|
||||
"""Test singleton service registration."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
# Register singleton
|
||||
container.register_singleton(TestService)
|
||||
|
||||
# Resolve twice - should get same instance
|
||||
instance1 = container.resolve(TestService)
|
||||
instance2 = container.resolve(TestService)
|
||||
|
||||
assert instance1 is instance2
|
||||
assert instance1.value == 42
|
||||
|
||||
def test_register_singleton_with_instance(self):
|
||||
"""Test singleton registration with pre-created instance."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
# Create instance and register
|
||||
instance = TestService(99)
|
||||
container.register_singleton(TestService, instance=instance)
|
||||
|
||||
# Resolve should return the pre-created instance
|
||||
resolved = container.resolve(TestService)
|
||||
assert resolved is instance
|
||||
assert resolved.value == 99
|
||||
|
||||
def test_register_transient(self):
|
||||
"""Test transient service registration."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
# Register transient
|
||||
container.register_transient(TestService)
|
||||
|
||||
# Resolve twice - should get different instances
|
||||
instance1 = container.resolve(TestService)
|
||||
instance2 = container.resolve(TestService)
|
||||
|
||||
assert instance1 is not instance2
|
||||
assert instance1.value == instance2.value == 42
|
||||
|
||||
def test_register_scoped(self):
|
||||
"""Test scoped service registration."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
# Register scoped
|
||||
container.register_scoped(TestService)
|
||||
|
||||
# Resolve in same scope - should get same instance
|
||||
instance1 = container.resolve(TestService, scope_id="scope1")
|
||||
instance2 = container.resolve(TestService, scope_id="scope1")
|
||||
|
||||
assert instance1 is instance2
|
||||
|
||||
# Resolve in different scope - should get different instance
|
||||
instance3 = container.resolve(TestService, scope_id="scope2")
|
||||
assert instance3 is not instance1
|
||||
|
||||
def test_register_with_factory(self):
|
||||
"""Test service registration with factory function."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
# Register with factory
|
||||
def factory():
|
||||
return TestService(100)
|
||||
|
||||
container.register_singleton(TestService, factory=factory)
|
||||
|
||||
instance = container.resolve(TestService)
|
||||
assert instance.value == 100
|
||||
|
||||
def test_register_with_implementation_type(self):
|
||||
"""Test service registration with implementation type."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class ITestService:
|
||||
pass
|
||||
|
||||
class TestService(ITestService):
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
# Register interface with implementation
|
||||
container.register_singleton(ITestService, implementation_type=TestService)
|
||||
|
||||
instance = container.resolve(ITestService)
|
||||
assert isinstance(instance, TestService)
|
||||
assert instance.value == 42
|
||||
|
||||
def test_dependency_injection(self):
|
||||
"""Test automatic dependency injection."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class DatabaseService:
|
||||
def __init__(self):
|
||||
self.connected = True
|
||||
|
||||
class UserService:
|
||||
def __init__(self, database: DatabaseService):
|
||||
self.database = database
|
||||
|
||||
# Register services
|
||||
container.register_singleton(DatabaseService)
|
||||
container.register_transient(UserService)
|
||||
|
||||
# Resolve should inject dependencies
|
||||
user_service = container.resolve(UserService)
|
||||
assert isinstance(user_service.database, DatabaseService)
|
||||
assert user_service.database.connected is True
|
||||
|
||||
def test_circular_dependency_detection(self):
|
||||
"""Test circular dependency detection."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class ServiceA:
|
||||
def __init__(self, service_b: 'ServiceB'):
|
||||
self.service_b = service_b
|
||||
|
||||
class ServiceB:
|
||||
def __init__(self, service_a: ServiceA):
|
||||
self.service_a = service_a
|
||||
|
||||
# Register circular dependencies
|
||||
container.register_singleton(ServiceA)
|
||||
container.register_singleton(ServiceB)
|
||||
|
||||
# Should raise circular dependency error
|
||||
with pytest.raises(DependencyInjectionError) as exc_info:
|
||||
container.resolve(ServiceA)
|
||||
|
||||
assert "Circular dependency detected" in str(exc_info.value)
|
||||
|
||||
def test_unregistered_service_error(self):
|
||||
"""Test error when resolving unregistered service."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class UnregisteredService:
|
||||
pass
|
||||
|
||||
with pytest.raises(DependencyInjectionError) as exc_info:
|
||||
container.resolve(UnregisteredService)
|
||||
|
||||
assert "is not registered" in str(exc_info.value)
|
||||
|
||||
def test_scoped_service_without_scope_id(self):
|
||||
"""Test error when resolving scoped service without scope ID."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
pass
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
with pytest.raises(DependencyInjectionError) as exc_info:
|
||||
container.resolve(TestService)
|
||||
|
||||
assert "Scope ID required" in str(exc_info.value)
|
||||
|
||||
def test_factory_error_handling(self):
|
||||
"""Test factory error handling."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
pass
|
||||
|
||||
def failing_factory():
|
||||
raise ValueError("Factory failed")
|
||||
|
||||
container.register_singleton(TestService, factory=failing_factory)
|
||||
|
||||
with pytest.raises(DependencyInjectionError) as exc_info:
|
||||
container.resolve(TestService)
|
||||
|
||||
assert "Failed to create service using factory" in str(exc_info.value)
|
||||
|
||||
def test_constructor_dependency_with_default(self):
|
||||
"""Test dependency with default value."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self, value: int = 42):
|
||||
self.value = value
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
instance = container.resolve(TestService)
|
||||
assert instance.value == 42
|
||||
|
||||
def test_unresolvable_dependency_with_default(self):
|
||||
"""Test unresolvable dependency that has a default value."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class UnregisteredService:
|
||||
pass
|
||||
|
||||
class TestService:
|
||||
def __init__(self, dep: UnregisteredService = None):
|
||||
self.dep = dep
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
instance = container.resolve(TestService)
|
||||
assert instance.dep is None
|
||||
|
||||
def test_unresolvable_dependency_without_default(self):
|
||||
"""Test unresolvable dependency without default value."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class UnregisteredService:
|
||||
pass
|
||||
|
||||
class TestService:
|
||||
def __init__(self, dep: UnregisteredService):
|
||||
self.dep = dep
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
with pytest.raises(DependencyInjectionError) as exc_info:
|
||||
container.resolve(TestService)
|
||||
|
||||
assert "Cannot resolve dependency" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestServiceScope:
|
||||
"""Test service scope functionality."""
|
||||
|
||||
def test_create_scope(self):
|
||||
"""Test scope creation."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
scope = container.create_scope("test_scope")
|
||||
assert isinstance(scope, ServiceScope)
|
||||
assert scope.scope_id == "test_scope"
|
||||
|
||||
def test_scope_context_manager(self):
|
||||
"""Test scope as context manager."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.disposed = False
|
||||
|
||||
def dispose(self):
|
||||
self.disposed = True
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
instance = None
|
||||
with container.create_scope("test_scope") as scope:
|
||||
instance = scope.resolve(TestService)
|
||||
assert not instance.disposed
|
||||
|
||||
# Instance should be disposed after scope exit
|
||||
assert instance.disposed
|
||||
|
||||
def test_dispose_scope(self):
|
||||
"""Test manual scope disposal."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
self.disposed = False
|
||||
|
||||
def dispose(self):
|
||||
self.disposed = True
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
instance = container.resolve(TestService, scope_id="test_scope")
|
||||
assert not instance.disposed
|
||||
|
||||
container.dispose_scope("test_scope")
|
||||
assert instance.disposed
|
||||
|
||||
def test_dispose_error_handling(self):
|
||||
"""Test error handling during scope disposal."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def dispose(self):
|
||||
raise ValueError("Dispose failed")
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
container.resolve(TestService, scope_id="test_scope")
|
||||
|
||||
# Should not raise error, just log it
|
||||
container.dispose_scope("test_scope")
|
||||
|
||||
|
||||
class TestContainerIntrospection:
|
||||
"""Test container introspection capabilities."""
|
||||
|
||||
def test_is_registered(self):
|
||||
"""Test checking if service is registered."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class RegisteredService:
|
||||
pass
|
||||
|
||||
class UnregisteredService:
|
||||
pass
|
||||
|
||||
container.register_singleton(RegisteredService)
|
||||
|
||||
assert container.is_registered(RegisteredService) is True
|
||||
assert container.is_registered(UnregisteredService) is False
|
||||
|
||||
def test_get_registration_info(self):
|
||||
"""Test getting service registration information."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
pass
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
info = container.get_registration_info(TestService)
|
||||
assert isinstance(info, ServiceDescriptor)
|
||||
assert info.service_type == TestService
|
||||
assert info.lifetime == ServiceLifetime.SINGLETON
|
||||
|
||||
def test_get_registered_services(self):
|
||||
"""Test getting all registered services."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class Service1:
|
||||
pass
|
||||
|
||||
class Service2:
|
||||
pass
|
||||
|
||||
container.register_singleton(Service1)
|
||||
container.register_transient(Service2)
|
||||
|
||||
services = container.get_registered_services()
|
||||
assert len(services) == 2
|
||||
assert Service1 in services
|
||||
assert Service2 in services
|
||||
|
||||
def test_clear_singletons(self):
|
||||
"""Test clearing singleton instances."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
pass
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
# Create singleton instance
|
||||
instance1 = container.resolve(TestService)
|
||||
|
||||
# Clear singletons
|
||||
container.clear_singletons()
|
||||
|
||||
# Next resolve should create new instance
|
||||
instance2 = container.resolve(TestService)
|
||||
assert instance2 is not instance1
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting container statistics."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class Service1:
|
||||
pass
|
||||
|
||||
class Service2:
|
||||
pass
|
||||
|
||||
class Service3:
|
||||
pass
|
||||
|
||||
container.register_singleton(Service1)
|
||||
container.register_transient(Service2)
|
||||
container.register_scoped(Service3)
|
||||
|
||||
# Create some instances
|
||||
container.resolve(Service1)
|
||||
container.resolve(Service3, scope_id="scope1")
|
||||
|
||||
stats = container.get_stats()
|
||||
|
||||
assert stats["registered_services"] == 3
|
||||
assert stats["active_singletons"] == 1
|
||||
assert stats["active_scopes"] == 1
|
||||
assert stats["lifetime_breakdown"]["singleton"] == 1
|
||||
assert stats["lifetime_breakdown"]["transient"] == 1
|
||||
assert stats["lifetime_breakdown"]["scoped"] == 1
|
||||
|
||||
|
||||
class TestDetectorWorkerContainer:
|
||||
"""Test pre-configured detector worker container."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test detector worker container initialization."""
|
||||
container = DetectorWorkerContainer()
|
||||
|
||||
assert isinstance(container.container, ServiceContainer)
|
||||
|
||||
# Should have core services registered
|
||||
stats = container.container.get_stats()
|
||||
assert stats["registered_services"] > 0
|
||||
|
||||
def test_resolve_convenience_method(self):
|
||||
"""Test resolve convenience method."""
|
||||
container = DetectorWorkerContainer()
|
||||
|
||||
# Should be able to resolve through convenience method
|
||||
from detector_worker.core.singleton_managers import ModelStateManager
|
||||
|
||||
manager = container.resolve(ModelStateManager)
|
||||
assert isinstance(manager, ModelStateManager)
|
||||
|
||||
def test_create_scope_convenience_method(self):
|
||||
"""Test create scope convenience method."""
|
||||
container = DetectorWorkerContainer()
|
||||
|
||||
scope = container.create_scope("test_scope")
|
||||
assert isinstance(scope, ServiceScope)
|
||||
assert scope.scope_id == "test_scope"
|
||||
|
||||
|
||||
class TestGlobalContainerFunctions:
|
||||
"""Test global container functions."""
|
||||
|
||||
def test_get_container_singleton(self):
|
||||
"""Test that get_container returns a singleton."""
|
||||
container1 = get_container()
|
||||
container2 = get_container()
|
||||
|
||||
assert container1 is container2
|
||||
assert isinstance(container1, DetectorWorkerContainer)
|
||||
|
||||
def test_resolve_service_convenience(self):
|
||||
"""Test resolve_service convenience function."""
|
||||
from detector_worker.core.singleton_managers import ModelStateManager
|
||||
|
||||
manager = resolve_service(ModelStateManager)
|
||||
assert isinstance(manager, ModelStateManager)
|
||||
|
||||
def test_create_service_scope_convenience(self):
|
||||
"""Test create_service_scope convenience function."""
|
||||
scope = create_service_scope("test_scope")
|
||||
assert isinstance(scope, ServiceScope)
|
||||
assert scope.scope_id == "test_scope"
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Test thread safety of dependency injection system."""
|
||||
|
||||
def test_container_thread_safety(self):
|
||||
"""Test that container is thread-safe."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
import threading
|
||||
self.thread_id = threading.current_thread().ident
|
||||
|
||||
container.register_singleton(TestService)
|
||||
|
||||
instances = {}
|
||||
|
||||
def resolve_service(thread_id):
|
||||
instances[thread_id] = container.resolve(TestService)
|
||||
|
||||
# Create multiple threads
|
||||
threads = []
|
||||
for i in range(10):
|
||||
thread = threading.Thread(target=resolve_service, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# All should get the same singleton instance
|
||||
first_instance = list(instances.values())[0]
|
||||
for instance in instances.values():
|
||||
assert instance is first_instance
|
||||
|
||||
def test_scope_thread_safety(self):
|
||||
"""Test that scoped services are thread-safe."""
|
||||
container = ServiceContainer()
|
||||
|
||||
class TestService:
|
||||
def __init__(self):
|
||||
import threading
|
||||
self.thread_id = threading.current_thread().ident
|
||||
|
||||
container.register_scoped(TestService)
|
||||
|
||||
results = {}
|
||||
|
||||
def resolve_in_scope(thread_id):
|
||||
# Each thread uses its own scope
|
||||
instance1 = container.resolve(TestService, scope_id=f"scope_{thread_id}")
|
||||
instance2 = container.resolve(TestService, scope_id=f"scope_{thread_id}")
|
||||
|
||||
results[thread_id] = {
|
||||
"same_instance": instance1 is instance2,
|
||||
"thread_id": instance1.thread_id
|
||||
}
|
||||
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=resolve_in_scope, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Each thread should get same instance within its scope
|
||||
for thread_id, result in results.items():
|
||||
assert result["same_instance"] is True
|
560
tests/unit/core/test_singleton_managers.py
Normal file
560
tests/unit/core/test_singleton_managers.py
Normal file
|
@ -0,0 +1,560 @@
|
|||
"""
|
||||
Unit tests for singleton state managers.
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
import threading
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from detector_worker.core.singleton_managers import (
|
||||
SingletonMeta,
|
||||
ModelStateManager,
|
||||
StreamStateManager,
|
||||
SessionStateManager,
|
||||
CacheStateManager,
|
||||
CameraStateManager,
|
||||
PipelineStateManager,
|
||||
ModelInfo,
|
||||
StreamInfo,
|
||||
SessionInfo
|
||||
)
|
||||
|
||||
|
||||
class TestSingletonMeta:
|
||||
"""Test singleton metaclass."""
|
||||
|
||||
def test_singleton_behavior(self):
|
||||
"""Test that singleton metaclass creates only one instance."""
|
||||
class TestSingleton(metaclass=SingletonMeta):
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
|
||||
instance1 = TestSingleton()
|
||||
instance2 = TestSingleton()
|
||||
|
||||
assert instance1 is instance2
|
||||
assert instance1.value == instance2.value
|
||||
|
||||
def test_singleton_thread_safety(self):
|
||||
"""Test that singleton is thread-safe."""
|
||||
class TestSingleton(metaclass=SingletonMeta):
|
||||
def __init__(self):
|
||||
self.created_by = threading.current_thread().name
|
||||
|
||||
instances = {}
|
||||
|
||||
def create_instance(thread_id):
|
||||
instances[thread_id] = TestSingleton()
|
||||
|
||||
threads = []
|
||||
for i in range(10):
|
||||
thread = threading.Thread(target=create_instance, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# All instances should be the same object
|
||||
first_instance = instances[0]
|
||||
for instance in instances.values():
|
||||
assert instance is first_instance
|
||||
|
||||
|
||||
class TestModelStateManager:
|
||||
"""Test model state management."""
|
||||
|
||||
def test_singleton_behavior(self):
|
||||
"""Test that ModelStateManager is a singleton."""
|
||||
manager1 = ModelStateManager()
|
||||
manager2 = ModelStateManager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_load_model(self):
|
||||
"""Test loading a model."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all() # Start fresh
|
||||
|
||||
mock_model = Mock()
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
|
||||
retrieved_model = manager.get_model("camera1", "model1")
|
||||
assert retrieved_model is mock_model
|
||||
|
||||
def test_load_same_model_increments_reference_count(self):
|
||||
"""Test that loading the same model increments reference count."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
mock_model = Mock()
|
||||
|
||||
# Load same model twice
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
|
||||
# Should still be accessible
|
||||
assert manager.get_model("camera1", "model1") is mock_model
|
||||
|
||||
def test_get_camera_models(self):
|
||||
"""Test getting all models for a camera."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
mock_model1 = Mock()
|
||||
mock_model2 = Mock()
|
||||
|
||||
manager.load_model("camera1", "model1", mock_model1)
|
||||
manager.load_model("camera1", "model2", mock_model2)
|
||||
|
||||
models = manager.get_camera_models("camera1")
|
||||
|
||||
assert len(models) == 2
|
||||
assert models["model1"] is mock_model1
|
||||
assert models["model2"] is mock_model2
|
||||
|
||||
def test_unload_model_with_multiple_references(self):
|
||||
"""Test unloading model with multiple references."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
mock_model = Mock()
|
||||
|
||||
# Load model twice (reference count = 2)
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
|
||||
# First unload should not remove model
|
||||
result = manager.unload_model("camera1", "model1")
|
||||
assert result is False # Still referenced
|
||||
assert manager.get_model("camera1", "model1") is mock_model
|
||||
|
||||
# Second unload should remove model
|
||||
result = manager.unload_model("camera1", "model1")
|
||||
assert result is True # Completely removed
|
||||
assert manager.get_model("camera1", "model1") is None
|
||||
|
||||
def test_unload_camera_models(self):
|
||||
"""Test unloading all models for a camera."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
mock_model1 = Mock()
|
||||
mock_model2 = Mock()
|
||||
|
||||
manager.load_model("camera1", "model1", mock_model1)
|
||||
manager.load_model("camera1", "model2", mock_model2)
|
||||
|
||||
manager.unload_camera_models("camera1")
|
||||
|
||||
assert manager.get_model("camera1", "model1") is None
|
||||
assert manager.get_model("camera1", "model2") is None
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting model statistics."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
mock_model = Mock()
|
||||
manager.load_model("camera1", "model1", mock_model)
|
||||
manager.load_model("camera2", "model2", mock_model)
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert stats["total_models"] == 2
|
||||
assert stats["total_cameras"] == 2
|
||||
assert "camera1" in stats["cameras"]
|
||||
assert "camera2" in stats["cameras"]
|
||||
|
||||
|
||||
class TestStreamStateManager:
|
||||
"""Test stream state management."""
|
||||
|
||||
def test_add_stream(self):
|
||||
"""Test adding a stream."""
|
||||
manager = StreamStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
config = {"rtsp_url": "rtsp://example.com", "model_id": "test"}
|
||||
manager.add_stream("camera1", "sub1", config)
|
||||
|
||||
stream = manager.get_stream("camera1")
|
||||
assert stream is not None
|
||||
assert stream.camera_id == "camera1"
|
||||
assert stream.subscription_id == "sub1"
|
||||
assert stream.config == config
|
||||
|
||||
def test_subscription_mapping(self):
|
||||
"""Test subscription to camera mapping."""
|
||||
manager = StreamStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
config = {"rtsp_url": "rtsp://example.com"}
|
||||
manager.add_stream("camera1", "sub1", config)
|
||||
|
||||
camera_id = manager.get_camera_by_subscription("sub1")
|
||||
assert camera_id == "camera1"
|
||||
|
||||
def test_remove_stream(self):
|
||||
"""Test removing a stream."""
|
||||
manager = StreamStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
config = {"rtsp_url": "rtsp://example.com"}
|
||||
manager.add_stream("camera1", "sub1", config)
|
||||
|
||||
removed_stream = manager.remove_stream("camera1")
|
||||
|
||||
assert removed_stream is not None
|
||||
assert removed_stream.camera_id == "camera1"
|
||||
assert manager.get_stream("camera1") is None
|
||||
assert manager.get_camera_by_subscription("sub1") is None
|
||||
|
||||
def test_shared_stream_management(self):
|
||||
"""Test shared stream management."""
|
||||
manager = StreamStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
stream_data = {"reader": Mock(), "reference_count": 1}
|
||||
manager.add_shared_stream("rtsp://example.com", stream_data)
|
||||
|
||||
retrieved_data = manager.get_shared_stream("rtsp://example.com")
|
||||
assert retrieved_data == stream_data
|
||||
|
||||
removed_data = manager.remove_shared_stream("rtsp://example.com")
|
||||
assert removed_data == stream_data
|
||||
assert manager.get_shared_stream("rtsp://example.com") is None
|
||||
|
||||
|
||||
class TestSessionStateManager:
|
||||
"""Test session state management."""
|
||||
|
||||
def test_session_id_management(self):
|
||||
"""Test session ID assignment."""
|
||||
manager = SessionStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.set_session_id("display1", "session123")
|
||||
|
||||
session_id = manager.get_session_id("display1")
|
||||
assert session_id == "session123"
|
||||
|
||||
def test_create_session(self):
|
||||
"""Test session creation with detection data."""
|
||||
manager = SessionStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
detection_data = {"class": "car", "confidence": 0.85}
|
||||
manager.create_session("session123", "camera1", detection_data)
|
||||
|
||||
retrieved_data = manager.get_session_detection("session123")
|
||||
assert retrieved_data == detection_data
|
||||
|
||||
camera_id = manager.get_camera_by_session("session123")
|
||||
assert camera_id == "camera1"
|
||||
|
||||
def test_update_session_detection(self):
|
||||
"""Test updating session detection data."""
|
||||
manager = SessionStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
initial_data = {"class": "car", "confidence": 0.85}
|
||||
manager.create_session("session123", "camera1", initial_data)
|
||||
|
||||
update_data = {"brand": "Toyota"}
|
||||
manager.update_session_detection("session123", update_data)
|
||||
|
||||
final_data = manager.get_session_detection("session123")
|
||||
assert final_data["class"] == "car"
|
||||
assert final_data["brand"] == "Toyota"
|
||||
|
||||
def test_session_expiration(self):
|
||||
"""Test session expiration based on TTL."""
|
||||
# Use a very short TTL for testing
|
||||
manager = SessionStateManager(session_ttl=0.1)
|
||||
manager.clear_all()
|
||||
|
||||
detection_data = {"class": "car", "confidence": 0.85}
|
||||
manager.create_session("session123", "camera1", detection_data)
|
||||
|
||||
# Session should exist initially
|
||||
assert manager.get_session_detection("session123") is not None
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(0.2)
|
||||
|
||||
# Clean up expired sessions
|
||||
expired_count = manager.cleanup_expired_sessions()
|
||||
|
||||
assert expired_count == 1
|
||||
assert manager.get_session_detection("session123") is None
|
||||
|
||||
def test_remove_session(self):
|
||||
"""Test manual session removal."""
|
||||
manager = SessionStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
detection_data = {"class": "car", "confidence": 0.85}
|
||||
manager.create_session("session123", "camera1", detection_data)
|
||||
|
||||
result = manager.remove_session("session123")
|
||||
assert result is True
|
||||
|
||||
assert manager.get_session_detection("session123") is None
|
||||
assert manager.get_camera_by_session("session123") is None
|
||||
|
||||
|
||||
class TestCacheStateManager:
|
||||
"""Test cache state management."""
|
||||
|
||||
def test_cache_detection(self):
|
||||
"""Test caching detection results."""
|
||||
manager = CacheStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
detection_data = {"class": "car", "confidence": 0.85, "bbox": [100, 200, 300, 400]}
|
||||
manager.cache_detection("camera1", detection_data)
|
||||
|
||||
cached_data = manager.get_cached_detection("camera1")
|
||||
assert cached_data == detection_data
|
||||
|
||||
def test_cache_pipeline_result(self):
|
||||
"""Test caching pipeline results."""
|
||||
manager = CacheStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
pipeline_result = {"status": "success", "detections": []}
|
||||
manager.cache_pipeline_result("camera1", pipeline_result)
|
||||
|
||||
cached_result = manager.get_cached_pipeline_result("camera1")
|
||||
assert cached_result == pipeline_result
|
||||
|
||||
def test_latest_frame_management(self):
|
||||
"""Test latest frame storage."""
|
||||
manager = CacheStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
frame_data = b"fake_frame_data"
|
||||
manager.set_latest_frame("camera1", frame_data)
|
||||
|
||||
retrieved_frame = manager.get_latest_frame("camera1")
|
||||
assert retrieved_frame == frame_data
|
||||
|
||||
def test_frame_skip_flag(self):
|
||||
"""Test frame skip flag management."""
|
||||
manager = CacheStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Initially should be False
|
||||
assert manager.get_frame_skip_flag("camera1") is False
|
||||
|
||||
manager.set_frame_skip_flag("camera1", True)
|
||||
assert manager.get_frame_skip_flag("camera1") is True
|
||||
|
||||
manager.set_frame_skip_flag("camera1", False)
|
||||
assert manager.get_frame_skip_flag("camera1") is False
|
||||
|
||||
def test_clear_camera_cache(self):
|
||||
"""Test clearing all cache data for a camera."""
|
||||
manager = CacheStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Set up cache data
|
||||
detection_data = {"class": "car"}
|
||||
pipeline_result = {"status": "success"}
|
||||
frame_data = b"frame"
|
||||
|
||||
manager.cache_detection("camera1", detection_data)
|
||||
manager.cache_pipeline_result("camera1", pipeline_result)
|
||||
manager.set_latest_frame("camera1", frame_data)
|
||||
manager.set_frame_skip_flag("camera1", True)
|
||||
|
||||
# Clear cache
|
||||
manager.clear_camera_cache("camera1")
|
||||
|
||||
# All data should be gone
|
||||
assert manager.get_cached_detection("camera1") is None
|
||||
assert manager.get_cached_pipeline_result("camera1") is None
|
||||
assert manager.get_latest_frame("camera1") is None
|
||||
assert manager.get_frame_skip_flag("camera1") is False
|
||||
|
||||
|
||||
class TestCameraStateManager:
|
||||
"""Test camera state management."""
|
||||
|
||||
def test_camera_connection_state(self):
|
||||
"""Test camera connection state management."""
|
||||
manager = CameraStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Initially connected (default)
|
||||
assert manager.is_camera_connected("camera1") is True
|
||||
|
||||
# Set disconnected
|
||||
manager.set_camera_connected("camera1", False)
|
||||
assert manager.is_camera_connected("camera1") is False
|
||||
|
||||
# Set connected again
|
||||
manager.set_camera_connected("camera1", True)
|
||||
assert manager.is_camera_connected("camera1") is True
|
||||
|
||||
def test_notification_flags(self):
|
||||
"""Test disconnection/reconnection notification flags."""
|
||||
manager = CameraStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
# Set disconnected
|
||||
manager.set_camera_connected("camera1", False)
|
||||
|
||||
# Should notify disconnection once
|
||||
assert manager.should_notify_disconnection("camera1") is True
|
||||
manager.mark_disconnection_notified("camera1")
|
||||
assert manager.should_notify_disconnection("camera1") is False
|
||||
|
||||
# Reconnect
|
||||
manager.set_camera_connected("camera1", True)
|
||||
|
||||
# Should notify reconnection
|
||||
assert manager.should_notify_reconnection("camera1") is True
|
||||
manager.mark_reconnection_notified("camera1")
|
||||
assert manager.should_notify_reconnection("camera1") is False
|
||||
|
||||
def test_get_camera_state(self):
|
||||
"""Test getting full camera state."""
|
||||
manager = CameraStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.set_camera_connected("camera1", False)
|
||||
|
||||
state = manager.get_camera_state("camera1")
|
||||
|
||||
assert state["connected"] is False
|
||||
assert "last_update" in state
|
||||
assert "disconnection_notified" in state
|
||||
assert "reconnection_notified" in state
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting camera state statistics."""
|
||||
manager = CameraStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.set_camera_connected("camera1", True)
|
||||
manager.set_camera_connected("camera2", False)
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert stats["total_cameras"] == 2
|
||||
assert stats["connected_cameras"] == 1
|
||||
assert stats["disconnected_cameras"] == 1
|
||||
|
||||
|
||||
class TestPipelineStateManager:
|
||||
"""Test pipeline state management."""
|
||||
|
||||
def test_get_or_init_state(self):
|
||||
"""Test getting or initializing pipeline state."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
state = manager.get_or_init_state("camera1")
|
||||
|
||||
assert state["mode"] == "validation_detecting"
|
||||
assert state["backend_session_id"] is None
|
||||
assert state["yolo_inference_enabled"] is True
|
||||
assert "created_at" in state
|
||||
|
||||
def test_update_mode(self):
|
||||
"""Test updating pipeline mode."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.update_mode("camera1", "classification", "session123")
|
||||
|
||||
state = manager.get_state("camera1")
|
||||
assert state["mode"] == "classification"
|
||||
assert state["backend_session_id"] == "session123"
|
||||
|
||||
def test_set_yolo_inference_enabled(self):
|
||||
"""Test setting YOLO inference state."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.set_yolo_inference_enabled("camera1", False)
|
||||
|
||||
state = manager.get_state("camera1")
|
||||
assert state["yolo_inference_enabled"] is False
|
||||
|
||||
def test_set_progression_stage(self):
|
||||
"""Test setting progression stage."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.set_progression_stage("camera1", "brand_classification")
|
||||
|
||||
state = manager.get_state("camera1")
|
||||
assert state["progression_stage"] == "brand_classification"
|
||||
|
||||
def test_set_validated_detection(self):
|
||||
"""Test setting validated detection."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
detection = {"class": "car", "confidence": 0.85}
|
||||
manager.set_validated_detection("camera1", detection)
|
||||
|
||||
state = manager.get_state("camera1")
|
||||
assert state["validated_detection"] == detection
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test getting pipeline state statistics."""
|
||||
manager = PipelineStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
manager.update_mode("camera1", "validation_detecting")
|
||||
manager.update_mode("camera2", "classification")
|
||||
manager.update_mode("camera3", "classification")
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert stats["total_pipeline_states"] == 3
|
||||
assert stats["mode_breakdown"]["validation_detecting"] == 1
|
||||
assert stats["mode_breakdown"]["classification"] == 2
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Test thread safety of singleton managers."""
|
||||
|
||||
def test_model_manager_thread_safety(self):
|
||||
"""Test ModelStateManager thread safety."""
|
||||
manager = ModelStateManager()
|
||||
manager.clear_all()
|
||||
|
||||
results = {}
|
||||
|
||||
def load_models(thread_id):
|
||||
for i in range(10):
|
||||
model = Mock()
|
||||
model.thread_id = thread_id
|
||||
model.model_id = i
|
||||
manager.load_model(f"camera{thread_id}", f"model{i}", model)
|
||||
|
||||
# Verify models
|
||||
models = manager.get_camera_models(f"camera{thread_id}")
|
||||
results[thread_id] = len(models)
|
||||
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=load_models, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Each thread should have loaded 10 models
|
||||
for thread_id, model_count in results.items():
|
||||
assert model_count == 10
|
||||
|
||||
# Total should be 50 models
|
||||
stats = manager.get_stats()
|
||||
assert stats["total_models"] == 50
|
Loading…
Add table
Add a link
Reference in a new issue