Fix: update message type to current implementation

This commit is contained in:
ziesorx 2025-09-12 22:16:06 +07:00
parent b940790e4a
commit 96ecc321ec
3 changed files with 549 additions and 230 deletions

View file

@ -141,18 +141,44 @@ WebSocketConnection() # Per-client connection wrapper
#### `message_processor.py` - Message Processing Pipeline #### `message_processor.py` - Message Processing Pipeline
```python ```python
MessageProcessor() MessageProcessor()
├── process_message() # Main message dispatcher ├── parse_message() # Message parsing and validation
├── _handle_subscribe() # Process subscription requests ├── _validate_set_subscription_list() # Validate declarative subscriptions
├── _handle_unsubscribe() # Process unsubscription ├── _validate_set_session() # Session management validation
├── _handle_state_request() # System state requests ├── _validate_patch_session() # Patch session validation
└── _handle_session_ops() # Session management operations ├── _validate_progression_stage() # Progression stage validation
└── _validate_patch_session_result() # Backend response validation
MessageType(Enum) # Supported message types MessageType(Enum) # Protocol-compliant message types per worker.md
├── SUBSCRIBE ├── SET_SUBSCRIPTION_LIST # Primary declarative subscription command
├── UNSUBSCRIBE ├── REQUEST_STATE # System state requests
├── REQUEST_STATE ├── SET_SESSION_ID # Session association (supports null clearing)
├── SET_SESSION_ID ├── PATCH_SESSION # Session modification requests
└── PATCH_SESSION ├── SET_PROGRESSION_STAGE # Real-time progression updates
├── PATCH_SESSION_RESULT # Backend responses to patch requests
├── IMAGE_DETECTION # Detection results (worker->backend)
└── STATE_REPORT # Heartbeat messages (worker->backend)
# Protocol-Compliant Payload Classes
SubscriptionObject() # Individual subscription per worker.md specification
├── subscription_identifier # Format: "displayId;cameraId"
├── rtsp_url # Required RTSP stream URL
├── model_url # Required fresh model URL (1-hour TTL)
├── model_id, model_name # Model identification
├── snapshot_url # Optional HTTP snapshot URL
├── snapshot_interval # Required if snapshot_url provided
└── crop_x1/y1/x2/y2 # Optional crop coordinates
SetSubscriptionListPayload() # Declarative subscription command payload
└── subscriptions: List[SubscriptionObject] # Complete desired state
SessionPayload() # Session management payload
├── display_identifier # Target display ID
├── session_id # Session ID (can be null for clearing)
└── data # Optional session patch data
ProgressionPayload() # Progression stage payload
├── display_identifier # Target display ID
└── progression_stage # Stage: welcome|car_fueling|car_waitpayment|null
``` ```
### Stream Management (`detector_worker/streams/`) ### Stream Management (`detector_worker/streams/`)
@ -452,105 +478,313 @@ sequenceDiagram
return stream_manager.get_latest_frame(camera_id) return stream_manager.get_latest_frame(camera_id)
``` ```
## Protocol Compliance (worker.md Implementation)
### Key Protocol Features Implemented
The WebSocket communication layer has been **fully updated** to comply with the worker.md protocol specification, replacing deprecated patterns with modern declarative subscription management.
#### ✅ **Protocol-Compliant Message Types**
| Message Type | Direction | Purpose | Status |
|--------------|-----------|---------|---------|
| `setSubscriptionList` | Backend→Worker | **Primary subscription command** - declarative management | ✅ **Implemented** |
| `setSessionId` | Backend→Worker | Associate session with display (supports null clearing) | ✅ **Implemented** |
| `setProgressionStage` | Backend→Worker | Real-time progression updates for context-aware processing | ✅ **Implemented** |
| `requestState` | Backend→Worker | Request immediate state report | ✅ **Implemented** |
| `patchSessionResult` | Backend→Worker | Response to worker's patch session request | ✅ **Implemented** |
| `stateReport` | Worker→Backend | **Heartbeat with performance metrics** (every 2 seconds) | ✅ **Implemented** |
| `imageDetection` | Worker→Backend | Real-time detection results with session context | ✅ **Implemented** |
| `patchSession` | Worker→Backend | Request modification to session data | ✅ **Implemented** |
#### ❌ **Deprecated Message Types Removed**
- `subscribe` - Replaced by declarative `setSubscriptionList`
- `unsubscribe` - Handled by empty `setSubscriptionList` array
#### 🔧 **Key Protocol Implementations**
##### 1. **Declarative Subscription Management**
```python
# setSubscriptionList provides complete desired state
{
"type": "setSubscriptionList",
"subscriptions": [
{
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://192.168.1.100/stream1",
"modelUrl": "http://storage/models/vehicle-id.mpta?token=fresh-token",
"modelId": 201,
"modelName": "Vehicle Identification",
"cropX1": 100, "cropY1": 200, "cropX2": 300, "cropY2": 400
}
]
}
```
**Worker Reconciliation Logic:**
- **Add new subscriptions** not in current state
- **Remove obsolete subscriptions** not in desired state
- **Update existing subscriptions** with fresh model URLs (handles S3 expiration)
- **Single stream optimization** - share RTSP streams across multiple subscriptions
##### 2. **Protocol-Compliant State Reports**
```python
# stateReport with flat structure per worker.md specification
{
"type": "stateReport",
"cpuUsage": 75.5,
"memoryUsage": 40.2,
"gpuUsage": 60.0,
"gpuMemoryUsage": 25.1,
"cameraConnections": [
{
"subscriptionIdentifier": "display-001;cam-001",
"modelId": 101,
"modelName": "General Object Detection",
"online": true,
"cropX1": 100, "cropY1": 200, "cropX2": 300, "cropY2": 400
}
]
}
```
##### 3. **Session Context in Detection Results**
```python
# imageDetection with sessionId for proper session linking
{
"type": "imageDetection",
"subscriptionIdentifier": "display-001;cam-001",
"timestamp": "2025-09-12T10:00:00.000Z",
"sessionId": 12345, # Critical for CMS session tracking
"data": {
"detection": {
"carBrand": "Honda",
"carModel": "CR-V",
"licensePlateText": "ABC-123"
},
"modelId": 201,
"modelName": "Vehicle Identification"
}
}
```
#### 🚀 **Enhanced Features**
1. **State Recovery**: Complete subscription restoration on worker reconnection
2. **Fresh Model URLs**: Automatic S3 URL refresh via subscription updates
3. **Multi-Display Support**: Proper display identifier parsing and session management
4. **Load Balancing Ready**: Backend can distribute subscriptions across workers
5. **Session Continuity**: Session IDs maintained across worker disconnections
## WebSocket Communication Flow ## WebSocket Communication Flow
### Client Connection Lifecycle ### Client Connection Lifecycle (Protocol-Compliant)
```mermaid ```mermaid
sequenceDiagram sequenceDiagram
participant Client as WebSocket Client participant Backend as CMS Backend
participant WS as WebSocketHandler participant WS as WebSocketHandler
participant CM as ConnectionManager
participant MP as MessageProcessor participant MP as MessageProcessor
participant SM as StreamManager participant SM as StreamManager
participant HB as Heartbeat Loop
Client->>WS: WebSocket Connection Backend->>WS: WebSocket Connection
WS->>WS: handle_websocket() WS->>WS: handle_connection()
WS->>CM: add_connection() WS->>HB: start_heartbeat_loop()
CM->>CM: create WebSocketConnection WS->>Backend: Connection Accepted
WS->>Client: Connection Accepted
loop Message Processing loop Heartbeat (every 2 seconds)
Client->>WS: JSON Message HB->>Backend: stateReport {cpuUsage, memoryUsage, gpuUsage, cameraConnections[]}
WS->>MP: process_message() end
alt Subscribe Message loop Message Processing (Protocol Commands)
MP->>SM: create_stream() Backend->>WS: JSON Message
SM->>SM: initialize StreamReader WS->>MP: parse_message()
MP->>Client: subscribeAck
else Unsubscribe Message alt setSubscriptionList (Declarative)
MP->>SM: remove_stream() MP->>MP: validate_subscription_list()
SM->>SM: cleanup StreamReader WS->>WS: reconcile_subscriptions()
MP->>Client: unsubscribeAck WS->>SM: add/remove/update_streams()
else State Request SM->>SM: handle_stream_lifecycle()
MP->>MP: collect_system_state() WS->>HB: update_camera_connections()
MP->>Client: stateReport
else setSessionId
MP->>MP: validate_set_session()
WS->>WS: store_session_for_display()
WS->>WS: apply_to_all_display_subscriptions()
else setProgressionStage
MP->>MP: validate_progression_stage()
WS->>WS: store_stage_for_display()
WS->>WS: apply_context_aware_processing()
else requestState
WS->>Backend: stateReport (immediate response)
else patchSessionResult
MP->>MP: validate_patch_result()
WS->>WS: log_patch_response()
end end
end end
Client->>WS: Disconnect Backend->>WS: Disconnect
WS->>CM: remove_connection() WS->>SM: cleanup_all_streams()
WS->>SM: cleanup_client_streams() WS->>HB: stop_heartbeat_loop()
``` ```
### Message Processing Detail ### Message Processing Detail
#### 1. Subscribe Message Flow (`message_processor.py:125-185`) #### 1. setSubscriptionList Flow (`websocket_handler.py:355-453`) - Protocol Compliant
```python ```python
async def _handle_subscribe(self, payload: Dict, client_id: str) -> Dict: async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None:
"""Process subscription request""" """Handle setSubscriptionList command - declarative subscription management"""
# 1. Extract subscription parameters subscriptions = data.get("subscriptions", [])
subscription_id = payload["subscriptionIdentifier"]
stream_url = payload.get("rtspUrl") or payload.get("snapshotUrl")
model_url = payload["modelUrl"]
# 2. Create stream configuration # 1. Get current and desired subscription states
stream_config = StreamConfig( current_subscriptions = set(subscription_to_camera.keys())
stream_url=stream_url, desired_subscriptions = set()
stream_type="rtsp" if "rtsp" in stream_url else "http_snapshot", subscription_configs = {}
crop_region=[payload.get("cropX1"), payload.get("cropY1"),
payload.get("cropX2"), payload.get("cropY2")]
)
# 3. Load ML pipeline for sub_config in subscriptions:
pipeline_config = await pipeline_loader.load_from_url(model_url) sub_id = sub_config.get("subscriptionIdentifier")
if sub_id:
desired_subscriptions.add(sub_id)
subscription_configs[sub_id] = sub_config
# 4. Create stream (with sharing if same URL) # Extract display ID for session management
stream_info = await stream_manager.create_stream( parts = sub_id.split(";")
camera_id=subscription_id.split(';')[1], if len(parts) >= 2:
config=stream_config, display_id = parts[0]
subscription_id=subscription_id self.display_identifiers.add(display_id)
)
# 5. Register client subscription # 2. Calculate reconciliation changes
connection_manager.add_subscription(client_id, subscription_id) to_add = desired_subscriptions - current_subscriptions
to_remove = current_subscriptions - desired_subscriptions
to_update = desired_subscriptions & current_subscriptions
return {"type": "subscribeAck", "status": "success", # 3. Remove obsolete subscriptions
"subscriptionId": subscription_id} for sub_id in to_remove:
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
subscription_to_camera.pop(sub_id, None)
self.session_cache.clear_session(camera_id)
# 4. Add new subscriptions
for sub_id in to_add:
await self._start_subscription(sub_id, subscription_configs[sub_id])
# 5. Update existing subscriptions (handle S3 URL refresh)
for sub_id in to_update:
current_config = subscription_to_camera.get(sub_id)
new_config = subscription_configs[sub_id]
# Restart if model URL changed (handles S3 expiration)
current_model_url = getattr(current_config, 'model_url', None)
new_model_url = new_config.get("modelUrl")
if current_model_url != new_model_url:
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
await self._start_subscription(sub_id, new_config)
async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None:
"""Start individual subscription with protocol-compliant configuration"""
# Extract camera ID from subscription identifier
parts = subscription_id.split(";")
camera_id = parts[1] if len(parts) >= 2 else subscription_id
# Store subscription mapping
subscription_to_camera[subscription_id] = camera_id
# Start camera stream with full config
await self.stream_manager.start_stream(camera_id, config)
# Load ML model from fresh URL
model_id = config.get("modelId")
model_url = config.get("modelUrl")
if model_id and model_url:
await self.model_manager.load_model(camera_id, model_id, model_url)
``` ```
#### 2. Detection Result Broadcasting (`websocket_handler.py:245-265`) #### 2. Detection Result Broadcasting (`websocket_handler.py:589-615`) - Protocol Compliant
```python ```python
async def broadcast_detection_result(self, subscription_id: str, async def _send_detection_result(self, camera_id: str, stream_info: Dict[str, Any],
detection_result: Dict): detection_result: DetectionResult) -> None:
"""Broadcast detection to subscribed clients""" """Send detection result with protocol-compliant format"""
message = { # Get session ID for this display (protocol requirement)
subscription_id = stream_info["subscriptionIdentifier"]
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
session_id = self.session_ids.get(display_id) # Can be None for null sessions
# Protocol-compliant imageDetection format per worker.md
detection_data = {
"type": "imageDetection", "type": "imageDetection",
"payload": { "subscriptionIdentifier": subscription_id, # Required at root level
"subscriptionId": subscription_id, "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"detections": detection_result["detections"], "sessionId": session_id, # Required for session linking
"timestamp": detection_result["timestamp"], "data": {
"modelInfo": detection_result["model_info"] "detection": detection_result.to_dict(), # Flat detection object
"modelId": stream_info["modelId"],
"modelName": stream_info["modelName"]
} }
} }
await self.connection_manager.broadcast_to_subscription( # Send to backend via WebSocket
subscription_id, message await self.websocket.send_json(detection_data)
) ```
#### 3. State Report Broadcasting (`websocket_handler.py:201-209`) - Protocol Compliant
```python
async def _send_heartbeat(self) -> None:
"""Send protocol-compliant stateReport every 2 seconds"""
while self.connected:
# Get system metrics
metrics = get_system_metrics()
# Build cameraConnections array as required by protocol
camera_connections = []
with self.stream_manager.streams_lock:
for camera_id, stream_info in self.stream_manager.streams.items():
is_online = self.stream_manager.is_stream_active(camera_id)
connection_info = {
"subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id),
"modelId": stream_info.get("modelId", 0),
"modelName": stream_info.get("modelName", "Unknown Model"),
"online": is_online
}
# Add crop coordinates if available (optional per protocol)
for coord in ["cropX1", "cropY1", "cropX2", "cropY2"]:
if coord in stream_info:
connection_info[coord] = stream_info[coord]
camera_connections.append(connection_info)
# Protocol-compliant stateReport format (worker.md lines 169-189)
state_data = {
"type": "stateReport", # Message type
"cpuUsage": metrics.get("cpu_percent", 0), # CPU percentage
"memoryUsage": metrics.get("memory_percent", 0), # Memory percentage
"gpuUsage": metrics.get("gpu_percent", 0), # GPU percentage
"gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # GPU memory
"cameraConnections": camera_connections # Camera details array
}
await self.websocket.send_json(state_data)
await asyncio.sleep(2) # 2-second heartbeat interval per protocol
``` ```
## Detection Pipeline Flow ## Detection Pipeline Flow

