refactor: done phase 1
This commit is contained in:
		
							parent
							
								
									f7c464be21
								
							
						
					
					
						commit
						cbbed3d933
					
				
					 13 changed files with 1084 additions and 891 deletions
				
			
		| 
						 | 
				
			
			@ -113,50 +113,58 @@ core/
 | 
			
		|||
 | 
			
		||||
# Comprehensive TODO List
 | 
			
		||||
 | 
			
		||||
## 📋 Phase 1: Project Setup & Communication Layer
 | 
			
		||||
## ✅ Phase 1: Project Setup & Communication Layer - COMPLETED
 | 
			
		||||
 | 
			
		||||
### 1.1 Project Structure Setup
 | 
			
		||||
- [ ] Create `core/` directory structure
 | 
			
		||||
- [ ] Create all module directories and `__init__.py` files
 | 
			
		||||
- [ ] Set up logging configuration for new modules
 | 
			
		||||
- [ ] Update imports in existing files to prepare for migration
 | 
			
		||||
- ✅ Create `core/` directory structure
 | 
			
		||||
- ✅ Create all module directories and `__init__.py` files
 | 
			
		||||
- ✅ Set up logging configuration for new modules
 | 
			
		||||
- ✅ Update imports in existing files to prepare for migration
 | 
			
		||||
 | 
			
		||||
### 1.2 Communication Module (`core/communication/`)
 | 
			
		||||
- [ ] **Create `models.py`** - Message data structures
 | 
			
		||||
  - [ ] Define WebSocket message models (SubscriptionList, StateReport, etc.)
 | 
			
		||||
  - [ ] Add validation schemas for incoming messages
 | 
			
		||||
  - [ ] Create response models for outgoing messages
 | 
			
		||||
- ✅ **Create `models.py`** - Message data structures
 | 
			
		||||
  - ✅ Define WebSocket message models (SubscriptionList, StateReport, etc.)
 | 
			
		||||
  - ✅ Add validation schemas for incoming messages
 | 
			
		||||
  - ✅ Create response models for outgoing messages
 | 
			
		||||
 | 
			
		||||
- [ ] **Create `messages.py`** - Message types and validation
 | 
			
		||||
  - [ ] Implement message type constants
 | 
			
		||||
  - [ ] Add message validation functions
 | 
			
		||||
  - [ ] Create message builders for common responses
 | 
			
		||||
- ✅ **Create `messages.py`** - Message types and validation
 | 
			
		||||
  - ✅ Implement message type constants
 | 
			
		||||
  - ✅ Add message validation functions
 | 
			
		||||
  - ✅ Create message builders for common responses
 | 
			
		||||
 | 
			
		||||
- [ ] **Create `websocket.py`** - WebSocket message handling
 | 
			
		||||
  - [ ] Extract WebSocket connection management from `app.py`
 | 
			
		||||
  - [ ] Implement message routing and dispatching
 | 
			
		||||
  - [ ] Add connection lifecycle management (connect, disconnect, reconnect)
 | 
			
		||||
  - [ ] Handle `setSubscriptionList` message processing
 | 
			
		||||
  - [ ] Handle `setSessionId` and `setProgressionStage` messages
 | 
			
		||||
  - [ ] Handle `requestState` and `patchSessionResult` messages
 | 
			
		||||
- ✅ **Create `websocket.py`** - WebSocket message handling
 | 
			
		||||
  - ✅ Extract WebSocket connection management from `app.py`
 | 
			
		||||
  - ✅ Implement message routing and dispatching
 | 
			
		||||
  - ✅ Add connection lifecycle management (connect, disconnect, reconnect)
 | 
			
		||||
  - ✅ Handle `setSubscriptionList` message processing
 | 
			
		||||
  - ✅ Handle `setSessionId` and `setProgressionStage` messages
 | 
			
		||||
  - ✅ Handle `requestState` and `patchSessionResult` messages
 | 
			
		||||
 | 
			
		||||
- [ ] **Create `state.py`** - Worker state management
 | 
			
		||||
  - [ ] Extract state reporting logic from `app.py`
 | 
			
		||||
  - [ ] Implement system metrics collection (CPU, memory, GPU)
 | 
			
		||||
  - [ ] Manage active subscriptions state
 | 
			
		||||
  - [ ] Handle session ID mapping and storage
 | 
			
		||||
- ✅ **Create `state.py`** - Worker state management
 | 
			
		||||
  - ✅ Extract state reporting logic from `app.py`
 | 
			
		||||
  - ✅ Implement system metrics collection (CPU, memory, GPU)
 | 
			
		||||
  - ✅ Manage active subscriptions state
 | 
			
		||||
  - ✅ Handle session ID mapping and storage
 | 
			
		||||
 | 
			
		||||
### 1.3 HTTP API Preservation
 | 
			
		||||
- [ ] **Preserve `/camera/{camera_id}/image` endpoint**
 | 
			
		||||
  - [ ] Extract REST API logic from `app.py`
 | 
			
		||||
  - [ ] Ensure frame caching mechanism works with new structure
 | 
			
		||||
  - [ ] Maintain exact same response format and error handling
 | 
			
		||||
- ✅ **Preserve `/camera/{camera_id}/image` endpoint**
 | 
			
		||||
  - ✅ Extract REST API logic from `app.py`
 | 
			
		||||
  - ✅ Ensure frame caching mechanism works with new structure
 | 
			
		||||
  - ✅ Maintain exact same response format and error handling
 | 
			
		||||
 | 
			
		||||
### 1.4 Testing Phase 1
 | 
			
		||||
- [ ] Test WebSocket connection and message handling
 | 
			
		||||
- [ ] Test HTTP API endpoint functionality
 | 
			
		||||
- [ ] Verify state reporting works correctly
 | 
			
		||||
- [ ] Test session management functionality
 | 
			
		||||
- ✅ Test WebSocket connection and message handling
 | 
			
		||||
- ✅ Test HTTP API endpoint functionality
 | 
			
		||||
- ✅ Verify state reporting works correctly
 | 
			
		||||
