diff --git a/detector_worker/communication/message_processor.py b/detector_worker/communication/message_processor.py new file mode 100644 index 0000000..a6167e6 --- /dev/null +++ b/detector_worker/communication/message_processor.py @@ -0,0 +1,509 @@ +""" +Message processor module. + +This module handles validation, parsing, and processing of WebSocket messages. +It provides a clean separation between message handling and business logic. +""" +import json +import logging +import time +from typing import Dict, Any, Optional, Callable, Tuple +from enum import Enum +from dataclasses import dataclass, field + +from ..core.exceptions import ValidationError, MessageProcessingError + +# Setup logging +logger = logging.getLogger("detector_worker.message_processor") + + +class MessageType(Enum): + """Enumeration of supported WebSocket message types.""" + SUBSCRIBE = "subscribe" + UNSUBSCRIBE = "unsubscribe" + REQUEST_STATE = "requestState" + SET_SESSION_ID = "setSessionId" + PATCH_SESSION = "patchSession" + SET_PROGRESSION_STAGE = "setProgressionStage" + IMAGE_DETECTION = "imageDetection" + STATE_REPORT = "stateReport" + ACK = "ack" + ERROR = "error" + + +class ProgressionStage(Enum): + """Enumeration of progression stages.""" + WELCOME = "welcome" + CAR_FUELING = "car_fueling" + CAR_WAIT_PAYMENT = "car_waitpayment" + CAR_WAIT_STAFF = "car_wait_staff" + + +@dataclass +class SubscribePayload: + """Payload for subscription messages.""" + subscription_identifier: str + model_id: int + model_name: str + model_url: str + rtsp_url: Optional[str] = None + snapshot_url: Optional[str] = None + snapshot_interval: Optional[int] = None + crop_x1: Optional[int] = None + crop_y1: Optional[int] = None + crop_x2: Optional[int] = None + crop_y2: Optional[int] = None + extra_params: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SubscribePayload": + """Create SubscribePayload from dictionary.""" + # Extract known fields + known_fields = { + "subscription_identifier": data.get("subscriptionIdentifier"), + "model_id": data.get("modelId"), + "model_name": data.get("modelName"), + "model_url": data.get("modelUrl"), + "rtsp_url": data.get("rtspUrl"), + "snapshot_url": data.get("snapshotUrl"), + "snapshot_interval": data.get("snapshotInterval"), + "crop_x1": data.get("cropX1"), + "crop_y1": data.get("cropY1"), + "crop_x2": data.get("cropX2"), + "crop_y2": data.get("cropY2") + } + + # Extract extra parameters + extra_params = {k: v for k, v in data.items() + if k not in ["subscriptionIdentifier", "modelId", "modelName", + "modelUrl", "rtspUrl", "snapshotUrl", "snapshotInterval", + "cropX1", "cropY1", "cropX2", "cropY2"]} + + return cls(**known_fields, extra_params=extra_params) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format for stream configuration.""" + result = { + "subscriptionIdentifier": self.subscription_identifier, + "modelId": self.model_id, + "modelName": self.model_name, + "modelUrl": self.model_url + } + + if self.rtsp_url: + result["rtspUrl"] = self.rtsp_url + if self.snapshot_url: + result["snapshotUrl"] = self.snapshot_url + if self.snapshot_interval is not None: + result["snapshotInterval"] = self.snapshot_interval + + if self.crop_x1 is not None: + result["cropX1"] = self.crop_x1 + if self.crop_y1 is not None: + result["cropY1"] = self.crop_y1 + if self.crop_x2 is not None: + result["cropX2"] = self.crop_x2 + if self.crop_y2 is not None: + result["cropY2"] = self.crop_y2 + + # Add any extra parameters + result.update(self.extra_params) + + return result + + +@dataclass +class SessionPayload: + """Payload for session-related messages.""" + display_identifier: str + session_id: str + data: Optional[Dict[str, Any]] = None + + +@dataclass +class ProgressionPayload: + """Payload for progression stage messages.""" + display_identifier: str + progression_stage: str + + +class MessageProcessor: + """ + Processes and validates WebSocket messages. + + This class handles: + - Message validation and parsing + - Payload extraction and type conversion + - Error handling and response generation + - Message routing to appropriate handlers + """ + + def __init__(self): + """Initialize the message processor.""" + self.validators: Dict[MessageType, Callable] = { + MessageType.SUBSCRIBE: self._validate_subscribe, + MessageType.UNSUBSCRIBE: self._validate_unsubscribe, + MessageType.SET_SESSION_ID: self._validate_set_session, + MessageType.PATCH_SESSION: self._validate_patch_session, + MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage + } + + def parse_message(self, raw_message: str) -> Tuple[MessageType, Dict[str, Any]]: + """ + Parse a raw WebSocket message. + + Args: + raw_message: Raw message string + + Returns: + Tuple of (MessageType, parsed_data) + + Raises: + MessageProcessingError: If message is invalid + """ + try: + data = json.loads(raw_message) + except json.JSONDecodeError as e: + raise MessageProcessingError(f"Invalid JSON: {e}") + + # Validate message structure + if not isinstance(data, dict): + raise MessageProcessingError("Message must be a JSON object") + + # Extract message type + msg_type_str = data.get("type") + if not msg_type_str: + raise MessageProcessingError("Missing 'type' field in message") + + # Convert to MessageType enum + try: + msg_type = MessageType(msg_type_str) + except ValueError: + raise MessageProcessingError(f"Unknown message type: {msg_type_str}") + + return msg_type, data + + def validate_message(self, msg_type: MessageType, data: Dict[str, Any]) -> Any: + """ + Validate message payload based on type. + + Args: + msg_type: Message type + data: Message data + + Returns: + Validated payload object + + Raises: + ValidationError: If validation fails + """ + # Get validator for message type + validator = self.validators.get(msg_type) + if not validator: + # No validation needed for some message types + return data.get("payload", {}) + + return validator(data) + + def _validate_subscribe(self, data: Dict[str, Any]) -> SubscribePayload: + """Validate subscribe message.""" + payload = data.get("payload", {}) + + # Required fields + required = ["subscriptionIdentifier", "modelId", "modelName", "modelUrl"] + missing = [field for field in required if not payload.get(field)] + if missing: + raise ValidationError(f"Missing required fields: {', '.join(missing)}") + + # Must have either rtspUrl or snapshotUrl + if not payload.get("rtspUrl") and not payload.get("snapshotUrl"): + raise ValidationError("Must provide either rtspUrl or snapshotUrl") + + # Validate snapshot interval if snapshot URL is provided + if payload.get("snapshotUrl") and not payload.get("snapshotInterval"): + raise ValidationError("snapshotInterval is required when using snapshotUrl") + + # Validate crop coordinates + crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"] + crop_values = [payload.get(coord) for coord in crop_coords] + if any(v is not None for v in crop_values): + # If any crop coordinate is set, all must be set + if not all(v is not None for v in crop_values): + raise ValidationError("All crop coordinates must be specified if any are set") + + # Validate coordinate values + x1, y1, x2, y2 = crop_values + if x2 <= x1 or y2 <= y1: + raise ValidationError("Invalid crop coordinates: x2 must be > x1, y2 must be > y1") + + return SubscribePayload.from_dict(payload) + + def _validate_unsubscribe(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Validate unsubscribe message.""" + payload = data.get("payload", {}) + + if not payload.get("subscriptionIdentifier"): + raise ValidationError("Missing required field: subscriptionIdentifier") + + return payload + + def _validate_set_session(self, data: Dict[str, Any]) -> SessionPayload: + """Validate setSessionId message.""" + payload = data.get("payload", {}) + + display_id = payload.get("displayIdentifier") + session_id = payload.get("sessionId") + + if not display_id: + raise ValidationError("Missing required field: displayIdentifier") + if not session_id: + raise ValidationError("Missing required field: sessionId") + + return SessionPayload(display_id, session_id) + + def _validate_patch_session(self, data: Dict[str, Any]) -> SessionPayload: + """Validate patchSession message.""" + payload = data.get("payload", {}) + + session_id = payload.get("sessionId") + patch_data = payload.get("data") + + if not session_id: + raise ValidationError("Missing required field: sessionId") + + # Display identifier is optional for patch + display_id = payload.get("displayIdentifier", "") + + return SessionPayload(display_id, session_id, patch_data) + + def _validate_progression_stage(self, data: Dict[str, Any]) -> ProgressionPayload: + """Validate setProgressionStage message.""" + payload = data.get("payload", {}) + + display_id = payload.get("displayIdentifier") + stage = payload.get("progressionStage") + + if not display_id: + raise ValidationError("Missing required field: displayIdentifier") + if not stage: + raise ValidationError("Missing required field: progressionStage") + + # Validate stage value + valid_stages = [s.value for s in ProgressionStage] + if stage not in valid_stages: + logger.warning(f"Unknown progression stage: {stage}") + + return ProgressionPayload(display_id, stage) + + def create_ack_response( + self, + request_id: Optional[str] = None, + message: str = "Success", + data: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Create an acknowledgment response. + + Args: + request_id: Request ID to acknowledge + message: Success message + data: Additional response data + + Returns: + ACK response dictionary + """ + response = { + "type": MessageType.ACK.value, + "requestId": request_id or str(int(time.time() * 1000)), + "code": "200", + "data": { + "message": message + } + } + + if data: + response["data"].update(data) + + return response + + def create_error_response( + self, + request_id: Optional[str] = None, + code: str = "400", + message: str = "Error", + details: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Create an error response. + + Args: + request_id: Request ID that caused error + code: Error code + message: Error message + details: Additional error details + + Returns: + Error response dictionary + """ + response = { + "type": MessageType.ERROR.value, + "requestId": request_id or str(int(time.time() * 1000)), + "code": code, + "data": { + "message": message + } + } + + if details: + response["data"]["details"] = details + + return response + + def create_detection_message( + self, + subscription_id: str, + detection_result: Optional[Dict[str, Any]], + model_id: int, + model_name: str, + timestamp: Optional[str] = None + ) -> Dict[str, Any]: + """ + Create a detection result message. + + Args: + subscription_id: Subscription identifier + detection_result: Detection result data (None for disconnection) + model_id: Model ID + model_name: Model name + timestamp: Optional timestamp (defaults to current UTC) + + Returns: + Detection message dictionary + """ + if timestamp is None: + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + return { + "type": MessageType.IMAGE_DETECTION.value, + "subscriptionIdentifier": subscription_id, + "timestamp": timestamp, + "data": { + "detection": detection_result, + "modelId": model_id, + "modelName": model_name + } + } + + def create_state_report( + self, + active_streams: int, + loaded_models: int, + cpu_usage: float = 0.0, + memory_usage: float = 0.0, + gpu_usage: float = 0.0, + gpu_memory: float = 0.0, + uptime: float = 0.0, + additional_metrics: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Create a state report message. + + Args: + active_streams: Number of active streams + loaded_models: Number of loaded models + cpu_usage: CPU usage percentage + memory_usage: Memory usage percentage + gpu_usage: GPU usage percentage + gpu_memory: GPU memory usage percentage + uptime: System uptime in seconds + additional_metrics: Additional metrics to include + + Returns: + State report dictionary + """ + report = { + "type": MessageType.STATE_REPORT.value, + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "data": { + "activeStreams": active_streams, + "loadedModels": loaded_models, + "cpuUsage": cpu_usage, + "memoryUsage": memory_usage, + "gpuUsage": gpu_usage, + "gpuMemory": gpu_memory, + "uptime": uptime + } + } + + if additional_metrics: + report["data"].update(additional_metrics) + + return report + + def extract_camera_id(self, subscription_id: str) -> Tuple[Optional[str], str]: + """ + Extract display ID and camera ID from subscription identifier. + + Args: + subscription_id: Subscription identifier (format: "display_id;camera_id") + + Returns: + Tuple of (display_id, camera_id) + """ + parts = subscription_id.split(";") + if len(parts) >= 2: + return parts[0], parts[1] + else: + return None, subscription_id + + +# Singleton instance for backward compatibility +_message_processor = MessageProcessor() + +# Convenience functions +def parse_websocket_message(raw_message: str) -> Tuple[MessageType, Dict[str, Any]]: + """Parse a raw WebSocket message.""" + return _message_processor.parse_message(raw_message) + +def validate_message_payload(msg_type: MessageType, data: Dict[str, Any]) -> Any: + """Validate message payload based on type.""" + return _message_processor.validate_message(msg_type, data) + +def create_ack_response( + request_id: Optional[str] = None, + message: str = "Success", + data: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: + """Create an acknowledgment response.""" + return _message_processor.create_ack_response(request_id, message, data) + +def create_error_response( + request_id: Optional[str] = None, + code: str = "400", + message: str = "Error", + details: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: + """Create an error response.""" + return _message_processor.create_error_response(request_id, code, message, details) + +def create_detection_message( + subscription_id: str, + detection_result: Optional[Dict[str, Any]], + model_id: int, + model_name: str, + timestamp: Optional[str] = None +) -> Dict[str, Any]: + """Create a detection result message.""" + return _message_processor.create_detection_message( + subscription_id, detection_result, model_id, model_name, timestamp + ) + +def create_state_report( + active_streams: int, + loaded_models: int, + **kwargs +) -> Dict[str, Any]: + """Create a state report message.""" + return _message_processor.create_state_report( + active_streams, loaded_models, **kwargs + ) \ No newline at end of file diff --git a/detector_worker/communication/response_formatter.py b/detector_worker/communication/response_formatter.py new file mode 100644 index 0000000..9fb06c8 --- /dev/null +++ b/detector_worker/communication/response_formatter.py @@ -0,0 +1,439 @@ +""" +Response formatter module. + +This module handles formatting of detection results and other responses +for WebSocket transmission, ensuring consistent output format. +""" +import json +import logging +import time +from typing import Dict, Any, List, Optional, Union +from datetime import datetime +from decimal import Decimal + +from ..detection.detection_result import DetectionResult + +# Setup logging +logger = logging.getLogger("detector_worker.response_formatter") + + +class ResponseFormatter: + """ + Formats responses for WebSocket transmission. + + This class handles: + - Detection result formatting + - Object serialization + - Response structure standardization + - Special type handling (Decimal, datetime, etc.) + """ + + def __init__(self, compact: bool = True): + """ + Initialize the response formatter. + + Args: + compact: Whether to use compact JSON formatting + """ + self.compact = compact + self.separators = (',', ':') if compact else (',', ': ') + + def format_detection_result( + self, + detection_result: Union[DetectionResult, Dict[str, Any], None], + include_tracking: bool = True, + include_metadata: bool = True + ) -> Optional[Dict[str, Any]]: + """ + Format a detection result for transmission. + + Args: + detection_result: Detection result to format + include_tracking: Whether to include tracking information + include_metadata: Whether to include metadata + + Returns: + Formatted detection dictionary or None + """ + if detection_result is None: + return None + + # Handle DetectionResult object + if isinstance(detection_result, DetectionResult): + result = detection_result.to_dict() + else: + result = detection_result + + # Format detections + if "detections" in result: + result["detections"] = self._format_detections( + result["detections"], + include_tracking + ) + + # Format license plate results + if "license_plate_results" in result: + result["license_plate_results"] = self._format_license_plates( + result["license_plate_results"] + ) + + # Format session info + if "session_info" in result: + result["session_info"] = self._format_session_info( + result["session_info"] + ) + + # Remove metadata if not needed + if not include_metadata and "metadata" in result: + result.pop("metadata", None) + + # Clean up None values + return self._clean_dict(result) + + def _format_detections( + self, + detections: List[Dict[str, Any]], + include_tracking: bool = True + ) -> List[Dict[str, Any]]: + """Format detection objects.""" + formatted = [] + + for detection in detections: + formatted_det = { + "class": detection.get("class"), + "confidence": self._format_float(detection.get("confidence")), + "bbox": self._format_bbox(detection.get("bbox", [])) + } + + # Add tracking info if available and requested + if include_tracking: + if "track_id" in detection: + formatted_det["track_id"] = detection["track_id"] + if "track_age" in detection: + formatted_det["track_age"] = detection["track_age"] + if "stability_score" in detection: + formatted_det["stability_score"] = self._format_float( + detection["stability_score"] + ) + + # Add additional properties + for key, value in detection.items(): + if key not in ["class", "confidence", "bbox", "track_id", + "track_age", "stability_score"]: + formatted_det[key] = self._format_value(value) + + formatted.append(formatted_det) + + return formatted + + def _format_license_plates( + self, + license_plates: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Format license plate results.""" + formatted = [] + + for plate in license_plates: + formatted_plate = { + "text": plate.get("text", ""), + "confidence": self._format_float(plate.get("confidence")), + "bbox": self._format_bbox(plate.get("bbox", [])) + } + + # Add tracking info if available + if "track_id" in plate: + formatted_plate["track_id"] = plate["track_id"] + + # Add type if available + if "type" in plate: + formatted_plate["type"] = plate["type"] + + formatted.append(formatted_plate) + + return formatted + + def _format_session_info(self, session_info: Dict[str, Any]) -> Dict[str, Any]: + """Format session information.""" + formatted = { + "session_id": session_info.get("session_id"), + "display_id": session_info.get("display_id"), + "camera_id": session_info.get("camera_id") + } + + # Add timestamps + if "created_at" in session_info: + formatted["created_at"] = self._format_timestamp( + session_info["created_at"] + ) + if "updated_at" in session_info: + formatted["updated_at"] = self._format_timestamp( + session_info["updated_at"] + ) + + # Add any additional session data + for key, value in session_info.items(): + if key not in ["session_id", "display_id", "camera_id", + "created_at", "updated_at"]: + formatted[key] = self._format_value(value) + + return self._clean_dict(formatted) + + def _format_bbox(self, bbox: List[Union[int, float]]) -> List[int]: + """Format bounding box coordinates.""" + if len(bbox) != 4: + return [] + return [int(coord) for coord in bbox] + + def _format_float(self, value: Union[float, Decimal, None], precision: int = 4) -> Optional[float]: + """Format floating point values.""" + if value is None: + return None + if isinstance(value, Decimal): + value = float(value) + return round(float(value), precision) + + def _format_timestamp(self, timestamp: Union[str, datetime, float, None]) -> Optional[str]: + """Format timestamp to ISO format.""" + if timestamp is None: + return None + + # Handle string timestamps + if isinstance(timestamp, str): + return timestamp + + # Handle datetime objects + if isinstance(timestamp, datetime): + return timestamp.strftime("%Y-%m-%dT%H:%M:%SZ") + + # Handle Unix timestamps + if isinstance(timestamp, (int, float)): + dt = datetime.fromtimestamp(timestamp) + return dt.strftime("%Y-%m-%dT%H:%M:%SZ") + + return str(timestamp) + + def _format_value(self, value: Any) -> Any: + """Format arbitrary values for JSON serialization.""" + # Handle None + if value is None: + return None + + # Handle basic types + if isinstance(value, (str, int, bool)): + return value + + # Handle float/Decimal + if isinstance(value, (float, Decimal)): + return self._format_float(value) + + # Handle datetime + if isinstance(value, datetime): + return self._format_timestamp(value) + + # Handle lists + if isinstance(value, list): + return [self._format_value(item) for item in value] + + # Handle dictionaries + if isinstance(value, dict): + return {k: self._format_value(v) for k, v in value.items()} + + # Convert to string for other types + return str(value) + + def _clean_dict(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Remove None values and empty collections from dictionary.""" + cleaned = {} + + for key, value in data.items(): + # Skip None values + if value is None: + continue + + # Skip empty lists + if isinstance(value, list) and len(value) == 0: + continue + + # Skip empty dictionaries + if isinstance(value, dict) and len(value) == 0: + continue + + # Recursively clean dictionaries + if isinstance(value, dict): + cleaned_value = self._clean_dict(value) + if cleaned_value: + cleaned[key] = cleaned_value + else: + cleaned[key] = value + + return cleaned + + def format_tracking_state(self, tracking_state: Dict[str, Any]) -> Dict[str, Any]: + """ + Format tracking state information. + + Args: + tracking_state: Tracking state dictionary + + Returns: + Formatted tracking state + """ + formatted = {} + + # Format tracker state + if "active_tracks" in tracking_state: + formatted["active_tracks"] = len(tracking_state["active_tracks"]) + + if "tracks" in tracking_state: + formatted["tracks"] = [] + for track in tracking_state["tracks"]: + formatted_track = { + "track_id": track.get("track_id"), + "class": track.get("class"), + "age": track.get("age", 0), + "hits": track.get("hits", 0), + "misses": track.get("misses", 0) + } + + if "bbox" in track: + formatted_track["bbox"] = self._format_bbox(track["bbox"]) + + if "confidence" in track: + formatted_track["confidence"] = self._format_float(track["confidence"]) + + formatted["tracks"].append(formatted_track) + + return self._clean_dict(formatted) + + def format_error_details(self, error: Exception) -> Dict[str, Any]: + """ + Format error details for response. + + Args: + error: Exception to format + + Returns: + Formatted error details + """ + details = { + "error_type": type(error).__name__, + "error_message": str(error) + } + + # Add additional context if available + if hasattr(error, "details"): + details["details"] = error.details + + if hasattr(error, "code"): + details["error_code"] = error.code + + return details + + def to_json(self, data: Any) -> str: + """ + Convert data to JSON string. + + Args: + data: Data to serialize + + Returns: + JSON string + """ + return json.dumps( + data, + separators=self.separators, + default=self._json_default + ) + + def _json_default(self, obj: Any) -> Any: + """Default JSON serialization for custom types.""" + # Handle Decimal + if isinstance(obj, Decimal): + return float(obj) + + # Handle datetime + if isinstance(obj, datetime): + return obj.strftime("%Y-%m-%dT%H:%M:%SZ") + + # Handle bytes + if isinstance(obj, bytes): + return obj.decode('utf-8', errors='ignore') + + # Default to string representation + return str(obj) + + def format_pipeline_result( + self, + pipeline_result: Dict[str, Any], + include_intermediate: bool = False + ) -> Dict[str, Any]: + """ + Format pipeline execution result. + + Args: + pipeline_result: Pipeline result dictionary + include_intermediate: Whether to include intermediate results + + Returns: + Formatted pipeline result + """ + formatted = { + "success": pipeline_result.get("success", False), + "pipeline_id": pipeline_result.get("pipeline_id"), + "execution_time": self._format_float( + pipeline_result.get("execution_time", 0) + ) + } + + # Add final detection result + if "detection_result" in pipeline_result: + formatted["detection_result"] = self.format_detection_result( + pipeline_result["detection_result"] + ) + + # Add branch results + if "branch_results" in pipeline_result and include_intermediate: + formatted["branch_results"] = {} + for branch_id, result in pipeline_result["branch_results"].items(): + formatted["branch_results"][branch_id] = self._format_value(result) + + # Add executed actions + if "executed_actions" in pipeline_result: + formatted["executed_actions"] = pipeline_result["executed_actions"] + + # Add errors if any + if "errors" in pipeline_result: + formatted["errors"] = [ + self.format_error_details(error) if isinstance(error, Exception) + else error + for error in pipeline_result["errors"] + ] + + return self._clean_dict(formatted) + + +# Singleton instance +_formatter = ResponseFormatter() + +# Convenience functions for backward compatibility +def format_detection_for_websocket( + detection_result: Union[DetectionResult, Dict[str, Any], None], + include_tracking: bool = True, + include_metadata: bool = True +) -> Optional[Dict[str, Any]]: + """Format detection result for WebSocket transmission.""" + return _formatter.format_detection_result( + detection_result, include_tracking, include_metadata + ) + +def format_tracking_state(tracking_state: Dict[str, Any]) -> Dict[str, Any]: + """Format tracking state information.""" + return _formatter.format_tracking_state(tracking_state) + +def format_error_response(error: Exception) -> Dict[str, Any]: + """Format error details for response.""" + return _formatter.format_error_details(error) + +def to_compact_json(data: Any) -> str: + """Convert data to compact JSON string.""" + return _formatter.to_json(data) \ No newline at end of file diff --git a/detector_worker/communication/websocket_handler.py b/detector_worker/communication/websocket_handler.py new file mode 100644 index 0000000..8f0e641 --- /dev/null +++ b/detector_worker/communication/websocket_handler.py @@ -0,0 +1,622 @@ +""" +WebSocket handler module. + +This module manages WebSocket connections, message processing, heartbeat functionality, +and coordination between stream processing and detection pipelines. +""" +import asyncio +import json +import logging +import time +import traceback +import uuid +from typing import Dict, Any, Optional, Callable, List, Set +from contextlib import asynccontextmanager + +from fastapi import WebSocket +from websockets.exceptions import ConnectionClosedError, WebSocketDisconnect + +from ..core.config import config, subscription_to_camera, latest_frames +from ..core.constants import HEARTBEAT_INTERVAL +from ..core.exceptions import WebSocketError, StreamError +from ..streams.stream_manager import StreamManager +from ..streams.camera_monitor import CameraConnectionMonitor +from ..detection.detection_result import DetectionResult +from ..models.model_manager import ModelManager +from ..pipeline.pipeline_executor import PipelineExecutor +from ..storage.session_cache import SessionCache +from ..storage.redis_client import RedisClientManager +from ..utils.system_monitor import get_system_metrics + +# Setup logging +logger = logging.getLogger("detector_worker.websocket_handler") +ws_logger = logging.getLogger("websocket") + +# Type definitions for callbacks +MessageHandler = Callable[[Dict[str, Any]], asyncio.coroutine] +DetectionHandler = Callable[[str, Dict[str, Any], Any, WebSocket, Any, Dict[str, Any]], asyncio.coroutine] + + +class WebSocketHandler: + """ + Manages WebSocket connections and message processing for the detection worker. + + This class handles: + - WebSocket lifecycle management + - Message routing and processing + - Heartbeat/state reporting + - Stream subscription management + - Detection result forwarding + - Session management + """ + + def __init__( + self, + stream_manager: StreamManager, + model_manager: ModelManager, + pipeline_executor: PipelineExecutor, + session_cache: SessionCache, + redis_client: Optional[RedisClientManager] = None + ): + """ + Initialize the WebSocket handler. + + Args: + stream_manager: Manager for camera streams + model_manager: Manager for ML models + pipeline_executor: Pipeline execution engine + session_cache: Session state cache + redis_client: Optional Redis client for pub/sub + """ + self.stream_manager = stream_manager + self.model_manager = model_manager + self.pipeline_executor = pipeline_executor + self.session_cache = session_cache + self.redis_client = redis_client + + # Connection state + self.websocket: Optional[WebSocket] = None + self.connected: bool = False + self.tasks: List[asyncio.Task] = [] + + # Message handlers + self.message_handlers: Dict[str, MessageHandler] = { + "subscribe": self._handle_subscribe, + "unsubscribe": self._handle_unsubscribe, + "requestState": self._handle_request_state, + "setSessionId": self._handle_set_session, + "patchSession": self._handle_patch_session, + "setProgressionStage": self._handle_set_progression_stage + } + + # Session and display management + self.session_ids: Dict[str, str] = {} # display_identifier -> session_id + self.display_identifiers: Set[str] = set() + + # Camera monitor + self.camera_monitor = CameraConnectionMonitor() + + async def handle_connection(self, websocket: WebSocket) -> None: + """ + Main entry point for handling a WebSocket connection. + + Args: + websocket: The WebSocket connection to handle + """ + try: + await websocket.accept() + self.websocket = websocket + self.connected = True + + logger.info("WebSocket connection accepted") + + # Create concurrent tasks + stream_task = asyncio.create_task(self._process_streams()) + heartbeat_task = asyncio.create_task(self._send_heartbeat()) + message_task = asyncio.create_task(self._process_messages()) + + self.tasks = [stream_task, heartbeat_task, message_task] + + # Wait for tasks to complete + await asyncio.gather(heartbeat_task, message_task) + + except Exception as e: + logger.error(f"Error in WebSocket handler: {e}") + + finally: + self.connected = False + await self._cleanup() + + async def _cleanup(self) -> None: + """Clean up resources when connection closes.""" + logger.info("Cleaning up WebSocket connection") + + # Cancel all tasks + for task in self.tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Clean up streams + await self.stream_manager.cleanup_all_streams() + + # Clean up models + self.model_manager.cleanup_all_models() + + # Clear session state + self.session_cache.clear_all_sessions() + self.session_ids.clear() + self.display_identifiers.clear() + + # Clear camera states + self.camera_monitor.clear_all_states() + + logger.info("WebSocket cleanup completed") + + async def _send_heartbeat(self) -> None: + """Send periodic heartbeat/state reports to maintain connection.""" + while self.connected: + try: + # Get system metrics + metrics = get_system_metrics() + + # Get active streams info + active_streams = self.stream_manager.get_active_streams() + active_models = self.model_manager.get_loaded_models() + + state_data = { + "type": "stateReport", + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "data": { + "activeStreams": len(active_streams), + "loadedModels": len(active_models), + "cpuUsage": metrics.get("cpu_percent", 0), + "memoryUsage": metrics.get("memory_percent", 0), + "gpuUsage": metrics.get("gpu_percent", 0), + "gpuMemory": metrics.get("gpu_memory_percent", 0), + "uptime": time.time() - metrics.get("start_time", time.time()) + } + } + + ws_logger.info(f"TX -> {json.dumps(state_data, separators=(',', ':'))}") + await self.websocket.send_json(state_data) + + await asyncio.sleep(HEARTBEAT_INTERVAL) + + except (WebSocketDisconnect, ConnectionClosedError): + logger.info("WebSocket disconnected during heartbeat") + break + except Exception as e: + logger.error(f"Error sending heartbeat: {e}") + break + + async def _process_messages(self) -> None: + """Process incoming WebSocket messages.""" + while self.connected: + try: + text_data = await self.websocket.receive_text() + ws_logger.info(f"RX <- {text_data}") + + data = json.loads(text_data) + msg_type = data.get("type") + + if msg_type in self.message_handlers: + handler = self.message_handlers[msg_type] + await handler(data) + else: + logger.error(f"Unknown message type: {msg_type}") + + except json.JSONDecodeError: + logger.error("Received invalid JSON message") + except (WebSocketDisconnect, ConnectionClosedError) as e: + logger.warning(f"WebSocket disconnected: {e}") + break + except Exception as e: + logger.error(f"Error handling message: {e}") + traceback.print_exc() + break + + async def _process_streams(self) -> None: + """Process active camera streams and run detection pipelines.""" + while self.connected: + try: + active_streams = self.stream_manager.get_active_streams() + + if active_streams: + # Process each active stream + tasks = [] + for camera_id, stream_info in active_streams.items(): + # Get latest frame + frame = self.stream_manager.get_latest_frame(camera_id) + if frame is None: + continue + + # Get model for this camera + model_id = stream_info.get("modelId") + if not model_id: + continue + + model_tree = self.model_manager.get_model(camera_id, model_id) + if not model_tree: + continue + + # Create detection task + persistent_data = self.session_cache.get_persistent_data(camera_id) + task = asyncio.create_task( + self._handle_detection( + camera_id, stream_info, frame, + self.websocket, model_tree, persistent_data + ) + ) + tasks.append(task) + + # Wait for all detection tasks + if tasks: + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Update persistent data + for i, (camera_id, _) in enumerate(active_streams.items()): + if i < len(results) and isinstance(results[i], dict): + self.session_cache.update_persistent_data(camera_id, results[i]) + + # Polling interval + poll_interval = config.get("poll_interval_ms", 100) / 1000.0 + await asyncio.sleep(poll_interval) + + except asyncio.CancelledError: + logger.info("Stream processing cancelled") + break + except Exception as e: + logger.error(f"Error in stream processing: {e}") + await asyncio.sleep(1) + + async def _handle_detection( + self, + camera_id: str, + stream_info: Dict[str, Any], + frame: Any, + websocket: WebSocket, + model_tree: Any, + persistent_data: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Handle detection for a single camera frame. + + Args: + camera_id: Camera identifier + stream_info: Stream configuration + frame: Video frame to process + websocket: WebSocket connection + model_tree: Model pipeline tree + persistent_data: Persistent data for this camera + + Returns: + Updated persistent data + """ + try: + # Check camera connection state + if self.camera_monitor.should_notify_disconnection(camera_id): + await self._send_disconnection_notification(camera_id, stream_info) + return persistent_data + + # Apply crop if specified + cropped_frame = self._apply_crop(frame, stream_info) + + # Get session pipeline state + pipeline_state = self.session_cache.get_session_pipeline_state(camera_id) + + # Run detection pipeline + detection_result = await self.pipeline_executor.execute_pipeline( + camera_id, + stream_info, + cropped_frame, + model_tree, + persistent_data, + pipeline_state + ) + + # Send detection result + if detection_result: + await self._send_detection_result( + camera_id, stream_info, detection_result + ) + + # Handle camera reconnection + if self.camera_monitor.should_notify_reconnection(camera_id): + self.camera_monitor.mark_reconnection_notified(camera_id) + logger.info(f"Camera {camera_id} reconnected successfully") + + return persistent_data + + except Exception as e: + logger.error(f"Error in detection handling for camera {camera_id}: {e}") + traceback.print_exc() + return persistent_data + + async def _handle_subscribe(self, data: Dict[str, Any]) -> None: + """Handle stream subscription request.""" + payload = data.get("payload", {}) + subscription_id = payload.get("subscriptionIdentifier") + + if not subscription_id: + logger.error("Missing subscriptionIdentifier in subscribe payload") + return + + try: + # Extract display and camera IDs + parts = subscription_id.split(";") + if len(parts) >= 2: + display_id = parts[0] + camera_id = parts[1] + self.display_identifiers.add(display_id) + else: + camera_id = subscription_id + + # Store subscription mapping + subscription_to_camera[subscription_id] = camera_id + + # Start camera stream + await self.stream_manager.start_stream(camera_id, payload) + + # Load model + model_id = payload.get("modelId") + model_url = payload.get("modelUrl") + if model_id and model_url: + await self.model_manager.load_model(camera_id, model_id, model_url) + + logger.info(f"Subscribed to stream: {subscription_id}") + + except Exception as e: + logger.error(f"Error handling subscription: {e}") + traceback.print_exc() + + async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None: + """Handle stream unsubscription request.""" + payload = data.get("payload", {}) + subscription_id = payload.get("subscriptionIdentifier") + + if not subscription_id: + logger.error("Missing subscriptionIdentifier in unsubscribe payload") + return + + try: + # Get camera ID from subscription + camera_id = subscription_to_camera.get(subscription_id) + if not camera_id: + logger.warning(f"No camera found for subscription: {subscription_id}") + return + + # Stop stream + await self.stream_manager.stop_stream(camera_id) + + # Unload model + self.model_manager.unload_models(camera_id) + + # Clean up mappings + subscription_to_camera.pop(subscription_id, None) + + # Clean up session state + self.session_cache.clear_session(camera_id) + + logger.info(f"Unsubscribed from stream: {subscription_id}") + + except Exception as e: + logger.error(f"Error handling unsubscription: {e}") + traceback.print_exc() + + async def _handle_request_state(self, data: Dict[str, Any]) -> None: + """Handle state request message.""" + # Send immediate state report + await self._send_heartbeat() + + async def _handle_set_session(self, data: Dict[str, Any]) -> None: + """Handle setSessionId message.""" + payload = data.get("payload", {}) + display_id = payload.get("displayIdentifier") + session_id = payload.get("sessionId") + + if display_id and session_id: + self.session_ids[display_id] = session_id + + # Update session for all cameras of this display + with self.stream_manager.streams_lock: + for camera_id, stream in self.stream_manager.streams.items(): + if stream["subscriptionIdentifier"].startswith(display_id + ";"): + self.session_cache.update_session_id(camera_id, session_id) + + # Send acknowledgment + response = { + "type": "ack", + "requestId": data.get("requestId", str(uuid.uuid4())), + "code": "200", + "data": { + "message": f"Session ID set for display {display_id}", + "sessionId": session_id + } + } + ws_logger.info(f"TX -> {json.dumps(response, separators=(',', ':'))}") + await self.websocket.send_json(response) + + logger.info(f"Set session {session_id} for display {display_id}") + + async def _handle_patch_session(self, data: Dict[str, Any]) -> None: + """Handle patchSession message.""" + payload = data.get("payload", {}) + session_id = payload.get("sessionId") + patch_data = payload.get("data", {}) + + if session_id: + # Store patch data (could be used for session updates) + logger.info(f"Received patch for session {session_id}: {patch_data}") + + # Send acknowledgment + response = { + "type": "ack", + "requestId": data.get("requestId", str(uuid.uuid4())), + "code": "200", + "data": { + "message": f"Session {session_id} patched successfully", + "sessionId": session_id, + "patchData": patch_data + } + } + ws_logger.info(f"TX -> {json.dumps(response, separators=(',', ':'))}") + await self.websocket.send_json(response) + + async def _handle_set_progression_stage(self, data: Dict[str, Any]) -> None: + """Handle setProgressionStage message.""" + payload = data.get("payload", {}) + display_id = payload.get("displayIdentifier") + progression_stage = payload.get("progressionStage") + + logger.info(f"🏁 PROGRESSION STAGE RECEIVED: displayId={display_id}, stage={progression_stage}") + + if not display_id: + logger.warning("Missing displayIdentifier in setProgressionStage") + return + + # Find all cameras for this display + affected_cameras = [] + with self.stream_manager.streams_lock: + for camera_id, stream in self.stream_manager.streams.items(): + if stream["subscriptionIdentifier"].startswith(display_id + ";"): + affected_cameras.append(camera_id) + + logger.debug(f"🎯 Found {len(affected_cameras)} cameras for display {display_id}: {affected_cameras}") + + # Update progression stage for each camera + for camera_id in affected_cameras: + pipeline_state = self.session_cache.get_or_init_session_pipeline_state(camera_id) + current_mode = pipeline_state.get("mode", "validation_detecting") + + if progression_stage == "car_fueling": + # Stop YOLO inference during fueling + if current_mode == "lightweight": + pipeline_state["yolo_inference_enabled"] = False + pipeline_state["progression_stage"] = "car_fueling" + logger.info(f"⏸️ Camera {camera_id}: YOLO inference DISABLED for car_fueling stage") + else: + logger.debug(f"πŸ“Š Camera {camera_id}: car_fueling received but not in lightweight mode (mode: {current_mode})") + + elif progression_stage == "car_waitpayment": + # Resume YOLO inference for absence counter + pipeline_state["yolo_inference_enabled"] = True + pipeline_state["progression_stage"] = "car_waitpayment" + logger.info(f"▢️ Camera {camera_id}: YOLO inference RE-ENABLED for car_waitpayment stage") + + elif progression_stage == "welcome": + # Ignore welcome messages during car_waitpayment + current_progression = pipeline_state.get("progression_stage") + if current_progression == "car_waitpayment": + logger.info(f"🚫 Camera {camera_id}: IGNORING welcome stage (currently in car_waitpayment)") + else: + pipeline_state["progression_stage"] = "welcome" + logger.info(f"πŸŽ‰ Camera {camera_id}: Progression stage set to welcome") + + elif progression_stage in ["car_wait_staff"]: + pipeline_state["progression_stage"] = progression_stage + logger.info(f"πŸ“‹ Camera {camera_id}: Progression stage set to {progression_stage}") + + async def _send_detection_result( + self, + camera_id: str, + stream_info: Dict[str, Any], + detection_result: DetectionResult + ) -> None: + """Send detection result over WebSocket.""" + detection_data = { + "type": "imageDetection", + "subscriptionIdentifier": stream_info["subscriptionIdentifier"], + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "data": { + "detection": detection_result.to_dict(), + "modelId": stream_info["modelId"], + "modelName": stream_info["modelName"] + } + } + + try: + ws_logger.info(f"TX -> {json.dumps(detection_data, separators=(',', ':'))}") + await self.websocket.send_json(detection_data) + except RuntimeError as e: + if "websocket.close" in str(e): + logger.warning(f"WebSocket closed - cannot send detection for camera {camera_id}") + else: + raise + + async def _send_disconnection_notification( + self, + camera_id: str, + stream_info: Dict[str, Any] + ) -> None: + """Send camera disconnection notification.""" + logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null") + + # Clear cached data + self.session_cache.clear_session(camera_id) + + # Send null detection + detection_data = { + "type": "imageDetection", + "subscriptionIdentifier": stream_info["subscriptionIdentifier"], + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "data": { + "detection": None, + "modelId": stream_info["modelId"], + "modelName": stream_info["modelName"] + } + } + + try: + ws_logger.info(f"TX -> {json.dumps(detection_data, separators=(',', ':'))}") + await self.websocket.send_json(detection_data) + except RuntimeError as e: + if "websocket.close" in str(e): + logger.warning(f"WebSocket closed - cannot send disconnection signal for camera {camera_id}") + else: + raise + + self.camera_monitor.mark_disconnection_notified(camera_id) + logger.info(f"πŸ“‘ SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session") + + def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any: + """Apply crop to frame if crop coordinates are specified.""" + crop_coords = [ + stream_info.get("cropX1"), + stream_info.get("cropY1"), + stream_info.get("cropX2"), + stream_info.get("cropY2") + ] + + if all(coord is not None for coord in crop_coords): + x1, y1, x2, y2 = crop_coords + return frame[y1:y2, x1:x2] + + return frame + + +# Convenience function for backward compatibility +async def handle_websocket_connection( + websocket: WebSocket, + stream_manager: StreamManager, + model_manager: ModelManager, + pipeline_executor: PipelineExecutor, + session_cache: SessionCache, + redis_client: Optional[RedisClientManager] = None +) -> None: + """ + Handle a WebSocket connection using the WebSocketHandler. + + This is a convenience function that creates a handler instance + and processes the connection. + """ + handler = WebSocketHandler( + stream_manager, + model_manager, + pipeline_executor, + session_cache, + redis_client + ) + await handler.handle_connection(websocket) \ No newline at end of file diff --git a/detector_worker/models/model_manager.py b/detector_worker/models/model_manager.py new file mode 100644 index 0000000..3de3b97 --- /dev/null +++ b/detector_worker/models/model_manager.py @@ -0,0 +1,441 @@ +""" +Model manager module. + +This module handles ML model loading, caching, and lifecycle management +for the detection worker. +""" +import os +import logging +import threading +from typing import Dict, Any, Optional, List, Set, Tuple +from urllib.parse import urlparse +import traceback + +from ..core.config import MODELS_DIR +from ..core.exceptions import ModelLoadError + +# Setup logging +logger = logging.getLogger("detector_worker.model_manager") + + +class ModelRegistry: + """ + Registry for loaded models. + + Maintains a reference count for each model to enable sharing + between multiple cameras. + """ + + def __init__(self): + """Initialize the model registry.""" + self.models: Dict[str, Dict[str, Any]] = {} # model_id -> model_info + self.references: Dict[str, Set[str]] = {} # model_id -> set of camera_ids + self.lock = threading.Lock() + + def register_model(self, model_id: str, model_data: Any, model_path: str) -> None: + """ + Register a model in the registry. + + Args: + model_id: Unique model identifier + model_data: Loaded model data + model_path: Path to model file + """ + with self.lock: + self.models[model_id] = { + "model": model_data, + "path": model_path, + "loaded_at": os.path.getmtime(model_path) + } + if model_id not in self.references: + self.references[model_id] = set() + + def add_reference(self, model_id: str, camera_id: str) -> None: + """Add a reference to a model from a camera.""" + with self.lock: + if model_id in self.references: + self.references[model_id].add(camera_id) + + def remove_reference(self, model_id: str, camera_id: str) -> bool: + """ + Remove a reference to a model from a camera. + + Returns: + True if model has no more references and can be unloaded + """ + with self.lock: + if model_id in self.references: + self.references[model_id].discard(camera_id) + return len(self.references[model_id]) == 0 + return True + + def get_model(self, model_id: str) -> Optional[Any]: + """Get a model from the registry.""" + with self.lock: + model_info = self.models.get(model_id) + return model_info["model"] if model_info else None + + def unregister_model(self, model_id: str) -> None: + """Remove a model from the registry.""" + with self.lock: + self.models.pop(model_id, None) + self.references.pop(model_id, None) + + def get_loaded_models(self) -> List[str]: + """Get list of loaded model IDs.""" + with self.lock: + return list(self.models.keys()) + + def get_reference_count(self, model_id: str) -> int: + """Get number of references to a model.""" + with self.lock: + return len(self.references.get(model_id, set())) + + def clear(self) -> None: + """Clear all models from registry.""" + with self.lock: + self.models.clear() + self.references.clear() + + +class ModelManager: + """ + Manages ML model loading, caching, and lifecycle. + + This class handles: + - Model downloading and caching + - Model loading with proper error handling + - Reference counting for model sharing + - Model cleanup and memory management + - Pipeline model tree management + """ + + def __init__(self, models_dir: str = MODELS_DIR): + """ + Initialize the model manager. + + Args: + models_dir: Directory to cache downloaded models + """ + self.models_dir = models_dir + self.registry = ModelRegistry() + self.models_lock = threading.Lock() + + # Camera to models mapping + self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree} + + # Pipeline loader will be injected + self.pipeline_loader = None + + # Create models directory if it doesn't exist + os.makedirs(self.models_dir, exist_ok=True) + + def set_pipeline_loader(self, pipeline_loader: Any) -> None: + """ + Set the pipeline loader instance. + + Args: + pipeline_loader: Pipeline loader to use for loading models + """ + self.pipeline_loader = pipeline_loader + + async def load_model( + self, + camera_id: str, + model_id: str, + model_url: str, + force_reload: bool = False + ) -> Any: + """ + Load a model for a specific camera. + + Args: + camera_id: Camera identifier + model_id: Model identifier + model_url: URL or path to model file + force_reload: Force reload even if cached + + Returns: + Loaded model tree + + Raises: + ModelLoadError: If model loading fails + """ + if not self.pipeline_loader: + raise ModelLoadError("Pipeline loader not initialized") + + try: + # Check if model is already loaded for this camera + with self.models_lock: + if camera_id in self.camera_models and model_id in self.camera_models[camera_id]: + if not force_reload: + logger.info(f"Model {model_id} already loaded for camera {camera_id}") + return self.camera_models[camera_id][model_id] + + # Check if model is in registry + cached_model = self.registry.get_model(model_id) + if cached_model and not force_reload: + # Add reference and return cached model + self.registry.add_reference(model_id, camera_id) + with self.models_lock: + if camera_id not in self.camera_models: + self.camera_models[camera_id] = {} + self.camera_models[camera_id][model_id] = cached_model + logger.info(f"Using cached model {model_id} for camera {camera_id}") + return cached_model + + # Download or locate model file + model_path = await self._get_model_path(model_url, model_id) + + # Load model using pipeline loader + logger.info(f"Loading model {model_id} from {model_path}") + model_tree = await self.pipeline_loader.load_pipeline(model_path) + + # Register in registry + self.registry.register_model(model_id, model_tree, model_path) + self.registry.add_reference(model_id, camera_id) + + # Store in camera models + with self.models_lock: + if camera_id not in self.camera_models: + self.camera_models[camera_id] = {} + self.camera_models[camera_id][model_id] = model_tree + + logger.info(f"Successfully loaded model {model_id} for camera {camera_id}") + return model_tree + + except Exception as e: + logger.error(f"Failed to load model {model_id}: {e}") + traceback.print_exc() + raise ModelLoadError(f"Failed to load model {model_id}: {e}") + + async def _get_model_path(self, model_url: str, model_id: str) -> str: + """ + Get local path for a model, downloading if necessary. + + Args: + model_url: URL or local path to model + model_id: Model identifier + + Returns: + Local file path to model + """ + # Check if it's already a local path + if os.path.exists(model_url): + return model_url + + # Parse URL + parsed = urlparse(model_url) + + # Check if it's a file:// URL + if parsed.scheme == 'file': + return parsed.path + + # For HTTP/HTTPS URLs, download to cache + if parsed.scheme in ['http', 'https']: + # Generate cache filename + filename = os.path.basename(parsed.path) + if not filename: + filename = f"{model_id}.mpta" + + cache_path = os.path.join(self.models_dir, filename) + + # Check if already cached + if os.path.exists(cache_path): + logger.info(f"Using cached model file: {cache_path}") + return cache_path + + # Download model + logger.info(f"Downloading model from {model_url}") + await self._download_model(model_url, cache_path) + return cache_path + + # For other schemes or no scheme, assume local path + return model_url + + async def _download_model(self, url: str, destination: str) -> None: + """ + Download a model file from URL. + + Args: + url: URL to download from + destination: Local path to save to + """ + import aiohttp + import aiofiles + + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + response.raise_for_status() + + # Get total size if available + total_size = response.headers.get('Content-Length') + if total_size: + total_size = int(total_size) + logger.info(f"Downloading {total_size / (1024*1024):.2f} MB") + + # Download to temporary file first + temp_path = f"{destination}.tmp" + downloaded = 0 + + async with aiofiles.open(temp_path, 'wb') as f: + async for chunk in response.content.iter_chunked(8192): + await f.write(chunk) + downloaded += len(chunk) + + # Log progress + if total_size and downloaded % (1024 * 1024) == 0: + progress = (downloaded / total_size) * 100 + logger.info(f"Download progress: {progress:.1f}%") + + # Move to final destination + os.rename(temp_path, destination) + logger.info(f"Model downloaded successfully to {destination}") + + except Exception as e: + # Clean up temporary file if exists + temp_path = f"{destination}.tmp" + if os.path.exists(temp_path): + os.remove(temp_path) + raise ModelLoadError(f"Failed to download model: {e}") + + def get_model(self, camera_id: str, model_id: str) -> Optional[Any]: + """ + Get a loaded model for a camera. + + Args: + camera_id: Camera identifier + model_id: Model identifier + + Returns: + Model tree if loaded, None otherwise + """ + with self.models_lock: + camera_models = self.camera_models.get(camera_id, {}) + return camera_models.get(model_id) + + def unload_models(self, camera_id: str) -> None: + """ + Unload all models for a camera. + + Args: + camera_id: Camera identifier + """ + with self.models_lock: + if camera_id not in self.camera_models: + return + + # Remove references for each model + for model_id in self.camera_models[camera_id]: + should_unload = self.registry.remove_reference(model_id, camera_id) + + if should_unload: + logger.info(f"Unloading model {model_id} (no more references)") + self.registry.unregister_model(model_id) + + # Clean up model if pipeline loader supports it + if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_model'): + try: + self.pipeline_loader.cleanup_model(model_id) + except Exception as e: + logger.error(f"Error cleaning up model {model_id}: {e}") + + # Remove camera entry + del self.camera_models[camera_id] + logger.info(f"Unloaded all models for camera {camera_id}") + + def cleanup_all_models(self) -> None: + """Clean up all loaded models.""" + logger.info("Cleaning up all loaded models") + + with self.models_lock: + # Get list of cameras to clean up + cameras = list(self.camera_models.keys()) + + # Unload models for each camera + for camera_id in cameras: + self.unload_models(camera_id) + + # Clear registry + self.registry.clear() + + # Clean up pipeline loader if it has cleanup + if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_all'): + try: + self.pipeline_loader.cleanup_all() + except Exception as e: + logger.error(f"Error in pipeline loader cleanup: {e}") + + logger.info("Model cleanup completed") + + def get_loaded_models(self) -> Dict[str, List[str]]: + """ + Get information about loaded models. + + Returns: + Dictionary mapping model IDs to list of camera IDs using them + """ + result = {} + + with self.models_lock: + for model_id in self.registry.get_loaded_models(): + cameras = [] + for camera_id, models in self.camera_models.items(): + if model_id in models: + cameras.append(camera_id) + result[model_id] = cameras + + return result + + def get_model_stats(self) -> Dict[str, Any]: + """ + Get statistics about loaded models. + + Returns: + Dictionary with model statistics + """ + with self.models_lock: + total_models = len(self.registry.get_loaded_models()) + total_cameras = len(self.camera_models) + + # Count total model instances + total_instances = sum( + len(models) for models in self.camera_models.values() + ) + + # Get cache size + cache_size = 0 + if os.path.exists(self.models_dir): + for filename in os.listdir(self.models_dir): + filepath = os.path.join(self.models_dir, filename) + if os.path.isfile(filepath): + cache_size += os.path.getsize(filepath) + + return { + "total_models": total_models, + "total_cameras": total_cameras, + "total_instances": total_instances, + "cache_size_mb": round(cache_size / (1024 * 1024), 2), + "models_dir": self.models_dir + } + + +# Global model manager instance +_model_manager = None + + +def get_model_manager() -> ModelManager: + """Get or create the global model manager instance.""" + global _model_manager + if _model_manager is None: + _model_manager = ModelManager() + return _model_manager + + +# Convenience functions for backward compatibility +def initialize_model_manager(models_dir: str = MODELS_DIR) -> ModelManager: + """Initialize the global model manager.""" + global _model_manager + _model_manager = ModelManager(models_dir) + return _model_manager \ No newline at end of file diff --git a/detector_worker/models/pipeline_loader.py b/detector_worker/models/pipeline_loader.py new file mode 100644 index 0000000..01d072a --- /dev/null +++ b/detector_worker/models/pipeline_loader.py @@ -0,0 +1,486 @@ +""" +Pipeline loader module. + +This module handles loading and parsing of MPTA (Machine Learning Pipeline Archive) +files, which contain model configurations and pipeline definitions. +""" +import os +import json +import logging +import zipfile +import tempfile +import shutil +from typing import Dict, Any, Optional, List, Tuple +from dataclasses import dataclass, field +from pathlib import Path + +from ..core.exceptions import ModelLoadError, PipelineError + +# Setup logging +logger = logging.getLogger("detector_worker.pipeline_loader") + + +@dataclass +class PipelineNode: + """Represents a node in the pipeline tree.""" + model_id: str + model_file: str + model_path: Optional[str] = None + model: Optional[Any] = None # Loaded model instance + + # Node configuration + multi_class: bool = False + expected_classes: List[str] = field(default_factory=list) + trigger_classes: List[str] = field(default_factory=list) + min_confidence: float = 0.5 + max_detections: Optional[int] = None + + # Cropping configuration + crop: bool = False + crop_class: Optional[str] = None + crop_expand_ratio: float = 1.0 + + # Actions configuration + actions: List[Dict[str, Any]] = field(default_factory=list) + parallel_actions: List[Dict[str, Any]] = field(default_factory=list) + + # Branch configuration + branches: List['PipelineNode'] = field(default_factory=list) + parallel: bool = False + + # Detection settings + yolo_settings: Dict[str, Any] = field(default_factory=dict) + track_classes: Optional[List[str]] = None + + # Metadata + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PipelineConfig: + """Pipeline configuration from pipeline.json.""" + pipeline_id: str + version: str = "1.0" + description: str = "" + + # Database configuration + database_config: Optional[Dict[str, Any]] = None + + # Redis configuration + redis_config: Optional[Dict[str, Any]] = None + + # Global settings + global_settings: Dict[str, Any] = field(default_factory=dict) + + # Root pipeline node + root: Optional[PipelineNode] = None + + +class PipelineLoader: + """ + Loads and manages ML pipeline configurations. + + This class handles: + - MPTA file extraction and parsing + - Pipeline configuration validation + - Model file management + - Pipeline tree construction + - Resource cleanup + """ + + def __init__(self, temp_dir: Optional[str] = None): + """ + Initialize the pipeline loader. + + Args: + temp_dir: Temporary directory for extracting MPTA files + """ + self.temp_dir = temp_dir or tempfile.gettempdir() + self.extracted_paths: Dict[str, str] = {} # mpta_path -> extracted_dir + self.loaded_models: Dict[str, Any] = {} # model_path -> model_instance + + async def load_pipeline(self, mpta_path: str) -> PipelineNode: + """ + Load a pipeline from an MPTA file. + + Args: + mpta_path: Path to MPTA file + + Returns: + Root pipeline node + + Raises: + ModelLoadError: If loading fails + """ + try: + # Extract MPTA if not already extracted + extracted_dir = await self._extract_mpta(mpta_path) + + # Load pipeline configuration + pipeline_json_path = os.path.join(extracted_dir, "pipeline.json") + if not os.path.exists(pipeline_json_path): + raise ModelLoadError(f"pipeline.json not found in {mpta_path}") + + with open(pipeline_json_path, 'r') as f: + config_data = json.load(f) + + # Parse pipeline configuration + pipeline_config = self._parse_pipeline_config(config_data, extracted_dir) + + # Validate pipeline + self._validate_pipeline(pipeline_config) + + # Load models for the pipeline + await self._load_pipeline_models(pipeline_config.root, extracted_dir) + + logger.info(f"Successfully loaded pipeline from {mpta_path}") + return pipeline_config.root + + except Exception as e: + logger.error(f"Failed to load pipeline from {mpta_path}: {e}") + raise ModelLoadError(f"Failed to load pipeline: {e}") + + async def _extract_mpta(self, mpta_path: str) -> str: + """ + Extract MPTA file to temporary directory. + + Args: + mpta_path: Path to MPTA file + + Returns: + Path to extracted directory + """ + # Check if already extracted + if mpta_path in self.extracted_paths: + extracted_dir = self.extracted_paths[mpta_path] + if os.path.exists(extracted_dir): + return extracted_dir + + # Create extraction directory + mpta_name = os.path.splitext(os.path.basename(mpta_path))[0] + extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}") + + # Extract MPTA + logger.info(f"Extracting MPTA file: {mpta_path}") + + try: + with zipfile.ZipFile(mpta_path, 'r') as zip_ref: + # Clean existing directory if exists + if os.path.exists(extracted_dir): + shutil.rmtree(extracted_dir) + + os.makedirs(extracted_dir) + zip_ref.extractall(extracted_dir) + + self.extracted_paths[mpta_path] = extracted_dir + logger.info(f"Extracted to: {extracted_dir}") + + return extracted_dir + + except Exception as e: + raise ModelLoadError(f"Failed to extract MPTA: {e}") + + def _parse_pipeline_config( + self, + config_data: Dict[str, Any], + base_dir: str + ) -> PipelineConfig: + """ + Parse pipeline configuration from JSON. + + Args: + config_data: Pipeline JSON data + base_dir: Base directory for model files + + Returns: + Parsed pipeline configuration + """ + # Create pipeline config + pipeline_config = PipelineConfig( + pipeline_id=config_data.get("pipelineId", "unknown"), + version=config_data.get("version", "1.0"), + description=config_data.get("description", "") + ) + + # Parse database config + if "database" in config_data: + pipeline_config.database_config = config_data["database"] + + # Parse Redis config + if "redis" in config_data: + pipeline_config.redis_config = config_data["redis"] + + # Parse global settings + if "globalSettings" in config_data: + pipeline_config.global_settings = config_data["globalSettings"] + + # Parse pipeline tree + if "pipeline" in config_data: + pipeline_config.root = self._parse_pipeline_node( + config_data["pipeline"], base_dir + ) + elif "root" in config_data: + pipeline_config.root = self._parse_pipeline_node( + config_data["root"], base_dir + ) + else: + raise PipelineError("No pipeline or root node found in configuration") + + return pipeline_config + + def _parse_pipeline_node( + self, + node_data: Dict[str, Any], + base_dir: str + ) -> PipelineNode: + """ + Parse a pipeline node from configuration. + + Args: + node_data: Node configuration data + base_dir: Base directory for model files + + Returns: + Parsed pipeline node + """ + # Create node + node = PipelineNode( + model_id=node_data.get("modelId", ""), + model_file=node_data.get("modelFile", "") + ) + + # Set model path + if node.model_file: + node.model_path = os.path.join(base_dir, node.model_file) + + # Parse configuration + node.multi_class = node_data.get("multiClass", False) + node.expected_classes = node_data.get("expectedClasses", []) + node.trigger_classes = node_data.get("triggerClasses", []) + node.min_confidence = node_data.get("minConfidence", 0.5) + node.max_detections = node_data.get("maxDetections") + + # Parse cropping + node.crop = node_data.get("crop", False) + node.crop_class = node_data.get("cropClass") + node.crop_expand_ratio = node_data.get("cropExpandRatio", 1.0) + + # Parse actions + node.actions = node_data.get("actions", []) + node.parallel_actions = node_data.get("parallelActions", []) + + # Parse YOLO settings + if "yoloSettings" in node_data: + node.yolo_settings = node_data["yoloSettings"] + elif "detectionSettings" in node_data: + node.yolo_settings = node_data["detectionSettings"] + + # Parse tracking + node.track_classes = node_data.get("trackClasses") + + # Parse metadata + node.metadata = node_data.get("metadata", {}) + + # Parse branches + branches_data = node_data.get("branches", []) + node.parallel = node_data.get("parallel", False) + + for branch_data in branches_data: + branch_node = self._parse_pipeline_node(branch_data, base_dir) + node.branches.append(branch_node) + + return node + + def _validate_pipeline(self, pipeline_config: PipelineConfig) -> None: + """ + Validate pipeline configuration. + + Args: + pipeline_config: Pipeline configuration to validate + + Raises: + PipelineError: If validation fails + """ + if not pipeline_config.root: + raise PipelineError("Pipeline has no root node") + + # Validate root node + self._validate_node(pipeline_config.root) + + def _validate_node(self, node: PipelineNode) -> None: + """ + Validate a pipeline node. + + Args: + node: Node to validate + + Raises: + PipelineError: If validation fails + """ + # Check required fields + if not node.model_id: + raise PipelineError("Node missing modelId") + + if not node.model_file and not node.model: + raise PipelineError(f"Node {node.model_id} missing modelFile") + + # Validate model path exists + if node.model_path and not os.path.exists(node.model_path): + raise PipelineError(f"Model file not found: {node.model_path}") + + # Validate cropping configuration + if node.crop and not node.crop_class: + raise PipelineError(f"Node {node.model_id} has crop=true but no cropClass") + + # Validate confidence + if not 0 <= node.min_confidence <= 1: + raise PipelineError(f"Invalid minConfidence: {node.min_confidence}") + + # Validate branches + for branch in node.branches: + self._validate_node(branch) + + async def _load_pipeline_models( + self, + node: PipelineNode, + base_dir: str + ) -> None: + """ + Load models for a pipeline node and its branches. + + Args: + node: Pipeline node + base_dir: Base directory for models + """ + # Load model for this node if path is specified + if node.model_path: + node.model = await self._load_model(node.model_path, node.model_id) + + # Load models for branches + for branch in node.branches: + await self._load_pipeline_models(branch, base_dir) + + async def _load_model(self, model_path: str, model_id: str) -> Any: + """ + Load a single model file. + + Args: + model_path: Path to model file + model_id: Model identifier + + Returns: + Loaded model instance + """ + # Check if already loaded + if model_path in self.loaded_models: + logger.info(f"Using cached model: {model_id}") + return self.loaded_models[model_path] + + try: + # Import here to avoid circular dependency + from ultralytics import YOLO + + logger.info(f"Loading model: {model_id} from {model_path}") + + # Load YOLO model + model = YOLO(model_path) + + # Cache the model + self.loaded_models[model_path] = model + + return model + + except Exception as e: + raise ModelLoadError(f"Failed to load model {model_id}: {e}") + + def cleanup_model(self, model_id: str) -> None: + """ + Clean up resources for a specific model. + + Args: + model_id: Model identifier to clean up + """ + # Clean up loaded models + models_to_remove = [] + for path, model in self.loaded_models.items(): + if model_id in path: + models_to_remove.append(path) + + for path in models_to_remove: + self.loaded_models.pop(path, None) + logger.info(f"Cleaned up model: {path}") + + def cleanup_all(self) -> None: + """Clean up all resources.""" + # Clear loaded models + self.loaded_models.clear() + + # Clean up extracted directories + for mpta_path, extracted_dir in self.extracted_paths.items(): + if os.path.exists(extracted_dir): + try: + shutil.rmtree(extracted_dir) + logger.info(f"Cleaned up extracted directory: {extracted_dir}") + except Exception as e: + logger.error(f"Failed to clean up {extracted_dir}: {e}") + + self.extracted_paths.clear() + + def get_node_info(self, node: PipelineNode, level: int = 0) -> str: + """ + Get formatted information about a pipeline node. + + Args: + node: Pipeline node + level: Indentation level + + Returns: + Formatted node information + """ + indent = " " * level + info = [] + + info.append(f"{indent}Model: {node.model_id}") + info.append(f"{indent} File: {node.model_file}") + info.append(f"{indent} Multi-class: {node.multi_class}") + + if node.expected_classes: + info.append(f"{indent} Expected: {', '.join(node.expected_classes)}") + if node.trigger_classes: + info.append(f"{indent} Triggers: {', '.join(node.trigger_classes)}") + + info.append(f"{indent} Confidence: {node.min_confidence}") + + if node.crop: + info.append(f"{indent} Crop: {node.crop_class} (ratio: {node.crop_expand_ratio})") + + if node.actions: + info.append(f"{indent} Actions: {len(node.actions)}") + if node.parallel_actions: + info.append(f"{indent} Parallel Actions: {len(node.parallel_actions)}") + + if node.branches: + info.append(f"{indent} Branches ({len(node.branches)}):") + for branch in node.branches: + info.append(self.get_node_info(branch, level + 2)) + + return "\n".join(info) + + +# Global pipeline loader instance +_pipeline_loader = None + + +def get_pipeline_loader(temp_dir: Optional[str] = None) -> PipelineLoader: + """Get or create the global pipeline loader instance.""" + global _pipeline_loader + if _pipeline_loader is None: + _pipeline_loader = PipelineLoader(temp_dir) + return _pipeline_loader + + +# Convenience functions +async def load_pipeline_from_mpta(mpta_path: str) -> PipelineNode: + """Load a pipeline from an MPTA file.""" + loader = get_pipeline_loader() + return await loader.load_pipeline(mpta_path) \ No newline at end of file diff --git a/detector_worker/utils/system_monitor.py b/detector_worker/utils/system_monitor.py new file mode 100644 index 0000000..78f020f --- /dev/null +++ b/detector_worker/utils/system_monitor.py @@ -0,0 +1,379 @@ +""" +System monitoring utilities. + +This module provides functions to monitor system resources including +CPU, memory, and GPU usage. +""" +import logging +import time +import psutil +from typing import Dict, Any, Optional, List + +try: + import pynvml + NVIDIA_AVAILABLE = True +except ImportError: + NVIDIA_AVAILABLE = False + +# Setup logging +logger = logging.getLogger("detector_worker.system_monitor") + +# Global initialization flag +_nvidia_initialized = False + +# Process start time +_start_time = time.time() + + +def initialize_nvidia_monitoring() -> bool: + """ + Initialize NVIDIA GPU monitoring. + + Returns: + True if initialization successful, False otherwise + """ + global _nvidia_initialized + + if not NVIDIA_AVAILABLE: + return False + + if _nvidia_initialized: + return True + + try: + pynvml.nvmlInit() + _nvidia_initialized = True + logger.info("NVIDIA GPU monitoring initialized") + return True + except Exception as e: + logger.warning(f"Failed to initialize NVIDIA monitoring: {e}") + return False + + +def get_gpu_info() -> List[Dict[str, Any]]: + """ + Get information about available GPUs. + + Returns: + List of GPU information dictionaries + """ + if not NVIDIA_AVAILABLE or not _nvidia_initialized: + return [] + + gpu_info = [] + + try: + device_count = pynvml.nvmlDeviceGetCount() + + for i in range(device_count): + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + + # Get GPU information + name = pynvml.nvmlDeviceGetName(handle).decode('utf-8') + + # Get memory info + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + total_memory = mem_info.total / (1024 ** 3) # Convert to GB + used_memory = mem_info.used / (1024 ** 3) + free_memory = mem_info.free / (1024 ** 3) + + # Get utilization + try: + util = pynvml.nvmlDeviceGetUtilizationRates(handle) + gpu_util = util.gpu + memory_util = util.memory + except Exception: + gpu_util = 0 + memory_util = 0 + + # Get temperature + try: + temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU) + except Exception: + temp = 0 + + gpu_info.append({ + "index": i, + "name": name, + "gpu_utilization": gpu_util, + "memory_utilization": memory_util, + "total_memory_gb": round(total_memory, 2), + "used_memory_gb": round(used_memory, 2), + "free_memory_gb": round(free_memory, 2), + "temperature": temp + }) + + except Exception as e: + logger.error(f"Error getting GPU info: {e}") + + return gpu_info + + +def get_gpu_metrics() -> Dict[str, float]: + """ + Get current GPU metrics. + + Returns: + Dictionary with GPU utilization and memory usage + """ + if not NVIDIA_AVAILABLE or not _nvidia_initialized: + return {"gpu_percent": 0.0, "gpu_memory_percent": 0.0} + + try: + # Get first GPU metrics + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + + # Get utilization + util = pynvml.nvmlDeviceGetUtilizationRates(handle) + + # Get memory info + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + memory_percent = (mem_info.used / mem_info.total) * 100 + + return { + "gpu_percent": float(util.gpu), + "gpu_memory_percent": float(memory_percent) + } + + except Exception as e: + logger.debug(f"Error getting GPU metrics: {e}") + return {"gpu_percent": 0.0, "gpu_memory_percent": 0.0} + + +def get_cpu_metrics() -> Dict[str, float]: + """ + Get current CPU metrics. + + Returns: + Dictionary with CPU usage percentage + """ + try: + # Get CPU percent with 0.1 second interval + cpu_percent = psutil.cpu_percent(interval=0.1) + + # Get per-core CPU usage + cpu_per_core = psutil.cpu_percent(interval=0.1, percpu=True) + + return { + "cpu_percent": cpu_percent, + "cpu_per_core": cpu_per_core, + "cpu_count": psutil.cpu_count() + } + except Exception as e: + logger.error(f"Error getting CPU metrics: {e}") + return {"cpu_percent": 0.0, "cpu_per_core": [], "cpu_count": 0} + + +def get_memory_metrics() -> Dict[str, Any]: + """ + Get current memory metrics. + + Returns: + Dictionary with memory usage information + """ + try: + # Get virtual memory info + virtual_mem = psutil.virtual_memory() + + # Get swap memory info + swap_mem = psutil.swap_memory() + + return { + "memory_percent": virtual_mem.percent, + "memory_used_gb": round(virtual_mem.used / (1024 ** 3), 2), + "memory_available_gb": round(virtual_mem.available / (1024 ** 3), 2), + "memory_total_gb": round(virtual_mem.total / (1024 ** 3), 2), + "swap_percent": swap_mem.percent, + "swap_used_gb": round(swap_mem.used / (1024 ** 3), 2) + } + except Exception as e: + logger.error(f"Error getting memory metrics: {e}") + return { + "memory_percent": 0.0, + "memory_used_gb": 0.0, + "memory_available_gb": 0.0, + "memory_total_gb": 0.0, + "swap_percent": 0.0, + "swap_used_gb": 0.0 + } + + +def get_disk_metrics(path: str = "/") -> Dict[str, Any]: + """ + Get disk usage metrics for specified path. + + Args: + path: Path to check disk usage for + + Returns: + Dictionary with disk usage information + """ + try: + disk_usage = psutil.disk_usage(path) + + return { + "disk_percent": disk_usage.percent, + "disk_used_gb": round(disk_usage.used / (1024 ** 3), 2), + "disk_free_gb": round(disk_usage.free / (1024 ** 3), 2), + "disk_total_gb": round(disk_usage.total / (1024 ** 3), 2) + } + except Exception as e: + logger.error(f"Error getting disk metrics: {e}") + return { + "disk_percent": 0.0, + "disk_used_gb": 0.0, + "disk_free_gb": 0.0, + "disk_total_gb": 0.0 + } + + +def get_process_metrics() -> Dict[str, Any]: + """ + Get current process metrics. + + Returns: + Dictionary with process-specific metrics + """ + try: + process = psutil.Process() + + # Get process info + with process.oneshot(): + cpu_percent = process.cpu_percent() + memory_info = process.memory_info() + memory_percent = process.memory_percent() + num_threads = process.num_threads() + + # Get open file descriptors + try: + num_fds = len(process.open_files()) + except Exception: + num_fds = 0 + + return { + "process_cpu_percent": cpu_percent, + "process_memory_mb": round(memory_info.rss / (1024 ** 2), 2), + "process_memory_percent": round(memory_percent, 2), + "process_threads": num_threads, + "process_open_files": num_fds + } + except Exception as e: + logger.error(f"Error getting process metrics: {e}") + return { + "process_cpu_percent": 0.0, + "process_memory_mb": 0.0, + "process_memory_percent": 0.0, + "process_threads": 0, + "process_open_files": 0 + } + + +def get_network_metrics() -> Dict[str, Any]: + """ + Get network I/O metrics. + + Returns: + Dictionary with network statistics + """ + try: + net_io = psutil.net_io_counters() + + return { + "bytes_sent": net_io.bytes_sent, + "bytes_recv": net_io.bytes_recv, + "packets_sent": net_io.packets_sent, + "packets_recv": net_io.packets_recv, + "errin": net_io.errin, + "errout": net_io.errout, + "dropin": net_io.dropin, + "dropout": net_io.dropout + } + except Exception as e: + logger.error(f"Error getting network metrics: {e}") + return { + "bytes_sent": 0, + "bytes_recv": 0, + "packets_sent": 0, + "packets_recv": 0, + "errin": 0, + "errout": 0, + "dropin": 0, + "dropout": 0 + } + + +def get_system_metrics() -> Dict[str, Any]: + """ + Get comprehensive system metrics. + + Returns: + Dictionary containing all system metrics + """ + # Initialize GPU monitoring if not done + if NVIDIA_AVAILABLE and not _nvidia_initialized: + initialize_nvidia_monitoring() + + metrics = { + "timestamp": time.time(), + "uptime": time.time() - _start_time, + "start_time": _start_time + } + + # Get CPU metrics + cpu_metrics = get_cpu_metrics() + metrics.update(cpu_metrics) + + # Get memory metrics + memory_metrics = get_memory_metrics() + metrics.update(memory_metrics) + + # Get GPU metrics + gpu_metrics = get_gpu_metrics() + metrics.update(gpu_metrics) + + # Get process metrics + process_metrics = get_process_metrics() + metrics.update(process_metrics) + + return metrics + + +def get_resource_summary() -> str: + """ + Get a formatted summary of system resources. + + Returns: + Formatted string with resource summary + """ + metrics = get_system_metrics() + + summary = [] + summary.append(f"CPU: {metrics['cpu_percent']:.1f}%") + summary.append(f"Memory: {metrics['memory_percent']:.1f}% ({metrics['memory_used_gb']:.1f}GB/{metrics['memory_total_gb']:.1f}GB)") + + if metrics['gpu_percent'] > 0: + summary.append(f"GPU: {metrics['gpu_percent']:.1f}%") + summary.append(f"GPU Memory: {metrics['gpu_memory_percent']:.1f}%") + + summary.append(f"Process Memory: {metrics['process_memory_mb']:.1f}MB") + summary.append(f"Threads: {metrics['process_threads']}") + + return " | ".join(summary) + + +def cleanup_nvidia_monitoring(): + """Clean up NVIDIA monitoring resources.""" + global _nvidia_initialized + + if NVIDIA_AVAILABLE and _nvidia_initialized: + try: + pynvml.nvmlShutdown() + _nvidia_initialized = False + logger.info("NVIDIA GPU monitoring cleaned up") + except Exception as e: + logger.error(f"Error cleaning up NVIDIA monitoring: {e}") + + +# Initialize on import +if NVIDIA_AVAILABLE: + initialize_nvidia_monitoring() \ No newline at end of file