Refactor: PHASE 8: Testing & Integration

This commit is contained in:
ziesorx 2025-09-12 18:55:23 +07:00
parent af34f4fd08
commit 9e8c6804a7
32 changed files with 17128 additions and 0 deletions

View file

@ -0,0 +1,429 @@
"""
Unit tests for configuration management system.
"""
import pytest
import json
import os
import tempfile
from unittest.mock import Mock, patch, MagicMock
from detector_worker.core.config import (
ConfigurationManager,
JsonFileProvider,
EnvironmentProvider,
DatabaseConfig,
RedisConfig,
StreamConfig,
ModelConfig,
LoggingConfig,
get_config_manager,
validate_config
)
from detector_worker.core.exceptions import ConfigurationError
class TestJsonFileProvider:
"""Test JSON file configuration provider."""
def test_get_config_from_valid_file(self, temp_dir):
"""Test loading configuration from a valid JSON file."""
config_data = {"test_key": "test_value", "number": 42}
config_file = os.path.join(temp_dir, "config.json")
with open(config_file, 'w') as f:
json.dump(config_data, f)
provider = JsonFileProvider(config_file)
result = provider.get_config()
assert result == config_data
def test_get_config_file_not_exists(self, temp_dir):
"""Test handling of non-existent config file."""
config_file = os.path.join(temp_dir, "nonexistent.json")
provider = JsonFileProvider(config_file)
result = provider.get_config()
assert result == {}
def test_get_config_invalid_json(self, temp_dir):
"""Test handling of invalid JSON file."""
config_file = os.path.join(temp_dir, "invalid.json")
with open(config_file, 'w') as f:
f.write("invalid json content")
provider = JsonFileProvider(config_file)
result = provider.get_config()
assert result == {}
def test_reload_updates_config(self, temp_dir):
"""Test that reload updates configuration."""
config_file = os.path.join(temp_dir, "config.json")
# Initial config
initial_config = {"version": 1}
with open(config_file, 'w') as f:
json.dump(initial_config, f)
provider = JsonFileProvider(config_file)
assert provider.get_config() == initial_config
# Update config file
updated_config = {"version": 2}
with open(config_file, 'w') as f:
json.dump(updated_config, f)
# Force reload
provider.reload()
assert provider.get_config() == updated_config
class TestEnvironmentProvider:
"""Test environment variable configuration provider."""
def test_get_config_with_env_vars(self):
"""Test loading configuration from environment variables."""
env_vars = {
"DETECTOR_MAX_STREAMS": "10",
"DETECTOR_TARGET_FPS": "15",
"DETECTOR_CONFIG": '{"nested": "value"}',
"OTHER_VAR": "ignored"
}
with patch.dict(os.environ, env_vars, clear=False):
provider = EnvironmentProvider("DETECTOR_")
config = provider.get_config()
assert config["max_streams"] == "10"
assert config["target_fps"] == "15"
assert config["config"] == {"nested": "value"}
assert "other_var" not in config
def test_get_config_no_env_vars(self):
"""Test with no matching environment variables."""
with patch.dict(os.environ, {}, clear=True):
provider = EnvironmentProvider("DETECTOR_")
config = provider.get_config()
assert config == {}
def test_custom_prefix(self):
"""Test with custom prefix."""
env_vars = {"CUSTOM_TEST": "value"}
with patch.dict(os.environ, env_vars, clear=False):
provider = EnvironmentProvider("CUSTOM_")
config = provider.get_config()
assert config["test"] == "value"
class TestConfigDataclasses:
"""Test configuration dataclasses."""
def test_database_config_from_dict(self):
"""Test DatabaseConfig creation from dictionary."""
data = {
"enabled": True,
"host": "db.example.com",
"port": 5432,
"database": "testdb",
"username": "user",
"password": "pass",
"schema": "test_schema",
"unknown_field": "ignored"
}
config = DatabaseConfig.from_dict(data)
assert config.enabled is True
assert config.host == "db.example.com"
assert config.port == 5432
assert config.database == "testdb"
assert config.username == "user"
assert config.password == "pass"
assert config.schema == "test_schema"
# Unknown fields should be ignored
assert not hasattr(config, 'unknown_field')
def test_redis_config_from_dict(self):
"""Test RedisConfig creation from dictionary."""
data = {
"enabled": True,
"host": "redis.example.com",
"port": 6379,
"password": "secret",
"db": 1
}
config = RedisConfig.from_dict(data)
assert config.enabled is True
assert config.host == "redis.example.com"
assert config.port == 6379
assert config.password == "secret"
assert config.db == 1
def test_stream_config_from_dict(self):
"""Test StreamConfig creation from dictionary."""
data = {
"poll_interval_ms": 50,
"max_streams": 10,
"target_fps": 20,
"reconnect_interval_sec": 10,
"max_retries": 5
}
config = StreamConfig.from_dict(data)
assert config.poll_interval_ms == 50
assert config.max_streams == 10
assert config.target_fps == 20
assert config.reconnect_interval_sec == 10
assert config.max_retries == 5
class TestConfigurationManager:
"""Test main configuration manager."""
def test_initialization_with_defaults(self):
"""Test that manager initializes with default values."""
manager = ConfigurationManager()
# Should have default providers
assert len(manager._providers) >= 1
# Should have default configuration values
config = manager.get_all()
assert "poll_interval_ms" in config
assert "max_streams" in config
assert "target_fps" in config
def test_add_provider(self):
"""Test adding configuration providers."""
manager = ConfigurationManager()
initial_count = len(manager._providers)
mock_provider = Mock()
mock_provider.get_config.return_value = {"test": "value"}
manager.add_provider(mock_provider)
assert len(manager._providers) == initial_count + 1
def test_get_configuration_value(self):
"""Test getting specific configuration values."""
manager = ConfigurationManager()
# Test existing key
value = manager.get("poll_interval_ms")
assert value is not None
# Test non-existing key with default
value = manager.get("nonexistent", "default")
assert value == "default"
# Test non-existing key without default
value = manager.get("nonexistent")
assert value is None
def test_get_section(self):
"""Test getting configuration sections."""
manager = ConfigurationManager()
# Test existing section
db_section = manager.get_section("database")
assert isinstance(db_section, dict)
# Test non-existing section
empty_section = manager.get_section("nonexistent")
assert empty_section == {}
def test_typed_config_access(self):
"""Test typed configuration object access."""
manager = ConfigurationManager()
# Test database config
db_config = manager.get_database_config()
assert isinstance(db_config, DatabaseConfig)
# Test Redis config
redis_config = manager.get_redis_config()
assert isinstance(redis_config, RedisConfig)
# Test stream config
stream_config = manager.get_stream_config()
assert isinstance(stream_config, StreamConfig)
# Test model config
model_config = manager.get_model_config()
assert isinstance(model_config, ModelConfig)
# Test logging config
logging_config = manager.get_logging_config()
assert isinstance(logging_config, LoggingConfig)
def test_set_configuration_value(self):
"""Test setting configuration values at runtime."""
manager = ConfigurationManager()
manager.set("test_key", "test_value")
assert manager.get("test_key") == "test_value"
# Should also update typed configs
manager.set("poll_interval_ms", 200)
stream_config = manager.get_stream_config()
assert stream_config.poll_interval_ms == 200
def test_validation_success(self):
"""Test configuration validation with valid config."""
manager = ConfigurationManager()
# Set valid configuration
manager.set("poll_interval_ms", 100)
manager.set("max_streams", 5)
manager.set("target_fps", 10)
errors = manager.validate()
assert errors == []
assert manager.is_valid() is True
def test_validation_errors(self):
"""Test configuration validation with invalid values."""
manager = ConfigurationManager()
# Set invalid configuration
manager.set("poll_interval_ms", 0)
manager.set("max_streams", -1)
manager.set("target_fps", 0)
errors = manager.validate()
assert len(errors) > 0
assert manager.is_valid() is False
# Check specific errors
error_messages = " ".join(errors)
assert "poll_interval_ms must be positive" in error_messages
assert "max_streams must be positive" in error_messages
assert "target_fps must be positive" in error_messages
def test_database_validation(self):
"""Test database-specific validation."""
manager = ConfigurationManager()
# Enable database but don't provide required fields
db_config = {
"enabled": True,
"host": "",
"database": ""
}
manager.set("database", db_config)
errors = manager.validate()
error_messages = " ".join(errors)
assert "database host is required" in error_messages
assert "database name is required" in error_messages
def test_redis_validation(self):
"""Test Redis-specific validation."""
manager = ConfigurationManager()
# Enable Redis but don't provide required fields
redis_config = {
"enabled": True,
"host": ""
}
manager.set("redis", redis_config)
errors = manager.validate()
error_messages = " ".join(errors)
assert "redis host is required" in error_messages
class TestGlobalConfigurationFunctions:
"""Test global configuration functions."""
def test_get_config_manager_singleton(self):
"""Test that get_config_manager returns a singleton."""
manager1 = get_config_manager()
manager2 = get_config_manager()
assert manager1 is manager2
@patch('detector_worker.core.config.get_config_manager')
def test_validate_config_function(self, mock_get_manager):
"""Test global validate_config function."""
mock_manager = Mock()
mock_manager.validate.return_value = ["error1", "error2"]
mock_get_manager.return_value = mock_manager
errors = validate_config()
assert errors == ["error1", "error2"]
mock_manager.validate.assert_called_once()
class TestConfigurationIntegration:
"""Integration tests for configuration system."""
def test_provider_priority(self, temp_dir):
"""Test that later providers override earlier ones."""
# Create JSON file with initial config
config_file = os.path.join(temp_dir, "config.json")
json_config = {"test_value": "from_json", "json_only": "json"}
with open(config_file, 'w') as f:
json.dump(json_config, f)
# Set environment variable that should override
env_vars = {"DETECTOR_TEST_VALUE": "from_env", "DETECTOR_ENV_ONLY": "env"}
with patch.dict(os.environ, env_vars, clear=False):
manager = ConfigurationManager()
manager._providers.clear() # Start fresh
# Add providers in order
manager.add_provider(JsonFileProvider(config_file))
manager.add_provider(EnvironmentProvider("DETECTOR_"))
config = manager.get_all()
# Environment should override JSON
assert config["test_value"] == "from_env"
# Both sources should be present
assert config["json_only"] == "json"
assert config["env_only"] == "env"
def test_hot_reload(self, temp_dir):
"""Test configuration hot reload functionality."""
config_file = os.path.join(temp_dir, "config.json")
# Initial config
initial_config = {"version": 1, "feature_enabled": False}
with open(config_file, 'w') as f:
json.dump(initial_config, f)
manager = ConfigurationManager()
manager._providers.clear()
manager.add_provider(JsonFileProvider(config_file))
assert manager.get("version") == 1
assert manager.get("feature_enabled") is False
# Update config file
updated_config = {"version": 2, "feature_enabled": True}
with open(config_file, 'w') as f:
json.dump(updated_config, f)
# Reload configuration
success = manager.reload()
assert success is True
assert manager.get("version") == 2
assert manager.get("feature_enabled") is True

