""" Model manager module. This module handles ML model loading, caching, and lifecycle management for the detection worker. """ import os import logging import threading from typing import Dict, Any, Optional, List, Set, Tuple from urllib.parse import urlparse import traceback from ..core.config import MODELS_DIR from ..core.exceptions import ModelLoadError # Setup logging logger = logging.getLogger("detector_worker.model_manager") class ModelRegistry: """ Registry for loaded models. Maintains a reference count for each model to enable sharing between multiple cameras. """ def __init__(self): """Initialize the model registry.""" self.models: Dict[str, Dict[str, Any]] = {} # model_id -> model_info self.references: Dict[str, Set[str]] = {} # model_id -> set of camera_ids self.lock = threading.Lock() def register_model(self, model_id: str, model_data: Any, model_path: str) -> None: """ Register a model in the registry. Args: model_id: Unique model identifier model_data: Loaded model data model_path: Path to model file """ with self.lock: self.models[model_id] = { "model": model_data, "path": model_path, "loaded_at": os.path.getmtime(model_path) } if model_id not in self.references: self.references[model_id] = set() def add_reference(self, model_id: str, camera_id: str) -> None: """Add a reference to a model from a camera.""" with self.lock: if model_id in self.references: self.references[model_id].add(camera_id) def remove_reference(self, model_id: str, camera_id: str) -> bool: """ Remove a reference to a model from a camera. Returns: True if model has no more references and can be unloaded """ with self.lock: if model_id in self.references: self.references[model_id].discard(camera_id) return len(self.references[model_id]) == 0 return True def get_model(self, model_id: str) -> Optional[Any]: """Get a model from the registry.""" with self.lock: model_info = self.models.get(model_id) return model_info["model"] if model_info else None def unregister_model(self, model_id: str) -> None: """Remove a model from the registry.""" with self.lock: self.models.pop(model_id, None) self.references.pop(model_id, None) def get_loaded_models(self) -> List[str]: """Get list of loaded model IDs.""" with self.lock: return list(self.models.keys()) def get_reference_count(self, model_id: str) -> int: """Get number of references to a model.""" with self.lock: return len(self.references.get(model_id, set())) def clear(self) -> None: """Clear all models from registry.""" with self.lock: self.models.clear() self.references.clear() class ModelManager: """ Manages ML model loading, caching, and lifecycle. This class handles: - Model downloading and caching - Model loading with proper error handling - Reference counting for model sharing - Model cleanup and memory management - Pipeline model tree management """ def __init__(self, models_dir: str = MODELS_DIR): """ Initialize the model manager. Args: models_dir: Directory to cache downloaded models """ self.models_dir = models_dir self.registry = ModelRegistry() self.models_lock = threading.Lock() # Camera to models mapping self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree} # Pipeline loader will be injected self.pipeline_loader = None # Create models directory if it doesn't exist os.makedirs(self.models_dir, exist_ok=True) def set_pipeline_loader(self, pipeline_loader: Any) -> None: """ Set the pipeline loader instance. Args: pipeline_loader: Pipeline loader to use for loading models """ self.pipeline_loader = pipeline_loader async def load_model( self, camera_id: str, model_id: str, model_url: str, force_reload: bool = False ) -> Any: """ Load a model for a specific camera. Args: camera_id: Camera identifier model_id: Model identifier model_url: URL or path to model file force_reload: Force reload even if cached Returns: Loaded model tree Raises: ModelLoadError: If model loading fails """ if not self.pipeline_loader: raise ModelLoadError("Pipeline loader not initialized") try: # Check if model is already loaded for this camera with self.models_lock: if camera_id in self.camera_models and model_id in self.camera_models[camera_id]: if not force_reload: logger.info(f"Model {model_id} already loaded for camera {camera_id}") return self.camera_models[camera_id][model_id] # Check if model is in registry cached_model = self.registry.get_model(model_id) if cached_model and not force_reload: # Add reference and return cached model self.registry.add_reference(model_id, camera_id) with self.models_lock: if camera_id not in self.camera_models: self.camera_models[camera_id] = {} self.camera_models[camera_id][model_id] = cached_model logger.info(f"Using cached model {model_id} for camera {camera_id}") return cached_model # Download or locate model file model_path = await self._get_model_path(model_url, model_id) # Load model using pipeline loader logger.info(f"Loading model {model_id} from {model_path}") model_tree = await self.pipeline_loader.load_pipeline(model_path) # Register in registry self.registry.register_model(model_id, model_tree, model_path) self.registry.add_reference(model_id, camera_id) # Store in camera models with self.models_lock: if camera_id not in self.camera_models: self.camera_models[camera_id] = {} self.camera_models[camera_id][model_id] = model_tree logger.info(f"Successfully loaded model {model_id} for camera {camera_id}") return model_tree except Exception as e: logger.error(f"Failed to load model {model_id}: {e}") traceback.print_exc() raise ModelLoadError(f"Failed to load model {model_id}: {e}") async def _get_model_path(self, model_url: str, model_id: str) -> str: """ Get local path for a model, downloading if necessary. Args: model_url: URL or local path to model model_id: Model identifier Returns: Local file path to model """ # Check if it's already a local path if os.path.exists(model_url): return model_url # Parse URL parsed = urlparse(model_url) # Check if it's a file:// URL if parsed.scheme == 'file': return parsed.path # For HTTP/HTTPS URLs, download to cache if parsed.scheme in ['http', 'https']: # Generate cache filename filename = os.path.basename(parsed.path) if not filename: filename = f"{model_id}.mpta" cache_path = os.path.join(self.models_dir, filename) # Check if already cached if os.path.exists(cache_path): logger.info(f"Using cached model file: {cache_path}") return cache_path # Download model logger.info(f"Downloading model from {model_url}") await self._download_model(model_url, cache_path) return cache_path # For other schemes or no scheme, assume local path return model_url async def _download_model(self, url: str, destination: str) -> None: """ Download a model file from URL. Args: url: URL to download from destination: Local path to save to """ import aiohttp import aiofiles try: async with aiohttp.ClientSession() as session: async with session.get(url) as response: response.raise_for_status() # Get total size if available total_size = response.headers.get('Content-Length') if total_size: total_size = int(total_size) logger.info(f"Downloading {total_size / (1024*1024):.2f} MB") # Download to temporary file first temp_path = f"{destination}.tmp" downloaded = 0 async with aiofiles.open(temp_path, 'wb') as f: async for chunk in response.content.iter_chunked(8192): await f.write(chunk) downloaded += len(chunk) # Log progress if total_size and downloaded % (1024 * 1024) == 0: progress = (downloaded / total_size) * 100 logger.info(f"Download progress: {progress:.1f}%") # Move to final destination os.rename(temp_path, destination) logger.info(f"Model downloaded successfully to {destination}") except Exception as e: # Clean up temporary file if exists temp_path = f"{destination}.tmp" if os.path.exists(temp_path): os.remove(temp_path) raise ModelLoadError(f"Failed to download model: {e}") def get_model(self, camera_id: str, model_id: str) -> Optional[Any]: """ Get a loaded model for a camera. Args: camera_id: Camera identifier model_id: Model identifier Returns: Model tree if loaded, None otherwise """ with self.models_lock: camera_models = self.camera_models.get(camera_id, {}) return camera_models.get(model_id) def unload_models(self, camera_id: str) -> None: """ Unload all models for a camera. Args: camera_id: Camera identifier """ with self.models_lock: if camera_id not in self.camera_models: return # Remove references for each model for model_id in self.camera_models[camera_id]: should_unload = self.registry.remove_reference(model_id, camera_id) if should_unload: logger.info(f"Unloading model {model_id} (no more references)") self.registry.unregister_model(model_id) # Clean up model if pipeline loader supports it if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_model'): try: self.pipeline_loader.cleanup_model(model_id) except Exception as e: logger.error(f"Error cleaning up model {model_id}: {e}") # Remove camera entry del self.camera_models[camera_id] logger.info(f"Unloaded all models for camera {camera_id}") def cleanup_all_models(self) -> None: """Clean up all loaded models.""" logger.info("Cleaning up all loaded models") with self.models_lock: # Get list of cameras to clean up cameras = list(self.camera_models.keys()) # Unload models for each camera for camera_id in cameras: self.unload_models(camera_id) # Clear registry self.registry.clear() # Clean up pipeline loader if it has cleanup if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_all'): try: self.pipeline_loader.cleanup_all() except Exception as e: logger.error(f"Error in pipeline loader cleanup: {e}") logger.info("Model cleanup completed") def get_loaded_models(self) -> Dict[str, List[str]]: """ Get information about loaded models. Returns: Dictionary mapping model IDs to list of camera IDs using them """ result = {} with self.models_lock: for model_id in self.registry.get_loaded_models(): cameras = [] for camera_id, models in self.camera_models.items(): if model_id in models: cameras.append(camera_id) result[model_id] = cameras return result def get_model_stats(self) -> Dict[str, Any]: """ Get statistics about loaded models. Returns: Dictionary with model statistics """ with self.models_lock: total_models = len(self.registry.get_loaded_models()) total_cameras = len(self.camera_models) # Count total model instances total_instances = sum( len(models) for models in self.camera_models.values() ) # Get cache size cache_size = 0 if os.path.exists(self.models_dir): for filename in os.listdir(self.models_dir): filepath = os.path.join(self.models_dir, filename) if os.path.isfile(filepath): cache_size += os.path.getsize(filepath) return { "total_models": total_models, "total_cameras": total_cameras, "total_instances": total_instances, "cache_size_mb": round(cache_size / (1024 * 1024), 2), "models_dir": self.models_dir } # Global model manager instance _model_manager = None def get_model_manager() -> ModelManager: """Get or create the global model manager instance.""" global _model_manager if _model_manager is None: _model_manager = ModelManager() return _model_manager # Convenience functions for backward compatibility def initialize_model_manager(models_dir: str = MODELS_DIR) -> ModelManager: """Initialize the global model manager.""" global _model_manager _model_manager = ModelManager(models_dir) return _model_manager