Refactor: done phase 4
This commit is contained in:
parent
7e8034c6e5
commit
9e4c23c75c
8 changed files with 1533 additions and 37 deletions
|
@ -18,6 +18,8 @@ from .models import (
|
|||
)
|
||||
from .state import worker_state, SystemMetrics
|
||||
from ..models import ModelManager
|
||||
from ..streaming.manager import shared_stream_manager
|
||||
from ..tracking.integration import TrackingPipelineIntegration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -199,17 +201,8 @@ class WebSocketHandler:
|
|||
# Phase 2: Download and manage models
|
||||
await self._ensure_models(message.subscriptions)
|
||||
|
||||
# TODO: Phase 3 - Integrate with streaming management
|
||||
# For now, just log the subscription changes
|
||||
for subscription in message.subscriptions:
|
||||
logger.info(f" Subscription: {subscription.subscriptionIdentifier} -> "
|
||||
f"Model {subscription.modelId} ({subscription.modelName})")
|
||||
if subscription.rtspUrl:
|
||||
logger.debug(f" RTSP: {subscription.rtspUrl}")
|
||||
if subscription.snapshotUrl:
|
||||
logger.debug(f" Snapshot: {subscription.snapshotUrl} ({subscription.snapshotInterval}ms)")
|
||||
if subscription.modelUrl:
|
||||
logger.debug(f" Model: {subscription.modelUrl}")
|
||||
# Phase 3 & 4: Integrate with streaming management and tracking
|
||||
await self._update_stream_subscriptions(message.subscriptions)
|
||||
|
||||
logger.info("Subscription list updated successfully")
|
||||
|
||||
|
@ -260,6 +253,168 @@ class WebSocketHandler:
|
|||
|
||||
logger.info(f"[Model Management] Successfully ensured {success_count}/{len(unique_models)} models")
|
||||
|
||||
async def _update_stream_subscriptions(self, subscriptions) -> None:
|
||||
"""Update streaming subscriptions with tracking integration."""
|
||||
try:
|
||||
# Convert subscriptions to the format expected by StreamManager
|
||||
subscription_payloads = []
|
||||
for subscription in subscriptions:
|
||||
payload = {
|
||||
'subscriptionIdentifier': subscription.subscriptionIdentifier,
|
||||
'rtspUrl': subscription.rtspUrl,
|
||||
'snapshotUrl': subscription.snapshotUrl,
|
||||
'snapshotInterval': subscription.snapshotInterval,
|
||||
'modelId': subscription.modelId,
|
||||
'modelUrl': subscription.modelUrl,
|
||||
'modelName': subscription.modelName
|
||||
}
|
||||
# Add crop coordinates if present
|
||||
if hasattr(subscription, 'cropX1'):
|
||||
payload.update({
|
||||
'cropX1': subscription.cropX1,
|
||||
'cropY1': subscription.cropY1,
|
||||
'cropX2': subscription.cropX2,
|
||||
'cropY2': subscription.cropY2
|
||||
})
|
||||
subscription_payloads.append(payload)
|
||||
|
||||
# Reconcile subscriptions with StreamManager
|
||||
logger.info("[Streaming] Reconciling stream subscriptions with tracking")
|
||||
reconcile_result = await self._reconcile_subscriptions_with_tracking(subscription_payloads)
|
||||
|
||||
logger.info(f"[Streaming] Subscription reconciliation complete: "
|
||||
f"added={reconcile_result.get('added', 0)}, "
|
||||
f"removed={reconcile_result.get('removed', 0)}, "
|
||||
f"failed={reconcile_result.get('failed', 0)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating stream subscriptions: {e}", exc_info=True)
|
||||
|
||||
async def _reconcile_subscriptions_with_tracking(self, target_subscriptions) -> dict:
|
||||
"""Reconcile subscriptions with tracking integration."""
|
||||
try:
|
||||
# First, we need to create tracking integrations for each unique model
|
||||
tracking_integrations = {}
|
||||
|
||||
for subscription_payload in target_subscriptions:
|
||||
model_id = subscription_payload['modelId']
|
||||
|
||||
# Create tracking integration if not already created
|
||||
if model_id not in tracking_integrations:
|
||||
# Get pipeline configuration for this model
|
||||
pipeline_parser = model_manager.get_pipeline_config(model_id)
|
||||
if pipeline_parser:
|
||||
# Create tracking integration
|
||||
tracking_integration = TrackingPipelineIntegration(
|
||||
pipeline_parser, model_manager
|
||||
)
|
||||
|
||||
# Initialize tracking model
|
||||
success = await tracking_integration.initialize_tracking_model()
|
||||
if success:
|
||||
tracking_integrations[model_id] = tracking_integration
|
||||
logger.info(f"[Tracking] Created tracking integration for model {model_id}")
|
||||
else:
|
||||
logger.warning(f"[Tracking] Failed to initialize tracking for model {model_id}")
|
||||
else:
|
||||
logger.warning(f"[Tracking] No pipeline config found for model {model_id}")
|
||||
|
||||
# Now reconcile with StreamManager, adding tracking integrations
|
||||
current_subscription_ids = set()
|
||||
for subscription_info in shared_stream_manager.get_all_subscriptions():
|
||||
current_subscription_ids.add(subscription_info.subscription_id)
|
||||
|
||||
target_subscription_ids = {sub['subscriptionIdentifier'] for sub in target_subscriptions}
|
||||
|
||||
# Find subscriptions to remove and add
|
||||
to_remove = current_subscription_ids - target_subscription_ids
|
||||
to_add = target_subscription_ids - current_subscription_ids
|
||||
|
||||
# Remove old subscriptions
|
||||
removed_count = 0
|
||||
for subscription_id in to_remove:
|
||||
if shared_stream_manager.remove_subscription(subscription_id):
|
||||
removed_count += 1
|
||||
logger.info(f"[Streaming] Removed subscription {subscription_id}")
|
||||
|
||||
# Add new subscriptions with tracking
|
||||
added_count = 0
|
||||
failed_count = 0
|
||||
for subscription_payload in target_subscriptions:
|
||||
subscription_id = subscription_payload['subscriptionIdentifier']
|
||||
if subscription_id in to_add:
|
||||
success = await self._add_subscription_with_tracking(
|
||||
subscription_payload, tracking_integrations
|
||||
)
|
||||
if success:
|
||||
added_count += 1
|
||||
logger.info(f"[Streaming] Added subscription {subscription_id} with tracking")
|
||||
else:
|
||||
failed_count += 1
|
||||
logger.error(f"[Streaming] Failed to add subscription {subscription_id}")
|
||||
|
||||
return {
|
||||
'removed': removed_count,
|
||||
'added': added_count,
|
||||
'failed': failed_count,
|
||||
'total_active': len(shared_stream_manager.get_all_subscriptions())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in subscription reconciliation with tracking: {e}", exc_info=True)
|
||||
return {'removed': 0, 'added': 0, 'failed': 0, 'total_active': 0}
|
||||
|
||||
async def _add_subscription_with_tracking(self, payload, tracking_integrations) -> bool:
|
||||
"""Add a subscription with tracking integration."""
|
||||
try:
|
||||
from ..streaming.manager import StreamConfig
|
||||
|
||||
subscription_id = payload['subscriptionIdentifier']
|
||||
camera_id = subscription_id.split(';')[-1]
|
||||
model_id = payload['modelId']
|
||||
|
||||
# Get tracking integration for this model
|
||||
tracking_integration = tracking_integrations.get(model_id)
|
||||
|
||||
# Extract crop coordinates if present
|
||||
crop_coords = None
|
||||
if all(key in payload for key in ['cropX1', 'cropY1', 'cropX2', 'cropY2']):
|
||||
crop_coords = (
|
||||
payload['cropX1'],
|
||||
payload['cropY1'],
|
||||
payload['cropX2'],
|
||||
payload['cropY2']
|
||||
)
|
||||
|
||||
# Create stream configuration
|
||||
stream_config = StreamConfig(
|
||||
camera_id=camera_id,
|
||||
rtsp_url=payload.get('rtspUrl'),
|
||||
snapshot_url=payload.get('snapshotUrl'),
|
||||
snapshot_interval=payload.get('snapshotInterval', 5000),
|
||||
max_retries=3,
|
||||
save_test_frames=False # Disable frame saving, focus on tracking
|
||||
)
|
||||
|
||||
# Add subscription to StreamManager with tracking
|
||||
success = shared_stream_manager.add_subscription(
|
||||
subscription_id=subscription_id,
|
||||
stream_config=stream_config,
|
||||
crop_coords=crop_coords,
|
||||
model_id=model_id,
|
||||
model_url=payload.get('modelUrl'),
|
||||
tracking_integration=tracking_integration
|
||||
)
|
||||
|
||||
if success and tracking_integration:
|
||||
logger.info(f"[Tracking] Subscription {subscription_id} configured with tracking for model {model_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding subscription with tracking: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _ensure_single_model(self, model_id: int, model_url: str, model_name: str) -> bool:
|
||||
"""Ensure a single model is downloaded and available."""
|
||||
try:
|
||||
|
@ -303,6 +458,9 @@ class WebSocketHandler:
|
|||
# Update worker state
|
||||
worker_state.set_session_id(display_identifier, session_id)
|
||||
|
||||
# Update tracking integrations with session ID
|
||||
shared_stream_manager.set_session_id(display_identifier, session_id)
|
||||
|
||||
async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None:
|
||||
"""Handle setProgressionStage message."""
|
||||
display_identifier = message.payload.displayIdentifier
|
||||
|
@ -313,6 +471,14 @@ class WebSocketHandler:
|
|||
# Update worker state
|
||||
worker_state.set_progression_stage(display_identifier, stage)
|
||||
|
||||
# If stage indicates session is cleared/finished, clear from tracking
|
||||
if stage in ['finished', 'cleared', 'idle']:
|
||||
# Get session ID for this display and clear it
|
||||
session_id = worker_state.get_session_id(display_identifier)
|
||||
if session_id:
|
||||
shared_stream_manager.clear_session_id(session_id)
|
||||
logger.info(f"[Tracking] Cleared session {session_id} due to progression stage: {stage}")
|
||||
|
||||
async def _handle_request_state(self, message: RequestStateMessage) -> None:
|
||||
"""Handle requestState message by sending immediate state report."""
|
||||
logger.debug("[RX Processing] requestState - sending immediate state report")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue