Refactor: Phase 4: Communication Layer
This commit is contained in:
parent
cdeaaf4a4f
commit
54f21672aa
6 changed files with 2876 additions and 0 deletions
509
detector_worker/communication/message_processor.py
Normal file
509
detector_worker/communication/message_processor.py
Normal file
|
@ -0,0 +1,509 @@
|
||||||
|
"""
|
||||||
|
Message processor module.
|
||||||
|
|
||||||
|
This module handles validation, parsing, and processing of WebSocket messages.
|
||||||
|
It provides a clean separation between message handling and business logic.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Dict, Any, Optional, Callable, Tuple
|
||||||
|
from enum import Enum
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from ..core.exceptions import ValidationError, MessageProcessingError
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logger = logging.getLogger("detector_worker.message_processor")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(Enum):
|
||||||
|
"""Enumeration of supported WebSocket message types."""
|
||||||
|
SUBSCRIBE = "subscribe"
|
||||||
|
UNSUBSCRIBE = "unsubscribe"
|
||||||
|
REQUEST_STATE = "requestState"
|
||||||
|
SET_SESSION_ID = "setSessionId"
|
||||||
|
PATCH_SESSION = "patchSession"
|
||||||
|
SET_PROGRESSION_STAGE = "setProgressionStage"
|
||||||
|
IMAGE_DETECTION = "imageDetection"
|
||||||
|
STATE_REPORT = "stateReport"
|
||||||
|
ACK = "ack"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressionStage(Enum):
|
||||||
|
"""Enumeration of progression stages."""
|
||||||
|
WELCOME = "welcome"
|
||||||
|
CAR_FUELING = "car_fueling"
|
||||||
|
CAR_WAIT_PAYMENT = "car_waitpayment"
|
||||||
|
CAR_WAIT_STAFF = "car_wait_staff"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SubscribePayload:
|
||||||
|
"""Payload for subscription messages."""
|
||||||
|
subscription_identifier: str
|
||||||
|
model_id: int
|
||||||
|
model_name: str
|
||||||
|
model_url: str
|
||||||
|
rtsp_url: Optional[str] = None
|
||||||
|
snapshot_url: Optional[str] = None
|
||||||
|
snapshot_interval: Optional[int] = None
|
||||||
|
crop_x1: Optional[int] = None
|
||||||
|
crop_y1: Optional[int] = None
|
||||||
|
crop_x2: Optional[int] = None
|
||||||
|
crop_y2: Optional[int] = None
|
||||||
|
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "SubscribePayload":
|
||||||
|
"""Create SubscribePayload from dictionary."""
|
||||||
|
# Extract known fields
|
||||||
|
known_fields = {
|
||||||
|
"subscription_identifier": data.get("subscriptionIdentifier"),
|
||||||
|
"model_id": data.get("modelId"),
|
||||||
|
"model_name": data.get("modelName"),
|
||||||
|
"model_url": data.get("modelUrl"),
|
||||||
|
"rtsp_url": data.get("rtspUrl"),
|
||||||
|
"snapshot_url": data.get("snapshotUrl"),
|
||||||
|
"snapshot_interval": data.get("snapshotInterval"),
|
||||||
|
"crop_x1": data.get("cropX1"),
|
||||||
|
"crop_y1": data.get("cropY1"),
|
||||||
|
"crop_x2": data.get("cropX2"),
|
||||||
|
"crop_y2": data.get("cropY2")
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract extra parameters
|
||||||
|
extra_params = {k: v for k, v in data.items()
|
||||||
|
if k not in ["subscriptionIdentifier", "modelId", "modelName",
|
||||||
|
"modelUrl", "rtspUrl", "snapshotUrl", "snapshotInterval",
|
||||||
|
"cropX1", "cropY1", "cropX2", "cropY2"]}
|
||||||
|
|
||||||
|
return cls(**known_fields, extra_params=extra_params)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format for stream configuration."""
|
||||||
|
result = {
|
||||||
|
"subscriptionIdentifier": self.subscription_identifier,
|
||||||
|
"modelId": self.model_id,
|
||||||
|
"modelName": self.model_name,
|
||||||
|
"modelUrl": self.model_url
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.rtsp_url:
|
||||||
|
result["rtspUrl"] = self.rtsp_url
|
||||||
|
if self.snapshot_url:
|
||||||
|
result["snapshotUrl"] = self.snapshot_url
|
||||||
|
if self.snapshot_interval is not None:
|
||||||
|
result["snapshotInterval"] = self.snapshot_interval
|
||||||
|
|
||||||
|
if self.crop_x1 is not None:
|
||||||
|
result["cropX1"] = self.crop_x1
|
||||||
|
if self.crop_y1 is not None:
|
||||||
|
result["cropY1"] = self.crop_y1
|
||||||
|
if self.crop_x2 is not None:
|
||||||
|
result["cropX2"] = self.crop_x2
|
||||||
|
if self.crop_y2 is not None:
|
||||||
|
result["cropY2"] = self.crop_y2
|
||||||
|
|
||||||
|
# Add any extra parameters
|
||||||
|
result.update(self.extra_params)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionPayload:
|
||||||
|
"""Payload for session-related messages."""
|
||||||
|
display_identifier: str
|
||||||
|
session_id: str
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProgressionPayload:
|
||||||
|
"""Payload for progression stage messages."""
|
||||||
|
display_identifier: str
|
||||||
|
progression_stage: str
|
||||||
|
|
||||||
|
|
||||||
|
class MessageProcessor:
|
||||||
|
"""
|
||||||
|
Processes and validates WebSocket messages.
|
||||||
|
|
||||||
|
This class handles:
|
||||||
|
- Message validation and parsing
|
||||||
|
- Payload extraction and type conversion
|
||||||
|
- Error handling and response generation
|
||||||
|
- Message routing to appropriate handlers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the message processor."""
|
||||||
|
self.validators: Dict[MessageType, Callable] = {
|
||||||
|
MessageType.SUBSCRIBE: self._validate_subscribe,
|
||||||
|
MessageType.UNSUBSCRIBE: self._validate_unsubscribe,
|
||||||
|
MessageType.SET_SESSION_ID: self._validate_set_session,
|
||||||
|
MessageType.PATCH_SESSION: self._validate_patch_session,
|
||||||
|
MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage
|
||||||
|
}
|
||||||
|
|
||||||
|
def parse_message(self, raw_message: str) -> Tuple[MessageType, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Parse a raw WebSocket message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_message: Raw message string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (MessageType, parsed_data)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MessageProcessingError: If message is invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(raw_message)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise MessageProcessingError(f"Invalid JSON: {e}")
|
||||||
|
|
||||||
|
# Validate message structure
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise MessageProcessingError("Message must be a JSON object")
|
||||||
|
|
||||||
|
# Extract message type
|
||||||
|
msg_type_str = data.get("type")
|
||||||
|
if not msg_type_str:
|
||||||
|
raise MessageProcessingError("Missing 'type' field in message")
|
||||||
|
|
||||||
|
# Convert to MessageType enum
|
||||||
|
try:
|
||||||
|
msg_type = MessageType(msg_type_str)
|
||||||
|
except ValueError:
|
||||||
|
raise MessageProcessingError(f"Unknown message type: {msg_type_str}")
|
||||||
|
|
||||||
|
return msg_type, data
|
||||||
|
|
||||||
|
def validate_message(self, msg_type: MessageType, data: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Validate message payload based on type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg_type: Message type
|
||||||
|
data: Message data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated payload object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If validation fails
|
||||||
|
"""
|
||||||
|
# Get validator for message type
|
||||||
|
validator = self.validators.get(msg_type)
|
||||||
|
if not validator:
|
||||||
|
# No validation needed for some message types
|
||||||
|
return data.get("payload", {})
|
||||||
|
|
||||||
|
return validator(data)
|
||||||
|
|
||||||
|
def _validate_subscribe(self, data: Dict[str, Any]) -> SubscribePayload:
|
||||||
|
"""Validate subscribe message."""
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
|
||||||
|
# Required fields
|
||||||
|
required = ["subscriptionIdentifier", "modelId", "modelName", "modelUrl"]
|
||||||
|
missing = [field for field in required if not payload.get(field)]
|
||||||
|
if missing:
|
||||||
|
raise ValidationError(f"Missing required fields: {', '.join(missing)}")
|
||||||
|
|
||||||
|
# Must have either rtspUrl or snapshotUrl
|
||||||
|
if not payload.get("rtspUrl") and not payload.get("snapshotUrl"):
|
||||||
|
raise ValidationError("Must provide either rtspUrl or snapshotUrl")
|
||||||
|
|
||||||
|
# Validate snapshot interval if snapshot URL is provided
|
||||||
|
if payload.get("snapshotUrl") and not payload.get("snapshotInterval"):
|
||||||
|
raise ValidationError("snapshotInterval is required when using snapshotUrl")
|
||||||
|
|
||||||
|
# Validate crop coordinates
|
||||||
|
crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"]
|
||||||
|
crop_values = [payload.get(coord) for coord in crop_coords]
|
||||||
|
if any(v is not None for v in crop_values):
|
||||||
|
# If any crop coordinate is set, all must be set
|
||||||
|
if not all(v is not None for v in crop_values):
|
||||||
|
raise ValidationError("All crop coordinates must be specified if any are set")
|
||||||
|
|
||||||
|
# Validate coordinate values
|
||||||
|
x1, y1, x2, y2 = crop_values
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
raise ValidationError("Invalid crop coordinates: x2 must be > x1, y2 must be > y1")
|
||||||
|
|
||||||
|
return SubscribePayload.from_dict(payload)
|
||||||
|
|
||||||
|
def _validate_unsubscribe(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Validate unsubscribe message."""
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
|
||||||
|
if not payload.get("subscriptionIdentifier"):
|
||||||
|
raise ValidationError("Missing required field: subscriptionIdentifier")
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _validate_set_session(self, data: Dict[str, Any]) -> SessionPayload:
|
||||||
|
"""Validate setSessionId message."""
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
|
||||||
|
display_id = payload.get("displayIdentifier")
|
||||||
|
session_id = payload.get("sessionId")
|
||||||
|
|
||||||
|
if not display_id:
|
||||||
|
raise ValidationError("Missing required field: displayIdentifier")
|
||||||
|
if not session_id:
|
||||||
|
raise ValidationError("Missing required field: sessionId")
|
||||||
|
|
||||||
|
return SessionPayload(display_id, session_id)
|
||||||
|
|
||||||
|
def _validate_patch_session(self, data: Dict[str, Any]) -> SessionPayload:
|
||||||
|
"""Validate patchSession message."""
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
|
||||||
|
session_id = payload.get("sessionId")
|
||||||
|
patch_data = payload.get("data")
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
raise ValidationError("Missing required field: sessionId")
|
||||||
|
|
||||||
|
# Display identifier is optional for patch
|
||||||
|
display_id = payload.get("displayIdentifier", "")
|
||||||
|
|
||||||
|
return SessionPayload(display_id, session_id, patch_data)
|
||||||
|
|
||||||
|
def _validate_progression_stage(self, data: Dict[str, Any]) -> ProgressionPayload:
|
||||||
|
"""Validate setProgressionStage message."""
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
|
||||||
|
display_id = payload.get("displayIdentifier")
|
||||||
|
stage = payload.get("progressionStage")
|
||||||
|
|
||||||
|
if not display_id:
|
||||||
|
raise ValidationError("Missing required field: displayIdentifier")
|
||||||
|
if not stage:
|
||||||
|
raise ValidationError("Missing required field: progressionStage")
|
||||||
|
|
||||||
|
# Validate stage value
|
||||||
|
valid_stages = [s.value for s in ProgressionStage]
|
||||||
|
if stage not in valid_stages:
|
||||||
|
logger.warning(f"Unknown progression stage: {stage}")
|
||||||
|
|
||||||
|
return ProgressionPayload(display_id, stage)
|
||||||
|
|
||||||
|
def create_ack_response(
|
||||||
|
self,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
message: str = "Success",
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create an acknowledgment response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: Request ID to acknowledge
|
||||||
|
message: Success message
|
||||||
|
data: Additional response data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ACK response dictionary
|
||||||
|
"""
|
||||||
|
response = {
|
||||||
|
"type": MessageType.ACK.value,
|
||||||
|
"requestId": request_id or str(int(time.time() * 1000)),
|
||||||
|
"code": "200",
|
||||||
|
"data": {
|
||||||
|
"message": message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if data:
|
||||||
|
response["data"].update(data)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def create_error_response(
|
||||||
|
self,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
code: str = "400",
|
||||||
|
message: str = "Error",
|
||||||
|
details: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create an error response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: Request ID that caused error
|
||||||
|
code: Error code
|
||||||
|
message: Error message
|
||||||
|
details: Additional error details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Error response dictionary
|
||||||
|
"""
|
||||||
|
response = {
|
||||||
|
"type": MessageType.ERROR.value,
|
||||||
|
"requestId": request_id or str(int(time.time() * 1000)),
|
||||||
|
"code": code,
|
||||||
|
"data": {
|
||||||
|
"message": message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if details:
|
||||||
|
response["data"]["details"] = details
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def create_detection_message(
|
||||||
|
self,
|
||||||
|
subscription_id: str,
|
||||||
|
detection_result: Optional[Dict[str, Any]],
|
||||||
|
model_id: int,
|
||||||
|
model_name: str,
|
||||||
|
timestamp: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create a detection result message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscription_id: Subscription identifier
|
||||||
|
detection_result: Detection result data (None for disconnection)
|
||||||
|
model_id: Model ID
|
||||||
|
model_name: Model name
|
||||||
|
timestamp: Optional timestamp (defaults to current UTC)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Detection message dictionary
|
||||||
|
"""
|
||||||
|
if timestamp is None:
|
||||||
|
timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": MessageType.IMAGE_DETECTION.value,
|
||||||
|
"subscriptionIdentifier": subscription_id,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"data": {
|
||||||
|
"detection": detection_result,
|
||||||
|
"modelId": model_id,
|
||||||
|
"modelName": model_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_state_report(
|
||||||
|
self,
|
||||||
|
active_streams: int,
|
||||||
|
loaded_models: int,
|
||||||
|
cpu_usage: float = 0.0,
|
||||||
|
memory_usage: float = 0.0,
|
||||||
|
gpu_usage: float = 0.0,
|
||||||
|
gpu_memory: float = 0.0,
|
||||||
|
uptime: float = 0.0,
|
||||||
|
additional_metrics: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create a state report message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
active_streams: Number of active streams
|
||||||
|
loaded_models: Number of loaded models
|
||||||
|
cpu_usage: CPU usage percentage
|
||||||
|
memory_usage: Memory usage percentage
|
||||||
|
gpu_usage: GPU usage percentage
|
||||||
|
gpu_memory: GPU memory usage percentage
|
||||||
|
uptime: System uptime in seconds
|
||||||
|
additional_metrics: Additional metrics to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
State report dictionary
|
||||||
|
"""
|
||||||
|
report = {
|
||||||
|
"type": MessageType.STATE_REPORT.value,
|
||||||
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||||
|
"data": {
|
||||||
|
"activeStreams": active_streams,
|
||||||
|
"loadedModels": loaded_models,
|
||||||
|
"cpuUsage": cpu_usage,
|
||||||
|
"memoryUsage": memory_usage,
|
||||||
|
"gpuUsage": gpu_usage,
|
||||||
|
"gpuMemory": gpu_memory,
|
||||||
|
"uptime": uptime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if additional_metrics:
|
||||||
|
report["data"].update(additional_metrics)
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
def extract_camera_id(self, subscription_id: str) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Extract display ID and camera ID from subscription identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscription_id: Subscription identifier (format: "display_id;camera_id")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (display_id, camera_id)
|
||||||
|
"""
|
||||||
|
parts = subscription_id.split(";")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
return parts[0], parts[1]
|
||||||
|
else:
|
||||||
|
return None, subscription_id
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance for backward compatibility
|
||||||
|
_message_processor = MessageProcessor()
|
||||||
|
|
||||||
|
# Convenience functions
|
||||||
|
def parse_websocket_message(raw_message: str) -> Tuple[MessageType, Dict[str, Any]]:
|
||||||
|
"""Parse a raw WebSocket message."""
|
||||||
|
return _message_processor.parse_message(raw_message)
|
||||||
|
|
||||||
|
def validate_message_payload(msg_type: MessageType, data: Dict[str, Any]) -> Any:
|
||||||
|
"""Validate message payload based on type."""
|
||||||
|
return _message_processor.validate_message(msg_type, data)
|
||||||
|
|
||||||
|
def create_ack_response(
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
message: str = "Success",
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create an acknowledgment response."""
|
||||||
|
return _message_processor.create_ack_response(request_id, message, data)
|
||||||
|
|
||||||
|
def create_error_response(
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
code: str = "400",
|
||||||
|
message: str = "Error",
|
||||||
|
details: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create an error response."""
|
||||||
|
return _message_processor.create_error_response(request_id, code, message, details)
|
||||||
|
|
||||||
|
def create_detection_message(
|
||||||
|
subscription_id: str,
|
||||||
|
detection_result: Optional[Dict[str, Any]],
|
||||||
|
model_id: int,
|
||||||
|
model_name: str,
|
||||||
|
timestamp: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create a detection result message."""
|
||||||
|
return _message_processor.create_detection_message(
|
||||||
|
subscription_id, detection_result, model_id, model_name, timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_state_report(
|
||||||
|
active_streams: int,
|
||||||
|
loaded_models: int,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create a state report message."""
|
||||||
|
return _message_processor.create_state_report(
|
||||||
|
active_streams, loaded_models, **kwargs
|
||||||
|
)
|
439
detector_worker/communication/response_formatter.py
Normal file
439
detector_worker/communication/response_formatter.py
Normal file
|
@ -0,0 +1,439 @@
|
||||||
|
"""
|
||||||
|
Response formatter module.
|
||||||
|
|
||||||
|
This module handles formatting of detection results and other responses
|
||||||
|
for WebSocket transmission, ensuring consistent output format.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Dict, Any, List, Optional, Union
|
||||||
|
from datetime import datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from ..detection.detection_result import DetectionResult
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logger = logging.getLogger("detector_worker.response_formatter")
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormatter:
|
||||||
|
"""
|
||||||
|
Formats responses for WebSocket transmission.
|
||||||
|
|
||||||
|
This class handles:
|
||||||
|
- Detection result formatting
|
||||||
|
- Object serialization
|
||||||
|
- Response structure standardization
|
||||||
|
- Special type handling (Decimal, datetime, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, compact: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize the response formatter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
compact: Whether to use compact JSON formatting
|
||||||
|
"""
|
||||||
|
self.compact = compact
|
||||||
|
self.separators = (',', ':') if compact else (',', ': ')
|
||||||
|
|
||||||
|
def format_detection_result(
|
||||||
|
self,
|
||||||
|
detection_result: Union[DetectionResult, Dict[str, Any], None],
|
||||||
|
include_tracking: bool = True,
|
||||||
|
include_metadata: bool = True
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Format a detection result for transmission.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detection_result: Detection result to format
|
||||||
|
include_tracking: Whether to include tracking information
|
||||||
|
include_metadata: Whether to include metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted detection dictionary or None
|
||||||
|
"""
|
||||||
|
if detection_result is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle DetectionResult object
|
||||||
|
if isinstance(detection_result, DetectionResult):
|
||||||
|
result = detection_result.to_dict()
|
||||||
|
else:
|
||||||
|
result = detection_result
|
||||||
|
|
||||||
|
# Format detections
|
||||||
|
if "detections" in result:
|
||||||
|
result["detections"] = self._format_detections(
|
||||||
|
result["detections"],
|
||||||
|
include_tracking
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format license plate results
|
||||||
|
if "license_plate_results" in result:
|
||||||
|
result["license_plate_results"] = self._format_license_plates(
|
||||||
|
result["license_plate_results"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format session info
|
||||||
|
if "session_info" in result:
|
||||||
|
result["session_info"] = self._format_session_info(
|
||||||
|
result["session_info"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove metadata if not needed
|
||||||
|
if not include_metadata and "metadata" in result:
|
||||||
|
result.pop("metadata", None)
|
||||||
|
|
||||||
|
# Clean up None values
|
||||||
|
return self._clean_dict(result)
|
||||||
|
|
||||||
|
def _format_detections(
|
||||||
|
self,
|
||||||
|
detections: List[Dict[str, Any]],
|
||||||
|
include_tracking: bool = True
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Format detection objects."""
|
||||||
|
formatted = []
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
formatted_det = {
|
||||||
|
"class": detection.get("class"),
|
||||||
|
"confidence": self._format_float(detection.get("confidence")),
|
||||||
|
"bbox": self._format_bbox(detection.get("bbox", []))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add tracking info if available and requested
|
||||||
|
if include_tracking:
|
||||||
|
if "track_id" in detection:
|
||||||
|
formatted_det["track_id"] = detection["track_id"]
|
||||||
|
if "track_age" in detection:
|
||||||
|
formatted_det["track_age"] = detection["track_age"]
|
||||||
|
if "stability_score" in detection:
|
||||||
|
formatted_det["stability_score"] = self._format_float(
|
||||||
|
detection["stability_score"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add additional properties
|
||||||
|
for key, value in detection.items():
|
||||||
|
if key not in ["class", "confidence", "bbox", "track_id",
|
||||||
|
"track_age", "stability_score"]:
|
||||||
|
formatted_det[key] = self._format_value(value)
|
||||||
|
|
||||||
|
formatted.append(formatted_det)
|
||||||
|
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
def _format_license_plates(
|
||||||
|
self,
|
||||||
|
license_plates: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Format license plate results."""
|
||||||
|
formatted = []
|
||||||
|
|
||||||
|
for plate in license_plates:
|
||||||
|
formatted_plate = {
|
||||||
|
"text": plate.get("text", ""),
|
||||||
|
"confidence": self._format_float(plate.get("confidence")),
|
||||||
|
"bbox": self._format_bbox(plate.get("bbox", []))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add tracking info if available
|
||||||
|
if "track_id" in plate:
|
||||||
|
formatted_plate["track_id"] = plate["track_id"]
|
||||||
|
|
||||||
|
# Add type if available
|
||||||
|
if "type" in plate:
|
||||||
|
formatted_plate["type"] = plate["type"]
|
||||||
|
|
||||||
|
formatted.append(formatted_plate)
|
||||||
|
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
def _format_session_info(self, session_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Format session information."""
|
||||||
|
formatted = {
|
||||||
|
"session_id": session_info.get("session_id"),
|
||||||
|
"display_id": session_info.get("display_id"),
|
||||||
|
"camera_id": session_info.get("camera_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add timestamps
|
||||||
|
if "created_at" in session_info:
|
||||||
|
formatted["created_at"] = self._format_timestamp(
|
||||||
|
session_info["created_at"]
|
||||||
|
)
|
||||||
|
if "updated_at" in session_info:
|
||||||
|
formatted["updated_at"] = self._format_timestamp(
|
||||||
|
session_info["updated_at"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add any additional session data
|
||||||
|
for key, value in session_info.items():
|
||||||
|
if key not in ["session_id", "display_id", "camera_id",
|
||||||
|
"created_at", "updated_at"]:
|
||||||
|
formatted[key] = self._format_value(value)
|
||||||
|
|
||||||
|
return self._clean_dict(formatted)
|
||||||
|
|
||||||
|
def _format_bbox(self, bbox: List[Union[int, float]]) -> List[int]:
|
||||||
|
"""Format bounding box coordinates."""
|
||||||
|
if len(bbox) != 4:
|
||||||
|
return []
|
||||||
|
return [int(coord) for coord in bbox]
|
||||||
|
|
||||||
|
def _format_float(self, value: Union[float, Decimal, None], precision: int = 4) -> Optional[float]:
|
||||||
|
"""Format floating point values."""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, Decimal):
|
||||||
|
value = float(value)
|
||||||
|
return round(float(value), precision)
|
||||||
|
|
||||||
|
def _format_timestamp(self, timestamp: Union[str, datetime, float, None]) -> Optional[str]:
|
||||||
|
"""Format timestamp to ISO format."""
|
||||||
|
if timestamp is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle string timestamps
|
||||||
|
if isinstance(timestamp, str):
|
||||||
|
return timestamp
|
||||||
|
|
||||||
|
# Handle datetime objects
|
||||||
|
if isinstance(timestamp, datetime):
|
||||||
|
return timestamp.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
# Handle Unix timestamps
|
||||||
|
if isinstance(timestamp, (int, float)):
|
||||||
|
dt = datetime.fromtimestamp(timestamp)
|
||||||
|
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
return str(timestamp)
|
||||||
|
|
||||||
|
def _format_value(self, value: Any) -> Any:
|
||||||
|
"""Format arbitrary values for JSON serialization."""
|
||||||
|
# Handle None
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle basic types
|
||||||
|
if isinstance(value, (str, int, bool)):
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Handle float/Decimal
|
||||||
|
if isinstance(value, (float, Decimal)):
|
||||||
|
return self._format_float(value)
|
||||||
|
|
||||||
|
# Handle datetime
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return self._format_timestamp(value)
|
||||||
|
|
||||||
|
# Handle lists
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [self._format_value(item) for item in value]
|
||||||
|
|
||||||
|
# Handle dictionaries
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: self._format_value(v) for k, v in value.items()}
|
||||||
|
|
||||||
|
# Convert to string for other types
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
def _clean_dict(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Remove None values and empty collections from dictionary."""
|
||||||
|
cleaned = {}
|
||||||
|
|
||||||
|
for key, value in data.items():
|
||||||
|
# Skip None values
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip empty lists
|
||||||
|
if isinstance(value, list) and len(value) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip empty dictionaries
|
||||||
|
if isinstance(value, dict) and len(value) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Recursively clean dictionaries
|
||||||
|
if isinstance(value, dict):
|
||||||
|
cleaned_value = self._clean_dict(value)
|
||||||
|
if cleaned_value:
|
||||||
|
cleaned[key] = cleaned_value
|
||||||
|
else:
|
||||||
|
cleaned[key] = value
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
def format_tracking_state(self, tracking_state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format tracking state information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracking_state: Tracking state dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted tracking state
|
||||||
|
"""
|
||||||
|
formatted = {}
|
||||||
|
|
||||||
|
# Format tracker state
|
||||||
|
if "active_tracks" in tracking_state:
|
||||||
|
formatted["active_tracks"] = len(tracking_state["active_tracks"])
|
||||||
|
|
||||||
|
if "tracks" in tracking_state:
|
||||||
|
formatted["tracks"] = []
|
||||||
|
for track in tracking_state["tracks"]:
|
||||||
|
formatted_track = {
|
||||||
|
"track_id": track.get("track_id"),
|
||||||
|
"class": track.get("class"),
|
||||||
|
"age": track.get("age", 0),
|
||||||
|
"hits": track.get("hits", 0),
|
||||||
|
"misses": track.get("misses", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if "bbox" in track:
|
||||||
|
formatted_track["bbox"] = self._format_bbox(track["bbox"])
|
||||||
|
|
||||||
|
if "confidence" in track:
|
||||||
|
formatted_track["confidence"] = self._format_float(track["confidence"])
|
||||||
|
|
||||||
|
formatted["tracks"].append(formatted_track)
|
||||||
|
|
||||||
|
return self._clean_dict(formatted)
|
||||||
|
|
||||||
|
def format_error_details(self, error: Exception) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format error details for response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: Exception to format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted error details
|
||||||
|
"""
|
||||||
|
details = {
|
||||||
|
"error_type": type(error).__name__,
|
||||||
|
"error_message": str(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add additional context if available
|
||||||
|
if hasattr(error, "details"):
|
||||||
|
details["details"] = error.details
|
||||||
|
|
||||||
|
if hasattr(error, "code"):
|
||||||
|
details["error_code"] = error.code
|
||||||
|
|
||||||
|
return details
|
||||||
|
|
||||||
|
def to_json(self, data: Any) -> str:
|
||||||
|
"""
|
||||||
|
Convert data to JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Data to serialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string
|
||||||
|
"""
|
||||||
|
return json.dumps(
|
||||||
|
data,
|
||||||
|
separators=self.separators,
|
||||||
|
default=self._json_default
|
||||||
|
)
|
||||||
|
|
||||||
|
def _json_default(self, obj: Any) -> Any:
|
||||||
|
"""Default JSON serialization for custom types."""
|
||||||
|
# Handle Decimal
|
||||||
|
if isinstance(obj, Decimal):
|
||||||
|
return float(obj)
|
||||||
|
|
||||||
|
# Handle datetime
|
||||||
|
if isinstance(obj, datetime):
|
||||||
|
return obj.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
# Handle bytes
|
||||||
|
if isinstance(obj, bytes):
|
||||||
|
return obj.decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
# Default to string representation
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
def format_pipeline_result(
|
||||||
|
self,
|
||||||
|
pipeline_result: Dict[str, Any],
|
||||||
|
include_intermediate: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format pipeline execution result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline_result: Pipeline result dictionary
|
||||||
|
include_intermediate: Whether to include intermediate results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted pipeline result
|
||||||
|
"""
|
||||||
|
formatted = {
|
||||||
|
"success": pipeline_result.get("success", False),
|
||||||
|
"pipeline_id": pipeline_result.get("pipeline_id"),
|
||||||
|
"execution_time": self._format_float(
|
||||||
|
pipeline_result.get("execution_time", 0)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add final detection result
|
||||||
|
if "detection_result" in pipeline_result:
|
||||||
|
formatted["detection_result"] = self.format_detection_result(
|
||||||
|
pipeline_result["detection_result"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add branch results
|
||||||
|
if "branch_results" in pipeline_result and include_intermediate:
|
||||||
|
formatted["branch_results"] = {}
|
||||||
|
for branch_id, result in pipeline_result["branch_results"].items():
|
||||||
|
formatted["branch_results"][branch_id] = self._format_value(result)
|
||||||
|
|
||||||
|
# Add executed actions
|
||||||
|
if "executed_actions" in pipeline_result:
|
||||||
|
formatted["executed_actions"] = pipeline_result["executed_actions"]
|
||||||
|
|
||||||
|
# Add errors if any
|
||||||
|
if "errors" in pipeline_result:
|
||||||
|
formatted["errors"] = [
|
||||||
|
self.format_error_details(error) if isinstance(error, Exception)
|
||||||
|
else error
|
||||||
|
for error in pipeline_result["errors"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return self._clean_dict(formatted)
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_formatter = ResponseFormatter()
|
||||||
|
|
||||||
|
# Convenience functions for backward compatibility
|
||||||
|
def format_detection_for_websocket(
|
||||||
|
detection_result: Union[DetectionResult, Dict[str, Any], None],
|
||||||
|
include_tracking: bool = True,
|
||||||
|
include_metadata: bool = True
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Format detection result for WebSocket transmission."""
|
||||||
|
return _formatter.format_detection_result(
|
||||||
|
detection_result, include_tracking, include_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_tracking_state(tracking_state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Format tracking state information."""
|
||||||
|
return _formatter.format_tracking_state(tracking_state)
|
||||||
|
|
||||||
|
def format_error_response(error: Exception) -> Dict[str, Any]:
|
||||||
|
"""Format error details for response."""
|
||||||
|
return _formatter.format_error_details(error)
|
||||||
|
|
||||||
|
def to_compact_json(data: Any) -> str:
|
||||||
|
"""Convert data to compact JSON string."""
|
||||||
|
return _formatter.to_json(data)
|
622
detector_worker/communication/websocket_handler.py
Normal file
622
detector_worker/communication/websocket_handler.py
Normal file
|
@ -0,0 +1,622 @@
|
||||||
|
"""
|
||||||
|
WebSocket handler module.
|
||||||
|
|
||||||
|
This module manages WebSocket connections, message processing, heartbeat functionality,
|
||||||
|
and coordination between stream processing and detection pipelines.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Any, Optional, Callable, List, Set
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import WebSocket
|
||||||
|
from websockets.exceptions import ConnectionClosedError, WebSocketDisconnect
|
||||||
|
|
||||||
|
from ..core.config import config, subscription_to_camera, latest_frames
|
||||||
|
from ..core.constants import HEARTBEAT_INTERVAL
|
||||||
|
from ..core.exceptions import WebSocketError, StreamError
|
||||||
|
from ..streams.stream_manager import StreamManager
|
||||||
|
from ..streams.camera_monitor import CameraConnectionMonitor
|
||||||
|
from ..detection.detection_result import DetectionResult
|
||||||
|
from ..models.model_manager import ModelManager
|
||||||
|
from ..pipeline.pipeline_executor import PipelineExecutor
|
||||||
|
from ..storage.session_cache import SessionCache
|
||||||
|
from ..storage.redis_client import RedisClientManager
|
||||||
|
from ..utils.system_monitor import get_system_metrics
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logger = logging.getLogger("detector_worker.websocket_handler")
|
||||||
|
ws_logger = logging.getLogger("websocket")
|
||||||
|
|
||||||
|
# Type definitions for callbacks
|
||||||
|
MessageHandler = Callable[[Dict[str, Any]], asyncio.coroutine]
|
||||||
|
DetectionHandler = Callable[[str, Dict[str, Any], Any, WebSocket, Any, Dict[str, Any]], asyncio.coroutine]
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketHandler:
|
||||||
|
"""
|
||||||
|
Manages WebSocket connections and message processing for the detection worker.
|
||||||
|
|
||||||
|
This class handles:
|
||||||
|
- WebSocket lifecycle management
|
||||||
|
- Message routing and processing
|
||||||
|
- Heartbeat/state reporting
|
||||||
|
- Stream subscription management
|
||||||
|
- Detection result forwarding
|
||||||
|
- Session management
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stream_manager: StreamManager,
|
||||||
|
model_manager: ModelManager,
|
||||||
|
pipeline_executor: PipelineExecutor,
|
||||||
|
session_cache: SessionCache,
|
||||||
|
redis_client: Optional[RedisClientManager] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the WebSocket handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_manager: Manager for camera streams
|
||||||
|
model_manager: Manager for ML models
|
||||||
|
pipeline_executor: Pipeline execution engine
|
||||||
|
session_cache: Session state cache
|
||||||
|
redis_client: Optional Redis client for pub/sub
|
||||||
|
"""
|
||||||
|
self.stream_manager = stream_manager
|
||||||
|
self.model_manager = model_manager
|
||||||
|
self.pipeline_executor = pipeline_executor
|
||||||
|
self.session_cache = session_cache
|
||||||
|
self.redis_client = redis_client
|
||||||
|
|
||||||
|
# Connection state
|
||||||
|
self.websocket: Optional[WebSocket] = None
|
||||||
|
self.connected: bool = False
|
||||||
|
self.tasks: List[asyncio.Task] = []
|
||||||
|
|
||||||
|
# Message handlers
|
||||||
|
self.message_handlers: Dict[str, MessageHandler] = {
|
||||||
|
"subscribe": self._handle_subscribe,
|
||||||
|
"unsubscribe": self._handle_unsubscribe,
|
||||||
|
"requestState": self._handle_request_state,
|
||||||
|
"setSessionId": self._handle_set_session,
|
||||||
|
"patchSession": self._handle_patch_session,
|
||||||
|
"setProgressionStage": self._handle_set_progression_stage
|
||||||
|
}
|
||||||
|
|
||||||
|
# Session and display management
|
||||||
|
self.session_ids: Dict[str, str] = {} # display_identifier -> session_id
|
||||||
|
self.display_identifiers: Set[str] = set()
|
||||||
|
|
||||||
|
# Camera monitor
|
||||||
|
self.camera_monitor = CameraConnectionMonitor()
|
||||||
|
|
||||||
|
async def handle_connection(self, websocket: WebSocket) -> None:
|
||||||
|
"""
|
||||||
|
Main entry point for handling a WebSocket connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: The WebSocket connection to handle
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await websocket.accept()
|
||||||
|
self.websocket = websocket
|
||||||
|
self.connected = True
|
||||||
|
|
||||||
|
logger.info("WebSocket connection accepted")
|
||||||
|
|
||||||
|
# Create concurrent tasks
|
||||||
|
stream_task = asyncio.create_task(self._process_streams())
|
||||||
|
heartbeat_task = asyncio.create_task(self._send_heartbeat())
|
||||||
|
message_task = asyncio.create_task(self._process_messages())
|
||||||
|
|
||||||
|
self.tasks = [stream_task, heartbeat_task, message_task]
|
||||||
|
|
||||||
|
# Wait for tasks to complete
|
||||||
|
await asyncio.gather(heartbeat_task, message_task)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in WebSocket handler: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.connected = False
|
||||||
|
await self._cleanup()
|
||||||
|
|
||||||
|
async def _cleanup(self) -> None:
|
||||||
|
"""Clean up resources when connection closes."""
|
||||||
|
logger.info("Cleaning up WebSocket connection")
|
||||||
|
|
||||||
|
# Cancel all tasks
|
||||||
|
for task in self.tasks:
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Clean up streams
|
||||||
|
await self.stream_manager.cleanup_all_streams()
|
||||||
|
|
||||||
|
# Clean up models
|
||||||
|
self.model_manager.cleanup_all_models()
|
||||||
|
|
||||||
|
# Clear session state
|
||||||
|
self.session_cache.clear_all_sessions()
|
||||||
|
self.session_ids.clear()
|
||||||
|
self.display_identifiers.clear()
|
||||||
|
|
||||||
|
# Clear camera states
|
||||||
|
self.camera_monitor.clear_all_states()
|
||||||
|
|
||||||
|
logger.info("WebSocket cleanup completed")
|
||||||
|
|
||||||
|
async def _send_heartbeat(self) -> None:
|
||||||
|
"""Send periodic heartbeat/state reports to maintain connection."""
|
||||||
|
while self.connected:
|
||||||
|
try:
|
||||||
|
# Get system metrics
|
||||||
|
metrics = get_system_metrics()
|
||||||
|
|
||||||
|
# Get active streams info
|
||||||
|
active_streams = self.stream_manager.get_active_streams()
|
||||||
|
active_models = self.model_manager.get_loaded_models()
|
||||||
|
|
||||||
|
state_data = {
|
||||||
|
"type": "stateReport",
|
||||||
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||||
|
"data": {
|
||||||
|
"activeStreams": len(active_streams),
|
||||||
|
"loadedModels": len(active_models),
|
||||||
|
"cpuUsage": metrics.get("cpu_percent", 0),
|
||||||
|
"memoryUsage": metrics.get("memory_percent", 0),
|
||||||
|
"gpuUsage": metrics.get("gpu_percent", 0),
|
||||||
|
"gpuMemory": metrics.get("gpu_memory_percent", 0),
|
||||||
|
"uptime": time.time() - metrics.get("start_time", time.time())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ws_logger.info(f"TX -> {json.dumps(state_data, separators=(',', ':'))}")
|
||||||
|
await self.websocket.send_json(state_data)
|
||||||
|
|
||||||
|
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||||
|
|
||||||
|
except (WebSocketDisconnect, ConnectionClosedError):
|
||||||
|
logger.info("WebSocket disconnected during heartbeat")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending heartbeat: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
async def _process_messages(self) -> None:
|
||||||
|
"""Process incoming WebSocket messages."""
|
||||||
|
while self.connected:
|
||||||
|
try:
|
||||||
|
text_data = await self.websocket.receive_text()
|
||||||
|
ws_logger.info(f"RX <- {text_data}")
|
||||||
|
|
||||||
|
data = json.loads(text_data)
|
||||||
|
msg_type = data.get("type")
|
||||||
|
|
||||||
|
if msg_type in self.message_handlers:
|
||||||
|
handler = self.message_handlers[msg_type]
|
||||||
|
await handler(data)
|
||||||
|
else:
|
||||||
|
logger.error(f"Unknown message type: {msg_type}")
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error("Received invalid JSON message")
|
||||||
|
except (WebSocketDisconnect, ConnectionClosedError) as e:
|
||||||
|
logger.warning(f"WebSocket disconnected: {e}")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling message: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
break
|
||||||
|
|
||||||
|
async def _process_streams(self) -> None:
|
||||||
|
"""Process active camera streams and run detection pipelines."""
|
||||||
|
while self.connected:
|
||||||
|
try:
|
||||||
|
active_streams = self.stream_manager.get_active_streams()
|
||||||
|
|
||||||
|
if active_streams:
|
||||||
|
# Process each active stream
|
||||||
|
tasks = []
|
||||||
|
for camera_id, stream_info in active_streams.items():
|
||||||
|
# Get latest frame
|
||||||
|
frame = self.stream_manager.get_latest_frame(camera_id)
|
||||||
|
if frame is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get model for this camera
|
||||||
|
model_id = stream_info.get("modelId")
|
||||||
|
if not model_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_tree = self.model_manager.get_model(camera_id, model_id)
|
||||||
|
if not model_tree:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create detection task
|
||||||
|
persistent_data = self.session_cache.get_persistent_data(camera_id)
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._handle_detection(
|
||||||
|
camera_id, stream_info, frame,
|
||||||
|
self.websocket, model_tree, persistent_data
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all detection tasks
|
||||||
|
if tasks:
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Update persistent data
|
||||||
|
for i, (camera_id, _) in enumerate(active_streams.items()):
|
||||||
|
if i < len(results) and isinstance(results[i], dict):
|
||||||
|
self.session_cache.update_persistent_data(camera_id, results[i])
|
||||||
|
|
||||||
|
# Polling interval
|
||||||
|
poll_interval = config.get("poll_interval_ms", 100) / 1000.0
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Stream processing cancelled")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in stream processing: {e}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
async def _handle_detection(
|
||||||
|
self,
|
||||||
|
camera_id: str,
|
||||||
|
stream_info: Dict[str, Any],
|
||||||
|
frame: Any,
|
||||||
|
websocket: WebSocket,
|
||||||
|
model_tree: Any,
|
||||||
|
persistent_data: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Handle detection for a single camera frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
stream_info: Stream configuration
|
||||||
|
frame: Video frame to process
|
||||||
|
websocket: WebSocket connection
|
||||||
|
model_tree: Model pipeline tree
|
||||||
|
persistent_data: Persistent data for this camera
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated persistent data
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check camera connection state
|
||||||
|
if self.camera_monitor.should_notify_disconnection(camera_id):
|
||||||
|
await self._send_disconnection_notification(camera_id, stream_info)
|
||||||
|
return persistent_data
|
||||||
|
|
||||||
|
# Apply crop if specified
|
||||||
|
cropped_frame = self._apply_crop(frame, stream_info)
|
||||||
|
|
||||||
|
# Get session pipeline state
|
||||||
|
pipeline_state = self.session_cache.get_session_pipeline_state(camera_id)
|
||||||
|
|
||||||
|
# Run detection pipeline
|
||||||
|
detection_result = await self.pipeline_executor.execute_pipeline(
|
||||||
|
camera_id,
|
||||||
|
stream_info,
|
||||||
|
cropped_frame,
|
||||||
|
model_tree,
|
||||||
|
persistent_data,
|
||||||
|
pipeline_state
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send detection result
|
||||||
|
if detection_result:
|
||||||
|
await self._send_detection_result(
|
||||||
|
camera_id, stream_info, detection_result
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle camera reconnection
|
||||||
|
if self.camera_monitor.should_notify_reconnection(camera_id):
|
||||||
|
self.camera_monitor.mark_reconnection_notified(camera_id)
|
||||||
|
logger.info(f"Camera {camera_id} reconnected successfully")
|
||||||
|
|
||||||
|
return persistent_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in detection handling for camera {camera_id}: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return persistent_data
|
||||||
|
|
||||||
|
async def _handle_subscribe(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle stream subscription request."""
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
subscription_id = payload.get("subscriptionIdentifier")
|
||||||
|
|
||||||
|
if not subscription_id:
|
||||||
|
logger.error("Missing subscriptionIdentifier in subscribe payload")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract display and camera IDs
|
||||||
|
parts = subscription_id.split(";")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
display_id = parts[0]
|
||||||
|
camera_id = parts[1]
|
||||||
|
self.display_identifiers.add(display_id)
|
||||||
|
else:
|
||||||
|
camera_id = subscription_id
|
||||||
|
|
||||||
|
# Store subscription mapping
|
||||||
|
subscription_to_camera[subscription_id] = camera_id
|
||||||
|
|
||||||
|
# Start camera stream
|
||||||
|
await self.stream_manager.start_stream(camera_id, payload)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model_id = payload.get("modelId")
|
||||||
|
model_url = payload.get("modelUrl")
|
||||||
|
if model_id and model_url:
|
||||||
|
await self.model_manager.load_model(camera_id, model_id, model_url)
|
||||||
|
|
||||||
|
logger.info(f"Subscribed to stream: {subscription_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling subscription: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle stream unsubscription request."""
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
subscription_id = payload.get("subscriptionIdentifier")
|
||||||
|
|
||||||
|
if not subscription_id:
|
||||||
|
logger.error("Missing subscriptionIdentifier in unsubscribe payload")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get camera ID from subscription
|
||||||
|
camera_id = subscription_to_camera.get(subscription_id)
|
||||||
|
if not camera_id:
|
||||||
|
logger.warning(f"No camera found for subscription: {subscription_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stop stream
|
||||||
|
await self.stream_manager.stop_stream(camera_id)
|
||||||
|
|
||||||
|
# Unload model
|
||||||
|
self.model_manager.unload_models(camera_id)
|
||||||
|
|
||||||
|
# Clean up mappings
|
||||||
|
subscription_to_camera.pop(subscription_id, None)
|
||||||
|
|
||||||
|
# Clean up session state
|
||||||
|
self.session_cache.clear_session(camera_id)
|
||||||
|
|
||||||
|
logger.info(f"Unsubscribed from stream: {subscription_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling unsubscription: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def _handle_request_state(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle state request message."""
|
||||||
|
# Send immediate state report
|
||||||
|
await self._send_heartbeat()
|
||||||
|
|
||||||
|
async def _handle_set_session(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle setSessionId message."""
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
display_id = payload.get("displayIdentifier")
|
||||||
|
session_id = payload.get("sessionId")
|
||||||
|
|
||||||
|
if display_id and session_id:
|
||||||
|
self.session_ids[display_id] = session_id
|
||||||
|
|
||||||
|
# Update session for all cameras of this display
|
||||||
|
with self.stream_manager.streams_lock:
|
||||||
|
for camera_id, stream in self.stream_manager.streams.items():
|
||||||
|
if stream["subscriptionIdentifier"].startswith(display_id + ";"):
|
||||||
|
self.session_cache.update_session_id(camera_id, session_id)
|
||||||
|
|
||||||
|
# Send acknowledgment
|
||||||
|
response = {
|
||||||
|
"type": "ack",
|
||||||
|
"requestId": data.get("requestId", str(uuid.uuid4())),
|
||||||
|
"code": "200",
|
||||||
|
"data": {
|
||||||
|
"message": f"Session ID set for display {display_id}",
|
||||||
|
"sessionId": session_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ws_logger.info(f"TX -> {json.dumps(response, separators=(',', ':'))}")
|
||||||
|
await self.websocket.send_json(response)
|
||||||
|
|
||||||
|
logger.info(f"Set session {session_id} for display {display_id}")
|
||||||
|
|
||||||
|
async def _handle_patch_session(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle patchSession message."""
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
session_id = payload.get("sessionId")
|
||||||
|
patch_data = payload.get("data", {})
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
# Store patch data (could be used for session updates)
|
||||||
|
logger.info(f"Received patch for session {session_id}: {patch_data}")
|
||||||
|
|
||||||
|
# Send acknowledgment
|
||||||
|
response = {
|
||||||
|
"type": "ack",
|
||||||
|
"requestId": data.get("requestId", str(uuid.uuid4())),
|
||||||
|
"code": "200",
|
||||||
|
"data": {
|
||||||
|
"message": f"Session {session_id} patched successfully",
|
||||||
|
"sessionId": session_id,
|
||||||
|
"patchData": patch_data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ws_logger.info(f"TX -> {json.dumps(response, separators=(',', ':'))}")
|
||||||
|
await self.websocket.send_json(response)
|
||||||
|
|
||||||
|
async def _handle_set_progression_stage(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Handle setProgressionStage message."""
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
display_id = payload.get("displayIdentifier")
|
||||||
|
progression_stage = payload.get("progressionStage")
|
||||||
|
|
||||||
|
logger.info(f"🏁 PROGRESSION STAGE RECEIVED: displayId={display_id}, stage={progression_stage}")
|
||||||
|
|
||||||
|
if not display_id:
|
||||||
|
logger.warning("Missing displayIdentifier in setProgressionStage")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find all cameras for this display
|
||||||
|
affected_cameras = []
|
||||||
|
with self.stream_manager.streams_lock:
|
||||||
|
for camera_id, stream in self.stream_manager.streams.items():
|
||||||
|
if stream["subscriptionIdentifier"].startswith(display_id + ";"):
|
||||||
|
affected_cameras.append(camera_id)
|
||||||
|
|
||||||
|
logger.debug(f"🎯 Found {len(affected_cameras)} cameras for display {display_id}: {affected_cameras}")
|
||||||
|
|
||||||
|
# Update progression stage for each camera
|
||||||
|
for camera_id in affected_cameras:
|
||||||
|
pipeline_state = self.session_cache.get_or_init_session_pipeline_state(camera_id)
|
||||||
|
current_mode = pipeline_state.get("mode", "validation_detecting")
|
||||||
|
|
||||||
|
if progression_stage == "car_fueling":
|
||||||
|
# Stop YOLO inference during fueling
|
||||||
|
if current_mode == "lightweight":
|
||||||
|
pipeline_state["yolo_inference_enabled"] = False
|
||||||
|
pipeline_state["progression_stage"] = "car_fueling"
|
||||||
|
logger.info(f"⏸️ Camera {camera_id}: YOLO inference DISABLED for car_fueling stage")
|
||||||
|
else:
|
||||||
|
logger.debug(f"📊 Camera {camera_id}: car_fueling received but not in lightweight mode (mode: {current_mode})")
|
||||||
|
|
||||||
|
elif progression_stage == "car_waitpayment":
|
||||||
|
# Resume YOLO inference for absence counter
|
||||||
|
pipeline_state["yolo_inference_enabled"] = True
|
||||||
|
pipeline_state["progression_stage"] = "car_waitpayment"
|
||||||
|
logger.info(f"▶️ Camera {camera_id}: YOLO inference RE-ENABLED for car_waitpayment stage")
|
||||||
|
|
||||||
|
elif progression_stage == "welcome":
|
||||||
|
# Ignore welcome messages during car_waitpayment
|
||||||
|
current_progression = pipeline_state.get("progression_stage")
|
||||||
|
if current_progression == "car_waitpayment":
|
||||||
|
logger.info(f"🚫 Camera {camera_id}: IGNORING welcome stage (currently in car_waitpayment)")
|
||||||
|
else:
|
||||||
|
pipeline_state["progression_stage"] = "welcome"
|
||||||
|
logger.info(f"🎉 Camera {camera_id}: Progression stage set to welcome")
|
||||||
|
|
||||||
|
elif progression_stage in ["car_wait_staff"]:
|
||||||
|
pipeline_state["progression_stage"] = progression_stage
|
||||||
|
logger.info(f"📋 Camera {camera_id}: Progression stage set to {progression_stage}")
|
||||||
|
|
||||||
|
async def _send_detection_result(
|
||||||
|
self,
|
||||||
|
camera_id: str,
|
||||||
|
stream_info: Dict[str, Any],
|
||||||
|
detection_result: DetectionResult
|
||||||
|
) -> None:
|
||||||
|
"""Send detection result over WebSocket."""
|
||||||
|
detection_data = {
|
||||||
|
"type": "imageDetection",
|
||||||
|
"subscriptionIdentifier": stream_info["subscriptionIdentifier"],
|
||||||
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||||
|
"data": {
|
||||||
|
"detection": detection_result.to_dict(),
|
||||||
|
"modelId": stream_info["modelId"],
|
||||||
|
"modelName": stream_info["modelName"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
ws_logger.info(f"TX -> {json.dumps(detection_data, separators=(',', ':'))}")
|
||||||
|
await self.websocket.send_json(detection_data)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "websocket.close" in str(e):
|
||||||
|
logger.warning(f"WebSocket closed - cannot send detection for camera {camera_id}")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _send_disconnection_notification(
|
||||||
|
self,
|
||||||
|
camera_id: str,
|
||||||
|
stream_info: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Send camera disconnection notification."""
|
||||||
|
logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null")
|
||||||
|
|
||||||
|
# Clear cached data
|
||||||
|
self.session_cache.clear_session(camera_id)
|
||||||
|
|
||||||
|
# Send null detection
|
||||||
|
detection_data = {
|
||||||
|
"type": "imageDetection",
|
||||||
|
"subscriptionIdentifier": stream_info["subscriptionIdentifier"],
|
||||||
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||||
|
"data": {
|
||||||
|
"detection": None,
|
||||||
|
"modelId": stream_info["modelId"],
|
||||||
|
"modelName": stream_info["modelName"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
ws_logger.info(f"TX -> {json.dumps(detection_data, separators=(',', ':'))}")
|
||||||
|
await self.websocket.send_json(detection_data)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "websocket.close" in str(e):
|
||||||
|
logger.warning(f"WebSocket closed - cannot send disconnection signal for camera {camera_id}")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
self.camera_monitor.mark_disconnection_notified(camera_id)
|
||||||
|
logger.info(f"📡 SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session")
|
||||||
|
|
||||||
|
def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any:
|
||||||
|
"""Apply crop to frame if crop coordinates are specified."""
|
||||||
|
crop_coords = [
|
||||||
|
stream_info.get("cropX1"),
|
||||||
|
stream_info.get("cropY1"),
|
||||||
|
stream_info.get("cropX2"),
|
||||||
|
stream_info.get("cropY2")
|
||||||
|
]
|
||||||
|
|
||||||
|
if all(coord is not None for coord in crop_coords):
|
||||||
|
x1, y1, x2, y2 = crop_coords
|
||||||
|
return frame[y1:y2, x1:x2]
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience function for backward compatibility
|
||||||
|
async def handle_websocket_connection(
|
||||||
|
websocket: WebSocket,
|
||||||
|
stream_manager: StreamManager,
|
||||||
|
model_manager: ModelManager,
|
||||||
|
pipeline_executor: PipelineExecutor,
|
||||||
|
session_cache: SessionCache,
|
||||||
|
redis_client: Optional[RedisClientManager] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Handle a WebSocket connection using the WebSocketHandler.
|
||||||
|
|
||||||
|
This is a convenience function that creates a handler instance
|
||||||
|
and processes the connection.
|
||||||
|
"""
|
||||||
|
handler = WebSocketHandler(
|
||||||
|
stream_manager,
|
||||||
|
model_manager,
|
||||||
|
pipeline_executor,
|
||||||
|
session_cache,
|
||||||
|
redis_client
|
||||||
|
)
|
||||||
|
await handler.handle_connection(websocket)
|
441
detector_worker/models/model_manager.py
Normal file
441
detector_worker/models/model_manager.py
Normal file
|
@ -0,0 +1,441 @@
|
||||||
|
"""
|
||||||
|
Model manager module.
|
||||||
|
|
||||||
|
This module handles ML model loading, caching, and lifecycle management
|
||||||
|
for the detection worker.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Dict, Any, Optional, List, Set, Tuple
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from ..core.config import MODELS_DIR
|
||||||
|
from ..core.exceptions import ModelLoadError
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logger = logging.getLogger("detector_worker.model_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistry:
|
||||||
|
"""
|
||||||
|
Registry for loaded models.
|
||||||
|
|
||||||
|
Maintains a reference count for each model to enable sharing
|
||||||
|
between multiple cameras.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the model registry."""
|
||||||
|
self.models: Dict[str, Dict[str, Any]] = {} # model_id -> model_info
|
||||||
|
self.references: Dict[str, Set[str]] = {} # model_id -> set of camera_ids
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def register_model(self, model_id: str, model_data: Any, model_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Register a model in the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Unique model identifier
|
||||||
|
model_data: Loaded model data
|
||||||
|
model_path: Path to model file
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
self.models[model_id] = {
|
||||||
|
"model": model_data,
|
||||||
|
"path": model_path,
|
||||||
|
"loaded_at": os.path.getmtime(model_path)
|
||||||
|
}
|
||||||
|
if model_id not in self.references:
|
||||||
|
self.references[model_id] = set()
|
||||||
|
|
||||||
|
def add_reference(self, model_id: str, camera_id: str) -> None:
|
||||||
|
"""Add a reference to a model from a camera."""
|
||||||
|
with self.lock:
|
||||||
|
if model_id in self.references:
|
||||||
|
self.references[model_id].add(camera_id)
|
||||||
|
|
||||||
|
def remove_reference(self, model_id: str, camera_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a reference to a model from a camera.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model has no more references and can be unloaded
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
if model_id in self.references:
|
||||||
|
self.references[model_id].discard(camera_id)
|
||||||
|
return len(self.references[model_id]) == 0
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_model(self, model_id: str) -> Optional[Any]:
|
||||||
|
"""Get a model from the registry."""
|
||||||
|
with self.lock:
|
||||||
|
model_info = self.models.get(model_id)
|
||||||
|
return model_info["model"] if model_info else None
|
||||||
|
|
||||||
|
def unregister_model(self, model_id: str) -> None:
|
||||||
|
"""Remove a model from the registry."""
|
||||||
|
with self.lock:
|
||||||
|
self.models.pop(model_id, None)
|
||||||
|
self.references.pop(model_id, None)
|
||||||
|
|
||||||
|
def get_loaded_models(self) -> List[str]:
|
||||||
|
"""Get list of loaded model IDs."""
|
||||||
|
with self.lock:
|
||||||
|
return list(self.models.keys())
|
||||||
|
|
||||||
|
def get_reference_count(self, model_id: str) -> int:
|
||||||
|
"""Get number of references to a model."""
|
||||||
|
with self.lock:
|
||||||
|
return len(self.references.get(model_id, set()))
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all models from registry."""
|
||||||
|
with self.lock:
|
||||||
|
self.models.clear()
|
||||||
|
self.references.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
"""
|
||||||
|
Manages ML model loading, caching, and lifecycle.
|
||||||
|
|
||||||
|
This class handles:
|
||||||
|
- Model downloading and caching
|
||||||
|
- Model loading with proper error handling
|
||||||
|
- Reference counting for model sharing
|
||||||
|
- Model cleanup and memory management
|
||||||
|
- Pipeline model tree management
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, models_dir: str = MODELS_DIR):
|
||||||
|
"""
|
||||||
|
Initialize the model manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models_dir: Directory to cache downloaded models
|
||||||
|
"""
|
||||||
|
self.models_dir = models_dir
|
||||||
|
self.registry = ModelRegistry()
|
||||||
|
self.models_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Camera to models mapping
|
||||||
|
self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree}
|
||||||
|
|
||||||
|
# Pipeline loader will be injected
|
||||||
|
self.pipeline_loader = None
|
||||||
|
|
||||||
|
# Create models directory if it doesn't exist
|
||||||
|
os.makedirs(self.models_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def set_pipeline_loader(self, pipeline_loader: Any) -> None:
|
||||||
|
"""
|
||||||
|
Set the pipeline loader instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline_loader: Pipeline loader to use for loading models
|
||||||
|
"""
|
||||||
|
self.pipeline_loader = pipeline_loader
|
||||||
|
|
||||||
|
async def load_model(
|
||||||
|
self,
|
||||||
|
camera_id: str,
|
||||||
|
model_id: str,
|
||||||
|
model_url: str,
|
||||||
|
force_reload: bool = False
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Load a model for a specific camera.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
model_url: URL or path to model file
|
||||||
|
force_reload: Force reload even if cached
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded model tree
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ModelLoadError: If model loading fails
|
||||||
|
"""
|
||||||
|
if not self.pipeline_loader:
|
||||||
|
raise ModelLoadError("Pipeline loader not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if model is already loaded for this camera
|
||||||
|
with self.models_lock:
|
||||||
|
if camera_id in self.camera_models and model_id in self.camera_models[camera_id]:
|
||||||
|
if not force_reload:
|
||||||
|
logger.info(f"Model {model_id} already loaded for camera {camera_id}")
|
||||||
|
return self.camera_models[camera_id][model_id]
|
||||||
|
|
||||||
|
# Check if model is in registry
|
||||||
|
cached_model = self.registry.get_model(model_id)
|
||||||
|
if cached_model and not force_reload:
|
||||||
|
# Add reference and return cached model
|
||||||
|
self.registry.add_reference(model_id, camera_id)
|
||||||
|
with self.models_lock:
|
||||||
|
if camera_id not in self.camera_models:
|
||||||
|
self.camera_models[camera_id] = {}
|
||||||
|
self.camera_models[camera_id][model_id] = cached_model
|
||||||
|
logger.info(f"Using cached model {model_id} for camera {camera_id}")
|
||||||
|
return cached_model
|
||||||
|
|
||||||
|
# Download or locate model file
|
||||||
|
model_path = await self._get_model_path(model_url, model_id)
|
||||||
|
|
||||||
|
# Load model using pipeline loader
|
||||||
|
logger.info(f"Loading model {model_id} from {model_path}")
|
||||||
|
model_tree = await self.pipeline_loader.load_pipeline(model_path)
|
||||||
|
|
||||||
|
# Register in registry
|
||||||
|
self.registry.register_model(model_id, model_tree, model_path)
|
||||||
|
self.registry.add_reference(model_id, camera_id)
|
||||||
|
|
||||||
|
# Store in camera models
|
||||||
|
with self.models_lock:
|
||||||
|
if camera_id not in self.camera_models:
|
||||||
|
self.camera_models[camera_id] = {}
|
||||||
|
self.camera_models[camera_id][model_id] = model_tree
|
||||||
|
|
||||||
|
logger.info(f"Successfully loaded model {model_id} for camera {camera_id}")
|
||||||
|
return model_tree
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model {model_id}: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
|
||||||
|
|
||||||
|
async def _get_model_path(self, model_url: str, model_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Get local path for a model, downloading if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_url: URL or local path to model
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Local file path to model
|
||||||
|
"""
|
||||||
|
# Check if it's already a local path
|
||||||
|
if os.path.exists(model_url):
|
||||||
|
return model_url
|
||||||
|
|
||||||
|
# Parse URL
|
||||||
|
parsed = urlparse(model_url)
|
||||||
|
|
||||||
|
# Check if it's a file:// URL
|
||||||
|
if parsed.scheme == 'file':
|
||||||
|
return parsed.path
|
||||||
|
|
||||||
|
# For HTTP/HTTPS URLs, download to cache
|
||||||
|
if parsed.scheme in ['http', 'https']:
|
||||||
|
# Generate cache filename
|
||||||
|
filename = os.path.basename(parsed.path)
|
||||||
|
if not filename:
|
||||||
|
filename = f"{model_id}.mpta"
|
||||||
|
|
||||||
|
cache_path = os.path.join(self.models_dir, filename)
|
||||||
|
|
||||||
|
# Check if already cached
|
||||||
|
if os.path.exists(cache_path):
|
||||||
|
logger.info(f"Using cached model file: {cache_path}")
|
||||||
|
return cache_path
|
||||||
|
|
||||||
|
# Download model
|
||||||
|
logger.info(f"Downloading model from {model_url}")
|
||||||
|
await self._download_model(model_url, cache_path)
|
||||||
|
return cache_path
|
||||||
|
|
||||||
|
# For other schemes or no scheme, assume local path
|
||||||
|
return model_url
|
||||||
|
|
||||||
|
async def _download_model(self, url: str, destination: str) -> None:
|
||||||
|
"""
|
||||||
|
Download a model file from URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to download from
|
||||||
|
destination: Local path to save to
|
||||||
|
"""
|
||||||
|
import aiohttp
|
||||||
|
import aiofiles
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Get total size if available
|
||||||
|
total_size = response.headers.get('Content-Length')
|
||||||
|
if total_size:
|
||||||
|
total_size = int(total_size)
|
||||||
|
logger.info(f"Downloading {total_size / (1024*1024):.2f} MB")
|
||||||
|
|
||||||
|
# Download to temporary file first
|
||||||
|
temp_path = f"{destination}.tmp"
|
||||||
|
downloaded = 0
|
||||||
|
|
||||||
|
async with aiofiles.open(temp_path, 'wb') as f:
|
||||||
|
async for chunk in response.content.iter_chunked(8192):
|
||||||
|
await f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
|
||||||
|
# Log progress
|
||||||
|
if total_size and downloaded % (1024 * 1024) == 0:
|
||||||
|
progress = (downloaded / total_size) * 100
|
||||||
|
logger.info(f"Download progress: {progress:.1f}%")
|
||||||
|
|
||||||
|
# Move to final destination
|
||||||
|
os.rename(temp_path, destination)
|
||||||
|
logger.info(f"Model downloaded successfully to {destination}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up temporary file if exists
|
||||||
|
temp_path = f"{destination}.tmp"
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.remove(temp_path)
|
||||||
|
raise ModelLoadError(f"Failed to download model: {e}")
|
||||||
|
|
||||||
|
def get_model(self, camera_id: str, model_id: str) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
Get a loaded model for a camera.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model tree if loaded, None otherwise
|
||||||
|
"""
|
||||||
|
with self.models_lock:
|
||||||
|
camera_models = self.camera_models.get(camera_id, {})
|
||||||
|
return camera_models.get(model_id)
|
||||||
|
|
||||||
|
def unload_models(self, camera_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Unload all models for a camera.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
"""
|
||||||
|
with self.models_lock:
|
||||||
|
if camera_id not in self.camera_models:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove references for each model
|
||||||
|
for model_id in self.camera_models[camera_id]:
|
||||||
|
should_unload = self.registry.remove_reference(model_id, camera_id)
|
||||||
|
|
||||||
|
if should_unload:
|
||||||
|
logger.info(f"Unloading model {model_id} (no more references)")
|
||||||
|
self.registry.unregister_model(model_id)
|
||||||
|
|
||||||
|
# Clean up model if pipeline loader supports it
|
||||||
|
if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_model'):
|
||||||
|
try:
|
||||||
|
self.pipeline_loader.cleanup_model(model_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cleaning up model {model_id}: {e}")
|
||||||
|
|
||||||
|
# Remove camera entry
|
||||||
|
del self.camera_models[camera_id]
|
||||||
|
logger.info(f"Unloaded all models for camera {camera_id}")
|
||||||
|
|
||||||
|
def cleanup_all_models(self) -> None:
|
||||||
|
"""Clean up all loaded models."""
|
||||||
|
logger.info("Cleaning up all loaded models")
|
||||||
|
|
||||||
|
with self.models_lock:
|
||||||
|
# Get list of cameras to clean up
|
||||||
|
cameras = list(self.camera_models.keys())
|
||||||
|
|
||||||
|
# Unload models for each camera
|
||||||
|
for camera_id in cameras:
|
||||||
|
self.unload_models(camera_id)
|
||||||
|
|
||||||
|
# Clear registry
|
||||||
|
self.registry.clear()
|
||||||
|
|
||||||
|
# Clean up pipeline loader if it has cleanup
|
||||||
|
if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_all'):
|
||||||
|
try:
|
||||||
|
self.pipeline_loader.cleanup_all()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in pipeline loader cleanup: {e}")
|
||||||
|
|
||||||
|
logger.info("Model cleanup completed")
|
||||||
|
|
||||||
|
def get_loaded_models(self) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Get information about loaded models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping model IDs to list of camera IDs using them
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
with self.models_lock:
|
||||||
|
for model_id in self.registry.get_loaded_models():
|
||||||
|
cameras = []
|
||||||
|
for camera_id, models in self.camera_models.items():
|
||||||
|
if model_id in models:
|
||||||
|
cameras.append(camera_id)
|
||||||
|
result[model_id] = cameras
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_model_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get statistics about loaded models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with model statistics
|
||||||
|
"""
|
||||||
|
with self.models_lock:
|
||||||
|
total_models = len(self.registry.get_loaded_models())
|
||||||
|
total_cameras = len(self.camera_models)
|
||||||
|
|
||||||
|
# Count total model instances
|
||||||
|
total_instances = sum(
|
||||||
|
len(models) for models in self.camera_models.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get cache size
|
||||||
|
cache_size = 0
|
||||||
|
if os.path.exists(self.models_dir):
|
||||||
|
for filename in os.listdir(self.models_dir):
|
||||||
|
filepath = os.path.join(self.models_dir, filename)
|
||||||
|
if os.path.isfile(filepath):
|
||||||
|
cache_size += os.path.getsize(filepath)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_models": total_models,
|
||||||
|
"total_cameras": total_cameras,
|
||||||
|
"total_instances": total_instances,
|
||||||
|
"cache_size_mb": round(cache_size / (1024 * 1024), 2),
|
||||||
|
"models_dir": self.models_dir
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global model manager instance
|
||||||
|
_model_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_manager() -> ModelManager:
|
||||||
|
"""Get or create the global model manager instance."""
|
||||||
|
global _model_manager
|
||||||
|
if _model_manager is None:
|
||||||
|
_model_manager = ModelManager()
|
||||||
|
return _model_manager
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions for backward compatibility
|
||||||
|
def initialize_model_manager(models_dir: str = MODELS_DIR) -> ModelManager:
|
||||||
|
"""Initialize the global model manager."""
|
||||||
|
global _model_manager
|
||||||
|
_model_manager = ModelManager(models_dir)
|
||||||
|
return _model_manager
|
486
detector_worker/models/pipeline_loader.py
Normal file
486
detector_worker/models/pipeline_loader.py
Normal file
|
@ -0,0 +1,486 @@
|
||||||
|
"""
|
||||||
|
Pipeline loader module.
|
||||||
|
|
||||||
|
This module handles loading and parsing of MPTA (Machine Learning Pipeline Archive)
|
||||||
|
files, which contain model configurations and pipeline definitions.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import zipfile
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
from typing import Dict, Any, Optional, List, Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ..core.exceptions import ModelLoadError, PipelineError
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logger = logging.getLogger("detector_worker.pipeline_loader")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineNode:
|
||||||
|
"""Represents a node in the pipeline tree."""
|
||||||
|
model_id: str
|
||||||
|
model_file: str
|
||||||
|
model_path: Optional[str] = None
|
||||||
|
model: Optional[Any] = None # Loaded model instance
|
||||||
|
|
||||||
|
# Node configuration
|
||||||
|
multi_class: bool = False
|
||||||
|
expected_classes: List[str] = field(default_factory=list)
|
||||||
|
trigger_classes: List[str] = field(default_factory=list)
|
||||||
|
min_confidence: float = 0.5
|
||||||
|
max_detections: Optional[int] = None
|
||||||
|
|
||||||
|
# Cropping configuration
|
||||||
|
crop: bool = False
|
||||||
|
crop_class: Optional[str] = None
|
||||||
|
crop_expand_ratio: float = 1.0
|
||||||
|
|
||||||
|
# Actions configuration
|
||||||
|
actions: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
parallel_actions: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Branch configuration
|
||||||
|
branches: List['PipelineNode'] = field(default_factory=list)
|
||||||
|
parallel: bool = False
|
||||||
|
|
||||||
|
# Detection settings
|
||||||
|
yolo_settings: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
track_classes: Optional[List[str]] = None
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineConfig:
|
||||||
|
"""Pipeline configuration from pipeline.json."""
|
||||||
|
pipeline_id: str
|
||||||
|
version: str = "1.0"
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
# Database configuration
|
||||||
|
database_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
# Redis configuration
|
||||||
|
redis_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
# Global settings
|
||||||
|
global_settings: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Root pipeline node
|
||||||
|
root: Optional[PipelineNode] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineLoader:
|
||||||
|
"""
|
||||||
|
Loads and manages ML pipeline configurations.
|
||||||
|
|
||||||
|
This class handles:
|
||||||
|
- MPTA file extraction and parsing
|
||||||
|
- Pipeline configuration validation
|
||||||
|
- Model file management
|
||||||
|
- Pipeline tree construction
|
||||||
|
- Resource cleanup
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temp_dir: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the pipeline loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temp_dir: Temporary directory for extracting MPTA files
|
||||||
|
"""
|
||||||
|
self.temp_dir = temp_dir or tempfile.gettempdir()
|
||||||
|
self.extracted_paths: Dict[str, str] = {} # mpta_path -> extracted_dir
|
||||||
|
self.loaded_models: Dict[str, Any] = {} # model_path -> model_instance
|
||||||
|
|
||||||
|
async def load_pipeline(self, mpta_path: str) -> PipelineNode:
|
||||||
|
"""
|
||||||
|
Load a pipeline from an MPTA file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mpta_path: Path to MPTA file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Root pipeline node
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ModelLoadError: If loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract MPTA if not already extracted
|
||||||
|
extracted_dir = await self._extract_mpta(mpta_path)
|
||||||
|
|
||||||
|
# Load pipeline configuration
|
||||||
|
pipeline_json_path = os.path.join(extracted_dir, "pipeline.json")
|
||||||
|
if not os.path.exists(pipeline_json_path):
|
||||||
|
raise ModelLoadError(f"pipeline.json not found in {mpta_path}")
|
||||||
|
|
||||||
|
with open(pipeline_json_path, 'r') as f:
|
||||||
|
config_data = json.load(f)
|
||||||
|
|
||||||
|
# Parse pipeline configuration
|
||||||
|
pipeline_config = self._parse_pipeline_config(config_data, extracted_dir)
|
||||||
|
|
||||||
|
# Validate pipeline
|
||||||
|
self._validate_pipeline(pipeline_config)
|
||||||
|
|
||||||
|
# Load models for the pipeline
|
||||||
|
await self._load_pipeline_models(pipeline_config.root, extracted_dir)
|
||||||
|
|
||||||
|
logger.info(f"Successfully loaded pipeline from {mpta_path}")
|
||||||
|
return pipeline_config.root
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load pipeline from {mpta_path}: {e}")
|
||||||
|
raise ModelLoadError(f"Failed to load pipeline: {e}")
|
||||||
|
|
||||||
|
async def _extract_mpta(self, mpta_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract MPTA file to temporary directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mpta_path: Path to MPTA file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to extracted directory
|
||||||
|
"""
|
||||||
|
# Check if already extracted
|
||||||
|
if mpta_path in self.extracted_paths:
|
||||||
|
extracted_dir = self.extracted_paths[mpta_path]
|
||||||
|
if os.path.exists(extracted_dir):
|
||||||
|
return extracted_dir
|
||||||
|
|
||||||
|
# Create extraction directory
|
||||||
|
mpta_name = os.path.splitext(os.path.basename(mpta_path))[0]
|
||||||
|
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
|
||||||
|
|
||||||
|
# Extract MPTA
|
||||||
|
logger.info(f"Extracting MPTA file: {mpta_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
|
||||||
|
# Clean existing directory if exists
|
||||||
|
if os.path.exists(extracted_dir):
|
||||||
|
shutil.rmtree(extracted_dir)
|
||||||
|
|
||||||
|
os.makedirs(extracted_dir)
|
||||||
|
zip_ref.extractall(extracted_dir)
|
||||||
|
|
||||||
|
self.extracted_paths[mpta_path] = extracted_dir
|
||||||
|
logger.info(f"Extracted to: {extracted_dir}")
|
||||||
|
|
||||||
|
return extracted_dir
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ModelLoadError(f"Failed to extract MPTA: {e}")
|
||||||
|
|
||||||
|
def _parse_pipeline_config(
|
||||||
|
self,
|
||||||
|
config_data: Dict[str, Any],
|
||||||
|
base_dir: str
|
||||||
|
) -> PipelineConfig:
|
||||||
|
"""
|
||||||
|
Parse pipeline configuration from JSON.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_data: Pipeline JSON data
|
||||||
|
base_dir: Base directory for model files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed pipeline configuration
|
||||||
|
"""
|
||||||
|
# Create pipeline config
|
||||||
|
pipeline_config = PipelineConfig(
|
||||||
|
pipeline_id=config_data.get("pipelineId", "unknown"),
|
||||||
|
version=config_data.get("version", "1.0"),
|
||||||
|
description=config_data.get("description", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse database config
|
||||||
|
if "database" in config_data:
|
||||||
|
pipeline_config.database_config = config_data["database"]
|
||||||
|
|
||||||
|
# Parse Redis config
|
||||||
|
if "redis" in config_data:
|
||||||
|
pipeline_config.redis_config = config_data["redis"]
|
||||||
|
|
||||||
|
# Parse global settings
|
||||||
|
if "globalSettings" in config_data:
|
||||||
|
pipeline_config.global_settings = config_data["globalSettings"]
|
||||||
|
|
||||||
|
# Parse pipeline tree
|
||||||
|
if "pipeline" in config_data:
|
||||||
|
pipeline_config.root = self._parse_pipeline_node(
|
||||||
|
config_data["pipeline"], base_dir
|
||||||
|
)
|
||||||
|
elif "root" in config_data:
|
||||||
|
pipeline_config.root = self._parse_pipeline_node(
|
||||||
|
config_data["root"], base_dir
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise PipelineError("No pipeline or root node found in configuration")
|
||||||
|
|
||||||
|
return pipeline_config
|
||||||
|
|
||||||
|
def _parse_pipeline_node(
|
||||||
|
self,
|
||||||
|
node_data: Dict[str, Any],
|
||||||
|
base_dir: str
|
||||||
|
) -> PipelineNode:
|
||||||
|
"""
|
||||||
|
Parse a pipeline node from configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_data: Node configuration data
|
||||||
|
base_dir: Base directory for model files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed pipeline node
|
||||||
|
"""
|
||||||
|
# Create node
|
||||||
|
node = PipelineNode(
|
||||||
|
model_id=node_data.get("modelId", ""),
|
||||||
|
model_file=node_data.get("modelFile", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set model path
|
||||||
|
if node.model_file:
|
||||||
|
node.model_path = os.path.join(base_dir, node.model_file)
|
||||||
|
|
||||||
|
# Parse configuration
|
||||||
|
node.multi_class = node_data.get("multiClass", False)
|
||||||
|
node.expected_classes = node_data.get("expectedClasses", [])
|
||||||
|
node.trigger_classes = node_data.get("triggerClasses", [])
|
||||||
|
node.min_confidence = node_data.get("minConfidence", 0.5)
|
||||||
|
node.max_detections = node_data.get("maxDetections")
|
||||||
|
|
||||||
|
# Parse cropping
|
||||||
|
node.crop = node_data.get("crop", False)
|
||||||
|
node.crop_class = node_data.get("cropClass")
|
||||||
|
node.crop_expand_ratio = node_data.get("cropExpandRatio", 1.0)
|
||||||
|
|
||||||
|
# Parse actions
|
||||||
|
node.actions = node_data.get("actions", [])
|
||||||
|
node.parallel_actions = node_data.get("parallelActions", [])
|
||||||
|
|
||||||
|
# Parse YOLO settings
|
||||||
|
if "yoloSettings" in node_data:
|
||||||
|
node.yolo_settings = node_data["yoloSettings"]
|
||||||
|
elif "detectionSettings" in node_data:
|
||||||
|
node.yolo_settings = node_data["detectionSettings"]
|
||||||
|
|
||||||
|
# Parse tracking
|
||||||
|
node.track_classes = node_data.get("trackClasses")
|
||||||
|
|
||||||
|
# Parse metadata
|
||||||
|
node.metadata = node_data.get("metadata", {})
|
||||||
|
|
||||||
|
# Parse branches
|
||||||
|
branches_data = node_data.get("branches", [])
|
||||||
|
node.parallel = node_data.get("parallel", False)
|
||||||
|
|
||||||
|
for branch_data in branches_data:
|
||||||
|
branch_node = self._parse_pipeline_node(branch_data, base_dir)
|
||||||
|
node.branches.append(branch_node)
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
def _validate_pipeline(self, pipeline_config: PipelineConfig) -> None:
|
||||||
|
"""
|
||||||
|
Validate pipeline configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline_config: Pipeline configuration to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PipelineError: If validation fails
|
||||||
|
"""
|
||||||
|
if not pipeline_config.root:
|
||||||
|
raise PipelineError("Pipeline has no root node")
|
||||||
|
|
||||||
|
# Validate root node
|
||||||
|
self._validate_node(pipeline_config.root)
|
||||||
|
|
||||||
|
def _validate_node(self, node: PipelineNode) -> None:
|
||||||
|
"""
|
||||||
|
Validate a pipeline node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Node to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PipelineError: If validation fails
|
||||||
|
"""
|
||||||
|
# Check required fields
|
||||||
|
if not node.model_id:
|
||||||
|
raise PipelineError("Node missing modelId")
|
||||||
|
|
||||||
|
if not node.model_file and not node.model:
|
||||||
|
raise PipelineError(f"Node {node.model_id} missing modelFile")
|
||||||
|
|
||||||
|
# Validate model path exists
|
||||||
|
if node.model_path and not os.path.exists(node.model_path):
|
||||||
|
raise PipelineError(f"Model file not found: {node.model_path}")
|
||||||
|
|
||||||
|
# Validate cropping configuration
|
||||||
|
if node.crop and not node.crop_class:
|
||||||
|
raise PipelineError(f"Node {node.model_id} has crop=true but no cropClass")
|
||||||
|
|
||||||
|
# Validate confidence
|
||||||
|
if not 0 <= node.min_confidence <= 1:
|
||||||
|
raise PipelineError(f"Invalid minConfidence: {node.min_confidence}")
|
||||||
|
|
||||||
|
# Validate branches
|
||||||
|
for branch in node.branches:
|
||||||
|
self._validate_node(branch)
|
||||||
|
|
||||||
|
async def _load_pipeline_models(
|
||||||
|
self,
|
||||||
|
node: PipelineNode,
|
||||||
|
base_dir: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Load models for a pipeline node and its branches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Pipeline node
|
||||||
|
base_dir: Base directory for models
|
||||||
|
"""
|
||||||
|
# Load model for this node if path is specified
|
||||||
|
if node.model_path:
|
||||||
|
node.model = await self._load_model(node.model_path, node.model_id)
|
||||||
|
|
||||||
|
# Load models for branches
|
||||||
|
for branch in node.branches:
|
||||||
|
await self._load_pipeline_models(branch, base_dir)
|
||||||
|
|
||||||
|
async def _load_model(self, model_path: str, model_id: str) -> Any:
|
||||||
|
"""
|
||||||
|
Load a single model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model file
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded model instance
|
||||||
|
"""
|
||||||
|
# Check if already loaded
|
||||||
|
if model_path in self.loaded_models:
|
||||||
|
logger.info(f"Using cached model: {model_id}")
|
||||||
|
return self.loaded_models[model_path]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular dependency
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
logger.info(f"Loading model: {model_id} from {model_path}")
|
||||||
|
|
||||||
|
# Load YOLO model
|
||||||
|
model = YOLO(model_path)
|
||||||
|
|
||||||
|
# Cache the model
|
||||||
|
self.loaded_models[model_path] = model
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
|
||||||
|
|
||||||
|
def cleanup_model(self, model_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Clean up resources for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier to clean up
|
||||||
|
"""
|
||||||
|
# Clean up loaded models
|
||||||
|
models_to_remove = []
|
||||||
|
for path, model in self.loaded_models.items():
|
||||||
|
if model_id in path:
|
||||||
|
models_to_remove.append(path)
|
||||||
|
|
||||||
|
for path in models_to_remove:
|
||||||
|
self.loaded_models.pop(path, None)
|
||||||
|
logger.info(f"Cleaned up model: {path}")
|
||||||
|
|
||||||
|
def cleanup_all(self) -> None:
|
||||||
|
"""Clean up all resources."""
|
||||||
|
# Clear loaded models
|
||||||
|
self.loaded_models.clear()
|
||||||
|
|
||||||
|
# Clean up extracted directories
|
||||||
|
for mpta_path, extracted_dir in self.extracted_paths.items():
|
||||||
|
if os.path.exists(extracted_dir):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(extracted_dir)
|
||||||
|
logger.info(f"Cleaned up extracted directory: {extracted_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to clean up {extracted_dir}: {e}")
|
||||||
|
|
||||||
|
self.extracted_paths.clear()
|
||||||
|
|
||||||
|
def get_node_info(self, node: PipelineNode, level: int = 0) -> str:
|
||||||
|
"""
|
||||||
|
Get formatted information about a pipeline node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Pipeline node
|
||||||
|
level: Indentation level
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted node information
|
||||||
|
"""
|
||||||
|
indent = " " * level
|
||||||
|
info = []
|
||||||
|
|
||||||
|
info.append(f"{indent}Model: {node.model_id}")
|
||||||
|
info.append(f"{indent} File: {node.model_file}")
|
||||||
|
info.append(f"{indent} Multi-class: {node.multi_class}")
|
||||||
|
|
||||||
|
if node.expected_classes:
|
||||||
|
info.append(f"{indent} Expected: {', '.join(node.expected_classes)}")
|
||||||
|
if node.trigger_classes:
|
||||||
|
info.append(f"{indent} Triggers: {', '.join(node.trigger_classes)}")
|
||||||
|
|
||||||
|
info.append(f"{indent} Confidence: {node.min_confidence}")
|
||||||
|
|
||||||
|
if node.crop:
|
||||||
|
info.append(f"{indent} Crop: {node.crop_class} (ratio: {node.crop_expand_ratio})")
|
||||||
|
|
||||||
|
if node.actions:
|
||||||
|
info.append(f"{indent} Actions: {len(node.actions)}")
|
||||||
|
if node.parallel_actions:
|
||||||
|
info.append(f"{indent} Parallel Actions: {len(node.parallel_actions)}")
|
||||||
|
|
||||||
|
if node.branches:
|
||||||
|
info.append(f"{indent} Branches ({len(node.branches)}):")
|
||||||
|
for branch in node.branches:
|
||||||
|
info.append(self.get_node_info(branch, level + 2))
|
||||||
|
|
||||||
|
return "\n".join(info)
|
||||||
|
|
||||||
|
|
||||||
|
# Global pipeline loader instance
|
||||||
|
_pipeline_loader = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_pipeline_loader(temp_dir: Optional[str] = None) -> PipelineLoader:
|
||||||
|
"""Get or create the global pipeline loader instance."""
|
||||||
|
global _pipeline_loader
|
||||||
|
if _pipeline_loader is None:
|
||||||
|
_pipeline_loader = PipelineLoader(temp_dir)
|
||||||
|
return _pipeline_loader
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions
|
||||||
|
async def load_pipeline_from_mpta(mpta_path: str) -> PipelineNode:
|
||||||
|
"""Load a pipeline from an MPTA file."""
|
||||||
|
loader = get_pipeline_loader()
|
||||||
|
return await loader.load_pipeline(mpta_path)
|
379
detector_worker/utils/system_monitor.py
Normal file
379
detector_worker/utils/system_monitor.py
Normal file
|
@ -0,0 +1,379 @@
|
||||||
|
"""
|
||||||
|
System monitoring utilities.
|
||||||
|
|
||||||
|
This module provides functions to monitor system resources including
|
||||||
|
CPU, memory, and GPU usage.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import psutil
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
NVIDIA_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
NVIDIA_AVAILABLE = False
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logger = logging.getLogger("detector_worker.system_monitor")
|
||||||
|
|
||||||
|
# Global initialization flag
|
||||||
|
_nvidia_initialized = False
|
||||||
|
|
||||||
|
# Process start time
|
||||||
|
_start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_nvidia_monitoring() -> bool:
|
||||||
|
"""
|
||||||
|
Initialize NVIDIA GPU monitoring.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if initialization successful, False otherwise
|
||||||
|
"""
|
||||||
|
global _nvidia_initialized
|
||||||
|
|
||||||
|
if not NVIDIA_AVAILABLE:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if _nvidia_initialized:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
_nvidia_initialized = True
|
||||||
|
logger.info("NVIDIA GPU monitoring initialized")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to initialize NVIDIA monitoring: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_info() -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get information about available GPUs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of GPU information dictionaries
|
||||||
|
"""
|
||||||
|
if not NVIDIA_AVAILABLE or not _nvidia_initialized:
|
||||||
|
return []
|
||||||
|
|
||||||
|
gpu_info = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
device_count = pynvml.nvmlDeviceGetCount()
|
||||||
|
|
||||||
|
for i in range(device_count):
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
||||||
|
|
||||||
|
# Get GPU information
|
||||||
|
name = pynvml.nvmlDeviceGetName(handle).decode('utf-8')
|
||||||
|
|
||||||
|
# Get memory info
|
||||||
|
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
|
total_memory = mem_info.total / (1024 ** 3) # Convert to GB
|
||||||
|
used_memory = mem_info.used / (1024 ** 3)
|
||||||
|
free_memory = mem_info.free / (1024 ** 3)
|
||||||
|
|
||||||
|
# Get utilization
|
||||||
|
try:
|
||||||
|
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
||||||
|
gpu_util = util.gpu
|
||||||
|
memory_util = util.memory
|
||||||
|
except Exception:
|
||||||
|
gpu_util = 0
|
||||||
|
memory_util = 0
|
||||||
|
|
||||||
|
# Get temperature
|
||||||
|
try:
|
||||||
|
temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
|
||||||
|
except Exception:
|
||||||
|
temp = 0
|
||||||
|
|
||||||
|
gpu_info.append({
|
||||||
|
"index": i,
|
||||||
|
"name": name,
|
||||||
|
"gpu_utilization": gpu_util,
|
||||||
|
"memory_utilization": memory_util,
|
||||||
|
"total_memory_gb": round(total_memory, 2),
|
||||||
|
"used_memory_gb": round(used_memory, 2),
|
||||||
|
"free_memory_gb": round(free_memory, 2),
|
||||||
|
"temperature": temp
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting GPU info: {e}")
|
||||||
|
|
||||||
|
return gpu_info
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_metrics() -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Get current GPU metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with GPU utilization and memory usage
|
||||||
|
"""
|
||||||
|
if not NVIDIA_AVAILABLE or not _nvidia_initialized:
|
||||||
|
return {"gpu_percent": 0.0, "gpu_memory_percent": 0.0}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get first GPU metrics
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||||
|
|
||||||
|
# Get utilization
|
||||||
|
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
||||||
|
|
||||||
|
# Get memory info
|
||||||
|
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
|
memory_percent = (mem_info.used / mem_info.total) * 100
|
||||||
|
|
||||||
|
return {
|
||||||
|
"gpu_percent": float(util.gpu),
|
||||||
|
"gpu_memory_percent": float(memory_percent)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error getting GPU metrics: {e}")
|
||||||
|
return {"gpu_percent": 0.0, "gpu_memory_percent": 0.0}
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_metrics() -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Get current CPU metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with CPU usage percentage
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get CPU percent with 0.1 second interval
|
||||||
|
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||||
|
|
||||||
|
# Get per-core CPU usage
|
||||||
|
cpu_per_core = psutil.cpu_percent(interval=0.1, percpu=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"cpu_percent": cpu_percent,
|
||||||
|
"cpu_per_core": cpu_per_core,
|
||||||
|
"cpu_count": psutil.cpu_count()
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting CPU metrics: {e}")
|
||||||
|
return {"cpu_percent": 0.0, "cpu_per_core": [], "cpu_count": 0}
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_metrics() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get current memory metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with memory usage information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get virtual memory info
|
||||||
|
virtual_mem = psutil.virtual_memory()
|
||||||
|
|
||||||
|
# Get swap memory info
|
||||||
|
swap_mem = psutil.swap_memory()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"memory_percent": virtual_mem.percent,
|
||||||
|
"memory_used_gb": round(virtual_mem.used / (1024 ** 3), 2),
|
||||||
|
"memory_available_gb": round(virtual_mem.available / (1024 ** 3), 2),
|
||||||
|
"memory_total_gb": round(virtual_mem.total / (1024 ** 3), 2),
|
||||||
|
"swap_percent": swap_mem.percent,
|
||||||
|
"swap_used_gb": round(swap_mem.used / (1024 ** 3), 2)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting memory metrics: {e}")
|
||||||
|
return {
|
||||||
|
"memory_percent": 0.0,
|
||||||
|
"memory_used_gb": 0.0,
|
||||||
|
"memory_available_gb": 0.0,
|
||||||
|
"memory_total_gb": 0.0,
|
||||||
|
"swap_percent": 0.0,
|
||||||
|
"swap_used_gb": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_disk_metrics(path: str = "/") -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get disk usage metrics for specified path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to check disk usage for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with disk usage information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
disk_usage = psutil.disk_usage(path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"disk_percent": disk_usage.percent,
|
||||||
|
"disk_used_gb": round(disk_usage.used / (1024 ** 3), 2),
|
||||||
|
"disk_free_gb": round(disk_usage.free / (1024 ** 3), 2),
|
||||||
|
"disk_total_gb": round(disk_usage.total / (1024 ** 3), 2)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting disk metrics: {e}")
|
||||||
|
return {
|
||||||
|
"disk_percent": 0.0,
|
||||||
|
"disk_used_gb": 0.0,
|
||||||
|
"disk_free_gb": 0.0,
|
||||||
|
"disk_total_gb": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_process_metrics() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get current process metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with process-specific metrics
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
process = psutil.Process()
|
||||||
|
|
||||||
|
# Get process info
|
||||||
|
with process.oneshot():
|
||||||
|
cpu_percent = process.cpu_percent()
|
||||||
|
memory_info = process.memory_info()
|
||||||
|
memory_percent = process.memory_percent()
|
||||||
|
num_threads = process.num_threads()
|
||||||
|
|
||||||
|
# Get open file descriptors
|
||||||
|
try:
|
||||||
|
num_fds = len(process.open_files())
|
||||||
|
except Exception:
|
||||||
|
num_fds = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"process_cpu_percent": cpu_percent,
|
||||||
|
"process_memory_mb": round(memory_info.rss / (1024 ** 2), 2),
|
||||||
|
"process_memory_percent": round(memory_percent, 2),
|
||||||
|
"process_threads": num_threads,
|
||||||
|
"process_open_files": num_fds
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting process metrics: {e}")
|
||||||
|
return {
|
||||||
|
"process_cpu_percent": 0.0,
|
||||||
|
"process_memory_mb": 0.0,
|
||||||
|
"process_memory_percent": 0.0,
|
||||||
|
"process_threads": 0,
|
||||||
|
"process_open_files": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_network_metrics() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get network I/O metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with network statistics
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
net_io = psutil.net_io_counters()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"bytes_sent": net_io.bytes_sent,
|
||||||
|
"bytes_recv": net_io.bytes_recv,
|
||||||
|
"packets_sent": net_io.packets_sent,
|
||||||
|
"packets_recv": net_io.packets_recv,
|
||||||
|
"errin": net_io.errin,
|
||||||
|
"errout": net_io.errout,
|
||||||
|
"dropin": net_io.dropin,
|
||||||
|
"dropout": net_io.dropout
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting network metrics: {e}")
|
||||||
|
return {
|
||||||
|
"bytes_sent": 0,
|
||||||
|
"bytes_recv": 0,
|
||||||
|
"packets_sent": 0,
|
||||||
|
"packets_recv": 0,
|
||||||
|
"errin": 0,
|
||||||
|
"errout": 0,
|
||||||
|
"dropin": 0,
|
||||||
|
"dropout": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_metrics() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get comprehensive system metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing all system metrics
|
||||||
|
"""
|
||||||
|
# Initialize GPU monitoring if not done
|
||||||
|
if NVIDIA_AVAILABLE and not _nvidia_initialized:
|
||||||
|
initialize_nvidia_monitoring()
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"uptime": time.time() - _start_time,
|
||||||
|
"start_time": _start_time
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get CPU metrics
|
||||||
|
cpu_metrics = get_cpu_metrics()
|
||||||
|
metrics.update(cpu_metrics)
|
||||||
|
|
||||||
|
# Get memory metrics
|
||||||
|
memory_metrics = get_memory_metrics()
|
||||||
|
metrics.update(memory_metrics)
|
||||||
|
|
||||||
|
# Get GPU metrics
|
||||||
|
gpu_metrics = get_gpu_metrics()
|
||||||
|
metrics.update(gpu_metrics)
|
||||||
|
|
||||||
|
# Get process metrics
|
||||||
|
process_metrics = get_process_metrics()
|
||||||
|
metrics.update(process_metrics)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def get_resource_summary() -> str:
|
||||||
|
"""
|
||||||
|
Get a formatted summary of system resources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string with resource summary
|
||||||
|
"""
|
||||||
|
metrics = get_system_metrics()
|
||||||
|
|
||||||
|
summary = []
|
||||||
|
summary.append(f"CPU: {metrics['cpu_percent']:.1f}%")
|
||||||
|
summary.append(f"Memory: {metrics['memory_percent']:.1f}% ({metrics['memory_used_gb']:.1f}GB/{metrics['memory_total_gb']:.1f}GB)")
|
||||||
|
|
||||||
|
if metrics['gpu_percent'] > 0:
|
||||||
|
summary.append(f"GPU: {metrics['gpu_percent']:.1f}%")
|
||||||
|
summary.append(f"GPU Memory: {metrics['gpu_memory_percent']:.1f}%")
|
||||||
|
|
||||||
|
summary.append(f"Process Memory: {metrics['process_memory_mb']:.1f}MB")
|
||||||
|
summary.append(f"Threads: {metrics['process_threads']}")
|
||||||
|
|
||||||
|
return " | ".join(summary)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_nvidia_monitoring():
|
||||||
|
"""Clean up NVIDIA monitoring resources."""
|
||||||
|
global _nvidia_initialized
|
||||||
|
|
||||||
|
if NVIDIA_AVAILABLE and _nvidia_initialized:
|
||||||
|
try:
|
||||||
|
pynvml.nvmlShutdown()
|
||||||
|
_nvidia_initialized = False
|
||||||
|
logger.info("NVIDIA GPU monitoring cleaned up")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cleaning up NVIDIA monitoring: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize on import
|
||||||
|
if NVIDIA_AVAILABLE:
|
||||||
|
initialize_nvidia_monitoring()
|
Loading…
Add table
Add a link
Reference in a new issue