Refactor: PHASE 6: Decoupling & Integration

This commit is contained in:
ziesorx 2025-09-12 15:57:51 +07:00
parent 6c7c4c5d9c
commit accefde8a1
8 changed files with 2344 additions and 86 deletions

View file

@ -14,17 +14,18 @@ from typing import Dict, Any, Optional, Callable, List, Set
from contextlib import asynccontextmanager
from fastapi import WebSocket
from websockets.exceptions import ConnectionClosedError, WebSocketDisconnect
from fastapi.websockets import WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
from ..core.config import config, subscription_to_camera, latest_frames
from ..core.constants import HEARTBEAT_INTERVAL
from ..core.exceptions import WebSocketError, StreamError
from ..streams.stream_manager import StreamManager
from ..streams.camera_monitor import CameraConnectionMonitor
from ..streams.camera_monitor import CameraMonitor
from ..detection.detection_result import DetectionResult
from ..models.model_manager import ModelManager
from ..pipeline.pipeline_executor import PipelineExecutor
from ..storage.session_cache import SessionCache
from ..storage.session_cache import SessionCacheManager
from ..storage.redis_client import RedisClientManager
from ..utils.system_monitor import get_system_metrics
@ -55,7 +56,7 @@ class WebSocketHandler:
stream_manager: StreamManager,
model_manager: ModelManager,
pipeline_executor: PipelineExecutor,
session_cache: SessionCache,
session_cache: SessionCacheManager,
redis_client: Optional[RedisClientManager] = None
):
"""
@ -603,7 +604,7 @@ async def handle_websocket_connection(
stream_manager: StreamManager,
model_manager: ModelManager,
pipeline_executor: PipelineExecutor,
session_cache: SessionCache,
session_cache: SessionCacheManager,
redis_client: Optional[RedisClientManager] = None
) -> None:
"""

View file

