""" Pipeline loader module. This module handles loading and parsing of MPTA (Machine Learning Pipeline Archive) files, which contain model configurations and pipeline definitions. """ import os import json import logging import zipfile import tempfile import shutil from typing import Dict, Any, Optional, List, Tuple from dataclasses import dataclass, field from pathlib import Path from ..core.exceptions import ModelLoadError, PipelineError # Setup logging logger = logging.getLogger("detector_worker.pipeline_loader") @dataclass class PipelineNode: """Represents a node in the pipeline tree.""" model_id: str model_file: str model_path: Optional[str] = None model: Optional[Any] = None # Loaded model instance # Node configuration multi_class: bool = False expected_classes: List[str] = field(default_factory=list) trigger_classes: List[str] = field(default_factory=list) min_confidence: float = 0.5 max_detections: Optional[int] = None # Cropping configuration crop: bool = False crop_class: Optional[str] = None crop_expand_ratio: float = 1.0 # Actions configuration actions: List[Dict[str, Any]] = field(default_factory=list) parallel_actions: List[Dict[str, Any]] = field(default_factory=list) # Branch configuration branches: List['PipelineNode'] = field(default_factory=list) parallel: bool = False # Detection settings yolo_settings: Dict[str, Any] = field(default_factory=dict) track_classes: Optional[List[str]] = None # Metadata metadata: Dict[str, Any] = field(default_factory=dict) @dataclass class PipelineConfig: """Pipeline configuration from pipeline.json.""" pipeline_id: str version: str = "1.0" description: str = "" # Database configuration database_config: Optional[Dict[str, Any]] = None # Redis configuration redis_config: Optional[Dict[str, Any]] = None # Global settings global_settings: Dict[str, Any] = field(default_factory=dict) # Root pipeline node root: Optional[PipelineNode] = None class PipelineLoader: """ Loads and manages ML pipeline configurations. This class handles: - MPTA file extraction and parsing - Pipeline configuration validation - Model file management - Pipeline tree construction - Resource cleanup """ def __init__(self, temp_dir: Optional[str] = None): """ Initialize the pipeline loader. Args: temp_dir: Temporary directory for extracting MPTA files """ self.temp_dir = temp_dir or tempfile.gettempdir() self.extracted_paths: Dict[str, str] = {} # mpta_path -> extracted_dir self.loaded_models: Dict[str, Any] = {} # model_path -> model_instance async def load_pipeline(self, mpta_path: str) -> PipelineNode: """ Load a pipeline from an MPTA file. Args: mpta_path: Path to MPTA file Returns: Root pipeline node Raises: ModelLoadError: If loading fails """ try: # Extract MPTA if not already extracted extracted_dir = await self._extract_mpta(mpta_path) # Load pipeline configuration pipeline_json_path = os.path.join(extracted_dir, "pipeline.json") if not os.path.exists(pipeline_json_path): raise ModelLoadError(f"pipeline.json not found in {mpta_path}") with open(pipeline_json_path, 'r') as f: config_data = json.load(f) # Parse pipeline configuration pipeline_config = self._parse_pipeline_config(config_data, extracted_dir) # Validate pipeline self._validate_pipeline(pipeline_config) # Load models for the pipeline await self._load_pipeline_models(pipeline_config.root, extracted_dir) logger.info(f"Successfully loaded pipeline from {mpta_path}") return pipeline_config.root except Exception as e: logger.error(f"Failed to load pipeline from {mpta_path}: {e}") raise ModelLoadError(f"Failed to load pipeline: {e}") async def _extract_mpta(self, mpta_path: str) -> str: """ Extract MPTA file to temporary directory. Args: mpta_path: Path to MPTA file Returns: Path to extracted directory """ # Check if already extracted if mpta_path in self.extracted_paths: extracted_dir = self.extracted_paths[mpta_path] if os.path.exists(extracted_dir): return extracted_dir # Create extraction directory mpta_name = os.path.splitext(os.path.basename(mpta_path))[0] extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}") # Extract MPTA logger.info(f"Extracting MPTA file: {mpta_path}") try: with zipfile.ZipFile(mpta_path, 'r') as zip_ref: # Clean existing directory if exists if os.path.exists(extracted_dir): shutil.rmtree(extracted_dir) os.makedirs(extracted_dir) zip_ref.extractall(extracted_dir) self.extracted_paths[mpta_path] = extracted_dir logger.info(f"Extracted to: {extracted_dir}") return extracted_dir except Exception as e: raise ModelLoadError(f"Failed to extract MPTA: {e}") def _parse_pipeline_config( self, config_data: Dict[str, Any], base_dir: str ) -> PipelineConfig: """ Parse pipeline configuration from JSON. Args: config_data: Pipeline JSON data base_dir: Base directory for model files Returns: Parsed pipeline configuration """ # Create pipeline config pipeline_config = PipelineConfig( pipeline_id=config_data.get("pipelineId", "unknown"), version=config_data.get("version", "1.0"), description=config_data.get("description", "") ) # Parse database config if "database" in config_data: pipeline_config.database_config = config_data["database"] # Parse Redis config if "redis" in config_data: pipeline_config.redis_config = config_data["redis"] # Parse global settings if "globalSettings" in config_data: pipeline_config.global_settings = config_data["globalSettings"] # Parse pipeline tree if "pipeline" in config_data: pipeline_config.root = self._parse_pipeline_node( config_data["pipeline"], base_dir ) elif "root" in config_data: pipeline_config.root = self._parse_pipeline_node( config_data["root"], base_dir ) else: raise PipelineError("No pipeline or root node found in configuration") return pipeline_config def _parse_pipeline_node( self, node_data: Dict[str, Any], base_dir: str ) -> PipelineNode: """ Parse a pipeline node from configuration. Args: node_data: Node configuration data base_dir: Base directory for model files Returns: Parsed pipeline node """ # Create node node = PipelineNode( model_id=node_data.get("modelId", ""), model_file=node_data.get("modelFile", "") ) # Set model path if node.model_file: node.model_path = os.path.join(base_dir, node.model_file) # Parse configuration node.multi_class = node_data.get("multiClass", False) node.expected_classes = node_data.get("expectedClasses", []) node.trigger_classes = node_data.get("triggerClasses", []) node.min_confidence = node_data.get("minConfidence", 0.5) node.max_detections = node_data.get("maxDetections") # Parse cropping node.crop = node_data.get("crop", False) node.crop_class = node_data.get("cropClass") node.crop_expand_ratio = node_data.get("cropExpandRatio", 1.0) # Parse actions node.actions = node_data.get("actions", []) node.parallel_actions = node_data.get("parallelActions", []) # Parse YOLO settings if "yoloSettings" in node_data: node.yolo_settings = node_data["yoloSettings"] elif "detectionSettings" in node_data: node.yolo_settings = node_data["detectionSettings"] # Parse tracking node.track_classes = node_data.get("trackClasses") # Parse metadata node.metadata = node_data.get("metadata", {}) # Parse branches branches_data = node_data.get("branches", []) node.parallel = node_data.get("parallel", False) for branch_data in branches_data: branch_node = self._parse_pipeline_node(branch_data, base_dir) node.branches.append(branch_node) return node def _validate_pipeline(self, pipeline_config: PipelineConfig) -> None: """ Validate pipeline configuration. Args: pipeline_config: Pipeline configuration to validate Raises: PipelineError: If validation fails """ if not pipeline_config.root: raise PipelineError("Pipeline has no root node") # Validate root node self._validate_node(pipeline_config.root) def _validate_node(self, node: PipelineNode) -> None: """ Validate a pipeline node. Args: node: Node to validate Raises: PipelineError: If validation fails """ # Check required fields if not node.model_id: raise PipelineError("Node missing modelId") if not node.model_file and not node.model: raise PipelineError(f"Node {node.model_id} missing modelFile") # Validate model path exists if node.model_path and not os.path.exists(node.model_path): raise PipelineError(f"Model file not found: {node.model_path}") # Validate cropping configuration if node.crop and not node.crop_class: raise PipelineError(f"Node {node.model_id} has crop=true but no cropClass") # Validate confidence if not 0 <= node.min_confidence <= 1: raise PipelineError(f"Invalid minConfidence: {node.min_confidence}") # Validate branches for branch in node.branches: self._validate_node(branch) async def _load_pipeline_models( self, node: PipelineNode, base_dir: str ) -> None: """ Load models for a pipeline node and its branches. Args: node: Pipeline node base_dir: Base directory for models """ # Load model for this node if path is specified if node.model_path: node.model = await self._load_model(node.model_path, node.model_id) # Load models for branches for branch in node.branches: await self._load_pipeline_models(branch, base_dir) async def _load_model(self, model_path: str, model_id: str) -> Any: """ Load a single model file. Args: model_path: Path to model file model_id: Model identifier Returns: Loaded model instance """ # Check if already loaded if model_path in self.loaded_models: logger.info(f"Using cached model: {model_id}") return self.loaded_models[model_path] try: # Import here to avoid circular dependency from ultralytics import YOLO logger.info(f"Loading model: {model_id} from {model_path}") # Load YOLO model model = YOLO(model_path) # Cache the model self.loaded_models[model_path] = model return model except Exception as e: raise ModelLoadError(f"Failed to load model {model_id}: {e}") def cleanup_model(self, model_id: str) -> None: """ Clean up resources for a specific model. Args: model_id: Model identifier to clean up """ # Clean up loaded models models_to_remove = [] for path, model in self.loaded_models.items(): if model_id in path: models_to_remove.append(path) for path in models_to_remove: self.loaded_models.pop(path, None) logger.info(f"Cleaned up model: {path}") def cleanup_all(self) -> None: """Clean up all resources.""" # Clear loaded models self.loaded_models.clear() # Clean up extracted directories for mpta_path, extracted_dir in self.extracted_paths.items(): if os.path.exists(extracted_dir): try: shutil.rmtree(extracted_dir) logger.info(f"Cleaned up extracted directory: {extracted_dir}") except Exception as e: logger.error(f"Failed to clean up {extracted_dir}: {e}") self.extracted_paths.clear() def get_node_info(self, node: PipelineNode, level: int = 0) -> str: """ Get formatted information about a pipeline node. Args: node: Pipeline node level: Indentation level Returns: Formatted node information """ indent = " " * level info = [] info.append(f"{indent}Model: {node.model_id}") info.append(f"{indent} File: {node.model_file}") info.append(f"{indent} Multi-class: {node.multi_class}") if node.expected_classes: info.append(f"{indent} Expected: {', '.join(node.expected_classes)}") if node.trigger_classes: info.append(f"{indent} Triggers: {', '.join(node.trigger_classes)}") info.append(f"{indent} Confidence: {node.min_confidence}") if node.crop: info.append(f"{indent} Crop: {node.crop_class} (ratio: {node.crop_expand_ratio})") if node.actions: info.append(f"{indent} Actions: {len(node.actions)}") if node.parallel_actions: info.append(f"{indent} Parallel Actions: {len(node.parallel_actions)}") if node.branches: info.append(f"{indent} Branches ({len(node.branches)}):") for branch in node.branches: info.append(self.get_node_info(branch, level + 2)) return "\n".join(info) # Global pipeline loader instance _pipeline_loader = None def get_pipeline_loader(temp_dir: Optional[str] = None) -> PipelineLoader: """Get or create the global pipeline loader instance.""" global _pipeline_loader if _pipeline_loader is None: _pipeline_loader = PipelineLoader(temp_dir) return _pipeline_loader # Convenience functions async def load_pipeline_from_mpta(mpta_path: str) -> PipelineNode: """Load a pipeline from an MPTA file.""" loader = get_pipeline_loader() return await loader.load_pipeline(mpta_path)