Fix: update message type to current implementation
This commit is contained in:
parent
b940790e4a
commit
96ecc321ec
3 changed files with 549 additions and 230 deletions
384
ARCHITECTURE.md
384
ARCHITECTURE.md
|
@ -141,18 +141,44 @@ WebSocketConnection() # Per-client connection wrapper
|
|||
#### `message_processor.py` - Message Processing Pipeline
|
||||
```python
|
||||
MessageProcessor()
|
||||
├── process_message() # Main message dispatcher
|
||||
├── _handle_subscribe() # Process subscription requests
|
||||
├── _handle_unsubscribe() # Process unsubscription
|
||||
├── _handle_state_request() # System state requests
|
||||
└── _handle_session_ops() # Session management operations
|
||||
├── parse_message() # Message parsing and validation
|
||||
├── _validate_set_subscription_list() # Validate declarative subscriptions
|
||||
├── _validate_set_session() # Session management validation
|
||||
├── _validate_patch_session() # Patch session validation
|
||||
├── _validate_progression_stage() # Progression stage validation
|
||||
└── _validate_patch_session_result() # Backend response validation
|
||||
|
||||
MessageType(Enum) # Supported message types
|
||||
├── SUBSCRIBE
|
||||
├── UNSUBSCRIBE
|
||||
├── REQUEST_STATE
|
||||
├── SET_SESSION_ID
|
||||
└── PATCH_SESSION
|
||||
MessageType(Enum) # Protocol-compliant message types per worker.md
|
||||
├── SET_SUBSCRIPTION_LIST # Primary declarative subscription command
|
||||
├── REQUEST_STATE # System state requests
|
||||
├── SET_SESSION_ID # Session association (supports null clearing)
|
||||
├── PATCH_SESSION # Session modification requests
|
||||
├── SET_PROGRESSION_STAGE # Real-time progression updates
|
||||
├── PATCH_SESSION_RESULT # Backend responses to patch requests
|
||||
├── IMAGE_DETECTION # Detection results (worker->backend)
|
||||
└── STATE_REPORT # Heartbeat messages (worker->backend)
|
||||
|
||||
# Protocol-Compliant Payload Classes
|
||||
SubscriptionObject() # Individual subscription per worker.md specification
|
||||
├── subscription_identifier # Format: "displayId;cameraId"
|
||||
├── rtsp_url # Required RTSP stream URL
|
||||
├── model_url # Required fresh model URL (1-hour TTL)
|
||||
├── model_id, model_name # Model identification
|
||||
├── snapshot_url # Optional HTTP snapshot URL
|
||||
├── snapshot_interval # Required if snapshot_url provided
|
||||
└── crop_x1/y1/x2/y2 # Optional crop coordinates
|
||||
|
||||
SetSubscriptionListPayload() # Declarative subscription command payload
|
||||
└── subscriptions: List[SubscriptionObject] # Complete desired state
|
||||
|
||||
SessionPayload() # Session management payload
|
||||
├── display_identifier # Target display ID
|
||||
├── session_id # Session ID (can be null for clearing)
|
||||
└── data # Optional session patch data
|
||||
|
||||
ProgressionPayload() # Progression stage payload
|
||||
├── display_identifier # Target display ID
|
||||
└── progression_stage # Stage: welcome|car_fueling|car_waitpayment|null
|
||||
```
|
||||
|
||||
### Stream Management (`detector_worker/streams/`)
|
||||
|
@ -452,105 +478,313 @@ sequenceDiagram
|
|||
return stream_manager.get_latest_frame(camera_id)
|
||||
```
|
||||
|
||||
## Protocol Compliance (worker.md Implementation)
|
||||
|
||||
### Key Protocol Features Implemented
|
||||
|
||||
The WebSocket communication layer has been **fully updated** to comply with the worker.md protocol specification, replacing deprecated patterns with modern declarative subscription management.
|
||||
|
||||
#### ✅ **Protocol-Compliant Message Types**
|
||||
|
||||
| Message Type | Direction | Purpose | Status |
|
||||
|--------------|-----------|---------|---------|
|
||||
| `setSubscriptionList` | Backend→Worker | **Primary subscription command** - declarative management | ✅ **Implemented** |
|
||||
| `setSessionId` | Backend→Worker | Associate session with display (supports null clearing) | ✅ **Implemented** |
|
||||
| `setProgressionStage` | Backend→Worker | Real-time progression updates for context-aware processing | ✅ **Implemented** |
|
||||
| `requestState` | Backend→Worker | Request immediate state report | ✅ **Implemented** |
|
||||
| `patchSessionResult` | Backend→Worker | Response to worker's patch session request | ✅ **Implemented** |
|
||||
| `stateReport` | Worker→Backend | **Heartbeat with performance metrics** (every 2 seconds) | ✅ **Implemented** |
|
||||
| `imageDetection` | Worker→Backend | Real-time detection results with session context | ✅ **Implemented** |
|
||||
| `patchSession` | Worker→Backend | Request modification to session data | ✅ **Implemented** |
|
||||
|
||||
#### ❌ **Deprecated Message Types Removed**
|
||||
|
||||
- `subscribe` - Replaced by declarative `setSubscriptionList`
|
||||
- `unsubscribe` - Handled by empty `setSubscriptionList` array
|
||||
|
||||
#### 🔧 **Key Protocol Implementations**
|
||||
|
||||
##### 1. **Declarative Subscription Management**
|
||||
```python
|
||||
# setSubscriptionList provides complete desired state
|
||||
{
|
||||
"type": "setSubscriptionList",
|
||||
"subscriptions": [
|
||||
{
|
||||
"subscriptionIdentifier": "display-001;cam-001",
|
||||
"rtspUrl": "rtsp://192.168.1.100/stream1",
|
||||
"modelUrl": "http://storage/models/vehicle-id.mpta?token=fresh-token",
|
||||
"modelId": 201,
|
||||
"modelName": "Vehicle Identification",
|
||||
"cropX1": 100, "cropY1": 200, "cropX2": 300, "cropY2": 400
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Worker Reconciliation Logic:**
|
||||
- **Add new subscriptions** not in current state
|
||||
- **Remove obsolete subscriptions** not in desired state
|
||||
- **Update existing subscriptions** with fresh model URLs (handles S3 expiration)
|
||||
- **Single stream optimization** - share RTSP streams across multiple subscriptions
|
||||
|
||||
##### 2. **Protocol-Compliant State Reports**
|
||||
```python
|
||||
# stateReport with flat structure per worker.md specification
|
||||
{
|
||||
"type": "stateReport",
|
||||
"cpuUsage": 75.5,
|
||||
"memoryUsage": 40.2,
|
||||
"gpuUsage": 60.0,
|
||||
"gpuMemoryUsage": 25.1,
|
||||
"cameraConnections": [
|
||||
{
|
||||
"subscriptionIdentifier": "display-001;cam-001",
|
||||
"modelId": 101,
|
||||
"modelName": "General Object Detection",
|
||||
"online": true,
|
||||
"cropX1": 100, "cropY1": 200, "cropX2": 300, "cropY2": 400
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
##### 3. **Session Context in Detection Results**
|
||||
```python
|
||||
# imageDetection with sessionId for proper session linking
|
||||
{
|
||||
"type": "imageDetection",
|
||||
"subscriptionIdentifier": "display-001;cam-001",
|
||||
"timestamp": "2025-09-12T10:00:00.000Z",
|
||||
"sessionId": 12345, # Critical for CMS session tracking
|
||||
"data": {
|
||||
"detection": {
|
||||
"carBrand": "Honda",
|
||||
"carModel": "CR-V",
|
||||
"licensePlateText": "ABC-123"
|
||||
},
|
||||
"modelId": 201,
|
||||
"modelName": "Vehicle Identification"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 🚀 **Enhanced Features**
|
||||
|
||||
1. **State Recovery**: Complete subscription restoration on worker reconnection
|
||||
2. **Fresh Model URLs**: Automatic S3 URL refresh via subscription updates
|
||||
3. **Multi-Display Support**: Proper display identifier parsing and session management
|
||||
4. **Load Balancing Ready**: Backend can distribute subscriptions across workers
|
||||
5. **Session Continuity**: Session IDs maintained across worker disconnections
|
||||
|
||||
## WebSocket Communication Flow
|
||||
|
||||
### Client Connection Lifecycle
|
||||
### Client Connection Lifecycle (Protocol-Compliant)
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client as WebSocket Client
|
||||
participant Backend as CMS Backend
|
||||
participant WS as WebSocketHandler
|
||||
participant CM as ConnectionManager
|
||||
participant MP as MessageProcessor
|
||||
participant SM as StreamManager
|
||||
participant HB as Heartbeat Loop
|
||||
|
||||
Client->>WS: WebSocket Connection
|
||||
WS->>WS: handle_websocket()
|
||||
WS->>CM: add_connection()
|
||||
CM->>CM: create WebSocketConnection
|
||||
WS->>Client: Connection Accepted
|
||||
Backend->>WS: WebSocket Connection
|
||||
WS->>WS: handle_connection()
|
||||
WS->>HB: start_heartbeat_loop()
|
||||
WS->>Backend: Connection Accepted
|
||||
|
||||
loop Message Processing
|
||||
Client->>WS: JSON Message
|
||||
WS->>MP: process_message()
|
||||
loop Heartbeat (every 2 seconds)
|
||||
HB->>Backend: stateReport {cpuUsage, memoryUsage, gpuUsage, cameraConnections[]}
|
||||
end
|
||||
|
||||
alt Subscribe Message
|
||||
MP->>SM: create_stream()
|
||||
SM->>SM: initialize StreamReader
|
||||
MP->>Client: subscribeAck
|
||||
else Unsubscribe Message
|
||||
MP->>SM: remove_stream()
|
||||
SM->>SM: cleanup StreamReader
|
||||
MP->>Client: unsubscribeAck
|
||||
else State Request
|
||||
MP->>MP: collect_system_state()
|
||||
MP->>Client: stateReport
|
||||
loop Message Processing (Protocol Commands)
|
||||
Backend->>WS: JSON Message
|
||||
WS->>MP: parse_message()
|
||||
|
||||
alt setSubscriptionList (Declarative)
|
||||
MP->>MP: validate_subscription_list()
|
||||
WS->>WS: reconcile_subscriptions()
|
||||
WS->>SM: add/remove/update_streams()
|
||||
SM->>SM: handle_stream_lifecycle()
|
||||
WS->>HB: update_camera_connections()
|
||||
|
||||
else setSessionId
|
||||
MP->>MP: validate_set_session()
|
||||
WS->>WS: store_session_for_display()
|
||||
WS->>WS: apply_to_all_display_subscriptions()
|
||||
|
||||
else setProgressionStage
|
||||
MP->>MP: validate_progression_stage()
|
||||
WS->>WS: store_stage_for_display()
|
||||
WS->>WS: apply_context_aware_processing()
|
||||
|
||||
else requestState
|
||||
WS->>Backend: stateReport (immediate response)
|
||||
|
||||
else patchSessionResult
|
||||
MP->>MP: validate_patch_result()
|
||||
WS->>WS: log_patch_response()
|
||||
end
|
||||
end
|
||||
|
||||
Client->>WS: Disconnect
|
||||
WS->>CM: remove_connection()
|
||||
WS->>SM: cleanup_client_streams()
|
||||
Backend->>WS: Disconnect
|
||||
WS->>SM: cleanup_all_streams()
|
||||
WS->>HB: stop_heartbeat_loop()
|
||||
```
|
||||
|
||||
### Message Processing Detail
|
||||
|
||||
#### 1. Subscribe Message Flow (`message_processor.py:125-185`)
|
||||
#### 1. setSubscriptionList Flow (`websocket_handler.py:355-453`) - Protocol Compliant
|
||||
|
||||
```python
|
||||
async def _handle_subscribe(self, payload: Dict, client_id: str) -> Dict:
|
||||
"""Process subscription request"""
|
||||
async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle setSubscriptionList command - declarative subscription management"""
|
||||
|
||||
# 1. Extract subscription parameters
|
||||
subscription_id = payload["subscriptionIdentifier"]
|
||||
stream_url = payload.get("rtspUrl") or payload.get("snapshotUrl")
|
||||
model_url = payload["modelUrl"]
|
||||
subscriptions = data.get("subscriptions", [])
|
||||
|
||||
# 2. Create stream configuration
|
||||
stream_config = StreamConfig(
|
||||
stream_url=stream_url,
|
||||
stream_type="rtsp" if "rtsp" in stream_url else "http_snapshot",
|
||||
crop_region=[payload.get("cropX1"), payload.get("cropY1"),
|
||||
payload.get("cropX2"), payload.get("cropY2")]
|
||||
)
|
||||
# 1. Get current and desired subscription states
|
||||
current_subscriptions = set(subscription_to_camera.keys())
|
||||
desired_subscriptions = set()
|
||||
subscription_configs = {}
|
||||
|
||||
# 3. Load ML pipeline
|
||||
pipeline_config = await pipeline_loader.load_from_url(model_url)
|
||||
for sub_config in subscriptions:
|
||||
sub_id = sub_config.get("subscriptionIdentifier")
|
||||
if sub_id:
|
||||
desired_subscriptions.add(sub_id)
|
||||
subscription_configs[sub_id] = sub_config
|
||||
|
||||
# 4. Create stream (with sharing if same URL)
|
||||
stream_info = await stream_manager.create_stream(
|
||||
camera_id=subscription_id.split(';')[1],
|
||||
config=stream_config,
|
||||
subscription_id=subscription_id
|
||||
)
|
||||
# Extract display ID for session management
|
||||
parts = sub_id.split(";")
|
||||
if len(parts) >= 2:
|
||||
display_id = parts[0]
|
||||
self.display_identifiers.add(display_id)
|
||||
|
||||
# 5. Register client subscription
|
||||
connection_manager.add_subscription(client_id, subscription_id)
|
||||
# 2. Calculate reconciliation changes
|
||||
to_add = desired_subscriptions - current_subscriptions
|
||||
to_remove = current_subscriptions - desired_subscriptions
|
||||
to_update = desired_subscriptions & current_subscriptions
|
||||
|
||||
return {"type": "subscribeAck", "status": "success",
|
||||
"subscriptionId": subscription_id}
|
||||
# 3. Remove obsolete subscriptions
|
||||
for sub_id in to_remove:
|
||||
camera_id = subscription_to_camera.get(sub_id)
|
||||
if camera_id:
|
||||
await self.stream_manager.stop_stream(camera_id)
|
||||
self.model_manager.unload_models(camera_id)
|
||||
subscription_to_camera.pop(sub_id, None)
|
||||
self.session_cache.clear_session(camera_id)
|
||||
|
||||
# 4. Add new subscriptions
|
||||
for sub_id in to_add:
|
||||
await self._start_subscription(sub_id, subscription_configs[sub_id])
|
||||
|
||||
# 5. Update existing subscriptions (handle S3 URL refresh)
|
||||
for sub_id in to_update:
|
||||
current_config = subscription_to_camera.get(sub_id)
|
||||
new_config = subscription_configs[sub_id]
|
||||
|
||||
# Restart if model URL changed (handles S3 expiration)
|
||||
current_model_url = getattr(current_config, 'model_url', None)
|
||||
new_model_url = new_config.get("modelUrl")
|
||||
|
||||
if current_model_url != new_model_url:
|
||||
camera_id = subscription_to_camera.get(sub_id)
|
||||
if camera_id:
|
||||
await self.stream_manager.stop_stream(camera_id)
|
||||
self.model_manager.unload_models(camera_id)
|
||||
await self._start_subscription(sub_id, new_config)
|
||||
|
||||
async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None:
|
||||
"""Start individual subscription with protocol-compliant configuration"""
|
||||
|
||||
# Extract camera ID from subscription identifier
|
||||
parts = subscription_id.split(";")
|
||||
camera_id = parts[1] if len(parts) >= 2 else subscription_id
|
||||
|
||||
# Store subscription mapping
|
||||
subscription_to_camera[subscription_id] = camera_id
|
||||
|
||||
# Start camera stream with full config
|
||||
await self.stream_manager.start_stream(camera_id, config)
|
||||
|
||||
# Load ML model from fresh URL
|
||||
model_id = config.get("modelId")
|
||||
model_url = config.get("modelUrl")
|
||||
if model_id and model_url:
|
||||
await self.model_manager.load_model(camera_id, model_id, model_url)
|
||||
```
|
||||
|
||||
#### 2. Detection Result Broadcasting (`websocket_handler.py:245-265`)
|
||||
#### 2. Detection Result Broadcasting (`websocket_handler.py:589-615`) - Protocol Compliant
|
||||
|
||||
```python
|
||||
async def broadcast_detection_result(self, subscription_id: str,
|
||||
detection_result: Dict):
|
||||
"""Broadcast detection to subscribed clients"""
|
||||
async def _send_detection_result(self, camera_id: str, stream_info: Dict[str, Any],
|
||||
detection_result: DetectionResult) -> None:
|
||||
"""Send detection result with protocol-compliant format"""
|
||||
|
||||
message = {
|
||||
# Get session ID for this display (protocol requirement)
|
||||
subscription_id = stream_info["subscriptionIdentifier"]
|
||||
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
|
||||
session_id = self.session_ids.get(display_id) # Can be None for null sessions
|
||||
|
||||
# Protocol-compliant imageDetection format per worker.md
|
||||
detection_data = {
|
||||
"type": "imageDetection",
|
||||
"payload": {
|
||||
"subscriptionId": subscription_id,
|
||||
"detections": detection_result["detections"],
|
||||
"timestamp": detection_result["timestamp"],
|
||||
"modelInfo": detection_result["model_info"]
|
||||
"subscriptionIdentifier": subscription_id, # Required at root level
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"sessionId": session_id, # Required for session linking
|
||||
"data": {
|
||||
"detection": detection_result.to_dict(), # Flat detection object
|
||||
"modelId": stream_info["modelId"],
|
||||
"modelName": stream_info["modelName"]
|
||||
}
|
||||
}
|
||||
|
||||
await self.connection_manager.broadcast_to_subscription(
|
||||
subscription_id, message
|
||||
)
|
||||
# Send to backend via WebSocket
|
||||
await self.websocket.send_json(detection_data)
|
||||
```
|
||||
|
||||
#### 3. State Report Broadcasting (`websocket_handler.py:201-209`) - Protocol Compliant
|
||||
|
||||
```python
|
||||
async def _send_heartbeat(self) -> None:
|
||||
"""Send protocol-compliant stateReport every 2 seconds"""
|
||||
|
||||
while self.connected:
|
||||
# Get system metrics
|
||||
metrics = get_system_metrics()
|
||||
|
||||
# Build cameraConnections array as required by protocol
|
||||
camera_connections = []
|
||||
with self.stream_manager.streams_lock:
|
||||
for camera_id, stream_info in self.stream_manager.streams.items():
|
||||
is_online = self.stream_manager.is_stream_active(camera_id)
|
||||
|
||||
connection_info = {
|
||||
"subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id),
|
||||
"modelId": stream_info.get("modelId", 0),
|
||||
"modelName": stream_info.get("modelName", "Unknown Model"),
|
||||
"online": is_online
|
||||
}
|
||||
|
||||
# Add crop coordinates if available (optional per protocol)
|
||||
for coord in ["cropX1", "cropY1", "cropX2", "cropY2"]:
|
||||
if coord in stream_info:
|
||||
connection_info[coord] = stream_info[coord]
|
||||
|
||||
camera_connections.append(connection_info)
|
||||
|
||||
# Protocol-compliant stateReport format (worker.md lines 169-189)
|
||||
state_data = {
|
||||
"type": "stateReport", # Message type
|
||||
"cpuUsage": metrics.get("cpu_percent", 0), # CPU percentage
|
||||
"memoryUsage": metrics.get("memory_percent", 0), # Memory percentage
|
||||
"gpuUsage": metrics.get("gpu_percent", 0), # GPU percentage
|
||||
"gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # GPU memory
|
||||
"cameraConnections": camera_connections # Camera details array
|
||||
}
|
||||
|
||||
await self.websocket.send_json(state_data)
|
||||
await asyncio.sleep(2) # 2-second heartbeat interval per protocol
|
||||
```
|
||||
|
||||
## Detection Pipeline Flow
|
||||
|
|
|
@ -7,7 +7,7 @@ It provides a clean separation between message handling and business logic.
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Callable, Tuple
|
||||
from typing import Dict, Any, Optional, Callable, Tuple, List
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
@ -19,13 +19,13 @@ msg_proc_logger = logging.getLogger("websocket.message_processor") # Detailed m
|
|||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Enumeration of supported WebSocket message types."""
|
||||
SUBSCRIBE = "subscribe"
|
||||
UNSUBSCRIBE = "unsubscribe"
|
||||
"""Enumeration of supported WebSocket message types per worker.md protocol."""
|
||||
SET_SUBSCRIPTION_LIST = "setSubscriptionList"
|
||||
REQUEST_STATE = "requestState"
|
||||
SET_SESSION_ID = "setSessionId"
|
||||
PATCH_SESSION = "patchSession"
|
||||
SET_PROGRESSION_STAGE = "setProgressionStage"
|
||||
PATCH_SESSION_RESULT = "patchSessionResult"
|
||||
IMAGE_DETECTION = "imageDetection"
|
||||
STATE_REPORT = "stateReport"
|
||||
ACK = "ack"
|
||||
|
@ -41,58 +41,47 @@ class ProgressionStage(Enum):
|
|||
|
||||
|
||||
@dataclass
|
||||
class SubscribePayload:
|
||||
"""Payload for subscription messages."""
|
||||
class SubscriptionObject:
|
||||
"""Subscription object per worker.md protocol specification."""
|
||||
subscription_identifier: str
|
||||
rtsp_url: str
|
||||
model_url: str
|
||||
model_id: int
|
||||
model_name: str
|
||||
model_url: str
|
||||
rtsp_url: Optional[str] = None
|
||||
snapshot_url: Optional[str] = None
|
||||
snapshot_interval: Optional[int] = None
|
||||
crop_x1: Optional[int] = None
|
||||
crop_y1: Optional[int] = None
|
||||
crop_x2: Optional[int] = None
|
||||
crop_y2: Optional[int] = None
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SubscribePayload":
|
||||
"""Create SubscribePayload from dictionary."""
|
||||
# Extract known fields
|
||||
known_fields = {
|
||||
"subscription_identifier": data.get("subscriptionIdentifier"),
|
||||
"model_id": data.get("modelId"),
|
||||
"model_name": data.get("modelName"),
|
||||
"model_url": data.get("modelUrl"),
|
||||
"rtsp_url": data.get("rtspUrl"),
|
||||
"snapshot_url": data.get("snapshotUrl"),
|
||||
"snapshot_interval": data.get("snapshotInterval"),
|
||||
"crop_x1": data.get("cropX1"),
|
||||
"crop_y1": data.get("cropY1"),
|
||||
"crop_x2": data.get("cropX2"),
|
||||
"crop_y2": data.get("cropY2")
|
||||
}
|
||||
|
||||
# Extract extra parameters
|
||||
extra_params = {k: v for k, v in data.items()
|
||||
if k not in ["subscriptionIdentifier", "modelId", "modelName",
|
||||
"modelUrl", "rtspUrl", "snapshotUrl", "snapshotInterval",
|
||||
"cropX1", "cropY1", "cropX2", "cropY2"]}
|
||||
|
||||
return cls(**known_fields, extra_params=extra_params)
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SubscriptionObject":
|
||||
"""Create SubscriptionObject from dictionary."""
|
||||
return cls(
|
||||
subscription_identifier=data.get("subscriptionIdentifier", ""),
|
||||
rtsp_url=data.get("rtspUrl", ""),
|
||||
model_url=data.get("modelUrl", ""),
|
||||
model_id=data.get("modelId", 0),
|
||||
model_name=data.get("modelName", ""),
|
||||
snapshot_url=data.get("snapshotUrl"),
|
||||
snapshot_interval=data.get("snapshotInterval"),
|
||||
crop_x1=data.get("cropX1"),
|
||||
crop_y1=data.get("cropY1"),
|
||||
crop_x2=data.get("cropX2"),
|
||||
crop_y2=data.get("cropY2")
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format for stream configuration."""
|
||||
result = {
|
||||
"subscriptionIdentifier": self.subscription_identifier,
|
||||
"rtspUrl": self.rtsp_url,
|
||||
"modelUrl": self.model_url,
|
||||
"modelId": self.model_id,
|
||||
"modelName": self.model_name,
|
||||
"modelUrl": self.model_url
|
||||
"modelName": self.model_name
|
||||
}
|
||||
|
||||
if self.rtsp_url:
|
||||
result["rtspUrl"] = self.rtsp_url
|
||||
if self.snapshot_url:
|
||||
result["snapshotUrl"] = self.snapshot_url
|
||||
if self.snapshot_interval is not None:
|
||||
|
@ -107,12 +96,22 @@ class SubscribePayload:
|
|||
if self.crop_y2 is not None:
|
||||
result["cropY2"] = self.crop_y2
|
||||
|
||||
# Add any extra parameters
|
||||
result.update(self.extra_params)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetSubscriptionListPayload:
|
||||
"""Payload for setSubscriptionList command per worker.md protocol."""
|
||||
subscriptions: List[SubscriptionObject]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SetSubscriptionListPayload":
|
||||
"""Create SetSubscriptionListPayload from dictionary."""
|
||||
subscriptions_data = data.get("subscriptions", [])
|
||||
subscriptions = [SubscriptionObject.from_dict(sub) for sub in subscriptions_data]
|
||||
return cls(subscriptions=subscriptions)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionPayload:
|
||||
"""Payload for session-related messages."""
|
||||
|
@ -142,11 +141,11 @@ class MessageProcessor:
|
|||
def __init__(self):
|
||||
"""Initialize the message processor."""
|
||||
self.validators: Dict[MessageType, Callable] = {
|
||||
MessageType.SUBSCRIBE: self._validate_subscribe,
|
||||
MessageType.UNSUBSCRIBE: self._validate_unsubscribe,
|
||||
MessageType.SET_SUBSCRIPTION_LIST: self._validate_set_subscription_list,
|
||||
MessageType.SET_SESSION_ID: self._validate_set_session,
|
||||
MessageType.PATCH_SESSION: self._validate_patch_session,
|
||||
MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage
|
||||
MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage,
|
||||
MessageType.PATCH_SESSION_RESULT: self._validate_patch_session_result
|
||||
}
|
||||
|
||||
def parse_message(self, raw_message: str) -> Tuple[MessageType, Dict[str, Any]]:
|
||||
|
@ -222,58 +221,69 @@ class MessageProcessor:
|
|||
msg_proc_logger.error(f"❌ VALIDATION FAILED FOR {msg_type.value}: {e}")
|
||||
raise
|
||||
|
||||
def _validate_subscribe(self, data: Dict[str, Any]) -> SubscribePayload:
|
||||
"""Validate subscribe message."""
|
||||
payload = data.get("payload", {})
|
||||
def _validate_set_subscription_list(self, data: Dict[str, Any]) -> SetSubscriptionListPayload:
|
||||
"""Validate setSubscriptionList message per worker.md protocol."""
|
||||
subscriptions_data = data.get("subscriptions", [])
|
||||
|
||||
# Required fields
|
||||
required = ["subscriptionIdentifier", "modelId", "modelName", "modelUrl"]
|
||||
missing = [field for field in required if not payload.get(field)]
|
||||
if not isinstance(subscriptions_data, list):
|
||||
raise ValidationError("subscriptions must be an array")
|
||||
|
||||
# Validate each subscription object
|
||||
for i, sub_data in enumerate(subscriptions_data):
|
||||
if not isinstance(sub_data, dict):
|
||||
raise ValidationError(f"Subscription {i} must be an object")
|
||||
|
||||
# Required fields per protocol
|
||||
required = ["subscriptionIdentifier", "rtspUrl", "modelUrl", "modelId", "modelName"]
|
||||
missing = [field for field in required if not sub_data.get(field)]
|
||||
if missing:
|
||||
raise ValidationError(f"Missing required fields: {', '.join(missing)}")
|
||||
|
||||
# Must have either rtspUrl or snapshotUrl
|
||||
if not payload.get("rtspUrl") and not payload.get("snapshotUrl"):
|
||||
raise ValidationError("Must provide either rtspUrl or snapshotUrl")
|
||||
raise ValidationError(f"Subscription {i} missing required fields: {', '.join(missing)}")
|
||||
|
||||
# Validate snapshot interval if snapshot URL is provided
|
||||
if payload.get("snapshotUrl") and not payload.get("snapshotInterval"):
|
||||
raise ValidationError("snapshotInterval is required when using snapshotUrl")
|
||||
if sub_data.get("snapshotUrl") and not sub_data.get("snapshotInterval"):
|
||||
raise ValidationError(f"Subscription {i}: snapshotInterval is required when using snapshotUrl")
|
||||
|
||||
# Validate crop coordinates
|
||||
crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"]
|
||||
crop_values = [payload.get(coord) for coord in crop_coords]
|
||||
crop_values = [sub_data.get(coord) for coord in crop_coords]
|
||||
if any(v is not None for v in crop_values):
|
||||
# If any crop coordinate is set, all must be set
|
||||
if not all(v is not None for v in crop_values):
|
||||
raise ValidationError("All crop coordinates must be specified if any are set")
|
||||
raise ValidationError(f"Subscription {i}: All crop coordinates must be specified if any are set")
|
||||
|
||||
# Validate coordinate values
|
||||
x1, y1, x2, y2 = crop_values
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
raise ValidationError("Invalid crop coordinates: x2 must be > x1, y2 must be > y1")
|
||||
raise ValidationError(f"Subscription {i}: Invalid crop coordinates: x2 must be > x1, y2 must be > y1")
|
||||
|
||||
return SubscribePayload.from_dict(payload)
|
||||
return SetSubscriptionListPayload.from_dict(data)
|
||||
|
||||
def _validate_unsubscribe(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate unsubscribe message."""
|
||||
def _validate_patch_session_result(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate patchSessionResult message."""
|
||||
payload = data.get("payload", {})
|
||||
|
||||
if not payload.get("subscriptionIdentifier"):
|
||||
raise ValidationError("Missing required field: subscriptionIdentifier")
|
||||
# Required fields per protocol
|
||||
session_id = payload.get("sessionId")
|
||||
if session_id is None:
|
||||
raise ValidationError("Missing required field: sessionId")
|
||||
|
||||
success = payload.get("success")
|
||||
if success is None:
|
||||
raise ValidationError("Missing required field: success")
|
||||
|
||||
return payload
|
||||
|
||||
def _validate_set_session(self, data: Dict[str, Any]) -> SessionPayload:
|
||||
"""Validate setSessionId message."""
|
||||
"""Validate setSessionId message per worker.md protocol."""
|
||||
payload = data.get("payload", {})
|
||||
|
||||
display_id = payload.get("displayIdentifier")
|
||||
session_id = payload.get("sessionId")
|
||||
session_id = payload.get("sessionId") # Can be null to clear session
|
||||
|
||||
if not display_id:
|
||||
raise ValidationError("Missing required field: displayIdentifier")
|
||||
if not session_id:
|
||||
# sessionId can be null to clear the session per protocol
|
||||
if "sessionId" not in payload:
|
||||
raise ValidationError("Missing required field: sessionId")
|
||||
|
||||
return SessionPayload(display_id, session_id)
|
||||
|
|
|
@ -83,16 +83,17 @@ class WebSocketHandler:
|
|||
|
||||
# Message handlers
|
||||
self.message_handlers: Dict[str, MessageHandler] = {
|
||||
"subscribe": self._handle_subscribe,
|
||||
"unsubscribe": self._handle_unsubscribe,
|
||||
"setSubscriptionList": self._handle_set_subscription_list,
|
||||
"requestState": self._handle_request_state,
|
||||
"setSessionId": self._handle_set_session,
|
||||
"patchSession": self._handle_patch_session,
|
||||
"setProgressionStage": self._handle_set_progression_stage
|
||||
"setProgressionStage": self._handle_set_progression_stage,
|
||||
"patchSessionResult": self._handle_patch_session_result
|
||||
}
|
||||
|
||||
# Session and display management
|
||||
self.session_ids: Dict[str, str] = {} # display_identifier -> session_id
|
||||
self.progression_stages: Dict[str, str] = {} # display_identifier -> progression_stage
|
||||
self.display_identifiers: Set[str] = set()
|
||||
|
||||
# Camera monitor
|
||||
|
@ -171,22 +172,40 @@ class WebSocketHandler:
|
|||
# Get system metrics
|
||||
metrics = get_system_metrics()
|
||||
|
||||
# Get active streams info
|
||||
active_streams = self.stream_manager.get_active_streams()
|
||||
active_models = self.model_manager.get_loaded_models()
|
||||
# Build cameraConnections array as required by protocol
|
||||
camera_connections = []
|
||||
with self.stream_manager.streams_lock:
|
||||
for camera_id, stream_info in self.stream_manager.streams.items():
|
||||
# Check if camera is online
|
||||
is_online = self.stream_manager.is_stream_active(camera_id)
|
||||
|
||||
connection_info = {
|
||||
"subscriptionIdentifier": stream_info.get("subscriptionIdentifier", camera_id),
|
||||
"modelId": stream_info.get("modelId", 0),
|
||||
"modelName": stream_info.get("modelName", "Unknown Model"),
|
||||
"online": is_online
|
||||
}
|
||||
|
||||
# Add crop coordinates if available
|
||||
if "cropX1" in stream_info:
|
||||
connection_info["cropX1"] = stream_info["cropX1"]
|
||||
if "cropY1" in stream_info:
|
||||
connection_info["cropY1"] = stream_info["cropY1"]
|
||||
if "cropX2" in stream_info:
|
||||
connection_info["cropX2"] = stream_info["cropX2"]
|
||||
if "cropY2" in stream_info:
|
||||
connection_info["cropY2"] = stream_info["cropY2"]
|
||||
|
||||
camera_connections.append(connection_info)
|
||||
|
||||
# Protocol-compliant stateReport format (worker.md lines 169-189)
|
||||
state_data = {
|
||||
"type": "stateReport",
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"data": {
|
||||
"activeStreams": len(active_streams),
|
||||
"loadedModels": len(active_models),
|
||||
"cpuUsage": metrics.get("cpu_percent", 0),
|
||||
"memoryUsage": metrics.get("memory_percent", 0),
|
||||
"gpuUsage": metrics.get("gpu_percent", 0),
|
||||
"gpuMemory": metrics.get("gpu_memory_percent", 0),
|
||||
"uptime": time.time() - metrics.get("start_time", time.time())
|
||||
}
|
||||
"gpuMemoryUsage": metrics.get("gpu_memory_percent", 0), # Fixed field name
|
||||
"cameraConnections": camera_connections
|
||||
}
|
||||
|
||||
# Compact JSON for RX/TX logging
|
||||
|
@ -351,75 +370,105 @@ class WebSocketHandler:
|
|||
traceback.print_exc()
|
||||
return persistent_data
|
||||
|
||||
async def _handle_subscribe(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle stream subscription request."""
|
||||
payload = data.get("payload", {})
|
||||
subscription_id = payload.get("subscriptionIdentifier")
|
||||
async def _handle_set_subscription_list(self, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Handle setSubscriptionList command - declarative subscription management.
|
||||
|
||||
if not subscription_id:
|
||||
logger.error("Missing subscriptionIdentifier in subscribe payload")
|
||||
return
|
||||
This is the primary subscription command per worker.md protocol.
|
||||
Workers must reconcile the new subscription list with current state.
|
||||
"""
|
||||
subscriptions = data.get("subscriptions", [])
|
||||
|
||||
try:
|
||||
# Extract display and camera IDs
|
||||
parts = subscription_id.split(";")
|
||||
# Get current subscription identifiers
|
||||
current_subscriptions = set(subscription_to_camera.keys())
|
||||
|
||||
# Get desired subscription identifiers
|
||||
desired_subscriptions = set()
|
||||
subscription_configs = {}
|
||||
|
||||
for sub_config in subscriptions:
|
||||
sub_id = sub_config.get("subscriptionIdentifier")
|
||||
if sub_id:
|
||||
desired_subscriptions.add(sub_id)
|
||||
subscription_configs[sub_id] = sub_config
|
||||
|
||||
# Extract display ID for session management
|
||||
parts = sub_id.split(";")
|
||||
if len(parts) >= 2:
|
||||
display_id = parts[0]
|
||||
camera_id = parts[1]
|
||||
self.display_identifiers.add(display_id)
|
||||
else:
|
||||
camera_id = subscription_id
|
||||
|
||||
# Calculate changes needed
|
||||
to_add = desired_subscriptions - current_subscriptions
|
||||
to_remove = current_subscriptions - desired_subscriptions
|
||||
to_update = desired_subscriptions & current_subscriptions
|
||||
|
||||
logger.info(f"Subscription reconciliation: add={len(to_add)}, remove={len(to_remove)}, update={len(to_update)}")
|
||||
|
||||
# Remove obsolete subscriptions
|
||||
for sub_id in to_remove:
|
||||
camera_id = subscription_to_camera.get(sub_id)
|
||||
if camera_id:
|
||||
await self.stream_manager.stop_stream(camera_id)
|
||||
self.model_manager.unload_models(camera_id)
|
||||
subscription_to_camera.pop(sub_id, None)
|
||||
self.session_cache.clear_session(camera_id)
|
||||
logger.info(f"Removed subscription: {sub_id}")
|
||||
|
||||
# Add new subscriptions
|
||||
for sub_id in to_add:
|
||||
await self._start_subscription(sub_id, subscription_configs[sub_id])
|
||||
logger.info(f"Added subscription: {sub_id}")
|
||||
|
||||
# Update existing subscriptions if needed
|
||||
for sub_id in to_update:
|
||||
# Check if configuration changed (model URL, crop coordinates, etc.)
|
||||
current_config = subscription_to_camera.get(sub_id)
|
||||
new_config = subscription_configs[sub_id]
|
||||
|
||||
# For now, restart subscription if model URL changed (handles S3 expiration)
|
||||
current_model_url = getattr(current_config, 'model_url', None) if current_config else None
|
||||
new_model_url = new_config.get("modelUrl")
|
||||
|
||||
if current_model_url != new_model_url:
|
||||
# Restart with new configuration
|
||||
camera_id = subscription_to_camera.get(sub_id)
|
||||
if camera_id:
|
||||
await self.stream_manager.stop_stream(camera_id)
|
||||
self.model_manager.unload_models(camera_id)
|
||||
|
||||
await self._start_subscription(sub_id, new_config)
|
||||
logger.info(f"Updated subscription: {sub_id}")
|
||||
|
||||
logger.info(f"Subscription list reconciliation completed. Active: {len(desired_subscriptions)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling setSubscriptionList: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def _start_subscription(self, subscription_id: str, config: Dict[str, Any]) -> None:
|
||||
"""Start a single subscription with given configuration."""
|
||||
try:
|
||||
# Extract camera ID from subscription identifier
|
||||
parts = subscription_id.split(";")
|
||||
camera_id = parts[1] if len(parts) >= 2 else subscription_id
|
||||
|
||||
# Store subscription mapping
|
||||
subscription_to_camera[subscription_id] = camera_id
|
||||
|
||||
# Start camera stream
|
||||
await self.stream_manager.start_stream(camera_id, payload)
|
||||
await self.stream_manager.start_stream(camera_id, config)
|
||||
|
||||
# Load model
|
||||
model_id = payload.get("modelId")
|
||||
model_url = payload.get("modelUrl")
|
||||
model_id = config.get("modelId")
|
||||
model_url = config.get("modelUrl")
|
||||
if model_id and model_url:
|
||||
await self.model_manager.load_model(camera_id, model_id, model_url)
|
||||
|
||||
logger.info(f"Subscribed to stream: {subscription_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling subscription: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle stream unsubscription request."""
|
||||
payload = data.get("payload", {})
|
||||
subscription_id = payload.get("subscriptionIdentifier")
|
||||
|
||||
if not subscription_id:
|
||||
logger.error("Missing subscriptionIdentifier in unsubscribe payload")
|
||||
return
|
||||
|
||||
try:
|
||||
# Get camera ID from subscription
|
||||
camera_id = subscription_to_camera.get(subscription_id)
|
||||
if not camera_id:
|
||||
logger.warning(f"No camera found for subscription: {subscription_id}")
|
||||
return
|
||||
|
||||
# Stop stream
|
||||
await self.stream_manager.stop_stream(camera_id)
|
||||
|
||||
# Unload model
|
||||
self.model_manager.unload_models(camera_id)
|
||||
|
||||
# Clean up mappings
|
||||
subscription_to_camera.pop(subscription_id, None)
|
||||
|
||||
# Clean up session state
|
||||
self.session_cache.clear_session(camera_id)
|
||||
|
||||
logger.info(f"Unsubscribed from stream: {subscription_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling unsubscription: {e}")
|
||||
logger.error(f"Error starting subscription {subscription_id}: {e}")
|
||||
raise
|
||||
traceback.print_exc()
|
||||
|
||||
async def _handle_request_state(self, data: Dict[str, Any]) -> None:
|
||||
|
@ -535,6 +584,26 @@ class WebSocketHandler:
|
|||
pipeline_state["progression_stage"] = progression_stage
|
||||
logger.info(f"📋 Camera {camera_id}: Progression stage set to {progression_stage}")
|
||||
|
||||
# Store progression stage for this display
|
||||
if display_id and progression_stage is not None:
|
||||
if progression_stage:
|
||||
self.progression_stages[display_id] = progression_stage
|
||||
else:
|
||||
# Clear progression stage if null
|
||||
self.progression_stages.pop(display_id, None)
|
||||
|
||||
async def _handle_patch_session_result(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle patchSessionResult message from backend."""
|
||||
payload = data.get("payload", {})
|
||||
session_id = payload.get("sessionId")
|
||||
success = payload.get("success", False)
|
||||
message = payload.get("message", "")
|
||||
|
||||
if success:
|
||||
logger.info(f"Patch session {session_id} successful: {message}")
|
||||
else:
|
||||
logger.warning(f"Patch session {session_id} failed: {message}")
|
||||
|
||||
async def _send_detection_result(
|
||||
self,
|
||||
camera_id: str,
|
||||
|
@ -542,10 +611,16 @@ class WebSocketHandler:
|
|||
detection_result: DetectionResult
|
||||
) -> None:
|
||||
"""Send detection result over WebSocket."""
|
||||
# Get session ID for this display
|
||||
subscription_id = stream_info["subscriptionIdentifier"]
|
||||
display_id = subscription_id.split(";")[0] if ";" in subscription_id else subscription_id
|
||||
session_id = self.session_ids.get(display_id)
|
||||
|
||||
detection_data = {
|
||||
"type": "imageDetection",
|
||||
"subscriptionIdentifier": stream_info["subscriptionIdentifier"],
|
||||
"subscriptionIdentifier": subscription_id,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"sessionId": session_id, # Required by protocol
|
||||
"data": {
|
||||
"detection": detection_result.to_dict(),
|
||||
"modelId": stream_info["modelId"],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue