486 lines
No EOL
16 KiB
Python
486 lines
No EOL
16 KiB
Python
"""
|
|
Pipeline loader module.
|
|
|
|
This module handles loading and parsing of MPTA (Machine Learning Pipeline Archive)
|
|
files, which contain model configurations and pipeline definitions.
|
|
"""
|
|
import os
|
|
import json
|
|
import logging
|
|
import zipfile
|
|
import tempfile
|
|
import shutil
|
|
from typing import Dict, Any, Optional, List, Tuple
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
from ..core.exceptions import ModelLoadError, PipelineError
|
|
|
|
# Setup logging
|
|
logger = logging.getLogger("detector_worker.pipeline_loader")
|
|
|
|
|
|
@dataclass
|
|
class PipelineNode:
|
|
"""Represents a node in the pipeline tree."""
|
|
model_id: str
|
|
model_file: str
|
|
model_path: Optional[str] = None
|
|
model: Optional[Any] = None # Loaded model instance
|
|
|
|
# Node configuration
|
|
multi_class: bool = False
|
|
expected_classes: List[str] = field(default_factory=list)
|
|
trigger_classes: List[str] = field(default_factory=list)
|
|
min_confidence: float = 0.5
|
|
max_detections: Optional[int] = None
|
|
|
|
# Cropping configuration
|
|
crop: bool = False
|
|
crop_class: Optional[str] = None
|
|
crop_expand_ratio: float = 1.0
|
|
|
|
# Actions configuration
|
|
actions: List[Dict[str, Any]] = field(default_factory=list)
|
|
parallel_actions: List[Dict[str, Any]] = field(default_factory=list)
|
|
|
|
# Branch configuration
|
|
branches: List['PipelineNode'] = field(default_factory=list)
|
|
parallel: bool = False
|
|
|
|
# Detection settings
|
|
yolo_settings: Dict[str, Any] = field(default_factory=dict)
|
|
track_classes: Optional[List[str]] = None
|
|
|
|
# Metadata
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class PipelineConfig:
|
|
"""Pipeline configuration from pipeline.json."""
|
|
pipeline_id: str
|
|
version: str = "1.0"
|
|
description: str = ""
|
|
|
|
# Database configuration
|
|
database_config: Optional[Dict[str, Any]] = None
|
|
|
|
# Redis configuration
|
|
redis_config: Optional[Dict[str, Any]] = None
|
|
|
|
# Global settings
|
|
global_settings: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
# Root pipeline node
|
|
root: Optional[PipelineNode] = None
|
|
|
|
|
|
class PipelineLoader:
|
|
"""
|
|
Loads and manages ML pipeline configurations.
|
|
|
|
This class handles:
|
|
- MPTA file extraction and parsing
|
|
- Pipeline configuration validation
|
|
- Model file management
|
|
- Pipeline tree construction
|
|
- Resource cleanup
|
|
"""
|
|
|
|
def __init__(self, temp_dir: Optional[str] = None):
|
|
"""
|
|
Initialize the pipeline loader.
|
|
|
|
Args:
|
|
temp_dir: Temporary directory for extracting MPTA files
|
|
"""
|
|
self.temp_dir = temp_dir or tempfile.gettempdir()
|
|
self.extracted_paths: Dict[str, str] = {} # mpta_path -> extracted_dir
|
|
self.loaded_models: Dict[str, Any] = {} # model_path -> model_instance
|
|
|
|
async def load_pipeline(self, mpta_path: str) -> PipelineNode:
|
|
"""
|
|
Load a pipeline from an MPTA file.
|
|
|
|
Args:
|
|
mpta_path: Path to MPTA file
|
|
|
|
Returns:
|
|
Root pipeline node
|
|
|
|
Raises:
|
|
ModelLoadError: If loading fails
|
|
"""
|
|
try:
|
|
# Extract MPTA if not already extracted
|
|
extracted_dir = await self._extract_mpta(mpta_path)
|
|
|
|
# Load pipeline configuration
|
|
pipeline_json_path = os.path.join(extracted_dir, "pipeline.json")
|
|
if not os.path.exists(pipeline_json_path):
|
|
raise ModelLoadError(f"pipeline.json not found in {mpta_path}")
|
|
|
|
with open(pipeline_json_path, 'r') as f:
|
|
config_data = json.load(f)
|
|
|
|
# Parse pipeline configuration
|
|
pipeline_config = self._parse_pipeline_config(config_data, extracted_dir)
|
|
|
|
# Validate pipeline
|
|
self._validate_pipeline(pipeline_config)
|
|
|
|
# Load models for the pipeline
|
|
await self._load_pipeline_models(pipeline_config.root, extracted_dir)
|
|
|
|
logger.info(f"Successfully loaded pipeline from {mpta_path}")
|
|
return pipeline_config.root
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load pipeline from {mpta_path}: {e}")
|
|
raise ModelLoadError(f"Failed to load pipeline: {e}")
|
|
|
|
async def _extract_mpta(self, mpta_path: str) -> str:
|
|
"""
|
|
Extract MPTA file to temporary directory.
|
|
|
|
Args:
|
|
mpta_path: Path to MPTA file
|
|
|
|
Returns:
|
|
Path to extracted directory
|
|
"""
|
|
# Check if already extracted
|
|
if mpta_path in self.extracted_paths:
|
|
extracted_dir = self.extracted_paths[mpta_path]
|
|
if os.path.exists(extracted_dir):
|
|
return extracted_dir
|
|
|
|
# Create extraction directory
|
|
mpta_name = os.path.splitext(os.path.basename(mpta_path))[0]
|
|
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
|
|
|
|
# Extract MPTA
|
|
logger.info(f"Extracting MPTA file: {mpta_path}")
|
|
|
|
try:
|
|
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
|
|
# Clean existing directory if exists
|
|
if os.path.exists(extracted_dir):
|
|
shutil.rmtree(extracted_dir)
|
|
|
|
os.makedirs(extracted_dir)
|
|
zip_ref.extractall(extracted_dir)
|
|
|
|
self.extracted_paths[mpta_path] = extracted_dir
|
|
logger.info(f"Extracted to: {extracted_dir}")
|
|
|
|
return extracted_dir
|
|
|
|
except Exception as e:
|
|
raise ModelLoadError(f"Failed to extract MPTA: {e}")
|
|
|
|
def _parse_pipeline_config(
|
|
self,
|
|
config_data: Dict[str, Any],
|
|
base_dir: str
|
|
) -> PipelineConfig:
|
|
"""
|
|
Parse pipeline configuration from JSON.
|
|
|
|
Args:
|
|
config_data: Pipeline JSON data
|
|
base_dir: Base directory for model files
|
|
|
|
Returns:
|
|
Parsed pipeline configuration
|
|
"""
|
|
# Create pipeline config
|
|
pipeline_config = PipelineConfig(
|
|
pipeline_id=config_data.get("pipelineId", "unknown"),
|
|
version=config_data.get("version", "1.0"),
|
|
description=config_data.get("description", "")
|
|
)
|
|
|
|
# Parse database config
|
|
if "database" in config_data:
|
|
pipeline_config.database_config = config_data["database"]
|
|
|
|
# Parse Redis config
|
|
if "redis" in config_data:
|
|
pipeline_config.redis_config = config_data["redis"]
|
|
|
|
# Parse global settings
|
|
if "globalSettings" in config_data:
|
|
pipeline_config.global_settings = config_data["globalSettings"]
|
|
|
|
# Parse pipeline tree
|
|
if "pipeline" in config_data:
|
|
pipeline_config.root = self._parse_pipeline_node(
|
|
config_data["pipeline"], base_dir
|
|
)
|
|
elif "root" in config_data:
|
|
pipeline_config.root = self._parse_pipeline_node(
|
|
config_data["root"], base_dir
|
|
)
|
|
else:
|
|
raise PipelineError("No pipeline or root node found in configuration")
|
|
|
|
return pipeline_config
|
|
|
|
def _parse_pipeline_node(
|
|
self,
|
|
node_data: Dict[str, Any],
|
|
base_dir: str
|
|
) -> PipelineNode:
|
|
"""
|
|
Parse a pipeline node from configuration.
|
|
|
|
Args:
|
|
node_data: Node configuration data
|
|
base_dir: Base directory for model files
|
|
|
|
Returns:
|
|
Parsed pipeline node
|
|
"""
|
|
# Create node
|
|
node = PipelineNode(
|
|
model_id=node_data.get("modelId", ""),
|
|
model_file=node_data.get("modelFile", "")
|
|
)
|
|
|
|
# Set model path
|
|
if node.model_file:
|
|
node.model_path = os.path.join(base_dir, node.model_file)
|
|
|
|
# Parse configuration
|
|
node.multi_class = node_data.get("multiClass", False)
|
|
node.expected_classes = node_data.get("expectedClasses", [])
|
|
node.trigger_classes = node_data.get("triggerClasses", [])
|
|
node.min_confidence = node_data.get("minConfidence", 0.5)
|
|
node.max_detections = node_data.get("maxDetections")
|
|
|
|
# Parse cropping
|
|
node.crop = node_data.get("crop", False)
|
|
node.crop_class = node_data.get("cropClass")
|
|
node.crop_expand_ratio = node_data.get("cropExpandRatio", 1.0)
|
|
|
|
# Parse actions
|
|
node.actions = node_data.get("actions", [])
|
|
node.parallel_actions = node_data.get("parallelActions", [])
|
|
|
|
# Parse YOLO settings
|
|
if "yoloSettings" in node_data:
|
|
node.yolo_settings = node_data["yoloSettings"]
|
|
elif "detectionSettings" in node_data:
|
|
node.yolo_settings = node_data["detectionSettings"]
|
|
|
|
# Parse tracking
|
|
node.track_classes = node_data.get("trackClasses")
|
|
|
|
# Parse metadata
|
|
node.metadata = node_data.get("metadata", {})
|
|
|
|
# Parse branches
|
|
branches_data = node_data.get("branches", [])
|
|
node.parallel = node_data.get("parallel", False)
|
|
|
|
for branch_data in branches_data:
|
|
branch_node = self._parse_pipeline_node(branch_data, base_dir)
|
|
node.branches.append(branch_node)
|
|
|
|
return node
|
|
|
|
def _validate_pipeline(self, pipeline_config: PipelineConfig) -> None:
|
|
"""
|
|
Validate pipeline configuration.
|
|
|
|
Args:
|
|
pipeline_config: Pipeline configuration to validate
|
|
|
|
Raises:
|
|
PipelineError: If validation fails
|
|
"""
|
|
if not pipeline_config.root:
|
|
raise PipelineError("Pipeline has no root node")
|
|
|
|
# Validate root node
|
|
self._validate_node(pipeline_config.root)
|
|
|
|
def _validate_node(self, node: PipelineNode) -> None:
|
|
"""
|
|
Validate a pipeline node.
|
|
|
|
Args:
|
|
node: Node to validate
|
|
|
|
Raises:
|
|
PipelineError: If validation fails
|
|
"""
|
|
# Check required fields
|
|
if not node.model_id:
|
|
raise PipelineError("Node missing modelId")
|
|
|
|
if not node.model_file and not node.model:
|
|
raise PipelineError(f"Node {node.model_id} missing modelFile")
|
|
|
|
# Validate model path exists
|
|
if node.model_path and not os.path.exists(node.model_path):
|
|
raise PipelineError(f"Model file not found: {node.model_path}")
|
|
|
|
# Validate cropping configuration
|
|
if node.crop and not node.crop_class:
|
|
raise PipelineError(f"Node {node.model_id} has crop=true but no cropClass")
|
|
|
|
# Validate confidence
|
|
if not 0 <= node.min_confidence <= 1:
|
|
raise PipelineError(f"Invalid minConfidence: {node.min_confidence}")
|
|
|
|
# Validate branches
|
|
for branch in node.branches:
|
|
self._validate_node(branch)
|
|
|
|
async def _load_pipeline_models(
|
|
self,
|
|
node: PipelineNode,
|
|
base_dir: str
|
|
) -> None:
|
|
"""
|
|
Load models for a pipeline node and its branches.
|
|
|
|
Args:
|
|
node: Pipeline node
|
|
base_dir: Base directory for models
|
|
"""
|
|
# Load model for this node if path is specified
|
|
if node.model_path:
|
|
node.model = await self._load_model(node.model_path, node.model_id)
|
|
|
|
# Load models for branches
|
|
for branch in node.branches:
|
|
await self._load_pipeline_models(branch, base_dir)
|
|
|
|
async def _load_model(self, model_path: str, model_id: str) -> Any:
|
|
"""
|
|
Load a single model file.
|
|
|
|
Args:
|
|
model_path: Path to model file
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
Loaded model instance
|
|
"""
|
|
# Check if already loaded
|
|
if model_path in self.loaded_models:
|
|
logger.info(f"Using cached model: {model_id}")
|
|
return self.loaded_models[model_path]
|
|
|
|
try:
|
|
# Import here to avoid circular dependency
|
|
from ultralytics import YOLO
|
|
|
|
logger.info(f"Loading model: {model_id} from {model_path}")
|
|
|
|
# Load YOLO model
|
|
model = YOLO(model_path)
|
|
|
|
# Cache the model
|
|
self.loaded_models[model_path] = model
|
|
|
|
return model
|
|
|
|
except Exception as e:
|
|
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
|
|
|
|
def cleanup_model(self, model_id: str) -> None:
|
|
"""
|
|
Clean up resources for a specific model.
|
|
|
|
Args:
|
|
model_id: Model identifier to clean up
|
|
"""
|
|
# Clean up loaded models
|
|
models_to_remove = []
|
|
for path, model in self.loaded_models.items():
|
|
if model_id in path:
|
|
models_to_remove.append(path)
|
|
|
|
for path in models_to_remove:
|
|
self.loaded_models.pop(path, None)
|
|
logger.info(f"Cleaned up model: {path}")
|
|
|
|
def cleanup_all(self) -> None:
|
|
"""Clean up all resources."""
|
|
# Clear loaded models
|
|
self.loaded_models.clear()
|
|
|
|
# Clean up extracted directories
|
|
for mpta_path, extracted_dir in self.extracted_paths.items():
|
|
if os.path.exists(extracted_dir):
|
|
try:
|
|
shutil.rmtree(extracted_dir)
|
|
logger.info(f"Cleaned up extracted directory: {extracted_dir}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to clean up {extracted_dir}: {e}")
|
|
|
|
self.extracted_paths.clear()
|
|
|
|
def get_node_info(self, node: PipelineNode, level: int = 0) -> str:
|
|
"""
|
|
Get formatted information about a pipeline node.
|
|
|
|
Args:
|
|
node: Pipeline node
|
|
level: Indentation level
|
|
|
|
Returns:
|
|
Formatted node information
|
|
"""
|
|
indent = " " * level
|
|
info = []
|
|
|
|
info.append(f"{indent}Model: {node.model_id}")
|
|
info.append(f"{indent} File: {node.model_file}")
|
|
info.append(f"{indent} Multi-class: {node.multi_class}")
|
|
|
|
if node.expected_classes:
|
|
info.append(f"{indent} Expected: {', '.join(node.expected_classes)}")
|
|
if node.trigger_classes:
|
|
info.append(f"{indent} Triggers: {', '.join(node.trigger_classes)}")
|
|
|
|
info.append(f"{indent} Confidence: {node.min_confidence}")
|
|
|
|
if node.crop:
|
|
info.append(f"{indent} Crop: {node.crop_class} (ratio: {node.crop_expand_ratio})")
|
|
|
|
if node.actions:
|
|
info.append(f"{indent} Actions: {len(node.actions)}")
|
|
if node.parallel_actions:
|
|
info.append(f"{indent} Parallel Actions: {len(node.parallel_actions)}")
|
|
|
|
if node.branches:
|
|
info.append(f"{indent} Branches ({len(node.branches)}):")
|
|
for branch in node.branches:
|
|
info.append(self.get_node_info(branch, level + 2))
|
|
|
|
return "\n".join(info)
|
|
|
|
|
|
# Global pipeline loader instance
|
|
_pipeline_loader = None
|
|
|
|
|
|
def get_pipeline_loader(temp_dir: Optional[str] = None) -> PipelineLoader:
|
|
"""Get or create the global pipeline loader instance."""
|
|
global _pipeline_loader
|
|
if _pipeline_loader is None:
|
|
_pipeline_loader = PipelineLoader(temp_dir)
|
|
return _pipeline_loader
|
|
|
|
|
|
# Convenience functions
|
|
async def load_pipeline_from_mpta(mpta_path: str) -> PipelineNode:
|
|
"""Load a pipeline from an MPTA file."""
|
|
loader = get_pipeline_loader()
|
|
return await loader.load_pipeline(mpta_path) |