From 96ecc321eca618e445f223e32832d0d08f887914 Mon Sep 17 00:00:00 2001 From: ziesorx Date: Fri, 12 Sep 2025 22:16:06 +0700 Subject: [PATCH] Fix: update message type to current implementation --- ARCHITECTURE.md | 384 ++++++++++++++---- .../communication/message_processor.py | 174 ++++---- .../communication/websocket_handler.py | 221 ++++++---- 3 files changed, 549 insertions(+), 230 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index b166a71..3ed37ec 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -141,18 +141,44 @@ WebSocketConnection() # Per-client connection wrapper #### `message_processor.py` - Message Processing Pipeline ```python MessageProcessor() - ├── process_message() # Main message dispatcher - ├── _handle_subscribe() # Process subscription requests - ├── _handle_unsubscribe() # Process unsubscription - ├── _handle_state_request() # System state requests - └── _handle_session_ops() # Session management operations + ├── parse_message() # Message parsing and validation + ├── _validate_set_subscription_list() # Validate declarative subscriptions + ├── _validate_set_session() # Session management validation + ├── _validate_patch_session() # Patch session validation + ├── _validate_progression_stage() # Progression stage validation + └── _validate_patch_session_result() # Backend response validation -MessageType(Enum) # Supported message types - ├── SUBSCRIBE - ├── UNSUBSCRIBE - ├── REQUEST_STATE - ├── SET_SESSION_ID - └── PATCH_SESSION +MessageType(Enum) # Protocol-compliant message types per worker.md + ├── SET_SUBSCRIPTION_LIST # Primary declarative subscription command + ├── REQUEST_STATE # System state requests + ├── SET_SESSION_ID # Session association (supports null clearing) + ├── PATCH_SESSION # Session modification requests + ├── SET_PROGRESSION_STAGE # Real-time progression updates + ├── PATCH_SESSION_RESULT # Backend responses to patch requests + ├── IMAGE_DETECTION # Detection results (worker->backend) + └── STATE_REPORT # Heartbeat messages (worker->backend) + +# Protocol-Compliant Payload Classes +SubscriptionObject() # Individual subscription per worker.md specification + ├── subscription_identifier # Format: "displayId;cameraId" + ├── rtsp_url # Required RTSP stream URL + ├── model_url # Required fresh model URL (1-hour TTL) + ├── model_id, model_name # Model identification + ├── snapshot_url # Optional HTTP snapshot URL + ├── snapshot_interval # Required if snapshot_url provided + └── crop_x1/y1/x2/y2 # Optional crop coordinates + +SetSubscriptionListPayload() # Declarative subscription command payload + └── subscriptions: List[SubscriptionObject] # Complete desired state + +SessionPayload() # Session management payload + ├── display_identifier # Target display ID + ├── session_id # Session ID (can be null for clearing) + └── data # Optional session patch data + +ProgressionPayload() # Progression stage payload + ├── display_identifier # Target display ID + └── progression_stage # Stage: welcome|car_fueling|car_waitpayment|null ``` ### Stream Management (`detector_worker/streams/`) @@ -452,105 +478,313 @@ sequenceDiagram return stream_manager.get_latest_frame(camera_id) ``` +## Protocol Compliance (worker.md Implementation) + +### Key Protocol Features Implemented + +The WebSocket communication layer has been **fully updated** to comply with the worker.md protocol specification, replacing deprecated patterns with modern declarative subscription management. + +#### ✅ **Protocol-Compliant Message Types** + +| Message Type | Direction | Purpose | Status | +|--------------|-----------|---------|---------| +| `setSubscriptionList` | Backend→Worker | **Primary subscription command** - declarative management | ✅ **Implemented** | +| `setSessionId` | Backend→Worker | Associate session with display (supports null clearing) | ✅ **Implemented** | +| `setProgressionStage` | Backend→Worker | Real-time progression updates for context-aware processing | ✅ **Implemented** | +| `requestState` | Backend→Worker | Request immediate state report | ✅ **Implemented** | +| `patchSessionResult` | Backend→Worker | Response to worker's patch session request | ✅ **Implemented** | +| `stateReport` | Worker→Backend | **Heartbeat with performance metrics** (every 2 seconds) | ✅ **Implemented** | +| `imageDetection` | Worker→Backend | Real-time detection results with session context | ✅ **Implemented** | +| `patchSession` | Worker→Backend | Request modification to session data | ✅ **Implemented** | + +#### ❌ **Deprecated Message Types Removed** + +- `subscribe` - Replaced by declarative `setSubscriptionList` +- `unsubscribe` - Handled by empty `setSubscriptionList` array + +#### 🔧 **Key Protocol Implementations** + +##### 1. **Declarative Subscription Management** +```python +# setSubscriptionList provides complete desired state +{ + "type": "setSubscriptionList", + "subscriptions": [ + { + "subscriptionIdentifier": "display-001;cam-001", + "rtspUrl": "rtsp://192.168.1.100/stream1", + "modelUrl": "http://storage/models/vehicle-id.mpta?token=fresh-token", + "modelId": 201, + "modelName": "Vehicle Identification", + "cropX1": 100, "cropY1": 200, "cropX2": 300, "cropY2": 400 + } + ] +} +``` + +**Worker Reconciliation Logic:** +- **Add new subscriptions** not in current state +- **Remove obsolete subscriptions** not in desired state +- **Update existing subscriptions** with fresh model URLs (handles S3 expiration) +- **Single stream optimization** - share RTSP streams across multiple subscriptions + +##### 2. **Protocol-Compliant State Reports** +```python +# stateReport with flat structure per worker.md specification +{ + "type": "stateReport", + "cpuUsage": 75.5, + "memoryUsage": 40.2, + "gpuUsage": 60.0, + "gpuMemoryUsage": 25.1, + "cameraConnections": [ + { + "subscriptionIdentifier": "display-001;cam-001", + "modelId": 101, + "modelName": "General Object Detection", + "online": true, + "cropX1": 100, "cropY1": 200, "cropX2": 300, "cropY2": 400 + } + ] +} +``` + +##### 3. **Session Context in Detection Results** +```python +# imageDetection with sessionId for proper session linking +{ + "type": "imageDetection", + "subscriptionIdentifier": "display-001;cam-001", + "timestamp": "2025-09-12T10:00:00.000Z", + "sessionId": 12345, # Critical for CMS session tracking + "data": { + "detection": { + "carBrand": "Honda", + "carModel": "CR-V", + "licensePlateText": "ABC-123" + }, + "modelId": 201, + "modelName": "Vehicle Identification" + } +} +``` + +#### 🚀 **Enhanced Features** + +1. **State Recovery**: Complete subscription restoration on worker reconnection +2. **Fresh Model URLs**: Automatic S3 URL refresh via subscription updates +3. **Multi-Display Support**: Proper display identifier parsing and session management +4. **Load Balancing Ready**: Backend can distribute subscriptions across workers +5. **Session Continuity**: Session IDs maintained across worker disconnections + ## WebSocket Communication Flow -### Client Connection Lifecycle +### Client Connection Lifecycle (Protocol-Compliant) ```mermaid sequenceDiagram - participant Client as WebSocket Client + participant Backend as CMS Backend participant WS as WebSocketHandler - participant CM as ConnectionManager participant MP as MessageProcessor participant SM as StreamManager + participant HB as Heartbeat Loop - Client->>WS: WebSocket Connection - WS->>WS: handle_websocket() - WS->>CM: add_connection() - CM->>CM: create WebSocketConnection - WS->>Client: Connection Accepted + Backend->>WS: WebSocket Connection + WS->>WS: handle_connection() + WS->>HB: start_heartbeat_loop() + WS->>Backend: Connection Accepted - loop Message Processing - Client->>WS: JSON Message - WS->>MP: process_message() + loop Heartbeat (every 2 seconds) + HB->>Backend: stateReport {cpuUsage, memoryUsage, gpuUsage, cameraConnections[]} + end + + loop Message Processing (Protocol Commands) + Backend->>WS: JSON Message + WS->>MP: parse_message() - alt Subscribe Message - MP->>SM: create_stream() - SM->>SM: initialize StreamReader - MP->>Client: subscribeAck - else Unsubscribe Message - MP->>SM: remove_stream() - SM->>SM: cleanup StreamReader - MP->>Client: unsubscribeAck - else State Request - MP->>MP: collect_system_state() - MP->>Client: stateReport + alt setSubscriptionList (Declarative) + MP->>MP: validate_subscription_list() + WS->>WS: reconcile_subscriptions() + WS->>SM: add/remove/update_streams() + SM->>SM: handle_stream_lifecycle() + WS->>HB: update_camera_connections() + + else setSessionId + MP->>MP: validate_set_session() + WS->>WS: store_session_for_display() + WS->>WS: apply_to_all_display_subscriptions() + + else setProgressionStage + MP->>MP: validate_progression_stage() + WS->>WS: store_stage_for_display() + WS->>WS: apply_context_aware_processing() + + else requestState + WS->>Backend: stateReport (immediate response) + + else patchSessionResult + MP->>MP: validate_patch_result() + WS->>WS: log_patch_response() end end - Client->>WS: Disconnect - WS->>CM: remove_connection() - WS->>SM: cleanup_client_streams() + Backend->>WS: Disconnect + WS->>SM: cleanup_all_streams() + WS->>HB: stop_heartbeat_loop() ``` ### Message Processing Detail -#### 1. Subscribe Message Flow (`message_processor.py:125-185`) +#### 1. setSubscriptionList Flow (`websocket_handler.py:355-453`) - Protocol Compliant ```python -async def _handle_subscribe(self, payload: Dict, client_id: str) -> Dict: - """Process subscription request""" +async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None: + """Handle setSubscriptionList command - declarative subscription management""" - # 1. Extract subscription parameters - subscription_id = payload["subscriptionIdentifier"] - stream_url = payload.get("rtspUrl") or payload.get("snapshotUrl") - model_url = payload["modelUrl"] + subscriptions = data.get("subscriptions", []) - # 2. Create stream configuration - stream_config = StreamConfig( - stream_url=stream_url, - stream_type="rtsp" if "rtsp" in stream_url else "http_snapshot", - crop_region=[payload.get("cropX1"), payload.get("cropY1"), - payload.get("cropX2"), payload.get("cropY2")] - ) + # 1. Get current and desired subscription states + current_subscriptions = set(subscription_to_camera.keys()) + desired_subscriptions = set() + subscription_configs = {} - # 3. Load ML pipeline - pipeline_config = await pipeline_loader.load_from_url(model_url) + for sub_config in subscriptions: + sub_id = sub_config.get("subscriptionIdentifier") + if sub_id: + desired_subscriptions.add(sub_id) + subscription_configs[sub_id] = sub_config + + # Extract display ID for session management + parts = sub_id.split(";") + if len(parts) >= 2: + display_id = parts[0] + self.display_identifiers.add(display_id) - # 4. Create stream (with sharing if same URL) - stream_info = await stream_manager.create_stream( - camera_id=subscription_id.split(';')[1], - config=stream_config, - subscription_id=subscription_id - ) + # 2. Calculate reconciliation changes + to_add = desired_subscriptions - current_subscriptions + to_remove = current_subscriptions - desired_subscriptions + to_update = desired_subscriptions & current_subscriptions - # 5. Register client subscription - connection_manager.add_subscription(client_id, subscription_id) + # 3. Remove obsolete subscriptions + for sub_id in to_remove: + camera_id = subscription_to_camera.get(sub_id) + if camera_id: + await self.stream_manager.stop_stream(camera_id) + self.model_manager.unload_models(camera_id) + subscription_to_camera.pop(sub_id, None) + self.session_cache.clear_session(camera_id) - return {"type": "subscribeAck", "status": "success", - "subscriptionId": subscription_id} + # 4. Add new subscriptions + for sub_id in to_add: + await self._start_subscription(sub_id, subscription_configs[sub_id]) + + # 5. Update existing subscriptions (handle S3 URL refresh) + for sub_id in to_update: + current_config = subscription_to_camera.get(sub_id) + new_config = subscription_configs[sub_id] + + # Restart if model URL changed (handles S3 expiration) + current_model_url = getattr(current_config, 'model_url', None) + new_model_url = new_config.get("modelUrl") + + if current_model_url != new_model_url: + camera_id = subscription_to_camera.get(sub_id) + if camera_id: + await self.stream_manager.stop_stream(camera_id) + self.model_manager.unload_models(camera_id) + await self._start_subscription(sub_id, new_config) + +async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None: + """Start individual subscription with protocol-compliant configuration""" + + # Extract camera ID from subscription identifier + parts = subscription_id.split(";") + camera_id = parts[1] if len(parts) >= 2 else subscription_id + + # Store subscription mapping + subscription_to_camera[subscription_id] = camera_id + + # Start camera stream with full config + await self.stream_manager.start_stream(camera_id, config) + + # Load ML model from fresh URL + model_id = config.get("modelId") + model_url = config.get("modelUrl") + if model_id and model_url: + await self.model_manager.load_model(camera_id, model_id, model_url) ``` -#### 2. Detection Result Broadcasting (`websocket_handler.py:245-265`) +#### 2. Detection Result Broadcasting (`websocket_handler.py:589-615`) - Protocol Compliant ```python -async def broadcast_detection_result(self, subscription_id: str, - detection_result: Dict): - """Broadcast detection to subscribed clients""" +async def _send_detection_result(self, camera_id: str, stream_info: Dict[str, Any], + detection_result: DetectionResult) -> None: + """Send detection result with protocol-compliant format""" - message = { + # Get session ID for this display (protocol requirement) + subscription_id = stream_info["subscriptionIdentifier"] + display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id + session_id = self.session_ids.get(display_id) # Can be None for null sessions + + # Protocol-compliant imageDetection format per worker.md + detection_data = { "type": "imageDetection", - "payload": { - "subscriptionId": subscription_id, - "detections": detection_result["detections"], - "timestamp": detection_result["timestamp"], - "modelInfo": detection_result["model_info"] + "subscriptionIdentifier": subscription_id, # Required at root level + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "sessionId": session_id, # Required for session linking + "data": { + "detection": detection_result.to_dict(), # Flat detection object + "modelId": stream_info["modelId"], + "modelName": stream_info["modelName"] } } - await self.connection_manager.broadcast_to_subscription( - subscription_id, message - ) + # Send to backend via WebSocket + await self.websocket.send_json(detection_data) +``` + +#### 3. State Report Broadcasting (`websocket_handler.py:201-209`) - Protocol Compliant + +```python +async def _send_heartbeat(self) -> None: + """Send protocol-compliant stateReport every 2 seconds""" + + while self.connected: + # Get system metrics + metrics = get_system_metrics() + + # Build cameraConnections array as required by protocol + camera_connections = [] + with self.stream_manager.streams_lock: + for camera_id, stream_info in self.stream_manager.streams.items(): + is_online = self.stream_manager.is_stream_active(camera_id) + + connection_info = { + "subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id), + "modelId": stream_info.get("modelId", 0), + "modelName": stream_info.get("modelName", "Unknown Model"), + "online": is_online + } + + # Add crop coordinates if available (optional per protocol) + for coord in ["cropX1", "cropY1", "cropX2", "cropY2"]: + if coord in stream_info: + connection_info[coord] = stream_info[coord] + + camera_connections.append(connection_info) + + # Protocol-compliant stateReport format (worker.md lines 169-189) + state_data = { + "type": "stateReport", # Message type + "cpuUsage": metrics.get("cpu_percent", 0), # CPU percentage + "memoryUsage": metrics.get("memory_percent", 0), # Memory percentage + "gpuUsage": metrics.get("gpu_percent", 0), # GPU percentage + "gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # GPU memory + "cameraConnections": camera_connections # Camera details array + } + + await self.websocket.send_json(state_data) + await asyncio.sleep(2) # 2-second heartbeat interval per protocol ``` ## Detection Pipeline Flow diff --git a/detector_worker/communication/message_processor.py b/detector_worker/communication/message_processor.py index fba0cfd..3185fa5 100644 --- a/detector_worker/communication/message_processor.py +++ b/detector_worker/communication/message_processor.py @@ -7,7 +7,7 @@ It provides a clean separation between message handling and business logic. import json import logging import time -from typing import Dict, Any, Optional, Callable, Tuple +from typing import Dict, Any, Optional, Callable, Tuple, List from enum import Enum from dataclasses import dataclass, field @@ -19,13 +19,13 @@ msg_proc_logger = logging.getLogger("websocket.message_processor") # Detailed m class MessageType(Enum): - """Enumeration of supported WebSocket message types.""" - SUBSCRIBE = "subscribe" - UNSUBSCRIBE = "unsubscribe" + """Enumeration of supported WebSocket message types per worker.md protocol.""" + SET_SUBSCRIPTION_LIST = "setSubscriptionList" REQUEST_STATE = "requestState" SET_SESSION_ID = "setSessionId" PATCH_SESSION = "patchSession" SET_PROGRESSION_STAGE = "setProgressionStage" + PATCH_SESSION_RESULT = "patchSessionResult" IMAGE_DETECTION = "imageDetection" STATE_REPORT = "stateReport" ACK = "ack" @@ -41,58 +41,47 @@ class ProgressionStage(Enum): @dataclass -class SubscribePayload: - """Payload for subscription messages.""" +class SubscriptionObject: + """Subscription object per worker.md protocol specification.""" subscription_identifier: str + rtsp_url: str + model_url: str model_id: int model_name: str - model_url: str - rtsp_url: Optional[str] = None snapshot_url: Optional[str] = None snapshot_interval: Optional[int] = None crop_x1: Optional[int] = None crop_y1: Optional[int] = None crop_x2: Optional[int] = None crop_y2: Optional[int] = None - extra_params: Dict[str, Any] = field(default_factory=dict) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SubscribePayload": - """Create SubscribePayload from dictionary.""" - # Extract known fields - known_fields = { - "subscription_identifier": data.get("subscriptionIdentifier"), - "model_id": data.get("modelId"), - "model_name": data.get("modelName"), - "model_url": data.get("modelUrl"), - "rtsp_url": data.get("rtspUrl"), - "snapshot_url": data.get("snapshotUrl"), - "snapshot_interval": data.get("snapshotInterval"), - "crop_x1": data.get("cropX1"), - "crop_y1": data.get("cropY1"), - "crop_x2": data.get("cropX2"), - "crop_y2": data.get("cropY2") - } - - # Extract extra parameters - extra_params = {k: v for k, v in data.items() - if k not in ["subscriptionIdentifier", "modelId", "modelName", - "modelUrl", "rtspUrl", "snapshotUrl", "snapshotInterval", - "cropX1", "cropY1", "cropX2", "cropY2"]} - - return cls(**known_fields, extra_params=extra_params) + def from_dict(cls, data: Dict[str, Any]) -> "SubscriptionObject": + """Create SubscriptionObject from dictionary.""" + return cls( + subscription_identifier=data.get("subscriptionIdentifier", ""), + rtsp_url=data.get("rtspUrl", ""), + model_url=data.get("modelUrl", ""), + model_id=data.get("modelId", 0), + model_name=data.get("modelName", ""), + snapshot_url=data.get("snapshotUrl"), + snapshot_interval=data.get("snapshotInterval"), + crop_x1=data.get("cropX1"), + crop_y1=data.get("cropY1"), + crop_x2=data.get("cropX2"), + crop_y2=data.get("cropY2") + ) def to_dict(self) -> Dict[str, Any]: """Convert to dictionary format for stream configuration.""" result = { "subscriptionIdentifier": self.subscription_identifier, + "rtspUrl": self.rtsp_url, + "modelUrl": self.model_url, "modelId": self.model_id, - "modelName": self.model_name, - "modelUrl": self.model_url + "modelName": self.model_name } - if self.rtsp_url: - result["rtspUrl"] = self.rtsp_url if self.snapshot_url: result["snapshotUrl"] = self.snapshot_url if self.snapshot_interval is not None: @@ -107,12 +96,22 @@ class SubscribePayload: if self.crop_y2 is not None: result["cropY2"] = self.crop_y2 - # Add any extra parameters - result.update(self.extra_params) - return result +@dataclass +class SetSubscriptionListPayload: + """Payload for setSubscriptionList command per worker.md protocol.""" + subscriptions: List[SubscriptionObject] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SetSubscriptionListPayload": + """Create SetSubscriptionListPayload from dictionary.""" + subscriptions_data = data.get("subscriptions", []) + subscriptions = [SubscriptionObject.from_dict(sub) for sub in subscriptions_data] + return cls(subscriptions=subscriptions) + + @dataclass class SessionPayload: """Payload for session-related messages.""" @@ -142,11 +141,11 @@ class MessageProcessor: def __init__(self): """Initialize the message processor.""" self.validators: Dict[MessageType, Callable] = { - MessageType.SUBSCRIBE: self._validate_subscribe, - MessageType.UNSUBSCRIBE: self._validate_unsubscribe, + MessageType.SET_SUBSCRIPTION_LIST: self._validate_set_subscription_list, MessageType.SET_SESSION_ID: self._validate_set_session, MessageType.PATCH_SESSION: self._validate_patch_session, - MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage + MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage, + MessageType.PATCH_SESSION_RESULT: self._validate_patch_session_result } def parse_message(self, raw_message: str) -> Tuple[MessageType, Dict[str, Any]]: @@ -222,58 +221,69 @@ class MessageProcessor: msg_proc_logger.error(f"❌ VALIDATION FAILED FOR {msg_type.value}: {e}") raise - def _validate_subscribe(self, data: Dict[str, Any]) -> SubscribePayload: - """Validate subscribe message.""" + def _validate_set_subscription_list(self, data: Dict[str, Any]) -> SetSubscriptionListPayload: + """Validate setSubscriptionList message per worker.md protocol.""" + subscriptions_data = data.get("subscriptions", []) + + if not isinstance(subscriptions_data, list): + raise ValidationError("subscriptions must be an array") + + # Validate each subscription object + for i, sub_data in enumerate(subscriptions_data): + if not isinstance(sub_data, dict): + raise ValidationError(f"Subscription {i} must be an object") + + # Required fields per protocol + required = ["subscriptionIdentifier", "rtspUrl", "modelUrl", "modelId", "modelName"] + missing = [field for field in required if not sub_data.get(field)] + if missing: + raise ValidationError(f"Subscription {i} missing required fields: {', '.join(missing)}") + + # Validate snapshot interval if snapshot URL is provided + if sub_data.get("snapshotUrl") and not sub_data.get("snapshotInterval"): + raise ValidationError(f"Subscription {i}: snapshotInterval is required when using snapshotUrl") + + # Validate crop coordinates + crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"] + crop_values = [sub_data.get(coord) for coord in crop_coords] + if any(v is not None for v in crop_values): + # If any crop coordinate is set, all must be set + if not all(v is not None for v in crop_values): + raise ValidationError(f"Subscription {i}: All crop coordinates must be specified if any are set") + + # Validate coordinate values + x1, y1, x2, y2 = crop_values + if x2 <= x1 or y2 <= y1: + raise ValidationError(f"Subscription {i}: Invalid crop coordinates: x2 must be > x1, y2 must be > y1") + + return SetSubscriptionListPayload.from_dict(data) + + def _validate_patch_session_result(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Validate patchSessionResult message.""" payload = data.get("payload", {}) - # Required fields - required = ["subscriptionIdentifier", "modelId", "modelName", "modelUrl"] - missing = [field for field in required if not payload.get(field)] - if missing: - raise ValidationError(f"Missing required fields: {', '.join(missing)}") + # Required fields per protocol + session_id = payload.get("sessionId") + if session_id is None: + raise ValidationError("Missing required field: sessionId") - # Must have either rtspUrl or snapshotUrl - if not payload.get("rtspUrl") and not payload.get("snapshotUrl"): - raise ValidationError("Must provide either rtspUrl or snapshotUrl") - - # Validate snapshot interval if snapshot URL is provided - if payload.get("snapshotUrl") and not payload.get("snapshotInterval"): - raise ValidationError("snapshotInterval is required when using snapshotUrl") - - # Validate crop coordinates - crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"] - crop_values = [payload.get(coord) for coord in crop_coords] - if any(v is not None for v in crop_values): - # If any crop coordinate is set, all must be set - if not all(v is not None for v in crop_values): - raise ValidationError("All crop coordinates must be specified if any are set") - - # Validate coordinate values - x1, y1, x2, y2 = crop_values - if x2 <= x1 or y2 <= y1: - raise ValidationError("Invalid crop coordinates: x2 must be > x1, y2 must be > y1") - - return SubscribePayload.from_dict(payload) - - def _validate_unsubscribe(self, data: Dict[str, Any]) -> Dict[str, Any]: - """Validate unsubscribe message.""" - payload = data.get("payload", {}) - - if not payload.get("subscriptionIdentifier"): - raise ValidationError("Missing required field: subscriptionIdentifier") + success = payload.get("success") + if success is None: + raise ValidationError("Missing required field: success") return payload def _validate_set_session(self, data: Dict[str, Any]) -> SessionPayload: - """Validate setSessionId message.""" + """Validate setSessionId message per worker.md protocol.""" payload = data.get("payload", {}) display_id = payload.get("displayIdentifier") - session_id = payload.get("sessionId") + session_id = payload.get("sessionId") # Can be null to clear session if not display_id: raise ValidationError("Missing required field: displayIdentifier") - if not session_id: + # sessionId can be null to clear the session per protocol + if "sessionId" not in payload: raise ValidationError("Missing required field: sessionId") return SessionPayload(display_id, session_id) diff --git a/detector_worker/communication/websocket_handler.py b/detector_worker/communication/websocket_handler.py index e86199b..bdfcad8 100644 --- a/detector_worker/communication/websocket_handler.py +++ b/detector_worker/communication/websocket_handler.py @@ -83,16 +83,17 @@ class WebSocketHandler: # Message handlers self.message_handlers: Dict[str, MessageHandler] = { - "subscribe": self._handle_subscribe, - "unsubscribe": self._handle_unsubscribe, + "setSubscriptionList": self._handle_set_subscription_list, "requestState": self._handle_request_state, "setSessionId": self._handle_set_session, "patchSession": self._handle_patch_session, - "setProgressionStage": self._handle_set_progression_stage + "setProgressionStage": self._handle_set_progression_stage, + "patchSessionResult": self._handle_patch_session_result } # Session and display management self.session_ids: Dict[str, str] = {} # display_identifier -> session_id + self.progression_stages: Dict[str, str] = {} # display_identifier -> progression_stage self.display_identifiers: Set[str] = set() # Camera monitor @@ -171,22 +172,40 @@ class WebSocketHandler: # Get system metrics metrics = get_system_metrics() - # Get active streams info - active_streams = self.stream_manager.get_active_streams() - active_models = self.model_manager.get_loaded_models() + # Build cameraConnections array as required by protocol + camera_connections = [] + with self.stream_manager.streams_lock: + for camera_id, stream_info in self.stream_manager.streams.items(): + # Check if camera is online + is_online = self.stream_manager.is_stream_active(camera_id) + + connection_info = { + "subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id), + "modelId": stream_info.get("modelId", 0), + "modelName": stream_info.get("modelName", "Unknown Model"), + "online": is_online + } + + # Add crop coordinates if available + if "cropX1" in stream_info: + connection_info["cropX1"] = stream_info["cropX1"] + if "cropY1" in stream_info: + connection_info["cropY1"] = stream_info["cropY1"] + if "cropX2" in stream_info: + connection_info["cropX2"] = stream_info["cropX2"] + if "cropY2" in stream_info: + connection_info["cropY2"] = stream_info["cropY2"] + + camera_connections.append(connection_info) + # Protocol-compliant stateReport format (worker.md lines 169-189) state_data = { "type": "stateReport", - "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - "data": { - "activeStreams": len(active_streams), - "loadedModels": len(active_models), - "cpuUsage": metrics.get("cpu_percent", 0), - "memoryUsage": metrics.get("memory_percent", 0), - "gpuUsage": metrics.get("gpu_percent", 0), - "gpuMemory": metrics.get("gpu_memory_percent", 0), - "uptime": time.time() - metrics.get("start_time", time.time()) - } + "cpuUsage": metrics.get("cpu_percent", 0), + "memoryUsage": metrics.get("memory_percent", 0), + "gpuUsage": metrics.get("gpu_percent", 0), + "gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # Fixed field name + "cameraConnections": camera_connections } # Compact JSON for RX/TX logging @@ -351,75 +370,105 @@ class WebSocketHandler: traceback.print_exc() return persistent_data - async def _handle_subscribe(self, data: Dict[str, Any]) -> None: - """Handle stream subscription request.""" - payload = data.get("payload", {}) - subscription_id = payload.get("subscriptionIdentifier") + async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None: + """ + Handle setSubscriptionList command - declarative subscription management. + + This is the primary subscription command per worker.md protocol. + Workers must reconcile the new subscription list with current state. + """ + subscriptions = data.get("subscriptions", []) - if not subscription_id: - logger.error("Missing subscriptionIdentifier in subscribe payload") - return - try: - # Extract display and camera IDs - parts = subscription_id.split(";") - if len(parts) >= 2: - display_id = parts[0] - camera_id = parts[1] - self.display_identifiers.add(display_id) - else: - camera_id = subscription_id + # Get current subscription identifiers + current_subscriptions = set(subscription_to_camera.keys()) + + # Get desired subscription identifiers + desired_subscriptions = set() + subscription_configs = {} + + for sub_config in subscriptions: + sub_id = sub_config.get("subscriptionIdentifier") + if sub_id: + desired_subscriptions.add(sub_id) + subscription_configs[sub_id] = sub_config + + # Extract display ID for session management + parts = sub_id.split(";") + if len(parts) >= 2: + display_id = parts[0] + self.display_identifiers.add(display_id) + + # Calculate changes needed + to_add = desired_subscriptions - current_subscriptions + to_remove = current_subscriptions - desired_subscriptions + to_update = desired_subscriptions & current_subscriptions + + logger.info(f"Subscription reconciliation: add={len(to_add)}, remove={len(to_remove)}, update={len(to_update)}") + + # Remove obsolete subscriptions + for sub_id in to_remove: + camera_id = subscription_to_camera.get(sub_id) + if camera_id: + await self.stream_manager.stop_stream(camera_id) + self.model_manager.unload_models(camera_id) + subscription_to_camera.pop(sub_id, None) + self.session_cache.clear_session(camera_id) + logger.info(f"Removed subscription: {sub_id}") + + # Add new subscriptions + for sub_id in to_add: + await self._start_subscription(sub_id, subscription_configs[sub_id]) + logger.info(f"Added subscription: {sub_id}") + + # Update existing subscriptions if needed + for sub_id in to_update: + # Check if configuration changed (model URL, crop coordinates, etc.) + current_config = subscription_to_camera.get(sub_id) + new_config = subscription_configs[sub_id] + # For now, restart subscription if model URL changed (handles S3 expiration) + current_model_url = getattr(current_config, 'model_url', None) if current_config else None + new_model_url = new_config.get("modelUrl") + + if current_model_url != new_model_url: + # Restart with new configuration + camera_id = subscription_to_camera.get(sub_id) + if camera_id: + await self.stream_manager.stop_stream(camera_id) + self.model_manager.unload_models(camera_id) + + await self._start_subscription(sub_id, new_config) + logger.info(f"Updated subscription: {sub_id}") + + logger.info(f"Subscription list reconciliation completed. Active: {len(desired_subscriptions)}") + + except Exception as e: + logger.error(f"Error handling setSubscriptionList: {e}") + traceback.print_exc() + + async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None: + """Start a single subscription with given configuration.""" + try: + # Extract camera ID from subscription identifier + parts = subscription_id.split(";") + camera_id = parts[1] if len(parts) >= 2 else subscription_id + # Store subscription mapping subscription_to_camera[subscription_id] = camera_id # Start camera stream - await self.stream_manager.start_stream(camera_id, payload) + await self.stream_manager.start_stream(camera_id, config) # Load model - model_id = payload.get("modelId") - model_url = payload.get("modelUrl") + model_id = config.get("modelId") + model_url = config.get("modelUrl") if model_id and model_url: await self.model_manager.load_model(camera_id, model_id, model_url) - logger.info(f"Subscribed to stream: {subscription_id}") - except Exception as e: - logger.error(f"Error handling subscription: {e}") - traceback.print_exc() - - async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None: - """Handle stream unsubscription request.""" - payload = data.get("payload", {}) - subscription_id = payload.get("subscriptionIdentifier") - - if not subscription_id: - logger.error("Missing subscriptionIdentifier in unsubscribe payload") - return - - try: - # Get camera ID from subscription - camera_id = subscription_to_camera.get(subscription_id) - if not camera_id: - logger.warning(f"No camera found for subscription: {subscription_id}") - return - - # Stop stream - await self.stream_manager.stop_stream(camera_id) - - # Unload model - self.model_manager.unload_models(camera_id) - - # Clean up mappings - subscription_to_camera.pop(subscription_id, None) - - # Clean up session state - self.session_cache.clear_session(camera_id) - - logger.info(f"Unsubscribed from stream: {subscription_id}") - - except Exception as e: - logger.error(f"Error handling unsubscription: {e}") + logger.error(f"Error starting subscription {subscription_id}: {e}") + raise traceback.print_exc() async def _handle_request_state(self, data: Dict[str, Any]) -> None: @@ -534,6 +583,26 @@ class WebSocketHandler: elif progression_stage in ["car_wait_staff"]: pipeline_state["progression_stage"] = progression_stage logger.info(f"📋 Camera {camera_id}: Progression stage set to {progression_stage}") + + # Store progression stage for this display + if display_id and progression_stage is not None: + if progression_stage: + self.progression_stages[display_id] = progression_stage + else: + # Clear progression stage if null + self.progression_stages.pop(display_id, None) + + async def _handle_patch_session_result(self, data: Dict[str, Any]) -> None: + """Handle patchSessionResult message from backend.""" + payload = data.get("payload", {}) + session_id = payload.get("sessionId") + success = payload.get("success", False) + message = payload.get("message", "") + + if success: + logger.info(f"Patch session {session_id} successful: {message}") + else: + logger.warning(f"Patch session {session_id} failed: {message}") async def _send_detection_result( self, @@ -542,10 +611,16 @@ class WebSocketHandler: detection_result: DetectionResult ) -> None: """Send detection result over WebSocket.""" + # Get session ID for this display + subscription_id = stream_info["subscriptionIdentifier"] + display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id + session_id = self.session_ids.get(display_id) + detection_data = { "type": "imageDetection", - "subscriptionIdentifier": stream_info["subscriptionIdentifier"], + "subscriptionIdentifier": subscription_id, "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "sessionId": session_id, # Required by protocol "data": { "detection": detection_result.to_dict(), "modelId": stream_info["modelId"],