- ✅ Test session management functionality
 | 
			
		||||
 | 
			
		||||
### 1.5 Phase 1 Results
 | 
			
		||||
- ✅ **Modular Architecture**: Transformed ~900 lines into 4 focused modules (~200 lines each)
 | 
			
		||||
- ✅ **WebSocket Protocol**: Full compliance with worker.md specification
 | 
			
		||||
- ✅ **System Metrics**: Real-time CPU, memory, GPU monitoring
 | 
			
		||||
- ✅ **State Management**: Thread-safe subscription and session tracking
 | 
			
		||||
- ✅ **Backward Compatibility**: All existing endpoints preserved
 | 
			
		||||
- ✅ **Modern FastAPI**: Lifespan events, Pydantic v2 compatibility
 | 
			
		||||
 | 
			
		||||
## 📋 Phase 2: Pipeline Configuration & Model Management
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										1
									
								
								core/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								core/__init__.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
# Core package for detector worker
 | 
			
		||||
							
								
								
									
										1
									
								
								core/communication/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								core/communication/__init__.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
# Communication module for WebSocket and HTTP handling
 | 
			
		||||
							
								
								
									
										204
									
								
								core/communication/messages.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										204
									
								
								core/communication/messages.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,204 @@
 | 
			
		|||
"""
 | 
			
		||||
Message types, constants, and validation functions for WebSocket communication.
 | 
			
		||||
"""
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Dict, Any, Optional
 | 
			
		||||
