Refactor: done phase 2
This commit is contained in:
parent
8222e82dd7
commit
aa10d5a55c
6 changed files with 1337 additions and 23 deletions
361
core/models/manager.py
Normal file
361
core/models/manager.py
Normal file
|
@ -0,0 +1,361 @@
|
|||
"""
|
||||
Model Manager Module - Handles MPTA download, extraction, and model loading
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import zipfile
|
||||
import json
|
||||
import hashlib
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Any, Set
|
||||
from threading import Lock
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Manages MPTA model downloads, extraction, and caching"""
|
||||
|
||||
def __init__(self, models_dir: str = "models"):
|
||||
"""
|
||||
Initialize the Model Manager
|
||||
|
||||
Args:
|
||||
models_dir: Base directory for storing models
|
||||
"""
|
||||
self.models_dir = Path(models_dir)
|
||||
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Track downloaded models to avoid duplicates
|
||||
self._downloaded_models: Set[int] = set()
|
||||
self._model_paths: Dict[int, Path] = {}
|
||||
self._download_lock = Lock()
|
||||
|
||||
# Scan existing models
|
||||
self._scan_existing_models()
|
||||
|
||||
logger.info(f"ModelManager initialized with models directory: {self.models_dir}")
|
||||
logger.info(f"Found existing models: {list(self._downloaded_models)}")
|
||||
|
||||
def _scan_existing_models(self) -> None:
|
||||
"""Scan the models directory for existing downloaded models"""
|
||||
if not self.models_dir.exists():
|
||||
return
|
||||
|
||||
for model_dir in self.models_dir.iterdir():
|
||||
if model_dir.is_dir() and model_dir.name.isdigit():
|
||||
model_id = int(model_dir.name)
|
||||
# Check if extraction was successful by looking for pipeline.json
|
||||
extracted_dirs = list(model_dir.glob("*/pipeline.json"))
|
||||
if extracted_dirs:
|
||||
self._downloaded_models.add(model_id)
|
||||
# Store path to the extracted model directory
|
||||
self._model_paths[model_id] = extracted_dirs[0].parent
|
||||
logger.debug(f"Found existing model {model_id} at {extracted_dirs[0].parent}")
|
||||
|
||||
def get_model_path(self, model_id: int) -> Optional[Path]:
|
||||
"""
|
||||
Get the path to an extracted model directory
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
|
||||
Returns:
|
||||
Path to the extracted model directory or None if not found
|
||||
"""
|
||||
return self._model_paths.get(model_id)
|
||||
|
||||
def is_model_downloaded(self, model_id: int) -> bool:
|
||||
"""
|
||||
Check if a model has already been downloaded and extracted
|
||||
|
||||
Args:
|
||||
model_id: The model ID to check
|
||||
|
||||
Returns:
|
||||
True if the model is already available
|
||||
"""
|
||||
return model_id in self._downloaded_models
|
||||
|
||||
def ensure_model(self, model_id: int, model_url: str, model_name: str = None) -> Optional[Path]:
|
||||
"""
|
||||
Ensure a model is downloaded and extracted, downloading if necessary
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
model_url: URL to download the MPTA file from
|
||||
model_name: Optional model name for logging
|
||||
|
||||
Returns:
|
||||
Path to the extracted model directory or None if failed
|
||||
"""
|
||||
# Check if already downloaded
|
||||
if self.is_model_downloaded(model_id):
|
||||
logger.info(f"Model {model_id} already available at {self._model_paths[model_id]}")
|
||||
return self._model_paths[model_id]
|
||||
|
||||
# Download and extract with lock to prevent concurrent downloads of same model
|
||||
with self._download_lock:
|
||||
# Double-check after acquiring lock
|
||||
if self.is_model_downloaded(model_id):
|
||||
return self._model_paths[model_id]
|
||||
|
||||
logger.info(f"Model {model_id} not found locally, downloading from {model_url}")
|
||||
|
||||
# Create model directory
|
||||
model_dir = self.models_dir / str(model_id)
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extract filename from URL
|
||||
mpta_filename = self._extract_filename_from_url(model_url, model_name, model_id)
|
||||
mpta_path = model_dir / mpta_filename
|
||||
|
||||
# Download MPTA file
|
||||
if not self._download_mpta(model_url, mpta_path):
|
||||
logger.error(f"Failed to download model {model_id}")
|
||||
return None
|
||||
|
||||
# Extract MPTA file
|
||||
extracted_path = self._extract_mpta(mpta_path, model_dir)
|
||||
if not extracted_path:
|
||||
logger.error(f"Failed to extract model {model_id}")
|
||||
return None
|
||||
|
||||
# Mark as downloaded and store path
|
||||
self._downloaded_models.add(model_id)
|
||||
self._model_paths[model_id] = extracted_path
|
||||
|
||||
logger.info(f"Successfully prepared model {model_id} at {extracted_path}")
|
||||
return extracted_path
|
||||
|
||||
def _extract_filename_from_url(self, url: str, model_name: str = None, model_id: int = None) -> str:
|
||||
"""
|
||||
Extract a suitable filename from the URL
|
||||
|
||||
Args:
|
||||
url: The URL to extract filename from
|
||||
model_name: Optional model name
|
||||
model_id: Optional model ID
|
||||
|
||||
Returns:
|
||||
A suitable filename for the MPTA file
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path
|
||||
|
||||
# Try to get filename from path
|
||||
if path:
|
||||
filename = os.path.basename(path)
|
||||
if filename and filename.endswith('.mpta'):
|
||||
return filename
|
||||
|
||||
# Fallback to constructed name
|
||||
if model_name:
|
||||
return f"{model_name}-{model_id}.mpta"
|
||||
else:
|
||||
return f"model-{model_id}.mpta"
|
||||
|
||||
def _download_mpta(self, url: str, dest_path: Path) -> bool:
|
||||
"""
|
||||
Download an MPTA file from a URL
|
||||
|
||||
Args:
|
||||
url: URL to download from
|
||||
dest_path: Destination path for the file
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting download of model from {url}")
|
||||
logger.debug(f"Download destination: {dest_path}")
|
||||
|
||||
response = requests.get(url, stream=True, timeout=300)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to download MPTA file (status {response.status_code})")
|
||||
return False
|
||||
|
||||
file_size = int(response.headers.get('content-length', 0))
|
||||
logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
|
||||
|
||||
downloaded = 0
|
||||
last_log_percent = 0
|
||||
|
||||
with open(dest_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
|
||||
# Log progress every 10%
|
||||
if file_size > 0:
|
||||
percent = int(downloaded * 100 / file_size)
|
||||
if percent >= last_log_percent + 10:
|
||||
logger.debug(f"Download progress: {percent}%")
|
||||
last_log_percent = percent
|
||||
|
||||
logger.info(f"Successfully downloaded MPTA file to {dest_path}")
|
||||
return True
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Network error downloading MPTA: {str(e)}", exc_info=True)
|
||||
# Clean up partial download
|
||||
if dest_path.exists():
|
||||
dest_path.unlink()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error downloading MPTA: {str(e)}", exc_info=True)
|
||||
# Clean up partial download
|
||||
if dest_path.exists():
|
||||
dest_path.unlink()
|
||||
return False
|
||||
|
||||
def _extract_mpta(self, mpta_path: Path, target_dir: Path) -> Optional[Path]:
|
||||
"""
|
||||
Extract an MPTA (ZIP) file to the target directory
|
||||
|
||||
Args:
|
||||
mpta_path: Path to the MPTA file
|
||||
target_dir: Directory to extract to
|
||||
|
||||
Returns:
|
||||
Path to the extracted model directory containing pipeline.json, or None if failed
|
||||
"""
|
||||
try:
|
||||
if not mpta_path.exists():
|
||||
logger.error(f"MPTA file not found: {mpta_path}")
|
||||
return None
|
||||
|
||||
logger.info(f"Extracting MPTA file from {mpta_path} to {target_dir}")
|
||||
|
||||
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
|
||||
# Get list of files
|
||||
file_list = zip_ref.namelist()
|
||||
logger.debug(f"Files in MPTA archive: {len(file_list)} files")
|
||||
|
||||
# Extract all files
|
||||
zip_ref.extractall(target_dir)
|
||||
|
||||
logger.info(f"Successfully extracted MPTA file to {target_dir}")
|
||||
|
||||
# Find the directory containing pipeline.json
|
||||
pipeline_files = list(target_dir.glob("*/pipeline.json"))
|
||||
if not pipeline_files:
|
||||
# Check if pipeline.json is in root
|
||||
if (target_dir / "pipeline.json").exists():
|
||||
logger.info(f"Found pipeline.json in root of {target_dir}")
|
||||
return target_dir
|
||||
logger.error(f"No pipeline.json found after extraction in {target_dir}")
|
||||
return None
|
||||
|
||||
# Return the directory containing pipeline.json
|
||||
extracted_dir = pipeline_files[0].parent
|
||||
logger.info(f"Extracted model to {extracted_dir}")
|
||||
|
||||
# Keep the MPTA file for reference but could delete if space is a concern
|
||||
# mpta_path.unlink()
|
||||
# logger.debug(f"Removed MPTA file after extraction: {mpta_path}")
|
||||
|
||||
return extracted_dir
|
||||
|
||||
except zipfile.BadZipFile as e:
|
||||
logger.error(f"Invalid ZIP/MPTA file {mpta_path}: {str(e)}", exc_info=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract MPTA file {mpta_path}: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def load_pipeline_config(self, model_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Load the pipeline.json configuration for a model
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
|
||||
Returns:
|
||||
The pipeline configuration dictionary or None if not found
|
||||
"""
|
||||
model_path = self.get_model_path(model_id)
|
||||
if not model_path:
|
||||
logger.error(f"Model {model_id} not found")
|
||||
return None
|
||||
|
||||
pipeline_path = model_path / "pipeline.json"
|
||||
if not pipeline_path.exists():
|
||||
logger.error(f"pipeline.json not found for model {model_id}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(pipeline_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
logger.debug(f"Loaded pipeline config for model {model_id}")
|
||||
return config
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in pipeline.json for model {model_id}: {str(e)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load pipeline.json for model {model_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_model_file_path(self, model_id: int, filename: str) -> Optional[Path]:
|
||||
"""
|
||||
Get the full path to a model file (e.g., .pt file)
|
||||
|
||||
Args:
|
||||
model_id: The model ID
|
||||
filename: The filename within the model directory
|
||||
|
||||
Returns:
|
||||
Full path to the model file or None if not found
|
||||
"""
|
||||
model_path = self.get_model_path(model_id)
|
||||
if not model_path:
|
||||
return None
|
||||
|
||||
file_path = model_path / filename
|
||||
if not file_path.exists():
|
||||
logger.error(f"Model file {filename} not found in model {model_id}")
|
||||
return None
|
||||
|
||||
return file_path
|
||||
|
||||
def cleanup_model(self, model_id: int) -> bool:
|
||||
"""
|
||||
Remove a downloaded model to free up space
|
||||
|
||||
Args:
|
||||
model_id: The model ID to remove
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
if model_id not in self._downloaded_models:
|
||||
logger.warning(f"Model {model_id} not in downloaded models")
|
||||
return False
|
||||
|
||||
try:
|
||||
model_dir = self.models_dir / str(model_id)
|
||||
if model_dir.exists():
|
||||
import shutil
|
||||
shutil.rmtree(model_dir)
|
||||
logger.info(f"Removed model directory: {model_dir}")
|
||||
|
||||
self._downloaded_models.discard(model_id)
|
||||
self._model_paths.pop(model_id, None)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup model {model_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_all_downloaded_models(self) -> Set[int]:
|
||||
"""
|
||||
Get a set of all downloaded model IDs
|
||||
|
||||
Returns:
|
||||
Set of model IDs that are currently downloaded
|
||||
"""
|
||||
return self._downloaded_models.copy()
|
Loading…
Add table
Add a link
Reference in a new issue