204 lines
No EOL
6 KiB
Python
204 lines
No EOL
6 KiB
Python
"""
|
|
Message types, constants, and validation functions for WebSocket communication.
|
|
"""
|
|
import json
|
|
import logging
|
|
from typing import Dict, Any, Optional
|
|
from .models import (
|
|
IncomingMessage, OutgoingMessage,
|
|
SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
|
|
RequestStateMessage, PatchSessionResultMessage,
|
|
StateReportMessage, ImageDetectionMessage, PatchSessionMessage
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Message type constants
|
|
class MessageTypes:
|
|
"""WebSocket message type constants."""
|
|
|
|
# Incoming from backend
|
|
SET_SUBSCRIPTION_LIST = "setSubscriptionList"
|
|
SET_SESSION_ID = "setSessionId"
|
|
SET_PROGRESSION_STAGE = "setProgressionStage"
|
|
REQUEST_STATE = "requestState"
|
|
PATCH_SESSION_RESULT = "patchSessionResult"
|
|
|
|
# Outgoing to backend
|
|
STATE_REPORT = "stateReport"
|
|
IMAGE_DETECTION = "imageDetection"
|
|
PATCH_SESSION = "patchSession"
|
|
|
|
|
|
def parse_incoming_message(raw_message: str) -> Optional[IncomingMessage]:
|
|
"""
|
|
Parse incoming WebSocket message and validate against known types.
|
|
|
|
Args:
|
|
raw_message: Raw JSON string from WebSocket
|
|
|
|
Returns:
|
|
Parsed message object or None if invalid
|
|
"""
|
|
try:
|
|
data = json.loads(raw_message)
|
|
message_type = data.get("type")
|
|
|
|
if not message_type:
|
|
logger.error("Message missing 'type' field")
|
|
return None
|
|
|
|
# Route to appropriate message class
|
|
if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
|
|
return SetSubscriptionListMessage(**data)
|
|
elif message_type == MessageTypes.SET_SESSION_ID:
|
|
return SetSessionIdMessage(**data)
|
|
elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
|
|
return SetProgressionStageMessage(**data)
|
|
elif message_type == MessageTypes.REQUEST_STATE:
|
|
return RequestStateMessage(**data)
|
|
elif message_type == MessageTypes.PATCH_SESSION_RESULT:
|
|
return PatchSessionResultMessage(**data)
|
|
else:
|
|
logger.warning(f"Unknown message type: {message_type}")
|
|
return None
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Failed to decode JSON message: {e}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse incoming message: {e}")
|
|
return None
|
|
|
|
|
|
def serialize_outgoing_message(message: OutgoingMessage) -> str:
|
|
"""
|
|
Serialize outgoing message to JSON string.
|
|
|
|
Args:
|
|
message: Message object to serialize
|
|
|
|
Returns:
|
|
JSON string representation
|
|
"""
|
|
try:
|
|
return message.model_dump_json(exclude_none=True)
|
|
except Exception as e:
|
|
logger.error(f"Failed to serialize outgoing message: {e}")
|
|
raise
|
|
|
|
|
|
def validate_subscription_identifier(identifier: str) -> bool:
|
|
"""
|
|
Validate subscription identifier format (displayId;cameraId).
|
|
|
|
Args:
|
|
identifier: Subscription identifier to validate
|
|
|
|
Returns:
|
|
True if valid format, False otherwise
|
|
"""
|
|
if not identifier or not isinstance(identifier, str):
|
|
return False
|
|
|
|
parts = identifier.split(';')
|
|
if len(parts) != 2:
|
|
logger.error(f"Invalid subscription identifier format: {identifier}")
|
|
return False
|
|
|
|
display_id, camera_id = parts
|
|
if not display_id or not camera_id:
|
|
logger.error(f"Empty display or camera ID in identifier: {identifier}")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def extract_display_identifier(subscription_identifier: str) -> Optional[str]:
|
|
"""
|
|
Extract display identifier from subscription identifier.
|
|
|
|
Args:
|
|
subscription_identifier: Full subscription identifier (displayId;cameraId)
|
|
|
|
Returns:
|
|
Display identifier or None if invalid format
|
|
"""
|
|
if not validate_subscription_identifier(subscription_identifier):
|
|
return None
|
|
|
|
return subscription_identifier.split(';')[0]
|
|
|
|
|
|
def create_state_report(cpu_usage: float, memory_usage: float,
|
|
gpu_usage: Optional[float] = None,
|
|
gpu_memory_usage: Optional[float] = None,
|
|
camera_connections: Optional[list] = None) -> StateReportMessage:
|
|
"""
|
|
Create a state report message with system metrics.
|
|
|
|
Args:
|
|
cpu_usage: CPU usage percentage
|
|
memory_usage: Memory usage percentage
|
|
gpu_usage: GPU usage percentage (optional)
|
|
gpu_memory_usage: GPU memory usage in MB (optional)
|
|
camera_connections: List of active camera connections
|
|
|
|
Returns:
|
|
StateReportMessage object
|
|
"""
|
|
return StateReportMessage(
|
|
cpuUsage=cpu_usage,
|
|
memoryUsage=memory_usage,
|
|
gpuUsage=gpu_usage,
|
|
gpuMemoryUsage=gpu_memory_usage,
|
|
cameraConnections=camera_connections or []
|
|
)
|
|
|
|
|
|
def create_image_detection(subscription_identifier: str, detection_data: Dict[str, Any],
|
|
model_id: int, model_name: str,
|
|
session_id: Optional[int] = None) -> ImageDetectionMessage:
|
|
"""
|
|
Create an image detection message.
|
|
|
|
Args:
|
|
subscription_identifier: Camera subscription identifier
|
|
detection_data: Flat dictionary of detection results
|
|
model_id: Model identifier
|
|
model_name: Model name
|
|
session_id: Optional session ID
|
|
|
|
Returns:
|
|
ImageDetectionMessage object
|
|
"""
|
|
from .models import DetectionData
|
|
|
|
data = DetectionData(
|
|
detection=detection_data,
|
|
modelId=model_id,
|
|
modelName=model_name
|
|
)
|
|
|
|
return ImageDetectionMessage(
|
|
subscriptionIdentifier=subscription_identifier,
|
|
sessionId=session_id,
|
|
data=data
|
|
)
|
|
|
|
|
|
def create_patch_session(session_id: int, patch_data: Dict[str, Any]) -> PatchSessionMessage:
|
|
"""
|
|
Create a patch session message.
|
|
|
|
Args:
|
|
session_id: Session ID to patch
|
|
patch_data: Partial session data to update
|
|
|
|
Returns:
|
|
PatchSessionMessage object
|
|
"""
|
|
return PatchSessionMessage(
|
|
sessionId=session_id,
|
|
data=patch_data
|
|
) |