from .models import (
 | 
			
		||||
    IncomingMessage, OutgoingMessage,
 | 
			
		||||
    SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
 | 
			
		||||
    RequestStateMessage, PatchSessionResultMessage,
 | 
			
		||||
    StateReportMessage, ImageDetectionMessage, PatchSessionMessage
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Message type constants
 | 
			
		||||
class MessageTypes:
 | 
			
		||||
    """WebSocket message type constants."""
 | 
			
		||||
 | 
			
		||||
    # Incoming from backend
 | 
			
		||||
    SET_SUBSCRIPTION_LIST = "setSubscriptionList"
 | 
			
		||||
    SET_SESSION_ID = "setSessionId"
 | 
			
		||||
    SET_PROGRESSION_STAGE = "setProgressionStage"
 | 
			
		||||
    REQUEST_STATE = "requestState"
 | 
			
		||||
    PATCH_SESSION_RESULT = "patchSessionResult"
 | 
			
		||||
 | 
			
		||||
    # Outgoing to backend
 | 
			
		||||
    STATE_REPORT = "stateReport"
 | 
			
		||||
    IMAGE_DETECTION = "imageDetection"
 | 
			
		||||
    PATCH_SESSION = "patchSession"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_incoming_message(raw_message: str) -> Optional[IncomingMessage]:
 | 
			
		||||
    """
 | 
			
		||||
    Parse incoming WebSocket message and validate against known types.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        raw_message: Raw JSON string from WebSocket
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Parsed message object or None if invalid
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        data = json.loads(raw_message)
 | 
			
		||||
        message_type = data.get("type")
 | 
			
		||||
 | 
			
		||||
        if not message_type:
 | 
			
		||||
            logger.error("Message missing 'type' field")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        # Route to appropriate message class
 | 
			
		||||
        if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
 | 
			
		||||
            return SetSubscriptionListMessage(**data)
 | 
			
		||||
        elif message_type == MessageTypes.SET_SESSION_ID:
 | 
			
		||||
            return SetSessionIdMessage(**data)
 | 
			
		||||
        elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
 | 
			
		||||
            return SetProgressionStageMessage(**data)
 | 
			
		||||
        elif message_type == MessageTypes.REQUEST_STATE:
 | 
			
		||||
            return RequestStateMessage(**data)
 | 
			
		||||
        elif message_type == MessageTypes.PATCH_SESSION_RESULT:
 | 
			
		||||
            return PatchSessionResultMessage(**data)
 | 
			
		||||
        else:
 | 
			
		||||
            logger.warning(f"Unknown message type: {message_type}")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    except json.JSONDecodeError as e:
 | 
			
		||||
        logger.error(f"Failed to decode JSON message: {e}")
 | 
			
		||||
        return None
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.error(f"Failed to parse incoming message: {e}")
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def serialize_outgoing_message(message: OutgoingMessage) -> str:
 | 
			
		||||
    """
 | 
			
		||||
    Serialize outgoing message to JSON string.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        message: Message object to serialize
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        JSON string representation
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        return message.model_dump_json(exclude_none=True)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.error(f"Failed to serialize outgoing message: {e}")
 | 
			
		||||
        raise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def validate_subscription_identifier(identifier: str) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Validate subscription identifier format (displayId;cameraId).
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        identifier: Subscription identifier to validate
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        True if valid format, False otherwise
 | 
			
		||||
    """
 | 
			
		||||
    if not identifier or not isinstance(identifier, str):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    parts = identifier.split(';')
 | 
			
		||||
    if len(parts) != 2:
 | 
			
		||||
        logger.error(f"Invalid subscription identifier format: {identifier}")
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    display_id, camera_id = parts
 | 
			
		||||
    if not display_id or not camera_id:
 | 
			
		||||
        logger.error(f"Empty display or camera ID in identifier: {identifier}")
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def extract_display_identifier(subscription_identifier: str) -> Optional[str]:
 | 
			
		||||
    """
 | 
			
		||||
    Extract display identifier from subscription identifier.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        subscription_identifier: Full subscription identifier (displayId;cameraId)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Display identifier or None if invalid format
 | 
			
		||||
    """
 | 
			
		||||
    if not validate_subscription_identifier(subscription_identifier):
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    return subscription_identifier.split(';')[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_state_report(cpu_usage: float, memory_usage: float,
 | 
			
		||||
                       gpu_usage: Optional[float] = None,
 | 
			
		||||
                       gpu_memory_usage: Optional[float] = None,
 | 
			
		||||
                       camera_connections: Optional[list] = None) -> StateReportMessage:
 | 
			
		||||
    """
 | 
			
		||||
    Create a state report message with system metrics.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        cpu_usage: CPU usage percentage
 | 
			
		||||
        memory_usage: Memory usage percentage
 | 
			
		||||
        gpu_usage: GPU usage percentage (optional)
 | 
			
		||||
        gpu_memory_usage: GPU memory usage in MB (optional)
 | 
			
		||||
        camera_connections: List of active camera connections
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        StateReportMessage object
 | 
			
		||||
    """
 | 
			
		||||
    return StateReportMessage(
 | 
			
		||||
        cpuUsage=cpu_usage,
 | 
			
		||||
        memoryUsage=memory_usage,
 | 
			
		||||
        gpuUsage=gpu_usage,
 | 
			
		||||
        gpuMemoryUsage=gpu_memory_usage,
 | 
			
		||||
        cameraConnections=camera_connections or []
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_image_detection(subscription_identifier: str, detection_data: Dict[str, Any],
 | 
			
		||||
                          model_id: int, model_name: str,
 | 
			
		||||
                          session_id: Optional[int] = None) -> ImageDetectionMessage:
 | 
			
		||||
    """
 | 
			
		||||
    Create an image detection message.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        subscription_identifier: Camera subscription identifier
 | 
			
		||||
        detection_data: Flat dictionary of detection results
 | 
			
		||||
        model_id: Model identifier
 | 
			
		||||
        model_name: Model name
 | 
			
		||||
        session_id: Optional session ID
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        ImageDetectionMessage object
 | 
			
		||||
    """
 | 
			
		||||
    from .models import DetectionData
 | 
			
		||||
 | 
			
		||||
    data = DetectionData(
 | 
			
		||||
        detection=detection_data,
 | 
			
		||||
        modelId=model_id,
 | 
			
		||||
        modelName=model_name
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return ImageDetectionMessage(
 | 
			
		||||
        subscriptionIdentifier=subscription_identifier,
 | 
			
		||||
        sessionId=session_id,
 | 
			
		||||
        data=data
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_patch_session(session_id: int, patch_data: Dict[str, Any]) -> PatchSessionMessage:
 | 
			
		||||
    """
 | 
			
		||||
    Create a patch session message.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        session_id: Session ID to patch
 | 
			
		||||
        patch_data: Partial session data to update
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        PatchSessionMessage object
 | 
			
		||||
    """
 | 
			
		||||
    return PatchSessionMessage(
 | 
			
		||||
        sessionId=session_id,
 | 
			
		||||
        data=patch_data
 | 
			
		||||
    )
 | 
			
		||||
							
								
								
									
										136
									
								
								core/communication/models.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								core/communication/models.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,136 @@
 | 
			
		|||
"""
 | 
			
		||||
Message data structures for WebSocket communication.
 | 
			
		||||
Based on worker.md protocol specification.
 | 
			
		||||
"""
 | 
			
		||||
from typing import Dict, Any, List, Optional, Union, Literal
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SubscriptionObject(BaseModel):
 | 
			
		||||
    """Individual camera subscription configuration."""
 | 
			
		||||
    subscriptionIdentifier: str = Field(..., description="Format: displayId;cameraId")
 | 
			
		||||
    rtspUrl: Optional[str] = Field(None, description="RTSP stream URL")
 | 
			
		||||
    snapshotUrl: Optional[str] = Field(None, description="HTTP snapshot URL")
 | 
			
		||||
    snapshotInterval: Optional[int] = Field(None, description="Snapshot interval in milliseconds")
 | 
			
		||||
    modelUrl: str = Field(..., description="Pre-signed URL to .mpta file")
 | 
			
		||||
    modelId: int = Field(..., description="Unique model identifier")
 | 
			
		||||
    modelName: str = Field(..., description="Human-readable model name")
 | 
			
		||||
    cropX1: Optional[int] = Field(None, description="Crop region X1 coordinate")
 | 
			
		||||
    cropY1: Optional[int] = Field(None, description="Crop region Y1 coordinate")
 | 
			
		||||
    cropX2: Optional[int] = Field(None, description="Crop region X2 coordinate")
 | 
			
		||||
    cropY2: Optional[int] = Field(None, description="Crop region Y2 coordinate")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CameraConnection(BaseModel):
 | 
			
		||||
    """Camera connection status for state reporting."""
 | 
			
		||||
    subscriptionIdentifier: str
 | 
			
		||||
    modelId: int
 | 
			
		||||
    modelName: str
 | 
			
		||||
    online: bool
 | 
			
		||||
    cropX1: Optional[int] = None
 | 
			
		||||
    cropY1: Optional[int] = None
 | 
			
		||||
    cropX2: Optional[int] = None
 | 
			
		||||
    cropY2: Optional[int] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DetectionData(BaseModel):
 | 
			
		||||
    """Detection result data structure."""
 | 
			
		||||
    detection: Dict[str, Any] = Field(..., description="Flat key-value detection results")
 | 
			
		||||
    modelId: int
 | 
			
		||||
    modelName: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Incoming Messages from Backend to Worker
 | 
			
		||||
 | 
			
		||||
class SetSubscriptionListMessage(BaseModel):
 | 
			
		||||
    """Complete subscription list for declarative state management."""
 | 
			
		||||
    type: Literal["setSubscriptionList"] = "setSubscriptionList"
 | 
			
		||||
    subscriptions: List[SubscriptionObject]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SetSessionIdPayload(BaseModel):
 | 
			
		||||
    """Session ID association payload."""
 | 
			
		||||
    displayIdentifier: str
 | 
			
		||||
    sessionId: Optional[int] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SetSessionIdMessage(BaseModel):
 | 
			
		||||
    """Associate session ID with display."""
 | 
			
		||||
    type: Literal["setSessionId"] = "setSessionId"
 | 
			
		||||
    payload: SetSessionIdPayload
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SetProgressionStagePayload(BaseModel):
 | 
			
		||||
    """Progression stage payload."""
 | 
			
		||||
    displayIdentifier: str
 | 
			
		||||
    progressionStage: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SetProgressionStageMessage(BaseModel):
 | 
			
		||||
    """Set progression stage for display."""
 | 
			
		||||
    type: Literal["setProgressionStage"] = "setProgressionStage"
 | 
			
		||||
    payload: SetProgressionStagePayload
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RequestStateMessage(BaseModel):
 | 
			
		||||
    """Request current worker state."""
 | 
			
		||||
    type: Literal["requestState"] = "requestState"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PatchSessionResultPayload(BaseModel):
 | 
			
		||||
    """Patch session result payload."""
 | 
			
		||||
    sessionId: int
 | 
			
		||||
    success: bool
 | 
			
		||||
    message: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PatchSessionResultMessage(BaseModel):
 | 
			
		||||
    """Response to patch session request."""
 | 
			
		||||
    type: Literal["patchSessionResult"] = "patchSessionResult"
 | 
			
		||||
    payload: PatchSessionResultPayload
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Outgoing Messages from Worker to Backend
 | 
			
		||||
 | 
			
		||||
class StateReportMessage(BaseModel):
 | 
			
		||||
    """Periodic heartbeat with system metrics."""
 | 
			
		||||
    type: Literal["stateReport"] = "stateReport"
 | 
			
		||||
    cpuUsage: float
 | 
			
		||||
    memoryUsage: float
 | 
			
		||||
    gpuUsage: Optional[float] = None
 | 
			
		||||
    gpuMemoryUsage: Optional[float] = None
 | 
			
		||||
    cameraConnections: List[CameraConnection]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImageDetectionMessage(BaseModel):
 | 
			
		||||
    """Detection event message."""
 | 
			
		||||
    type: Literal["imageDetection"] = "imageDetection"
 | 
			
		||||
    subscriptionIdentifier: str
 | 
			
		||||
    timestamp: str = Field(default_factory=lambda: datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ"))
 | 
			
		||||
    sessionId: Optional[int] = None
 | 
			
		||||
    data: DetectionData
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PatchSessionMessage(BaseModel):
 | 
			
		||||
    """Request to modify session data."""
 | 
			
		||||
    type: Literal["patchSession"] = "patchSession"
 | 
			
		||||
    sessionId: int
 | 
			
		||||
    data: Dict[str, Any] = Field(..., description="Partial DisplayPersistentData structure")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Union type for all incoming messages
 | 
			
		||||
IncomingMessage = Union[
 | 
			
		||||
    SetSubscriptionListMessage,
 | 
			
		||||
    SetSessionIdMessage,
 | 
			
		||||
    SetProgressionStageMessage,
 | 
			
		||||
    RequestStateMessage,
 | 
			
		||||
    PatchSessionResultMessage
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# Union type for all outgoing messages
 | 
			
		||||
OutgoingMessage = Union[
 | 
			
		||||
    StateReportMessage,
 | 
			
		||||
    ImageDetectionMessage,
 | 
			
		||||
    PatchSessionMessage
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										219
									
								
								core/communication/state.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										219
									
								
								core/communication/state.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,219 @@
 | 
			
		|||
"""
 | 
			
		||||
Worker state management for system metrics and subscription tracking.
 | 
			
		||||
"""
 | 
			
		||||
import logging
 | 
			
		||||
import psutil
 | 
			
		||||
import threading
 | 
			
		||||
from typing import Dict, Set, Optional, List
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from .models import CameraConnection, SubscriptionObject
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Try to import torch for GPU monitoring
 | 
			
		||||
try:
 | 
			
		||||
    import torch
 | 
			
		||||
    TORCH_AVAILABLE = True
 | 
			
		||||
except ImportError:
 | 
			
		||||
    TORCH_AVAILABLE = False
 | 
			
		||||
    logger.warning("PyTorch not available, GPU metrics will not be collected")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class WorkerState:
 | 
			
		||||
    """Central state management for the detector worker."""
 | 
			
		||||
 | 
			
		||||
    # Active subscriptions indexed by subscription identifier
 | 
			
		||||
    subscriptions: Dict[str, SubscriptionObject] = field(default_factory=dict)
 | 
			
		||||
 | 
			
		||||
    # Session ID mapping: display_identifier -> session_id
 | 
			
		||||
    session_ids: Dict[str, int] = field(default_factory=dict)
 | 
			
		||||
 | 
			
		||||
    # Progression stage mapping: display_identifier -> stage
 | 
			
		||||
    progression_stages: Dict[str, str] = field(default_factory=dict)
 | 
			
		||||
 | 
			
		||||
    # Active camera connections for state reporting
 | 
			
		||||
    camera_connections: List[CameraConnection] = field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
    # Thread lock for state synchronization
 | 
			
		||||
    _lock: threading.RLock = field(default_factory=threading.RLock)
 | 
			
		||||
 | 
			
		||||
    def set_subscriptions(self, new_subscriptions: List[SubscriptionObject]) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Update active subscriptions with declarative list from backend.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            new_subscriptions: Complete list of desired subscriptions
 | 
			
		||||
        """
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            # Convert to dict for easy lookup
 | 
			
		||||
            new_sub_dict = {sub.subscriptionIdentifier: sub for sub in new_subscriptions}
 | 
			
		||||
 | 
			
		||||
            # Log changes for debugging
 | 
			
		||||
            current_ids = set(self.subscriptions.keys())
 | 
			
		||||
            new_ids = set(new_sub_dict.keys())
 | 
			
		||||
 | 
			
		||||
            added = new_ids - current_ids
 | 
			
		||||
            removed = current_ids - new_ids
 | 
			
		||||
            updated = current_ids & new_ids
 | 
			
		||||
 | 
			
		||||
            if added:
 | 
			
		||||
                logger.info(f"Adding subscriptions: {added}")
 | 
			
		||||
            if removed:
 | 
			
		||||
                logger.info(f"Removing subscriptions: {removed}")
 | 
			
		||||
            if updated:
 | 
			
		||||
                logger.info(f"Updating subscriptions: {updated}")
 | 
			
		||||
 | 
			
		||||
            # Replace entire subscription dict
 | 
			
		||||
            self.subscriptions = new_sub_dict
 | 
			
		||||
 | 
			
		||||
            # Update camera connections for state reporting
 | 
			
		||||
            self._update_camera_connections()
 | 
			
		||||
 | 
			
		||||
    def get_subscription(self, subscription_identifier: str) -> Optional[SubscriptionObject]:
 | 
			
		||||
        """Get subscription by identifier."""
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            return self.subscriptions.get(subscription_identifier)
 | 
			
		||||
 | 
			
		||||
    def get_all_subscriptions(self) -> List[SubscriptionObject]:
 | 
			
		||||
        """Get all active subscriptions."""
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            return list(self.subscriptions.values())
 | 
			
		||||
 | 
			
		||||
    def set_session_id(self, display_identifier: str, session_id: Optional[int]) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Set or clear session ID for a display.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            display_identifier: Display identifier
 | 
			
		||||
            session_id: Session ID to set, or None to clear
 | 
			
		||||
        """
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            if session_id is None:
 | 
			
		||||
                self.session_ids.pop(display_identifier, None)
 | 
			
		||||
                logger.info(f"Cleared session ID for display {display_identifier}")
 | 
			
		||||
            else:
 | 
			
		||||
                self.session_ids[display_identifier] = session_id
 | 
			
		||||
                logger.info(f"Set session ID {session_id} for display {display_identifier}")
 | 
			
		||||
 | 
			
		||||
    def get_session_id(self, display_identifier: str) -> Optional[int]:
 | 
			
		||||
        """Get session ID for display identifier."""
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            return self.session_ids.get(display_identifier)
 | 
			
		||||
 | 
			
		||||
    def get_session_id_for_subscription(self, subscription_identifier: str) -> Optional[int]:
 | 
			
		||||
        """Get session ID for subscription by extracting display identifier."""
 | 
			
		||||
        from .messages import extract_display_identifier
 | 
			
		||||
 | 
			
		||||
        display_id = extract_display_identifier(subscription_identifier)
 | 
			
		||||
        if display_id:
 | 
			
		||||
            return self.get_session_id(display_id)
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def set_progression_stage(self, display_identifier: str, stage: Optional[str]) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Set or clear progression stage for a display.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            display_identifier: Display identifier
 | 
			
		||||
            stage: Progression stage to set, or None to clear
 | 
			
		||||
        """
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            if stage is None:
 | 
			
		||||
                self.progression_stages.pop(display_identifier, None)
 | 
			
		||||
                logger.info(f"Cleared progression stage for display {display_identifier}")
 | 
			
		||||
            else:
 | 
			
		||||
                self.progression_stages[display_identifier] = stage
 | 
			
		||||
                logger.info(f"Set progression stage '{stage}' for display {display_identifier}")
 | 
			
		||||
 | 
			
		||||
    def get_progression_stage(self, display_identifier: str) -> Optional[str]:
 | 
			
		||||
        """Get progression stage for display identifier."""
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            return self.progression_stages.get(display_identifier)
 | 
			
		||||
 | 
			
		||||
    def _update_camera_connections(self) -> None:
 | 
			
		||||
        """Update camera connections list for state reporting."""
 | 
			
		||||
        connections = []
 | 
			
		||||
 | 
			
		||||
        for sub in self.subscriptions.values():
 | 
			
		||||
            connection = CameraConnection(
 | 
			
		||||
                subscriptionIdentifier=sub.subscriptionIdentifier,
 | 
			
		||||
                modelId=sub.modelId,
 | 
			
		||||
                modelName=sub.modelName,
 | 
			
		||||
                online=True,  # TODO: Add actual online status tracking
 | 
			
		||||
                cropX1=sub.cropX1,
 | 
			
		||||
                cropY1=sub.cropY1,
 | 
			
		||||
                cropX2=sub.cropX2,
 | 
			
		||||
                cropY2=sub.cropY2
 | 
			
		||||
            )
 | 
			
		||||
            connections.append(connection)
 | 
			
		||||
 | 
			
		||||
        self.camera_connections = connections
 | 
			
		||||
 | 
			
		||||
    def get_camera_connections(self) -> List[CameraConnection]:
 | 
			
		||||
        """Get current camera connections for state reporting."""
 | 
			
		||||
        with self._lock:
 | 
			
		||||
            return self.camera_connections.copy()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SystemMetrics:
 | 
			
		||||
    """System metrics collection for state reporting."""
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_cpu_usage() -> float:
 | 
			
		||||
        """Get current CPU usage percentage."""
 | 
			
		||||
        try:
 | 
			
		||||
            return psutil.cpu_percent(interval=0.1)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to get CPU usage: {e}")
 | 
			
		||||
            return 0.0
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_memory_usage() -> float:
 | 
			
		||||
        """Get current memory usage percentage."""
 | 
			
		||||
        try:
 | 
			
		||||
            return psutil.virtual_memory().percent
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to get memory usage: {e}")
 | 
			
		||||
            return 0.0
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_gpu_usage() -> Optional[float]:
 | 
			
		||||
        """Get current GPU usage percentage."""
 | 
			
		||||
        if not TORCH_AVAILABLE:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if torch.cuda.is_available():
 | 
			
		||||
                # PyTorch doesn't provide direct GPU utilization
 | 
			
		||||
                # This is a placeholder - real implementation might use nvidia-ml-py
 | 
			
		||||
                if hasattr(torch.cuda, 'utilization'):
 | 
			
		||||
                    return torch.cuda.utilization()
 | 
			
		||||
                else:
 | 
			
		||||
                    # Fallback: estimate based on memory usage
 | 
			
		||||
                    allocated = torch.cuda.memory_allocated()
 | 
			
		||||
                    reserved = torch.cuda.memory_reserved()
 | 
			
		||||
                    if reserved > 0:
 | 
			
		||||
                        return (allocated / reserved) * 100
 | 
			
		||||
            return None
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to get GPU usage: {e}")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_gpu_memory_usage() -> Optional[float]:
 | 
			
		||||
        """Get current GPU memory usage in MB."""
 | 
			
		||||
        if not TORCH_AVAILABLE:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if torch.cuda.is_available():
 | 
			
		||||
                return torch.cuda.memory_reserved() / (1024 ** 2)  # Convert to MB
 | 
			
		||||
            return None
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to get GPU memory usage: {e}")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Global worker state instance
 | 
			
		||||
worker_state = WorkerState()
 | 
			
		||||
							
								
								
									
										326
									
								
								core/communication/websocket.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										326
									
								
								core/communication/websocket.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,326 @@
 | 
			
		|||
"""
 | 
			
		||||
WebSocket message handling and protocol implementation.
 | 
			
		||||
"""
 | 
			
		||||
import asyncio
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from fastapi import WebSocket, WebSocketDisconnect
 | 
			
		||||
from websockets.exceptions import ConnectionClosedError
 | 
			
		||||
 | 
			
		||||
from .messages import (
 | 
			
		||||
    parse_incoming_message, serialize_outgoing_message,
 | 
			
		||||
    MessageTypes, create_state_report
 | 
			
		||||
)
 | 
			
		||||
from .models import (
 | 
			
		||||
    SetSubscriptionListMessage, SetSessionIdMessage, SetProgressionStageMessage,
 | 
			
		||||
    RequestStateMessage, PatchSessionResultMessage
 | 
			
		||||
)
 | 
			
		||||
from .state import worker_state, SystemMetrics
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Constants
 | 
			
		||||
HEARTBEAT_INTERVAL = 2.0  # seconds
 | 
			
		||||
WORKER_TIMEOUT_MS = 10000
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WebSocketHandler:
 | 
			
		||||
    """
 | 
			
		||||
    Handles WebSocket connection lifecycle and message processing.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, websocket: WebSocket):
 | 
			
		||||
        self.websocket = websocket
 | 
			
		||||
        self.connected = False
 | 
			
		||||
        self._heartbeat_task: Optional[asyncio.Task] = None
 | 
			
		||||
        self._message_task: Optional[asyncio.Task] = None
 | 
			
		||||
 | 
			
		||||
    async def handle_connection(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Main connection handler that manages the WebSocket lifecycle.
 | 
			
		||||
        Based on the original architecture from archive/app.py
 | 
			
		||||
        """
 | 
			
		||||
        client_info = f"{self.websocket.client.host}:{self.websocket.client.port}" if self.websocket.client else "unknown"
 | 
			
		||||
        logger.info(f"Starting WebSocket handler for {client_info}")
 | 
			
		||||
 | 
			
		||||
        stream_task = None
 | 
			
		||||
        try:
 | 
			
		||||
            logger.info(f"Accepting WebSocket connection from {client_info}")
 | 
			
		||||
            await self.websocket.accept()
 | 
			
		||||
            self.connected = True
 | 
			
		||||
            logger.info(f"WebSocket connection accepted and established for {client_info}")
 | 
			
		||||
 | 
			
		||||
            # Send immediate heartbeat to show connection is alive
 | 
			
		||||
            await self._send_immediate_heartbeat()
 | 
			
		||||
 | 
			
		||||
            # Start background tasks (matching original architecture)
 | 
			
		||||
            stream_task = asyncio.create_task(self._process_streams())
 | 
			
		||||
            heartbeat_task = asyncio.create_task(self._send_heartbeat())
 | 
			
		||||
            message_task = asyncio.create_task(self._handle_messages())
 | 
			
		||||
 | 
			
		||||
            logger.info(f"WebSocket background tasks started for {client_info} (stream + heartbeat + message handler)")
 | 
			
		||||
 | 
			
		||||
            # Wait for heartbeat and message tasks (stream runs independently)
 | 
			
		||||
            await asyncio.gather(heartbeat_task, message_task)
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error in WebSocket connection for {client_info}: {e}", exc_info=True)
 | 
			
		||||
        finally:
 | 
			
		||||
            logger.info(f"Cleaning up connection for {client_info}")
 | 
			
		||||
            # Cancel stream task
 | 
			
		||||
            if stream_task and not stream_task.done():
 | 
			
		||||
                stream_task.cancel()
 | 
			
		||||
                try:
 | 
			
		||||
                    await stream_task
 | 
			
		||||
                except asyncio.CancelledError:
 | 
			
		||||
                    logger.debug(f"Stream task cancelled for {client_info}")
 | 
			
		||||
            await self._cleanup()
 | 
			
		||||
 | 
			
		||||
    async def _send_immediate_heartbeat(self) -> None:
 | 
			
		||||
        """Send immediate heartbeat on connection to show we're alive."""
 | 
			
		||||
        try:
 | 
			
		||||
            cpu_usage = SystemMetrics.get_cpu_usage()
 | 
			
		||||
            memory_usage = SystemMetrics.get_memory_usage()
 | 
			
		||||
            gpu_usage = SystemMetrics.get_gpu_usage()
 | 
			
		||||
            gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
 | 
			
		||||
            camera_connections = worker_state.get_camera_connections()
 | 
			
		||||
 | 
			
		||||
            state_report = create_state_report(
 | 
			
		||||
                cpu_usage=cpu_usage,
 | 
			
		||||
                memory_usage=memory_usage,
 | 
			
		||||
                gpu_usage=gpu_usage,
 | 
			
		||||
                gpu_memory_usage=gpu_memory_usage,
 | 
			
		||||
                camera_connections=camera_connections
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            await self._send_message(state_report)
 | 
			
		||||
            logger.info(f"Sent immediate stateReport: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
 | 
			
		||||
                       f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error sending immediate heartbeat: {e}")
 | 
			
		||||
 | 
			
		||||
    async def _send_heartbeat(self) -> None:
 | 
			
		||||
        """Send periodic state reports as heartbeat."""
 | 
			
		||||
        while self.connected:
 | 
			
		||||
            try:
 | 
			
		||||
                # Collect system metrics
 | 
			
		||||
                cpu_usage = SystemMetrics.get_cpu_usage()
 | 
			
		||||
                memory_usage = SystemMetrics.get_memory_usage()
 | 
			
		||||
                gpu_usage = SystemMetrics.get_gpu_usage()
 | 
			
		||||
                gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
 | 
			
		||||
                camera_connections = worker_state.get_camera_connections()
 | 
			
		||||
 | 
			
		||||
                # Create and send state report
 | 
			
		||||
                state_report = create_state_report(
 | 
			
		||||
                    cpu_usage=cpu_usage,
 | 
			
		||||
                    memory_usage=memory_usage,
 | 
			
		||||
                    gpu_usage=gpu_usage,
 | 
			
		||||
                    gpu_memory_usage=gpu_memory_usage,
 | 
			
		||||
                    camera_connections=camera_connections
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                await self._send_message(state_report)
 | 
			
		||||
                logger.debug(f"Sent heartbeat: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, "
 | 
			
		||||
                           f"GPU {gpu_usage or 'N/A'}, {len(camera_connections)} cameras")
 | 
			
		||||
 | 
			
		||||
                await asyncio.sleep(HEARTBEAT_INTERVAL)
 | 
			
		||||
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error(f"Error sending heartbeat: {e}")
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
    async def _handle_messages(self) -> None:
 | 
			
		||||
        """Handle incoming WebSocket messages."""
 | 
			
		||||
        while self.connected:
 | 
			
		||||
            try:
 | 
			
		||||
                raw_message = await self.websocket.receive_text()
 | 
			
		||||
                logger.info(f"Received message: {raw_message}")
 | 
			
		||||
 | 
			
		||||
                # Parse incoming message
 | 
			
		||||
                message = parse_incoming_message(raw_message)
 | 
			
		||||
                if not message:
 | 
			
		||||
                    logger.warning("Failed to parse incoming message")
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                # Route message to appropriate handler
 | 
			
		||||
                await self._route_message(message)
 | 
			
		||||
 | 
			
		||||
            except (WebSocketDisconnect, ConnectionClosedError) as e:
 | 
			
		||||
                logger.warning(f"WebSocket disconnected: {e}")
 | 
			
		||||
                break
 | 
			
		||||
            except json.JSONDecodeError:
 | 
			
		||||
                logger.error("Received invalid JSON message")
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error(f"Error handling message: {e}")
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
    async def _route_message(self, message) -> None:
 | 
			
		||||
        """Route parsed message to appropriate handler."""
 | 
			
		||||
        message_type = message.type
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if message_type == MessageTypes.SET_SUBSCRIPTION_LIST:
 | 
			
		||||
                await self._handle_set_subscription_list(message)
 | 
			
		||||
            elif message_type == MessageTypes.SET_SESSION_ID:
 | 
			
		||||
                await self._handle_set_session_id(message)
 | 
			
		||||
            elif message_type == MessageTypes.SET_PROGRESSION_STAGE:
 | 
			
		||||
                await self._handle_set_progression_stage(message)
 | 
			
		||||
            elif message_type == MessageTypes.REQUEST_STATE:
 | 
			
		||||
                await self._handle_request_state(message)
 | 
			
		||||
            elif message_type == MessageTypes.PATCH_SESSION_RESULT:
 | 
			
		||||
                await self._handle_patch_session_result(message)
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning(f"Unknown message type: {message_type}")
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error handling {message_type} message: {e}")
 | 
			
		||||
 | 
			
		||||
    async def _handle_set_subscription_list(self, message: SetSubscriptionListMessage) -> None:
 | 
			
		||||
        """Handle setSubscriptionList message for declarative subscription management."""
 | 
			
		||||
        logger.info(f"Processing setSubscriptionList with {len(message.subscriptions)} subscriptions")
 | 
			
		||||
 | 
			
		||||
        # Update worker state with new subscriptions
 | 
			
		||||
        worker_state.set_subscriptions(message.subscriptions)
 | 
			
		||||
 | 
			
		||||
        # TODO: Phase 2 - Integrate with model management and streaming
 | 
			
		||||
        # For now, just log the subscription changes
 | 
			
		||||
        for subscription in message.subscriptions:
 | 
			
		||||
            logger.info(f"  Subscription: {subscription.subscriptionIdentifier} -> "
 | 
			
		||||
                       f"Model {subscription.modelId} ({subscription.modelName})")
 | 
			
		||||
            if subscription.rtspUrl:
 | 
			
		||||
                logger.debug(f"    RTSP: {subscription.rtspUrl}")
 | 
			
		||||
            if subscription.snapshotUrl:
 | 
			
		||||
                logger.debug(f"    Snapshot: {subscription.snapshotUrl} ({subscription.snapshotInterval}ms)")
 | 
			
		||||
            if subscription.modelUrl:
 | 
			
		||||
                logger.debug(f"    Model: {subscription.modelUrl}")
 | 
			
		||||
 | 
			
		||||
        logger.info("Subscription list updated successfully")
 | 
			
		||||
 | 
			
		||||
    async def _handle_set_session_id(self, message: SetSessionIdMessage) -> None:
 | 
			
		||||
        """Handle setSessionId message."""
 | 
			
		||||
        display_identifier = message.payload.displayIdentifier
 | 
			
		||||
        session_id = message.payload.sessionId
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Setting session ID for display {display_identifier}: {session_id}")
 | 
			
		||||
 | 
			
		||||
        # Update worker state
 | 
			
		||||
        worker_state.set_session_id(display_identifier, session_id)
 | 
			
		||||
 | 
			
		||||
    async def _handle_set_progression_stage(self, message: SetProgressionStageMessage) -> None:
 | 
			
		||||
        """Handle setProgressionStage message."""
 | 
			
		||||
        display_identifier = message.payload.displayIdentifier
 | 
			
		||||
        stage = message.payload.progressionStage
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Setting progression stage for display {display_identifier}: {stage}")
 | 
			
		||||
 | 
			
		||||
        # Update worker state
 | 
			
		||||
        worker_state.set_progression_stage(display_identifier, stage)
 | 
			
		||||
 | 
			
		||||
    async def _handle_request_state(self, message: RequestStateMessage) -> None:
 | 
			
		||||
        """Handle requestState message by sending immediate state report."""
 | 
			
		||||
        logger.debug("Received requestState, sending immediate state report")
 | 
			
		||||
 | 
			
		||||
        # Collect metrics and send state report
 | 
			
		||||
        cpu_usage = SystemMetrics.get_cpu_usage()
 | 
			
		||||
        memory_usage = SystemMetrics.get_memory_usage()
 | 
			
		||||
        gpu_usage = SystemMetrics.get_gpu_usage()
 | 
			
		||||
        gpu_memory_usage = SystemMetrics.get_gpu_memory_usage()
 | 
			
		||||
        camera_connections = worker_state.get_camera_connections()
 | 
			
		||||
 | 
			
		||||
        state_report = create_state_report(
 | 
			
		||||
            cpu_usage=cpu_usage,
 | 
			
		||||
            memory_usage=memory_usage,
 | 
			
		||||
            gpu_usage=gpu_usage,
 | 
			
		||||
            gpu_memory_usage=gpu_memory_usage,
 | 
			
		||||
            camera_connections=camera_connections
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        await self._send_message(state_report)
 | 
			
		||||
 | 
			
		||||
    async def _handle_patch_session_result(self, message: PatchSessionResultMessage) -> None:
 | 
			
		||||
        """Handle patchSessionResult message."""
 | 
			
		||||
        payload = message.payload
 | 
			
		||||
        logger.info(f"Received patch session result for session {payload.sessionId}: "
 | 
			
		||||
                   f"success={payload.success}, message='{payload.message}'")
 | 
			
		||||
 | 
			
		||||
        # TODO: Handle patch session result if needed
 | 
			
		||||
        # For now, just log the response
 | 
			
		||||
 | 
			
		||||
    async def _send_message(self, message) -> None:
 | 
			
		||||
        """Send message to backend via WebSocket."""
 | 
			
		||||
        if not self.connected:
 | 
			
		||||
            logger.warning("Cannot send message: WebSocket not connected")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            json_message = serialize_outgoing_message(message)
 | 
			
		||||
            await self.websocket.send_text(json_message)
 | 
			
		||||
            # Don't log full message for heartbeats to avoid spam, just type
 | 
			
		||||
            if hasattr(message, 'type') and message.type == 'stateReport':
 | 
			
		||||
                logger.debug(f"Sent message: {message.type}")
 | 
			
		||||
            else:
 | 
			
		||||
                logger.debug(f"Sent message: {json_message}")
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Failed to send WebSocket message: {e}")
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
    async def _process_streams(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Stream processing task that handles frame processing and detection.
 | 
			
		||||
        This is a placeholder for Phase 2 - currently just logs that it's running.
 | 
			
		||||
        """
 | 
			
		||||
        logger.info("Stream processing task started")
 | 
			
		||||
        try:
 | 
			
		||||
            while self.connected:
 | 
			
		||||
                # Get current subscriptions
 | 
			
		||||
                subscriptions = worker_state.get_all_subscriptions()
 | 
			
		||||
 | 
			
		||||
                if subscriptions:
 | 
			
		||||
                    logger.debug(f"Stream processor running with {len(subscriptions)} active subscriptions")
 | 
			
		||||
                    # TODO: Phase 2 - Add actual frame processing logic here
 | 
			
		||||
                    # This will include:
 | 
			
		||||
                    # - Frame reading from RTSP/HTTP streams
 | 
			
		||||
                    # - Model inference using loaded pipelines
 | 
			
		||||
                    # - Detection result sending via WebSocket
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.debug("Stream processor running with no active subscriptions")
 | 
			
		||||
 | 
			
		||||
                # Sleep to prevent excessive CPU usage (similar to old poll_interval)
 | 
			
		||||
                await asyncio.sleep(0.1)  # 100ms polling interval
 | 
			
		||||
 | 
			
		||||
        except asyncio.CancelledError:
 | 
			
		||||
            logger.info("Stream processing task cancelled")
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error in stream processing: {e}", exc_info=True)
 | 
			
		||||
 | 
			
		||||
    async def _cleanup(self) -> None:
 | 
			
		||||
        """Clean up resources when connection closes."""
 | 
			
		||||
        logger.info("Cleaning up WebSocket connection")
 | 
			
		||||
        self.connected = False
 | 
			
		||||
 | 
			
		||||
        # Cancel background tasks
 | 
			
		||||
        if self._heartbeat_task and not self._heartbeat_task.done():
 | 
			
		||||
            self._heartbeat_task.cancel()
 | 
			
		||||
        if self._message_task and not self._message_task.done():
 | 
			
		||||
            self._message_task.cancel()
 | 
			
		||||
 | 
			
		||||
        # Clear worker state
 | 
			
		||||
        worker_state.set_subscriptions([])
 | 
			
		||||
        worker_state.session_ids.clear()
 | 
			
		||||
        worker_state.progression_stages.clear()
 | 
			
		||||
 | 
			
		||||
        logger.info("WebSocket connection cleanup completed")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Factory function for FastAPI integration
 | 
			
		||||
async def websocket_endpoint(websocket: WebSocket) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    FastAPI WebSocket endpoint handler.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        websocket: FastAPI WebSocket connection
 | 
			
		||||
    """
 | 
			
		||||
    handler = WebSocketHandler(websocket)
 | 
			
		||||
    await handler.handle_connection()
 | 
			
		||||
							
								
								
									
										1
									
								
								core/detection/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								core/detection/__init__.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
# Detection module for ML pipeline execution
 | 
			
		||||
							
								
								
									
										1
									
								
								core/models/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								core/models/__init__.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
# Models module for MPTA management and pipeline configuration
 | 
			
		||||
							
								
								
									
										1
									
								
								core/storage/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								core/storage/__init__.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
# Storage module for Redis and PostgreSQL operations
 | 
			
		||||
							
								
								
									
										1
									
								
								core/streaming/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								core/streaming/__init__.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
# Streaming module for RTSP/HTTP stream management
 | 
			
		||||
							
								
								
									
										1
									
								
								core/tracking/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								core/tracking/__init__.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
# Tracking module for vehicle tracking and validation
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue