Fix: update message type to current implementation

This commit is contained in:
ziesorx 2025-09-12 22:16:06 +07:00
parent b940790e4a
commit 96ecc321ec
3 changed files with 549 additions and 230 deletions

View file

@ -141,18 +141,44 @@ WebSocketConnection() # Per-client connection wrapper
#### `message_processor.py` - Message Processing Pipeline
```python
MessageProcessor()
├── process_message() # Main message dispatcher
├── _handle_subscribe() # Process subscription requests
├── _handle_unsubscribe() # Process unsubscription
├── _handle_state_request() # System state requests
└── _handle_session_ops() # Session management operations
├── parse_message() # Message parsing and validation
├── _validate_set_subscription_list() # Validate declarative subscriptions
├── _validate_set_session() # Session management validation
├── _validate_patch_session() # Patch session validation
├── _validate_progression_stage() # Progression stage validation
└── _validate_patch_session_result() # Backend response validation
MessageType(Enum) # Supported message types
├── SUBSCRIBE
├── UNSUBSCRIBE
├── REQUEST_STATE
├── SET_SESSION_ID
└── PATCH_SESSION
MessageType(Enum) # Protocol-compliant message types per worker.md
├── SET_SUBSCRIPTION_LIST # Primary declarative subscription command
├── REQUEST_STATE # System state requests
├── SET_SESSION_ID # Session association (supports null clearing)
├── PATCH_SESSION # Session modification requests
├── SET_PROGRESSION_STAGE # Real-time progression updates
├── PATCH_SESSION_RESULT # Backend responses to patch requests
├── IMAGE_DETECTION # Detection results (worker->backend)
└── STATE_REPORT # Heartbeat messages (worker->backend)
# Protocol-Compliant Payload Classes
SubscriptionObject() # Individual subscription per worker.md specification
├── subscription_identifier # Format: "displayId;cameraId"
├── rtsp_url # Required RTSP stream URL
├── model_url # Required fresh model URL (1-hour TTL)
├── model_id, model_name # Model identification
├── snapshot_url # Optional HTTP snapshot URL
├── snapshot_interval # Required if snapshot_url provided
└── crop_x1/y1/x2/y2 # Optional crop coordinates
SetSubscriptionListPayload() # Declarative subscription command payload
└── subscriptions: List[SubscriptionObject] # Complete desired state
SessionPayload() # Session management payload
├── display_identifier # Target display ID
├── session_id # Session ID (can be null for clearing)
└── data # Optional session patch data
ProgressionPayload() # Progression stage payload
├── display_identifier # Target display ID
└── progression_stage # Stage: welcome|car_fueling|car_waitpayment|null
```
### Stream Management (`detector_worker/streams/`)
@ -452,105 +478,313 @@ sequenceDiagram
return stream_manager.get_latest_frame(camera_id)
```
## Protocol Compliance (worker.md Implementation)
### Key Protocol Features Implemented
The WebSocket communication layer has been **fully updated** to comply with the worker.md protocol specification, replacing deprecated patterns with modern declarative subscription management.
#### ✅ **Protocol-Compliant Message Types**
| Message Type | Direction | Purpose | Status |
|--------------|-----------|---------|---------|
| `setSubscriptionList` | Backend→Worker | **Primary subscription command** - declarative management | ✅ **Implemented** |
| `setSessionId` | Backend→Worker | Associate session with display (supports null clearing) | ✅ **Implemented** |
| `setProgressionStage` | Backend→Worker | Real-time progression updates for context-aware processing | ✅ **Implemented** |
| `requestState` | Backend→Worker | Request immediate state report | ✅ **Implemented** |
| `patchSessionResult` | Backend→Worker | Response to worker's patch session request | ✅ **Implemented** |
| `stateReport` | Worker→Backend | **Heartbeat with performance metrics** (every 2 seconds) | ✅ **Implemented** |
| `imageDetection` | Worker→Backend | Real-time detection results with session context | ✅ **Implemented** |
| `patchSession` | Worker→Backend | Request modification to session data | ✅ **Implemented** |
#### ❌ **Deprecated Message Types Removed**
- `subscribe` - Replaced by declarative `setSubscriptionList`
- `unsubscribe` - Handled by empty `setSubscriptionList` array
#### 🔧 **Key Protocol Implementations**
##### 1. **Declarative Subscription Management**
```python
# setSubscriptionList provides complete desired state
{
"type": "setSubscriptionList",
"subscriptions": [
{
"subscriptionIdentifier": "display-001;cam-001",
"rtspUrl": "rtsp://192.168.1.100/stream1",
"modelUrl": "http://storage/models/vehicle-id.mpta?token=fresh-token",
"modelId": 201,
"modelName": "Vehicle Identification",
"cropX1": 100, "cropY1": 200, "cropX2": 300, "cropY2": 400
}
]
}
```
**Worker Reconciliation Logic:**
- **Add new subscriptions** not in current state
- **Remove obsolete subscriptions** not in desired state
- **Update existing subscriptions** with fresh model URLs (handles S3 expiration)
- **Single stream optimization** - share RTSP streams across multiple subscriptions
##### 2. **Protocol-Compliant State Reports**
```python
# stateReport with flat structure per worker.md specification
{
"type": "stateReport",
"cpuUsage": 75.5,
"memoryUsage": 40.2,
"gpuUsage": 60.0,
"gpuMemoryUsage": 25.1,
"cameraConnections": [
{
"subscriptionIdentifier": "display-001;cam-001",
"modelId": 101,
"modelName": "General Object Detection",
"online": true,
"cropX1": 100, "cropY1": 200, "cropX2": 300, "cropY2": 400
}
]
}
```
##### 3. **Session Context in Detection Results**
```python
# imageDetection with sessionId for proper session linking
{
"type": "imageDetection",
"subscriptionIdentifier": "display-001;cam-001",
"timestamp": "2025-09-12T10:00:00.000Z",
"sessionId": 12345, # Critical for CMS session tracking
"data": {
"detection": {
"carBrand": "Honda",
"carModel": "CR-V",
"licensePlateText": "ABC-123"
},
"modelId": 201,
"modelName": "Vehicle Identification"
}
}
```
#### 🚀 **Enhanced Features**
1. **State Recovery**: Complete subscription restoration on worker reconnection
2. **Fresh Model URLs**: Automatic S3 URL refresh via subscription updates
3. **Multi-Display Support**: Proper display identifier parsing and session management
4. **Load Balancing Ready**: Backend can distribute subscriptions across workers
5. **Session Continuity**: Session IDs maintained across worker disconnections
## WebSocket Communication Flow
### Client Connection Lifecycle
### Client Connection Lifecycle (Protocol-Compliant)
```mermaid
sequenceDiagram
participant Client as WebSocket Client
participant Backend as CMS Backend
participant WS as WebSocketHandler
participant CM as ConnectionManager
participant MP as MessageProcessor
participant SM as StreamManager
participant HB as Heartbeat Loop
Client->>WS: WebSocket Connection
WS->>WS: handle_websocket()
WS->>CM: add_connection()
CM->>CM: create WebSocketConnection
WS->>Client: Connection Accepted
Backend->>WS: WebSocket Connection
WS->>WS: handle_connection()
WS->>HB: start_heartbeat_loop()
WS->>Backend: Connection Accepted
loop Message Processing
Client->>WS: JSON Message
WS->>MP: process_message()
loop Heartbeat (every 2 seconds)
HB->>Backend: stateReport {cpuUsage, memoryUsage, gpuUsage, cameraConnections[]}
end
alt Subscribe Message
MP->>SM: create_stream()
SM->>SM: initialize StreamReader
MP->>Client: subscribeAck
else Unsubscribe Message
MP->>SM: remove_stream()
SM->>SM: cleanup StreamReader
MP->>Client: unsubscribeAck
else State Request
MP->>MP: collect_system_state()
MP->>Client: stateReport
loop Message Processing (Protocol Commands)
Backend->>WS: JSON Message
WS->>MP: parse_message()
alt setSubscriptionList (Declarative)
MP->>MP: validate_subscription_list()
WS->>WS: reconcile_subscriptions()
WS->>SM: add/remove/update_streams()
SM->>SM: handle_stream_lifecycle()
WS->>HB: update_camera_connections()
else setSessionId
MP->>MP: validate_set_session()
WS->>WS: store_session_for_display()
WS->>WS: apply_to_all_display_subscriptions()
else setProgressionStage
MP->>MP: validate_progression_stage()
WS->>WS: store_stage_for_display()
WS->>WS: apply_context_aware_processing()
else requestState
WS->>Backend: stateReport (immediate response)
else patchSessionResult
MP->>MP: validate_patch_result()
WS->>WS: log_patch_response()
end
end
Client->>WS: Disconnect
WS->>CM: remove_connection()
WS->>SM: cleanup_client_streams()
Backend->>WS: Disconnect
WS->>SM: cleanup_all_streams()
WS->>HB: stop_heartbeat_loop()
```
### Message Processing Detail
#### 1. Subscribe Message Flow (`message_processor.py:125-185`)
#### 1. setSubscriptionList Flow (`websocket_handler.py:355-453`) - Protocol Compliant
```python
async def _handle_subscribe(self, payload: Dict, client_id: str) -> Dict:
"""Process subscription request"""
async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None:
"""Handle setSubscriptionList command - declarative subscription management"""
# 1. Extract subscription parameters
subscription_id = payload["subscriptionIdentifier"]
stream_url = payload.get("rtspUrl") or payload.get("snapshotUrl")
model_url = payload["modelUrl"]
subscriptions = data.get("subscriptions", [])
# 2. Create stream configuration
stream_config = StreamConfig(
stream_url=stream_url,
stream_type="rtsp" if "rtsp" in stream_url else "http_snapshot",
crop_region=[payload.get("cropX1"), payload.get("cropY1"),
payload.get("cropX2"), payload.get("cropY2")]
)
# 1. Get current and desired subscription states
current_subscriptions = set(subscription_to_camera.keys())
desired_subscriptions = set()
subscription_configs = {}
# 3. Load ML pipeline
pipeline_config = await pipeline_loader.load_from_url(model_url)
for sub_config in subscriptions:
sub_id = sub_config.get("subscriptionIdentifier")
if sub_id:
desired_subscriptions.add(sub_id)
subscription_configs[sub_id] = sub_config
# 4. Create stream (with sharing if same URL)
stream_info = await stream_manager.create_stream(
camera_id=subscription_id.split(';')[1],
config=stream_config,
subscription_id=subscription_id
)
# Extract display ID for session management
parts = sub_id.split(";")
if len(parts) >= 2:
display_id = parts[0]
self.display_identifiers.add(display_id)
# 5. Register client subscription
connection_manager.add_subscription(client_id, subscription_id)
# 2. Calculate reconciliation changes
to_add = desired_subscriptions - current_subscriptions
to_remove = current_subscriptions - desired_subscriptions
to_update = desired_subscriptions & current_subscriptions
return {"type": "subscribeAck", "status": "success",
"subscriptionId": subscription_id}
# 3. Remove obsolete subscriptions
for sub_id in to_remove:
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
subscription_to_camera.pop(sub_id, None)
self.session_cache.clear_session(camera_id)
# 4. Add new subscriptions
for sub_id in to_add:
await self._start_subscription(sub_id, subscription_configs[sub_id])
# 5. Update existing subscriptions (handle S3 URL refresh)
for sub_id in to_update:
current_config = subscription_to_camera.get(sub_id)
new_config = subscription_configs[sub_id]
# Restart if model URL changed (handles S3 expiration)
current_model_url = getattr(current_config, 'model_url', None)
new_model_url = new_config.get("modelUrl")
if current_model_url != new_model_url:
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
await self._start_subscription(sub_id, new_config)
async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None:
"""Start individual subscription with protocol-compliant configuration"""
# Extract camera ID from subscription identifier
parts = subscription_id.split(";")
camera_id = parts[1] if len(parts) >= 2 else subscription_id
# Store subscription mapping
subscription_to_camera[subscription_id] = camera_id
# Start camera stream with full config
await self.stream_manager.start_stream(camera_id, config)
# Load ML model from fresh URL
model_id = config.get("modelId")
model_url = config.get("modelUrl")
if model_id and model_url:
await self.model_manager.load_model(camera_id, model_id, model_url)
```
#### 2. Detection Result Broadcasting (`websocket_handler.py:245-265`)
#### 2. Detection Result Broadcasting (`websocket_handler.py:589-615`) - Protocol Compliant
```python
async def broadcast_detection_result(self, subscription_id: str,
detection_result: Dict):
"""Broadcast detection to subscribed clients"""
async def _send_detection_result(self, camera_id: str, stream_info: Dict[str, Any],
detection_result: DetectionResult) -> None:
"""Send detection result with protocol-compliant format"""
message = {
# Get session ID for this display (protocol requirement)
subscription_id = stream_info["subscriptionIdentifier"]
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
session_id = self.session_ids.get(display_id) # Can be None for null sessions
# Protocol-compliant imageDetection format per worker.md
detection_data = {
"type": "imageDetection",
"payload": {
"subscriptionId": subscription_id,
"detections": detection_result["detections"],
"timestamp": detection_result["timestamp"],
"modelInfo": detection_result["model_info"]
"subscriptionIdentifier": subscription_id, # Required at root level
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"sessionId": session_id, # Required for session linking
"data": {
"detection": detection_result.to_dict(), # Flat detection object
"modelId": stream_info["modelId"],
"modelName": stream_info["modelName"]
}
}
await self.connection_manager.broadcast_to_subscription(
subscription_id, message
)
# Send to backend via WebSocket
await self.websocket.send_json(detection_data)
```
#### 3. State Report Broadcasting (`websocket_handler.py:201-209`) - Protocol Compliant
```python
async def _send_heartbeat(self) -> None:
"""Send protocol-compliant stateReport every 2 seconds"""
while self.connected:
# Get system metrics
metrics = get_system_metrics()
# Build cameraConnections array as required by protocol
camera_connections = []
with self.stream_manager.streams_lock:
for camera_id, stream_info in self.stream_manager.streams.items():
is_online = self.stream_manager.is_stream_active(camera_id)
connection_info = {
"subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id),
"modelId": stream_info.get("modelId", 0),
"modelName": stream_info.get("modelName", "Unknown Model"),
"online": is_online
}
# Add crop coordinates if available (optional per protocol)
for coord in ["cropX1", "cropY1", "cropX2", "cropY2"]:
if coord in stream_info:
connection_info[coord] = stream_info[coord]
camera_connections.append(connection_info)
# Protocol-compliant stateReport format (worker.md lines 169-189)
state_data = {
"type": "stateReport", # Message type
"cpuUsage": metrics.get("cpu_percent", 0), # CPU percentage
"memoryUsage": metrics.get("memory_percent", 0), # Memory percentage
"gpuUsage": metrics.get("gpu_percent", 0), # GPU percentage
"gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # GPU memory
"cameraConnections": camera_connections # Camera details array
}
await self.websocket.send_json(state_data)
await asyncio.sleep(2) # 2-second heartbeat interval per protocol
```
## Detection Pipeline Flow

View file

@ -7,7 +7,7 @@ It provides a clean separation between message handling and business logic.
import json
import logging
import time
from typing import Dict, Any, Optional, Callable, Tuple
from typing import Dict, Any, Optional, Callable, Tuple, List
from enum import Enum
from dataclasses import dataclass, field
@ -19,13 +19,13 @@ msg_proc_logger = logging.getLogger("websocket.message_processor") # Detailed m
class MessageType(Enum):
"""Enumeration of supported WebSocket message types."""
SUBSCRIBE = "subscribe"
UNSUBSCRIBE = "unsubscribe"
"""Enumeration of supported WebSocket message types per worker.md protocol."""
SET_SUBSCRIPTION_LIST = "setSubscriptionList"
REQUEST_STATE = "requestState"
SET_SESSION_ID = "setSessionId"
PATCH_SESSION = "patchSession"
SET_PROGRESSION_STAGE = "setProgressionStage"
PATCH_SESSION_RESULT = "patchSessionResult"
IMAGE_DETECTION = "imageDetection"
STATE_REPORT = "stateReport"
ACK = "ack"
@ -41,58 +41,47 @@ class ProgressionStage(Enum):
@dataclass
class SubscribePayload:
"""Payload for subscription messages."""
class SubscriptionObject:
"""Subscription object per worker.md protocol specification."""
subscription_identifier: str
rtsp_url: str
model_url: str
model_id: int
model_name: str
model_url: str
rtsp_url: Optional[str] = None
snapshot_url: Optional[str] = None
snapshot_interval: Optional[int] = None
crop_x1: Optional[int] = None
crop_y1: Optional[int] = None
crop_x2: Optional[int] = None
crop_y2: Optional[int] = None
extra_params: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SubscribePayload":
"""Create SubscribePayload from dictionary."""
# Extract known fields
known_fields = {
"subscription_identifier": data.get("subscriptionIdentifier"),
"model_id": data.get("modelId"),
"model_name": data.get("modelName"),
"model_url": data.get("modelUrl"),
"rtsp_url": data.get("rtspUrl"),
"snapshot_url": data.get("snapshotUrl"),
"snapshot_interval": data.get("snapshotInterval"),
"crop_x1": data.get("cropX1"),
"crop_y1": data.get("cropY1"),
"crop_x2": data.get("cropX2"),
"crop_y2": data.get("cropY2")
}
# Extract extra parameters
extra_params = {k: v for k, v in data.items()
if k not in ["subscriptionIdentifier", "modelId", "modelName",
"modelUrl", "rtspUrl", "snapshotUrl", "snapshotInterval",
"cropX1", "cropY1", "cropX2", "cropY2"]}
return cls(**known_fields, extra_params=extra_params)
def from_dict(cls, data: Dict[str, Any]) -> "SubscriptionObject":
"""Create SubscriptionObject from dictionary."""
return cls(
subscription_identifier=data.get("subscriptionIdentifier", ""),
rtsp_url=data.get("rtspUrl", ""),
model_url=data.get("modelUrl", ""),
model_id=data.get("modelId", 0),
model_name=data.get("modelName", ""),
snapshot_url=data.get("snapshotUrl"),
snapshot_interval=data.get("snapshotInterval"),
crop_x1=data.get("cropX1"),
crop_y1=data.get("cropY1"),
crop_x2=data.get("cropX2"),
crop_y2=data.get("cropY2")
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format for stream configuration."""
result = {
"subscriptionIdentifier": self.subscription_identifier,
"rtspUrl": self.rtsp_url,
"modelUrl": self.model_url,
"modelId": self.model_id,
"modelName": self.model_name,
"modelUrl": self.model_url
"modelName": self.model_name
}
if self.rtsp_url:
result["rtspUrl"] = self.rtsp_url
if self.snapshot_url:
result["snapshotUrl"] = self.snapshot_url
if self.snapshot_interval is not None:
@ -107,12 +96,22 @@ class SubscribePayload:
if self.crop_y2 is not None:
result["cropY2"] = self.crop_y2
# Add any extra parameters
result.update(self.extra_params)
return result
@dataclass
class SetSubscriptionListPayload:
"""Payload for setSubscriptionList command per worker.md protocol."""
subscriptions: List[SubscriptionObject]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SetSubscriptionListPayload":
"""Create SetSubscriptionListPayload from dictionary."""
subscriptions_data = data.get("subscriptions", [])
subscriptions = [SubscriptionObject.from_dict(sub) for sub in subscriptions_data]
return cls(subscriptions=subscriptions)
@dataclass
class SessionPayload:
"""Payload for session-related messages."""
@ -142,11 +141,11 @@ class MessageProcessor:
def __init__(self):
"""Initialize the message processor."""
self.validators: Dict[MessageType, Callable] = {
MessageType.SUBSCRIBE: self._validate_subscribe,
MessageType.UNSUBSCRIBE: self._validate_unsubscribe,
MessageType.SET_SUBSCRIPTION_LIST: self._validate_set_subscription_list,
MessageType.SET_SESSION_ID: self._validate_set_session,
MessageType.PATCH_SESSION: self._validate_patch_session,
MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage
MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage,
MessageType.PATCH_SESSION_RESULT: self._validate_patch_session_result
}
def parse_message(self, raw_message: str) -> Tuple[MessageType, Dict[str, Any]]:
@ -222,58 +221,69 @@ class MessageProcessor:
msg_proc_logger.error(f"❌ VALIDATION FAILED FOR {msg_type.value}: {e}")
raise
def _validate_subscribe(self, data: Dict[str, Any]) -> SubscribePayload:
"""Validate subscribe message."""
def _validate_set_subscription_list(self, data: Dict[str, Any]) -> SetSubscriptionListPayload:
"""Validate setSubscriptionList message per worker.md protocol."""
subscriptions_data = data.get("subscriptions", [])
if not isinstance(subscriptions_data, list):
raise ValidationError("subscriptions must be an array")
# Validate each subscription object
for i, sub_data in enumerate(subscriptions_data):
if not isinstance(sub_data, dict):
raise ValidationError(f"Subscription {i} must be an object")
# Required fields per protocol
required = ["subscriptionIdentifier", "rtspUrl", "modelUrl", "modelId", "modelName"]
missing = [field for field in required if not sub_data.get(field)]
if missing:
raise ValidationError(f"Subscription {i} missing required fields: {', '.join(missing)}")
# Validate snapshot interval if snapshot URL is provided
if sub_data.get("snapshotUrl") and not sub_data.get("snapshotInterval"):
raise ValidationError(f"Subscription {i}: snapshotInterval is required when using snapshotUrl")
# Validate crop coordinates
crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"]
crop_values = [sub_data.get(coord) for coord in crop_coords]
if any(v is not None for v in crop_values):
# If any crop coordinate is set, all must be set
if not all(v is not None for v in crop_values):
raise ValidationError(f"Subscription {i}: All crop coordinates must be specified if any are set")
# Validate coordinate values
x1, y1, x2, y2 = crop_values
if x2 <= x1 or y2 <= y1:
raise ValidationError(f"Subscription {i}: Invalid crop coordinates: x2 must be > x1, y2 must be > y1")
return SetSubscriptionListPayload.from_dict(data)
def _validate_patch_session_result(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate patchSessionResult message."""
payload = data.get("payload", {})
# Required fields
required = ["subscriptionIdentifier", "modelId", "modelName", "modelUrl"]
missing = [field for field in required if not payload.get(field)]
if missing:
raise ValidationError(f"Missing required fields: {', '.join(missing)}")
# Required fields per protocol
session_id = payload.get("sessionId")
if session_id is None:
raise ValidationError("Missing required field: sessionId")
# Must have either rtspUrl or snapshotUrl
if not payload.get("rtspUrl") and not payload.get("snapshotUrl"):
raise ValidationError("Must provide either rtspUrl or snapshotUrl")
# Validate snapshot interval if snapshot URL is provided
if payload.get("snapshotUrl") and not payload.get("snapshotInterval"):
raise ValidationError("snapshotInterval is required when using snapshotUrl")
# Validate crop coordinates
crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"]
crop_values = [payload.get(coord) for coord in crop_coords]
if any(v is not None for v in crop_values):
# If any crop coordinate is set, all must be set
if not all(v is not None for v in crop_values):
raise ValidationError("All crop coordinates must be specified if any are set")
# Validate coordinate values
x1, y1, x2, y2 = crop_values
if x2 <= x1 or y2 <= y1:
raise ValidationError("Invalid crop coordinates: x2 must be > x1, y2 must be > y1")
return SubscribePayload.from_dict(payload)
def _validate_unsubscribe(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate unsubscribe message."""
payload = data.get("payload", {})
if not payload.get("subscriptionIdentifier"):
raise ValidationError("Missing required field: subscriptionIdentifier")
success = payload.get("success")
if success is None:
raise ValidationError("Missing required field: success")
return payload
def _validate_set_session(self, data: Dict[str, Any]) -> SessionPayload:
"""Validate setSessionId message."""
"""Validate setSessionId message per worker.md protocol."""
payload = data.get("payload", {})
display_id = payload.get("displayIdentifier")
session_id = payload.get("sessionId")
session_id = payload.get("sessionId") # Can be null to clear session
if not display_id:
raise ValidationError("Missing required field: displayIdentifier")
if not session_id:
# sessionId can be null to clear the session per protocol
if "sessionId" not in payload:
raise ValidationError("Missing required field: sessionId")
return SessionPayload(display_id, session_id)

View file

@ -83,16 +83,17 @@ class WebSocketHandler:
# Message handlers
self.message_handlers: Dict[str, MessageHandler] = {
"subscribe": self._handle_subscribe,
"unsubscribe": self._handle_unsubscribe,
"setSubscriptionList": self._handle_set_subscription_list,
"requestState": self._handle_request_state,
"setSessionId": self._handle_set_session,
"patchSession": self._handle_patch_session,
"setProgressionStage": self._handle_set_progression_stage
"setProgressionStage": self._handle_set_progression_stage,
"patchSessionResult": self._handle_patch_session_result
}
# Session and display management
self.session_ids: Dict[str, str] = {} # display_identifier -> session_id
self.progression_stages: Dict[str, str] = {} # display_identifier -> progression_stage
self.display_identifiers: Set[str] = set()
# Camera monitor
@ -171,22 +172,40 @@ class WebSocketHandler:
# Get system metrics
metrics = get_system_metrics()
# Get active streams info
active_streams = self.stream_manager.get_active_streams()
active_models = self.model_manager.get_loaded_models()
# Build cameraConnections array as required by protocol
camera_connections = []
with self.stream_manager.streams_lock:
for camera_id, stream_info in self.stream_manager.streams.items():
# Check if camera is online
is_online = self.stream_manager.is_stream_active(camera_id)
connection_info = {
"subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id),
"modelId": stream_info.get("modelId", 0),
"modelName": stream_info.get("modelName", "Unknown Model"),
"online": is_online
}
# Add crop coordinates if available
if "cropX1" in stream_info:
connection_info["cropX1"] = stream_info["cropX1"]
if "cropY1" in stream_info:
connection_info["cropY1"] = stream_info["cropY1"]
if "cropX2" in stream_info:
connection_info["cropX2"] = stream_info["cropX2"]
if "cropY2" in stream_info:
connection_info["cropY2"] = stream_info["cropY2"]
camera_connections.append(connection_info)
# Protocol-compliant stateReport format (worker.md lines 169-189)
state_data = {
"type": "stateReport",
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"data": {
"activeStreams": len(active_streams),
"loadedModels": len(active_models),
"cpuUsage": metrics.get("cpu_percent", 0),
"memoryUsage": metrics.get("memory_percent", 0),
"gpuUsage": metrics.get("gpu_percent", 0),
"gpuMemory": metrics.get("gpu_memory_percent", 0),
"uptime": time.time() - metrics.get("start_time", time.time())
}
"cpuUsage": metrics.get("cpu_percent", 0),
"memoryUsage": metrics.get("memory_percent", 0),
"gpuUsage": metrics.get("gpu_percent", 0),
"gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # Fixed field name
"cameraConnections": camera_connections
}
# Compact JSON for RX/TX logging
@ -351,75 +370,105 @@ class WebSocketHandler:
traceback.print_exc()
return persistent_data
async def _handle_subscribe(self, data: Dict[str, Any]) -> None:
"""Handle stream subscription request."""
payload = data.get("payload", {})
subscription_id = payload.get("subscriptionIdentifier")
async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None:
"""
Handle setSubscriptionList command - declarative subscription management.
if not subscription_id:
logger.error("Missing subscriptionIdentifier in subscribe payload")
return
This is the primary subscription command per worker.md protocol.
Workers must reconcile the new subscription list with current state.
"""
subscriptions = data.get("subscriptions", [])
try:
# Extract display and camera IDs
# Get current subscription identifiers
current_subscriptions = set(subscription_to_camera.keys())
# Get desired subscription identifiers
desired_subscriptions = set()
subscription_configs = {}
for sub_config in subscriptions:
sub_id = sub_config.get("subscriptionIdentifier")
if sub_id:
desired_subscriptions.add(sub_id)
subscription_configs[sub_id] = sub_config
# Extract display ID for session management
parts = sub_id.split(";")
if len(parts) >= 2:
display_id = parts[0]
self.display_identifiers.add(display_id)
# Calculate changes needed
to_add = desired_subscriptions - current_subscriptions
to_remove = current_subscriptions - desired_subscriptions
to_update = desired_subscriptions & current_subscriptions
logger.info(f"Subscription reconciliation: add={len(to_add)}, remove={len(to_remove)}, update={len(to_update)}")
# Remove obsolete subscriptions
for sub_id in to_remove:
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
subscription_to_camera.pop(sub_id, None)
self.session_cache.clear_session(camera_id)
logger.info(f"Removed subscription: {sub_id}")
# Add new subscriptions
for sub_id in to_add:
await self._start_subscription(sub_id, subscription_configs[sub_id])
logger.info(f"Added subscription: {sub_id}")
# Update existing subscriptions if needed
for sub_id in to_update:
# Check if configuration changed (model URL, crop coordinates, etc.)
current_config = subscription_to_camera.get(sub_id)
new_config = subscription_configs[sub_id]
# For now, restart subscription if model URL changed (handles S3 expiration)
current_model_url = getattr(current_config, 'model_url', None) if current_config else None
new_model_url = new_config.get("modelUrl")
if current_model_url != new_model_url:
# Restart with new configuration
camera_id = subscription_to_camera.get(sub_id)
if camera_id:
await self.stream_manager.stop_stream(camera_id)
self.model_manager.unload_models(camera_id)
await self._start_subscription(sub_id, new_config)
logger.info(f"Updated subscription: {sub_id}")
logger.info(f"Subscription list reconciliation completed. Active: {len(desired_subscriptions)}")
except Exception as e:
logger.error(f"Error handling setSubscriptionList: {e}")
traceback.print_exc()
async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None:
"""Start a single subscription with given configuration."""
try:
# Extract camera ID from subscription identifier
parts = subscription_id.split(";")
if len(parts) >= 2:
display_id = parts[0]
camera_id = parts[1]
self.display_identifiers.add(display_id)
else:
camera_id = subscription_id
camera_id = parts[1] if len(parts) >= 2 else subscription_id
# Store subscription mapping
subscription_to_camera[subscription_id] = camera_id
# Start camera stream
await self.stream_manager.start_stream(camera_id, payload)
await self.stream_manager.start_stream(camera_id, config)
# Load model
model_id = payload.get("modelId")
model_url = payload.get("modelUrl")
model_id = config.get("modelId")
model_url = config.get("modelUrl")
if model_id and model_url:
await self.model_manager.load_model(camera_id, model_id, model_url)
logger.info(f"Subscribed to stream: {subscription_id}")
except Exception as e:
logger.error(f"Error handling subscription: {e}")
traceback.print_exc()
async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None:
"""Handle stream unsubscription request."""
payload = data.get("payload", {})
subscription_id = payload.get("subscriptionIdentifier")
if not subscription_id:
logger.error("Missing subscriptionIdentifier in unsubscribe payload")
return
try:
# Get camera ID from subscription
camera_id = subscription_to_camera.get(subscription_id)
if not camera_id:
logger.warning(f"No camera found for subscription: {subscription_id}")
return
# Stop stream
await self.stream_manager.stop_stream(camera_id)
# Unload model
self.model_manager.unload_models(camera_id)
# Clean up mappings
subscription_to_camera.pop(subscription_id, None)
# Clean up session state
self.session_cache.clear_session(camera_id)
logger.info(f"Unsubscribed from stream: {subscription_id}")
except Exception as e:
logger.error(f"Error handling unsubscription: {e}")
logger.error(f"Error starting subscription {subscription_id}: {e}")
raise
traceback.print_exc()
async def _handle_request_state(self, data: Dict[str, Any]) -> None:
@ -535,6 +584,26 @@ class WebSocketHandler:
pipeline_state["progression_stage"] = progression_stage
logger.info(f"📋 Camera {camera_id}: Progression stage set to {progression_stage}")
# Store progression stage for this display
if display_id and progression_stage is not None:
if progression_stage:
self.progression_stages[display_id] = progression_stage
else:
# Clear progression stage if null
self.progression_stages.pop(display_id, None)
async def _handle_patch_session_result(self, data: Dict[str, Any]) -> None:
"""Handle patchSessionResult message from backend."""
payload = data.get("payload", {})
session_id = payload.get("sessionId")
success = payload.get("success", False)
message = payload.get("message", "")
if success:
logger.info(f"Patch session {session_id} successful: {message}")
else:
logger.warning(f"Patch session {session_id} failed: {message}")
async def _send_detection_result(
self,
camera_id: str,
@ -542,10 +611,16 @@ class WebSocketHandler:
detection_result: DetectionResult
) -> None:
"""Send detection result over WebSocket."""
# Get session ID for this display
subscription_id = stream_info["subscriptionIdentifier"]
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
session_id = self.session_ids.get(display_id)
detection_data = {
"type": "imageDetection",
"subscriptionIdentifier": stream_info["subscriptionIdentifier"],
"subscriptionIdentifier": subscription_id,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"sessionId": session_id, # Required by protocol
"data": {
"detection": detection_result.to_dict(),
"modelId": stream_info["modelId"],