Refactor: Phase 4: Communication Layer
This commit is contained in:
parent
cdeaaf4a4f
commit
54f21672aa
6 changed files with 2876 additions and 0 deletions
441
detector_worker/models/model_manager.py
Normal file
441
detector_worker/models/model_manager.py
Normal file
|
@ -0,0 +1,441 @@
|
|||
"""
|
||||
Model manager module.
|
||||
|
||||
This module handles ML model loading, caching, and lifecycle management
|
||||
for the detection worker.
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Any, Optional, List, Set, Tuple
|
||||
from urllib.parse import urlparse
|
||||
import traceback
|
||||
|
||||
from ..core.config import MODELS_DIR
|
||||
from ..core.exceptions import ModelLoadError
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger("detector_worker.model_manager")
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""
|
||||
Registry for loaded models.
|
||||
|
||||
Maintains a reference count for each model to enable sharing
|
||||
between multiple cameras.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the model registry."""
|
||||
self.models: Dict[str, Dict[str, Any]] = {} # model_id -> model_info
|
||||
self.references: Dict[str, Set[str]] = {} # model_id -> set of camera_ids
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def register_model(self, model_id: str, model_data: Any, model_path: str) -> None:
|
||||
"""
|
||||
Register a model in the registry.
|
||||
|
||||
Args:
|
||||
model_id: Unique model identifier
|
||||
model_data: Loaded model data
|
||||
model_path: Path to model file
|
||||
"""
|
||||
with self.lock:
|
||||
self.models[model_id] = {
|
||||
"model": model_data,
|
||||
"path": model_path,
|
||||
"loaded_at": os.path.getmtime(model_path)
|
||||
}
|
||||
if model_id not in self.references:
|
||||
self.references[model_id] = set()
|
||||
|
||||
def add_reference(self, model_id: str, camera_id: str) -> None:
|
||||
"""Add a reference to a model from a camera."""
|
||||
with self.lock:
|
||||
if model_id in self.references:
|
||||
self.references[model_id].add(camera_id)
|
||||
|
||||
def remove_reference(self, model_id: str, camera_id: str) -> bool:
|
||||
"""
|
||||
Remove a reference to a model from a camera.
|
||||
|
||||
Returns:
|
||||
True if model has no more references and can be unloaded
|
||||
"""
|
||||
with self.lock:
|
||||
if model_id in self.references:
|
||||
self.references[model_id].discard(camera_id)
|
||||
return len(self.references[model_id]) == 0
|
||||
return True
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[Any]:
|
||||
"""Get a model from the registry."""
|
||||
with self.lock:
|
||||
model_info = self.models.get(model_id)
|
||||
return model_info["model"] if model_info else None
|
||||
|
||||
def unregister_model(self, model_id: str) -> None:
|
||||
"""Remove a model from the registry."""
|
||||
with self.lock:
|
||||
self.models.pop(model_id, None)
|
||||
self.references.pop(model_id, None)
|
||||
|
||||
def get_loaded_models(self) -> List[str]:
|
||||
"""Get list of loaded model IDs."""
|
||||
with self.lock:
|
||||
return list(self.models.keys())
|
||||
|
||||
def get_reference_count(self, model_id: str) -> int:
|
||||
"""Get number of references to a model."""
|
||||
with self.lock:
|
||||
return len(self.references.get(model_id, set()))
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all models from registry."""
|
||||
with self.lock:
|
||||
self.models.clear()
|
||||
self.references.clear()
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""
|
||||
Manages ML model loading, caching, and lifecycle.
|
||||
|
||||
This class handles:
|
||||
- Model downloading and caching
|
||||
- Model loading with proper error handling
|
||||
- Reference counting for model sharing
|
||||
- Model cleanup and memory management
|
||||
- Pipeline model tree management
|
||||
"""
|
||||
|
||||
def __init__(self, models_dir: str = MODELS_DIR):
|
||||
"""
|
||||
Initialize the model manager.
|
||||
|
||||
Args:
|
||||
models_dir: Directory to cache downloaded models
|
||||
"""
|
||||
self.models_dir = models_dir
|
||||
self.registry = ModelRegistry()
|
||||
self.models_lock = threading.Lock()
|
||||
|
||||
# Camera to models mapping
|
||||
self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree}
|
||||
|
||||
# Pipeline loader will be injected
|
||||
self.pipeline_loader = None
|
||||
|
||||
# Create models directory if it doesn't exist
|
||||
os.makedirs(self.models_dir, exist_ok=True)
|
||||
|
||||
def set_pipeline_loader(self, pipeline_loader: Any) -> None:
|
||||
"""
|
||||
Set the pipeline loader instance.
|
||||
|
||||
Args:
|
||||
pipeline_loader: Pipeline loader to use for loading models
|
||||
"""
|
||||
self.pipeline_loader = pipeline_loader
|
||||
|
||||
async def load_model(
|
||||
self,
|
||||
camera_id: str,
|
||||
model_id: str,
|
||||
model_url: str,
|
||||
force_reload: bool = False
|
||||
) -> Any:
|
||||
"""
|
||||
Load a model for a specific camera.
|
||||
|
||||
Args:
|
||||
camera_id: Camera identifier
|
||||
model_id: Model identifier
|
||||
model_url: URL or path to model file
|
||||
force_reload: Force reload even if cached
|
||||
|
||||
Returns:
|
||||
Loaded model tree
|
||||
|
||||
Raises:
|
||||
ModelLoadError: If model loading fails
|
||||
"""
|
||||
if not self.pipeline_loader:
|
||||
raise ModelLoadError("Pipeline loader not initialized")
|
||||
|
||||
try:
|
||||
# Check if model is already loaded for this camera
|
||||
with self.models_lock:
|
||||
if camera_id in self.camera_models and model_id in self.camera_models[camera_id]:
|
||||
if not force_reload:
|
||||
logger.info(f"Model {model_id} already loaded for camera {camera_id}")
|
||||
return self.camera_models[camera_id][model_id]
|
||||
|
||||
# Check if model is in registry
|
||||
cached_model = self.registry.get_model(model_id)
|
||||
if cached_model and not force_reload:
|
||||
# Add reference and return cached model
|
||||
self.registry.add_reference(model_id, camera_id)
|
||||
with self.models_lock:
|
||||
if camera_id not in self.camera_models:
|
||||
self.camera_models[camera_id] = {}
|
||||
self.camera_models[camera_id][model_id] = cached_model
|
||||
logger.info(f"Using cached model {model_id} for camera {camera_id}")
|
||||
return cached_model
|
||||
|
||||
# Download or locate model file
|
||||
model_path = await self._get_model_path(model_url, model_id)
|
||||
|
||||
# Load model using pipeline loader
|
||||
logger.info(f"Loading model {model_id} from {model_path}")
|
||||
model_tree = await self.pipeline_loader.load_pipeline(model_path)
|
||||
|
||||
# Register in registry
|
||||
self.registry.register_model(model_id, model_tree, model_path)
|
||||
self.registry.add_reference(model_id, camera_id)
|
||||
|
||||
# Store in camera models
|
||||
with self.models_lock:
|
||||
if camera_id not in self.camera_models:
|
||||
self.camera_models[camera_id] = {}
|
||||
self.camera_models[camera_id][model_id] = model_tree
|
||||
|
||||
logger.info(f"Successfully loaded model {model_id} for camera {camera_id}")
|
||||
return model_tree
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model {model_id}: {e}")
|
||||
traceback.print_exc()
|
||||
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
|
||||
|
||||
async def _get_model_path(self, model_url: str, model_id: str) -> str:
|
||||
"""
|
||||
Get local path for a model, downloading if necessary.
|
||||
|
||||
Args:
|
||||
model_url: URL or local path to model
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Local file path to model
|
||||
"""
|
||||
# Check if it's already a local path
|
||||
if os.path.exists(model_url):
|
||||
return model_url
|
||||
|
||||
# Parse URL
|
||||
parsed = urlparse(model_url)
|
||||
|
||||
# Check if it's a file:// URL
|
||||
if parsed.scheme == 'file':
|
||||
return parsed.path
|
||||
|
||||
# For HTTP/HTTPS URLs, download to cache
|
||||
if parsed.scheme in ['http', 'https']:
|
||||
# Generate cache filename
|
||||
filename = os.path.basename(parsed.path)
|
||||
if not filename:
|
||||
filename = f"{model_id}.mpta"
|
||||
|
||||
cache_path = os.path.join(self.models_dir, filename)
|
||||
|
||||
# Check if already cached
|
||||
if os.path.exists(cache_path):
|
||||
logger.info(f"Using cached model file: {cache_path}")
|
||||
return cache_path
|
||||
|
||||
# Download model
|
||||
logger.info(f"Downloading model from {model_url}")
|
||||
await self._download_model(model_url, cache_path)
|
||||
return cache_path
|
||||
|
||||
# For other schemes or no scheme, assume local path
|
||||
return model_url
|
||||
|
||||
async def _download_model(self, url: str, destination: str) -> None:
|
||||
"""
|
||||
Download a model file from URL.
|
||||
|
||||
Args:
|
||||
url: URL to download from
|
||||
destination: Local path to save to
|
||||
"""
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# Get total size if available
|
||||
total_size = response.headers.get('Content-Length')
|
||||
if total_size:
|
||||
total_size = int(total_size)
|
||||
logger.info(f"Downloading {total_size / (1024*1024):.2f} MB")
|
||||
|
||||
# Download to temporary file first
|
||||
temp_path = f"{destination}.tmp"
|
||||
downloaded = 0
|
||||
|
||||
async with aiofiles.open(temp_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
await f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
|
||||
# Log progress
|
||||
if total_size and downloaded % (1024 * 1024) == 0:
|
||||
progress = (downloaded / total_size) * 100
|
||||
logger.info(f"Download progress: {progress:.1f}%")
|
||||
|
||||
# Move to final destination
|
||||
os.rename(temp_path, destination)
|
||||
logger.info(f"Model downloaded successfully to {destination}")
|
||||
|
||||
except Exception as e:
|
||||
# Clean up temporary file if exists
|
||||
temp_path = f"{destination}.tmp"
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
raise ModelLoadError(f"Failed to download model: {e}")
|
||||
|
||||
def get_model(self, camera_id: str, model_id: str) -> Optional[Any]:
|
||||
"""
|
||||
Get a loaded model for a camera.
|
||||
|
||||
Args:
|
||||
camera_id: Camera identifier
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Model tree if loaded, None otherwise
|
||||
"""
|
||||
with self.models_lock:
|
||||
camera_models = self.camera_models.get(camera_id, {})
|
||||
return camera_models.get(model_id)
|
||||
|
||||
def unload_models(self, camera_id: str) -> None:
|
||||
"""
|
||||
Unload all models for a camera.
|
||||
|
||||
Args:
|
||||
camera_id: Camera identifier
|
||||
"""
|
||||
with self.models_lock:
|
||||
if camera_id not in self.camera_models:
|
||||
return
|
||||
|
||||
# Remove references for each model
|
||||
for model_id in self.camera_models[camera_id]:
|
||||
should_unload = self.registry.remove_reference(model_id, camera_id)
|
||||
|
||||
if should_unload:
|
||||
logger.info(f"Unloading model {model_id} (no more references)")
|
||||
self.registry.unregister_model(model_id)
|
||||
|
||||
# Clean up model if pipeline loader supports it
|
||||
if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_model'):
|
||||
try:
|
||||
self.pipeline_loader.cleanup_model(model_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up model {model_id}: {e}")
|
||||
|
||||
# Remove camera entry
|
||||
del self.camera_models[camera_id]
|
||||
logger.info(f"Unloaded all models for camera {camera_id}")
|
||||
|
||||
def cleanup_all_models(self) -> None:
|
||||
"""Clean up all loaded models."""
|
||||
logger.info("Cleaning up all loaded models")
|
||||
|
||||
with self.models_lock:
|
||||
# Get list of cameras to clean up
|
||||
cameras = list(self.camera_models.keys())
|
||||
|
||||
# Unload models for each camera
|
||||
for camera_id in cameras:
|
||||
self.unload_models(camera_id)
|
||||
|
||||
# Clear registry
|
||||
self.registry.clear()
|
||||
|
||||
# Clean up pipeline loader if it has cleanup
|
||||
if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_all'):
|
||||
try:
|
||||
self.pipeline_loader.cleanup_all()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pipeline loader cleanup: {e}")
|
||||
|
||||
logger.info("Model cleanup completed")
|
||||
|
||||
def get_loaded_models(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get information about loaded models.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model IDs to list of camera IDs using them
|
||||
"""
|
||||
result = {}
|
||||
|
||||
with self.models_lock:
|
||||
for model_id in self.registry.get_loaded_models():
|
||||
cameras = []
|
||||
for camera_id, models in self.camera_models.items():
|
||||
if model_id in models:
|
||||
cameras.append(camera_id)
|
||||
result[model_id] = cameras
|
||||
|
||||
return result
|
||||
|
||||
def get_model_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about loaded models.
|
||||
|
||||
Returns:
|
||||
Dictionary with model statistics
|
||||
"""
|
||||
with self.models_lock:
|
||||
total_models = len(self.registry.get_loaded_models())
|
||||
total_cameras = len(self.camera_models)
|
||||
|
||||
# Count total model instances
|
||||
total_instances = sum(
|
||||
len(models) for models in self.camera_models.values()
|
||||
)
|
||||
|
||||
# Get cache size
|
||||
cache_size = 0
|
||||
if os.path.exists(self.models_dir):
|
||||
for filename in os.listdir(self.models_dir):
|
||||
filepath = os.path.join(self.models_dir, filename)
|
||||
if os.path.isfile(filepath):
|
||||
cache_size += os.path.getsize(filepath)
|
||||
|
||||
return {
|
||||
"total_models": total_models,
|
||||
"total_cameras": total_cameras,
|
||||
"total_instances": total_instances,
|
||||
"cache_size_mb": round(cache_size / (1024 * 1024), 2),
|
||||
"models_dir": self.models_dir
|
||||
}
|
||||
|
||||
|
||||
# Global model manager instance
|
||||
_model_manager = None
|
||||
|
||||
|
||||
def get_model_manager() -> ModelManager:
|
||||
"""Get or create the global model manager instance."""
|
||||
global _model_manager
|
||||
if _model_manager is None:
|
||||
_model_manager = ModelManager()
|
||||
return _model_manager
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def initialize_model_manager(models_dir: str = MODELS_DIR) -> ModelManager:
|
||||
"""Initialize the global model manager."""
|
||||
global _model_manager
|
||||
_model_manager = ModelManager(models_dir)
|
||||
return _model_manager
|
Loading…
Add table
Add a link
Reference in a new issue