View file

@ -0,0 +1,566 @@
"""
Unit tests for dependency injection system.
"""
import pytest
import threading
from unittest.mock import Mock, MagicMock
from detector_worker.core.dependency_injection import (
ServiceContainer,
ServiceLifetime,
ServiceDescriptor,
ServiceScope,
DetectorWorkerContainer,
get_container,
resolve_service,
create_service_scope
)
from detector_worker.core.exceptions import DependencyInjectionError
class TestServiceContainer:
"""Test core service container functionality."""
def test_register_singleton(self):
"""Test singleton service registration."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
# Register singleton
container.register_singleton(TestService)
# Resolve twice - should get same instance
instance1 = container.resolve(TestService)
instance2 = container.resolve(TestService)
assert instance1 is instance2
assert instance1.value == 42
def test_register_singleton_with_instance(self):
"""Test singleton registration with pre-created instance."""
container = ServiceContainer()
class TestService:
def __init__(self, value):
self.value = value
# Create instance and register
instance = TestService(99)
container.register_singleton(TestService, instance=instance)
# Resolve should return the pre-created instance
resolved = container.resolve(TestService)
assert resolved is instance
assert resolved.value == 99
def test_register_transient(self):
"""Test transient service registration."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
# Register transient
container.register_transient(TestService)
# Resolve twice - should get different instances
instance1 = container.resolve(TestService)
instance2 = container.resolve(TestService)
assert instance1 is not instance2
assert instance1.value == instance2.value == 42
def test_register_scoped(self):
"""Test scoped service registration."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
# Register scoped
container.register_scoped(TestService)
# Resolve in same scope - should get same instance
instance1 = container.resolve(TestService, scope_id="scope1")
instance2 = container.resolve(TestService, scope_id="scope1")
assert instance1 is instance2
# Resolve in different scope - should get different instance
instance3 = container.resolve(TestService, scope_id="scope2")
assert instance3 is not instance1
def test_register_with_factory(self):
"""Test service registration with factory function."""
container = ServiceContainer()
class TestService:
def __init__(self, value):
self.value = value
# Register with factory
def factory():
return TestService(100)
container.register_singleton(TestService, factory=factory)
instance = container.resolve(TestService)
assert instance.value == 100
def test_register_with_implementation_type(self):
"""Test service registration with implementation type."""
container = ServiceContainer()
class ITestService:
pass
class TestService(ITestService):
def __init__(self):
self.value = 42
# Register interface with implementation
container.register_singleton(ITestService, implementation_type=TestService)
instance = container.resolve(ITestService)
assert isinstance(instance, TestService)
assert instance.value == 42
def test_dependency_injection(self):
"""Test automatic dependency injection."""
container = ServiceContainer()
class DatabaseService:
def __init__(self):
self.connected = True
class UserService:
def __init__(self, database: DatabaseService):
self.database = database
# Register services
container.register_singleton(DatabaseService)
container.register_transient(UserService)
# Resolve should inject dependencies
user_service = container.resolve(UserService)
assert isinstance(user_service.database, DatabaseService)
assert user_service.database.connected is True
def test_circular_dependency_detection(self):
"""Test circular dependency detection."""
container = ServiceContainer()
class ServiceA:
def __init__(self, service_b: 'ServiceB'):
self.service_b = service_b
class ServiceB:
def __init__(self, service_a: ServiceA):
self.service_a = service_a
# Register circular dependencies
container.register_singleton(ServiceA)
container.register_singleton(ServiceB)
# Should raise circular dependency error
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(ServiceA)
assert "Circular dependency detected" in str(exc_info.value)
def test_unregistered_service_error(self):
"""Test error when resolving unregistered service."""
container = ServiceContainer()
class UnregisteredService:
pass
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(UnregisteredService)
assert "is not registered" in str(exc_info.value)
def test_scoped_service_without_scope_id(self):
"""Test error when resolving scoped service without scope ID."""
container = ServiceContainer()
class TestService:
pass
container.register_scoped(TestService)
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(TestService)
assert "Scope ID required" in str(exc_info.value)
def test_factory_error_handling(self):
"""Test factory error handling."""
container = ServiceContainer()
class TestService:
pass
def failing_factory():
raise ValueError("Factory failed")
container.register_singleton(TestService, factory=failing_factory)
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(TestService)
assert "Failed to create service using factory" in str(exc_info.value)
def test_constructor_dependency_with_default(self):
"""Test dependency with default value."""
container = ServiceContainer()
class TestService:
def __init__(self, value: int = 42):
self.value = value
container.register_singleton(TestService)
instance = container.resolve(TestService)
assert instance.value == 42
def test_unresolvable_dependency_with_default(self):
"""Test unresolvable dependency that has a default value."""
container = ServiceContainer()
class UnregisteredService:
pass
class TestService:
def __init__(self, dep: UnregisteredService = None):
self.dep = dep
container.register_singleton(TestService)
instance = container.resolve(TestService)
assert instance.dep is None
def test_unresolvable_dependency_without_default(self):
"""Test unresolvable dependency without default value."""
container = ServiceContainer()
class UnregisteredService:
pass
class TestService:
def __init__(self, dep: UnregisteredService):
self.dep = dep
container.register_singleton(TestService)
with pytest.raises(DependencyInjectionError) as exc_info:
container.resolve(TestService)
assert "Cannot resolve dependency" in str(exc_info.value)
class TestServiceScope:
"""Test service scope functionality."""
def test_create_scope(self):
"""Test scope creation."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.value = 42
container.register_scoped(TestService)
scope = container.create_scope("test_scope")
assert isinstance(scope, ServiceScope)
assert scope.scope_id == "test_scope"
def test_scope_context_manager(self):
"""Test scope as context manager."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.disposed = False
def dispose(self):
self.disposed = True
container.register_scoped(TestService)
instance = None
with container.create_scope("test_scope") as scope:
instance = scope.resolve(TestService)
assert not instance.disposed
# Instance should be disposed after scope exit
assert instance.disposed
def test_dispose_scope(self):
"""Test manual scope disposal."""
container = ServiceContainer()
class TestService:
def __init__(self):
self.disposed = False
def dispose(self):
self.disposed = True
container.register_scoped(TestService)
instance = container.resolve(TestService, scope_id="test_scope")
assert not instance.disposed
container.dispose_scope("test_scope")
assert instance.disposed
def test_dispose_error_handling(self):
"""Test error handling during scope disposal."""
container = ServiceContainer()
class TestService:
def dispose(self):
raise ValueError("Dispose failed")
container.register_scoped(TestService)
container.resolve(TestService, scope_id="test_scope")
# Should not raise error, just log it
container.dispose_scope("test_scope")
class TestContainerIntrospection:
"""Test container introspection capabilities."""
def test_is_registered(self):
"""Test checking if service is registered."""
container = ServiceContainer()
class RegisteredService:
pass
class UnregisteredService:
pass
container.register_singleton(RegisteredService)
assert container.is_registered(RegisteredService) is True
assert container.is_registered(UnregisteredService) is False
def test_get_registration_info(self):
"""Test getting service registration information."""
container = ServiceContainer()
class TestService:
pass
container.register_singleton(TestService)
info = container.get_registration_info(TestService)
assert isinstance(info, ServiceDescriptor)
assert info.service_type == TestService
assert info.lifetime == ServiceLifetime.SINGLETON
def test_get_registered_services(self):
"""Test getting all registered services."""
container = ServiceContainer()
class Service1:
pass
class Service2:
pass
container.register_singleton(Service1)
container.register_transient(Service2)
services = container.get_registered_services()
assert len(services) == 2
assert Service1 in services
assert Service2 in services
def test_clear_singletons(self):
"""Test clearing singleton instances."""
container = ServiceContainer()
class TestService:
pass
container.register_singleton(TestService)
# Create singleton instance
instance1 = container.resolve(TestService)
# Clear singletons
container.clear_singletons()
# Next resolve should create new instance
instance2 = container.resolve(TestService)
assert instance2 is not instance1
def test_get_stats(self):
"""Test getting container statistics."""
container = ServiceContainer()
class Service1:
pass
class Service2:
pass
class Service3:
pass
container.register_singleton(Service1)
container.register_transient(Service2)
container.register_scoped(Service3)
# Create some instances
container.resolve(Service1)
container.resolve(Service3, scope_id="scope1")
stats = container.get_stats()
assert stats["registered_services"] == 3
assert stats["active_singletons"] == 1
assert stats["active_scopes"] == 1
assert stats["lifetime_breakdown"]["singleton"] == 1
assert stats["lifetime_breakdown"]["transient"] == 1
assert stats["lifetime_breakdown"]["scoped"] == 1
class TestDetectorWorkerContainer:
"""Test pre-configured detector worker container."""
def test_initialization(self):
"""Test detector worker container initialization."""
container = DetectorWorkerContainer()
assert isinstance(container.container, ServiceContainer)
# Should have core services registered
stats = container.container.get_stats()
assert stats["registered_services"] > 0
def test_resolve_convenience_method(self):
"""Test resolve convenience method."""
container = DetectorWorkerContainer()
# Should be able to resolve through convenience method
from detector_worker.core.singleton_managers import ModelStateManager
manager = container.resolve(ModelStateManager)
assert isinstance(manager, ModelStateManager)
def test_create_scope_convenience_method(self):
"""Test create scope convenience method."""
container = DetectorWorkerContainer()
scope = container.create_scope("test_scope")
assert isinstance(scope, ServiceScope)
assert scope.scope_id == "test_scope"
class TestGlobalContainerFunctions:
"""Test global container functions."""
def test_get_container_singleton(self):
"""Test that get_container returns a singleton."""
container1 = get_container()
container2 = get_container()
assert container1 is container2
assert isinstance(container1, DetectorWorkerContainer)
def test_resolve_service_convenience(self):
"""Test resolve_service convenience function."""
from detector_worker.core.singleton_managers import ModelStateManager
manager = resolve_service(ModelStateManager)
assert isinstance(manager, ModelStateManager)
def test_create_service_scope_convenience(self):
"""Test create_service_scope convenience function."""
scope = create_service_scope("test_scope")
assert isinstance(scope, ServiceScope)
assert scope.scope_id == "test_scope"
class TestThreadSafety:
"""Test thread safety of dependency injection system."""
def test_container_thread_safety(self):
"""Test that container is thread-safe."""
container = ServiceContainer()
class TestService:
def __init__(self):
import threading
self.thread_id = threading.current_thread().ident
container.register_singleton(TestService)
instances = {}
def resolve_service(thread_id):
instances[thread_id] = container.resolve(TestService)
# Create multiple threads
threads = []
for i in range(10):
thread = threading.Thread(target=resolve_service, args=(i,))
threads.append(thread)
thread.start()
# Wait for all threads
for thread in threads:
thread.join()
# All should get the same singleton instance
first_instance = list(instances.values())[0]
for instance in instances.values():
assert instance is first_instance
def test_scope_thread_safety(self):
"""Test that scoped services are thread-safe."""
container = ServiceContainer()
class TestService:
def __init__(self):
import threading
self.thread_id = threading.current_thread().ident
container.register_scoped(TestService)
results = {}
def resolve_in_scope(thread_id):
# Each thread uses its own scope
instance1 = container.resolve(TestService, scope_id=f"scope_{thread_id}")
instance2 = container.resolve(TestService, scope_id=f"scope_{thread_id}")
results[thread_id] = {
"same_instance": instance1 is instance2,
"thread_id": instance1.thread_id
}
threads = []
for i in range(5):
thread = threading.Thread(target=resolve_in_scope, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Each thread should get same instance within its scope
for thread_id, result in results.items():
assert result["same_instance"] is True

View file

@ -0,0 +1,560 @@
"""
Unit tests for singleton state managers.
"""
import pytest
import time
import threading
from unittest.mock import Mock, patch, MagicMock
from detector_worker.core.singleton_managers import (
SingletonMeta,
ModelStateManager,
StreamStateManager,
SessionStateManager,
CacheStateManager,
CameraStateManager,
PipelineStateManager,
ModelInfo,
StreamInfo,
SessionInfo
)
class TestSingletonMeta:
"""Test singleton metaclass."""
def test_singleton_behavior(self):
"""Test that singleton metaclass creates only one instance."""
class TestSingleton(metaclass=SingletonMeta):
def __init__(self):
self.value = 42
instance1 = TestSingleton()
instance2 = TestSingleton()
assert instance1 is instance2
assert instance1.value == instance2.value
def test_singleton_thread_safety(self):
"""Test that singleton is thread-safe."""
class TestSingleton(metaclass=SingletonMeta):
def __init__(self):
self.created_by = threading.current_thread().name
instances = {}
def create_instance(thread_id):
instances[thread_id] = TestSingleton()
threads = []
for i in range(10):
thread = threading.Thread(target=create_instance, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# All instances should be the same object
first_instance = instances[0]
for instance in instances.values():
assert instance is first_instance
class TestModelStateManager:
"""Test model state management."""
def test_singleton_behavior(self):
"""Test that ModelStateManager is a singleton."""
manager1 = ModelStateManager()
manager2 = ModelStateManager()
assert manager1 is manager2
def test_load_model(self):
"""Test loading a model."""
manager = ModelStateManager()
manager.clear_all() # Start fresh
mock_model = Mock()
manager.load_model("camera1", "model1", mock_model)
retrieved_model = manager.get_model("camera1", "model1")
assert retrieved_model is mock_model
def test_load_same_model_increments_reference_count(self):
"""Test that loading the same model increments reference count."""
manager = ModelStateManager()
manager.clear_all()
mock_model = Mock()
# Load same model twice
manager.load_model("camera1", "model1", mock_model)
manager.load_model("camera1", "model1", mock_model)
# Should still be accessible
assert manager.get_model("camera1", "model1") is mock_model
def test_get_camera_models(self):
"""Test getting all models for a camera."""
manager = ModelStateManager()
manager.clear_all()
mock_model1 = Mock()
mock_model2 = Mock()
manager.load_model("camera1", "model1", mock_model1)
manager.load_model("camera1", "model2", mock_model2)
models = manager.get_camera_models("camera1")
assert len(models) == 2
assert models["model1"] is mock_model1
assert models["model2"] is mock_model2
def test_unload_model_with_multiple_references(self):
"""Test unloading model with multiple references."""
manager = ModelStateManager()
manager.clear_all()
mock_model = Mock()
# Load model twice (reference count = 2)
manager.load_model("camera1", "model1", mock_model)
manager.load_model("camera1", "model1", mock_model)
# First unload should not remove model
result = manager.unload_model("camera1", "model1")
assert result is False # Still referenced
assert manager.get_model("camera1", "model1") is mock_model
# Second unload should remove model
result = manager.unload_model("camera1", "model1")
assert result is True # Completely removed
assert manager.get_model("camera1", "model1") is None
def test_unload_camera_models(self):
"""Test unloading all models for a camera."""
manager = ModelStateManager()
manager.clear_all()
mock_model1 = Mock()
mock_model2 = Mock()
manager.load_model("camera1", "model1", mock_model1)
manager.load_model("camera1", "model2", mock_model2)
manager.unload_camera_models("camera1")
assert manager.get_model("camera1", "model1") is None
assert manager.get_model("camera1", "model2") is None
def test_get_stats(self):
"""Test getting model statistics."""
manager = ModelStateManager()
manager.clear_all()
mock_model = Mock()
manager.load_model("camera1", "model1", mock_model)
manager.load_model("camera2", "model2", mock_model)
stats = manager.get_stats()
assert stats["total_models"] == 2
assert stats["total_cameras"] == 2
assert "camera1" in stats["cameras"]
assert "camera2" in stats["cameras"]
class TestStreamStateManager:
"""Test stream state management."""
def test_add_stream(self):
"""Test adding a stream."""
manager = StreamStateManager()
manager.clear_all()
config = {"rtsp_url": "rtsp://example.com", "model_id": "test"}
manager.add_stream("camera1", "sub1", config)
stream = manager.get_stream("camera1")
assert stream is not None
assert stream.camera_id == "camera1"
assert stream.subscription_id == "sub1"
assert stream.config == config
def test_subscription_mapping(self):
"""Test subscription to camera mapping."""
manager = StreamStateManager()
manager.clear_all()
config = {"rtsp_url": "rtsp://example.com"}
manager.add_stream("camera1", "sub1", config)
camera_id = manager.get_camera_by_subscription("sub1")
assert camera_id == "camera1"
def test_remove_stream(self):
"""Test removing a stream."""
manager = StreamStateManager()
manager.clear_all()
config = {"rtsp_url": "rtsp://example.com"}
manager.add_stream("camera1", "sub1", config)
removed_stream = manager.remove_stream("camera1")
assert removed_stream is not None
assert removed_stream.camera_id == "camera1"
assert manager.get_stream("camera1") is None
assert manager.get_camera_by_subscription("sub1") is None
def test_shared_stream_management(self):
"""Test shared stream management."""
manager = StreamStateManager()
manager.clear_all()
stream_data = {"reader": Mock(), "reference_count": 1}
manager.add_shared_stream("rtsp://example.com", stream_data)
retrieved_data = manager.get_shared_stream("rtsp://example.com")
assert retrieved_data == stream_data
removed_data = manager.remove_shared_stream("rtsp://example.com")
assert removed_data == stream_data
assert manager.get_shared_stream("rtsp://example.com") is None
class TestSessionStateManager:
"""Test session state management."""
def test_session_id_management(self):
"""Test session ID assignment."""
manager = SessionStateManager()
manager.clear_all()
manager.set_session_id("display1", "session123")
session_id = manager.get_session_id("display1")
assert session_id == "session123"
def test_create_session(self):
"""Test session creation with detection data."""
manager = SessionStateManager()
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", detection_data)
retrieved_data = manager.get_session_detection("session123")
assert retrieved_data == detection_data
camera_id = manager.get_camera_by_session("session123")
assert camera_id == "camera1"
def test_update_session_detection(self):
"""Test updating session detection data."""
manager = SessionStateManager()
manager.clear_all()
initial_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", initial_data)
update_data = {"brand": "Toyota"}
manager.update_session_detection("session123", update_data)
final_data = manager.get_session_detection("session123")
assert final_data["class"] == "car"
assert final_data["brand"] == "Toyota"
def test_session_expiration(self):
"""Test session expiration based on TTL."""
# Use a very short TTL for testing
manager = SessionStateManager(session_ttl=0.1)
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", detection_data)
# Session should exist initially
assert manager.get_session_detection("session123") is not None
# Wait for expiration
time.sleep(0.2)
# Clean up expired sessions
expired_count = manager.cleanup_expired_sessions()
assert expired_count == 1
assert manager.get_session_detection("session123") is None
def test_remove_session(self):
"""Test manual session removal."""
manager = SessionStateManager()
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85}
manager.create_session("session123", "camera1", detection_data)
result = manager.remove_session("session123")
assert result is True
assert manager.get_session_detection("session123") is None
assert manager.get_camera_by_session("session123") is None
class TestCacheStateManager:
"""Test cache state management."""
def test_cache_detection(self):
"""Test caching detection results."""
manager = CacheStateManager()
manager.clear_all()
detection_data = {"class": "car", "confidence": 0.85, "bbox": [100, 200, 300, 400]}
manager.cache_detection("camera1", detection_data)
cached_data = manager.get_cached_detection("camera1")
assert cached_data == detection_data
def test_cache_pipeline_result(self):
"""Test caching pipeline results."""
manager = CacheStateManager()
manager.clear_all()
pipeline_result = {"status": "success", "detections": []}
manager.cache_pipeline_result("camera1", pipeline_result)
cached_result = manager.get_cached_pipeline_result("camera1")
assert cached_result == pipeline_result
def test_latest_frame_management(self):
"""Test latest frame storage."""
manager = CacheStateManager()
manager.clear_all()
frame_data = b"fake_frame_data"
manager.set_latest_frame("camera1", frame_data)
retrieved_frame = manager.get_latest_frame("camera1")
assert retrieved_frame == frame_data
def test_frame_skip_flag(self):
"""Test frame skip flag management."""
manager = CacheStateManager()
manager.clear_all()
# Initially should be False
assert manager.get_frame_skip_flag("camera1") is False
manager.set_frame_skip_flag("camera1", True)
assert manager.get_frame_skip_flag("camera1") is True
manager.set_frame_skip_flag("camera1", False)
assert manager.get_frame_skip_flag("camera1") is False
def test_clear_camera_cache(self):
"""Test clearing all cache data for a camera."""
manager = CacheStateManager()
manager.clear_all()
# Set up cache data
detection_data = {"class": "car"}
pipeline_result = {"status": "success"}
frame_data = b"frame"
manager.cache_detection("camera1", detection_data)
manager.cache_pipeline_result("camera1", pipeline_result)
manager.set_latest_frame("camera1", frame_data)
manager.set_frame_skip_flag("camera1", True)
# Clear cache
manager.clear_camera_cache("camera1")
# All data should be gone
assert manager.get_cached_detection("camera1") is None
assert manager.get_cached_pipeline_result("camera1") is None
assert manager.get_latest_frame("camera1") is None
assert manager.get_frame_skip_flag("camera1") is False
class TestCameraStateManager:
"""Test camera state management."""
def test_camera_connection_state(self):
"""Test camera connection state management."""
manager = CameraStateManager()
manager.clear_all()
# Initially connected (default)
assert manager.is_camera_connected("camera1") is True
# Set disconnected
manager.set_camera_connected("camera1", False)
assert manager.is_camera_connected("camera1") is False
# Set connected again
manager.set_camera_connected("camera1", True)
assert manager.is_camera_connected("camera1") is True
def test_notification_flags(self):
"""Test disconnection/reconnection notification flags."""
manager = CameraStateManager()
manager.clear_all()
# Set disconnected
manager.set_camera_connected("camera1", False)
# Should notify disconnection once
assert manager.should_notify_disconnection("camera1") is True
manager.mark_disconnection_notified("camera1")
assert manager.should_notify_disconnection("camera1") is False
# Reconnect
manager.set_camera_connected("camera1", True)
# Should notify reconnection
assert manager.should_notify_reconnection("camera1") is True
manager.mark_reconnection_notified("camera1")
assert manager.should_notify_reconnection("camera1") is False
def test_get_camera_state(self):
"""Test getting full camera state."""
manager = CameraStateManager()
manager.clear_all()
manager.set_camera_connected("camera1", False)
state = manager.get_camera_state("camera1")
assert state["connected"] is False
assert "last_update" in state
assert "disconnection_notified" in state
assert "reconnection_notified" in state
def test_get_stats(self):
"""Test getting camera state statistics."""
manager = CameraStateManager()
manager.clear_all()
manager.set_camera_connected("camera1", True)
manager.set_camera_connected("camera2", False)
stats = manager.get_stats()
assert stats["total_cameras"] == 2
assert stats["connected_cameras"] == 1
assert stats["disconnected_cameras"] == 1
class TestPipelineStateManager:
"""Test pipeline state management."""
def test_get_or_init_state(self):
"""Test getting or initializing pipeline state."""
manager = PipelineStateManager()
manager.clear_all()
state = manager.get_or_init_state("camera1")
assert state["mode"] == "validation_detecting"
assert state["backend_session_id"] is None
assert state["yolo_inference_enabled"] is True
assert "created_at" in state
def test_update_mode(self):
"""Test updating pipeline mode."""
manager = PipelineStateManager()
manager.clear_all()
manager.update_mode("camera1", "classification", "session123")
state = manager.get_state("camera1")
assert state["mode"] == "classification"
assert state["backend_session_id"] == "session123"
def test_set_yolo_inference_enabled(self):
"""Test setting YOLO inference state."""
manager = PipelineStateManager()
manager.clear_all()
manager.set_yolo_inference_enabled("camera1", False)
state = manager.get_state("camera1")
assert state["yolo_inference_enabled"] is False
def test_set_progression_stage(self):
"""Test setting progression stage."""
manager = PipelineStateManager()
manager.clear_all()
manager.set_progression_stage("camera1", "brand_classification")
state = manager.get_state("camera1")
assert state["progression_stage"] == "brand_classification"
def test_set_validated_detection(self):
"""Test setting validated detection."""
manager = PipelineStateManager()
manager.clear_all()
detection = {"class": "car", "confidence": 0.85}
manager.set_validated_detection("camera1", detection)
state = manager.get_state("camera1")
assert state["validated_detection"] == detection
def test_get_stats(self):
"""Test getting pipeline state statistics."""
manager = PipelineStateManager()
manager.clear_all()
manager.update_mode("camera1", "validation_detecting")
manager.update_mode("camera2", "classification")
manager.update_mode("camera3", "classification")
stats = manager.get_stats()
assert stats["total_pipeline_states"] == 3
assert stats["mode_breakdown"]["validation_detecting"] == 1
assert stats["mode_breakdown"]["classification"] == 2
class TestThreadSafety:
"""Test thread safety of singleton managers."""
def test_model_manager_thread_safety(self):
"""Test ModelStateManager thread safety."""
manager = ModelStateManager()
manager.clear_all()
results = {}
def load_models(thread_id):
for i in range(10):
model = Mock()
model.thread_id = thread_id
model.model_id = i
manager.load_model(f"camera{thread_id}", f"model{i}", model)
# Verify models
models = manager.get_camera_models(f"camera{thread_id}")
results[thread_id] = len(models)
threads = []
for i in range(5):
thread = threading.Thread(target=load_models, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
# Each thread should have loaded 10 models
for thread_id, model_count in results.items():
assert model_count == 10
# Total should be 50 models
stats = manager.get_stats()
assert stats["total_models"] == 50