439 lines
No EOL
15 KiB
Python
439 lines
No EOL
15 KiB
Python
"""
|
|
Model Manager Module - Handles MPTA download, extraction, and model loading
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import zipfile
|
|
import json
|
|
import hashlib
|
|
import requests
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Any, Set
|
|
from threading import Lock
|
|
from urllib.parse import urlparse, parse_qs
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelManager:
|
|
"""Manages MPTA model downloads, extraction, and caching"""
|
|
|
|
def __init__(self, models_dir: str = "models"):
|
|
"""
|
|
Initialize the Model Manager
|
|
|
|
Args:
|
|
models_dir: Base directory for storing models
|
|
"""
|
|
self.models_dir = Path(models_dir)
|
|
self.models_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Track downloaded models to avoid duplicates
|
|
self._downloaded_models: Set[int] = set()
|
|
self._model_paths: Dict[int, Path] = {}
|
|
self._download_lock = Lock()
|
|
|
|
# Scan existing models
|
|
self._scan_existing_models()
|
|
|
|
logger.info(f"ModelManager initialized with models directory: {self.models_dir}")
|
|
logger.info(f"Found existing models: {list(self._downloaded_models)}")
|
|
|
|
def _scan_existing_models(self) -> None:
|
|
"""Scan the models directory for existing downloaded models"""
|
|
if not self.models_dir.exists():
|
|
return
|
|
|
|
for model_dir in self.models_dir.iterdir():
|
|
if model_dir.is_dir() and model_dir.name.isdigit():
|
|
model_id = int(model_dir.name)
|
|
# Check if extraction was successful by looking for pipeline.json
|
|
extracted_dirs = list(model_dir.glob("*/pipeline.json"))
|
|
if extracted_dirs:
|
|
self._downloaded_models.add(model_id)
|
|
# Store path to the extracted model directory
|
|
self._model_paths[model_id] = extracted_dirs[0].parent
|
|
logger.debug(f"Found existing model {model_id} at {extracted_dirs[0].parent}")
|
|
|
|
def get_model_path(self, model_id: int) -> Optional[Path]:
|
|
"""
|
|
Get the path to an extracted model directory
|
|
|
|
Args:
|
|
model_id: The model ID
|
|
|
|
Returns:
|
|
Path to the extracted model directory or None if not found
|
|
"""
|
|
return self._model_paths.get(model_id)
|
|
|
|
def is_model_downloaded(self, model_id: int) -> bool:
|
|
"""
|
|
Check if a model has already been downloaded and extracted
|
|
|
|
Args:
|
|
model_id: The model ID to check
|
|
|
|
Returns:
|
|
True if the model is already available
|
|
"""
|
|
return model_id in self._downloaded_models
|
|
|
|
def ensure_model(self, model_id: int, model_url: str, model_name: str = None) -> Optional[Path]:
|
|
"""
|
|
Ensure a model is downloaded and extracted, downloading if necessary
|
|
|
|
Args:
|
|
model_id: The model ID
|
|
model_url: URL to download the MPTA file from
|
|
model_name: Optional model name for logging
|
|
|
|
Returns:
|
|
Path to the extracted model directory or None if failed
|
|
"""
|
|
# Check if already downloaded
|
|
if self.is_model_downloaded(model_id):
|
|
logger.info(f"Model {model_id} already available at {self._model_paths[model_id]}")
|
|
return self._model_paths[model_id]
|
|
|
|
# Download and extract with lock to prevent concurrent downloads of same model
|
|
with self._download_lock:
|
|
# Double-check after acquiring lock
|
|
if self.is_model_downloaded(model_id):
|
|
return self._model_paths[model_id]
|
|
|
|
logger.info(f"Model {model_id} not found locally, downloading from {model_url}")
|
|
|
|
# Create model directory
|
|
model_dir = self.models_dir / str(model_id)
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Extract filename from URL
|
|
mpta_filename = self._extract_filename_from_url(model_url, model_name, model_id)
|
|
mpta_path = model_dir / mpta_filename
|
|
|
|
# Download MPTA file
|
|
if not self._download_mpta(model_url, mpta_path):
|
|
logger.error(f"Failed to download model {model_id}")
|
|
return None
|
|
|
|
# Extract MPTA file
|
|
extracted_path = self._extract_mpta(mpta_path, model_dir)
|
|
if not extracted_path:
|
|
logger.error(f"Failed to extract model {model_id}")
|
|
return None
|
|
|
|
# Mark as downloaded and store path
|
|
self._downloaded_models.add(model_id)
|
|
self._model_paths[model_id] = extracted_path
|
|
|
|
logger.info(f"Successfully prepared model {model_id} at {extracted_path}")
|
|
return extracted_path
|
|
|
|
def _extract_filename_from_url(self, url: str, model_name: str = None, model_id: int = None) -> str:
|
|
"""
|
|
Extract a suitable filename from the URL
|
|
|
|
Args:
|
|
url: The URL to extract filename from
|
|
model_name: Optional model name
|
|
model_id: Optional model ID
|
|
|
|
Returns:
|
|
A suitable filename for the MPTA file
|
|
"""
|
|
parsed = urlparse(url)
|
|
path = parsed.path
|
|
|
|
# Try to get filename from path
|
|
if path:
|
|
filename = os.path.basename(path)
|
|
if filename and filename.endswith('.mpta'):
|
|
return filename
|
|
|
|
# Fallback to constructed name
|
|
if model_name:
|
|
return f"{model_name}-{model_id}.mpta"
|
|
else:
|
|
return f"model-{model_id}.mpta"
|
|
|
|
def _download_mpta(self, url: str, dest_path: Path) -> bool:
|
|
"""
|
|
Download an MPTA file from a URL
|
|
|
|
Args:
|
|
url: URL to download from
|
|
dest_path: Destination path for the file
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
logger.info(f"Starting download of model from {url}")
|
|
logger.debug(f"Download destination: {dest_path}")
|
|
|
|
response = requests.get(url, stream=True, timeout=300)
|
|
if response.status_code != 200:
|
|
logger.error(f"Failed to download MPTA file (status {response.status_code})")
|
|
return False
|
|
|
|
file_size = int(response.headers.get('content-length', 0))
|
|
logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
|
|
|
|
downloaded = 0
|
|
last_log_percent = 0
|
|
|
|
with open(dest_path, 'wb') as f:
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
if chunk:
|
|
f.write(chunk)
|
|
downloaded += len(chunk)
|
|
|
|
# Log progress every 10%
|
|
if file_size > 0:
|
|
percent = int(downloaded * 100 / file_size)
|
|
if percent >= last_log_percent + 10:
|
|
logger.debug(f"Download progress: {percent}%")
|
|
last_log_percent = percent
|
|
|
|
logger.info(f"Successfully downloaded MPTA file to {dest_path}")
|
|
return True
|
|
|
|
except requests.RequestException as e:
|
|
logger.error(f"Network error downloading MPTA: {str(e)}", exc_info=True)
|
|
# Clean up partial download
|
|
if dest_path.exists():
|
|
dest_path.unlink()
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error downloading MPTA: {str(e)}", exc_info=True)
|
|
# Clean up partial download
|
|
if dest_path.exists():
|
|
dest_path.unlink()
|
|
return False
|
|
|
|
def _extract_mpta(self, mpta_path: Path, target_dir: Path) -> Optional[Path]:
|
|
"""
|
|
Extract an MPTA (ZIP) file to the target directory
|
|
|
|
Args:
|
|
mpta_path: Path to the MPTA file
|
|
target_dir: Directory to extract to
|
|
|
|
Returns:
|
|
Path to the extracted model directory containing pipeline.json, or None if failed
|
|
"""
|
|
try:
|
|
if not mpta_path.exists():
|
|
logger.error(f"MPTA file not found: {mpta_path}")
|
|
return None
|
|
|
|
logger.info(f"Extracting MPTA file from {mpta_path} to {target_dir}")
|
|
|
|
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
|
|
# Get list of files
|
|
file_list = zip_ref.namelist()
|
|
logger.debug(f"Files in MPTA archive: {len(file_list)} files")
|
|
|
|
# Extract all files
|
|
zip_ref.extractall(target_dir)
|
|
|
|
logger.info(f"Successfully extracted MPTA file to {target_dir}")
|
|
|
|
# Find the directory containing pipeline.json
|
|
pipeline_files = list(target_dir.glob("*/pipeline.json"))
|
|
if not pipeline_files:
|
|
# Check if pipeline.json is in root
|
|
if (target_dir / "pipeline.json").exists():
|
|
logger.info(f"Found pipeline.json in root of {target_dir}")
|
|
return target_dir
|
|
logger.error(f"No pipeline.json found after extraction in {target_dir}")
|
|
return None
|
|
|
|
# Return the directory containing pipeline.json
|
|
extracted_dir = pipeline_files[0].parent
|
|
logger.info(f"Extracted model to {extracted_dir}")
|
|
|
|
# Keep the MPTA file for reference but could delete if space is a concern
|
|
# mpta_path.unlink()
|
|
# logger.debug(f"Removed MPTA file after extraction: {mpta_path}")
|
|
|
|
return extracted_dir
|
|
|
|
except zipfile.BadZipFile as e:
|
|
logger.error(f"Invalid ZIP/MPTA file {mpta_path}: {str(e)}", exc_info=True)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to extract MPTA file {mpta_path}: {str(e)}", exc_info=True)
|
|
return None
|
|
|
|
def load_pipeline_config(self, model_id: int) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Load the pipeline.json configuration for a model
|
|
|
|
Args:
|
|
model_id: The model ID
|
|
|
|
Returns:
|
|
The pipeline configuration dictionary or None if not found
|
|
"""
|
|
model_path = self.get_model_path(model_id)
|
|
if not model_path:
|
|
logger.error(f"Model {model_id} not found")
|
|
return None
|
|
|
|
pipeline_path = model_path / "pipeline.json"
|
|
if not pipeline_path.exists():
|
|
logger.error(f"pipeline.json not found for model {model_id}")
|
|
return None
|
|
|
|
try:
|
|
with open(pipeline_path, 'r') as f:
|
|
config = json.load(f)
|
|
logger.debug(f"Loaded pipeline config for model {model_id}")
|
|
return config
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Invalid JSON in pipeline.json for model {model_id}: {str(e)}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to load pipeline.json for model {model_id}: {str(e)}")
|
|
return None
|
|
|
|
def get_model_file_path(self, model_id: int, filename: str) -> Optional[Path]:
|
|
"""
|
|
Get the full path to a model file (e.g., .pt file)
|
|
|
|
Args:
|
|
model_id: The model ID
|
|
filename: The filename within the model directory
|
|
|
|
Returns:
|
|
Full path to the model file or None if not found
|
|
"""
|
|
model_path = self.get_model_path(model_id)
|
|
if not model_path:
|
|
return None
|
|
|
|
file_path = model_path / filename
|
|
if not file_path.exists():
|
|
logger.error(f"Model file {filename} not found in model {model_id}")
|
|
return None
|
|
|
|
return file_path
|
|
|
|
def cleanup_model(self, model_id: int) -> bool:
|
|
"""
|
|
Remove a downloaded model to free up space
|
|
|
|
Args:
|
|
model_id: The model ID to remove
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
if model_id not in self._downloaded_models:
|
|
logger.warning(f"Model {model_id} not in downloaded models")
|
|
return False
|
|
|
|
try:
|
|
model_dir = self.models_dir / str(model_id)
|
|
if model_dir.exists():
|
|
import shutil
|
|
shutil.rmtree(model_dir)
|
|
logger.info(f"Removed model directory: {model_dir}")
|
|
|
|
self._downloaded_models.discard(model_id)
|
|
self._model_paths.pop(model_id, None)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to cleanup model {model_id}: {str(e)}")
|
|
return False
|
|
|
|
def get_all_downloaded_models(self) -> Set[int]:
|
|
"""
|
|
Get a set of all downloaded model IDs
|
|
|
|
Returns:
|
|
Set of model IDs that are currently downloaded
|
|
"""
|
|
return self._downloaded_models.copy()
|
|
|
|
def get_pipeline_config(self, model_id: int) -> Optional[Any]:
|
|
"""
|
|
Get the pipeline configuration for a model.
|
|
|
|
Args:
|
|
model_id: The model ID
|
|
|
|
Returns:
|
|
PipelineConfig object if found, None otherwise
|
|
"""
|
|
try:
|
|
if model_id not in self._downloaded_models:
|
|
logger.warning(f"Model {model_id} not downloaded")
|
|
return None
|
|
|
|
model_path = self._model_paths.get(model_id)
|
|
if not model_path:
|
|
logger.warning(f"Model path not found for model {model_id}")
|
|
return None
|
|
|
|
# Import here to avoid circular imports
|
|
from .pipeline import PipelineParser
|
|
|
|
# Load pipeline.json
|
|
pipeline_file = model_path / "pipeline.json"
|
|
if not pipeline_file.exists():
|
|
logger.warning(f"No pipeline.json found for model {model_id}")
|
|
return None
|
|
|
|
# Create PipelineParser object and parse the configuration
|
|
pipeline_parser = PipelineParser()
|
|
success = pipeline_parser.parse(pipeline_file)
|
|
|
|
if success:
|
|
return pipeline_parser
|
|
else:
|
|
logger.error(f"Failed to parse pipeline.json for model {model_id}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting pipeline config for model {model_id}: {e}", exc_info=True)
|
|
return None
|
|
|
|
def get_yolo_model(self, model_id: int, model_filename: str) -> Optional[Any]:
|
|
"""
|
|
Create a YOLOWrapper instance for a specific model file.
|
|
|
|
Args:
|
|
model_id: The model ID
|
|
model_filename: The .pt model filename
|
|
|
|
Returns:
|
|
YOLOWrapper instance if successful, None otherwise
|
|
"""
|
|
try:
|
|
# Get the model file path
|
|
model_file_path = self.get_model_file_path(model_id, model_filename)
|
|
if not model_file_path or not model_file_path.exists():
|
|
logger.error(f"Model file {model_filename} not found for model {model_id}")
|
|
return None
|
|
|
|
# Import here to avoid circular imports
|
|
from .inference import YOLOWrapper
|
|
|
|
# Create YOLOWrapper instance
|
|
yolo_model = YOLOWrapper(
|
|
model_path=model_file_path,
|
|
model_id=f"{model_id}_{model_filename}",
|
|
device=None # Auto-detect device
|
|
)
|
|
|
|
logger.info(f"Created YOLOWrapper for model {model_id}: {model_filename}")
|
|
return yolo_model
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating YOLO model for {model_id}:{model_filename}: {e}", exc_info=True)
|
|
return None |