Refactor: PHASE 6: Decoupling & Integration
This commit is contained in:
parent
6c7c4c5d9c
commit
accefde8a1
8 changed files with 2344 additions and 86 deletions
|
@ -14,17 +14,18 @@ from typing import Dict, Any, Optional, Callable, List, Set
|
|||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import WebSocket
|
||||
from websockets.exceptions import ConnectionClosedError, WebSocketDisconnect
|
||||
from fastapi.websockets import WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
from ..core.config import config, subscription_to_camera, latest_frames
|
||||
from ..core.constants import HEARTBEAT_INTERVAL
|
||||
from ..core.exceptions import WebSocketError, StreamError
|
||||
from ..streams.stream_manager import StreamManager
|
||||
from ..streams.camera_monitor import CameraConnectionMonitor
|
||||
from ..streams.camera_monitor import CameraMonitor
|
||||
from ..detection.detection_result import DetectionResult
|
||||
from ..models.model_manager import ModelManager
|
||||
from ..pipeline.pipeline_executor import PipelineExecutor
|
||||
from ..storage.session_cache import SessionCache
|
||||
from ..storage.session_cache import SessionCacheManager
|
||||
from ..storage.redis_client import RedisClientManager
|
||||
from ..utils.system_monitor import get_system_metrics
|
||||
|
||||
|
@ -55,7 +56,7 @@ class WebSocketHandler:
|
|||
stream_manager: StreamManager,
|
||||
model_manager: ModelManager,
|
||||
pipeline_executor: PipelineExecutor,
|
||||
session_cache: SessionCache,
|
||||
session_cache: SessionCacheManager,
|
||||
redis_client: Optional[RedisClientManager] = None
|
||||
):
|
||||
"""
|
||||
|
@ -603,7 +604,7 @@ async def handle_websocket_connection(
|
|||
stream_manager: StreamManager,
|
||||
model_manager: ModelManager,
|
||||
pipeline_executor: PipelineExecutor,
|
||||
session_cache: SessionCache,
|
||||
session_cache: SessionCacheManager,
|
||||
redis_client: Optional[RedisClientManager] = None
|
||||
) -> None:
|
||||
"""
|
||||
|
|
|
@ -1,22 +1,429 @@
|
|||
"""
|
||||
Configuration management for detector worker.
|
||||
Centralized configuration management for the detector worker.
|
||||
|
||||
This module handles application configuration loading, validation,
|
||||
and provides centralized access to configuration parameters.
|
||||
This module provides comprehensive configuration management including:
|
||||
- Environment variable support
|
||||
- Configuration validation
|
||||
- Hot-reload capabilities
|
||||
- Type-safe configuration access
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .exceptions import ConfigurationError
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger("detector_worker.config")
|
||||
|
||||
|
||||
class ConfigurationProvider(ABC):
|
||||
"""Abstract base class for configuration providers."""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get configuration data."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reload(self) -> bool:
|
||||
"""Reload configuration data."""
|
||||
pass
|
||||
|
||||
|
||||
class JsonFileProvider(ConfigurationProvider):
|
||||
"""Configuration provider that reads from JSON files."""
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
"""Initialize with JSON file path."""
|
||||
self.file_path = Path(file_path)
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._last_modified: Optional[float] = None
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get configuration from JSON file."""
|
||||
if not self._config or self._should_reload():
|
||||
self.reload()
|
||||
return self._config.copy()
|
||||
|
||||
def reload(self) -> bool:
|
||||
"""Reload configuration from file."""
|
||||
try:
|
||||
if self.file_path.exists():
|
||||
with open(self.file_path, 'r') as f:
|
||||
self._config = json.load(f)
|
||||
self._last_modified = self.file_path.stat().st_mtime
|
||||
logger.debug(f"Loaded configuration from {self.file_path}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Configuration file not found: {self.file_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading configuration from {self.file_path}: {e}")
|
||||
return False
|
||||
|
||||
def _should_reload(self) -> bool:
|
||||
"""Check if file has been modified since last load."""
|
||||
if not self.file_path.exists():
|
||||
return False
|
||||
|
||||
current_mtime = self.file_path.stat().st_mtime
|
||||
return self._last_modified is None or current_mtime > self._last_modified
|
||||
|
||||
|
||||
class EnvironmentProvider(ConfigurationProvider):
|
||||
"""Configuration provider that reads from environment variables."""
|
||||
|
||||
def __init__(self, prefix: str = "DETECTOR_"):
|
||||
"""Initialize with environment variable prefix."""
|
||||
self.prefix = prefix
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get configuration from environment variables."""
|
||||
config = {}
|
||||
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith(self.prefix):
|
||||
# Convert DETECTOR_POLL_INTERVAL_MS -> poll_interval_ms
|
||||
config_key = key[len(self.prefix):].lower()
|
||||
|
||||
# Try to parse as JSON first, then as string
|
||||
try:
|
||||
config[config_key] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
config[config_key] = value
|
||||
|
||||
return config
|
||||
|
||||
def reload(self) -> bool:
|
||||
"""Environment variables don't need explicit reload."""
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database configuration."""
|
||||
enabled: bool = False
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
database: str = "detector_worker"
|
||||
username: str = "postgres"
|
||||
password: str = ""
|
||||
schema: str = "public"
|
||||
connection_pool_size: int = 10
|
||||
connection_timeout: int = 30
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DatabaseConfig':
|
||||
"""Create from dictionary."""
|
||||
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedisConfig:
|
||||
"""Redis configuration."""
|
||||
enabled: bool = False
|
||||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
password: Optional[str] = None
|
||||
db: int = 0
|
||||
connection_pool_size: int = 10
|
||||
connection_timeout: int = 30
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
|
||||
"""Create from dictionary."""
|
||||
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamConfig:
|
||||
"""Stream processing configuration."""
|
||||
poll_interval_ms: int = 100
|
||||
max_streams: int = 5
|
||||
target_fps: int = 10
|
||||
reconnect_interval_sec: int = 5
|
||||
max_retries: int = 3
|
||||
buffer_size: int = 1
|
||||
timeout_ms: int = 10000
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'StreamConfig':
|
||||
"""Create from dictionary."""
|
||||
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Model configuration."""
|
||||
models_dir: str = "models"
|
||||
cache_size_mb: int = 1024
|
||||
load_timeout_sec: int = 60
|
||||
inference_timeout_sec: int = 10
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ModelConfig':
|
||||
"""Create from dictionary."""
|
||||
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
"""Logging configuration."""
|
||||
level: str = "INFO"
|
||||
format: str = "%(asctime)s | %(levelname)s | %(name)s | %(message)s"
|
||||
file_path: Optional[str] = "detector_worker.log"
|
||||
max_file_size_mb: int = 10
|
||||
backup_count: int = 5
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'LoggingConfig':
|
||||
"""Create from dictionary."""
|
||||
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
|
||||
|
||||
|
||||
class ConfigurationManager:
|
||||
"""Centralized configuration manager with multiple providers and validation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize configuration manager."""
|
||||
self._providers: List[ConfigurationProvider] = []
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._typed_configs: Dict[str, Any] = {}
|
||||
self._lock = threading.RLock()
|
||||
self._default_config = self._get_default_config()
|
||||
|
||||
# Set up default providers
|
||||
self.add_provider(JsonFileProvider("config.json"))
|
||||
self.add_provider(EnvironmentProvider("DETECTOR_"))
|
||||
|
||||
# Load initial configuration
|
||||
self.reload()
|
||||
|
||||
def add_provider(self, provider: ConfigurationProvider) -> None:
|
||||
"""Add a configuration provider."""
|
||||
with self._lock:
|
||||
self._providers.append(provider)
|
||||
logger.debug(f"Added configuration provider: {type(provider).__name__}")
|
||||
|
||||
def reload(self) -> bool:
|
||||
"""Reload configuration from all providers."""
|
||||
with self._lock:
|
||||
# Start with default configuration
|
||||
config = self._default_config.copy()
|
||||
|
||||
# Merge configurations from providers (later providers override earlier ones)
|
||||
for provider in self._providers:
|
||||
try:
|
||||
provider_config = provider.get_config()
|
||||
config.update(provider_config)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading from provider {type(provider).__name__}: {e}")
|
||||
|
||||
self._config = config
|
||||
self._update_typed_configs()
|
||||
|
||||
logger.debug("Configuration reloaded")
|
||||
return True
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration values."""
|
||||
return {
|
||||
"poll_interval_ms": 100,
|
||||
"max_streams": 5,
|
||||
"target_fps": 10,
|
||||
"reconnect_interval_sec": 5,
|
||||
"max_retries": 3,
|
||||
"heartbeat_interval": 2,
|
||||
"session_timeout": 600,
|
||||
"models_dir": "models",
|
||||
"log_level": "INFO",
|
||||
"database": {
|
||||
"enabled": False,
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "detector_worker",
|
||||
"username": "postgres",
|
||||
"password": "",
|
||||
"schema": "public"
|
||||
},
|
||||
"redis": {
|
||||
"enabled": False,
|
||||
"host": "localhost",
|
||||
"port": 6379,
|
||||
"password": None,
|
||||
"db": 0
|
||||
}
|
||||
}
|
||||
|
||||
def _update_typed_configs(self) -> None:
|
||||
"""Update typed configuration objects."""
|
||||
# Database configuration
|
||||
db_config = self._config.get("database", {})
|
||||
self._typed_configs["database"] = DatabaseConfig.from_dict(db_config)
|
||||
|
||||
# Redis configuration
|
||||
redis_config = self._config.get("redis", {})
|
||||
self._typed_configs["redis"] = RedisConfig.from_dict(redis_config)
|
||||
|
||||
# Stream configuration
|
||||
stream_config = {k: v for k, v in self._config.items()
|
||||
if k in StreamConfig.__dataclass_fields__}
|
||||
self._typed_configs["stream"] = StreamConfig.from_dict(stream_config)
|
||||
|
||||
# Model configuration
|
||||
model_config = {k: v for k, v in self._config.items()
|
||||
if k in ModelConfig.__dataclass_fields__}
|
||||
self._typed_configs["model"] = ModelConfig.from_dict(model_config)
|
||||
|
||||
# Logging configuration
|
||||
logging_config = self._config.get("logging", {})
|
||||
self._typed_configs["logging"] = LoggingConfig.from_dict(logging_config)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a configuration value."""
|
||||
with self._lock:
|
||||
return self._config.get(key, default)
|
||||
|
||||
def get_section(self, section: str) -> Dict[str, Any]:
|
||||
"""Get a configuration section."""
|
||||
with self._lock:
|
||||
return self._config.get(section, {})
|
||||
|
||||
def get_database_config(self) -> DatabaseConfig:
|
||||
"""Get typed database configuration."""
|
||||
with self._lock:
|
||||
return self._typed_configs["database"]
|
||||
|
||||
def get_redis_config(self) -> RedisConfig:
|
||||
"""Get typed Redis configuration."""
|
||||
with self._lock:
|
||||
return self._typed_configs["redis"]
|
||||
|
||||
def get_stream_config(self) -> StreamConfig:
|
||||
"""Get typed stream configuration."""
|
||||
with self._lock:
|
||||
return self._typed_configs["stream"]
|
||||
|
||||
def get_model_config(self) -> ModelConfig:
|
||||
"""Get typed model configuration."""
|
||||
with self._lock:
|
||||
return self._typed_configs["model"]
|
||||
|
||||
def get_logging_config(self) -> LoggingConfig:
|
||||
"""Get typed logging configuration."""
|
||||
with self._lock:
|
||||
return self._typed_configs["logging"]
|
||||
|
||||
def get_all(self) -> Dict[str, Any]:
|
||||
"""Get all configuration values."""
|
||||
with self._lock:
|
||||
return self._config.copy()
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""Set a configuration value (runtime only)."""
|
||||
with self._lock:
|
||||
self._config[key] = value
|
||||
self._update_typed_configs()
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate configuration and return any errors."""
|
||||
errors = []
|
||||
|
||||
with self._lock:
|
||||
# Validate stream configuration
|
||||
if self._config.get("poll_interval_ms", 0) <= 0:
|
||||
errors.append("poll_interval_ms must be positive")
|
||||
|
||||
if self._config.get("max_streams", 0) <= 0:
|
||||
errors.append("max_streams must be positive")
|
||||
|
||||
if self._config.get("target_fps", 0) <= 0:
|
||||
errors.append("target_fps must be positive")
|
||||
|
||||
# Validate database configuration if enabled
|
||||
db_config = self.get_database_config()
|
||||
if db_config.enabled:
|
||||
if not db_config.host:
|
||||
errors.append("database host is required when database is enabled")
|
||||
if not db_config.database:
|
||||
errors.append("database name is required when database is enabled")
|
||||
|
||||
# Validate Redis configuration if enabled
|
||||
redis_config = self.get_redis_config()
|
||||
if redis_config.enabled:
|
||||
if not redis_config.host:
|
||||
errors.append("redis host is required when redis is enabled")
|
||||
|
||||
return errors
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if configuration is valid."""
|
||||
return len(self.validate()) == 0
|
||||
|
||||
|
||||
# Global configuration manager instance
|
||||
_config_manager: Optional[ConfigurationManager] = None
|
||||
_config_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_config_manager() -> ConfigurationManager:
|
||||
"""Get or create the global configuration manager."""
|
||||
global _config_manager
|
||||
|
||||
if _config_manager is None:
|
||||
with _config_lock:
|
||||
if _config_manager is None:
|
||||
_config_manager = ConfigurationManager()
|
||||
logger.info("Created global configuration manager")
|
||||
|
||||
return _config_manager
|
||||
|
||||
|
||||
# Backward compatibility - these will be replaced by singleton managers
|
||||
subscription_to_camera: Dict[str, str] = {}
|
||||
latest_frames: Dict[str, Any] = {}
|
||||
|
||||
# Configuration access
|
||||
config_manager = get_config_manager()
|
||||
config = config_manager.get_all()
|
||||
|
||||
# Constants derived from configuration
|
||||
MAX_STREAMS = config_manager.get("max_streams", 5)
|
||||
TARGET_FPS = config_manager.get("target_fps", 10)
|
||||
POLL_INTERVAL_MS = config_manager.get("poll_interval_ms", 100)
|
||||
RECONNECT_INTERVAL_SEC = config_manager.get("reconnect_interval_sec", 5)
|
||||
MAX_RETRIES = config_manager.get("max_retries", 3)
|
||||
MODELS_DIR = config_manager.get("models_dir", "models")
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def get_config(key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value."""
|
||||
return config_manager.get(key, default)
|
||||
|
||||
def reload_config() -> bool:
|
||||
"""Reload configuration from all sources."""
|
||||
global config
|
||||
success = config_manager.reload()
|
||||
if success:
|
||||
config = config_manager.get_all()
|
||||
return success
|
||||
|
||||
def validate_config() -> List[str]:
|
||||
"""Validate current configuration."""
|
||||
return config_manager.validate()
|
||||
|
||||
|
||||
# Legacy compatibility
|
||||
@dataclass
|
||||
class DetectorConfig:
|
||||
"""Configuration class for detector worker parameters."""
|
||||
"""Legacy configuration class for backward compatibility."""
|
||||
|
||||
# Frame processing settings
|
||||
poll_interval_ms: int = 100
|
||||
|
@ -38,78 +445,16 @@ class DetectorConfig:
|
|||
return 1000 / self.target_fps if self.target_fps > 0 else self.poll_interval_ms
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""Centralized configuration manager."""
|
||||
|
||||
def __init__(self, config_file: str = "config.json"):
|
||||
self.config_file = config_file
|
||||
self._config: Optional[DetectorConfig] = None
|
||||
|
||||
def load_config(self) -> DetectorConfig:
|
||||
"""Load configuration from file with defaults fallback."""
|
||||
if self._config is not None:
|
||||
return self._config
|
||||
|
||||
config_data = {}
|
||||
|
||||
# Try to load from config file
|
||||
if os.path.exists(self.config_file):
|
||||
try:
|
||||
with open(self.config_file, "r") as f:
|
||||
config_data = json.load(f)
|
||||
logger.info(f"Loaded configuration from {self.config_file}")
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning(f"Failed to load config from {self.config_file}: {e}")
|
||||
logger.info("Using default configuration")
|
||||
else:
|
||||
logger.info(f"Config file {self.config_file} not found, using defaults")
|
||||
|
||||
# Create config with defaults + loaded values
|
||||
self._config = DetectorConfig(
|
||||
poll_interval_ms=config_data.get("poll_interval_ms", 100),
|
||||
target_fps=config_data.get("target_fps", 10),
|
||||
max_streams=config_data.get("max_streams", 5),
|
||||
reconnect_interval_sec=config_data.get("reconnect_interval_sec", 5),
|
||||
max_retries=config_data.get("max_retries", 3),
|
||||
log_level=config_data.get("log_level", "INFO"),
|
||||
log_file=config_data.get("log_file", "detector_worker.log"),
|
||||
websocket_log_file=config_data.get("websocket_log_file", "websocket_comm.log")
|
||||
)
|
||||
|
||||
# Log configuration summary
|
||||
self._log_config_summary()
|
||||
|
||||
return self._config
|
||||
|
||||
def _log_config_summary(self):
|
||||
"""Log configuration summary for debugging."""
|
||||
if self._config:
|
||||
logger.info(f"Configuration loaded:")
|
||||
logger.info(f" Target FPS: {self._config.target_fps}")
|
||||
logger.info(f" Poll interval: {self._config.poll_interval}ms")
|
||||
logger.info(f" Max streams: {self._config.max_streams}")
|
||||
logger.info(f" Max retries: {self._config.max_retries}")
|
||||
logger.info(f" Log level: {self._config.log_level}")
|
||||
|
||||
def get_config(self) -> DetectorConfig:
|
||||
"""Get current configuration, loading if necessary."""
|
||||
if self._config is None:
|
||||
return self.load_config()
|
||||
return self._config
|
||||
|
||||
def reload_config(self) -> DetectorConfig:
|
||||
"""Force reload configuration from file."""
|
||||
self._config = None
|
||||
return self.load_config()
|
||||
|
||||
|
||||
# Global config manager instance
|
||||
_config_manager = ConfigManager()
|
||||
|
||||
def get_config() -> DetectorConfig:
|
||||
"""Get the global configuration instance."""
|
||||
return _config_manager.get_config()
|
||||
|
||||
def reload_config() -> DetectorConfig:
|
||||
"""Reload configuration from file."""
|
||||
return _config_manager.reload_config()
|
||||
"""Get legacy detector config for backward compatibility."""
|
||||
config_data = config_manager.get_all()
|
||||
return DetectorConfig(
|
||||
poll_interval_ms=config_data.get("poll_interval_ms", 100),
|
||||
target_fps=config_data.get("target_fps", 10),
|
||||
max_streams=config_data.get("max_streams", 5),
|
||||
reconnect_interval_sec=config_data.get("reconnect_interval_sec", 5),
|
||||
max_retries=config_data.get("max_retries", 3),
|
||||
log_level=config_data.get("log_level", "INFO"),
|
||||
log_file=config_data.get("log_file", "detector_worker.log"),
|
||||
websocket_log_file=config_data.get("websocket_log_file", "websocket_comm.log")
|
||||
)
|
514
detector_worker/core/dependency_injection.py
Normal file
514
detector_worker/core/dependency_injection.py
Normal file
|
@ -0,0 +1,514 @@
|
|||
"""
|
||||
Dependency Injection Container (IoC Container).
|
||||
|
||||
This module provides a comprehensive dependency injection system that manages
|
||||
all components and their dependencies, enabling clean decoupling and testability.
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Any, Type, Optional, TypeVar, Generic, Callable, Union
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from .exceptions import DependencyInjectionError
|
||||
from .singleton_managers import (
|
||||
ModelStateManager, StreamStateManager, SessionStateManager,
|
||||
CacheStateManager, CameraStateManager, PipelineStateManager
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger("detector_worker.dependency_injection")
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class ServiceLifetime(Enum):
|
||||
"""Service lifetime management options."""
|
||||
SINGLETON = "singleton"
|
||||
TRANSIENT = "transient"
|
||||
SCOPED = "scoped"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceDescriptor:
|
||||
"""Describes how to create and manage a service."""
|
||||
service_type: Type
|
||||
implementation_type: Optional[Type] = None
|
||||
factory: Optional[Callable] = None
|
||||
instance: Optional[Any] = None
|
||||
lifetime: ServiceLifetime = ServiceLifetime.TRANSIENT
|
||||
dependencies: list = None
|
||||
|
||||
|
||||
class ServiceContainer:
|
||||
"""
|
||||
Dependency injection container for managing services and their dependencies.
|
||||
|
||||
This container provides comprehensive dependency injection capabilities including:
|
||||
- Service registration with different lifetimes
|
||||
- Automatic dependency resolution
|
||||
- Circular dependency detection
|
||||
- Thread-safe service creation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the service container."""
|
||||
self._services: Dict[Type, ServiceDescriptor] = {}
|
||||
self._singletons: Dict[Type, Any] = {}
|
||||
self._scoped_services: Dict[str, Dict[Type, Any]] = {} # scope_id -> services
|
||||
self._lock = threading.RLock()
|
||||
self._resolution_stack: list = []
|
||||
|
||||
def register_singleton(
|
||||
self,
|
||||
service_type: Type[T],
|
||||
implementation_type: Optional[Type] = None,
|
||||
factory: Optional[Callable[[], T]] = None,
|
||||
instance: Optional[T] = None
|
||||
) -> 'ServiceContainer':
|
||||
"""
|
||||
Register a singleton service.
|
||||
|
||||
Args:
|
||||
service_type: The service interface/type
|
||||
implementation_type: The concrete implementation
|
||||
factory: Factory function to create the service
|
||||
instance: Pre-created instance
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
with self._lock:
|
||||
if instance is not None:
|
||||
self._singletons[service_type] = instance
|
||||
|
||||
descriptor = ServiceDescriptor(
|
||||
service_type=service_type,
|
||||
implementation_type=implementation_type,
|
||||
factory=factory,
|
||||
instance=instance,
|
||||
lifetime=ServiceLifetime.SINGLETON
|
||||
)
|
||||
|
||||
self._services[service_type] = descriptor
|
||||
logger.debug(f"Registered singleton service: {service_type.__name__}")
|
||||
|
||||
return self
|
||||
|
||||
def register_transient(
|
||||
self,
|
||||
service_type: Type[T],
|
||||
implementation_type: Optional[Type] = None,
|
||||
factory: Optional[Callable[[], T]] = None
|
||||
) -> 'ServiceContainer':
|
||||
"""
|
||||
Register a transient service (new instance each time).
|
||||
|
||||
Args:
|
||||
service_type: The service interface/type
|
||||
implementation_type: The concrete implementation
|
||||
factory: Factory function to create the service
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
with self._lock:
|
||||
descriptor = ServiceDescriptor(
|
||||
service_type=service_type,
|
||||
implementation_type=implementation_type,
|
||||
factory=factory,
|
||||
lifetime=ServiceLifetime.TRANSIENT
|
||||
)
|
||||
|
||||
self._services[service_type] = descriptor
|
||||
logger.debug(f"Registered transient service: {service_type.__name__}")
|
||||
|
||||
return self
|
||||
|
||||
def register_scoped(
|
||||
self,
|
||||
service_type: Type[T],
|
||||
implementation_type: Optional[Type] = None,
|
||||
factory: Optional[Callable[[], T]] = None
|
||||
) -> 'ServiceContainer':
|
||||
"""
|
||||
Register a scoped service (same instance within a scope).
|
||||
|
||||
Args:
|
||||
service_type: The service interface/type
|
||||
implementation_type: The concrete implementation
|
||||
factory: Factory function to create the service
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
with self._lock:
|
||||
descriptor = ServiceDescriptor(
|
||||
service_type=service_type,
|
||||
implementation_type=implementation_type,
|
||||
factory=factory,
|
||||
lifetime=ServiceLifetime.SCOPED
|
||||
)
|
||||
|
||||
self._services[service_type] = descriptor
|
||||
logger.debug(f"Registered scoped service: {service_type.__name__}")
|
||||
|
||||
return self
|
||||
|
||||
def resolve(self, service_type: Type[T], scope_id: Optional[str] = None) -> T:
|
||||
"""
|
||||
Resolve a service instance.
|
||||
|
||||
Args:
|
||||
service_type: The service type to resolve
|
||||
scope_id: Optional scope identifier for scoped services
|
||||
|
||||
Returns:
|
||||
Service instance
|
||||
|
||||
Raises:
|
||||
DependencyInjectionError: If service cannot be resolved
|
||||
"""
|
||||
with self._lock:
|
||||
# Check for circular dependencies
|
||||
if service_type in self._resolution_stack:
|
||||
cycle = " -> ".join(cls.__name__ for cls in self._resolution_stack)
|
||||
cycle += f" -> {service_type.__name__}"
|
||||
raise DependencyInjectionError(f"Circular dependency detected: {cycle}")
|
||||
|
||||
self._resolution_stack.append(service_type)
|
||||
|
||||
try:
|
||||
return self._resolve_service(service_type, scope_id)
|
||||
finally:
|
||||
self._resolution_stack.pop()
|
||||
|
||||
def _resolve_service(self, service_type: Type[T], scope_id: Optional[str]) -> T:
|
||||
"""Internal service resolution."""
|
||||
# Check if service is registered
|
||||
if service_type not in self._services:
|
||||
raise DependencyInjectionError(f"Service {service_type.__name__} is not registered")
|
||||
|
||||
descriptor = self._services[service_type]
|
||||
|
||||
# Handle singleton lifetime
|
||||
if descriptor.lifetime == ServiceLifetime.SINGLETON:
|
||||
if service_type in self._singletons:
|
||||
return self._singletons[service_type]
|
||||
|
||||
instance = self._create_instance(descriptor)
|
||||
self._singletons[service_type] = instance
|
||||
return instance
|
||||
|
||||
# Handle scoped lifetime
|
||||
elif descriptor.lifetime == ServiceLifetime.SCOPED:
|
||||
if scope_id is None:
|
||||
raise DependencyInjectionError(f"Scope ID required for scoped service {service_type.__name__}")
|
||||
|
||||
if scope_id not in self._scoped_services:
|
||||
self._scoped_services[scope_id] = {}
|
||||
|
||||
scoped_services = self._scoped_services[scope_id]
|
||||
|
||||
if service_type in scoped_services:
|
||||
return scoped_services[service_type]
|
||||
|
||||
instance = self._create_instance(descriptor)
|
||||
scoped_services[service_type] = instance
|
||||
return instance
|
||||
|
||||
# Handle transient lifetime
|
||||
else:
|
||||
return self._create_instance(descriptor)
|
||||
|
||||
def _create_instance(self, descriptor: ServiceDescriptor) -> Any:
|
||||
"""Create a service instance using the descriptor."""
|
||||
# Use existing instance if available
|
||||
if descriptor.instance is not None:
|
||||
return descriptor.instance
|
||||
|
||||
# Use factory if available
|
||||
if descriptor.factory is not None:
|
||||
try:
|
||||
return descriptor.factory()
|
||||
except Exception as e:
|
||||
raise DependencyInjectionError(f"Failed to create service using factory: {e}")
|
||||
|
||||
# Use implementation type
|
||||
if descriptor.implementation_type is not None:
|
||||
try:
|
||||
return self._create_with_dependencies(descriptor.implementation_type)
|
||||
except Exception as e:
|
||||
raise DependencyInjectionError(f"Failed to create service {descriptor.implementation_type.__name__}: {e}")
|
||||
|
||||
# Use service type directly
|
||||
try:
|
||||
return self._create_with_dependencies(descriptor.service_type)
|
||||
except Exception as e:
|
||||
raise DependencyInjectionError(f"Failed to create service {descriptor.service_type.__name__}: {e}")
|
||||
|
||||
def _create_with_dependencies(self, service_type: Type) -> Any:
|
||||
"""Create service instance with automatic dependency injection."""
|
||||
# Get constructor parameters
|
||||
import inspect
|
||||
|
||||
try:
|
||||
signature = inspect.signature(service_type.__init__)
|
||||
parameters = list(signature.parameters.values())[1:] # Skip 'self'
|
||||
|
||||
# Resolve dependencies
|
||||
dependencies = []
|
||||
for param in parameters:
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
# Try to resolve the parameter type
|
||||
try:
|
||||
dependency = self.resolve(param.annotation)
|
||||
dependencies.append(dependency)
|
||||
except DependencyInjectionError:
|
||||
# If dependency cannot be resolved and has default, use default
|
||||
if param.default != inspect.Parameter.empty:
|
||||
dependencies.append(param.default)
|
||||
else:
|
||||
raise DependencyInjectionError(
|
||||
f"Cannot resolve dependency {param.annotation.__name__} for {service_type.__name__}"
|
||||
)
|
||||
else:
|
||||
# Parameter without type annotation, use default if available
|
||||
if param.default != inspect.Parameter.empty:
|
||||
dependencies.append(param.default)
|
||||
else:
|
||||
raise DependencyInjectionError(
|
||||
f"Cannot resolve untyped parameter {param.name} for {service_type.__name__}"
|
||||
)
|
||||
|
||||
return service_type(*dependencies)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, DependencyInjectionError):
|
||||
raise
|
||||
else:
|
||||
# Try to create without dependencies
|
||||
return service_type()
|
||||
|
||||
def create_scope(self, scope_id: str) -> 'ServiceScope':
|
||||
"""Create a new service scope."""
|
||||
return ServiceScope(self, scope_id)
|
||||
|
||||
def dispose_scope(self, scope_id: str) -> None:
|
||||
"""Dispose a service scope and its services."""
|
||||
with self._lock:
|
||||
if scope_id in self._scoped_services:
|
||||
scoped_services = self._scoped_services.pop(scope_id)
|
||||
|
||||
# Dispose services that implement IDisposable
|
||||
for service in scoped_services.values():
|
||||
if hasattr(service, 'dispose'):
|
||||
try:
|
||||
service.dispose()
|
||||
except Exception as e:
|
||||
logger.error(f"Error disposing scoped service: {e}")
|
||||
|
||||
logger.debug(f"Disposed scope {scope_id} with {len(scoped_services)} services")
|
||||
|
||||
def is_registered(self, service_type: Type) -> bool:
|
||||
"""Check if a service type is registered."""
|
||||
with self._lock:
|
||||
return service_type in self._services
|
||||
|
||||
def get_registration_info(self, service_type: Type) -> Optional[ServiceDescriptor]:
|
||||
"""Get registration information for a service."""
|
||||
with self._lock:
|
||||
return self._services.get(service_type)
|
||||
|
||||
def get_registered_services(self) -> Dict[Type, ServiceDescriptor]:
|
||||
"""Get all registered services."""
|
||||
with self._lock:
|
||||
return self._services.copy()
|
||||
|
||||
def clear_singletons(self) -> None:
|
||||
"""Clear all singleton instances."""
|
||||
with self._lock:
|
||||
singleton_count = len(self._singletons)
|
||||
self._singletons.clear()
|
||||
logger.info(f"Cleared {singleton_count} singleton services")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get container statistics."""
|
||||
with self._lock:
|
||||
lifetime_counts = {}
|
||||
for descriptor in self._services.values():
|
||||
lifetime = descriptor.lifetime.value
|
||||
lifetime_counts[lifetime] = lifetime_counts.get(lifetime, 0) + 1
|
||||
|
||||
return {
|
||||
"registered_services": len(self._services),
|
||||
"active_singletons": len(self._singletons),
|
||||
"active_scopes": len(self._scoped_services),
|
||||
"lifetime_breakdown": lifetime_counts
|
||||
}
|
||||
|
||||
|
||||
class ServiceScope:
|
||||
"""
|
||||
Service scope for managing scoped service lifetimes.
|
||||
|
||||
This class provides a context for scoped services, ensuring they
|
||||
are properly disposed when the scope ends.
|
||||
"""
|
||||
|
||||
def __init__(self, container: ServiceContainer, scope_id: str):
|
||||
"""Initialize service scope."""
|
||||
self.container = container
|
||||
self.scope_id = scope_id
|
||||
|
||||
def resolve(self, service_type: Type[T]) -> T:
|
||||
"""Resolve a service within this scope."""
|
||||
return self.container.resolve(service_type, self.scope_id)
|
||||
|
||||
def __enter__(self) -> 'ServiceScope':
|
||||
"""Enter the scope context."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Exit the scope context and dispose services."""
|
||||
self.container.dispose_scope(self.scope_id)
|
||||
|
||||
|
||||
class DetectorWorkerContainer:
|
||||
"""
|
||||
Pre-configured dependency injection container for the detector worker.
|
||||
|
||||
This class sets up all the standard services and managers used by
|
||||
the detection worker system.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the detector worker container."""
|
||||
self.container = ServiceContainer()
|
||||
self._register_core_services()
|
||||
|
||||
def _register_core_services(self) -> None:
|
||||
"""Register all core services and managers."""
|
||||
# Register state managers as singletons
|
||||
self.container.register_singleton(
|
||||
ModelStateManager,
|
||||
instance=ModelStateManager()
|
||||
)
|
||||
|
||||
self.container.register_singleton(
|
||||
StreamStateManager,
|
||||
instance=StreamStateManager()
|
||||
)
|
||||
|
||||
self.container.register_singleton(
|
||||
SessionStateManager,
|
||||
instance=SessionStateManager()
|
||||
)
|
||||
|
||||
self.container.register_singleton(
|
||||
CacheStateManager,
|
||||
instance=CacheStateManager()
|
||||
)
|
||||
|
||||
self.container.register_singleton(
|
||||
CameraStateManager,
|
||||
instance=CameraStateManager()
|
||||
)
|
||||
|
||||
self.container.register_singleton(
|
||||
PipelineStateManager,
|
||||
instance=PipelineStateManager()
|
||||
)
|
||||
|
||||
# Register other core services
|
||||
self._register_detection_services()
|
||||
self._register_communication_services()
|
||||
self._register_storage_services()
|
||||
|
||||
logger.info("Registered all core services in dependency container")
|
||||
|
||||
def _register_detection_services(self) -> None:
|
||||
"""Register detection-related services."""
|
||||
# These will be registered when the modules are imported
|
||||
try:
|
||||
from ..detection.yolo_detector import YOLODetector
|
||||
from ..detection.tracking_manager import TrackingManager
|
||||
from ..detection.stability_validator import StabilityValidator
|
||||
|
||||
self.container.register_transient(YOLODetector)
|
||||
self.container.register_singleton(TrackingManager)
|
||||
self.container.register_transient(StabilityValidator)
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not register detection services: {e}")
|
||||
|
||||
def _register_communication_services(self) -> None:
|
||||
"""Register communication-related services."""
|
||||
try:
|
||||
from ..communication.websocket_handler import WebSocketHandler
|
||||
from ..communication.message_processor import MessageProcessor
|
||||
from ..communication.response_formatter import ResponseFormatter
|
||||
|
||||
self.container.register_transient(WebSocketHandler)
|
||||
self.container.register_singleton(MessageProcessor)
|
||||
self.container.register_singleton(ResponseFormatter)
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not register communication services: {e}")
|
||||
|
||||
def _register_storage_services(self) -> None:
|
||||
"""Register storage-related services."""
|
||||
try:
|
||||
from ..storage.database_manager import DatabaseManager
|
||||
from ..storage.redis_client import RedisClientManager
|
||||
from ..storage.session_cache import SessionCacheManager
|
||||
|
||||
self.container.register_transient(DatabaseManager)
|
||||
self.container.register_singleton(RedisClientManager)
|
||||
self.container.register_singleton(SessionCacheManager)
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not register storage services: {e}")
|
||||
|
||||
def get_container(self) -> ServiceContainer:
|
||||
"""Get the underlying service container."""
|
||||
return self.container
|
||||
|
||||
def resolve(self, service_type: Type[T]) -> T:
|
||||
"""Resolve a service from the container."""
|
||||
return self.container.resolve(service_type)
|
||||
|
||||
def create_scope(self, scope_id: str) -> ServiceScope:
|
||||
"""Create a new service scope."""
|
||||
return self.container.create_scope(scope_id)
|
||||
|
||||
|
||||
# Global container instance
|
||||
_global_container: Optional[DetectorWorkerContainer] = None
|
||||
_container_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_container() -> DetectorWorkerContainer:
|
||||
"""Get or create the global dependency injection container."""
|
||||
global _global_container
|
||||
|
||||
if _global_container is None:
|
||||
with _container_lock:
|
||||
if _global_container is None:
|
||||
_global_container = DetectorWorkerContainer()
|
||||
logger.info("Created global dependency injection container")
|
||||
|
||||
return _global_container
|
||||
|
||||
|
||||
def resolve_service(service_type: Type[T]) -> T:
|
||||
"""Convenience function to resolve a service from the global container."""
|
||||
container = get_container()
|
||||
return container.resolve(service_type)
|
||||
|
||||
|
||||
def create_service_scope(scope_id: str) -> ServiceScope:
|
||||
"""Convenience function to create a service scope."""
|
||||
container = get_container()
|
||||
return container.create_scope(scope_id)
|
|
@ -122,6 +122,16 @@ class InvalidStateError(DetectorWorkerError):
|
|||
pass
|
||||
|
||||
|
||||
class StateManagerError(DetectorWorkerError):
|
||||
"""Raised when state manager operations fail."""
|
||||
pass
|
||||
|
||||
|
||||
class DependencyInjectionError(DetectorWorkerError):
|
||||
"""Raised when dependency injection operations fail."""
|
||||
pass
|
||||
|
||||
|
||||
# ===== ERROR CONTEXT HELPERS =====
|
||||
|
||||
def add_error_context(exception: Exception, **context) -> DetectorWorkerError:
|
||||
|
|
767
detector_worker/core/singleton_managers.py
Normal file
767
detector_worker/core/singleton_managers.py
Normal file
|
@ -0,0 +1,767 @@
|
|||
"""
|
||||
Singleton managers to replace global state variables.
|
||||
|
||||
This module provides singleton managers that replace the global dictionaries
|
||||
and variables used throughout the detection worker, enabling proper dependency
|
||||
injection and cleaner state management.
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Set, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from contextlib import contextmanager
|
||||
|
||||
from .exceptions import StateManagerError
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger("detector_worker.singleton_managers")
|
||||
|
||||
|
||||
class SingletonMeta(type):
|
||||
"""Metaclass for thread-safe singleton implementation."""
|
||||
_instances = {}
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
with cls._lock:
|
||||
if cls not in cls._instances:
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
cls._instances[cls] = instance
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about a loaded model."""
|
||||
model_id: str
|
||||
model_tree: Any
|
||||
camera_id: str
|
||||
loaded_at: float = field(default_factory=time.time)
|
||||
reference_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamInfo:
|
||||
"""Information about an active stream."""
|
||||
camera_id: str
|
||||
subscription_id: str
|
||||
config: Dict[str, Any]
|
||||
created_at: float = field(default_factory=time.time)
|
||||
active: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInfo:
|
||||
"""Session information."""
|
||||
session_id: str
|
||||
display_id: str
|
||||
camera_id: str
|
||||
created_at: float = field(default_factory=time.time)
|
||||
last_activity: float = field(default_factory=time.time)
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelStateManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Singleton manager for model state (replaces global 'models' dict).
|
||||
|
||||
Thread-safe management of loaded ML models with reference counting
|
||||
and automatic cleanup.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the model state manager."""
|
||||
self._models: Dict[str, Dict[str, ModelInfo]] = {} # camera_id -> {model_id -> ModelInfo}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@contextmanager
|
||||
def _thread_safe(self):
|
||||
"""Context manager for thread-safe operations."""
|
||||
with self._lock:
|
||||
yield
|
||||
|
||||
def load_model(self, camera_id: str, model_id: str, model_tree: Any) -> None:
|
||||
"""
|
||||
Load a model for a camera.
|
||||
|
||||
Args:
|
||||
camera_id: Camera identifier
|
||||
model_id: Model identifier
|
||||
model_tree: Loaded model tree
|
||||
"""
|
||||
with self._thread_safe():
|
||||
if camera_id not in self._models:
|
||||
self._models[camera_id] = {}
|
||||
|
||||
model_info = ModelInfo(
|
||||
model_id=model_id,
|
||||
model_tree=model_tree,
|
||||
camera_id=camera_id,
|
||||
reference_count=1
|
||||
)
|
||||
|
||||
if model_id in self._models[camera_id]:
|
||||
# Increment reference count for existing model
|
||||
self._models[camera_id][model_id].reference_count += 1
|
||||
else:
|
||||
self._models[camera_id][model_id] = model_info
|
||||
|
||||
logger.debug(f"Loaded model {model_id} for camera {camera_id}")
|
||||
|
||||
def get_model(self, camera_id: str, model_id: str) -> Optional[Any]:
|
||||
"""Get a model tree for a camera."""
|
||||
with self._thread_safe():
|
||||
camera_models = self._models.get(camera_id, {})
|
||||
model_info = camera_models.get(model_id)
|
||||
return model_info.model_tree if model_info else None
|
||||
|
||||
def get_camera_models(self, camera_id: str) -> Dict[str, Any]:
|
||||
"""Get all models for a camera."""
|
||||
with self._thread_safe():
|
||||
camera_models = self._models.get(camera_id, {})
|
||||
return {mid: info.model_tree for mid, info in camera_models.items()}
|
||||
|
||||
def unload_model(self, camera_id: str, model_id: str) -> bool:
|
||||
"""
|
||||
Unload a model (decrement reference count).
|
||||
|
||||
Returns:
|
||||
True if model was completely removed, False if still referenced
|
||||
"""
|
||||
with self._thread_safe():
|
||||
if camera_id not in self._models:
|
||||
return True
|
||||
|
||||
camera_models = self._models[camera_id]
|
||||
if model_id not in camera_models:
|
||||
return True
|
||||
|
||||
model_info = camera_models[model_id]
|
||||
model_info.reference_count -= 1
|
||||
|
||||
if model_info.reference_count <= 0:
|
||||
del camera_models[model_id]
|
||||
logger.debug(f"Unloaded model {model_id} for camera {camera_id}")
|
||||
|
||||
# Clean up camera entry if no models left
|
||||
if not camera_models:
|
||||
del self._models[camera_id]
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def unload_camera_models(self, camera_id: str) -> None:
|
||||
"""Unload all models for a camera."""
|
||||
with self._thread_safe():
|
||||
if camera_id in self._models:
|
||||
model_count = len(self._models[camera_id])
|
||||
del self._models[camera_id]
|
||||
logger.debug(f"Unloaded {model_count} models for camera {camera_id}")
|
||||
|
||||
def get_all_models(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all loaded models (for backward compatibility)."""
|
||||
with self._thread_safe():
|
||||
result = {}
|
||||
for camera_id, models in self._models.items():
|
||||
result[camera_id] = {mid: info.model_tree for mid, info in models.items()}
|
||||
return result
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all models."""
|
||||
with self._thread_safe():
|
||||
model_count = sum(len(models) for models in self._models.values())
|
||||
self._models.clear()
|
||||
logger.info(f"Cleared {model_count} models from state manager")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get model statistics."""
|
||||
with self._thread_safe():
|
||||
total_models = sum(len(models) for models in self._models.values())
|
||||
total_cameras = len(self._models)
|
||||
|
||||
return {
|
||||
"total_models": total_models,
|
||||
"total_cameras": total_cameras,
|
||||
"cameras": list(self._models.keys())
|
||||
}
|
||||
|
||||
|
||||
class StreamStateManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Singleton manager for stream state (replaces global stream-related dicts).
|
||||
|
||||
Manages streams, camera_streams, subscription_to_camera mappings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize stream state manager."""
|
||||
self._streams: Dict[str, StreamInfo] = {} # camera_id -> StreamInfo
|
||||
self._camera_streams: Dict[str, Dict[str, Any]] = {} # camera_url -> shared_stream_info
|
||||
self._subscription_to_camera: Dict[str, str] = {} # subscription_id -> camera_id
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@contextmanager
|
||||
def _thread_safe(self):
|
||||
"""Context manager for thread-safe operations."""
|
||||
with self._lock:
|
||||
yield
|
||||
|
||||
def add_stream(
|
||||
self,
|
||||
camera_id: str,
|
||||
subscription_id: str,
|
||||
config: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Add a new stream."""
|
||||
with self._thread_safe():
|
||||
stream_info = StreamInfo(
|
||||
camera_id=camera_id,
|
||||
subscription_id=subscription_id,
|
||||
config=config.copy()
|
||||
)
|
||||
|
||||
self._streams[camera_id] = stream_info
|
||||
self._subscription_to_camera[subscription_id] = camera_id
|
||||
|
||||
logger.debug(f"Added stream for camera {camera_id} with subscription {subscription_id}")
|
||||
|
||||
def remove_stream(self, camera_id: str) -> Optional[StreamInfo]:
|
||||
"""Remove a stream."""
|
||||
with self._thread_safe():
|
||||
stream_info = self._streams.pop(camera_id, None)
|
||||
|
||||
if stream_info:
|
||||
# Clean up subscription mapping
|
||||
self._subscription_to_camera.pop(stream_info.subscription_id, None)
|
||||
logger.debug(f"Removed stream for camera {camera_id}")
|
||||
|
||||
return stream_info
|
||||
|
||||
def get_stream(self, camera_id: str) -> Optional[StreamInfo]:
|
||||
"""Get stream information."""
|
||||
with self._thread_safe():
|
||||
return self._streams.get(camera_id)
|
||||
|
||||
def get_camera_by_subscription(self, subscription_id: str) -> Optional[str]:
|
||||
"""Get camera ID by subscription ID."""
|
||||
with self._thread_safe():
|
||||
return self._subscription_to_camera.get(subscription_id)
|
||||
|
||||
def add_shared_stream(self, camera_url: str, stream_data: Dict[str, Any]) -> None:
|
||||
"""Add a shared camera stream."""
|
||||
with self._thread_safe():
|
||||
self._camera_streams[camera_url] = stream_data
|
||||
|
||||
def get_shared_stream(self, camera_url: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get shared stream data."""
|
||||
with self._thread_safe():
|
||||
return self._camera_streams.get(camera_url)
|
||||
|
||||
def remove_shared_stream(self, camera_url: str) -> Optional[Dict[str, Any]]:
|
||||
"""Remove shared stream."""
|
||||
with self._thread_safe():
|
||||
return self._camera_streams.pop(camera_url, None)
|
||||
|
||||
def get_all_streams(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all streams (for backward compatibility)."""
|
||||
with self._thread_safe():
|
||||
return {cid: info.config for cid, info in self._streams.items()}
|
||||
|
||||
def get_all_camera_streams(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all shared camera streams."""
|
||||
with self._thread_safe():
|
||||
return self._camera_streams.copy()
|
||||
|
||||
def get_subscription_mappings(self) -> Dict[str, str]:
|
||||
"""Get subscription to camera mappings."""
|
||||
with self._thread_safe():
|
||||
return self._subscription_to_camera.copy()
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all stream data."""
|
||||
with self._thread_safe():
|
||||
stream_count = len(self._streams)
|
||||
camera_stream_count = len(self._camera_streams)
|
||||
subscription_count = len(self._subscription_to_camera)
|
||||
|
||||
self._streams.clear()
|
||||
self._camera_streams.clear()
|
||||
self._subscription_to_camera.clear()
|
||||
|
||||
logger.info(f"Cleared {stream_count} streams, {camera_stream_count} camera streams, {subscription_count} subscriptions")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get stream statistics."""
|
||||
with self._thread_safe():
|
||||
return {
|
||||
"active_streams": len(self._streams),
|
||||
"shared_camera_streams": len(self._camera_streams),
|
||||
"subscription_mappings": len(self._subscription_to_camera)
|
||||
}
|
||||
|
||||
|
||||
class SessionStateManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Singleton manager for session state (replaces session-related global dicts).
|
||||
|
||||
Manages session_ids, session_detections, session_to_camera, detection_timestamps.
|
||||
"""
|
||||
|
||||
def __init__(self, session_ttl: float = 600.0): # 10 minutes default TTL
|
||||
"""Initialize session state manager."""
|
||||
self._session_ids: Dict[str, str] = {} # display_id -> session_id
|
||||
self._session_detections: Dict[str, Dict[str, Any]] = {} # session_id -> detection_data
|
||||
self._session_to_camera: Dict[str, str] = {} # session_id -> camera_id
|
||||
self._detection_timestamps: Dict[str, float] = {} # session_id -> timestamp
|
||||
self._session_ttl = session_ttl
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@contextmanager
|
||||
def _thread_safe(self):
|
||||
"""Context manager for thread-safe operations."""
|
||||
with self._lock:
|
||||
yield
|
||||
|
||||
def set_session_id(self, display_id: str, session_id: str) -> None:
|
||||
"""Set session ID for a display."""
|
||||
with self._thread_safe():
|
||||
self._session_ids[display_id] = session_id
|
||||
logger.debug(f"Set session {session_id} for display {display_id}")
|
||||
|
||||
def get_session_id(self, display_id: str) -> Optional[str]:
|
||||
"""Get session ID for a display."""
|
||||
with self._thread_safe():
|
||||
return self._session_ids.get(display_id)
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session_id: str,
|
||||
camera_id: str,
|
||||
detection_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Create a new session with detection data."""
|
||||
with self._thread_safe():
|
||||
self._session_detections[session_id] = detection_data.copy()
|
||||
self._session_to_camera[session_id] = camera_id
|
||||
self._detection_timestamps[session_id] = time.time()
|
||||
|
||||
logger.debug(f"Created session {session_id} for camera {camera_id}")
|
||||
|
||||
def get_session_detection(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get detection data for a session."""
|
||||
with self._thread_safe():
|
||||
return self._session_detections.get(session_id)
|
||||
|
||||
def update_session_detection(self, session_id: str, detection_data: Dict[str, Any]) -> None:
|
||||
"""Update detection data for a session."""
|
||||
with self._thread_safe():
|
||||
if session_id in self._session_detections:
|
||||
self._session_detections[session_id].update(detection_data)
|
||||
self._detection_timestamps[session_id] = time.time()
|
||||
|
||||
def get_camera_by_session(self, session_id: str) -> Optional[str]:
|
||||
"""Get camera ID by session ID."""
|
||||
with self._thread_safe():
|
||||
return self._session_to_camera.get(session_id)
|
||||
|
||||
def remove_session(self, session_id: str) -> bool:
|
||||
"""Remove a session completely."""
|
||||
with self._thread_safe():
|
||||
removed = False
|
||||
|
||||
if session_id in self._session_detections:
|
||||
del self._session_detections[session_id]
|
||||
removed = True
|
||||
|
||||
if session_id in self._session_to_camera:
|
||||
del self._session_to_camera[session_id]
|
||||
|
||||
if session_id in self._detection_timestamps:
|
||||
del self._detection_timestamps[session_id]
|
||||
|
||||
if removed:
|
||||
logger.debug(f"Removed session {session_id}")
|
||||
|
||||
return removed
|
||||
|
||||
def cleanup_expired_sessions(self) -> int:
|
||||
"""Clean up expired sessions based on TTL."""
|
||||
with self._thread_safe():
|
||||
current_time = time.time()
|
||||
expired_sessions = []
|
||||
|
||||
for session_id, timestamp in self._detection_timestamps.items():
|
||||
if current_time - timestamp > self._session_ttl:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
self.remove_session(session_id)
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")
|
||||
|
||||
return len(expired_sessions)
|
||||
|
||||
def get_all_session_ids(self) -> Dict[str, str]:
|
||||
"""Get all session IDs (for backward compatibility)."""
|
||||
with self._thread_safe():
|
||||
return self._session_ids.copy()
|
||||
|
||||
def get_all_session_detections(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all session detections."""
|
||||
with self._thread_safe():
|
||||
return self._session_detections.copy()
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all session data."""
|
||||
with self._thread_safe():
|
||||
session_count = len(self._session_ids)
|
||||
detection_count = len(self._session_detections)
|
||||
|
||||
self._session_ids.clear()
|
||||
self._session_detections.clear()
|
||||
self._session_to_camera.clear()
|
||||
self._detection_timestamps.clear()
|
||||
|
||||
logger.info(f"Cleared {session_count} session IDs and {detection_count} session detections")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get session statistics."""
|
||||
with self._thread_safe():
|
||||
current_time = time.time()
|
||||
active_sessions = sum(
|
||||
1 for timestamp in self._detection_timestamps.values()
|
||||
if current_time - timestamp <= self._session_ttl
|
||||
)
|
||||
|
||||
return {
|
||||
"total_display_sessions": len(self._session_ids),
|
||||
"total_detection_sessions": len(self._session_detections),
|
||||
"active_sessions": active_sessions,
|
||||
"session_ttl": self._session_ttl
|
||||
}
|
||||
|
||||
|
||||
class CacheStateManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Singleton manager for cache state (replaces cache-related global dicts).
|
||||
|
||||
Manages cached_detections, cached_full_pipeline_results, latest_frames.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize cache state manager."""
|
||||
self._cached_detections: Dict[str, Dict[str, Any]] = {} # camera_id -> detection_data
|
||||
self._cached_pipeline_results: Dict[str, Dict[str, Any]] = {} # camera_id -> pipeline_result
|
||||
self._latest_frames: Dict[str, Any] = {} # camera_id -> frame_data
|
||||
self._frame_skip_flags: Dict[str, bool] = {} # camera_id -> skip_flag
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@contextmanager
|
||||
def _thread_safe(self):
|
||||
"""Context manager for thread-safe operations."""
|
||||
with self._lock:
|
||||
yield
|
||||
|
||||
def cache_detection(self, camera_id: str, detection_data: Dict[str, Any]) -> None:
|
||||
"""Cache detection result for a camera."""
|
||||
with self._thread_safe():
|
||||
self._cached_detections[camera_id] = detection_data.copy()
|
||||
|
||||
def get_cached_detection(self, camera_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached detection for a camera."""
|
||||
with self._thread_safe():
|
||||
return self._cached_detections.get(camera_id)
|
||||
|
||||
def clear_cached_detection(self, camera_id: str) -> bool:
|
||||
"""Clear cached detection for a camera."""
|
||||
with self._thread_safe():
|
||||
return self._cached_detections.pop(camera_id, None) is not None
|
||||
|
||||
def cache_pipeline_result(self, camera_id: str, result: Dict[str, Any]) -> None:
|
||||
"""Cache pipeline result for a camera."""
|
||||
with self._thread_safe():
|
||||
self._cached_pipeline_results[camera_id] = result.copy()
|
||||
|
||||
def get_cached_pipeline_result(self, camera_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached pipeline result for a camera."""
|
||||
with self._thread_safe():
|
||||
return self._cached_pipeline_results.get(camera_id)
|
||||
|
||||
def clear_cached_pipeline_result(self, camera_id: str) -> bool:
|
||||
"""Clear cached pipeline result for a camera."""
|
||||
with self._thread_safe():
|
||||
return self._cached_pipeline_results.pop(camera_id, None) is not None
|
||||
|
||||
def set_latest_frame(self, camera_id: str, frame_data: Any) -> None:
|
||||
"""Set latest frame for a camera."""
|
||||
with self._thread_safe():
|
||||
self._latest_frames[camera_id] = frame_data
|
||||
|
||||
def get_latest_frame(self, camera_id: str) -> Optional[Any]:
|
||||
"""Get latest frame for a camera."""
|
||||
with self._thread_safe():
|
||||
return self._latest_frames.get(camera_id)
|
||||
|
||||
def set_frame_skip_flag(self, camera_id: str, skip: bool) -> None:
|
||||
"""Set frame skip flag for a camera."""
|
||||
with self._thread_safe():
|
||||
self._frame_skip_flags[camera_id] = skip
|
||||
|
||||
def get_frame_skip_flag(self, camera_id: str) -> bool:
|
||||
"""Get frame skip flag for a camera."""
|
||||
with self._thread_safe():
|
||||
return self._frame_skip_flags.get(camera_id, False)
|
||||
|
||||
def clear_camera_cache(self, camera_id: str) -> None:
|
||||
"""Clear all cache data for a camera."""
|
||||
with self._thread_safe():
|
||||
self._cached_detections.pop(camera_id, None)
|
||||
self._cached_pipeline_results.pop(camera_id, None)
|
||||
self._latest_frames.pop(camera_id, None)
|
||||
self._frame_skip_flags.pop(camera_id, None)
|
||||
|
||||
logger.debug(f"Cleared cache for camera {camera_id}")
|
||||
|
||||
def get_all_cached_detections(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all cached detections (for backward compatibility)."""
|
||||
with self._thread_safe():
|
||||
return self._cached_detections.copy()
|
||||
|
||||
def get_all_latest_frames(self) -> Dict[str, Any]:
|
||||
"""Get all latest frames (for backward compatibility)."""
|
||||
with self._thread_safe():
|
||||
return self._latest_frames.copy()
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all cache data."""
|
||||
with self._thread_safe():
|
||||
detection_count = len(self._cached_detections)
|
||||
pipeline_count = len(self._cached_pipeline_results)
|
||||
frame_count = len(self._latest_frames)
|
||||
|
||||
self._cached_detections.clear()
|
||||
self._cached_pipeline_results.clear()
|
||||
self._latest_frames.clear()
|
||||
self._frame_skip_flags.clear()
|
||||
|
||||
logger.info(f"Cleared {detection_count} cached detections, {pipeline_count} pipeline results, {frame_count} frames")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
with self._thread_safe():
|
||||
return {
|
||||
"cached_detections": len(self._cached_detections),
|
||||
"cached_pipeline_results": len(self._cached_pipeline_results),
|
||||
"latest_frames": len(self._latest_frames),
|
||||
"frame_skip_flags": len(self._frame_skip_flags)
|
||||
}
|
||||
|
||||
|
||||
class CameraStateManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Singleton manager for camera connection state.
|
||||
|
||||
Manages camera_states and connection monitoring.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize camera state manager."""
|
||||
self._camera_states: Dict[str, Dict[str, Any]] = {} # camera_id -> state_info
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@contextmanager
|
||||
def _thread_safe(self):
|
||||
"""Context manager for thread-safe operations."""
|
||||
with self._lock:
|
||||
yield
|
||||
|
||||
def set_camera_connected(self, camera_id: str, connected: bool = True) -> None:
|
||||
"""Set camera connection state."""
|
||||
with self._thread_safe():
|
||||
if camera_id not in self._camera_states:
|
||||
self._camera_states[camera_id] = {}
|
||||
|
||||
self._camera_states[camera_id].update({
|
||||
"connected": connected,
|
||||
"last_update": time.time(),
|
||||
"disconnection_notified": False if connected else self._camera_states[camera_id].get("disconnection_notified", False),
|
||||
"reconnection_notified": False if not connected else self._camera_states[camera_id].get("reconnection_notified", False)
|
||||
})
|
||||
|
||||
logger.debug(f"Camera {camera_id} connection state: {connected}")
|
||||
|
||||
def is_camera_connected(self, camera_id: str) -> bool:
|
||||
"""Check if camera is connected."""
|
||||
with self._thread_safe():
|
||||
state = self._camera_states.get(camera_id, {})
|
||||
return state.get("connected", True) # Default to connected
|
||||
|
||||
def should_notify_disconnection(self, camera_id: str) -> bool:
|
||||
"""Check if disconnection should be notified."""
|
||||
with self._thread_safe():
|
||||
state = self._camera_states.get(camera_id, {})
|
||||
return not state.get("connected", True) and not state.get("disconnection_notified", False)
|
||||
|
||||
def mark_disconnection_notified(self, camera_id: str) -> None:
|
||||
"""Mark disconnection as notified."""
|
||||
with self._thread_safe():
|
||||
if camera_id in self._camera_states:
|
||||
self._camera_states[camera_id]["disconnection_notified"] = True
|
||||
|
||||
def should_notify_reconnection(self, camera_id: str) -> bool:
|
||||
"""Check if reconnection should be notified."""
|
||||
with self._thread_safe():
|
||||
state = self._camera_states.get(camera_id, {})
|
||||
return state.get("connected", True) and not state.get("reconnection_notified", True)
|
||||
|
||||
def mark_reconnection_notified(self, camera_id: str) -> None:
|
||||
"""Mark reconnection as notified."""
|
||||
with self._thread_safe():
|
||||
if camera_id in self._camera_states:
|
||||
self._camera_states[camera_id]["reconnection_notified"] = True
|
||||
|
||||
def get_camera_state(self, camera_id: str) -> Dict[str, Any]:
|
||||
"""Get full camera state."""
|
||||
with self._thread_safe():
|
||||
return self._camera_states.get(camera_id, {}).copy()
|
||||
|
||||
def remove_camera_state(self, camera_id: str) -> bool:
|
||||
"""Remove camera state."""
|
||||
with self._thread_safe():
|
||||
return self._camera_states.pop(camera_id, None) is not None
|
||||
|
||||
def get_all_camera_states(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all camera states (for backward compatibility)."""
|
||||
with self._thread_safe():
|
||||
return self._camera_states.copy()
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all camera states."""
|
||||
with self._thread_safe():
|
||||
state_count = len(self._camera_states)
|
||||
self._camera_states.clear()
|
||||
logger.info(f"Cleared {state_count} camera states")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get camera state statistics."""
|
||||
with self._thread_safe():
|
||||
connected_count = sum(
|
||||
1 for state in self._camera_states.values()
|
||||
if state.get("connected", True)
|
||||
)
|
||||
|
||||
return {
|
||||
"total_cameras": len(self._camera_states),
|
||||
"connected_cameras": connected_count,
|
||||
"disconnected_cameras": len(self._camera_states) - connected_count
|
||||
}
|
||||
|
||||
|
||||
class PipelineStateManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
Singleton manager for pipeline state (replaces session_pipeline_states).
|
||||
|
||||
Manages session pipeline states and mode switching.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize pipeline state manager."""
|
||||
self._pipeline_states: Dict[str, Dict[str, Any]] = {} # camera_id -> pipeline_state
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@contextmanager
|
||||
def _thread_safe(self):
|
||||
"""Context manager for thread-safe operations."""
|
||||
with self._lock:
|
||||
yield
|
||||
|
||||
def get_or_init_state(self, camera_id: str) -> Dict[str, Any]:
|
||||
"""Get or initialize pipeline state for a camera."""
|
||||
with self._thread_safe():
|
||||
if camera_id not in self._pipeline_states:
|
||||
self._pipeline_states[camera_id] = {
|
||||
"mode": "validation_detecting",
|
||||
"backend_session_id": None,
|
||||
"yolo_inference_enabled": True,
|
||||
"progression_stage": None,
|
||||
"validated_detection": None,
|
||||
"created_at": time.time()
|
||||
}
|
||||
|
||||
return self._pipeline_states[camera_id].copy()
|
||||
|
||||
def update_mode(self, camera_id: str, mode: str, session_id: Optional[str] = None) -> None:
|
||||
"""Update pipeline mode for a camera."""
|
||||
with self._thread_safe():
|
||||
state = self.get_or_init_state(camera_id) # This creates if not exists
|
||||
self._pipeline_states[camera_id]["mode"] = mode
|
||||
|
||||
if session_id is not None:
|
||||
self._pipeline_states[camera_id]["backend_session_id"] = session_id
|
||||
|
||||
logger.debug(f"Updated pipeline mode for camera {camera_id}: {mode}")
|
||||
|
||||
def set_yolo_inference_enabled(self, camera_id: str, enabled: bool) -> None:
|
||||
"""Set YOLO inference enabled state."""
|
||||
with self._thread_safe():
|
||||
state = self.get_or_init_state(camera_id) # This creates if not exists
|
||||
self._pipeline_states[camera_id]["yolo_inference_enabled"] = enabled
|
||||
|
||||
def set_progression_stage(self, camera_id: str, stage: str) -> None:
|
||||
"""Set progression stage for a camera."""
|
||||
with self._thread_safe():
|
||||
state = self.get_or_init_state(camera_id) # This creates if not exists
|
||||
self._pipeline_states[camera_id]["progression_stage"] = stage
|
||||
|
||||
def set_validated_detection(self, camera_id: str, detection: Optional[Dict[str, Any]]) -> None:
|
||||
"""Set validated detection for a camera."""
|
||||
with self._thread_safe():
|
||||
state = self.get_or_init_state(camera_id) # This creates if not exists
|
||||
self._pipeline_states[camera_id]["validated_detection"] = detection
|
||||
|
||||
def get_state(self, camera_id: str) -> Dict[str, Any]:
|
||||
"""Get pipeline state for a camera."""
|
||||
with self._thread_safe():
|
||||
return self.get_or_init_state(camera_id)
|
||||
|
||||
def remove_state(self, camera_id: str) -> bool:
|
||||
"""Remove pipeline state for a camera."""
|
||||
with self._thread_safe():
|
||||
return self._pipeline_states.pop(camera_id, None) is not None
|
||||
|
||||
def get_all_states(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all pipeline states (for backward compatibility)."""
|
||||
with self._thread_safe():
|
||||
return self._pipeline_states.copy()
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all pipeline states."""
|
||||
with self._thread_safe():
|
||||
state_count = len(self._pipeline_states)
|
||||
self._pipeline_states.clear()
|
||||
logger.info(f"Cleared {state_count} pipeline states")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get pipeline state statistics."""
|
||||
with self._thread_safe():
|
||||
mode_counts = {}
|
||||
for state in self._pipeline_states.values():
|
||||
mode = state.get("mode", "unknown")
|
||||
mode_counts[mode] = mode_counts.get(mode, 0) + 1
|
||||
|
||||
return {
|
||||
"total_pipeline_states": len(self._pipeline_states),
|
||||
"mode_breakdown": mode_counts
|
||||
}
|
||||
|
||||
|
||||
# Global singleton instances (for backward compatibility and easy access)
|
||||
model_state_manager = ModelStateManager()
|
||||
stream_state_manager = StreamStateManager()
|
||||
session_state_manager = SessionStateManager()
|
||||
cache_state_manager = CacheStateManager()
|
||||
camera_state_manager = CameraStateManager()
|
||||
pipeline_state_manager = PipelineStateManager()
|
Loading…
Add table
Add a link
Reference in a new issue