Refactor: Phase 4: Communication Layer
This commit is contained in:
parent
cdeaaf4a4f
commit
54f21672aa
6 changed files with 2876 additions and 0 deletions
486
detector_worker/models/pipeline_loader.py
Normal file
486
detector_worker/models/pipeline_loader.py
Normal file
|
@ -0,0 +1,486 @@
|
|||
"""
|
||||
Pipeline loader module.
|
||||
|
||||
This module handles loading and parsing of MPTA (Machine Learning Pipeline Archive)
|
||||
files, which contain model configurations and pipeline definitions.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import zipfile
|
||||
import tempfile
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from ..core.exceptions import ModelLoadError, PipelineError
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger("detector_worker.pipeline_loader")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineNode:
|
||||
"""Represents a node in the pipeline tree."""
|
||||
model_id: str
|
||||
model_file: str
|
||||
model_path: Optional[str] = None
|
||||
model: Optional[Any] = None # Loaded model instance
|
||||
|
||||
# Node configuration
|
||||
multi_class: bool = False
|
||||
expected_classes: List[str] = field(default_factory=list)
|
||||
trigger_classes: List[str] = field(default_factory=list)
|
||||
min_confidence: float = 0.5
|
||||
max_detections: Optional[int] = None
|
||||
|
||||
# Cropping configuration
|
||||
crop: bool = False
|
||||
crop_class: Optional[str] = None
|
||||
crop_expand_ratio: float = 1.0
|
||||
|
||||
# Actions configuration
|
||||
actions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
parallel_actions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# Branch configuration
|
||||
branches: List['PipelineNode'] = field(default_factory=list)
|
||||
parallel: bool = False
|
||||
|
||||
# Detection settings
|
||||
yolo_settings: Dict[str, Any] = field(default_factory=dict)
|
||||
track_classes: Optional[List[str]] = None
|
||||
|
||||
# Metadata
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineConfig:
|
||||
"""Pipeline configuration from pipeline.json."""
|
||||
pipeline_id: str
|
||||
version: str = "1.0"
|
||||
description: str = ""
|
||||
|
||||
# Database configuration
|
||||
database_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Redis configuration
|
||||
redis_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Global settings
|
||||
global_settings: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Root pipeline node
|
||||
root: Optional[PipelineNode] = None
|
||||
|
||||
|
||||
class PipelineLoader:
|
||||
"""
|
||||
Loads and manages ML pipeline configurations.
|
||||
|
||||
This class handles:
|
||||
- MPTA file extraction and parsing
|
||||
- Pipeline configuration validation
|
||||
- Model file management
|
||||
- Pipeline tree construction
|
||||
- Resource cleanup
|
||||
"""
|
||||
|
||||
def __init__(self, temp_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the pipeline loader.
|
||||
|
||||
Args:
|
||||
temp_dir: Temporary directory for extracting MPTA files
|
||||
"""
|
||||
self.temp_dir = temp_dir or tempfile.gettempdir()
|
||||
self.extracted_paths: Dict[str, str] = {} # mpta_path -> extracted_dir
|
||||
self.loaded_models: Dict[str, Any] = {} # model_path -> model_instance
|
||||
|
||||
async def load_pipeline(self, mpta_path: str) -> PipelineNode:
|
||||
"""
|
||||
Load a pipeline from an MPTA file.
|
||||
|
||||
Args:
|
||||
mpta_path: Path to MPTA file
|
||||
|
||||
Returns:
|
||||
Root pipeline node
|
||||
|
||||
Raises:
|
||||
ModelLoadError: If loading fails
|
||||
"""
|
||||
try:
|
||||
# Extract MPTA if not already extracted
|
||||
extracted_dir = await self._extract_mpta(mpta_path)
|
||||
|
||||
# Load pipeline configuration
|
||||
pipeline_json_path = os.path.join(extracted_dir, "pipeline.json")
|
||||
if not os.path.exists(pipeline_json_path):
|
||||
raise ModelLoadError(f"pipeline.json not found in {mpta_path}")
|
||||
|
||||
with open(pipeline_json_path, 'r') as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# Parse pipeline configuration
|
||||
pipeline_config = self._parse_pipeline_config(config_data, extracted_dir)
|
||||
|
||||
# Validate pipeline
|
||||
self._validate_pipeline(pipeline_config)
|
||||
|
||||
# Load models for the pipeline
|
||||
await self._load_pipeline_models(pipeline_config.root, extracted_dir)
|
||||
|
||||
logger.info(f"Successfully loaded pipeline from {mpta_path}")
|
||||
return pipeline_config.root
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load pipeline from {mpta_path}: {e}")
|
||||
raise ModelLoadError(f"Failed to load pipeline: {e}")
|
||||
|
||||
async def _extract_mpta(self, mpta_path: str) -> str:
|
||||
"""
|
||||
Extract MPTA file to temporary directory.
|
||||
|
||||
Args:
|
||||
mpta_path: Path to MPTA file
|
||||
|
||||
Returns:
|
||||
Path to extracted directory
|
||||
"""
|
||||
# Check if already extracted
|
||||
if mpta_path in self.extracted_paths:
|
||||
extracted_dir = self.extracted_paths[mpta_path]
|
||||
if os.path.exists(extracted_dir):
|
||||
return extracted_dir
|
||||
|
||||
# Create extraction directory
|
||||
mpta_name = os.path.splitext(os.path.basename(mpta_path))[0]
|
||||
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
|
||||
|
||||
# Extract MPTA
|
||||
logger.info(f"Extracting MPTA file: {mpta_path}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
|
||||
# Clean existing directory if exists
|
||||
if os.path.exists(extracted_dir):
|
||||
shutil.rmtree(extracted_dir)
|
||||
|
||||
os.makedirs(extracted_dir)
|
||||
zip_ref.extractall(extracted_dir)
|
||||
|
||||
self.extracted_paths[mpta_path] = extracted_dir
|
||||
logger.info(f"Extracted to: {extracted_dir}")
|
||||
|
||||
return extracted_dir
|
||||
|
||||
except Exception as e:
|
||||
raise ModelLoadError(f"Failed to extract MPTA: {e}")
|
||||
|
||||
def _parse_pipeline_config(
|
||||
self,
|
||||
config_data: Dict[str, Any],
|
||||
base_dir: str
|
||||
) -> PipelineConfig:
|
||||
"""
|
||||
Parse pipeline configuration from JSON.
|
||||
|
||||
Args:
|
||||
config_data: Pipeline JSON data
|
||||
base_dir: Base directory for model files
|
||||
|
||||
Returns:
|
||||
Parsed pipeline configuration
|
||||
"""
|
||||
# Create pipeline config
|
||||
pipeline_config = PipelineConfig(
|
||||
pipeline_id=config_data.get("pipelineId", "unknown"),
|
||||
version=config_data.get("version", "1.0"),
|
||||
description=config_data.get("description", "")
|
||||
)
|
||||
|
||||
# Parse database config
|
||||
if "database" in config_data:
|
||||
pipeline_config.database_config = config_data["database"]
|
||||
|
||||
# Parse Redis config
|
||||
if "redis" in config_data:
|
||||
pipeline_config.redis_config = config_data["redis"]
|
||||
|
||||
# Parse global settings
|
||||
if "globalSettings" in config_data:
|
||||
pipeline_config.global_settings = config_data["globalSettings"]
|
||||
|
||||
# Parse pipeline tree
|
||||
if "pipeline" in config_data:
|
||||
pipeline_config.root = self._parse_pipeline_node(
|
||||
config_data["pipeline"], base_dir
|
||||
)
|
||||
elif "root" in config_data:
|
||||
pipeline_config.root = self._parse_pipeline_node(
|
||||
config_data["root"], base_dir
|
||||
)
|
||||
else:
|
||||
raise PipelineError("No pipeline or root node found in configuration")
|
||||
|
||||
return pipeline_config
|
||||
|
||||
def _parse_pipeline_node(
|
||||
self,
|
||||
node_data: Dict[str, Any],
|
||||
base_dir: str
|
||||
) -> PipelineNode:
|
||||
"""
|
||||
Parse a pipeline node from configuration.
|
||||
|
||||
Args:
|
||||
node_data: Node configuration data
|
||||
base_dir: Base directory for model files
|
||||
|
||||
Returns:
|
||||
Parsed pipeline node
|
||||
"""
|
||||
# Create node
|
||||
node = PipelineNode(
|
||||
model_id=node_data.get("modelId", ""),
|
||||
model_file=node_data.get("modelFile", "")
|
||||
)
|
||||
|
||||
# Set model path
|
||||
if node.model_file:
|
||||
node.model_path = os.path.join(base_dir, node.model_file)
|
||||
|
||||
# Parse configuration
|
||||
node.multi_class = node_data.get("multiClass", False)
|
||||
node.expected_classes = node_data.get("expectedClasses", [])
|
||||
node.trigger_classes = node_data.get("triggerClasses", [])
|
||||
node.min_confidence = node_data.get("minConfidence", 0.5)
|
||||
node.max_detections = node_data.get("maxDetections")
|
||||
|
||||
# Parse cropping
|
||||
node.crop = node_data.get("crop", False)
|
||||
node.crop_class = node_data.get("cropClass")
|
||||
node.crop_expand_ratio = node_data.get("cropExpandRatio", 1.0)
|
||||
|
||||
# Parse actions
|
||||
node.actions = node_data.get("actions", [])
|
||||
node.parallel_actions = node_data.get("parallelActions", [])
|
||||
|
||||
# Parse YOLO settings
|
||||
if "yoloSettings" in node_data:
|
||||
node.yolo_settings = node_data["yoloSettings"]
|
||||
elif "detectionSettings" in node_data:
|
||||
node.yolo_settings = node_data["detectionSettings"]
|
||||
|
||||
# Parse tracking
|
||||
node.track_classes = node_data.get("trackClasses")
|
||||
|
||||
# Parse metadata
|
||||
node.metadata = node_data.get("metadata", {})
|
||||
|
||||
# Parse branches
|
||||
branches_data = node_data.get("branches", [])
|
||||
node.parallel = node_data.get("parallel", False)
|
||||
|
||||
for branch_data in branches_data:
|
||||
branch_node = self._parse_pipeline_node(branch_data, base_dir)
|
||||
node.branches.append(branch_node)
|
||||
|
||||
return node
|
||||
|
||||
def _validate_pipeline(self, pipeline_config: PipelineConfig) -> None:
|
||||
"""
|
||||
Validate pipeline configuration.
|
||||
|
||||
Args:
|
||||
pipeline_config: Pipeline configuration to validate
|
||||
|
||||
Raises:
|
||||
PipelineError: If validation fails
|
||||
"""
|
||||
if not pipeline_config.root:
|
||||
raise PipelineError("Pipeline has no root node")
|
||||
|
||||
# Validate root node
|
||||
self._validate_node(pipeline_config.root)
|
||||
|
||||
def _validate_node(self, node: PipelineNode) -> None:
|
||||
"""
|
||||
Validate a pipeline node.
|
||||
|
||||
Args:
|
||||
node: Node to validate
|
||||
|
||||
Raises:
|
||||
PipelineError: If validation fails
|
||||
"""
|
||||
# Check required fields
|
||||
if not node.model_id:
|
||||
raise PipelineError("Node missing modelId")
|
||||
|
||||
if not node.model_file and not node.model:
|
||||
raise PipelineError(f"Node {node.model_id} missing modelFile")
|
||||
|
||||
# Validate model path exists
|
||||
if node.model_path and not os.path.exists(node.model_path):
|
||||
raise PipelineError(f"Model file not found: {node.model_path}")
|
||||
|
||||
# Validate cropping configuration
|
||||
if node.crop and not node.crop_class:
|
||||
raise PipelineError(f"Node {node.model_id} has crop=true but no cropClass")
|
||||
|
||||
# Validate confidence
|
||||
if not 0 <= node.min_confidence <= 1:
|
||||
raise PipelineError(f"Invalid minConfidence: {node.min_confidence}")
|
||||
|
||||
# Validate branches
|
||||
for branch in node.branches:
|
||||
self._validate_node(branch)
|
||||
|
||||
async def _load_pipeline_models(
|
||||
self,
|
||||
node: PipelineNode,
|
||||
base_dir: str
|
||||
) -> None:
|
||||
"""
|
||||
Load models for a pipeline node and its branches.
|
||||
|
||||
Args:
|
||||
node: Pipeline node
|
||||
base_dir: Base directory for models
|
||||
"""
|
||||
# Load model for this node if path is specified
|
||||
if node.model_path:
|
||||
node.model = await self._load_model(node.model_path, node.model_id)
|
||||
|
||||
# Load models for branches
|
||||
for branch in node.branches:
|
||||
await self._load_pipeline_models(branch, base_dir)
|
||||
|
||||
async def _load_model(self, model_path: str, model_id: str) -> Any:
|
||||
"""
|
||||
Load a single model file.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
model_id: Model identifier
|
||||
|
||||
Returns:
|
||||
Loaded model instance
|
||||
"""
|
||||
# Check if already loaded
|
||||
if model_path in self.loaded_models:
|
||||
logger.info(f"Using cached model: {model_id}")
|
||||
return self.loaded_models[model_path]
|
||||
|
||||
try:
|
||||
# Import here to avoid circular dependency
|
||||
from ultralytics import YOLO
|
||||
|
||||
logger.info(f"Loading model: {model_id} from {model_path}")
|
||||
|
||||
# Load YOLO model
|
||||
model = YOLO(model_path)
|
||||
|
||||
# Cache the model
|
||||
self.loaded_models[model_path] = model
|
||||
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
|
||||
|
||||
def cleanup_model(self, model_id: str) -> None:
|
||||
"""
|
||||
Clean up resources for a specific model.
|
||||
|
||||
Args:
|
||||
model_id: Model identifier to clean up
|
||||
"""
|
||||
# Clean up loaded models
|
||||
models_to_remove = []
|
||||
for path, model in self.loaded_models.items():
|
||||
if model_id in path:
|
||||
models_to_remove.append(path)
|
||||
|
||||
for path in models_to_remove:
|
||||
self.loaded_models.pop(path, None)
|
||||
logger.info(f"Cleaned up model: {path}")
|
||||
|
||||
def cleanup_all(self) -> None:
|
||||
"""Clean up all resources."""
|
||||
# Clear loaded models
|
||||
self.loaded_models.clear()
|
||||
|
||||
# Clean up extracted directories
|
||||
for mpta_path, extracted_dir in self.extracted_paths.items():
|
||||
if os.path.exists(extracted_dir):
|
||||
try:
|
||||
shutil.rmtree(extracted_dir)
|
||||
logger.info(f"Cleaned up extracted directory: {extracted_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clean up {extracted_dir}: {e}")
|
||||
|
||||
self.extracted_paths.clear()
|
||||
|
||||
def get_node_info(self, node: PipelineNode, level: int = 0) -> str:
|
||||
"""
|
||||
Get formatted information about a pipeline node.
|
||||
|
||||
Args:
|
||||
node: Pipeline node
|
||||
level: Indentation level
|
||||
|
||||
Returns:
|
||||
Formatted node information
|
||||
"""
|
||||
indent = " " * level
|
||||
info = []
|
||||
|
||||
info.append(f"{indent}Model: {node.model_id}")
|
||||
info.append(f"{indent} File: {node.model_file}")
|
||||
info.append(f"{indent} Multi-class: {node.multi_class}")
|
||||
|
||||
if node.expected_classes:
|
||||
info.append(f"{indent} Expected: {', '.join(node.expected_classes)}")
|
||||
if node.trigger_classes:
|
||||
info.append(f"{indent} Triggers: {', '.join(node.trigger_classes)}")
|
||||
|
||||
info.append(f"{indent} Confidence: {node.min_confidence}")
|
||||
|
||||
if node.crop:
|
||||
info.append(f"{indent} Crop: {node.crop_class} (ratio: {node.crop_expand_ratio})")
|
||||
|
||||
if node.actions:
|
||||
info.append(f"{indent} Actions: {len(node.actions)}")
|
||||
if node.parallel_actions:
|
||||
info.append(f"{indent} Parallel Actions: {len(node.parallel_actions)}")
|
||||
|
||||
if node.branches:
|
||||
info.append(f"{indent} Branches ({len(node.branches)}):")
|
||||
for branch in node.branches:
|
||||
info.append(self.get_node_info(branch, level + 2))
|
||||
|
||||
return "\n".join(info)
|
||||
|
||||
|
||||
# Global pipeline loader instance
|
||||
_pipeline_loader = None
|
||||
|
||||
|
||||
def get_pipeline_loader(temp_dir: Optional[str] = None) -> PipelineLoader:
|
||||
"""Get or create the global pipeline loader instance."""
|
||||
global _pipeline_loader
|
||||
if _pipeline_loader is None:
|
||||
_pipeline_loader = PipelineLoader(temp_dir)
|
||||
return _pipeline_loader
|
||||
|
||||
|
||||
# Convenience functions
|
||||
async def load_pipeline_from_mpta(mpta_path: str) -> PipelineNode:
|
||||
"""Load a pipeline from an MPTA file."""
|
||||
loader = get_pipeline_loader()
|
||||
return await loader.load_pipeline(mpta_path)
|
Loading…
Add table
Add a link
Reference in a new issue