View file

@ -7,7 +7,7 @@ It provides a clean separation between message handling and business logic.
import json import json
import logging import logging
import time import time
from typing import Dict, Any, Optional, Callable, Tuple from typing import Dict, Any, Optional, Callable, Tuple, List
from enum import Enum from enum import Enum
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -19,13 +19,13 @@ msg_proc_logger = logging.getLogger("websocket.message_processor") # Detailed m
class MessageType(Enum): class MessageType(Enum):
"""Enumeration of supported WebSocket message types.""" """Enumeration of supported WebSocket message types per worker.md protocol."""
SUBSCRIBE = "subscribe" SET_SUBSCRIPTION_LIST = "setSubscriptionList"
UNSUBSCRIBE = "unsubscribe"
REQUEST_STATE = "requestState" REQUEST_STATE = "requestState"
SET_SESSION_ID = "setSessionId" SET_SESSION_ID = "setSessionId"
PATCH_SESSION = "patchSession" PATCH_SESSION = "patchSession"
SET_PROGRESSION_STAGE = "setProgressionStage" SET_PROGRESSION_STAGE = "setProgressionStage"
PATCH_SESSION_RESULT = "patchSessionResult"
IMAGE_DETECTION = "imageDetection" IMAGE_DETECTION = "imageDetection"
STATE_REPORT = "stateReport" STATE_REPORT = "stateReport"
ACK = "ack" ACK = "ack"
@ -41,58 +41,47 @@ class ProgressionStage(Enum):
@dataclass @dataclass
class SubscribePayload: class SubscriptionObject:
"""Payload for subscription messages.""" """Subscription object per worker.md protocol specification."""
subscription_identifier: str subscription_identifier: str
rtsp_url: str
model_url: str
model_id: int model_id: int
model_name: str model_name: str
model_url: str
rtsp_url: Optional[str] = None
snapshot_url: Optional[str] = None snapshot_url: Optional[str] = None
snapshot_interval: Optional[int] = None snapshot_interval: Optional[int] = None
crop_x1: Optional[int] = None crop_x1: Optional[int] = None
crop_y1: Optional[int] = None crop_y1: Optional[int] = None
crop_x2: Optional[int] = None crop_x2: Optional[int] = None
crop_y2: Optional[int] = None crop_y2: Optional[int] = None
extra_params: Dict[str, Any] = field(default_factory=dict)
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SubscribePayload": def from_dict(cls, data: Dict[str, Any]) -> "SubscriptionObject":
"""Create SubscribePayload from dictionary.""" """Create SubscriptionObject from dictionary."""
# Extract known fields return cls(
known_fields = { subscription_identifier=data.get("subscriptionIdentifier", ""),
"subscription_identifier": data.get("subscriptionIdentifier"), rtsp_url=data.get("rtspUrl", ""),
"model_id": data.get("modelId"), model_url=data.get("modelUrl", ""),
"model_name": data.get("modelName"), model_id=data.get("modelId", 0),
"model_url": data.get("modelUrl"), model_name=data.get("modelName", ""),
"rtsp_url": data.get("rtspUrl"), snapshot_url=data.get("snapshotUrl"),
"snapshot_url": data.get("snapshotUrl"), snapshot_interval=data.get("snapshotInterval"),
"snapshot_interval": data.get("snapshotInterval"), crop_x1=data.get("cropX1"),
"crop_x1": data.get("cropX1"), crop_y1=data.get("cropY1"),
"crop_y1": data.get("cropY1"), crop_x2=data.get("cropX2"),
"crop_x2": data.get("cropX2"), crop_y2=data.get("cropY2")
"crop_y2": data.get("cropY2") )
}
# Extract extra parameters
extra_params = {k: v for k, v in data.items()
if k not in ["subscriptionIdentifier", "modelId", "modelName",
"modelUrl", "rtspUrl", "snapshotUrl", "snapshotInterval",
"cropX1", "cropY1", "cropX2", "cropY2"]}
return cls(**known_fields, extra_params=extra_params)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format for stream configuration.""" """Convert to dictionary format for stream configuration."""
result = { result = {
"subscriptionIdentifier": self.subscription_identifier, "subscriptionIdentifier": self.subscription_identifier,
"rtspUrl": self.rtsp_url,
"modelUrl": self.model_url,
"modelId": self.model_id, "modelId": self.model_id,
"modelName": self.model_name, "modelName": self.model_name
"modelUrl": self.model_url
} }
if self.rtsp_url:
result["rtspUrl"] = self.rtsp_url
if self.snapshot_url: if self.snapshot_url:
result["snapshotUrl"] = self.snapshot_url result["snapshotUrl"] = self.snapshot_url
if self.snapshot_interval is not None: if self.snapshot_interval is not None:
@ -107,12 +96,22 @@ class SubscribePayload:
if self.crop_y2 is not None: if self.crop_y2 is not None:
result["cropY2"] = self.crop_y2 result["cropY2"] = self.crop_y2
# Add any extra parameters
result.update(self.extra_params)
return result return result
@dataclass
class SetSubscriptionListPayload:
"""Payload for setSubscriptionList command per worker.md protocol."""
subscriptions: List[SubscriptionObject]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SetSubscriptionListPayload":
"""Create SetSubscriptionListPayload from dictionary."""
subscriptions_data = data.get("subscriptions", [])
subscriptions = [SubscriptionObject.from_dict(sub) for sub in subscriptions_data]
return cls(subscriptions=subscriptions)
@dataclass @dataclass
class SessionPayload: class SessionPayload:
"""Payload for session-related messages.""" """Payload for session-related messages."""
@ -142,11 +141,11 @@ class MessageProcessor:
def __init__(self): def __init__(self):
"""Initialize the message processor.""" """Initialize the message processor."""
self.validators: Dict[MessageType, Callable] = { self.validators: Dict[MessageType, Callable] = {
MessageType.SUBSCRIBE: self._validate_subscribe, MessageType.SET_SUBSCRIPTION_LIST: self._validate_set_subscription_list,
MessageType.UNSUBSCRIBE: self._validate_unsubscribe,
MessageType.SET_SESSION_ID: self._validate_set_session, MessageType.SET_SESSION_ID: self._validate_set_session,
MessageType.PATCH_SESSION: self._validate_patch_session, MessageType.PATCH_SESSION: self._validate_patch_session,
MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage,
MessageType.PATCH_SESSION_RESULT: self._validate_patch_session_result
} }
def parse_message(self, raw_message: str) -> Tuple[MessageType, Dict[str, Any]]: def parse_message(self, raw_message: str) -> Tuple[MessageType, Dict[str, Any]]:
@ -222,58 +221,69 @@ class MessageProcessor:
msg_proc_logger.error(f"❌ VALIDATION FAILED FOR {msg_type.value}: {e}") msg_proc_logger.error(f"❌ VALIDATION FAILED FOR {msg_type.value}: {e}")
raise raise
def _validate_subscribe(self, data: Dict[str, Any]) -> SubscribePayload: def _validate_set_subscription_list(self, data: Dict[str, Any]) -> SetSubscriptionListPayload:
"""Validate subscribe message.""" """Validate setSubscriptionList message per worker.md protocol."""
subscriptions_data = data.get("subscriptions", [])
if not isinstance(subscriptions_data, list):
raise ValidationError("subscriptions must be an array")
# Validate each subscription object
for i, sub_data in enumerate(subscriptions_data):
if not isinstance(sub_data, dict):
raise ValidationError(f"Subscription {i} must be an object")
# Required fields per protocol
required = ["subscriptionIdentifier", "rtspUrl", "modelUrl", "modelId", "modelName"]
missing = [field for field in required if not sub_data.get(field)]
if missing:
raise ValidationError(f"Subscription {i} missing required fields: {', '.join(missing)}")
# Validate snapshot interval if snapshot URL is provided
if sub_data.get("snapshotUrl") and not sub_data.get("snapshotInterval"):
raise ValidationError(f"Subscription {i}: snapshotInterval is required when using snapshotUrl")
# Validate crop coordinates
crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"]
crop_values = [sub_data.get(coord) for coord in crop_coords]
if any(v is not None for v in crop_values):
# If any crop coordinate is set, all must be set
if not all(v is not None for v in crop_values):
raise ValidationError(f"Subscription {i}: All crop coordinates must be specified if any are set")
# Validate coordinate values
x1, y1, x2, y2 = crop_values
if x2 <= x1 or y2 <= y1:
raise ValidationError(f"Subscription {i}: Invalid crop coordinates: x2 must be > x1, y2 must be > y1")
return SetSubscriptionListPayload.from_dict(data)
def _validate_patch_session_result(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate patchSessionResult message."""
payload = data.get("payload", {}) payload = data.get("payload", {})
# Required fields # Required fields per protocol
required = ["subscriptionIdentifier", "modelId", "modelName", "modelUrl"] session_id = payload.get("sessionId")
missing = [field for field in required if not payload.get(field)] if session_id is None:
if missing: raise ValidationError("Missing required field: sessionId")
raise ValidationError(f"Missing required fields: {', '.join(missing)}")
# Must have either rtspUrl or snapshotUrl success = payload.get("success")
if not payload.get("rtspUrl") and not payload.get("snapshotUrl"): if success is None:
raise ValidationError("Must provide either rtspUrl or snapshotUrl") raise ValidationError("Missing required field: success")
# Validate snapshot interval if snapshot URL is provided
if payload.get("snapshotUrl") and not payload.get("snapshotInterval"):
raise ValidationError("snapshotInterval is required when using snapshotUrl")
# Validate crop coordinates
crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"]
crop_values = [payload.get(coord) for coord in crop_coords]
if any(v is not None for v in crop_values):
# If any crop coordinate is set, all must be set
if not all(v is not None for v in crop_values):
raise ValidationError("All crop coordinates must be specified if any are set")
# Validate coordinate values
x1, y1, x2, y2 = crop_values
if x2 <= x1 or y2 <= y1:
raise ValidationError("Invalid crop coordinates: x2 must be > x1, y2 must be > y1")
return SubscribePayload.from_dict(payload)
def _validate_unsubscribe(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate unsubscribe message."""
payload = data.get("payload", {})
if not payload.get("subscriptionIdentifier"):
raise ValidationError("Missing required field: subscriptionIdentifier")
return payload return payload
def _validate_set_session(self, data: Dict[str, Any]) -> SessionPayload: def _validate_set_session(self, data: Dict[str, Any]) -> SessionPayload:
"""Validate setSessionId message.""" """Validate setSessionId message per worker.md protocol."""
payload = data.get("payload", {}) payload = data.get("payload", {})
display_id = payload.get("displayIdentifier") display_id = payload.get("displayIdentifier")
session_id = payload.get("sessionId") session_id = payload.get("sessionId") # Can be null to clear session
if not display_id: if not display_id:
raise ValidationError("Missing required field: displayIdentifier") raise ValidationError("Missing required field: displayIdentifier")
if not session_id: # sessionId can be null to clear the session per protocol
if "sessionId" not in payload:
raise ValidationError("Missing required field: sessionId") raise ValidationError("Missing required field: sessionId")
return SessionPayload(display_id, session_id) return SessionPayload(display_id, session_id)

View file

@ -83,16 +83,17 @@ class WebSocketHandler:
# Message handlers # Message handlers
self.message_handlers: Dict[str, MessageHandler] = { self.message_handlers: Dict[str, MessageHandler] = {
"subscribe": self._handle_subscribe, "setSubscriptionList": self._handle_set_subscription_list,
"unsubscribe": self._handle_unsubscribe,
"requestState": self._handle_request_state, "requestState": self._handle_request_state,
"setSessionId": self._handle_set_session, "setSessionId": self._handle_set_session,
"patchSession": self._handle_patch_session, "patchSession": self._handle_patch_session,
"setProgressionStage": self._handle_set_progression_stage "setProgressionStage": self._handle_set_progression_stage,
"patchSessionResult": self._handle_patch_session_result
} }
# Session and display management # Session and display management
self.session_ids: Dict[str, str] = {} # display_identifier -> session_id self.session_ids: Dict[str, str] = {} # display_identifier -> session_id
self.progression_stages: Dict[str, str] = {} # display_identifier -> progression_stage
self.display_identifiers: Set[str] = set() self.display_identifiers: Set[str] = set()
# Camera monitor # Camera monitor
@ -171,22 +172,40 @@ class WebSocketHandler:
# Get system metrics # Get system metrics
metrics = get_system_metrics() metrics = get_system_metrics()
# Get active streams info # Build cameraConnections array as required by protocol
active_streams = self.stream_manager.get_active_streams() camera_connections = []
active_models = self.model_manager.get_loaded_models() with self.stream_manager.streams_lock:
for camera_id, stream_info in self.stream_manager.streams.items():
# Check if camera is online
is_online = self.stream_manager.is_stream_active(camera_id)
connection_info = {
"subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id),
"modelId": stream_info.get("modelId", 0),
"modelName": stream_info.get("modelName", "Unknown Model"),
"online": is_online
}
# Add crop coordinates if available
if "cropX1" in stream_info:
connection_info["cropX1"] = stream_info["cropX1"]
if "cropY1" in stream_info:
connection_info["cropY1"] = stream_info["cropY1"]
if "cropX2" in stream_info:
connection_info["cropX2"] = stream_info["cropX2"]
if "cropY2" in stream_info:
connection_info["cropY2"] = stream_info["cropY2"]
camera_connections.append(connection_info)
# Protocol-compliant stateReport format (worker.md lines 169-189)
state_data = { state_data = {
"type": "stateReport", "type": "stateReport",
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "cpuUsage": metrics.get("cpu_percent", 0),
"data": { "memoryUsage": metrics.get("memory_percent", 0),
"activeStreams": len(active_streams), "gpuUsage": metrics.get("gpu_percent", 0),
"loadedModels": len(active_models), "gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # Fixed field name
"cpuUsage": metrics.get("cpu_percent", 0), "cameraConnections": camera_connections
"memoryUsage": metrics.get("memory_percent", 0),
"gpuUsage": metrics.get("gpu_percent", 0),
"gpuMemory": metrics.get("gpu_memory_percent", 0),
"uptime": time.time() - metrics.get("start_time", time.time())
}
} }
# Compact JSON for RX/TX logging # Compact JSON for RX/TX logging
@ -351,75 +370,105 @@ class WebSocketHandler:
traceback.print_exc() traceback.print_exc()
return persistent_data return persistent_data
async def _handle_subscribe(self, data: Dict[str, Any]) -> None: async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None:
"""Handle stream subscription request.""" """
payload = data.get("payload", {}) Handle setSubscriptionList command - declarative subscription management.
subscription_id = payload.get("subscriptionIdentifier")
if not subscription_id: This is the primary subscription command per worker.md protocol.
logger.error("Missing subscriptionIdentifier in subscribe payload") Workers must reconcile the new subscription list with current state.
return """
subscriptions = data.get("subscriptions", [])
try: try:
# Extract display and camera IDs # Get current subscription identifiers
current_subscriptions = set(subscription_to_camera.keys())
# Get desired subscription identifiers
desired_subscriptions = set()
subscription_configs = {}
for sub_config in subscriptions:
sub_id = sub_config.get("subscriptionIdentifier")
if sub_id:
desired_subscriptions.add(sub_id)
subscription_configs[sub_id] = sub_config
# Extract display ID for session management
parts = sub_id.split(";")
if len(parts) >= 2:
display_id = parts[0]
self.display_identifiers.add(display_id)
# Calculate changes needed
to_add = desired_subscriptions - current_subscriptions
to_remove = current_subscriptions - desired_subscriptions
to_update = desired_subscriptions & current_subscriptions
logger.info(f"Subscription reconciliation: add={len(to_add)}, remove={len(to_remove)}, update={len(to_update)}")
# Remove obsolete subscriptions
for sub_id in to_remove:
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
subscription_to_camera.pop(sub_id, None)
self.session_cache.clear_session(camera_id)
logger.info(f"Removed subscription: {sub_id}")
# Add new subscriptions
for sub_id in to_add:
await self._start_subscription(sub_id, subscription_configs[sub_id])
logger.info(f"Added subscription: {sub_id}")
# Update existing subscriptions if needed
for sub_id in to_update:
# Check if configuration changed (model URL, crop coordinates, etc.)
current_config = subscription_to_camera.get(sub_id)
new_config = subscription_configs[sub_id]
# For now, restart subscription if model URL changed (handles S3 expiration)
current_model_url = getattr(current_config, 'model_url', None) if current_config else None
new_model_url = new_config.get("modelUrl")
if current_model_url != new_model_url:
# Restart with new configuration
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
await self._start_subscription(sub_id, new_config)
logger.info(f"Updated subscription: {sub_id}")
logger.info(f"Subscription list reconciliation completed. Active: {len(desired_subscriptions)}")
except Exception as e:
logger.error(f"Error handling setSubscriptionList: {e}")
traceback.print_exc()
async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None:
"""Start a single subscription with given configuration."""
try:
# Extract camera ID from subscription identifier
parts = subscription_id.split(";") parts = subscription_id.split(";")
if len(parts) >= 2: camera_id = parts[1] if len(parts) >= 2 else subscription_id
display_id = parts[0]
camera_id = parts[1]
self.display_identifiers.add(display_id)
else:
camera_id = subscription_id
# Store subscription mapping # Store subscription mapping
subscription_to_camera[subscription_id] = camera_id subscription_to_camera[subscription_id] = camera_id
# Start camera stream # Start camera stream
await self.stream_manager.start_stream(camera_id, payload) await self.stream_manager.start_stream(camera_id, config)
# Load model # Load model
model_id = payload.get("modelId") model_id = config.get("modelId")
model_url = payload.get("modelUrl") model_url = config.get("modelUrl")
if model_id and model_url: if model_id and model_url:
await self.model_manager.load_model(camera_id, model_id, model_url) await self.model_manager.load_model(camera_id, model_id, model_url)
logger.info(f"Subscribed to stream: {subscription_id}")
except Exception as e: except Exception as e:
logger.error(f"Error handling subscription: {e}") logger.error(f"Error starting subscription {subscription_id}: {e}")
traceback.print_exc() raise
async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None:
"""Handle stream unsubscription request."""
payload = data.get("payload", {})
subscription_id = payload.get("subscriptionIdentifier")
if not subscription_id:
logger.error("Missing subscriptionIdentifier in unsubscribe payload")
return
try:
# Get camera ID from subscription
camera_id = subscription_to_camera.get(subscription_id)
if not camera_id:
logger.warning(f"No camera found for subscription: {subscription_id}")
return
# Stop stream
await self.stream_manager.stop_stream(camera_id)
# Unload model
self.model_manager.unload_models(camera_id)
# Clean up mappings
subscription_to_camera.pop(subscription_id, None)
# Clean up session state
self.session_cache.clear_session(camera_id)
logger.info(f"Unsubscribed from stream: {subscription_id}")
except Exception as e:
logger.error(f"Error handling unsubscription: {e}")
traceback.print_exc() traceback.print_exc()
async def _handle_request_state(self, data: Dict[str, Any]) -> None: async def _handle_request_state(self, data: Dict[str, Any]) -> None:
@ -535,6 +584,26 @@ class WebSocketHandler:
pipeline_state["progression_stage"] = progression_stage pipeline_state["progression_stage"] = progression_stage
logger.info(f"📋 Camera {camera_id}: Progression stage set to {progression_stage}") logger.info(f"📋 Camera {camera_id}: Progression stage set to {progression_stage}")
# Store progression stage for this display
if display_id and progression_stage is not None:
if progression_stage:
self.progression_stages[display_id] = progression_stage
else:
# Clear progression stage if null
self.progression_stages.pop(display_id, None)
async def _handle_patch_session_result(self, data: Dict[str, Any]) -> None:
"""Handle patchSessionResult message from backend."""
payload = data.get("payload", {})
session_id = payload.get("sessionId")
success = payload.get("success", False)
message = payload.get("message", "")
if success:
logger.info(f"Patch session {session_id} successful: {message}")
else:
logger.warning(f"Patch session {session_id} failed: {message}")
async def _send_detection_result( async def _send_detection_result(
self, self,
camera_id: str, camera_id: str,
@ -542,10 +611,16 @@ class WebSocketHandler:
detection_result: DetectionResult detection_result: DetectionResult
) -> None: ) -> None:
"""Send detection result over WebSocket.""" """Send detection result over WebSocket."""
# Get session ID for this display
subscription_id = stream_info["subscriptionIdentifier"]
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
session_id = self.session_ids.get(display_id)
detection_data = { detection_data = {
"type": "imageDetection", "type": "imageDetection",
"subscriptionIdentifier": stream_info["subscriptionIdentifier"], "subscriptionIdentifier": subscription_id,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"sessionId": session_id, # Required by protocol
"data": { "data": {
"detection": detection_result.to_dict(), "detection": detection_result.to_dict(),
"modelId": stream_info["modelId"], "modelId": stream_info["modelId"],