python-detector-worker/detector_worker/models/pipeline_loader.py

583 lines
No EOL
21 KiB
Python

"""
Pipeline loader module.
This module handles loading and parsing of MPTA (Machine Learning Pipeline Archive)
files, which contain model configurations and pipeline definitions.
"""
import os
import json
import logging
import zipfile
import tempfile
import shutil
import traceback
from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass, field
from pathlib import Path
from ..core.exceptions import ModelLoadError, PipelineError
# Setup logging
logger = logging.getLogger("detector_worker.pipeline_loader")
@dataclass
class PipelineNode:
"""Represents a node in the pipeline tree."""
model_id: str
model_file: str
model_path: Optional[str] = None
model: Optional[Any] = None # Loaded model instance
# Node configuration
multi_class: bool = False
expected_classes: List[str] = field(default_factory=list)
trigger_classes: List[str] = field(default_factory=list)
min_confidence: float = 0.5
max_detections: Optional[int] = None
# Cropping configuration
crop: bool = False
crop_class: Optional[str] = None
crop_expand_ratio: float = 1.0
# Actions configuration
actions: List[Dict[str, Any]] = field(default_factory=list)
parallel_actions: List[Dict[str, Any]] = field(default_factory=list)
# Branch configuration
branches: List['PipelineNode'] = field(default_factory=list)
parallel: bool = False
# Detection settings
yolo_settings: Dict[str, Any] = field(default_factory=dict)
track_classes: Optional[List[str]] = None
# Metadata
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class PipelineConfig:
"""Pipeline configuration from pipeline.json."""
pipeline_id: str
version: str = "1.0"
description: str = ""
# Database configuration
database_config: Optional[Dict[str, Any]] = None
# Redis configuration
redis_config: Optional[Dict[str, Any]] = None
# Global settings
global_settings: Dict[str, Any] = field(default_factory=dict)
# Root pipeline node
root: Optional[PipelineNode] = None
class PipelineLoader:
"""
Loads and manages ML pipeline configurations.
This class handles:
- MPTA file extraction and parsing
- Pipeline configuration validation
- Model file management
- Pipeline tree construction
- Resource cleanup
"""
def __init__(self, temp_dir: Optional[str] = None):
"""
Initialize the pipeline loader.
Args:
temp_dir: Temporary directory for extracting MPTA files
"""
self.temp_dir = temp_dir or tempfile.gettempdir()
self.extracted_paths: Dict[str, str] = {} # mpta_path -> extracted_dir
self.loaded_models: Dict[str, Any] = {} # model_path -> model_instance
async def load_pipeline(self, mpta_path: str) -> PipelineNode:
"""
Load a pipeline from an MPTA file.
Args:
mpta_path: Path to MPTA file
Returns:
Root pipeline node
Raises:
ModelLoadError: If loading fails
"""
try:
logger.info(f"🔍 Loading pipeline from MPTA file: {mpta_path}")
# Verify MPTA file exists
if not os.path.exists(mpta_path):
raise ModelLoadError(f"MPTA file not found: {mpta_path}")
# Check if it's actually a zip file
if not zipfile.is_zipfile(mpta_path):
raise ModelLoadError(f"File is not a valid ZIP/MPTA archive: {mpta_path}")
# Extract MPTA if not already extracted
extracted_dir = await self._extract_mpta(mpta_path)
logger.info(f"📂 MPTA extracted to: {extracted_dir}")
# List contents of extracted directory for debugging
if os.path.exists(extracted_dir):
contents = os.listdir(extracted_dir)
logger.info(f"📋 Extracted contents: {contents}")
else:
raise ModelLoadError(f"Extraction failed - directory not found: {extracted_dir}")
# Load pipeline configuration
# First check if pipeline.json exists in a subdirectory (most common case)
pipeline_json_path = None
logger.info(f"🔍 Looking for pipeline.json in extracted directory: {extracted_dir}")
# Look for pipeline.json in subdirectories first (common case)
for root, _, files in os.walk(extracted_dir):
if "pipeline.json" in files:
pipeline_json_path = os.path.join(root, "pipeline.json")
logger.info(f"✅ Found pipeline.json at: {pipeline_json_path}")
break
# If not found in subdirectories, try root level
if not pipeline_json_path:
root_pipeline_json = os.path.join(extracted_dir, "pipeline.json")
if os.path.exists(root_pipeline_json):
pipeline_json_path = root_pipeline_json
logger.info(f"✅ Found pipeline.json at root: {pipeline_json_path}")
if not pipeline_json_path:
# List all files in extracted directory for debugging
all_files = []
for root, _, files in os.walk(extracted_dir):
for file in files:
all_files.append(os.path.join(root, file))
raise ModelLoadError(f"pipeline.json not found in extracted MPTA. "
f"Extracted to: {extracted_dir}. "
f"Files found: {all_files}")
with open(pipeline_json_path, 'r') as f:
config_data = json.load(f)
logger.info(f"📋 Pipeline config loaded from: {pipeline_json_path}")
# Parse pipeline configuration (use extracted directory as base)
base_dir = os.path.dirname(pipeline_json_path)
pipeline_config = self._parse_pipeline_config(config_data, base_dir)
# Validate pipeline
self._validate_pipeline(pipeline_config)
# Load models for the pipeline
await self._load_pipeline_models(pipeline_config.root, base_dir)
logger.info(f"✅ Successfully loaded pipeline from {mpta_path}")
return pipeline_config.root
except Exception as e:
logger.error(f"❌ Failed to load pipeline from {mpta_path}: {e}")
traceback.print_exc()
raise ModelLoadError(f"Failed to load pipeline: {e}")
async def _extract_mpta(self, mpta_path: str) -> str:
"""
Extract MPTA file to model_id based directory structure.
For models/{model_id}/ structure, extracts to the same directory as the MPTA file.
Args:
mpta_path: Path to MPTA file
Returns:
Path to extracted directory
"""
# Check if already extracted
if mpta_path in self.extracted_paths:
extracted_dir = self.extracted_paths[mpta_path]
if os.path.exists(extracted_dir):
return extracted_dir
# Determine extraction directory
# If MPTA is in models/{model_id}/ structure, extract there
# Otherwise use temporary directory
mpta_dir = os.path.dirname(mpta_path)
mpta_name = os.path.splitext(os.path.basename(mpta_path))[0]
# Check if this is in models/{model_id}/ structure
if "models/" in mpta_dir and mpta_dir.count("/") >= 1:
# Extract directly to the models/{model_id}/ directory
extracted_dir = mpta_dir # Extract directly where the MPTA file is
else:
# Use temporary directory for non-model files
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
# Extract MPTA
logger.info(f"📦 Extracting MPTA file: {mpta_path}")
logger.info(f"📂 Extraction target: {extracted_dir}")
try:
# Verify it's a valid zip file before extracting
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
# List contents for debugging
file_list = zip_ref.namelist()
logger.info(f"📋 ZIP file contents ({len(file_list)} files): {file_list[:10]}{'...' if len(file_list) > 10 else ''}")
# For models/{model_id}/ structure, only clean extracted contents, not the MPTA file
if "models/" in extracted_dir and mpta_path.startswith(extracted_dir):
# Clean only the extracted subdirectories, keep the MPTA file
for item in os.listdir(extracted_dir):
item_path = os.path.join(extracted_dir, item)
if os.path.isdir(item_path):
logger.info(f"🧹 Cleaning existing extracted directory: {item_path}")
shutil.rmtree(item_path)
elif not item.endswith('.mpta'):
# Remove non-MPTA files that might be leftover extracts
logger.info(f"🧹 Cleaning leftover file: {item_path}")
os.remove(item_path)
else:
# For temp directories, clean everything
if os.path.exists(extracted_dir):
logger.info(f"🧹 Cleaning existing extraction directory: {extracted_dir}")
shutil.rmtree(extracted_dir)
os.makedirs(extracted_dir, exist_ok=True)
# Extract all files
logger.info(f"📤 Extracting {len(file_list)} files...")
zip_ref.extractall(extracted_dir)
# Verify extraction worked
extracted_files = []
for root, dirs, files in os.walk(extracted_dir):
for file in files:
extracted_files.append(os.path.join(root, file))
logger.info(f"✅ Extraction completed - {len(extracted_files)} files extracted")
logger.info(f"📋 Sample extracted files: {extracted_files[:5]}{'...' if len(extracted_files) > 5 else ''}")
self.extracted_paths[mpta_path] = extracted_dir
logger.info(f"✅ MPTA successfully extracted to: {extracted_dir}")
return extracted_dir
except zipfile.BadZipFile as e:
logger.error(f"❌ Invalid ZIP file: {mpta_path}")
raise ModelLoadError(f"Invalid ZIP/MPTA file: {e}")
except Exception as e:
logger.error(f"❌ Failed to extract MPTA: {e}")
raise ModelLoadError(f"Failed to extract MPTA: {e}")
def _parse_pipeline_config(
self,
config_data: Dict[str, Any],
base_dir: str
) -> PipelineConfig:
"""
Parse pipeline configuration from JSON.
Args:
config_data: Pipeline JSON data
base_dir: Base directory for model files
Returns:
Parsed pipeline configuration
"""
# Create pipeline config
pipeline_config = PipelineConfig(
pipeline_id=config_data.get("pipelineId", "unknown"),
version=config_data.get("version", "1.0"),
description=config_data.get("description", "")
)
# Parse database config
if "database" in config_data:
pipeline_config.database_config = config_data["database"]
# Parse Redis config
if "redis" in config_data:
pipeline_config.redis_config = config_data["redis"]
# Parse global settings
if "globalSettings" in config_data:
pipeline_config.global_settings = config_data["globalSettings"]
# Parse pipeline tree
if "pipeline" in config_data:
pipeline_config.root = self._parse_pipeline_node(
config_data["pipeline"], base_dir
)
elif "root" in config_data:
pipeline_config.root = self._parse_pipeline_node(
config_data["root"], base_dir
)
else:
raise PipelineError("No pipeline or root node found in configuration")
return pipeline_config
def _parse_pipeline_node(
self,
node_data: Dict[str, Any],
base_dir: str
) -> PipelineNode:
"""
Parse a pipeline node from configuration.
Args:
node_data: Node configuration data
base_dir: Base directory for model files
Returns:
Parsed pipeline node
"""
# Create node
node = PipelineNode(
model_id=node_data.get("modelId", ""),
model_file=node_data.get("modelFile", "")
)
# Set model path
if node.model_file:
node.model_path = os.path.join(base_dir, node.model_file)
# Parse configuration
node.multi_class = node_data.get("multiClass", False)
node.expected_classes = node_data.get("expectedClasses", [])
node.trigger_classes = node_data.get("triggerClasses", [])
node.min_confidence = node_data.get("minConfidence", 0.5)
node.max_detections = node_data.get("maxDetections")
# Parse cropping
node.crop = node_data.get("crop", False)
node.crop_class = node_data.get("cropClass")
node.crop_expand_ratio = node_data.get("cropExpandRatio", 1.0)
# Parse actions
node.actions = node_data.get("actions", [])
node.parallel_actions = node_data.get("parallelActions", [])
# Parse YOLO settings
if "yoloSettings" in node_data:
node.yolo_settings = node_data["yoloSettings"]
elif "detectionSettings" in node_data:
node.yolo_settings = node_data["detectionSettings"]
# Parse tracking
node.track_classes = node_data.get("trackClasses")
# Parse metadata
node.metadata = node_data.get("metadata", {})
# Parse branches
branches_data = node_data.get("branches", [])
node.parallel = node_data.get("parallel", False)
for branch_data in branches_data:
branch_node = self._parse_pipeline_node(branch_data, base_dir)
node.branches.append(branch_node)
return node
def _validate_pipeline(self, pipeline_config: PipelineConfig) -> None:
"""
Validate pipeline configuration.
Args:
pipeline_config: Pipeline configuration to validate
Raises:
PipelineError: If validation fails
"""
if not pipeline_config.root:
raise PipelineError("Pipeline has no root node")
# Validate root node
self._validate_node(pipeline_config.root)
def _validate_node(self, node: PipelineNode) -> None:
"""
Validate a pipeline node.
Args:
node: Node to validate
Raises:
PipelineError: If validation fails
"""
# Check required fields
if not node.model_id:
raise PipelineError("Node missing modelId")
if not node.model_file and not node.model:
raise PipelineError(f"Node {node.model_id} missing modelFile")
# Validate model path exists
if node.model_path and not os.path.exists(node.model_path):
raise PipelineError(f"Model file not found: {node.model_path}")
# Validate cropping configuration - be more lenient for backward compatibility
if node.crop and not node.crop_class:
logger.warning(f"Node {node.model_id} has crop=true but no cropClass - will disable cropping")
node.crop = False # Disable cropping instead of failing
# Validate confidence
if not 0 <= node.min_confidence <= 1:
raise PipelineError(f"Invalid minConfidence: {node.min_confidence}")
# Validate branches
for branch in node.branches:
self._validate_node(branch)
async def _load_pipeline_models(
self,
node: PipelineNode,
base_dir: str
) -> None:
"""
Load models for a pipeline node and its branches.
Args:
node: Pipeline node
base_dir: Base directory for models
"""
# Load model for this node if path is specified
if node.model_path:
node.model = await self._load_model(node.model_path, node.model_id)
# Load models for branches
for branch in node.branches:
await self._load_pipeline_models(branch, base_dir)
async def _load_model(self, model_path: str, model_id: str) -> Any:
"""
Load a single model file.
Args:
model_path: Path to model file
model_id: Model identifier
Returns:
Loaded model instance
"""
# Check if already loaded
if model_path in self.loaded_models:
logger.info(f"Using cached model: {model_id}")
return self.loaded_models[model_path]
try:
# Import here to avoid circular dependency
from ultralytics import YOLO
logger.info(f"Loading model: {model_id} from {model_path}")
# Load YOLO model
model = YOLO(model_path)
# Cache the model
self.loaded_models[model_path] = model
return model
except Exception as e:
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
def cleanup_model(self, model_id: str) -> None:
"""
Clean up resources for a specific model.
Args:
model_id: Model identifier to clean up
"""
# Clean up loaded models
models_to_remove = []
for path, model in self.loaded_models.items():
if model_id in path:
models_to_remove.append(path)
for path in models_to_remove:
self.loaded_models.pop(path, None)
logger.info(f"Cleaned up model: {path}")
def cleanup_all(self) -> None:
"""Clean up all resources."""
# Clear loaded models
self.loaded_models.clear()
# Clean up extracted directories
for mpta_path, extracted_dir in self.extracted_paths.items():
if os.path.exists(extracted_dir):
try:
shutil.rmtree(extracted_dir)
logger.info(f"Cleaned up extracted directory: {extracted_dir}")
except Exception as e:
logger.error(f"Failed to clean up {extracted_dir}: {e}")
self.extracted_paths.clear()
def get_node_info(self, node: PipelineNode, level: int = 0) -> str:
"""
Get formatted information about a pipeline node.
Args:
node: Pipeline node
level: Indentation level
Returns:
Formatted node information
"""
indent = " " * level
info = []
info.append(f"{indent}Model: {node.model_id}")
info.append(f"{indent} File: {node.model_file}")
info.append(f"{indent} Multi-class: {node.multi_class}")
if node.expected_classes:
info.append(f"{indent} Expected: {', '.join(node.expected_classes)}")
if node.trigger_classes:
info.append(f"{indent} Triggers: {', '.join(node.trigger_classes)}")
info.append(f"{indent} Confidence: {node.min_confidence}")
if node.crop:
info.append(f"{indent} Crop: {node.crop_class} (ratio: {node.crop_expand_ratio})")
if node.actions:
info.append(f"{indent} Actions: {len(node.actions)}")
if node.parallel_actions:
info.append(f"{indent} Parallel Actions: {len(node.parallel_actions)}")
if node.branches:
info.append(f"{indent} Branches ({len(node.branches)}):")
for branch in node.branches:
info.append(self.get_node_info(branch, level + 2))
return "\n".join(info)
# Global pipeline loader instance
_pipeline_loader = None
def get_pipeline_loader(temp_dir: Optional[str] = None) -> PipelineLoader:
"""Get or create the global pipeline loader instance."""
global _pipeline_loader
if _pipeline_loader is None:
_pipeline_loader = PipelineLoader(temp_dir)
return _pipeline_loader
# Convenience functions
async def load_pipeline_from_mpta(mpta_path: str) -> PipelineNode:
"""Load a pipeline from an MPTA file."""
loader = get_pipeline_loader()
return await loader.load_pipeline(mpta_path)