Refactor: done phase 2
This commit is contained in:
parent
8222e82dd7
commit
aa10d5a55c
6 changed files with 1337 additions and 23 deletions
|
@ -17,6 +17,7 @@ from .models import (
|
|||
RequestStateMessage, PatchSessionResultMessage
|
||||
)
|
||||
from .state import worker_state, SystemMetrics
|
||||
from ..models import ModelManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -24,6 +25,9 @@ logger = logging.getLogger(__name__)
|
|||
HEARTBEAT_INTERVAL = 2.0 # seconds
|
||||
WORKER_TIMEOUT_MS = 10000
|
||||
|
||||
# Global model manager instance
|
||||
model_manager = ModelManager()
|
||||
|
||||
|
||||
class WebSocketHandler:
|
||||
"""
|
||||
|
@ -184,7 +188,10 @@ class WebSocketHandler:
|
|||
# Update worker state with new subscriptions
|
||||
worker_state.set_subscriptions(message.subscriptions)
|
||||
|
||||
# TODO: Phase 2 - Integrate with model management and streaming
|
||||
# Phase 2: Download and manage models
|
||||
await self._ensure_models(message.subscriptions)
|
||||
|
||||
# TODO: Phase 3 - Integrate with streaming management
|
||||
# For now, just log the subscription changes
|
||||
for subscription in message.subscriptions:
|
||||
logger.info(f" Subscription: {subscription.subscriptionIdentifier} -> "
|
||||
|
@ -198,6 +205,79 @@ class WebSocketHandler:
|
|||
|
||||
logger.info("Subscription list updated successfully")
|
||||
|
||||
async def _ensure_models(self, subscriptions) -> None:
|
||||
"""Ensure all required models are downloaded and available."""
|
||||
# Extract unique model requirements
|
||||
unique_models = {}
|
||||
for subscription in subscriptions:
|
||||
model_id = subscription.modelId
|
||||
if model_id not in unique_models:
|
||||
unique_models[model_id] = {
|
||||
'model_url': subscription.modelUrl,
|
||||
'model_name': subscription.modelName
|
||||
}
|
||||
|
||||
logger.info(f"[Model Management] Processing {len(unique_models)} unique models: {list(unique_models.keys())}")
|
||||
|
||||
# Check and download models concurrently
|
||||
download_tasks = []
|
||||
for model_id, model_info in unique_models.items():
|
||||
task = asyncio.create_task(
|
||||
self._ensure_single_model(model_id, model_info['model_url'], model_info['model_name'])
|
||||
)
|
||||
download_tasks.append(task)
|
||||
|
||||
# Wait for all downloads to complete
|
||||
if download_tasks:
|
||||
results = await asyncio.gather(*download_tasks, return_exceptions=True)
|
||||
|
||||
# Log results
|
||||
success_count = 0
|
||||
for i, result in enumerate(results):
|
||||
model_id = list(unique_models.keys())[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"[Model Management] Failed to ensure model {model_id}: {result}")
|
||||
elif result:
|
||||
success_count += 1
|
||||
logger.info(f"[Model Management] Model {model_id} ready for use")
|
||||
else:
|
||||
logger.error(f"[Model Management] Failed to ensure model {model_id}")
|
||||
|
||||
logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models")
|
||||
|
||||
async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool:
|
||||
"""Ensure a single model is downloaded and available."""
|
||||
try:
|
||||
# Check if model is already available
|
||||
if model_manager.is_model_downloaded(model_id):
|
||||
logger.info(f"[Model Management] Model {model_id} ({model_name}) already available")
|
||||
return True
|
||||
|
||||
# Download and extract model in a thread pool to avoid blocking the event loop
|
||||
logger.info(f"[Model Management] Downloading model {model_id} ({model_name}) from {model_url}")
|
||||
|
||||
# Use asyncio.to_thread for CPU-bound operations (Python 3.9+)
|
||||
# For compatibility, we'll use run_in_executor
|
||||
loop = asyncio.get_event_loop()
|
||||
model_path = await loop.run_in_executor(
|
||||
None,
|
||||
model_manager.ensure_model,
|
||||
model_id,
|
||||
model_url,
|
||||
model_name
|
||||
)
|
||||
|
||||
if model_path:
|
||||
logger.info(f"[Model Management] Successfully prepared model {model_id} at {model_path}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[Model Management] Failed to prepare model {model_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Model Management] Exception ensuring model {model_id}: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None:
|
||||
"""Handle setSessionId message."""
|
||||
display_identifier = message.payload.displayIdentifier
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue