""" Model Manager Module - Handles MPTA download, extraction, and model loading """ import os import logging import zipfile import json import hashlib import requests from pathlib import Path from typing import Dict, Optional, Any, Set from threading import Lock from urllib.parse import urlparse, parse_qs logger = logging.getLogger(__name__) class ModelManager: """Manages MPTA model downloads, extraction, and caching""" def __init__(self, models_dir: str = "models"): """ Initialize the Model Manager Args: models_dir: Base directory for storing models """ self.models_dir = Path(models_dir) self.models_dir.mkdir(parents=True, exist_ok=True) # Track downloaded models to avoid duplicates self._downloaded_models: Set[int] = set() self._model_paths: Dict[int, Path] = {} self._download_lock = Lock() # Scan existing models self._scan_existing_models() logger.info(f"ModelManager initialized with models directory: {self.models_dir}") logger.info(f"Found existing models: {list(self._downloaded_models)}") def _scan_existing_models(self) -> None: """Scan the models directory for existing downloaded models""" if not self.models_dir.exists(): return for model_dir in self.models_dir.iterdir(): if model_dir.is_dir() and model_dir.name.isdigit(): model_id = int(model_dir.name) # Check if extraction was successful by looking for pipeline.json extracted_dirs = list(model_dir.glob("*/pipeline.json")) if extracted_dirs: self._downloaded_models.add(model_id) # Store path to the extracted model directory self._model_paths[model_id] = extracted_dirs[0].parent logger.debug(f"Found existing model {model_id} at {extracted_dirs[0].parent}") def get_model_path(self, model_id: int) -> Optional[Path]: """ Get the path to an extracted model directory Args: model_id: The model ID Returns: Path to the extracted model directory or None if not found """ return self._model_paths.get(model_id) def is_model_downloaded(self, model_id: int) -> bool: """ Check if a model has already been downloaded and extracted Args: model_id: The model ID to check Returns: True if the model is already available """ return model_id in self._downloaded_models def ensure_model(self, model_id: int, model_url: str, model_name: str = None) -> Optional[Path]: """ Ensure a model is downloaded and extracted, downloading if necessary Args: model_id: The model ID model_url: URL to download the MPTA file from model_name: Optional model name for logging Returns: Path to the extracted model directory or None if failed """ # Check if already downloaded if self.is_model_downloaded(model_id): logger.info(f"Model {model_id} already available at {self._model_paths[model_id]}") return self._model_paths[model_id] # Download and extract with lock to prevent concurrent downloads of same model with self._download_lock: # Double-check after acquiring lock if self.is_model_downloaded(model_id): return self._model_paths[model_id] logger.info(f"Model {model_id} not found locally, downloading from {model_url}") # Create model directory model_dir = self.models_dir / str(model_id) model_dir.mkdir(parents=True, exist_ok=True) # Extract filename from URL mpta_filename = self._extract_filename_from_url(model_url, model_name, model_id) mpta_path = model_dir / mpta_filename # Download MPTA file if not self._download_mpta(model_url, mpta_path): logger.error(f"Failed to download model {model_id}") return None # Extract MPTA file extracted_path = self._extract_mpta(mpta_path, model_dir) if not extracted_path: logger.error(f"Failed to extract model {model_id}") return None # Mark as downloaded and store path self._downloaded_models.add(model_id) self._model_paths[model_id] = extracted_path logger.info(f"Successfully prepared model {model_id} at {extracted_path}") return extracted_path def _extract_filename_from_url(self, url: str, model_name: str = None, model_id: int = None) -> str: """ Extract a suitable filename from the URL Args: url: The URL to extract filename from model_name: Optional model name model_id: Optional model ID Returns: A suitable filename for the MPTA file """ parsed = urlparse(url) path = parsed.path # Try to get filename from path if path: filename = os.path.basename(path) if filename and filename.endswith('.mpta'): return filename # Fallback to constructed name if model_name: return f"{model_name}-{model_id}.mpta" else: return f"model-{model_id}.mpta" def _download_mpta(self, url: str, dest_path: Path) -> bool: """ Download an MPTA file from a URL Args: url: URL to download from dest_path: Destination path for the file Returns: True if successful, False otherwise """ try: logger.info(f"Starting download of model from {url}") logger.debug(f"Download destination: {dest_path}") response = requests.get(url, stream=True, timeout=300) if response.status_code != 200: logger.error(f"Failed to download MPTA file (status {response.status_code})") return False file_size = int(response.headers.get('content-length', 0)) logger.info(f"Model file size: {file_size/1024/1024:.2f} MB") downloaded = 0 last_log_percent = 0 with open(dest_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) downloaded += len(chunk) # Log progress every 10% if file_size > 0: percent = int(downloaded * 100 / file_size) if percent >= last_log_percent + 10: logger.debug(f"Download progress: {percent}%") last_log_percent = percent logger.info(f"Successfully downloaded MPTA file to {dest_path}") return True except requests.RequestException as e: logger.error(f"Network error downloading MPTA: {str(e)}", exc_info=True) # Clean up partial download if dest_path.exists(): dest_path.unlink() return False except Exception as e: logger.error(f"Unexpected error downloading MPTA: {str(e)}", exc_info=True) # Clean up partial download if dest_path.exists(): dest_path.unlink() return False def _extract_mpta(self, mpta_path: Path, target_dir: Path) -> Optional[Path]: """ Extract an MPTA (ZIP) file to the target directory Args: mpta_path: Path to the MPTA file target_dir: Directory to extract to Returns: Path to the extracted model directory containing pipeline.json, or None if failed """ try: if not mpta_path.exists(): logger.error(f"MPTA file not found: {mpta_path}") return None logger.info(f"Extracting MPTA file from {mpta_path} to {target_dir}") with zipfile.ZipFile(mpta_path, 'r') as zip_ref: # Get list of files file_list = zip_ref.namelist() logger.debug(f"Files in MPTA archive: {len(file_list)} files") # Extract all files zip_ref.extractall(target_dir) logger.info(f"Successfully extracted MPTA file to {target_dir}") # Find the directory containing pipeline.json pipeline_files = list(target_dir.glob("*/pipeline.json")) if not pipeline_files: # Check if pipeline.json is in root if (target_dir / "pipeline.json").exists(): logger.info(f"Found pipeline.json in root of {target_dir}") return target_dir logger.error(f"No pipeline.json found after extraction in {target_dir}") return None # Return the directory containing pipeline.json extracted_dir = pipeline_files[0].parent logger.info(f"Extracted model to {extracted_dir}") # Keep the MPTA file for reference but could delete if space is a concern # mpta_path.unlink() # logger.debug(f"Removed MPTA file after extraction: {mpta_path}") return extracted_dir except zipfile.BadZipFile as e: logger.error(f"Invalid ZIP/MPTA file {mpta_path}: {str(e)}", exc_info=True) return None except Exception as e: logger.error(f"Failed to extract MPTA file {mpta_path}: {str(e)}", exc_info=True) return None def load_pipeline_config(self, model_id: int) -> Optional[Dict[str, Any]]: """ Load the pipeline.json configuration for a model Args: model_id: The model ID Returns: The pipeline configuration dictionary or None if not found """ model_path = self.get_model_path(model_id) if not model_path: logger.error(f"Model {model_id} not found") return None pipeline_path = model_path / "pipeline.json" if not pipeline_path.exists(): logger.error(f"pipeline.json not found for model {model_id}") return None try: with open(pipeline_path, 'r') as f: config = json.load(f) logger.debug(f"Loaded pipeline config for model {model_id}") return config except json.JSONDecodeError as e: logger.error(f"Invalid JSON in pipeline.json for model {model_id}: {str(e)}") return None except Exception as e: logger.error(f"Failed to load pipeline.json for model {model_id}: {str(e)}") return None def get_model_file_path(self, model_id: int, filename: str) -> Optional[Path]: """ Get the full path to a model file (e.g., .pt file) Args: model_id: The model ID filename: The filename within the model directory Returns: Full path to the model file or None if not found """ model_path = self.get_model_path(model_id) if not model_path: return None file_path = model_path / filename if not file_path.exists(): logger.error(f"Model file {filename} not found in model {model_id}") return None return file_path def cleanup_model(self, model_id: int) -> bool: """ Remove a downloaded model to free up space Args: model_id: The model ID to remove Returns: True if successful, False otherwise """ if model_id not in self._downloaded_models: logger.warning(f"Model {model_id} not in downloaded models") return False try: model_dir = self.models_dir / str(model_id) if model_dir.exists(): import shutil shutil.rmtree(model_dir) logger.info(f"Removed model directory: {model_dir}") self._downloaded_models.discard(model_id) self._model_paths.pop(model_id, None) return True except Exception as e: logger.error(f"Failed to cleanup model {model_id}: {str(e)}") return False def get_all_downloaded_models(self) -> Set[int]: """ Get a set of all downloaded model IDs Returns: Set of model IDs that are currently downloaded """ return self._downloaded_models.copy() def get_pipeline_config(self, model_id: int) -> Optional[Any]: """ Get the pipeline configuration for a model. Args: model_id: The model ID Returns: PipelineConfig object if found, None otherwise """ try: if model_id not in self._downloaded_models: logger.warning(f"Model {model_id} not downloaded") return None model_path = self._model_paths.get(model_id) if not model_path: logger.warning(f"Model path not found for model {model_id}") return None # Import here to avoid circular imports from .pipeline import PipelineParser # Load pipeline.json pipeline_file = model_path / "pipeline.json" if not pipeline_file.exists(): logger.warning(f"No pipeline.json found for model {model_id}") return None # Create PipelineParser object and parse the configuration pipeline_parser = PipelineParser() success = pipeline_parser.parse(pipeline_file) if success: return pipeline_parser else: logger.error(f"Failed to parse pipeline.json for model {model_id}") return None except Exception as e: logger.error(f"Error getting pipeline config for model {model_id}: {e}", exc_info=True) return None def get_yolo_model(self, model_id: int, model_filename: str) -> Optional[Any]: """ Create a YOLOWrapper instance for a specific model file. Args: model_id: The model ID model_filename: The .pt model filename Returns: YOLOWrapper instance if successful, None otherwise """ try: # Get the model file path model_file_path = self.get_model_file_path(model_id, model_filename) if not model_file_path or not model_file_path.exists(): logger.error(f"Model file {model_filename} not found for model {model_id}") return None # Import here to avoid circular imports from .inference import YOLOWrapper # Create YOLOWrapper instance yolo_model = YOLOWrapper( model_path=model_file_path, model_id=f"{model_id}_{model_filename}", device=None # Auto-detect device ) logger.info(f"Created YOLOWrapper for model {model_id}: {model_filename}") return yolo_model except Exception as e: logger.error(f"Error creating YOLO model for {model_id}:{model_filename}: {e}", exc_info=True) return None