@ -1,22 +1,429 @@
"""
Configuration management for detector worker.
Centralized configuration management for the detector worker.
This module handles application configuration loading, validation,
and provides centralized access to configuration parameters.
This module provides comprehensive configuration management including:
- Environment variable support
- Configuration validation
- Hot-reload capabilities
- Type-safe configuration access
"""
import json
import os
import logging
from typing import Dict, Any, Optional
from dataclasses import dataclass
import threading
from typing import Dict, Any, Optional, Union, List
from pathlib import Path
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
from .exceptions import ConfigurationError
# Setup logging
logger = logging.getLogger("detector_worker.config")
class ConfigurationProvider(ABC):
"""Abstract base class for configuration providers."""
@abstractmethod
def get_config(self) -> Dict[str, Any]:
"""Get configuration data."""
pass
@abstractmethod
def reload(self) -> bool:
"""Reload configuration data."""
pass
class JsonFileProvider(ConfigurationProvider):
"""Configuration provider that reads from JSON files."""
def __init__(self, file_path: str):
"""Initialize with JSON file path."""
self.file_path = Path(file_path)
self._config: Dict[str, Any] = {}
self._last_modified: Optional[float] = None
def get_config(self) -> Dict[str, Any]:
"""Get configuration from JSON file."""
if not self._config or self._should_reload():
self.reload()
return self._config.copy()
def reload(self) -> bool:
"""Reload configuration from file."""
try:
if self.file_path.exists():
with open(self.file_path, 'r') as f:
self._config = json.load(f)
self._last_modified = self.file_path.stat().st_mtime
logger.debug(f"Loaded configuration from {self.file_path}")
return True
else:
logger.warning(f"Configuration file not found: {self.file_path}")
return False
except Exception as e:
logger.error(f"Error loading configuration from {self.file_path}: {e}")
return False
def _should_reload(self) -> bool:
"""Check if file has been modified since last load."""
if not self.file_path.exists():
return False
current_mtime = self.file_path.stat().st_mtime
return self._last_modified is None or current_mtime > self._last_modified
class EnvironmentProvider(ConfigurationProvider):
"""Configuration provider that reads from environment variables."""
def __init__(self, prefix: str = "DETECTOR_"):
"""Initialize with environment variable prefix."""
self.prefix = prefix
def get_config(self) -> Dict[str, Any]:
"""Get configuration from environment variables."""
config = {}
for key, value in os.environ.items():
if key.startswith(self.prefix):
# Convert DETECTOR_POLL_INTERVAL_MS -> poll_interval_ms
config_key = key[len(self.prefix):].lower()
# Try to parse as JSON first, then as string
try:
config[config_key] = json.loads(value)
except json.JSONDecodeError:
config[config_key] = value
return config
def reload(self) -> bool:
"""Environment variables don't need explicit reload."""
return True
@dataclass
class DatabaseConfig:
"""Database configuration."""
enabled: bool = False
host: str = "localhost"
port: int = 5432
database: str = "detector_worker"
username: str = "postgres"
password: str = ""
schema: str = "public"
connection_pool_size: int = 10
connection_timeout: int = 30
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DatabaseConfig':
"""Create from dictionary."""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class RedisConfig:
"""Redis configuration."""
enabled: bool = False
host: str = "localhost"
port: int = 6379
password: Optional[str] = None
db: int = 0
connection_pool_size: int = 10
connection_timeout: int = 30
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
"""Create from dictionary."""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class StreamConfig:
"""Stream processing configuration."""
poll_interval_ms: int = 100
max_streams: int = 5
target_fps: int = 10
reconnect_interval_sec: int = 5
max_retries: int = 3
buffer_size: int = 1
timeout_ms: int = 10000
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'StreamConfig':
"""Create from dictionary."""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class ModelConfig:
"""Model configuration."""
models_dir: str = "models"
cache_size_mb: int = 1024
load_timeout_sec: int = 60
inference_timeout_sec: int = 10
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ModelConfig':
"""Create from dictionary."""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class LoggingConfig:
"""Logging configuration."""
level: str = "INFO"
format: str = "%(asctime)s | %(levelname)s | %(name)s | %(message)s"
file_path: Optional[str] = "detector_worker.log"
max_file_size_mb: int = 10
backup_count: int = 5
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'LoggingConfig':
"""Create from dictionary."""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
class ConfigurationManager:
"""Centralized configuration manager with multiple providers and validation."""
def __init__(self):
"""Initialize configuration manager."""
self._providers: List[ConfigurationProvider] = []
self._config: Dict[str, Any] = {}
self._typed_configs: Dict[str, Any] = {}
self._lock = threading.RLock()
self._default_config = self._get_default_config()
# Set up default providers
self.add_provider(JsonFileProvider("config.json"))
self.add_provider(EnvironmentProvider("DETECTOR_"))
# Load initial configuration
self.reload()
def add_provider(self, provider: ConfigurationProvider) -> None:
"""Add a configuration provider."""
with self._lock:
self._providers.append(provider)
logger.debug(f"Added configuration provider: {type(provider).__name__}")
def reload(self) -> bool:
"""Reload configuration from all providers."""
with self._lock:
# Start with default configuration
config = self._default_config.copy()
# Merge configurations from providers (later providers override earlier ones)
for provider in self._providers:
try:
provider_config = provider.get_config()
config.update(provider_config)
except Exception as e:
logger.error(f"Error loading from provider {type(provider).__name__}: {e}")
self._config = config
self._update_typed_configs()
logger.debug("Configuration reloaded")
return True
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration values."""
return {
"poll_interval_ms": 100,
"max_streams": 5,
"target_fps": 10,
"reconnect_interval_sec": 5,
"max_retries": 3,
"heartbeat_interval": 2,
"session_timeout": 600,
"models_dir": "models",
"log_level": "INFO",
"database": {
"enabled": False,
"host": "localhost",
"port": 5432,
"database": "detector_worker",
"username": "postgres",
"password": "",
"schema": "public"
},
"redis": {
"enabled": False,
"host": "localhost",
"port": 6379,
"password": None,
"db": 0
}
}
def _update_typed_configs(self) -> None:
"""Update typed configuration objects."""
# Database configuration
db_config = self._config.get("database", {})
self._typed_configs["database"] = DatabaseConfig.from_dict(db_config)
# Redis configuration
redis_config = self._config.get("redis", {})
self._typed_configs["redis"] = RedisConfig.from_dict(redis_config)
# Stream configuration
stream_config = {k: v for k, v in self._config.items()
if k in StreamConfig.__dataclass_fields__}
self._typed_configs["stream"] = StreamConfig.from_dict(stream_config)
# Model configuration
model_config = {k: v for k, v in self._config.items()
if k in ModelConfig.__dataclass_fields__}
self._typed_configs["model"] = ModelConfig.from_dict(model_config)
# Logging configuration
logging_config = self._config.get("logging", {})
self._typed_configs["logging"] = LoggingConfig.from_dict(logging_config)
def get(self, key: str, default: Any = None) -> Any:
"""Get a configuration value."""
with self._lock:
return self._config.get(key, default)
def get_section(self, section: str) -> Dict[str, Any]:
"""Get a configuration section."""
with self._lock:
return self._config.get(section, {})
def get_database_config(self) -> DatabaseConfig:
"""Get typed database configuration."""
with self._lock:
return self._typed_configs["database"]
def get_redis_config(self) -> RedisConfig:
"""Get typed Redis configuration."""
with self._lock:
return self._typed_configs["redis"]
def get_stream_config(self) -> StreamConfig:
"""Get typed stream configuration."""
with self._lock:
return self._typed_configs["stream"]
def get_model_config(self) -> ModelConfig:
"""Get typed model configuration."""
with self._lock:
return self._typed_configs["model"]
def get_logging_config(self) -> LoggingConfig:
"""Get typed logging configuration."""
with self._lock:
return self._typed_configs["logging"]
def get_all(self) -> Dict[str, Any]:
"""Get all configuration values."""
with self._lock:
return self._config.copy()
def set(self, key: str, value: Any) -> None:
"""Set a configuration value (runtime only)."""
with self._lock:
self._config[key] = value
self._update_typed_configs()
def validate(self) -> List[str]:
"""Validate configuration and return any errors."""
errors = []
with self._lock:
# Validate stream configuration
if self._config.get("poll_interval_ms", 0) <= 0:
errors.append("poll_interval_ms must be positive")
if self._config.get("max_streams", 0) <= 0:
errors.append("max_streams must be positive")
if self._config.get("target_fps", 0) <= 0:
errors.append("target_fps must be positive")
# Validate database configuration if enabled
db_config = self.get_database_config()
if db_config.enabled:
if not db_config.host:
errors.append("database host is required when database is enabled")
if not db_config.database:
errors.append("database name is required when database is enabled")
# Validate Redis configuration if enabled
redis_config = self.get_redis_config()
if redis_config.enabled:
if not redis_config.host:
errors.append("redis host is required when redis is enabled")
return errors
def is_valid(self) -> bool:
"""Check if configuration is valid."""
return len(self.validate()) == 0
# Global configuration manager instance
_config_manager: Optional[ConfigurationManager] = None
_config_lock = threading.Lock()
def get_config_manager() -> ConfigurationManager:
"""Get or create the global configuration manager."""
global _config_manager
if _config_manager is None:
with _config_lock:
if _config_manager is None:
_config_manager = ConfigurationManager()
logger.info("Created global configuration manager")
return _config_manager
# Backward compatibility - these will be replaced by singleton managers
subscription_to_camera: Dict[str, str] = {}
latest_frames: Dict[str, Any] = {}
# Configuration access
config_manager = get_config_manager()
config = config_manager.get_all()
# Constants derived from configuration
MAX_STREAMS = config_manager.get("max_streams", 5)
TARGET_FPS = config_manager.get("target_fps", 10)
POLL_INTERVAL_MS = config_manager.get("poll_interval_ms", 100)
RECONNECT_INTERVAL_SEC = config_manager.get("reconnect_interval_sec", 5)
MAX_RETRIES = config_manager.get("max_retries", 3)
MODELS_DIR = config_manager.get("models_dir", "models")
# Convenience functions for backward compatibility
def get_config(key: str, default: Any = None) -> Any:
"""Get configuration value."""
return config_manager.get(key, default)
def reload_config() -> bool:
"""Reload configuration from all sources."""
global config
success = config_manager.reload()
if success:
config = config_manager.get_all()
return success
def validate_config() -> List[str]:
"""Validate current configuration."""
return config_manager.validate()
# Legacy compatibility
@dataclass
class DetectorConfig:
"""Configuration class for detector worker parameters."""
"""Legacy configuration class for backward compatibility."""
# Frame processing settings
poll_interval_ms: int = 100
@ -38,78 +445,16 @@ class DetectorConfig:
return 1000 / self.target_fps if self.target_fps > 0 else self.poll_interval_ms
class ConfigManager:
"""Centralized configuration manager."""
def __init__(self, config_file: str = "config.json"):
self.config_file = config_file
self._config: Optional[DetectorConfig] = None
def load_config(self) -> DetectorConfig:
"""Load configuration from file with defaults fallback."""
if self._config is not None:
return self._config
config_data = {}
# Try to load from config file
if os.path.exists(self.config_file):
try:
with open(self.config_file, "r") as f:
config_data = json.load(f)
logger.info(f"Loaded configuration from {self.config_file}")
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to load config from {self.config_file}: {e}")
logger.info("Using default configuration")
else:
logger.info(f"Config file {self.config_file} not found, using defaults")
# Create config with defaults + loaded values
self._config = DetectorConfig(
poll_interval_ms=config_data.get("poll_interval_ms", 100),
target_fps=config_data.get("target_fps", 10),
max_streams=config_data.get("max_streams", 5),
reconnect_interval_sec=config_data.get("reconnect_interval_sec", 5),
max_retries=config_data.get("max_retries", 3),
log_level=config_data.get("log_level", "INFO"),
log_file=config_data.get("log_file", "detector_worker.log"),
websocket_log_file=config_data.get("websocket_log_file", "websocket_comm.log")
)
# Log configuration summary
self._log_config_summary()
return self._config
def _log_config_summary(self):
"""Log configuration summary for debugging."""
if self._config:
logger.info(f"Configuration loaded:")
logger.info(f" Target FPS: {self._config.target_fps}")
logger.info(f" Poll interval: {self._config.poll_interval}ms")
logger.info(f" Max streams: {self._config.max_streams}")
logger.info(f" Max retries: {self._config.max_retries}")
logger.info(f" Log level: {self._config.log_level}")
def get_config(self) -> DetectorConfig:
"""Get current configuration, loading if necessary."""
if self._config is None:
return self.load_config()
return self._config
def reload_config(self) -> DetectorConfig:
"""Force reload configuration from file."""
self._config = None
return self.load_config()
# Global config manager instance
_config_manager = ConfigManager()
def get_config() -> DetectorConfig:
"""Get the global configuration instance."""
return _config_manager.get_config()
def reload_config() -> DetectorConfig:
"""Reload configuration from file."""
return _config_manager.reload_config()
"""Get legacy detector config for backward compatibility."""
config_data = config_manager.get_all()
return DetectorConfig(
poll_interval_ms=config_data.get("poll_interval_ms", 100),
target_fps=config_data.get("target_fps", 10),
max_streams=config_data.get("max_streams", 5),
reconnect_interval_sec=config_data.get("reconnect_interval_sec", 5),
max_retries=config_data.get("max_retries", 3),
log_level=config_data.get("log_level", "INFO"),
log_file=config_data.get("log_file", "detector_worker.log"),
websocket_log_file=config_data.get("websocket_log_file", "websocket_comm.log")
)

View file

@ -0,0 +1,514 @@
"""
Dependency Injection Container (IoC Container).
This module provides a comprehensive dependency injection system that manages
all components and their dependencies, enabling clean decoupling and testability.
"""
import logging
import threading
from typing import Dict, Any, Type, Optional, TypeVar, Generic, Callable, Union
from dataclasses import dataclass
from enum import Enum
from abc import ABC, abstractmethod
from .exceptions import DependencyInjectionError
from .singleton_managers import (
ModelStateManager, StreamStateManager, SessionStateManager,
CacheStateManager, CameraStateManager, PipelineStateManager
)
# Setup logging
logger = logging.getLogger("detector_worker.dependency_injection")
T = TypeVar('T')
class ServiceLifetime(Enum):
"""Service lifetime management options."""
SINGLETON = "singleton"
TRANSIENT = "transient"
SCOPED = "scoped"
@dataclass
class ServiceDescriptor:
"""Describes how to create and manage a service."""
service_type: Type
implementation_type: Optional[Type] = None
factory: Optional[Callable] = None
instance: Optional[Any] = None
lifetime: ServiceLifetime = ServiceLifetime.TRANSIENT
dependencies: list = None
class ServiceContainer:
"""
Dependency injection container for managing services and their dependencies.
This container provides comprehensive dependency injection capabilities including:
- Service registration with different lifetimes
- Automatic dependency resolution
- Circular dependency detection
- Thread-safe service creation
"""
def __init__(self):
"""Initialize the service container."""
self._services: Dict[Type, ServiceDescriptor] = {}
self._singletons: Dict[Type, Any] = {}
self._scoped_services: Dict[str, Dict[Type, Any]] = {} # scope_id -> services
self._lock = threading.RLock()
self._resolution_stack: list = []
def register_singleton(
self,
service_type: Type[T],
implementation_type: Optional[Type] = None,
factory: Optional[Callable[[], T]] = None,
instance: Optional[T] = None
) -> 'ServiceContainer':
"""
Register a singleton service.
Args:
service_type: The service interface/type
implementation_type: The concrete implementation
factory: Factory function to create the service
instance: Pre-created instance
Returns:
Self for method chaining
"""
with self._lock:
if instance is not None:
self._singletons[service_type] = instance
descriptor = ServiceDescriptor(
service_type=service_type,
implementation_type=implementation_type,
factory=factory,
instance=instance,
lifetime=ServiceLifetime.SINGLETON
)
self._services[service_type] = descriptor
logger.debug(f"Registered singleton service: {service_type.__name__}")
return self
def register_transient(
self,
service_type: Type[T],
implementation_type: Optional[Type] = None,
factory: Optional[Callable[[], T]] = None
) -> 'ServiceContainer':
"""
Register a transient service (new instance each time).
Args:
service_type: The service interface/type
implementation_type: The concrete implementation
factory: Factory function to create the service
Returns:
Self for method chaining
"""
with self._lock:
descriptor = ServiceDescriptor(
service_type=service_type,
implementation_type=implementation_type,
factory=factory,
lifetime=ServiceLifetime.TRANSIENT
)
self._services[service_type] = descriptor
logger.debug(f"Registered transient service: {service_type.__name__}")
return self
def register_scoped(
self,
service_type: Type[T],
implementation_type: Optional[Type] = None,
factory: Optional[Callable[[], T]] = None
) -> 'ServiceContainer':
"""
Register a scoped service (same instance within a scope).
Args:
service_type: The service interface/type
implementation_type: The concrete implementation
factory: Factory function to create the service
Returns:
Self for method chaining
"""
with self._lock:
descriptor = ServiceDescriptor(
service_type=service_type,
implementation_type=implementation_type,
factory=factory,
lifetime=ServiceLifetime.SCOPED
)
self._services[service_type] = descriptor
logger.debug(f"Registered scoped service: {service_type.__name__}")
return self
def resolve(self, service_type: Type[T], scope_id: Optional[str] = None) -> T:
"""
Resolve a service instance.
Args:
service_type: The service type to resolve
scope_id: Optional scope identifier for scoped services
Returns:
Service instance
Raises:
DependencyInjectionError: If service cannot be resolved
"""
with self._lock:
# Check for circular dependencies
if service_type in self._resolution_stack:
cycle = " -> ".join(cls.__name__ for cls in self._resolution_stack)
cycle += f" -> {service_type.__name__}"
raise DependencyInjectionError(f"Circular dependency detected: {cycle}")
self._resolution_stack.append(service_type)
try:
return self._resolve_service(service_type, scope_id)
finally:
self._resolution_stack.pop()
def _resolve_service(self, service_type: Type[T], scope_id: Optional[str]) -> T:
"""Internal service resolution."""
# Check if service is registered
if service_type not in self._services:
raise DependencyInjectionError(f"Service {service_type.__name__} is not registered")
descriptor = self._services[service_type]
# Handle singleton lifetime
if descriptor.lifetime == ServiceLifetime.SINGLETON:
if service_type in self._singletons:
return self._singletons[service_type]
instance = self._create_instance(descriptor)
self._singletons[service_type] = instance
return instance
# Handle scoped lifetime
elif descriptor.lifetime == ServiceLifetime.SCOPED:
if scope_id is None:
raise DependencyInjectionError(f"Scope ID required for scoped service {service_type.__name__}")
if scope_id not in self._scoped_services:
self._scoped_services[scope_id] = {}
scoped_services = self._scoped_services[scope_id]
if service_type in scoped_services:
return scoped_services[service_type]
instance = self._create_instance(descriptor)
scoped_services[service_type] = instance
return instance
# Handle transient lifetime
else:
return self._create_instance(descriptor)
def _create_instance(self, descriptor: ServiceDescriptor) -> Any:
"""Create a service instance using the descriptor."""
# Use existing instance if available
if descriptor.instance is not None:
return descriptor.instance
# Use factory if available
if descriptor.factory is not None:
try:
return descriptor.factory()
except Exception as e:
raise DependencyInjectionError(f"Failed to create service using factory: {e}")
# Use implementation type
if descriptor.implementation_type is not None:
try:
return self._create_with_dependencies(descriptor.implementation_type)
except Exception as e:
raise DependencyInjectionError(f"Failed to create service {descriptor.implementation_type.__name__}: {e}")
# Use service type directly
try:
return self._create_with_dependencies(descriptor.service_type)
except Exception as e:
raise DependencyInjectionError(f"Failed to create service {descriptor.service_type.__name__}: {e}")
def _create_with_dependencies(self, service_type: Type) -> Any:
"""Create service instance with automatic dependency injection."""
# Get constructor parameters
import inspect
try:
signature = inspect.signature(service_type.__init__)
parameters = list(signature.parameters.values())[1:] # Skip 'self'
# Resolve dependencies
dependencies = []
for param in parameters:
if param.annotation != inspect.Parameter.empty:
# Try to resolve the parameter type
try:
dependency = self.resolve(param.annotation)
dependencies.append(dependency)
except DependencyInjectionError:
# If dependency cannot be resolved and has default, use default
if param.default != inspect.Parameter.empty:
dependencies.append(param.default)
else:
raise DependencyInjectionError(
f"Cannot resolve dependency {param.annotation.__name__} for {service_type.__name__}"
)
else:
# Parameter without type annotation, use default if available
if param.default != inspect.Parameter.empty:
dependencies.append(param.default)
else:
raise DependencyInjectionError(
f"Cannot resolve untyped parameter {param.name} for {service_type.__name__}"
)
return service_type(*dependencies)
except Exception as e:
if isinstance(e, DependencyInjectionError):
raise
else:
# Try to create without dependencies
return service_type()
def create_scope(self, scope_id: str) -> 'ServiceScope':
"""Create a new service scope."""
return ServiceScope(self, scope_id)
def dispose_scope(self, scope_id: str) -> None:
"""Dispose a service scope and its services."""
with self._lock:
if scope_id in self._scoped_services:
scoped_services = self._scoped_services.pop(scope_id)
# Dispose services that implement IDisposable
for service in scoped_services.values():
if hasattr(service, 'dispose'):
try:
service.dispose()
except Exception as e:
logger.error(f"Error disposing scoped service: {e}")
logger.debug(f"Disposed scope {scope_id} with {len(scoped_services)} services")
def is_registered(self, service_type: Type) -> bool:
"""Check if a service type is registered."""
with self._lock:
return service_type in self._services
def get_registration_info(self, service_type: Type) -> Optional[ServiceDescriptor]:
"""Get registration information for a service."""
with self._lock:
return self._services.get(service_type)
def get_registered_services(self) -> Dict[Type, ServiceDescriptor]:
"""Get all registered services."""
with self._lock:
return self._services.copy()
def clear_singletons(self) -> None:
"""Clear all singleton instances."""
with self._lock:
singleton_count = len(self._singletons)
self._singletons.clear()
logger.info(f"Cleared {singleton_count} singleton services")
def get_stats(self) -> Dict[str, Any]:
"""Get container statistics."""
with self._lock:
lifetime_counts = {}
for descriptor in self._services.values():
lifetime = descriptor.lifetime.value
lifetime_counts[lifetime] = lifetime_counts.get(lifetime, 0) + 1
return {
"registered_services": len(self._services),
"active_singletons": len(self._singletons),
"active_scopes": len(self._scoped_services),
"lifetime_breakdown": lifetime_counts
}
class ServiceScope:
"""
Service scope for managing scoped service lifetimes.
This class provides a context for scoped services, ensuring they
are properly disposed when the scope ends.
"""
def __init__(self, container: ServiceContainer, scope_id: str):
"""Initialize service scope."""
self.container = container
self.scope_id = scope_id
def resolve(self, service_type: Type[T]) -> T:
"""Resolve a service within this scope."""
return self.container.resolve(service_type, self.scope_id)
def __enter__(self) -> 'ServiceScope':
"""Enter the scope context."""
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit the scope context and dispose services."""
self.container.dispose_scope(self.scope_id)
class DetectorWorkerContainer:
"""
Pre-configured dependency injection container for the detector worker.
This class sets up all the standard services and managers used by
the detection worker system.
"""
def __init__(self):
"""Initialize the detector worker container."""
self.container = ServiceContainer()
self._register_core_services()
def _register_core_services(self) -> None:
"""Register all core services and managers."""
# Register state managers as singletons
self.container.register_singleton(
ModelStateManager,
instance=ModelStateManager()
)
self.container.register_singleton(
StreamStateManager,
instance=StreamStateManager()
)
self.container.register_singleton(
SessionStateManager,
instance=SessionStateManager()
)
self.container.register_singleton(
CacheStateManager,
instance=CacheStateManager()
)
self.container.register_singleton(
CameraStateManager,
instance=CameraStateManager()
)
self.container.register_singleton(
PipelineStateManager,
instance=PipelineStateManager()
)
# Register other core services
self._register_detection_services()
self._register_communication_services()
self._register_storage_services()
logger.info("Registered all core services in dependency container")
def _register_detection_services(self) -> None:
"""Register detection-related services."""
# These will be registered when the modules are imported
try:
from ..detection.yolo_detector import YOLODetector
from ..detection.tracking_manager import TrackingManager
from ..detection.stability_validator import StabilityValidator
self.container.register_transient(YOLODetector)
self.container.register_singleton(TrackingManager)
self.container.register_transient(StabilityValidator)
except ImportError as e:
logger.warning(f"Could not register detection services: {e}")
def _register_communication_services(self) -> None:
"""Register communication-related services."""
try:
from ..communication.websocket_handler import WebSocketHandler
from ..communication.message_processor import MessageProcessor
from ..communication.response_formatter import ResponseFormatter
self.container.register_transient(WebSocketHandler)
self.container.register_singleton(MessageProcessor)
self.container.register_singleton(ResponseFormatter)
except ImportError as e:
logger.warning(f"Could not register communication services: {e}")
def _register_storage_services(self) -> None:
"""Register storage-related services."""
try:
from ..storage.database_manager import DatabaseManager
from ..storage.redis_client import RedisClientManager
from ..storage.session_cache import SessionCacheManager
self.container.register_transient(DatabaseManager)
self.container.register_singleton(RedisClientManager)
self.container.register_singleton(SessionCacheManager)
except ImportError as e:
logger.warning(f"Could not register storage services: {e}")
def get_container(self) -> ServiceContainer:
"""Get the underlying service container."""
return self.container
def resolve(self, service_type: Type[T]) -> T:
"""Resolve a service from the container."""
return self.container.resolve(service_type)
def create_scope(self, scope_id: str) -> ServiceScope:
"""Create a new service scope."""
return self.container.create_scope(scope_id)
# Global container instance
_global_container: Optional[DetectorWorkerContainer] = None
_container_lock = threading.Lock()
def get_container() -> DetectorWorkerContainer:
"""Get or create the global dependency injection container."""
global _global_container
if _global_container is None:
with _container_lock:
if _global_container is None:
_global_container = DetectorWorkerContainer()
logger.info("Created global dependency injection container")
return _global_container
def resolve_service(service_type: Type[T]) -> T:
"""Convenience function to resolve a service from the global container."""
container = get_container()
return container.resolve(service_type)
def create_service_scope(scope_id: str) -> ServiceScope:
"""Convenience function to create a service scope."""
container = get_container()
return container.create_scope(scope_id)

