Refactor: done phase 2

This commit is contained in:
ziesorx 2025-09-23 16:13:11 +07:00
parent 8222e82dd7
commit aa10d5a55c
6 changed files with 1337 additions and 23 deletions

View file

@ -17,6 +17,7 @@ from .models import (
RequestStateMessage, PatchSessionResultMessage
)
from .state import worker_state, SystemMetrics
from ..models import ModelManager
logger = logging.getLogger(__name__)
@ -24,6 +25,9 @@ logger = logging.getLogger(__name__)
HEARTBEAT_INTERVAL = 2.0 # seconds
WORKER_TIMEOUT_MS = 10000
# Global model manager instance
model_manager = ModelManager()
class WebSocketHandler:
"""
@ -184,7 +188,10 @@ class WebSocketHandler:
# Update worker state with new subscriptions
worker_state.set_subscriptions(message.subscriptions)
# TODO: Phase 2 - Integrate with model management and streaming
# Phase 2: Download and manage models
await self._ensure_models(message.subscriptions)
# TODO: Phase 3 - Integrate with streaming management
# For now, just log the subscription changes
for subscription in message.subscriptions:
logger.info(f" Subscription: {subscription.subscriptionIdentifier} -> "
@ -198,6 +205,79 @@ class WebSocketHandler:
logger.info("Subscription list updated successfully")
async def _ensure_models(self, subscriptions) -> None:
"""Ensure all required models are downloaded and available."""
# Extract unique model requirements
unique_models = {}
for subscription in subscriptions:
model_id = subscription.modelId
if model_id not in unique_models:
unique_models[model_id] = {
'model_url': subscription.modelUrl,
'model_name': subscription.modelName
}
logger.info(f"[Model Management] Processing {len(unique_models)} unique models: {list(unique_models.keys())}")
# Check and download models concurrently
download_tasks = []
for model_id, model_info in unique_models.items():
task = asyncio.create_task(
self._ensure_single_model(model_id, model_info['model_url'], model_info['model_name'])
)
download_tasks.append(task)
# Wait for all downloads to complete
if download_tasks:
results = await asyncio.gather(*download_tasks, return_exceptions=True)
# Log results
success_count = 0
for i, result in enumerate(results):
model_id = list(unique_models.keys())[i]
if isinstance(result, Exception):
logger.error(f"[Model Management] Failed to ensure model {model_id}: {result}")
elif result:
success_count += 1
logger.info(f"[Model Management] Model {model_id} ready for use")
else:
logger.error(f"[Model Management] Failed to ensure model {model_id}")
logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models")
async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool:
"""Ensure a single model is downloaded and available."""
try:
# Check if model is already available
if model_manager.is_model_downloaded(model_id):
logger.info(f"[Model Management] Model {model_id} ({model_name}) already available")
return True
# Download and extract model in a thread pool to avoid blocking the event loop
logger.info(f"[Model Management] Downloading model {model_id} ({model_name}) from {model_url}")
# Use asyncio.to_thread for CPU-bound operations (Python 3.9+)
# For compatibility, we'll use run_in_executor
loop = asyncio.get_event_loop()
model_path = await loop.run_in_executor(
None,
model_manager.ensure_model,
model_id,
model_url,
model_name
)
if model_path:
logger.info(f"[Model Management] Successfully prepared model {model_id} at {model_path}")
return True
else:
logger.error(f"[Model Management] Failed to prepare model {model_id}")
return False
except Exception as e:
logger.error(f"[Model Management] Exception ensuring model {model_id}: {str(e)}", exc_info=True)
return False
async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None:
"""Handle setSessionId message."""
display_identifier = message.payload.displayIdentifier