python-detector-worker/core/models/manager.py
2025-09-23 17:56:40 +07:00

439 lines
No EOL
15 KiB
Python

"""
Model Manager Module - Handles MPTA download, extraction, and model loading
"""
import os
import logging
import zipfile
import json
import hashlib
import requests
from pathlib import Path
from typing import Dict, Optional, Any, Set
from threading import Lock
from urllib.parse import urlparse, parse_qs
logger = logging.getLogger(__name__)
class ModelManager:
"""Manages MPTA model downloads, extraction, and caching"""
def __init__(self, models_dir: str = "models"):
"""
Initialize the Model Manager
Args:
models_dir: Base directory for storing models
"""
self.models_dir = Path(models_dir)
self.models_dir.mkdir(parents=True, exist_ok=True)
# Track downloaded models to avoid duplicates
self._downloaded_models: Set[int] = set()
self._model_paths: Dict[int, Path] = {}
self._download_lock = Lock()
# Scan existing models
self._scan_existing_models()
logger.info(f"ModelManager initialized with models directory: {self.models_dir}")
logger.info(f"Found existing models: {list(self._downloaded_models)}")
def _scan_existing_models(self) -> None:
"""Scan the models directory for existing downloaded models"""
if not self.models_dir.exists():
return
for model_dir in self.models_dir.iterdir():
if model_dir.is_dir() and model_dir.name.isdigit():
model_id = int(model_dir.name)
# Check if extraction was successful by looking for pipeline.json
extracted_dirs = list(model_dir.glob("*/pipeline.json"))
if extracted_dirs:
self._downloaded_models.add(model_id)
# Store path to the extracted model directory
self._model_paths[model_id] = extracted_dirs[0].parent
logger.debug(f"Found existing model {model_id} at {extracted_dirs[0].parent}")
def get_model_path(self, model_id: int) -> Optional[Path]:
"""
Get the path to an extracted model directory
Args:
model_id: The model ID
Returns:
Path to the extracted model directory or None if not found
"""
return self._model_paths.get(model_id)
def is_model_downloaded(self, model_id: int) -> bool:
"""
Check if a model has already been downloaded and extracted
Args:
model_id: The model ID to check
Returns:
True if the model is already available
"""
return model_id in self._downloaded_models
def ensure_model(self, model_id: int, model_url: str, model_name: str = None) -> Optional[Path]:
"""
Ensure a model is downloaded and extracted, downloading if necessary
Args:
model_id: The model ID
model_url: URL to download the MPTA file from
model_name: Optional model name for logging
Returns:
Path to the extracted model directory or None if failed
"""
# Check if already downloaded
if self.is_model_downloaded(model_id):
logger.info(f"Model {model_id} already available at {self._model_paths[model_id]}")
return self._model_paths[model_id]
# Download and extract with lock to prevent concurrent downloads of same model
with self._download_lock:
# Double-check after acquiring lock
if self.is_model_downloaded(model_id):
return self._model_paths[model_id]
logger.info(f"Model {model_id} not found locally, downloading from {model_url}")
# Create model directory
model_dir = self.models_dir / str(model_id)
model_dir.mkdir(parents=True, exist_ok=True)
# Extract filename from URL
mpta_filename = self._extract_filename_from_url(model_url, model_name, model_id)
mpta_path = model_dir / mpta_filename
# Download MPTA file
if not self._download_mpta(model_url, mpta_path):
logger.error(f"Failed to download model {model_id}")
return None
# Extract MPTA file
extracted_path = self._extract_mpta(mpta_path, model_dir)
if not extracted_path:
logger.error(f"Failed to extract model {model_id}")
return None
# Mark as downloaded and store path
self._downloaded_models.add(model_id)
self._model_paths[model_id] = extracted_path
logger.info(f"Successfully prepared model {model_id} at {extracted_path}")
return extracted_path
def _extract_filename_from_url(self, url: str, model_name: str = None, model_id: int = None) -> str:
"""
Extract a suitable filename from the URL
Args:
url: The URL to extract filename from
model_name: Optional model name
model_id: Optional model ID
Returns:
A suitable filename for the MPTA file
"""
parsed = urlparse(url)
path = parsed.path
# Try to get filename from path
if path:
filename = os.path.basename(path)
if filename and filename.endswith('.mpta'):
return filename
# Fallback to constructed name
if model_name:
return f"{model_name}-{model_id}.mpta"
else:
return f"model-{model_id}.mpta"
def _download_mpta(self, url: str, dest_path: Path) -> bool:
"""
Download an MPTA file from a URL
Args:
url: URL to download from
dest_path: Destination path for the file
Returns:
True if successful, False otherwise
"""
try:
logger.info(f"Starting download of model from {url}")
logger.debug(f"Download destination: {dest_path}")
response = requests.get(url, stream=True, timeout=300)
if response.status_code != 200:
logger.error(f"Failed to download MPTA file (status {response.status_code})")
return False
file_size = int(response.headers.get('content-length', 0))
logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
downloaded = 0
last_log_percent = 0
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
downloaded += len(chunk)
# Log progress every 10%
if file_size > 0:
percent = int(downloaded * 100 / file_size)
if percent >= last_log_percent + 10:
logger.debug(f"Download progress: {percent}%")
last_log_percent = percent
logger.info(f"Successfully downloaded MPTA file to {dest_path}")
return True
except requests.RequestException as e:
logger.error(f"Network error downloading MPTA: {str(e)}", exc_info=True)
# Clean up partial download
if dest_path.exists():
dest_path.unlink()
return False
except Exception as e:
logger.error(f"Unexpected error downloading MPTA: {str(e)}", exc_info=True)
# Clean up partial download
if dest_path.exists():
dest_path.unlink()
return False
def _extract_mpta(self, mpta_path: Path, target_dir: Path) -> Optional[Path]:
"""
Extract an MPTA (ZIP) file to the target directory
Args:
mpta_path: Path to the MPTA file
target_dir: Directory to extract to
Returns:
Path to the extracted model directory containing pipeline.json, or None if failed
"""
try:
if not mpta_path.exists():
logger.error(f"MPTA file not found: {mpta_path}")
return None
logger.info(f"Extracting MPTA file from {mpta_path} to {target_dir}")
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
# Get list of files
file_list = zip_ref.namelist()
logger.debug(f"Files in MPTA archive: {len(file_list)} files")
# Extract all files
zip_ref.extractall(target_dir)
logger.info(f"Successfully extracted MPTA file to {target_dir}")
# Find the directory containing pipeline.json
pipeline_files = list(target_dir.glob("*/pipeline.json"))
if not pipeline_files:
# Check if pipeline.json is in root
if (target_dir / "pipeline.json").exists():
logger.info(f"Found pipeline.json in root of {target_dir}")
return target_dir
logger.error(f"No pipeline.json found after extraction in {target_dir}")
return None
# Return the directory containing pipeline.json
extracted_dir = pipeline_files[0].parent
logger.info(f"Extracted model to {extracted_dir}")
# Keep the MPTA file for reference but could delete if space is a concern
# mpta_path.unlink()
# logger.debug(f"Removed MPTA file after extraction: {mpta_path}")
return extracted_dir
except zipfile.BadZipFile as e:
logger.error(f"Invalid ZIP/MPTA file {mpta_path}: {str(e)}", exc_info=True)
return None
except Exception as e:
logger.error(f"Failed to extract MPTA file {mpta_path}: {str(e)}", exc_info=True)
return None
def load_pipeline_config(self, model_id: int) -> Optional[Dict[str, Any]]:
"""
Load the pipeline.json configuration for a model
Args:
model_id: The model ID
Returns:
The pipeline configuration dictionary or None if not found
"""
model_path = self.get_model_path(model_id)
if not model_path:
logger.error(f"Model {model_id} not found")
return None
pipeline_path = model_path / "pipeline.json"
if not pipeline_path.exists():
logger.error(f"pipeline.json not found for model {model_id}")
return None
try:
with open(pipeline_path, 'r') as f:
config = json.load(f)
logger.debug(f"Loaded pipeline config for model {model_id}")
return config
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in pipeline.json for model {model_id}: {str(e)}")
return None
except Exception as e:
logger.error(f"Failed to load pipeline.json for model {model_id}: {str(e)}")
return None
def get_model_file_path(self, model_id: int, filename: str) -> Optional[Path]:
"""
Get the full path to a model file (e.g., .pt file)
Args:
model_id: The model ID
filename: The filename within the model directory
Returns:
Full path to the model file or None if not found
"""
model_path = self.get_model_path(model_id)
if not model_path:
return None
file_path = model_path / filename
if not file_path.exists():
logger.error(f"Model file {filename} not found in model {model_id}")
return None
return file_path
def cleanup_model(self, model_id: int) -> bool:
"""
Remove a downloaded model to free up space
Args:
model_id: The model ID to remove
Returns:
True if successful, False otherwise
"""
if model_id not in self._downloaded_models:
logger.warning(f"Model {model_id} not in downloaded models")
return False
try:
model_dir = self.models_dir / str(model_id)
if model_dir.exists():
import shutil
shutil.rmtree(model_dir)
logger.info(f"Removed model directory: {model_dir}")
self._downloaded_models.discard(model_id)
self._model_paths.pop(model_id, None)
return True
except Exception as e:
logger.error(f"Failed to cleanup model {model_id}: {str(e)}")
return False
def get_all_downloaded_models(self) -> Set[int]:
"""
Get a set of all downloaded model IDs
Returns:
Set of model IDs that are currently downloaded
"""
return self._downloaded_models.copy()
def get_pipeline_config(self, model_id: int) -> Optional[Any]:
"""
Get the pipeline configuration for a model.
Args:
model_id: The model ID
Returns:
PipelineConfig object if found, None otherwise
"""
try:
if model_id not in self._downloaded_models:
logger.warning(f"Model {model_id} not downloaded")
return None
model_path = self._model_paths.get(model_id)
if not model_path:
logger.warning(f"Model path not found for model {model_id}")
return None
# Import here to avoid circular imports
from .pipeline import PipelineParser
# Load pipeline.json
pipeline_file = model_path / "pipeline.json"
if not pipeline_file.exists():
logger.warning(f"No pipeline.json found for model {model_id}")
return None
# Create PipelineParser object and parse the configuration
pipeline_parser = PipelineParser()
success = pipeline_parser.parse(pipeline_file)
if success:
return pipeline_parser
else:
logger.error(f"Failed to parse pipeline.json for model {model_id}")
return None
except Exception as e:
logger.error(f"Error getting pipeline config for model {model_id}: {e}", exc_info=True)
return None
def get_yolo_model(self, model_id: int, model_filename: str) -> Optional[Any]:
"""
Create a YOLOWrapper instance for a specific model file.
Args:
model_id: The model ID
model_filename: The .pt model filename
Returns:
YOLOWrapper instance if successful, None otherwise
"""
try:
# Get the model file path
model_file_path = self.get_model_file_path(model_id, model_filename)
if not model_file_path or not model_file_path.exists():
logger.error(f"Model file {model_filename} not found for model {model_id}")
return None
# Import here to avoid circular imports
from .inference import YOLOWrapper
# Create YOLOWrapper instance
yolo_model = YOLOWrapper(
model_path=model_file_path,
model_id=f"{model_id}_{model_filename}",
device=None # Auto-detect device
)
logger.info(f"Created YOLOWrapper for model {model_id}: {model_filename}")
return yolo_model
except Exception as e:
logger.error(f"Error creating YOLO model for {model_id}:{model_filename}: {e}", exc_info=True)
return None