Fix: Websocket communication misunderstanding error

This commit is contained in:
ziesorx 2025-09-13 00:19:41 +07:00
parent 9967bff6dc
commit 42a8325faf
8 changed files with 1109 additions and 63 deletions

View file

@ -10,6 +10,7 @@ import logging
import zipfile
import tempfile
import shutil
import traceback
from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass, field
from pathlib import Path
@ -113,36 +114,85 @@ class PipelineLoader:
ModelLoadError: If loading fails
"""
try:
logger.info(f"🔍 Loading pipeline from MPTA file: {mpta_path}")
# Verify MPTA file exists
if not os.path.exists(mpta_path):
raise ModelLoadError(f"MPTA file not found: {mpta_path}")
# Check if it's actually a zip file
if not zipfile.is_zipfile(mpta_path):
raise ModelLoadError(f"File is not a valid ZIP/MPTA archive: {mpta_path}")
# Extract MPTA if not already extracted
extracted_dir = await self._extract_mpta(mpta_path)
logger.info(f"📂 MPTA extracted to: {extracted_dir}")
# List contents of extracted directory for debugging
if os.path.exists(extracted_dir):
contents = os.listdir(extracted_dir)
logger.info(f"📋 Extracted contents: {contents}")
else:
raise ModelLoadError(f"Extraction failed - directory not found: {extracted_dir}")
# Load pipeline configuration
pipeline_json_path = os.path.join(extracted_dir, "pipeline.json")
if not os.path.exists(pipeline_json_path):
raise ModelLoadError(f"pipeline.json not found in {mpta_path}")
# First check if pipeline.json exists in a subdirectory (most common case)
pipeline_json_path = None
logger.info(f"🔍 Looking for pipeline.json in extracted directory: {extracted_dir}")
# Look for pipeline.json in subdirectories first (common case)
for root, _, files in os.walk(extracted_dir):
if "pipeline.json" in files:
pipeline_json_path = os.path.join(root, "pipeline.json")
logger.info(f"✅ Found pipeline.json at: {pipeline_json_path}")
break
# If not found in subdirectories, try root level
if not pipeline_json_path:
root_pipeline_json = os.path.join(extracted_dir, "pipeline.json")
if os.path.exists(root_pipeline_json):
pipeline_json_path = root_pipeline_json
logger.info(f"✅ Found pipeline.json at root: {pipeline_json_path}")
if not pipeline_json_path:
# List all files in extracted directory for debugging
all_files = []
for root, _, files in os.walk(extracted_dir):
for file in files:
all_files.append(os.path.join(root, file))
raise ModelLoadError(f"pipeline.json not found in extracted MPTA. "
f"Extracted to: {extracted_dir}. "
f"Files found: {all_files}")
with open(pipeline_json_path, 'r') as f:
config_data = json.load(f)
# Parse pipeline configuration
pipeline_config = self._parse_pipeline_config(config_data, extracted_dir)
logger.info(f"📋 Pipeline config loaded from: {pipeline_json_path}")
# Parse pipeline configuration (use extracted directory as base)
base_dir = os.path.dirname(pipeline_json_path)
pipeline_config = self._parse_pipeline_config(config_data, base_dir)
# Validate pipeline
self._validate_pipeline(pipeline_config)
# Load models for the pipeline
await self._load_pipeline_models(pipeline_config.root, extracted_dir)
await self._load_pipeline_models(pipeline_config.root, base_dir)
logger.info(f"Successfully loaded pipeline from {mpta_path}")
logger.info(f"Successfully loaded pipeline from {mpta_path}")
return pipeline_config.root
except Exception as e:
logger.error(f"Failed to load pipeline from {mpta_path}: {e}")
logger.error(f"❌ Failed to load pipeline from {mpta_path}: {e}")
traceback.print_exc()
raise ModelLoadError(f"Failed to load pipeline: {e}")
async def _extract_mpta(self, mpta_path: str) -> str:
"""
Extract MPTA file to temporary directory.
Extract MPTA file to model_id based directory structure.
For models/{model_id}/ structure, extracts to the same directory as the MPTA file.
Args:
mpta_path: Path to MPTA file
@ -156,28 +206,74 @@ class PipelineLoader:
if os.path.exists(extracted_dir):
return extracted_dir
# Create extraction directory
# Determine extraction directory
# If MPTA is in models/{model_id}/ structure, extract there
# Otherwise use temporary directory
mpta_dir = os.path.dirname(mpta_path)
mpta_name = os.path.splitext(os.path.basename(mpta_path))[0]
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
# Check if this is in models/{model_id}/ structure
if "models/" in mpta_dir and mpta_dir.count("/") >= 1:
# Extract directly to the models/{model_id}/ directory
extracted_dir = mpta_dir # Extract directly where the MPTA file is
else:
# Use temporary directory for non-model files
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
# Extract MPTA
logger.info(f"Extracting MPTA file: {mpta_path}")
logger.info(f"📦 Extracting MPTA file: {mpta_path}")
logger.info(f"📂 Extraction target: {extracted_dir}")
try:
# Verify it's a valid zip file before extracting
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
# Clean existing directory if exists
if os.path.exists(extracted_dir):
shutil.rmtree(extracted_dir)
os.makedirs(extracted_dir)
# List contents for debugging
file_list = zip_ref.namelist()
logger.info(f"📋 ZIP file contents ({len(file_list)} files): {file_list[:10]}{'...' if len(file_list) > 10 else ''}")
# For models/{model_id}/ structure, only clean extracted contents, not the MPTA file
if "models/" in extracted_dir and mpta_path.startswith(extracted_dir):
# Clean only the extracted subdirectories, keep the MPTA file
for item in os.listdir(extracted_dir):
item_path = os.path.join(extracted_dir, item)
if os.path.isdir(item_path):
logger.info(f"🧹 Cleaning existing extracted directory: {item_path}")
shutil.rmtree(item_path)
elif not item.endswith('.mpta'):
# Remove non-MPTA files that might be leftover extracts
logger.info(f"🧹 Cleaning leftover file: {item_path}")
os.remove(item_path)
else:
# For temp directories, clean everything
if os.path.exists(extracted_dir):
logger.info(f"🧹 Cleaning existing extraction directory: {extracted_dir}")
shutil.rmtree(extracted_dir)
os.makedirs(extracted_dir, exist_ok=True)
# Extract all files
logger.info(f"📤 Extracting {len(file_list)} files...")
zip_ref.extractall(extracted_dir)
# Verify extraction worked
extracted_files = []
for root, dirs, files in os.walk(extracted_dir):
for file in files:
extracted_files.append(os.path.join(root, file))
logger.info(f"✅ Extraction completed - {len(extracted_files)} files extracted")
logger.info(f"📋 Sample extracted files: {extracted_files[:5]}{'...' if len(extracted_files) > 5 else ''}")
self.extracted_paths[mpta_path] = extracted_dir
logger.info(f"Extracted to: {extracted_dir}")
logger.info(f"✅ MPTA successfully extracted to: {extracted_dir}")
return extracted_dir
except zipfile.BadZipFile as e:
logger.error(f"❌ Invalid ZIP file: {mpta_path}")
raise ModelLoadError(f"Invalid ZIP/MPTA file: {e}")
except Exception as e:
logger.error(f"❌ Failed to extract MPTA: {e}")
raise ModelLoadError(f"Failed to extract MPTA: {e}")
def _parse_pipeline_config(
@ -328,9 +424,10 @@ class PipelineLoader:
if node.model_path and not os.path.exists(node.model_path):
raise PipelineError(f"Model file not found: {node.model_path}")
# Validate cropping configuration
# Validate cropping configuration - be more lenient for backward compatibility
if node.crop and not node.crop_class:
raise PipelineError(f"Node {node.model_id} has crop=true but no cropClass")
logger.warning(f"Node {node.model_id} has crop=true but no cropClass - will disable cropping")
node.crop = False # Disable cropping instead of failing
# Validate confidence
if not 0 <= node.min_confidence <= 1: