Fix: update message type to current implementation

This commit is contained in:
ziesorx 2025-09-12 22:16:06 +07:00
parent b940790e4a
commit 96ecc321ec
3 changed files with 549 additions and 230 deletions

View file

@ -7,7 +7,7 @@ It provides a clean separation between message handling and business logic.
import json
import logging
import time
from typing import Dict, Any, Optional, Callable, Tuple
from typing import Dict, Any, Optional, Callable, Tuple, List
from enum import Enum
from dataclasses import dataclass, field
@ -19,13 +19,13 @@ msg_proc_logger = logging.getLogger("websocket.message_processor") # Detailed m
class MessageType(Enum):
"""Enumeration of supported WebSocket message types."""
SUBSCRIBE = "subscribe"
UNSUBSCRIBE = "unsubscribe"
"""Enumeration of supported WebSocket message types per worker.md protocol."""
SET_SUBSCRIPTION_LIST = "setSubscriptionList"
REQUEST_STATE = "requestState"
SET_SESSION_ID = "setSessionId"
PATCH_SESSION = "patchSession"
SET_PROGRESSION_STAGE = "setProgressionStage"
PATCH_SESSION_RESULT = "patchSessionResult"
IMAGE_DETECTION = "imageDetection"
STATE_REPORT = "stateReport"
ACK = "ack"
@ -41,58 +41,47 @@ class ProgressionStage(Enum):
@dataclass
class SubscribePayload:
"""Payload for subscription messages."""
class SubscriptionObject:
"""Subscription object per worker.md protocol specification."""
subscription_identifier: str
rtsp_url: str
model_url: str
model_id: int
model_name: str
model_url: str
rtsp_url: Optional[str] = None
snapshot_url: Optional[str] = None
snapshot_interval: Optional[int] = None
crop_x1: Optional[int] = None
crop_y1: Optional[int] = None
crop_x2: Optional[int] = None
crop_y2: Optional[int] = None
extra_params: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SubscribePayload":
"""Create SubscribePayload from dictionary."""
# Extract known fields
known_fields = {
"subscription_identifier": data.get("subscriptionIdentifier"),
"model_id": data.get("modelId"),
"model_name": data.get("modelName"),
"model_url": data.get("modelUrl"),
"rtsp_url": data.get("rtspUrl"),
"snapshot_url": data.get("snapshotUrl"),
"snapshot_interval": data.get("snapshotInterval"),
"crop_x1": data.get("cropX1"),
"crop_y1": data.get("cropY1"),
"crop_x2": data.get("cropX2"),
"crop_y2": data.get("cropY2")
}
# Extract extra parameters
extra_params = {k: v for k, v in data.items()
if k not in ["subscriptionIdentifier", "modelId", "modelName",
"modelUrl", "rtspUrl", "snapshotUrl", "snapshotInterval",
"cropX1", "cropY1", "cropX2", "cropY2"]}
return cls(**known_fields, extra_params=extra_params)
def from_dict(cls, data: Dict[str, Any]) -> "SubscriptionObject":
"""Create SubscriptionObject from dictionary."""
return cls(
subscription_identifier=data.get("subscriptionIdentifier", ""),
rtsp_url=data.get("rtspUrl", ""),
model_url=data.get("modelUrl", ""),
model_id=data.get("modelId", 0),
model_name=data.get("modelName", ""),
snapshot_url=data.get("snapshotUrl"),
snapshot_interval=data.get("snapshotInterval"),
crop_x1=data.get("cropX1"),
crop_y1=data.get("cropY1"),
crop_x2=data.get("cropX2"),
crop_y2=data.get("cropY2")
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format for stream configuration."""
result = {
"subscriptionIdentifier": self.subscription_identifier,
"rtspUrl": self.rtsp_url,
"modelUrl": self.model_url,
"modelId": self.model_id,
"modelName": self.model_name,
"modelUrl": self.model_url
"modelName": self.model_name
}
if self.rtsp_url:
result["rtspUrl"] = self.rtsp_url
if self.snapshot_url:
result["snapshotUrl"] = self.snapshot_url
if self.snapshot_interval is not None:
@ -107,12 +96,22 @@ class SubscribePayload:
if self.crop_y2 is not None:
result["cropY2"] = self.crop_y2
# Add any extra parameters
result.update(self.extra_params)
return result
@dataclass
class SetSubscriptionListPayload:
"""Payload for setSubscriptionList command per worker.md protocol."""
subscriptions: List[SubscriptionObject]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SetSubscriptionListPayload":
"""Create SetSubscriptionListPayload from dictionary."""
subscriptions_data = data.get("subscriptions", [])
subscriptions = [SubscriptionObject.from_dict(sub) for sub in subscriptions_data]
return cls(subscriptions=subscriptions)
@dataclass
class SessionPayload:
"""Payload for session-related messages."""
@ -142,11 +141,11 @@ class MessageProcessor:
def __init__(self):
"""Initialize the message processor."""
self.validators: Dict[MessageType, Callable] = {
MessageType.SUBSCRIBE: self._validate_subscribe,
MessageType.UNSUBSCRIBE: self._validate_unsubscribe,
MessageType.SET_SUBSCRIPTION_LIST: self._validate_set_subscription_list,
MessageType.SET_SESSION_ID: self._validate_set_session,
MessageType.PATCH_SESSION: self._validate_patch_session,
MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage
MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage,
MessageType.PATCH_SESSION_RESULT: self._validate_patch_session_result
}
def parse_message(self, raw_message: str) -> Tuple[MessageType, Dict[str, Any]]:
@ -222,58 +221,69 @@ class MessageProcessor:
msg_proc_logger.error(f"❌ VALIDATION FAILED FOR {msg_type.value}: {e}")
raise
def _validate_subscribe(self, data: Dict[str, Any]) -> SubscribePayload:
"""Validate subscribe message."""
def _validate_set_subscription_list(self, data: Dict[str, Any]) -> SetSubscriptionListPayload:
"""Validate setSubscriptionList message per worker.md protocol."""
subscriptions_data = data.get("subscriptions", [])
if not isinstance(subscriptions_data, list):
raise ValidationError("subscriptions must be an array")
# Validate each subscription object
for i, sub_data in enumerate(subscriptions_data):
if not isinstance(sub_data, dict):
raise ValidationError(f"Subscription {i} must be an object")
# Required fields per protocol
required = ["subscriptionIdentifier", "rtspUrl", "modelUrl", "modelId", "modelName"]
missing = [field for field in required if not sub_data.get(field)]
if missing:
raise ValidationError(f"Subscription {i} missing required fields: {', '.join(missing)}")
# Validate snapshot interval if snapshot URL is provided
if sub_data.get("snapshotUrl") and not sub_data.get("snapshotInterval"):
raise ValidationError(f"Subscription {i}: snapshotInterval is required when using snapshotUrl")
# Validate crop coordinates
crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"]
crop_values = [sub_data.get(coord) for coord in crop_coords]
if any(v is not None for v in crop_values):
# If any crop coordinate is set, all must be set
if not all(v is not None for v in crop_values):
raise ValidationError(f"Subscription {i}: All crop coordinates must be specified if any are set")
# Validate coordinate values
x1, y1, x2, y2 = crop_values
if x2 <= x1 or y2 <= y1:
raise ValidationError(f"Subscription {i}: Invalid crop coordinates: x2 must be > x1, y2 must be > y1")
return SetSubscriptionListPayload.from_dict(data)
def _validate_patch_session_result(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate patchSessionResult message."""
payload = data.get("payload", {})
# Required fields
required = ["subscriptionIdentifier", "modelId", "modelName", "modelUrl"]
missing = [field for field in required if not payload.get(field)]
if missing:
raise ValidationError(f"Missing required fields: {', '.join(missing)}")
# Required fields per protocol
session_id = payload.get("sessionId")
if session_id is None:
raise ValidationError("Missing required field: sessionId")
# Must have either rtspUrl or snapshotUrl
if not payload.get("rtspUrl") and not payload.get("snapshotUrl"):
raise ValidationError("Must provide either rtspUrl or snapshotUrl")
# Validate snapshot interval if snapshot URL is provided
if payload.get("snapshotUrl") and not payload.get("snapshotInterval"):
raise ValidationError("snapshotInterval is required when using snapshotUrl")
# Validate crop coordinates
crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"]
crop_values = [payload.get(coord) for coord in crop_coords]
if any(v is not None for v in crop_values):
# If any crop coordinate is set, all must be set
if not all(v is not None for v in crop_values):
raise ValidationError("All crop coordinates must be specified if any are set")
# Validate coordinate values
x1, y1, x2, y2 = crop_values
if x2 <= x1 or y2 <= y1:
raise ValidationError("Invalid crop coordinates: x2 must be > x1, y2 must be > y1")
return SubscribePayload.from_dict(payload)
def _validate_unsubscribe(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate unsubscribe message."""
payload = data.get("payload", {})
if not payload.get("subscriptionIdentifier"):
raise ValidationError("Missing required field: subscriptionIdentifier")
success = payload.get("success")
if success is None:
raise ValidationError("Missing required field: success")
return payload
def _validate_set_session(self, data: Dict[str, Any]) -> SessionPayload:
"""Validate setSessionId message."""
"""Validate setSessionId message per worker.md protocol."""
payload = data.get("payload", {})
display_id = payload.get("displayIdentifier")
session_id = payload.get("sessionId")
session_id = payload.get("sessionId") # Can be null to clear session
if not display_id:
raise ValidationError("Missing required field: displayIdentifier")
if not session_id:
# sessionId can be null to clear the session per protocol
if "sessionId" not in payload:
raise ValidationError("Missing required field: sessionId")
return SessionPayload(display_id, session_id)

View file

@ -83,16 +83,17 @@ class WebSocketHandler:
# Message handlers
self.message_handlers: Dict[str, MessageHandler] = {
"subscribe": self._handle_subscribe,
"unsubscribe": self._handle_unsubscribe,
"setSubscriptionList": self._handle_set_subscription_list,
"requestState": self._handle_request_state,
"setSessionId": self._handle_set_session,
"patchSession": self._handle_patch_session,
"setProgressionStage": self._handle_set_progression_stage
"setProgressionStage": self._handle_set_progression_stage,
"patchSessionResult": self._handle_patch_session_result
}
# Session and display management
self.session_ids: Dict[str, str] = {} # display_identifier -> session_id
self.progression_stages: Dict[str, str] = {} # display_identifier -> progression_stage
self.display_identifiers: Set[str] = set()
# Camera monitor
@ -171,22 +172,40 @@ class WebSocketHandler:
# Get system metrics
metrics = get_system_metrics()
# Get active streams info
active_streams = self.stream_manager.get_active_streams()
active_models = self.model_manager.get_loaded_models()
# Build cameraConnections array as required by protocol
camera_connections = []
with self.stream_manager.streams_lock:
for camera_id, stream_info in self.stream_manager.streams.items():
# Check if camera is online
is_online = self.stream_manager.is_stream_active(camera_id)
connection_info = {
"subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id),
"modelId": stream_info.get("modelId", 0),
"modelName": stream_info.get("modelName", "Unknown Model"),
"online": is_online
}
# Add crop coordinates if available
if "cropX1" in stream_info:
connection_info["cropX1"] = stream_info["cropX1"]
if "cropY1" in stream_info:
connection_info["cropY1"] = stream_info["cropY1"]
if "cropX2" in stream_info:
connection_info["cropX2"] = stream_info["cropX2"]
if "cropY2" in stream_info:
connection_info["cropY2"] = stream_info["cropY2"]
camera_connections.append(connection_info)
# Protocol-compliant stateReport format (worker.md lines 169-189)
state_data = {
"type": "stateReport",
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"data": {
"activeStreams": len(active_streams),
"loadedModels": len(active_models),
"cpuUsage": metrics.get("cpu_percent", 0),
"memoryUsage": metrics.get("memory_percent", 0),
"gpuUsage": metrics.get("gpu_percent", 0),
"gpuMemory": metrics.get("gpu_memory_percent", 0),
"uptime": time.time() - metrics.get("start_time", time.time())
}
"cpuUsage": metrics.get("cpu_percent", 0),
"memoryUsage": metrics.get("memory_percent", 0),
"gpuUsage": metrics.get("gpu_percent", 0),
"gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # Fixed field name
"cameraConnections": camera_connections
}
# Compact JSON for RX/TX logging
@ -351,75 +370,105 @@ class WebSocketHandler:
traceback.print_exc()
return persistent_data
async def _handle_subscribe(self, data: Dict[str, Any]) -> None:
"""Handle stream subscription request."""
payload = data.get("payload", {})
subscription_id = payload.get("subscriptionIdentifier")
async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None:
"""
Handle setSubscriptionList command - declarative subscription management.
This is the primary subscription command per worker.md protocol.
Workers must reconcile the new subscription list with current state.
"""
subscriptions = data.get("subscriptions", [])
if not subscription_id:
logger.error("Missing subscriptionIdentifier in subscribe payload")
return
try:
# Extract display and camera IDs
parts = subscription_id.split(";")
if len(parts) >= 2:
display_id = parts[0]
camera_id = parts[1]
self.display_identifiers.add(display_id)
else:
camera_id = subscription_id
# Get current subscription identifiers
current_subscriptions = set(subscription_to_camera.keys())
# Get desired subscription identifiers
desired_subscriptions = set()
subscription_configs = {}
for sub_config in subscriptions:
sub_id = sub_config.get("subscriptionIdentifier")
if sub_id:
desired_subscriptions.add(sub_id)
subscription_configs[sub_id] = sub_config
# Extract display ID for session management
parts = sub_id.split(";")
if len(parts) >= 2:
display_id = parts[0]
self.display_identifiers.add(display_id)
# Calculate changes needed
to_add = desired_subscriptions - current_subscriptions
to_remove = current_subscriptions - desired_subscriptions
to_update = desired_subscriptions & current_subscriptions
logger.info(f"Subscription reconciliation: add={len(to_add)}, remove={len(to_remove)}, update={len(to_update)}")
# Remove obsolete subscriptions
for sub_id in to_remove:
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
subscription_to_camera.pop(sub_id, None)
self.session_cache.clear_session(camera_id)
logger.info(f"Removed subscription: {sub_id}")
# Add new subscriptions
for sub_id in to_add:
await self._start_subscription(sub_id, subscription_configs[sub_id])
logger.info(f"Added subscription: {sub_id}")
# Update existing subscriptions if needed
for sub_id in to_update:
# Check if configuration changed (model URL, crop coordinates, etc.)
current_config = subscription_to_camera.get(sub_id)
new_config = subscription_configs[sub_id]
# For now, restart subscription if model URL changed (handles S3 expiration)
current_model_url = getattr(current_config, 'model_url', None) if current_config else None
new_model_url = new_config.get("modelUrl")
if current_model_url != new_model_url:
# Restart with new configuration
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
await self._start_subscription(sub_id, new_config)
logger.info(f"Updated subscription: {sub_id}")
logger.info(f"Subscription list reconciliation completed. Active: {len(desired_subscriptions)}")
except Exception as e:
logger.error(f"Error handling setSubscriptionList: {e}")
traceback.print_exc()
async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None:
"""Start a single subscription with given configuration."""
try:
# Extract camera ID from subscription identifier
parts = subscription_id.split(";")
camera_id = parts[1] if len(parts) >= 2 else subscription_id
# Store subscription mapping
subscription_to_camera[subscription_id] = camera_id
# Start camera stream
await self.stream_manager.start_stream(camera_id, payload)
await self.stream_manager.start_stream(camera_id, config)
# Load model
model_id = payload.get("modelId")
model_url = payload.get("modelUrl")
model_id = config.get("modelId")
model_url = config.get("modelUrl")
if model_id and model_url:
await self.model_manager.load_model(camera_id, model_id, model_url)
logger.info(f"Subscribed to stream: {subscription_id}")
except Exception as e:
logger.error(f"Error handling subscription: {e}")
traceback.print_exc()
async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None:
"""Handle stream unsubscription request."""
payload = data.get("payload", {})
subscription_id = payload.get("subscriptionIdentifier")
if not subscription_id:
logger.error("Missing subscriptionIdentifier in unsubscribe payload")
return
try:
# Get camera ID from subscription
camera_id = subscription_to_camera.get(subscription_id)
if not camera_id:
logger.warning(f"No camera found for subscription: {subscription_id}")
return
# Stop stream
await self.stream_manager.stop_stream(camera_id)
# Unload model
self.model_manager.unload_models(camera_id)
# Clean up mappings
subscription_to_camera.pop(subscription_id, None)
# Clean up session state
self.session_cache.clear_session(camera_id)
logger.info(f"Unsubscribed from stream: {subscription_id}")
except Exception as e:
logger.error(f"Error handling unsubscription: {e}")
logger.error(f"Error starting subscription {subscription_id}: {e}")
raise
traceback.print_exc()
async def _handle_request_state(self, data: Dict[str, Any]) -> None:
@ -534,6 +583,26 @@ class WebSocketHandler:
elif progression_stage in ["car_wait_staff"]:
pipeline_state["progression_stage"] = progression_stage
logger.info(f"📋 Camera {camera_id}: Progression stage set to {progression_stage}")
# Store progression stage for this display
if display_id and progression_stage is not None:
if progression_stage:
self.progression_stages[display_id] = progression_stage
else:
# Clear progression stage if null
self.progression_stages.pop(display_id, None)
async def _handle_patch_session_result(self, data: Dict[str, Any]) -> None:
"""Handle patchSessionResult message from backend."""
payload = data.get("payload", {})
session_id = payload.get("sessionId")
success = payload.get("success", False)
message = payload.get("message", "")
if success:
logger.info(f"Patch session {session_id} successful: {message}")
else:
logger.warning(f"Patch session {session_id} failed: {message}")
async def _send_detection_result(
self,
@ -542,10 +611,16 @@ class WebSocketHandler:
detection_result: DetectionResult
) -> None:
"""Send detection result over WebSocket."""
# Get session ID for this display
subscription_id = stream_info["subscriptionIdentifier"]
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
session_id = self.session_ids.get(display_id)
detection_data = {
"type": "imageDetection",
"subscriptionIdentifier": stream_info["subscriptionIdentifier"],
"subscriptionIdentifier": subscription_id,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"sessionId": session_id, # Required by protocol
"data": {
"detection": detection_result.to_dict(),
"modelId": stream_info["modelId"],