489 lines
No EOL
18 KiB
Python
489 lines
No EOL
18 KiB
Python
"""
|
|
Model manager module.
|
|
|
|
This module handles ML model loading, caching, and lifecycle management
|
|
for the detection worker.
|
|
"""
|
|
import os
|
|
import logging
|
|
import threading
|
|
from typing import Dict, Any, Optional, List, Set, Tuple, TYPE_CHECKING
|
|
from urllib.parse import urlparse
|
|
import traceback
|
|
|
|
from ..core.config import MODELS_DIR
|
|
from ..core.exceptions import ModelLoadError
|
|
|
|
if TYPE_CHECKING:
|
|
from .pipeline_loader import PipelineLoader
|
|
|
|
# Setup logging
|
|
logger = logging.getLogger("detector_worker.model_manager")
|
|
|
|
|
|
class ModelRegistry:
|
|
"""
|
|
Registry for loaded models.
|
|
|
|
Maintains a reference count for each model to enable sharing
|
|
between multiple cameras.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the model registry."""
|
|
self.models: Dict[str, Dict[str, Any]] = {} # model_id -> model_info
|
|
self.references: Dict[str, Set[str]] = {} # model_id -> set of camera_ids
|
|
self.lock = threading.Lock()
|
|
|
|
def register_model(self, model_id: str, model_data: Any, model_path: str) -> None:
|
|
"""
|
|
Register a model in the registry.
|
|
|
|
Args:
|
|
model_id: Unique model identifier
|
|
model_data: Loaded model data
|
|
model_path: Path to model file
|
|
"""
|
|
with self.lock:
|
|
self.models[model_id] = {
|
|
"model": model_data,
|
|
"path": model_path,
|
|
"loaded_at": os.path.getmtime(model_path)
|
|
}
|
|
if model_id not in self.references:
|
|
self.references[model_id] = set()
|
|
|
|
def add_reference(self, model_id: str, camera_id: str) -> None:
|
|
"""Add a reference to a model from a camera."""
|
|
with self.lock:
|
|
if model_id in self.references:
|
|
self.references[model_id].add(camera_id)
|
|
|
|
def remove_reference(self, model_id: str, camera_id: str) -> bool:
|
|
"""
|
|
Remove a reference to a model from a camera.
|
|
|
|
Returns:
|
|
True if model has no more references and can be unloaded
|
|
"""
|
|
with self.lock:
|
|
if model_id in self.references:
|
|
self.references[model_id].discard(camera_id)
|
|
return len(self.references[model_id]) == 0
|
|
return True
|
|
|
|
def get_model(self, model_id: str) -> Optional[Any]:
|
|
"""Get a model from the registry."""
|
|
with self.lock:
|
|
model_info = self.models.get(model_id)
|
|
return model_info["model"] if model_info else None
|
|
|
|
def unregister_model(self, model_id: str) -> None:
|
|
"""Remove a model from the registry."""
|
|
with self.lock:
|
|
self.models.pop(model_id, None)
|
|
self.references.pop(model_id, None)
|
|
|
|
def get_loaded_models(self) -> List[str]:
|
|
"""Get list of loaded model IDs."""
|
|
with self.lock:
|
|
return list(self.models.keys())
|
|
|
|
def get_reference_count(self, model_id: str) -> int:
|
|
"""Get number of references to a model."""
|
|
with self.lock:
|
|
return len(self.references.get(model_id, set()))
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all models from registry."""
|
|
with self.lock:
|
|
self.models.clear()
|
|
self.references.clear()
|
|
|
|
|
|
class ModelManager:
|
|
"""
|
|
Manages ML model loading, caching, and lifecycle.
|
|
|
|
This class handles:
|
|
- Model downloading and caching
|
|
- Model loading with proper error handling
|
|
- Reference counting for model sharing
|
|
- Model cleanup and memory management
|
|
- Pipeline model tree management
|
|
"""
|
|
|
|
def __init__(self, pipeline_loader: Optional['PipelineLoader'] = None, models_dir: str = MODELS_DIR):
|
|
"""
|
|
Initialize the model manager.
|
|
|
|
Args:
|
|
pipeline_loader: Pipeline loader for handling MPTA archives (injected via DI)
|
|
models_dir: Directory to cache downloaded models
|
|
"""
|
|
self.models_dir = models_dir
|
|
self.registry = ModelRegistry()
|
|
self.models_lock = threading.Lock()
|
|
|
|
# Camera to models mapping
|
|
self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree}
|
|
|
|
# Pipeline loader injected via dependency injection
|
|
self.pipeline_loader = pipeline_loader
|
|
|
|
# If pipeline_loader is None, try to resolve it from the container
|
|
if self.pipeline_loader is None:
|
|
try:
|
|
from ..core.dependency_injection import get_container
|
|
from .pipeline_loader import PipelineLoader
|
|
container = get_container()
|
|
self.pipeline_loader = container.resolve(PipelineLoader)
|
|
logger.info("PipelineLoader resolved from dependency container")
|
|
except Exception as e:
|
|
logger.warning(f"Could not resolve PipelineLoader from container: {e}")
|
|
|
|
# Create models directory if it doesn't exist
|
|
os.makedirs(self.models_dir, exist_ok=True)
|
|
|
|
def set_pipeline_loader(self, pipeline_loader: Any) -> None:
|
|
"""
|
|
Set the pipeline loader instance.
|
|
|
|
Args:
|
|
pipeline_loader: Pipeline loader to use for loading models
|
|
"""
|
|
self.pipeline_loader = pipeline_loader
|
|
|
|
async def load_model(
|
|
self,
|
|
camera_id: str,
|
|
model_id: str,
|
|
model_url: str,
|
|
force_reload: bool = False
|
|
) -> Any:
|
|
"""
|
|
Load a model for a specific camera.
|
|
|
|
Args:
|
|
camera_id: Camera identifier
|
|
model_id: Model identifier
|
|
model_url: URL or path to model file
|
|
force_reload: Force reload even if cached
|
|
|
|
Returns:
|
|
Loaded model tree
|
|
|
|
Raises:
|
|
ModelLoadError: If model loading fails
|
|
"""
|
|
if not self.pipeline_loader:
|
|
raise ModelLoadError("Pipeline loader not initialized")
|
|
|
|
try:
|
|
# Check if model is already loaded for this camera
|
|
with self.models_lock:
|
|
if camera_id in self.camera_models and model_id in self.camera_models[camera_id]:
|
|
if not force_reload:
|
|
logger.info(f"Model {model_id} already loaded for camera {camera_id}")
|
|
return self.camera_models[camera_id][model_id]
|
|
|
|
# Check if model is in registry
|
|
cached_model = self.registry.get_model(model_id)
|
|
if cached_model and not force_reload:
|
|
# Add reference and return cached model
|
|
self.registry.add_reference(model_id, camera_id)
|
|
with self.models_lock:
|
|
if camera_id not in self.camera_models:
|
|
self.camera_models[camera_id] = {}
|
|
self.camera_models[camera_id][model_id] = cached_model
|
|
logger.info(f"Using cached model {model_id} for camera {camera_id}")
|
|
return cached_model
|
|
|
|
# Download or locate model file
|
|
model_path = await self._get_model_path(model_url, model_id)
|
|
|
|
# Load model using pipeline loader
|
|
logger.info(f"Loading model {model_id} from {model_path}")
|
|
model_tree = await self.pipeline_loader.load_pipeline(model_path)
|
|
|
|
# Register in registry
|
|
self.registry.register_model(model_id, model_tree, model_path)
|
|
self.registry.add_reference(model_id, camera_id)
|
|
|
|
# Store in camera models
|
|
with self.models_lock:
|
|
if camera_id not in self.camera_models:
|
|
self.camera_models[camera_id] = {}
|
|
self.camera_models[camera_id][model_id] = model_tree
|
|
|
|
logger.info(f"Successfully loaded model {model_id} for camera {camera_id}")
|
|
return model_tree
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model {model_id}: {e}")
|
|
traceback.print_exc()
|
|
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
|
|
|
|
async def _get_model_path(self, model_url: str, model_id: str) -> str:
|
|
"""
|
|
Get local path for a model, downloading if necessary.
|
|
Uses model_id subfolder structure: models/{model_id}/
|
|
|
|
Args:
|
|
model_url: URL or local path to model
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
Local file path to model
|
|
"""
|
|
# Check if it's already a local path
|
|
if os.path.exists(model_url):
|
|
return model_url
|
|
|
|
# Parse URL
|
|
parsed = urlparse(model_url)
|
|
|
|
# Check if it's a file:// URL
|
|
if parsed.scheme == 'file':
|
|
return parsed.path
|
|
|
|
# For HTTP/HTTPS URLs, download to cache with model_id subfolder
|
|
if parsed.scheme in ['http', 'https']:
|
|
# Create model_id subfolder structure
|
|
model_dir = os.path.join(self.models_dir, str(model_id))
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
# Generate cache filename
|
|
filename = os.path.basename(parsed.path)
|
|
if not filename:
|
|
filename = f"model_{model_id}.mpta"
|
|
|
|
cache_path = os.path.join(model_dir, filename)
|
|
|
|
# Check if already cached
|
|
if os.path.exists(cache_path):
|
|
logger.info(f"Using cached model file: {cache_path}")
|
|
return cache_path
|
|
|
|
# Download model
|
|
logger.info(f"Downloading model {model_id} from {model_url}")
|
|
await self._download_model(model_url, cache_path)
|
|
return cache_path
|
|
|
|
# For other schemes or no scheme, assume local path
|
|
return model_url
|
|
|
|
async def _download_model(self, url: str, destination: str) -> None:
|
|
"""
|
|
Download a model file from URL with enhanced HTTP request logging.
|
|
|
|
Args:
|
|
url: URL to download from
|
|
destination: Local path to save to
|
|
"""
|
|
import aiohttp
|
|
import aiofiles
|
|
import time
|
|
|
|
# Import HTTP logger
|
|
from ..utils.logging_utils import get_http_logger
|
|
http_logger = get_http_logger()
|
|
|
|
start_time = time.time()
|
|
correlation_id = None
|
|
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
# Log request start
|
|
correlation_id = http_logger.log_request_start("GET", url)
|
|
|
|
async with session.get(url) as response:
|
|
response.raise_for_status()
|
|
|
|
# Get total size if available
|
|
total_size = response.headers.get('Content-Length')
|
|
if total_size:
|
|
total_size = int(total_size)
|
|
logger.info(f"Downloading {total_size / (1024*1024):.2f} MB")
|
|
|
|
# Download to temporary file first
|
|
temp_path = f"{destination}.tmp"
|
|
downloaded = 0
|
|
last_progress_log = 0
|
|
|
|
async with aiofiles.open(temp_path, 'wb') as f:
|
|
async for chunk in response.content.iter_chunked(8192):
|
|
await f.write(chunk)
|
|
downloaded += len(chunk)
|
|
|
|
# Log progress at 10% intervals
|
|
if total_size and downloaded > 0:
|
|
progress = (downloaded / total_size) * 100
|
|
if progress >= last_progress_log + 10 and progress <= 100:
|
|
logger.info(f"Download progress: {progress:.1f}%")
|
|
http_logger.log_download_progress(
|
|
downloaded, total_size, progress, correlation_id
|
|
)
|
|
last_progress_log = progress
|
|
|
|
# Move to final destination
|
|
os.rename(temp_path, destination)
|
|
|
|
# Log successful completion
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
http_logger.log_request_end(
|
|
response.status, downloaded, duration_ms, correlation_id
|
|
)
|
|
logger.info(f"Model downloaded successfully to {destination}")
|
|
|
|
except Exception as e:
|
|
# Log failed completion
|
|
if correlation_id:
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
http_logger.log_request_end(500, None, duration_ms, correlation_id)
|
|
|
|
# Clean up temporary file if exists
|
|
temp_path = f"{destination}.tmp"
|
|
if os.path.exists(temp_path):
|
|
os.remove(temp_path)
|
|
raise ModelLoadError(f"Failed to download model: {e}")
|
|
|
|
def get_model(self, camera_id: str, model_id: str) -> Optional[Any]:
|
|
"""
|
|
Get a loaded model for a camera.
|
|
|
|
Args:
|
|
camera_id: Camera identifier
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
Model tree if loaded, None otherwise
|
|
"""
|
|
with self.models_lock:
|
|
camera_models = self.camera_models.get(camera_id, {})
|
|
return camera_models.get(model_id)
|
|
|
|
def unload_models(self, camera_id: str) -> None:
|
|
"""
|
|
Unload all models for a camera.
|
|
|
|
Args:
|
|
camera_id: Camera identifier
|
|
"""
|
|
with self.models_lock:
|
|
if camera_id not in self.camera_models:
|
|
return
|
|
|
|
# Remove references for each model
|
|
for model_id in self.camera_models[camera_id]:
|
|
should_unload = self.registry.remove_reference(model_id, camera_id)
|
|
|
|
if should_unload:
|
|
logger.info(f"Unloading model {model_id} (no more references)")
|
|
self.registry.unregister_model(model_id)
|
|
|
|
# Clean up model if pipeline loader supports it
|
|
if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_model'):
|
|
try:
|
|
self.pipeline_loader.cleanup_model(model_id)
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up model {model_id}: {e}")
|
|
|
|
# Remove camera entry
|
|
del self.camera_models[camera_id]
|
|
logger.info(f"Unloaded all models for camera {camera_id}")
|
|
|
|
def cleanup_all_models(self) -> None:
|
|
"""Clean up all loaded models."""
|
|
logger.info("Cleaning up all loaded models")
|
|
|
|
with self.models_lock:
|
|
# Get list of cameras to clean up
|
|
cameras = list(self.camera_models.keys())
|
|
|
|
# Unload models for each camera
|
|
for camera_id in cameras:
|
|
self.unload_models(camera_id)
|
|
|
|
# Clear registry
|
|
self.registry.clear()
|
|
|
|
# Clean up pipeline loader if it has cleanup
|
|
if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_all'):
|
|
try:
|
|
self.pipeline_loader.cleanup_all()
|
|
except Exception as e:
|
|
logger.error(f"Error in pipeline loader cleanup: {e}")
|
|
|
|
logger.info("Model cleanup completed")
|
|
|
|
def get_loaded_models(self) -> Dict[str, List[str]]:
|
|
"""
|
|
Get information about loaded models.
|
|
|
|
Returns:
|
|
Dictionary mapping model IDs to list of camera IDs using them
|
|
"""
|
|
result = {}
|
|
|
|
with self.models_lock:
|
|
for model_id in self.registry.get_loaded_models():
|
|
cameras = []
|
|
for camera_id, models in self.camera_models.items():
|
|
if model_id in models:
|
|
cameras.append(camera_id)
|
|
result[model_id] = cameras
|
|
|
|
return result
|
|
|
|
def get_model_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Get statistics about loaded models.
|
|
|
|
Returns:
|
|
Dictionary with model statistics
|
|
"""
|
|
with self.models_lock:
|
|
total_models = len(self.registry.get_loaded_models())
|
|
total_cameras = len(self.camera_models)
|
|
|
|
# Count total model instances
|
|
total_instances = sum(
|
|
len(models) for models in self.camera_models.values()
|
|
)
|
|
|
|
# Get cache size
|
|
cache_size = 0
|
|
if os.path.exists(self.models_dir):
|
|
for filename in os.listdir(self.models_dir):
|
|
filepath = os.path.join(self.models_dir, filename)
|
|
if os.path.isfile(filepath):
|
|
cache_size += os.path.getsize(filepath)
|
|
|
|
return {
|
|
"total_models": total_models,
|
|
"total_cameras": total_cameras,
|
|
"total_instances": total_instances,
|
|
"cache_size_mb": round(cache_size / (1024 * 1024), 2),
|
|
"models_dir": self.models_dir
|
|
}
|
|
|
|
|
|
# Global model manager instance
|
|
_model_manager = None
|
|
|
|
|
|
def get_model_manager() -> ModelManager:
|
|
"""Get or create the global model manager instance."""
|
|
global _model_manager
|
|
if _model_manager is None:
|
|
_model_manager = ModelManager()
|
|
return _model_manager
|
|
|
|
|
|
# Convenience functions for backward compatibility
|
|
def initialize_model_manager(models_dir: str = MODELS_DIR) -> ModelManager:
|
|
"""Initialize the global model manager."""
|
|
global _model_manager
|
|
_model_manager = ModelManager(models_dir)
|
|
return _model_manager |