Refactor: Phase 4: Communication Layer

This commit is contained in:
ziesorx 2025-09-12 15:26:31 +07:00
parent cdeaaf4a4f
commit 54f21672aa
6 changed files with 2876 additions and 0 deletions

View file

@ -0,0 +1,441 @@
"""
Model manager module.
This module handles ML model loading, caching, and lifecycle management
for the detection worker.
"""
import os
import logging
import threading
from typing import Dict, Any, Optional, List, Set, Tuple
from urllib.parse import urlparse
import traceback
from ..core.config import MODELS_DIR
from ..core.exceptions import ModelLoadError
# Setup logging
logger = logging.getLogger("detector_worker.model_manager")
class ModelRegistry:
"""
Registry for loaded models.
Maintains a reference count for each model to enable sharing
between multiple cameras.
"""
def __init__(self):
"""Initialize the model registry."""
self.models: Dict[str, Dict[str, Any]] = {} # model_id -> model_info
self.references: Dict[str, Set[str]] = {} # model_id -> set of camera_ids
self.lock = threading.Lock()
def register_model(self, model_id: str, model_data: Any, model_path: str) -> None:
"""
Register a model in the registry.
Args:
model_id: Unique model identifier
model_data: Loaded model data
model_path: Path to model file
"""
with self.lock:
self.models[model_id] = {
"model": model_data,
"path": model_path,
"loaded_at": os.path.getmtime(model_path)
}
if model_id not in self.references:
self.references[model_id] = set()
def add_reference(self, model_id: str, camera_id: str) -> None:
"""Add a reference to a model from a camera."""
with self.lock:
if model_id in self.references:
self.references[model_id].add(camera_id)
def remove_reference(self, model_id: str, camera_id: str) -> bool:
"""
Remove a reference to a model from a camera.
Returns:
True if model has no more references and can be unloaded
"""
with self.lock:
if model_id in self.references:
self.references[model_id].discard(camera_id)
return len(self.references[model_id]) == 0
return True
def get_model(self, model_id: str) -> Optional[Any]:
"""Get a model from the registry."""
with self.lock:
model_info = self.models.get(model_id)
return model_info["model"] if model_info else None
def unregister_model(self, model_id: str) -> None:
"""Remove a model from the registry."""
with self.lock:
self.models.pop(model_id, None)
self.references.pop(model_id, None)
def get_loaded_models(self) -> List[str]:
"""Get list of loaded model IDs."""
with self.lock:
return list(self.models.keys())
def get_reference_count(self, model_id: str) -> int:
"""Get number of references to a model."""
with self.lock:
return len(self.references.get(model_id, set()))
def clear(self) -> None:
"""Clear all models from registry."""
with self.lock:
self.models.clear()
self.references.clear()
class ModelManager:
"""
Manages ML model loading, caching, and lifecycle.
This class handles:
- Model downloading and caching
- Model loading with proper error handling
- Reference counting for model sharing
- Model cleanup and memory management
- Pipeline model tree management
"""
def __init__(self, models_dir: str = MODELS_DIR):
"""
Initialize the model manager.
Args:
models_dir: Directory to cache downloaded models
"""
self.models_dir = models_dir
self.registry = ModelRegistry()
self.models_lock = threading.Lock()
# Camera to models mapping
self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree}
# Pipeline loader will be injected
self.pipeline_loader = None
# Create models directory if it doesn't exist
os.makedirs(self.models_dir, exist_ok=True)
def set_pipeline_loader(self, pipeline_loader: Any) -> None:
"""
Set the pipeline loader instance.
Args:
pipeline_loader: Pipeline loader to use for loading models
"""
self.pipeline_loader = pipeline_loader
async def load_model(
self,
camera_id: str,
model_id: str,
model_url: str,
force_reload: bool = False
) -> Any:
"""
Load a model for a specific camera.
Args:
camera_id: Camera identifier
model_id: Model identifier
model_url: URL or path to model file
force_reload: Force reload even if cached
Returns:
Loaded model tree
Raises:
ModelLoadError: If model loading fails
"""
if not self.pipeline_loader:
raise ModelLoadError("Pipeline loader not initialized")
try:
# Check if model is already loaded for this camera
with self.models_lock:
if camera_id in self.camera_models and model_id in self.camera_models[camera_id]:
if not force_reload:
logger.info(f"Model {model_id} already loaded for camera {camera_id}")
return self.camera_models[camera_id][model_id]
# Check if model is in registry
cached_model = self.registry.get_model(model_id)
if cached_model and not force_reload:
# Add reference and return cached model
self.registry.add_reference(model_id, camera_id)
with self.models_lock:
if camera_id not in self.camera_models:
self.camera_models[camera_id] = {}
self.camera_models[camera_id][model_id] = cached_model
logger.info(f"Using cached model {model_id} for camera {camera_id}")
return cached_model
# Download or locate model file
model_path = await self._get_model_path(model_url, model_id)
# Load model using pipeline loader
logger.info(f"Loading model {model_id} from {model_path}")
model_tree = await self.pipeline_loader.load_pipeline(model_path)
# Register in registry
self.registry.register_model(model_id, model_tree, model_path)
self.registry.add_reference(model_id, camera_id)
# Store in camera models
with self.models_lock:
if camera_id not in self.camera_models:
self.camera_models[camera_id] = {}
self.camera_models[camera_id][model_id] = model_tree
logger.info(f"Successfully loaded model {model_id} for camera {camera_id}")
return model_tree
except Exception as e:
logger.error(f"Failed to load model {model_id}: {e}")
traceback.print_exc()
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
async def _get_model_path(self, model_url: str, model_id: str) -> str:
"""
Get local path for a model, downloading if necessary.
Args:
model_url: URL or local path to model
model_id: Model identifier
Returns:
Local file path to model
"""
# Check if it's already a local path
if os.path.exists(model_url):
return model_url
# Parse URL
parsed = urlparse(model_url)
# Check if it's a file:// URL
if parsed.scheme == 'file':
return parsed.path
# For HTTP/HTTPS URLs, download to cache
if parsed.scheme in ['http', 'https']:
# Generate cache filename
filename = os.path.basename(parsed.path)
if not filename:
filename = f"{model_id}.mpta"
cache_path = os.path.join(self.models_dir, filename)
# Check if already cached
if os.path.exists(cache_path):
logger.info(f"Using cached model file: {cache_path}")
return cache_path
# Download model
logger.info(f"Downloading model from {model_url}")
await self._download_model(model_url, cache_path)
return cache_path
# For other schemes or no scheme, assume local path
return model_url
async def _download_model(self, url: str, destination: str) -> None:
"""
Download a model file from URL.
Args:
url: URL to download from
destination: Local path to save to
"""
import aiohttp
import aiofiles
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
response.raise_for_status()
# Get total size if available
total_size = response.headers.get('Content-Length')
if total_size:
total_size = int(total_size)
logger.info(f"Downloading {total_size / (1024*1024):.2f} MB")
# Download to temporary file first
temp_path = f"{destination}.tmp"
downloaded = 0
async with aiofiles.open(temp_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
await f.write(chunk)
downloaded += len(chunk)
# Log progress
if total_size and downloaded % (1024 * 1024) == 0:
progress = (downloaded / total_size) * 100
logger.info(f"Download progress: {progress:.1f}%")
# Move to final destination
os.rename(temp_path, destination)
logger.info(f"Model downloaded successfully to {destination}")
except Exception as e:
# Clean up temporary file if exists
temp_path = f"{destination}.tmp"
if os.path.exists(temp_path):
os.remove(temp_path)
raise ModelLoadError(f"Failed to download model: {e}")
def get_model(self, camera_id: str, model_id: str) -> Optional[Any]:
"""
Get a loaded model for a camera.
Args:
camera_id: Camera identifier
model_id: Model identifier
Returns:
Model tree if loaded, None otherwise
"""
with self.models_lock:
camera_models = self.camera_models.get(camera_id, {})
return camera_models.get(model_id)
def unload_models(self, camera_id: str) -> None:
"""
Unload all models for a camera.
Args:
camera_id: Camera identifier
"""
with self.models_lock:
if camera_id not in self.camera_models:
return
# Remove references for each model
for model_id in self.camera_models[camera_id]:
should_unload = self.registry.remove_reference(model_id, camera_id)
if should_unload:
logger.info(f"Unloading model {model_id} (no more references)")
self.registry.unregister_model(model_id)
# Clean up model if pipeline loader supports it
if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_model'):
try:
self.pipeline_loader.cleanup_model(model_id)
except Exception as e:
logger.error(f"Error cleaning up model {model_id}: {e}")
# Remove camera entry
del self.camera_models[camera_id]
logger.info(f"Unloaded all models for camera {camera_id}")
def cleanup_all_models(self) -> None:
"""Clean up all loaded models."""
logger.info("Cleaning up all loaded models")
with self.models_lock:
# Get list of cameras to clean up
cameras = list(self.camera_models.keys())
# Unload models for each camera
for camera_id in cameras:
self.unload_models(camera_id)
# Clear registry
self.registry.clear()
# Clean up pipeline loader if it has cleanup
if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_all'):
try:
self.pipeline_loader.cleanup_all()
except Exception as e:
logger.error(f"Error in pipeline loader cleanup: {e}")
logger.info("Model cleanup completed")
def get_loaded_models(self) -> Dict[str, List[str]]:
"""
Get information about loaded models.
Returns:
Dictionary mapping model IDs to list of camera IDs using them
"""
result = {}
with self.models_lock:
for model_id in self.registry.get_loaded_models():
cameras = []
for camera_id, models in self.camera_models.items():
if model_id in models:
cameras.append(camera_id)
result[model_id] = cameras
return result
def get_model_stats(self) -> Dict[str, Any]:
"""
Get statistics about loaded models.
Returns:
Dictionary with model statistics
"""
with self.models_lock:
total_models = len(self.registry.get_loaded_models())
total_cameras = len(self.camera_models)
# Count total model instances
total_instances = sum(
len(models) for models in self.camera_models.values()
)
# Get cache size
cache_size = 0
if os.path.exists(self.models_dir):
for filename in os.listdir(self.models_dir):
filepath = os.path.join(self.models_dir, filename)
if os.path.isfile(filepath):
cache_size += os.path.getsize(filepath)
return {
"total_models": total_models,
"total_cameras": total_cameras,
"total_instances": total_instances,
"cache_size_mb": round(cache_size / (1024 * 1024), 2),
"models_dir": self.models_dir
}
# Global model manager instance
_model_manager = None
def get_model_manager() -> ModelManager:
"""Get or create the global model manager instance."""
global _model_manager
if _model_manager is None:
_model_manager = ModelManager()
return _model_manager
# Convenience functions for backward compatibility
def initialize_model_manager(models_dir: str = MODELS_DIR) -> ModelManager:
"""Initialize the global model manager."""
global _model_manager
_model_manager = ModelManager(models_dir)
return _model_manager

View file

@ -0,0 +1,486 @@
"""
Pipeline loader module.
This module handles loading and parsing of MPTA (Machine Learning Pipeline Archive)
files, which contain model configurations and pipeline definitions.
"""
import os
import json
import logging
import zipfile
import tempfile
import shutil
from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass, field
from pathlib import Path
from ..core.exceptions import ModelLoadError, PipelineError
# Setup logging
logger = logging.getLogger("detector_worker.pipeline_loader")
@dataclass
class PipelineNode:
"""Represents a node in the pipeline tree."""
model_id: str
model_file: str
model_path: Optional[str] = None
model: Optional[Any] = None # Loaded model instance
# Node configuration
multi_class: bool = False
expected_classes: List[str] = field(default_factory=list)
trigger_classes: List[str] = field(default_factory=list)
min_confidence: float = 0.5
max_detections: Optional[int] = None
# Cropping configuration
crop: bool = False
crop_class: Optional[str] = None
crop_expand_ratio: float = 1.0
# Actions configuration
actions: List[Dict[str, Any]] = field(default_factory=list)
parallel_actions: List[Dict[str, Any]] = field(default_factory=list)
# Branch configuration
branches: List['PipelineNode'] = field(default_factory=list)
parallel: bool = False
# Detection settings
yolo_settings: Dict[str, Any] = field(default_factory=dict)
track_classes: Optional[List[str]] = None
# Metadata
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class PipelineConfig:
"""Pipeline configuration from pipeline.json."""
pipeline_id: str
version: str = "1.0"
description: str = ""
# Database configuration
database_config: Optional[Dict[str, Any]] = None
# Redis configuration
redis_config: Optional[Dict[str, Any]] = None
# Global settings
global_settings: Dict[str, Any] = field(default_factory=dict)
# Root pipeline node
root: Optional[PipelineNode] = None
class PipelineLoader:
"""
Loads and manages ML pipeline configurations.
This class handles:
- MPTA file extraction and parsing
- Pipeline configuration validation
- Model file management
- Pipeline tree construction
- Resource cleanup
"""
def __init__(self, temp_dir: Optional[str] = None):
"""
Initialize the pipeline loader.
Args:
temp_dir: Temporary directory for extracting MPTA files
"""
self.temp_dir = temp_dir or tempfile.gettempdir()
self.extracted_paths: Dict[str, str] = {} # mpta_path -> extracted_dir
self.loaded_models: Dict[str, Any] = {} # model_path -> model_instance
async def load_pipeline(self, mpta_path: str) -> PipelineNode:
"""
Load a pipeline from an MPTA file.
Args:
mpta_path: Path to MPTA file
Returns:
Root pipeline node
Raises:
ModelLoadError: If loading fails
"""
try:
# Extract MPTA if not already extracted
extracted_dir = await self._extract_mpta(mpta_path)
# Load pipeline configuration
pipeline_json_path = os.path.join(extracted_dir, "pipeline.json")
if not os.path.exists(pipeline_json_path):
raise ModelLoadError(f"pipeline.json not found in {mpta_path}")
with open(pipeline_json_path, 'r') as f:
config_data = json.load(f)
# Parse pipeline configuration
pipeline_config = self._parse_pipeline_config(config_data, extracted_dir)
# Validate pipeline
self._validate_pipeline(pipeline_config)
# Load models for the pipeline
await self._load_pipeline_models(pipeline_config.root, extracted_dir)
logger.info(f"Successfully loaded pipeline from {mpta_path}")
return pipeline_config.root
except Exception as e:
logger.error(f"Failed to load pipeline from {mpta_path}: {e}")
raise ModelLoadError(f"Failed to load pipeline: {e}")
async def _extract_mpta(self, mpta_path: str) -> str:
"""
Extract MPTA file to temporary directory.
Args:
mpta_path: Path to MPTA file
Returns:
Path to extracted directory
"""
# Check if already extracted
if mpta_path in self.extracted_paths:
extracted_dir = self.extracted_paths[mpta_path]
if os.path.exists(extracted_dir):
return extracted_dir
# Create extraction directory
mpta_name = os.path.splitext(os.path.basename(mpta_path))[0]
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
# Extract MPTA
logger.info(f"Extracting MPTA file: {mpta_path}")
try:
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
# Clean existing directory if exists
if os.path.exists(extracted_dir):
shutil.rmtree(extracted_dir)
os.makedirs(extracted_dir)
zip_ref.extractall(extracted_dir)
self.extracted_paths[mpta_path] = extracted_dir
logger.info(f"Extracted to: {extracted_dir}")
return extracted_dir
except Exception as e:
raise ModelLoadError(f"Failed to extract MPTA: {e}")
def _parse_pipeline_config(
self,
config_data: Dict[str, Any],
base_dir: str
) -> PipelineConfig:
"""
Parse pipeline configuration from JSON.
Args:
config_data: Pipeline JSON data
base_dir: Base directory for model files
Returns:
Parsed pipeline configuration
"""
# Create pipeline config
pipeline_config = PipelineConfig(
pipeline_id=config_data.get("pipelineId", "unknown"),
version=config_data.get("version", "1.0"),
description=config_data.get("description", "")
)
# Parse database config
if "database" in config_data:
pipeline_config.database_config = config_data["database"]
# Parse Redis config
if "redis" in config_data:
pipeline_config.redis_config = config_data["redis"]
# Parse global settings
if "globalSettings" in config_data:
pipeline_config.global_settings = config_data["globalSettings"]
# Parse pipeline tree
if "pipeline" in config_data:
pipeline_config.root = self._parse_pipeline_node(
config_data["pipeline"], base_dir
)
elif "root" in config_data:
pipeline_config.root = self._parse_pipeline_node(
config_data["root"], base_dir
)
else:
raise PipelineError("No pipeline or root node found in configuration")
return pipeline_config
def _parse_pipeline_node(
self,
node_data: Dict[str, Any],
base_dir: str
) -> PipelineNode:
"""
Parse a pipeline node from configuration.
Args:
node_data: Node configuration data
base_dir: Base directory for model files
Returns:
Parsed pipeline node
"""
# Create node
node = PipelineNode(
model_id=node_data.get("modelId", ""),
model_file=node_data.get("modelFile", "")
)
# Set model path
if node.model_file:
node.model_path = os.path.join(base_dir, node.model_file)
# Parse configuration
node.multi_class = node_data.get("multiClass", False)
node.expected_classes = node_data.get("expectedClasses", [])
node.trigger_classes = node_data.get("triggerClasses", [])
node.min_confidence = node_data.get("minConfidence", 0.5)
node.max_detections = node_data.get("maxDetections")
# Parse cropping
node.crop = node_data.get("crop", False)
node.crop_class = node_data.get("cropClass")
node.crop_expand_ratio = node_data.get("cropExpandRatio", 1.0)
# Parse actions
node.actions = node_data.get("actions", [])
node.parallel_actions = node_data.get("parallelActions", [])
# Parse YOLO settings
if "yoloSettings" in node_data:
node.yolo_settings = node_data["yoloSettings"]
elif "detectionSettings" in node_data:
node.yolo_settings = node_data["detectionSettings"]
# Parse tracking
node.track_classes = node_data.get("trackClasses")
# Parse metadata
node.metadata = node_data.get("metadata", {})
# Parse branches
branches_data = node_data.get("branches", [])
node.parallel = node_data.get("parallel", False)
for branch_data in branches_data:
branch_node = self._parse_pipeline_node(branch_data, base_dir)
node.branches.append(branch_node)
return node
def _validate_pipeline(self, pipeline_config: PipelineConfig) -> None:
"""
Validate pipeline configuration.
Args:
pipeline_config: Pipeline configuration to validate
Raises:
PipelineError: If validation fails
"""
if not pipeline_config.root:
raise PipelineError("Pipeline has no root node")
# Validate root node
self._validate_node(pipeline_config.root)
def _validate_node(self, node: PipelineNode) -> None:
"""
Validate a pipeline node.
Args:
node: Node to validate
Raises:
PipelineError: If validation fails
"""
# Check required fields
if not node.model_id:
raise PipelineError("Node missing modelId")
if not node.model_file and not node.model:
raise PipelineError(f"Node {node.model_id} missing modelFile")
# Validate model path exists
if node.model_path and not os.path.exists(node.model_path):
raise PipelineError(f"Model file not found: {node.model_path}")
# Validate cropping configuration
if node.crop and not node.crop_class:
raise PipelineError(f"Node {node.model_id} has crop=true but no cropClass")
# Validate confidence
if not 0 <= node.min_confidence <= 1:
raise PipelineError(f"Invalid minConfidence: {node.min_confidence}")
# Validate branches
for branch in node.branches:
self._validate_node(branch)
async def _load_pipeline_models(
self,
node: PipelineNode,
base_dir: str
) -> None:
"""
Load models for a pipeline node and its branches.
Args:
node: Pipeline node
base_dir: Base directory for models
"""
# Load model for this node if path is specified
if node.model_path:
node.model = await self._load_model(node.model_path, node.model_id)
# Load models for branches
for branch in node.branches:
await self._load_pipeline_models(branch, base_dir)
async def _load_model(self, model_path: str, model_id: str) -> Any:
"""
Load a single model file.
Args:
model_path: Path to model file
model_id: Model identifier
Returns:
Loaded model instance
"""
# Check if already loaded
if model_path in self.loaded_models:
logger.info(f"Using cached model: {model_id}")
return self.loaded_models[model_path]
try:
# Import here to avoid circular dependency
from ultralytics import YOLO
logger.info(f"Loading model: {model_id} from {model_path}")
# Load YOLO model
model = YOLO(model_path)
# Cache the model
self.loaded_models[model_path] = model
return model
except Exception as e:
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
def cleanup_model(self, model_id: str) -> None:
"""
Clean up resources for a specific model.
Args:
model_id: Model identifier to clean up
"""
# Clean up loaded models
models_to_remove = []
for path, model in self.loaded_models.items():
if model_id in path:
models_to_remove.append(path)
for path in models_to_remove:
self.loaded_models.pop(path, None)
logger.info(f"Cleaned up model: {path}")
def cleanup_all(self) -> None:
"""Clean up all resources."""
# Clear loaded models
self.loaded_models.clear()
# Clean up extracted directories
for mpta_path, extracted_dir in self.extracted_paths.items():
if os.path.exists(extracted_dir):
try:
shutil.rmtree(extracted_dir)
logger.info(f"Cleaned up extracted directory: {extracted_dir}")
except Exception as e:
logger.error(f"Failed to clean up {extracted_dir}: {e}")
self.extracted_paths.clear()
def get_node_info(self, node: PipelineNode, level: int = 0) -> str:
"""
Get formatted information about a pipeline node.
Args:
node: Pipeline node
level: Indentation level
Returns:
Formatted node information
"""
indent = " " * level
info = []
info.append(f"{indent}Model: {node.model_id}")
info.append(f"{indent} File: {node.model_file}")
info.append(f"{indent} Multi-class: {node.multi_class}")
if node.expected_classes:
info.append(f"{indent} Expected: {', '.join(node.expected_classes)}")
if node.trigger_classes:
info.append(f"{indent} Triggers: {', '.join(node.trigger_classes)}")
info.append(f"{indent} Confidence: {node.min_confidence}")
if node.crop:
info.append(f"{indent} Crop: {node.crop_class} (ratio: {node.crop_expand_ratio})")
if node.actions:
info.append(f"{indent} Actions: {len(node.actions)}")
if node.parallel_actions:
info.append(f"{indent} Parallel Actions: {len(node.parallel_actions)}")
if node.branches:
info.append(f"{indent} Branches ({len(node.branches)}):")
for branch in node.branches:
info.append(self.get_node_info(branch, level + 2))
return "\n".join(info)
# Global pipeline loader instance
_pipeline_loader = None
def get_pipeline_loader(temp_dir: Optional[str] = None) -> PipelineLoader:
"""Get or create the global pipeline loader instance."""
global _pipeline_loader
if _pipeline_loader is None:
_pipeline_loader = PipelineLoader(temp_dir)
return _pipeline_loader
# Convenience functions
async def load_pipeline_from_mpta(mpta_path: str) -> PipelineNode:
"""Load a pipeline from an MPTA file."""
loader = get_pipeline_loader()
return await loader.load_pipeline(mpta_path)