View file

@ -122,6 +122,16 @@ class InvalidStateError(DetectorWorkerError):
pass
class StateManagerError(DetectorWorkerError):
"""Raised when state manager operations fail."""
pass
class DependencyInjectionError(DetectorWorkerError):
"""Raised when dependency injection operations fail."""
pass
# ===== ERROR CONTEXT HELPERS =====
def add_error_context(exception: Exception, **context) -> DetectorWorkerError:

View file

@ -0,0 +1,767 @@
"""
Singleton managers to replace global state variables.
This module provides singleton managers that replace the global dictionaries
and variables used throughout the detection worker, enabling proper dependency
injection and cleaner state management.
"""
import threading
import time
import logging
from typing import Dict, Any, Optional, List, Set, Callable
from dataclasses import dataclass, field
from contextlib import contextmanager
from .exceptions import StateManagerError
# Setup logging
logger = logging.getLogger("detector_worker.singleton_managers")
class SingletonMeta(type):
"""Metaclass for thread-safe singleton implementation."""
_instances = {}
_lock: threading.Lock = threading.Lock()
def __call__(cls, *args, **kwargs):
with cls._lock:
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
@dataclass
class ModelInfo:
"""Information about a loaded model."""
model_id: str
model_tree: Any
camera_id: str
loaded_at: float = field(default_factory=time.time)
reference_count: int = 0
@dataclass
class StreamInfo:
"""Information about an active stream."""
camera_id: str
subscription_id: str
config: Dict[str, Any]
created_at: float = field(default_factory=time.time)
active: bool = True
@dataclass
class SessionInfo:
"""Session information."""
session_id: str
display_id: str
camera_id: str
created_at: float = field(default_factory=time.time)
last_activity: float = field(default_factory=time.time)
data: Dict[str, Any] = field(default_factory=dict)
class ModelStateManager(metaclass=SingletonMeta):
"""
Singleton manager for model state (replaces global 'models' dict).
Thread-safe management of loaded ML models with reference counting
and automatic cleanup.
"""
def __init__(self):
"""Initialize the model state manager."""
self._models: Dict[str, Dict[str, ModelInfo]] = {} # camera_id -> {model_id -> ModelInfo}
self._lock = threading.RLock()
@contextmanager
def _thread_safe(self):
"""Context manager for thread-safe operations."""
with self._lock:
yield
def load_model(self, camera_id: str, model_id: str, model_tree: Any) -> None:
"""
Load a model for a camera.
Args:
camera_id: Camera identifier
model_id: Model identifier
model_tree: Loaded model tree
"""
with self._thread_safe():
if camera_id not in self._models:
self._models[camera_id] = {}
model_info = ModelInfo(
model_id=model_id,
model_tree=model_tree,
camera_id=camera_id,
reference_count=1
)
if model_id in self._models[camera_id]:
# Increment reference count for existing model
self._models[camera_id][model_id].reference_count += 1
else:
self._models[camera_id][model_id] = model_info
logger.debug(f"Loaded model {model_id} for camera {camera_id}")
def get_model(self, camera_id: str, model_id: str) -> Optional[Any]:
"""Get a model tree for a camera."""
with self._thread_safe():
camera_models = self._models.get(camera_id, {})
model_info = camera_models.get(model_id)
return model_info.model_tree if model_info else None
def get_camera_models(self, camera_id: str) -> Dict[str, Any]:
"""Get all models for a camera."""
with self._thread_safe():
camera_models = self._models.get(camera_id, {})
return {mid: info.model_tree for mid, info in camera_models.items()}
def unload_model(self, camera_id: str, model_id: str) -> bool:
"""
Unload a model (decrement reference count).
Returns:
True if model was completely removed, False if still referenced
"""
with self._thread_safe():
if camera_id not in self._models:
return True
camera_models = self._models[camera_id]
if model_id not in camera_models:
return True
model_info = camera_models[model_id]
model_info.reference_count -= 1
if model_info.reference_count <= 0:
del camera_models[model_id]
logger.debug(f"Unloaded model {model_id} for camera {camera_id}")
# Clean up camera entry if no models left
if not camera_models:
del self._models[camera_id]
return True
return False
def unload_camera_models(self, camera_id: str) -> None:
"""Unload all models for a camera."""
with self._thread_safe():
if camera_id in self._models:
model_count = len(self._models[camera_id])
del self._models[camera_id]
logger.debug(f"Unloaded {model_count} models for camera {camera_id}")
def get_all_models(self) -> Dict[str, Dict[str, Any]]:
"""Get all loaded models (for backward compatibility)."""
with self._thread_safe():
result = {}
for camera_id, models in self._models.items():
result[camera_id] = {mid: info.model_tree for mid, info in models.items()}
return result
def clear_all(self) -> None:
"""Clear all models."""
with self._thread_safe():
model_count = sum(len(models) for models in self._models.values())
self._models.clear()
logger.info(f"Cleared {model_count} models from state manager")
def get_stats(self) -> Dict[str, Any]:
"""Get model statistics."""
with self._thread_safe():
total_models = sum(len(models) for models in self._models.values())
total_cameras = len(self._models)
return {
"total_models": total_models,
"total_cameras": total_cameras,
"cameras": list(self._models.keys())
}
class StreamStateManager(metaclass=SingletonMeta):
"""
Singleton manager for stream state (replaces global stream-related dicts).
Manages streams, camera_streams, subscription_to_camera mappings.
"""
def __init__(self):
"""Initialize stream state manager."""
self._streams: Dict[str, StreamInfo] = {} # camera_id -> StreamInfo
self._camera_streams: Dict[str, Dict[str, Any]] = {} # camera_url -> shared_stream_info
self._subscription_to_camera: Dict[str, str] = {} # subscription_id -> camera_id
self._lock = threading.RLock()
@contextmanager
def _thread_safe(self):
"""Context manager for thread-safe operations."""
with self._lock:
yield
def add_stream(
self,
camera_id: str,
subscription_id: str,
config: Dict[str, Any]
) -> None:
"""Add a new stream."""
with self._thread_safe():
stream_info = StreamInfo(
camera_id=camera_id,
subscription_id=subscription_id,
config=config.copy()
)
self._streams[camera_id] = stream_info
self._subscription_to_camera[subscription_id] = camera_id
logger.debug(f"Added stream for camera {camera_id} with subscription {subscription_id}")
def remove_stream(self, camera_id: str) -> Optional[StreamInfo]:
"""Remove a stream."""
with self._thread_safe():
stream_info = self._streams.pop(camera_id, None)
if stream_info:
# Clean up subscription mapping
self._subscription_to_camera.pop(stream_info.subscription_id, None)
logger.debug(f"Removed stream for camera {camera_id}")
return stream_info
def get_stream(self, camera_id: str) -> Optional[StreamInfo]:
"""Get stream information."""
with self._thread_safe():
return self._streams.get(camera_id)
def get_camera_by_subscription(self, subscription_id: str) -> Optional[str]:
"""Get camera ID by subscription ID."""
with self._thread_safe():
return self._subscription_to_camera.get(subscription_id)
def add_shared_stream(self, camera_url: str, stream_data: Dict[str, Any]) -> None:
"""Add a shared camera stream."""
with self._thread_safe():
self._camera_streams[camera_url] = stream_data
def get_shared_stream(self, camera_url: str) -> Optional[Dict[str, Any]]:
"""Get shared stream data."""
with self._thread_safe():
return self._camera_streams.get(camera_url)
def remove_shared_stream(self, camera_url: str) -> Optional[Dict[str, Any]]:
"""Remove shared stream."""
with self._thread_safe():
return self._camera_streams.pop(camera_url, None)
def get_all_streams(self) -> Dict[str, Dict[str, Any]]:
"""Get all streams (for backward compatibility)."""
with self._thread_safe():
return {cid: info.config for cid, info in self._streams.items()}
def get_all_camera_streams(self) -> Dict[str, Dict[str, Any]]:
"""Get all shared camera streams."""
with self._thread_safe():
return self._camera_streams.copy()
def get_subscription_mappings(self) -> Dict[str, str]:
"""Get subscription to camera mappings."""
with self._thread_safe():
return self._subscription_to_camera.copy()
def clear_all(self) -> None:
"""Clear all stream data."""
with self._thread_safe():
stream_count = len(self._streams)
camera_stream_count = len(self._camera_streams)
subscription_count = len(self._subscription_to_camera)
self._streams.clear()
self._camera_streams.clear()
self._subscription_to_camera.clear()
logger.info(f"Cleared {stream_count} streams, {camera_stream_count} camera streams, {subscription_count} subscriptions")
def get_stats(self) -> Dict[str, Any]:
"""Get stream statistics."""
with self._thread_safe():
return {
"active_streams": len(self._streams),
"shared_camera_streams": len(self._camera_streams),
"subscription_mappings": len(self._subscription_to_camera)
}
class SessionStateManager(metaclass=SingletonMeta):
"""
Singleton manager for session state (replaces session-related global dicts).
Manages session_ids, session_detections, session_to_camera, detection_timestamps.
"""
def __init__(self, session_ttl: float = 600.0): # 10 minutes default TTL
"""Initialize session state manager."""
self._session_ids: Dict[str, str] = {} # display_id -> session_id
self._session_detections: Dict[str, Dict[str, Any]] = {} # session_id -> detection_data
self._session_to_camera: Dict[str, str] = {} # session_id -> camera_id
self._detection_timestamps: Dict[str, float] = {} # session_id -> timestamp
self._session_ttl = session_ttl
self._lock = threading.RLock()
@contextmanager
def _thread_safe(self):
"""Context manager for thread-safe operations."""
with self._lock:
yield
def set_session_id(self, display_id: str, session_id: str) -> None:
"""Set session ID for a display."""
with self._thread_safe():
self._session_ids[display_id] = session_id
logger.debug(f"Set session {session_id} for display {display_id}")
def get_session_id(self, display_id: str) -> Optional[str]:
"""Get session ID for a display."""
with self._thread_safe():
return self._session_ids.get(display_id)
def create_session(
self,
session_id: str,
camera_id: str,
detection_data: Dict[str, Any]
) -> None:
"""Create a new session with detection data."""
with self._thread_safe():
self._session_detections[session_id] = detection_data.copy()
self._session_to_camera[session_id] = camera_id
self._detection_timestamps[session_id] = time.time()
logger.debug(f"Created session {session_id} for camera {camera_id}")
def get_session_detection(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get detection data for a session."""
with self._thread_safe():
return self._session_detections.get(session_id)
def update_session_detection(self, session_id: str, detection_data: Dict[str, Any]) -> None:
"""Update detection data for a session."""
with self._thread_safe():
if session_id in self._session_detections:
self._session_detections[session_id].update(detection_data)
self._detection_timestamps[session_id] = time.time()
def get_camera_by_session(self, session_id: str) -> Optional[str]:
"""Get camera ID by session ID."""
with self._thread_safe():
return self._session_to_camera.get(session_id)
def remove_session(self, session_id: str) -> bool:
"""Remove a session completely."""
with self._thread_safe():
removed = False
if session_id in self._session_detections:
del self._session_detections[session_id]
removed = True
if session_id in self._session_to_camera:
del self._session_to_camera[session_id]
if session_id in self._detection_timestamps:
del self._detection_timestamps[session_id]
if removed:
logger.debug(f"Removed session {session_id}")
return removed
def cleanup_expired_sessions(self) -> int:
"""Clean up expired sessions based on TTL."""
with self._thread_safe():
current_time = time.time()
expired_sessions = []
for session_id, timestamp in self._detection_timestamps.items():
if current_time - timestamp > self._session_ttl:
expired_sessions.append(session_id)
for session_id in expired_sessions:
self.remove_session(session_id)
if expired_sessions:
logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")
return len(expired_sessions)
def get_all_session_ids(self) -> Dict[str, str]:
"""Get all session IDs (for backward compatibility)."""
with self._thread_safe():
return self._session_ids.copy()
def get_all_session_detections(self) -> Dict[str, Dict[str, Any]]:
"""Get all session detections."""
with self._thread_safe():
return self._session_detections.copy()
def clear_all(self) -> None:
"""Clear all session data."""
with self._thread_safe():
session_count = len(self._session_ids)
detection_count = len(self._session_detections)
self._session_ids.clear()
self._session_detections.clear()
self._session_to_camera.clear()
self._detection_timestamps.clear()
logger.info(f"Cleared {session_count} session IDs and {detection_count} session detections")
def get_stats(self) -> Dict[str, Any]:
"""Get session statistics."""
with self._thread_safe():
current_time = time.time()
active_sessions = sum(
1 for timestamp in self._detection_timestamps.values()
if current_time - timestamp <= self._session_ttl
)
return {
"total_display_sessions": len(self._session_ids),
"total_detection_sessions": len(self._session_detections),
"active_sessions": active_sessions,
"session_ttl": self._session_ttl
}
class CacheStateManager(metaclass=SingletonMeta):
"""
Singleton manager for cache state (replaces cache-related global dicts).
Manages cached_detections, cached_full_pipeline_results, latest_frames.
"""
def __init__(self):
"""Initialize cache state manager."""
self._cached_detections: Dict[str, Dict[str, Any]] = {} # camera_id -> detection_data
self._cached_pipeline_results: Dict[str, Dict[str, Any]] = {} # camera_id -> pipeline_result
self._latest_frames: Dict[str, Any] = {} # camera_id -> frame_data
self._frame_skip_flags: Dict[str, bool] = {} # camera_id -> skip_flag
self._lock = threading.RLock()
@contextmanager
def _thread_safe(self):
"""Context manager for thread-safe operations."""
with self._lock:
yield
def cache_detection(self, camera_id: str, detection_data: Dict[str, Any]) -> None:
"""Cache detection result for a camera."""
with self._thread_safe():
self._cached_detections[camera_id] = detection_data.copy()
def get_cached_detection(self, camera_id: str) -> Optional[Dict[str, Any]]:
"""Get cached detection for a camera."""
with self._thread_safe():
return self._cached_detections.get(camera_id)
def clear_cached_detection(self, camera_id: str) -> bool:
"""Clear cached detection for a camera."""
with self._thread_safe():
return self._cached_detections.pop(camera_id, None) is not None
def cache_pipeline_result(self, camera_id: str, result: Dict[str, Any]) -> None:
"""Cache pipeline result for a camera."""
with self._thread_safe():
self._cached_pipeline_results[camera_id] = result.copy()
def get_cached_pipeline_result(self, camera_id: str) -> Optional[Dict[str, Any]]:
"""Get cached pipeline result for a camera."""
with self._thread_safe():
return self._cached_pipeline_results.get(camera_id)
def clear_cached_pipeline_result(self, camera_id: str) -> bool:
"""Clear cached pipeline result for a camera."""
with self._thread_safe():
return self._cached_pipeline_results.pop(camera_id, None) is not None
def set_latest_frame(self, camera_id: str, frame_data: Any) -> None:
"""Set latest frame for a camera."""
with self._thread_safe():
self._latest_frames[camera_id] = frame_data
def get_latest_frame(self, camera_id: str) -> Optional[Any]:
"""Get latest frame for a camera."""
with self._thread_safe():
return self._latest_frames.get(camera_id)
def set_frame_skip_flag(self, camera_id: str, skip: bool) -> None:
"""Set frame skip flag for a camera."""
with self._thread_safe():
self._frame_skip_flags[camera_id] = skip
def get_frame_skip_flag(self, camera_id: str) -> bool:
"""Get frame skip flag for a camera."""
with self._thread_safe():
return self._frame_skip_flags.get(camera_id, False)
def clear_camera_cache(self, camera_id: str) -> None:
"""Clear all cache data for a camera."""
with self._thread_safe():
self._cached_detections.pop(camera_id, None)
self._cached_pipeline_results.pop(camera_id, None)
self._latest_frames.pop(camera_id, None)
self._frame_skip_flags.pop(camera_id, None)
logger.debug(f"Cleared cache for camera {camera_id}")
def get_all_cached_detections(self) -> Dict[str, Dict[str, Any]]:
"""Get all cached detections (for backward compatibility)."""
with self._thread_safe():
return self._cached_detections.copy()
def get_all_latest_frames(self) -> Dict[str, Any]:
"""Get all latest frames (for backward compatibility)."""
with self._thread_safe():
return self._latest_frames.copy()
def clear_all(self) -> None:
"""Clear all cache data."""
with self._thread_safe():
detection_count = len(self._cached_detections)
pipeline_count = len(self._cached_pipeline_results)
frame_count = len(self._latest_frames)
self._cached_detections.clear()
self._cached_pipeline_results.clear()
self._latest_frames.clear()
self._frame_skip_flags.clear()
logger.info(f"Cleared {detection_count} cached detections, {pipeline_count} pipeline results, {frame_count} frames")
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
with self._thread_safe():
return {
"cached_detections": len(self._cached_detections),
"cached_pipeline_results": len(self._cached_pipeline_results),
"latest_frames": len(self._latest_frames),
"frame_skip_flags": len(self._frame_skip_flags)
}
class CameraStateManager(metaclass=SingletonMeta):
"""
Singleton manager for camera connection state.
Manages camera_states and connection monitoring.
"""
def __init__(self):
"""Initialize camera state manager."""
self._camera_states: Dict[str, Dict[str, Any]] = {} # camera_id -> state_info
self._lock = threading.RLock()
@contextmanager
def _thread_safe(self):
"""Context manager for thread-safe operations."""
with self._lock:
yield
def set_camera_connected(self, camera_id: str, connected: bool = True) -> None:
"""Set camera connection state."""
with self._thread_safe():
if camera_id not in self._camera_states:
self._camera_states[camera_id] = {}
self._camera_states[camera_id].update({
"connected": connected,
"last_update": time.time(),
"disconnection_notified": False if connected else self._camera_states[camera_id].get("disconnection_notified", False),
"reconnection_notified": False if not connected else self._camera_states[camera_id].get("reconnection_notified", False)
})
logger.debug(f"Camera {camera_id} connection state: {connected}")
def is_camera_connected(self, camera_id: str) -> bool:
"""Check if camera is connected."""
with self._thread_safe():
state = self._camera_states.get(camera_id, {})
return state.get("connected", True) # Default to connected
def should_notify_disconnection(self, camera_id: str) -> bool:
"""Check if disconnection should be notified."""
with self._thread_safe():
state = self._camera_states.get(camera_id, {})
return not state.get("connected", True) and not state.get("disconnection_notified", False)
def mark_disconnection_notified(self, camera_id: str) -> None:
"""Mark disconnection as notified."""
with self._thread_safe():
if camera_id in self._camera_states:
self._camera_states[camera_id]["disconnection_notified"] = True
def should_notify_reconnection(self, camera_id: str) -> bool:
"""Check if reconnection should be notified."""
with self._thread_safe():
state = self._camera_states.get(camera_id, {})
return state.get("connected", True) and not state.get("reconnection_notified", True)
def mark_reconnection_notified(self, camera_id: str) -> None:
"""Mark reconnection as notified."""
with self._thread_safe():
if camera_id in self._camera_states:
self._camera_states[camera_id]["reconnection_notified"] = True
def get_camera_state(self, camera_id: str) -> Dict[str, Any]:
"""Get full camera state."""
with self._thread_safe():
return self._camera_states.get(camera_id, {}).copy()
def remove_camera_state(self, camera_id: str) -> bool:
"""Remove camera state."""
with self._thread_safe():
return self._camera_states.pop(camera_id, None) is not None
def get_all_camera_states(self) -> Dict[str, Dict[str, Any]]:
"""Get all camera states (for backward compatibility)."""
with self._thread_safe():
return self._camera_states.copy()
def clear_all(self) -> None:
"""Clear all camera states."""
with self._thread_safe():
state_count = len(self._camera_states)
self._camera_states.clear()
logger.info(f"Cleared {state_count} camera states")
def get_stats(self) -> Dict[str, Any]:
"""Get camera state statistics."""
with self._thread_safe():
connected_count = sum(
1 for state in self._camera_states.values()
if state.get("connected", True)
)
return {
"total_cameras": len(self._camera_states),
"connected_cameras": connected_count,
"disconnected_cameras": len(self._camera_states) - connected_count
}
class PipelineStateManager(metaclass=SingletonMeta):
"""
Singleton manager for pipeline state (replaces session_pipeline_states).
Manages session pipeline states and mode switching.
"""
def __init__(self):
"""Initialize pipeline state manager."""
self._pipeline_states: Dict[str, Dict[str, Any]] = {} # camera_id -> pipeline_state
self._lock = threading.RLock()
@contextmanager
def _thread_safe(self):
"""Context manager for thread-safe operations."""
with self._lock:
yield
def get_or_init_state(self, camera_id: str) -> Dict[str, Any]:
"""Get or initialize pipeline state for a camera."""
with self._thread_safe():
if camera_id not in self._pipeline_states:
self._pipeline_states[camera_id] = {
"mode": "validation_detecting",
"backend_session_id": None,
"yolo_inference_enabled": True,
"progression_stage": None,
"validated_detection": None,
"created_at": time.time()
}
return self._pipeline_states[camera_id].copy()
def update_mode(self, camera_id: str, mode: str, session_id: Optional[str] = None) -> None:
"""Update pipeline mode for a camera."""
with self._thread_safe():
state = self.get_or_init_state(camera_id) # This creates if not exists
self._pipeline_states[camera_id]["mode"] = mode
if session_id is not None:
self._pipeline_states[camera_id]["backend_session_id"] = session_id
logger.debug(f"Updated pipeline mode for camera {camera_id}: {mode}")
def set_yolo_inference_enabled(self, camera_id: str, enabled: bool) -> None:
"""Set YOLO inference enabled state."""
with self._thread_safe():
state = self.get_or_init_state(camera_id) # This creates if not exists
self._pipeline_states[camera_id]["yolo_inference_enabled"] = enabled
def set_progression_stage(self, camera_id: str, stage: str) -> None:
"""Set progression stage for a camera."""
with self._thread_safe():
state = self.get_or_init_state(camera_id) # This creates if not exists
self._pipeline_states[camera_id]["progression_stage"] = stage
def set_validated_detection(self, camera_id: str, detection: Optional[Dict[str, Any]]) -> None:
"""Set validated detection for a camera."""
with self._thread_safe():
state = self.get_or_init_state(camera_id) # This creates if not exists
self._pipeline_states[camera_id]["validated_detection"] = detection
def get_state(self, camera_id: str) -> Dict[str, Any]:
"""Get pipeline state for a camera."""
with self._thread_safe():
return self.get_or_init_state(camera_id)
def remove_state(self, camera_id: str) -> bool:
"""Remove pipeline state for a camera."""
with self._thread_safe():
return self._pipeline_states.pop(camera_id, None) is not None
def get_all_states(self) -> Dict[str, Dict[str, Any]]:
"""Get all pipeline states (for backward compatibility)."""
with self._thread_safe():
return self._pipeline_states.copy()
def clear_all(self) -> None:
"""Clear all pipeline states."""
with self._thread_safe():
state_count = len(self._pipeline_states)
self._pipeline_states.clear()
logger.info(f"Cleared {state_count} pipeline states")
def get_stats(self) -> Dict[str, Any]:
"""Get pipeline state statistics."""
with self._thread_safe():
mode_counts = {}
for state in self._pipeline_states.values():
mode = state.get("mode", "unknown")
mode_counts[mode] = mode_counts.get(mode, 0) + 1
return {
"total_pipeline_states": len(self._pipeline_states),
"mode_breakdown": mode_counts
}
# Global singleton instances (for backward compatibility and easy access)
model_state_manager = ModelStateManager()
stream_state_manager = StreamStateManager()
session_state_manager = SessionStateManager()
cache_state_manager = CacheStateManager()
camera_state_manager = CameraStateManager()
pipeline_state_manager = PipelineStateManager()