Fix: Websocket communication misunderstanding error

This commit is contained in:
ziesorx 2025-09-13 00:19:41 +07:00
parent 9967bff6dc
commit 42a8325faf
8 changed files with 1109 additions and 63 deletions

View file

@ -227,6 +227,7 @@ class ModelManager:
async def _get_model_path(self, model_url: str, model_id: str) -> str:
"""
Get local path for a model, downloading if necessary.
Uses model_id subfolder structure: models/{model_id}/
Args:
model_url: URL or local path to model
@ -246,14 +247,18 @@ class ModelManager:
if parsed.scheme == 'file':
return parsed.path
# For HTTP/HTTPS URLs, download to cache
# For HTTP/HTTPS URLs, download to cache with model_id subfolder
if parsed.scheme in ['http', 'https']:
# Create model_id subfolder structure
model_dir = os.path.join(self.models_dir, str(model_id))
os.makedirs(model_dir, exist_ok=True)
# Generate cache filename
filename = os.path.basename(parsed.path)
if not filename:
filename = f"{model_id}.mpta"
filename = f"model_{model_id}.mpta"
cache_path = os.path.join(self.models_dir, filename)
cache_path = os.path.join(model_dir, filename)
# Check if already cached
if os.path.exists(cache_path):
@ -261,7 +266,7 @@ class ModelManager:
return cache_path
# Download model
logger.info(f"Downloading model from {model_url}")
logger.info(f"Downloading model {model_id} from {model_url}")
await self._download_model(model_url, cache_path)
return cache_path
@ -270,7 +275,7 @@ class ModelManager:
async def _download_model(self, url: str, destination: str) -> None:
"""
Download a model file from URL.
Download a model file from URL with enhanced HTTP request logging.
Args:
url: URL to download from
@ -278,9 +283,20 @@ class ModelManager:
"""
import aiohttp
import aiofiles
import time
# Import HTTP logger
from ..utils.logging_utils import get_http_logger
http_logger = get_http_logger()
start_time = time.time()
correlation_id = None
try:
async with aiohttp.ClientSession() as session:
# Log request start
correlation_id = http_logger.log_request_start("GET", url)
async with session.get(url) as response:
response.raise_for_status()
@ -293,22 +309,39 @@ class ModelManager:
# Download to temporary file first
temp_path = f"{destination}.tmp"
downloaded = 0
last_progress_log = 0
async with aiofiles.open(temp_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
await f.write(chunk)
downloaded += len(chunk)
# Log progress
if total_size and downloaded % (1024 * 1024) == 0:
# Log progress at 10% intervals
if total_size and downloaded > 0:
progress = (downloaded / total_size) * 100
logger.info(f"Download progress: {progress:.1f}%")
if progress >= last_progress_log + 10 and progress <= 100:
logger.info(f"Download progress: {progress:.1f}%")
http_logger.log_download_progress(
downloaded, total_size, progress, correlation_id
)
last_progress_log = progress
# Move to final destination
os.rename(temp_path, destination)
# Log successful completion
duration_ms = (time.time() - start_time) * 1000
http_logger.log_request_end(
response.status, downloaded, duration_ms, correlation_id
)
logger.info(f"Model downloaded successfully to {destination}")
except Exception as e:
# Log failed completion
if correlation_id:
duration_ms = (time.time() - start_time) * 1000
http_logger.log_request_end(500, None, duration_ms, correlation_id)
# Clean up temporary file if exists
temp_path = f"{destination}.tmp"
if os.path.exists(temp_path):

View file

@ -10,6 +10,7 @@ import logging
import zipfile
import tempfile
import shutil
import traceback
from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass, field
from pathlib import Path
@ -113,36 +114,85 @@ class PipelineLoader:
ModelLoadError: If loading fails
"""
try:
logger.info(f"🔍 Loading pipeline from MPTA file: {mpta_path}")
# Verify MPTA file exists
if not os.path.exists(mpta_path):
raise ModelLoadError(f"MPTA file not found: {mpta_path}")
# Check if it's actually a zip file
if not zipfile.is_zipfile(mpta_path):
raise ModelLoadError(f"File is not a valid ZIP/MPTA archive: {mpta_path}")
# Extract MPTA if not already extracted
extracted_dir = await self._extract_mpta(mpta_path)
logger.info(f"📂 MPTA extracted to: {extracted_dir}")
# List contents of extracted directory for debugging
if os.path.exists(extracted_dir):
contents = os.listdir(extracted_dir)
logger.info(f"📋 Extracted contents: {contents}")
else:
raise ModelLoadError(f"Extraction failed - directory not found: {extracted_dir}")
# Load pipeline configuration
pipeline_json_path = os.path.join(extracted_dir, "pipeline.json")
if not os.path.exists(pipeline_json_path):
raise ModelLoadError(f"pipeline.json not found in {mpta_path}")
# First check if pipeline.json exists in a subdirectory (most common case)
pipeline_json_path = None
logger.info(f"🔍 Looking for pipeline.json in extracted directory: {extracted_dir}")
# Look for pipeline.json in subdirectories first (common case)
for root, _, files in os.walk(extracted_dir):
if "pipeline.json" in files:
pipeline_json_path = os.path.join(root, "pipeline.json")
logger.info(f"✅ Found pipeline.json at: {pipeline_json_path}")
break
# If not found in subdirectories, try root level
if not pipeline_json_path:
root_pipeline_json = os.path.join(extracted_dir, "pipeline.json")
if os.path.exists(root_pipeline_json):
pipeline_json_path = root_pipeline_json
logger.info(f"✅ Found pipeline.json at root: {pipeline_json_path}")
if not pipeline_json_path:
# List all files in extracted directory for debugging
all_files = []
for root, _, files in os.walk(extracted_dir):
for file in files:
all_files.append(os.path.join(root, file))
raise ModelLoadError(f"pipeline.json not found in extracted MPTA. "
f"Extracted to: {extracted_dir}. "
f"Files found: {all_files}")
with open(pipeline_json_path, 'r') as f:
config_data = json.load(f)
# Parse pipeline configuration
pipeline_config = self._parse_pipeline_config(config_data, extracted_dir)
logger.info(f"📋 Pipeline config loaded from: {pipeline_json_path}")
# Parse pipeline configuration (use extracted directory as base)
base_dir = os.path.dirname(pipeline_json_path)
pipeline_config = self._parse_pipeline_config(config_data, base_dir)
# Validate pipeline
self._validate_pipeline(pipeline_config)
# Load models for the pipeline
await self._load_pipeline_models(pipeline_config.root, extracted_dir)
await self._load_pipeline_models(pipeline_config.root, base_dir)
logger.info(f"Successfully loaded pipeline from {mpta_path}")
logger.info(f"Successfully loaded pipeline from {mpta_path}")
return pipeline_config.root
except Exception as e:
logger.error(f"Failed to load pipeline from {mpta_path}: {e}")
logger.error(f"❌ Failed to load pipeline from {mpta_path}: {e}")
traceback.print_exc()
raise ModelLoadError(f"Failed to load pipeline: {e}")
async def _extract_mpta(self, mpta_path: str) -> str:
"""
Extract MPTA file to temporary directory.
Extract MPTA file to model_id based directory structure.
For models/{model_id}/ structure, extracts to the same directory as the MPTA file.
Args:
mpta_path: Path to MPTA file
@ -156,28 +206,74 @@ class PipelineLoader:
if os.path.exists(extracted_dir):
return extracted_dir
# Create extraction directory
# Determine extraction directory
# If MPTA is in models/{model_id}/ structure, extract there
# Otherwise use temporary directory
mpta_dir = os.path.dirname(mpta_path)
mpta_name = os.path.splitext(os.path.basename(mpta_path))[0]
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
# Check if this is in models/{model_id}/ structure
if "models/" in mpta_dir and mpta_dir.count("/") >= 1:
# Extract directly to the models/{model_id}/ directory
extracted_dir = mpta_dir # Extract directly where the MPTA file is
else:
# Use temporary directory for non-model files
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
# Extract MPTA
logger.info(f"Extracting MPTA file: {mpta_path}")
logger.info(f"📦 Extracting MPTA file: {mpta_path}")
logger.info(f"📂 Extraction target: {extracted_dir}")
try:
# Verify it's a valid zip file before extracting
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
# Clean existing directory if exists
if os.path.exists(extracted_dir):
shutil.rmtree(extracted_dir)
os.makedirs(extracted_dir)
# List contents for debugging
file_list = zip_ref.namelist()
logger.info(f"📋 ZIP file contents ({len(file_list)} files): {file_list[:10]}{'...' if len(file_list) > 10 else ''}")
# For models/{model_id}/ structure, only clean extracted contents, not the MPTA file
if "models/" in extracted_dir and mpta_path.startswith(extracted_dir):
# Clean only the extracted subdirectories, keep the MPTA file
for item in os.listdir(extracted_dir):
item_path = os.path.join(extracted_dir, item)
if os.path.isdir(item_path):
logger.info(f"🧹 Cleaning existing extracted directory: {item_path}")
shutil.rmtree(item_path)
elif not item.endswith('.mpta'):
# Remove non-MPTA files that might be leftover extracts
logger.info(f"🧹 Cleaning leftover file: {item_path}")
os.remove(item_path)
else:
# For temp directories, clean everything
if os.path.exists(extracted_dir):
logger.info(f"🧹 Cleaning existing extraction directory: {extracted_dir}")
shutil.rmtree(extracted_dir)
os.makedirs(extracted_dir, exist_ok=True)
# Extract all files
logger.info(f"📤 Extracting {len(file_list)} files...")
zip_ref.extractall(extracted_dir)
# Verify extraction worked
extracted_files = []
for root, dirs, files in os.walk(extracted_dir):
for file in files:
extracted_files.append(os.path.join(root, file))
logger.info(f"✅ Extraction completed - {len(extracted_files)} files extracted")
logger.info(f"📋 Sample extracted files: {extracted_files[:5]}{'...' if len(extracted_files) > 5 else ''}")
self.extracted_paths[mpta_path] = extracted_dir
logger.info(f"Extracted to: {extracted_dir}")
logger.info(f"✅ MPTA successfully extracted to: {extracted_dir}")
return extracted_dir
except zipfile.BadZipFile as e:
logger.error(f"❌ Invalid ZIP file: {mpta_path}")
raise ModelLoadError(f"Invalid ZIP/MPTA file: {e}")
except Exception as e:
logger.error(f"❌ Failed to extract MPTA: {e}")
raise ModelLoadError(f"Failed to extract MPTA: {e}")
def _parse_pipeline_config(
@ -328,9 +424,10 @@ class PipelineLoader:
if node.model_path and not os.path.exists(node.model_path):
raise PipelineError(f"Model file not found: {node.model_path}")
# Validate cropping configuration
# Validate cropping configuration - be more lenient for backward compatibility
if node.crop and not node.crop_class:
raise PipelineError(f"Node {node.model_id} has crop=true but no cropClass")
logger.warning(f"Node {node.model_id} has crop=true but no cropClass - will disable cropping")
node.crop = False # Disable cropping instead of failing
# Validate confidence
if not 0 <= node.min_confidence <= 1: