Fix: Websocket communication misunderstanding error
This commit is contained in:
parent
9967bff6dc
commit
42a8325faf
8 changed files with 1109 additions and 63 deletions
|
@ -227,6 +227,7 @@ class ModelManager:
|
|||
async def _get_model_path(self, model_url: str, model_id: str) -> str:
|
||||
"""
|
||||
Get local path for a model, downloading if necessary.
|
||||
Uses model_id subfolder structure: models/{model_id}/
|
||||
|
||||
Args:
|
||||
model_url: URL or local path to model
|
||||
|
@ -246,14 +247,18 @@ class ModelManager:
|
|||
if parsed.scheme == 'file':
|
||||
return parsed.path
|
||||
|
||||
# For HTTP/HTTPS URLs, download to cache
|
||||
# For HTTP/HTTPS URLs, download to cache with model_id subfolder
|
||||
if parsed.scheme in ['http', 'https']:
|
||||
# Create model_id subfolder structure
|
||||
model_dir = os.path.join(self.models_dir, str(model_id))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# Generate cache filename
|
||||
filename = os.path.basename(parsed.path)
|
||||
if not filename:
|
||||
filename = f"{model_id}.mpta"
|
||||
filename = f"model_{model_id}.mpta"
|
||||
|
||||
cache_path = os.path.join(self.models_dir, filename)
|
||||
cache_path = os.path.join(model_dir, filename)
|
||||
|
||||
# Check if already cached
|
||||
if os.path.exists(cache_path):
|
||||
|
@ -261,7 +266,7 @@ class ModelManager:
|
|||
return cache_path
|
||||
|
||||
# Download model
|
||||
logger.info(f"Downloading model from {model_url}")
|
||||
logger.info(f"Downloading model {model_id} from {model_url}")
|
||||
await self._download_model(model_url, cache_path)
|
||||
return cache_path
|
||||
|
||||
|
@ -270,7 +275,7 @@ class ModelManager:
|
|||
|
||||
async def _download_model(self, url: str, destination: str) -> None:
|
||||
"""
|
||||
Download a model file from URL.
|
||||
Download a model file from URL with enhanced HTTP request logging.
|
||||
|
||||
Args:
|
||||
url: URL to download from
|
||||
|
@ -278,9 +283,20 @@ class ModelManager:
|
|||
"""
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
import time
|
||||
|
||||
# Import HTTP logger
|
||||
from ..utils.logging_utils import get_http_logger
|
||||
http_logger = get_http_logger()
|
||||
|
||||
start_time = time.time()
|
||||
correlation_id = None
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Log request start
|
||||
correlation_id = http_logger.log_request_start("GET", url)
|
||||
|
||||
async with session.get(url) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
|
@ -293,22 +309,39 @@ class ModelManager:
|
|||
# Download to temporary file first
|
||||
temp_path = f"{destination}.tmp"
|
||||
downloaded = 0
|
||||
last_progress_log = 0
|
||||
|
||||
async with aiofiles.open(temp_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
await f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
|
||||
# Log progress
|
||||
if total_size and downloaded % (1024 * 1024) == 0:
|
||||
# Log progress at 10% intervals
|
||||
if total_size and downloaded > 0:
|
||||
progress = (downloaded / total_size) * 100
|
||||
logger.info(f"Download progress: {progress:.1f}%")
|
||||
if progress >= last_progress_log + 10 and progress <= 100:
|
||||
logger.info(f"Download progress: {progress:.1f}%")
|
||||
http_logger.log_download_progress(
|
||||
downloaded, total_size, progress, correlation_id
|
||||
)
|
||||
last_progress_log = progress
|
||||
|
||||
# Move to final destination
|
||||
os.rename(temp_path, destination)
|
||||
|
||||
# Log successful completion
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
http_logger.log_request_end(
|
||||
response.status, downloaded, duration_ms, correlation_id
|
||||
)
|
||||
logger.info(f"Model downloaded successfully to {destination}")
|
||||
|
||||
except Exception as e:
|
||||
# Log failed completion
|
||||
if correlation_id:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
http_logger.log_request_end(500, None, duration_ms, correlation_id)
|
||||
|
||||
# Clean up temporary file if exists
|
||||
temp_path = f"{destination}.tmp"
|
||||
if os.path.exists(temp_path):
|
||||
|
|
|
@ -10,6 +10,7 @@ import logging
|
|||
import zipfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import traceback
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
@ -113,36 +114,85 @@ class PipelineLoader:
|
|||
ModelLoadError: If loading fails
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🔍 Loading pipeline from MPTA file: {mpta_path}")
|
||||
|
||||
# Verify MPTA file exists
|
||||
if not os.path.exists(mpta_path):
|
||||
raise ModelLoadError(f"MPTA file not found: {mpta_path}")
|
||||
|
||||
# Check if it's actually a zip file
|
||||
if not zipfile.is_zipfile(mpta_path):
|
||||
raise ModelLoadError(f"File is not a valid ZIP/MPTA archive: {mpta_path}")
|
||||
|
||||
# Extract MPTA if not already extracted
|
||||
extracted_dir = await self._extract_mpta(mpta_path)
|
||||
logger.info(f"📂 MPTA extracted to: {extracted_dir}")
|
||||
|
||||
# List contents of extracted directory for debugging
|
||||
if os.path.exists(extracted_dir):
|
||||
contents = os.listdir(extracted_dir)
|
||||
logger.info(f"📋 Extracted contents: {contents}")
|
||||
else:
|
||||
raise ModelLoadError(f"Extraction failed - directory not found: {extracted_dir}")
|
||||
|
||||
# Load pipeline configuration
|
||||
pipeline_json_path = os.path.join(extracted_dir, "pipeline.json")
|
||||
if not os.path.exists(pipeline_json_path):
|
||||
raise ModelLoadError(f"pipeline.json not found in {mpta_path}")
|
||||
# First check if pipeline.json exists in a subdirectory (most common case)
|
||||
pipeline_json_path = None
|
||||
|
||||
logger.info(f"🔍 Looking for pipeline.json in extracted directory: {extracted_dir}")
|
||||
|
||||
# Look for pipeline.json in subdirectories first (common case)
|
||||
for root, _, files in os.walk(extracted_dir):
|
||||
if "pipeline.json" in files:
|
||||
pipeline_json_path = os.path.join(root, "pipeline.json")
|
||||
logger.info(f"✅ Found pipeline.json at: {pipeline_json_path}")
|
||||
break
|
||||
|
||||
# If not found in subdirectories, try root level
|
||||
if not pipeline_json_path:
|
||||
root_pipeline_json = os.path.join(extracted_dir, "pipeline.json")
|
||||
if os.path.exists(root_pipeline_json):
|
||||
pipeline_json_path = root_pipeline_json
|
||||
logger.info(f"✅ Found pipeline.json at root: {pipeline_json_path}")
|
||||
|
||||
if not pipeline_json_path:
|
||||
# List all files in extracted directory for debugging
|
||||
all_files = []
|
||||
for root, _, files in os.walk(extracted_dir):
|
||||
for file in files:
|
||||
all_files.append(os.path.join(root, file))
|
||||
|
||||
raise ModelLoadError(f"pipeline.json not found in extracted MPTA. "
|
||||
f"Extracted to: {extracted_dir}. "
|
||||
f"Files found: {all_files}")
|
||||
|
||||
with open(pipeline_json_path, 'r') as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# Parse pipeline configuration
|
||||
pipeline_config = self._parse_pipeline_config(config_data, extracted_dir)
|
||||
logger.info(f"📋 Pipeline config loaded from: {pipeline_json_path}")
|
||||
|
||||
# Parse pipeline configuration (use extracted directory as base)
|
||||
base_dir = os.path.dirname(pipeline_json_path)
|
||||
pipeline_config = self._parse_pipeline_config(config_data, base_dir)
|
||||
|
||||
# Validate pipeline
|
||||
self._validate_pipeline(pipeline_config)
|
||||
|
||||
# Load models for the pipeline
|
||||
await self._load_pipeline_models(pipeline_config.root, extracted_dir)
|
||||
await self._load_pipeline_models(pipeline_config.root, base_dir)
|
||||
|
||||
logger.info(f"Successfully loaded pipeline from {mpta_path}")
|
||||
logger.info(f"✅ Successfully loaded pipeline from {mpta_path}")
|
||||
return pipeline_config.root
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load pipeline from {mpta_path}: {e}")
|
||||
logger.error(f"❌ Failed to load pipeline from {mpta_path}: {e}")
|
||||
traceback.print_exc()
|
||||
raise ModelLoadError(f"Failed to load pipeline: {e}")
|
||||
|
||||
async def _extract_mpta(self, mpta_path: str) -> str:
|
||||
"""
|
||||
Extract MPTA file to temporary directory.
|
||||
Extract MPTA file to model_id based directory structure.
|
||||
For models/{model_id}/ structure, extracts to the same directory as the MPTA file.
|
||||
|
||||
Args:
|
||||
mpta_path: Path to MPTA file
|
||||
|
@ -156,28 +206,74 @@ class PipelineLoader:
|
|||
if os.path.exists(extracted_dir):
|
||||
return extracted_dir
|
||||
|
||||
# Create extraction directory
|
||||
# Determine extraction directory
|
||||
# If MPTA is in models/{model_id}/ structure, extract there
|
||||
# Otherwise use temporary directory
|
||||
mpta_dir = os.path.dirname(mpta_path)
|
||||
mpta_name = os.path.splitext(os.path.basename(mpta_path))[0]
|
||||
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
|
||||
|
||||
# Check if this is in models/{model_id}/ structure
|
||||
if "models/" in mpta_dir and mpta_dir.count("/") >= 1:
|
||||
# Extract directly to the models/{model_id}/ directory
|
||||
extracted_dir = mpta_dir # Extract directly where the MPTA file is
|
||||
else:
|
||||
# Use temporary directory for non-model files
|
||||
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
|
||||
|
||||
# Extract MPTA
|
||||
logger.info(f"Extracting MPTA file: {mpta_path}")
|
||||
logger.info(f"📦 Extracting MPTA file: {mpta_path}")
|
||||
logger.info(f"📂 Extraction target: {extracted_dir}")
|
||||
|
||||
try:
|
||||
# Verify it's a valid zip file before extracting
|
||||
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
|
||||
# Clean existing directory if exists
|
||||
if os.path.exists(extracted_dir):
|
||||
shutil.rmtree(extracted_dir)
|
||||
|
||||
os.makedirs(extracted_dir)
|
||||
# List contents for debugging
|
||||
file_list = zip_ref.namelist()
|
||||
logger.info(f"📋 ZIP file contents ({len(file_list)} files): {file_list[:10]}{'...' if len(file_list) > 10 else ''}")
|
||||
|
||||
# For models/{model_id}/ structure, only clean extracted contents, not the MPTA file
|
||||
if "models/" in extracted_dir and mpta_path.startswith(extracted_dir):
|
||||
# Clean only the extracted subdirectories, keep the MPTA file
|
||||
for item in os.listdir(extracted_dir):
|
||||
item_path = os.path.join(extracted_dir, item)
|
||||
if os.path.isdir(item_path):
|
||||
logger.info(f"🧹 Cleaning existing extracted directory: {item_path}")
|
||||
shutil.rmtree(item_path)
|
||||
elif not item.endswith('.mpta'):
|
||||
# Remove non-MPTA files that might be leftover extracts
|
||||
logger.info(f"🧹 Cleaning leftover file: {item_path}")
|
||||
os.remove(item_path)
|
||||
else:
|
||||
# For temp directories, clean everything
|
||||
if os.path.exists(extracted_dir):
|
||||
logger.info(f"🧹 Cleaning existing extraction directory: {extracted_dir}")
|
||||
shutil.rmtree(extracted_dir)
|
||||
|
||||
os.makedirs(extracted_dir, exist_ok=True)
|
||||
|
||||
# Extract all files
|
||||
logger.info(f"📤 Extracting {len(file_list)} files...")
|
||||
zip_ref.extractall(extracted_dir)
|
||||
|
||||
# Verify extraction worked
|
||||
extracted_files = []
|
||||
for root, dirs, files in os.walk(extracted_dir):
|
||||
for file in files:
|
||||
extracted_files.append(os.path.join(root, file))
|
||||
|
||||
logger.info(f"✅ Extraction completed - {len(extracted_files)} files extracted")
|
||||
logger.info(f"📋 Sample extracted files: {extracted_files[:5]}{'...' if len(extracted_files) > 5 else ''}")
|
||||
|
||||
self.extracted_paths[mpta_path] = extracted_dir
|
||||
logger.info(f"Extracted to: {extracted_dir}")
|
||||
logger.info(f"✅ MPTA successfully extracted to: {extracted_dir}")
|
||||
|
||||
return extracted_dir
|
||||
|
||||
except zipfile.BadZipFile as e:
|
||||
logger.error(f"❌ Invalid ZIP file: {mpta_path}")
|
||||
raise ModelLoadError(f"Invalid ZIP/MPTA file: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to extract MPTA: {e}")
|
||||
raise ModelLoadError(f"Failed to extract MPTA: {e}")
|
||||
|
||||
def _parse_pipeline_config(
|
||||
|
@ -328,9 +424,10 @@ class PipelineLoader:
|
|||
if node.model_path and not os.path.exists(node.model_path):
|
||||
raise PipelineError(f"Model file not found: {node.model_path}")
|
||||
|
||||
# Validate cropping configuration
|
||||
# Validate cropping configuration - be more lenient for backward compatibility
|
||||
if node.crop and not node.crop_class:
|
||||
raise PipelineError(f"Node {node.model_id} has crop=true but no cropClass")
|
||||
logger.warning(f"Node {node.model_id} has crop=true but no cropClass - will disable cropping")
|
||||
node.crop = False # Disable cropping instead of failing
|
||||
|
||||
# Validate confidence
|
||||
if not 0 <= node.min_confidence <= 1:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue