Refactor: Phase 4: Communication Layer

This commit is contained in:
ziesorx 2025-09-12 15:26:31 +07:00
parent cdeaaf4a4f
commit 54f21672aa
6 changed files with 2876 additions and 0 deletions

View file

@ -0,0 +1,509 @@
"""
Message processor module.
This module handles validation, parsing, and processing of WebSocket messages.
It provides a clean separation between message handling and business logic.
"""
import json
import logging
import time
from typing import Dict, Any, Optional, Callable, Tuple
from enum import Enum
from dataclasses import dataclass, field
from ..core.exceptions import ValidationError, MessageProcessingError
# Setup logging
logger = logging.getLogger("detector_worker.message_processor")
class MessageType(Enum):
"""Enumeration of supported WebSocket message types."""
SUBSCRIBE = "subscribe"
UNSUBSCRIBE = "unsubscribe"
REQUEST_STATE = "requestState"
SET_SESSION_ID = "setSessionId"
PATCH_SESSION = "patchSession"
SET_PROGRESSION_STAGE = "setProgressionStage"
IMAGE_DETECTION = "imageDetection"
STATE_REPORT = "stateReport"
ACK = "ack"
ERROR = "error"
class ProgressionStage(Enum):
"""Enumeration of progression stages."""
WELCOME = "welcome"
CAR_FUELING = "car_fueling"
CAR_WAIT_PAYMENT = "car_waitpayment"
CAR_WAIT_STAFF = "car_wait_staff"
@dataclass
class SubscribePayload:
"""Payload for subscription messages."""
subscription_identifier: str
model_id: int
model_name: str
model_url: str
rtsp_url: Optional[str] = None
snapshot_url: Optional[str] = None
snapshot_interval: Optional[int] = None
crop_x1: Optional[int] = None
crop_y1: Optional[int] = None
crop_x2: Optional[int] = None
crop_y2: Optional[int] = None
extra_params: Dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SubscribePayload":
"""Create SubscribePayload from dictionary."""
# Extract known fields
known_fields = {
"subscription_identifier": data.get("subscriptionIdentifier"),
"model_id": data.get("modelId"),
"model_name": data.get("modelName"),
"model_url": data.get("modelUrl"),
"rtsp_url": data.get("rtspUrl"),
"snapshot_url": data.get("snapshotUrl"),
"snapshot_interval": data.get("snapshotInterval"),
"crop_x1": data.get("cropX1"),
"crop_y1": data.get("cropY1"),
"crop_x2": data.get("cropX2"),
"crop_y2": data.get("cropY2")
}
# Extract extra parameters
extra_params = {k: v for k, v in data.items()
if k not in ["subscriptionIdentifier", "modelId", "modelName",
"modelUrl", "rtspUrl", "snapshotUrl", "snapshotInterval",
"cropX1", "cropY1", "cropX2", "cropY2"]}
return cls(**known_fields, extra_params=extra_params)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format for stream configuration."""
result = {
"subscriptionIdentifier": self.subscription_identifier,
"modelId": self.model_id,
"modelName": self.model_name,
"modelUrl": self.model_url
}
if self.rtsp_url:
result["rtspUrl"] = self.rtsp_url
if self.snapshot_url:
result["snapshotUrl"] = self.snapshot_url
if self.snapshot_interval is not None:
result["snapshotInterval"] = self.snapshot_interval
if self.crop_x1 is not None:
result["cropX1"] = self.crop_x1
if self.crop_y1 is not None:
result["cropY1"] = self.crop_y1
if self.crop_x2 is not None:
result["cropX2"] = self.crop_x2
if self.crop_y2 is not None:
result["cropY2"] = self.crop_y2
# Add any extra parameters
result.update(self.extra_params)
return result
@dataclass
class SessionPayload:
"""Payload for session-related messages."""
display_identifier: str
session_id: str
data: Optional[Dict[str, Any]] = None
@dataclass
class ProgressionPayload:
"""Payload for progression stage messages."""
display_identifier: str
progression_stage: str
class MessageProcessor:
"""
Processes and validates WebSocket messages.
This class handles:
- Message validation and parsing
- Payload extraction and type conversion
- Error handling and response generation
- Message routing to appropriate handlers
"""
def __init__(self):
"""Initialize the message processor."""
self.validators: Dict[MessageType, Callable] = {
MessageType.SUBSCRIBE: self._validate_subscribe,
MessageType.UNSUBSCRIBE: self._validate_unsubscribe,
MessageType.SET_SESSION_ID: self._validate_set_session,
MessageType.PATCH_SESSION: self._validate_patch_session,
MessageType.SET_PROGRESSION_STAGE: self._validate_progression_stage
}
def parse_message(self, raw_message: str) -> Tuple[MessageType, Dict[str, Any]]:
"""
Parse a raw WebSocket message.
Args:
raw_message: Raw message string
Returns:
Tuple of (MessageType, parsed_data)
Raises:
MessageProcessingError: If message is invalid
"""
try:
data = json.loads(raw_message)
except json.JSONDecodeError as e:
raise MessageProcessingError(f"Invalid JSON: {e}")
# Validate message structure
if not isinstance(data, dict):
raise MessageProcessingError("Message must be a JSON object")
# Extract message type
msg_type_str = data.get("type")
if not msg_type_str:
raise MessageProcessingError("Missing 'type' field in message")
# Convert to MessageType enum
try:
msg_type = MessageType(msg_type_str)
except ValueError:
raise MessageProcessingError(f"Unknown message type: {msg_type_str}")
return msg_type, data
def validate_message(self, msg_type: MessageType, data: Dict[str, Any]) -> Any:
"""
Validate message payload based on type.
Args:
msg_type: Message type
data: Message data
Returns:
Validated payload object
Raises:
ValidationError: If validation fails
"""
# Get validator for message type
validator = self.validators.get(msg_type)
if not validator:
# No validation needed for some message types
return data.get("payload", {})
return validator(data)
def _validate_subscribe(self, data: Dict[str, Any]) -> SubscribePayload:
"""Validate subscribe message."""
payload = data.get("payload", {})
# Required fields
required = ["subscriptionIdentifier", "modelId", "modelName", "modelUrl"]
missing = [field for field in required if not payload.get(field)]
if missing:
raise ValidationError(f"Missing required fields: {', '.join(missing)}")
# Must have either rtspUrl or snapshotUrl
if not payload.get("rtspUrl") and not payload.get("snapshotUrl"):
raise ValidationError("Must provide either rtspUrl or snapshotUrl")
# Validate snapshot interval if snapshot URL is provided
if payload.get("snapshotUrl") and not payload.get("snapshotInterval"):
raise ValidationError("snapshotInterval is required when using snapshotUrl")
# Validate crop coordinates
crop_coords = ["cropX1", "cropY1", "cropX2", "cropY2"]
crop_values = [payload.get(coord) for coord in crop_coords]
if any(v is not None for v in crop_values):
# If any crop coordinate is set, all must be set
if not all(v is not None for v in crop_values):
raise ValidationError("All crop coordinates must be specified if any are set")
# Validate coordinate values
x1, y1, x2, y2 = crop_values
if x2 <= x1 or y2 <= y1:
raise ValidationError("Invalid crop coordinates: x2 must be > x1, y2 must be > y1")
return SubscribePayload.from_dict(payload)
def _validate_unsubscribe(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate unsubscribe message."""
payload = data.get("payload", {})
if not payload.get("subscriptionIdentifier"):
raise ValidationError("Missing required field: subscriptionIdentifier")
return payload
def _validate_set_session(self, data: Dict[str, Any]) -> SessionPayload:
"""Validate setSessionId message."""
payload = data.get("payload", {})
display_id = payload.get("displayIdentifier")
session_id = payload.get("sessionId")
if not display_id:
raise ValidationError("Missing required field: displayIdentifier")
if not session_id:
raise ValidationError("Missing required field: sessionId")
return SessionPayload(display_id, session_id)
def _validate_patch_session(self, data: Dict[str, Any]) -> SessionPayload:
"""Validate patchSession message."""
payload = data.get("payload", {})
session_id = payload.get("sessionId")
patch_data = payload.get("data")
if not session_id:
raise ValidationError("Missing required field: sessionId")
# Display identifier is optional for patch
display_id = payload.get("displayIdentifier", "")
return SessionPayload(display_id, session_id, patch_data)
def _validate_progression_stage(self, data: Dict[str, Any]) -> ProgressionPayload:
"""Validate setProgressionStage message."""
payload = data.get("payload", {})
display_id = payload.get("displayIdentifier")
stage = payload.get("progressionStage")
if not display_id:
raise ValidationError("Missing required field: displayIdentifier")
if not stage:
raise ValidationError("Missing required field: progressionStage")
# Validate stage value
valid_stages = [s.value for s in ProgressionStage]
if stage not in valid_stages:
logger.warning(f"Unknown progression stage: {stage}")
return ProgressionPayload(display_id, stage)
def create_ack_response(
self,
request_id: Optional[str] = None,
message: str = "Success",
data: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Create an acknowledgment response.
Args:
request_id: Request ID to acknowledge
message: Success message
data: Additional response data
Returns:
ACK response dictionary
"""
response = {
"type": MessageType.ACK.value,
"requestId": request_id or str(int(time.time() * 1000)),
"code": "200",
"data": {
"message": message
}
}
if data:
response["data"].update(data)
return response
def create_error_response(
self,
request_id: Optional[str] = None,
code: str = "400",
message: str = "Error",
details: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Create an error response.
Args:
request_id: Request ID that caused error
code: Error code
message: Error message
details: Additional error details
Returns:
Error response dictionary
"""
response = {
"type": MessageType.ERROR.value,
"requestId": request_id or str(int(time.time() * 1000)),
"code": code,
"data": {
"message": message
}
}
if details:
response["data"]["details"] = details
return response
def create_detection_message(
self,
subscription_id: str,
detection_result: Optional[Dict[str, Any]],
model_id: int,
model_name: str,
timestamp: Optional[str] = None
) -> Dict[str, Any]:
"""
Create a detection result message.
Args:
subscription_id: Subscription identifier
detection_result: Detection result data (None for disconnection)
model_id: Model ID
model_name: Model name
timestamp: Optional timestamp (defaults to current UTC)
Returns:
Detection message dictionary
"""
if timestamp is None:
timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
return {
"type": MessageType.IMAGE_DETECTION.value,
"subscriptionIdentifier": subscription_id,
"timestamp": timestamp,
"data": {
"detection": detection_result,
"modelId": model_id,
"modelName": model_name
}
}
def create_state_report(
self,
active_streams: int,
loaded_models: int,
cpu_usage: float = 0.0,
memory_usage: float = 0.0,
gpu_usage: float = 0.0,
gpu_memory: float = 0.0,
uptime: float = 0.0,
additional_metrics: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Create a state report message.
Args:
active_streams: Number of active streams
loaded_models: Number of loaded models
cpu_usage: CPU usage percentage
memory_usage: Memory usage percentage
gpu_usage: GPU usage percentage
gpu_memory: GPU memory usage percentage
uptime: System uptime in seconds
additional_metrics: Additional metrics to include
Returns:
State report dictionary
"""
report = {
"type": MessageType.STATE_REPORT.value,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"data": {
"activeStreams": active_streams,
"loadedModels": loaded_models,
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemory": gpu_memory,
"uptime": uptime
}
}
if additional_metrics:
report["data"].update(additional_metrics)
return report
def extract_camera_id(self, subscription_id: str) -> Tuple[Optional[str], str]:
"""
Extract display ID and camera ID from subscription identifier.
Args:
subscription_id: Subscription identifier (format: "display_id;camera_id")
Returns:
Tuple of (display_id, camera_id)
"""
parts = subscription_id.split(";")
if len(parts) >= 2:
return parts[0], parts[1]
else:
return None, subscription_id
# Singleton instance for backward compatibility
_message_processor = MessageProcessor()
# Convenience functions
def parse_websocket_message(raw_message: str) -> Tuple[MessageType, Dict[str, Any]]:
"""Parse a raw WebSocket message."""
return _message_processor.parse_message(raw_message)
def validate_message_payload(msg_type: MessageType, data: Dict[str, Any]) -> Any:
"""Validate message payload based on type."""
return _message_processor.validate_message(msg_type, data)
def create_ack_response(
request_id: Optional[str] = None,
message: str = "Success",
data: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create an acknowledgment response."""
return _message_processor.create_ack_response(request_id, message, data)
def create_error_response(
request_id: Optional[str] = None,
code: str = "400",
message: str = "Error",
details: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create an error response."""
return _message_processor.create_error_response(request_id, code, message, details)
def create_detection_message(
subscription_id: str,
detection_result: Optional[Dict[str, Any]],
model_id: int,
model_name: str,
timestamp: Optional[str] = None
) -> Dict[str, Any]:
"""Create a detection result message."""
return _message_processor.create_detection_message(
subscription_id, detection_result, model_id, model_name, timestamp
)
def create_state_report(
active_streams: int,
loaded_models: int,
**kwargs
) -> Dict[str, Any]:
"""Create a state report message."""
return _message_processor.create_state_report(
active_streams, loaded_models, **kwargs
)

View file

@ -0,0 +1,439 @@
"""
Response formatter module.
This module handles formatting of detection results and other responses
for WebSocket transmission, ensuring consistent output format.
"""
import json
import logging
import time
from typing import Dict, Any, List, Optional, Union
from datetime import datetime
from decimal import Decimal
from ..detection.detection_result import DetectionResult
# Setup logging
logger = logging.getLogger("detector_worker.response_formatter")
class ResponseFormatter:
"""
Formats responses for WebSocket transmission.
This class handles:
- Detection result formatting
- Object serialization
- Response structure standardization
- Special type handling (Decimal, datetime, etc.)
"""
def __init__(self, compact: bool = True):
"""
Initialize the response formatter.
Args:
compact: Whether to use compact JSON formatting
"""
self.compact = compact
self.separators = (',', ':') if compact else (',', ': ')
def format_detection_result(
self,
detection_result: Union[DetectionResult, Dict[str, Any], None],
include_tracking: bool = True,
include_metadata: bool = True
) -> Optional[Dict[str, Any]]:
"""
Format a detection result for transmission.
Args:
detection_result: Detection result to format
include_tracking: Whether to include tracking information
include_metadata: Whether to include metadata
Returns:
Formatted detection dictionary or None
"""
if detection_result is None:
return None
# Handle DetectionResult object
if isinstance(detection_result, DetectionResult):
result = detection_result.to_dict()
else:
result = detection_result
# Format detections
if "detections" in result:
result["detections"] = self._format_detections(
result["detections"],
include_tracking
)
# Format license plate results
if "license_plate_results" in result:
result["license_plate_results"] = self._format_license_plates(
result["license_plate_results"]
)
# Format session info
if "session_info" in result:
result["session_info"] = self._format_session_info(
result["session_info"]
)
# Remove metadata if not needed
if not include_metadata and "metadata" in result:
result.pop("metadata", None)
# Clean up None values
return self._clean_dict(result)
def _format_detections(
self,
detections: List[Dict[str, Any]],
include_tracking: bool = True
) -> List[Dict[str, Any]]:
"""Format detection objects."""
formatted = []
for detection in detections:
formatted_det = {
"class": detection.get("class"),
"confidence": self._format_float(detection.get("confidence")),
"bbox": self._format_bbox(detection.get("bbox", []))
}
# Add tracking info if available and requested
if include_tracking:
if "track_id" in detection:
formatted_det["track_id"] = detection["track_id"]
if "track_age" in detection:
formatted_det["track_age"] = detection["track_age"]
if "stability_score" in detection:
formatted_det["stability_score"] = self._format_float(
detection["stability_score"]
)
# Add additional properties
for key, value in detection.items():
if key not in ["class", "confidence", "bbox", "track_id",
"track_age", "stability_score"]:
formatted_det[key] = self._format_value(value)
formatted.append(formatted_det)
return formatted
def _format_license_plates(
self,
license_plates: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Format license plate results."""
formatted = []
for plate in license_plates:
formatted_plate = {
"text": plate.get("text", ""),
"confidence": self._format_float(plate.get("confidence")),
"bbox": self._format_bbox(plate.get("bbox", []))
}
# Add tracking info if available
if "track_id" in plate:
formatted_plate["track_id"] = plate["track_id"]
# Add type if available
if "type" in plate:
formatted_plate["type"] = plate["type"]
formatted.append(formatted_plate)
return formatted
def _format_session_info(self, session_info: Dict[str, Any]) -> Dict[str, Any]:
"""Format session information."""
formatted = {
"session_id": session_info.get("session_id"),
"display_id": session_info.get("display_id"),
"camera_id": session_info.get("camera_id")
}
# Add timestamps
if "created_at" in session_info:
formatted["created_at"] = self._format_timestamp(
session_info["created_at"]
)
if "updated_at" in session_info:
formatted["updated_at"] = self._format_timestamp(
session_info["updated_at"]
)
# Add any additional session data
for key, value in session_info.items():
if key not in ["session_id", "display_id", "camera_id",
"created_at", "updated_at"]:
formatted[key] = self._format_value(value)
return self._clean_dict(formatted)
def _format_bbox(self, bbox: List[Union[int, float]]) -> List[int]:
"""Format bounding box coordinates."""
if len(bbox) != 4:
return []
return [int(coord) for coord in bbox]
def _format_float(self, value: Union[float, Decimal, None], precision: int = 4) -> Optional[float]:
"""Format floating point values."""
if value is None:
return None
if isinstance(value, Decimal):
value = float(value)
return round(float(value), precision)
def _format_timestamp(self, timestamp: Union[str, datetime, float, None]) -> Optional[str]:
"""Format timestamp to ISO format."""
if timestamp is None:
return None
# Handle string timestamps
if isinstance(timestamp, str):
return timestamp
# Handle datetime objects
if isinstance(timestamp, datetime):
return timestamp.strftime("%Y-%m-%dT%H:%M:%SZ")
# Handle Unix timestamps
if isinstance(timestamp, (int, float)):
dt = datetime.fromtimestamp(timestamp)
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
return str(timestamp)
def _format_value(self, value: Any) -> Any:
"""Format arbitrary values for JSON serialization."""
# Handle None
if value is None:
return None
# Handle basic types
if isinstance(value, (str, int, bool)):
return value
# Handle float/Decimal
if isinstance(value, (float, Decimal)):
return self._format_float(value)
# Handle datetime
if isinstance(value, datetime):
return self._format_timestamp(value)
# Handle lists
if isinstance(value, list):
return [self._format_value(item) for item in value]
# Handle dictionaries
if isinstance(value, dict):
return {k: self._format_value(v) for k, v in value.items()}
# Convert to string for other types
return str(value)
def _clean_dict(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Remove None values and empty collections from dictionary."""
cleaned = {}
for key, value in data.items():
# Skip None values
if value is None:
continue
# Skip empty lists
if isinstance(value, list) and len(value) == 0:
continue
# Skip empty dictionaries
if isinstance(value, dict) and len(value) == 0:
continue
# Recursively clean dictionaries
if isinstance(value, dict):
cleaned_value = self._clean_dict(value)
if cleaned_value:
cleaned[key] = cleaned_value
else:
cleaned[key] = value
return cleaned
def format_tracking_state(self, tracking_state: Dict[str, Any]) -> Dict[str, Any]:
"""
Format tracking state information.
Args:
tracking_state: Tracking state dictionary
Returns:
Formatted tracking state
"""
formatted = {}
# Format tracker state
if "active_tracks" in tracking_state:
formatted["active_tracks"] = len(tracking_state["active_tracks"])
if "tracks" in tracking_state:
formatted["tracks"] = []
for track in tracking_state["tracks"]:
formatted_track = {
"track_id": track.get("track_id"),
"class": track.get("class"),
"age": track.get("age", 0),
"hits": track.get("hits", 0),
"misses": track.get("misses", 0)
}
if "bbox" in track:
formatted_track["bbox"] = self._format_bbox(track["bbox"])
if "confidence" in track:
formatted_track["confidence"] = self._format_float(track["confidence"])
formatted["tracks"].append(formatted_track)
return self._clean_dict(formatted)
def format_error_details(self, error: Exception) -> Dict[str, Any]:
"""
Format error details for response.
Args:
error: Exception to format
Returns:
Formatted error details
"""
details = {
"error_type": type(error).__name__,
"error_message": str(error)
}
# Add additional context if available
if hasattr(error, "details"):
details["details"] = error.details
if hasattr(error, "code"):
details["error_code"] = error.code
return details
def to_json(self, data: Any) -> str:
"""
Convert data to JSON string.
Args:
data: Data to serialize
Returns:
JSON string
"""
return json.dumps(
data,
separators=self.separators,
default=self._json_default
)
def _json_default(self, obj: Any) -> Any:
"""Default JSON serialization for custom types."""
# Handle Decimal
if isinstance(obj, Decimal):
return float(obj)
# Handle datetime
if isinstance(obj, datetime):
return obj.strftime("%Y-%m-%dT%H:%M:%SZ")
# Handle bytes
if isinstance(obj, bytes):
return obj.decode('utf-8', errors='ignore')
# Default to string representation
return str(obj)
def format_pipeline_result(
self,
pipeline_result: Dict[str, Any],
include_intermediate: bool = False
) -> Dict[str, Any]:
"""
Format pipeline execution result.
Args:
pipeline_result: Pipeline result dictionary
include_intermediate: Whether to include intermediate results
Returns:
Formatted pipeline result
"""
formatted = {
"success": pipeline_result.get("success", False),
"pipeline_id": pipeline_result.get("pipeline_id"),
"execution_time": self._format_float(
pipeline_result.get("execution_time", 0)
)
}
# Add final detection result
if "detection_result" in pipeline_result:
formatted["detection_result"] = self.format_detection_result(
pipeline_result["detection_result"]
)
# Add branch results
if "branch_results" in pipeline_result and include_intermediate:
formatted["branch_results"] = {}
for branch_id, result in pipeline_result["branch_results"].items():
formatted["branch_results"][branch_id] = self._format_value(result)
# Add executed actions
if "executed_actions" in pipeline_result:
formatted["executed_actions"] = pipeline_result["executed_actions"]
# Add errors if any
if "errors" in pipeline_result:
formatted["errors"] = [
self.format_error_details(error) if isinstance(error, Exception)
else error
for error in pipeline_result["errors"]
]
return self._clean_dict(formatted)
# Singleton instance
_formatter = ResponseFormatter()
# Convenience functions for backward compatibility
def format_detection_for_websocket(
detection_result: Union[DetectionResult, Dict[str, Any], None],
include_tracking: bool = True,
include_metadata: bool = True
) -> Optional[Dict[str, Any]]:
"""Format detection result for WebSocket transmission."""
return _formatter.format_detection_result(
detection_result, include_tracking, include_metadata
)
def format_tracking_state(tracking_state: Dict[str, Any]) -> Dict[str, Any]:
"""Format tracking state information."""
return _formatter.format_tracking_state(tracking_state)
def format_error_response(error: Exception) -> Dict[str, Any]:
"""Format error details for response."""
return _formatter.format_error_details(error)
def to_compact_json(data: Any) -> str:
"""Convert data to compact JSON string."""
return _formatter.to_json(data)

View file

@ -0,0 +1,622 @@
"""
WebSocket handler module.
This module manages WebSocket connections, message processing, heartbeat functionality,
and coordination between stream processing and detection pipelines.
"""
import asyncio
import json
import logging
import time
import traceback
import uuid
from typing import Dict, Any, Optional, Callable, List, Set
from contextlib import asynccontextmanager
from fastapi import WebSocket
from websockets.exceptions import ConnectionClosedError, WebSocketDisconnect
from ..core.config import config, subscription_to_camera, latest_frames
from ..core.constants import HEARTBEAT_INTERVAL
from ..core.exceptions import WebSocketError, StreamError
from ..streams.stream_manager import StreamManager
from ..streams.camera_monitor import CameraConnectionMonitor
from ..detection.detection_result import DetectionResult
from ..models.model_manager import ModelManager
from ..pipeline.pipeline_executor import PipelineExecutor
from ..storage.session_cache import SessionCache
from ..storage.redis_client import RedisClientManager
from ..utils.system_monitor import get_system_metrics
# Setup logging
logger = logging.getLogger("detector_worker.websocket_handler")
ws_logger = logging.getLogger("websocket")
# Type definitions for callbacks
MessageHandler = Callable[[Dict[str, Any]], asyncio.coroutine]
DetectionHandler = Callable[[str, Dict[str, Any], Any, WebSocket, Any, Dict[str, Any]], asyncio.coroutine]
class WebSocketHandler:
"""
Manages WebSocket connections and message processing for the detection worker.
This class handles:
- WebSocket lifecycle management
- Message routing and processing
- Heartbeat/state reporting
- Stream subscription management
- Detection result forwarding
- Session management
"""
def __init__(
self,
stream_manager: StreamManager,
model_manager: ModelManager,
pipeline_executor: PipelineExecutor,
session_cache: SessionCache,
redis_client: Optional[RedisClientManager] = None
):
"""
Initialize the WebSocket handler.
Args:
stream_manager: Manager for camera streams
model_manager: Manager for ML models
pipeline_executor: Pipeline execution engine
session_cache: Session state cache
redis_client: Optional Redis client for pub/sub
"""
self.stream_manager = stream_manager
self.model_manager = model_manager
self.pipeline_executor = pipeline_executor
self.session_cache = session_cache
self.redis_client = redis_client
# Connection state
self.websocket: Optional[WebSocket] = None
self.connected: bool = False
self.tasks: List[asyncio.Task] = []
# Message handlers
self.message_handlers: Dict[str, MessageHandler] = {
"subscribe": self._handle_subscribe,
"unsubscribe": self._handle_unsubscribe,
"requestState": self._handle_request_state,
"setSessionId": self._handle_set_session,
"patchSession": self._handle_patch_session,
"setProgressionStage": self._handle_set_progression_stage
}
# Session and display management
self.session_ids: Dict[str, str] = {} # display_identifier -> session_id
self.display_identifiers: Set[str] = set()
# Camera monitor
self.camera_monitor = CameraConnectionMonitor()
async def handle_connection(self, websocket: WebSocket) -> None:
"""
Main entry point for handling a WebSocket connection.
Args:
websocket: The WebSocket connection to handle
"""
try:
await websocket.accept()
self.websocket = websocket
self.connected = True
logger.info("WebSocket connection accepted")
# Create concurrent tasks
stream_task = asyncio.create_task(self._process_streams())
heartbeat_task = asyncio.create_task(self._send_heartbeat())
message_task = asyncio.create_task(self._process_messages())
self.tasks = [stream_task, heartbeat_task, message_task]
# Wait for tasks to complete
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logger.error(f"Error in WebSocket handler: {e}")
finally:
self.connected = False
await self._cleanup()
async def _cleanup(self) -> None:
"""Clean up resources when connection closes."""
logger.info("Cleaning up WebSocket connection")
# Cancel all tasks
for task in self.tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Clean up streams
await self.stream_manager.cleanup_all_streams()
# Clean up models
self.model_manager.cleanup_all_models()
# Clear session state
self.session_cache.clear_all_sessions()
self.session_ids.clear()
self.display_identifiers.clear()
# Clear camera states
self.camera_monitor.clear_all_states()
logger.info("WebSocket cleanup completed")
async def _send_heartbeat(self) -> None:
"""Send periodic heartbeat/state reports to maintain connection."""
while self.connected:
try:
# Get system metrics
metrics = get_system_metrics()
# Get active streams info
active_streams = self.stream_manager.get_active_streams()
active_models = self.model_manager.get_loaded_models()
state_data = {
"type": "stateReport",
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"data": {
"activeStreams": len(active_streams),
"loadedModels": len(active_models),
"cpuUsage": metrics.get("cpu_percent", 0),
"memoryUsage": metrics.get("memory_percent", 0),
"gpuUsage": metrics.get("gpu_percent", 0),
"gpuMemory": metrics.get("gpu_memory_percent", 0),
"uptime": time.time() - metrics.get("start_time", time.time())
}
}
ws_logger.info(f"TX -> {json.dumps(state_data, separators=(',', ':'))}")
await self.websocket.send_json(state_data)
await asyncio.sleep(HEARTBEAT_INTERVAL)
except (WebSocketDisconnect, ConnectionClosedError):
logger.info("WebSocket disconnected during heartbeat")
break
except Exception as e:
logger.error(f"Error sending heartbeat: {e}")
break
async def _process_messages(self) -> None:
"""Process incoming WebSocket messages."""
while self.connected:
try:
text_data = await self.websocket.receive_text()
ws_logger.info(f"RX <- {text_data}")
data = json.loads(text_data)
msg_type = data.get("type")
if msg_type in self.message_handlers:
handler = self.message_handlers[msg_type]
await handler(data)
else:
logger.error(f"Unknown message type: {msg_type}")
except json.JSONDecodeError:
logger.error("Received invalid JSON message")
except (WebSocketDisconnect, ConnectionClosedError) as e:
logger.warning(f"WebSocket disconnected: {e}")
break
except Exception as e:
logger.error(f"Error handling message: {e}")
traceback.print_exc()
break
async def _process_streams(self) -> None:
"""Process active camera streams and run detection pipelines."""
while self.connected:
try:
active_streams = self.stream_manager.get_active_streams()
if active_streams:
# Process each active stream
tasks = []
for camera_id, stream_info in active_streams.items():
# Get latest frame
frame = self.stream_manager.get_latest_frame(camera_id)
if frame is None:
continue
# Get model for this camera
model_id = stream_info.get("modelId")
if not model_id:
continue
model_tree = self.model_manager.get_model(camera_id, model_id)
if not model_tree:
continue
# Create detection task
persistent_data = self.session_cache.get_persistent_data(camera_id)
task = asyncio.create_task(
self._handle_detection(
camera_id, stream_info, frame,
self.websocket, model_tree, persistent_data
)
)
tasks.append(task)
# Wait for all detection tasks
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
# Update persistent data
for i, (camera_id, _) in enumerate(active_streams.items()):
if i < len(results) and isinstance(results[i], dict):
self.session_cache.update_persistent_data(camera_id, results[i])
# Polling interval
poll_interval = config.get("poll_interval_ms", 100) / 1000.0
await asyncio.sleep(poll_interval)
except asyncio.CancelledError:
logger.info("Stream processing cancelled")
break
except Exception as e:
logger.error(f"Error in stream processing: {e}")
await asyncio.sleep(1)
async def _handle_detection(
self,
camera_id: str,
stream_info: Dict[str, Any],
frame: Any,
websocket: WebSocket,
model_tree: Any,
persistent_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Handle detection for a single camera frame.
Args:
camera_id: Camera identifier
stream_info: Stream configuration
frame: Video frame to process
websocket: WebSocket connection
model_tree: Model pipeline tree
persistent_data: Persistent data for this camera
Returns:
Updated persistent data
"""
try:
# Check camera connection state
if self.camera_monitor.should_notify_disconnection(camera_id):
await self._send_disconnection_notification(camera_id, stream_info)
return persistent_data
# Apply crop if specified
cropped_frame = self._apply_crop(frame, stream_info)
# Get session pipeline state
pipeline_state = self.session_cache.get_session_pipeline_state(camera_id)
# Run detection pipeline
detection_result = await self.pipeline_executor.execute_pipeline(
camera_id,
stream_info,
cropped_frame,
model_tree,
persistent_data,
pipeline_state
)
# Send detection result
if detection_result:
await self._send_detection_result(
camera_id, stream_info, detection_result
)
# Handle camera reconnection
if self.camera_monitor.should_notify_reconnection(camera_id):
self.camera_monitor.mark_reconnection_notified(camera_id)
logger.info(f"Camera {camera_id} reconnected successfully")
return persistent_data
except Exception as e:
logger.error(f"Error in detection handling for camera {camera_id}: {e}")
traceback.print_exc()
return persistent_data
async def _handle_subscribe(self, data: Dict[str, Any]) -> None:
"""Handle stream subscription request."""
payload = data.get("payload", {})
subscription_id = payload.get("subscriptionIdentifier")
if not subscription_id:
logger.error("Missing subscriptionIdentifier in subscribe payload")
return
try:
# Extract display and camera IDs
parts = subscription_id.split(";")
if len(parts) >= 2:
display_id = parts[0]
camera_id = parts[1]
self.display_identifiers.add(display_id)
else:
camera_id = subscription_id
# Store subscription mapping
subscription_to_camera[subscription_id] = camera_id
# Start camera stream
await self.stream_manager.start_stream(camera_id, payload)
# Load model
model_id = payload.get("modelId")
model_url = payload.get("modelUrl")
if model_id and model_url:
await self.model_manager.load_model(camera_id, model_id, model_url)
logger.info(f"Subscribed to stream: {subscription_id}")
except Exception as e:
logger.error(f"Error handling subscription: {e}")
traceback.print_exc()
async def _handle_unsubscribe(self, data: Dict[str, Any]) -> None:
"""Handle stream unsubscription request."""
payload = data.get("payload", {})
subscription_id = payload.get("subscriptionIdentifier")
if not subscription_id:
logger.error("Missing subscriptionIdentifier in unsubscribe payload")
return
try:
# Get camera ID from subscription
camera_id = subscription_to_camera.get(subscription_id)
if not camera_id:
logger.warning(f"No camera found for subscription: {subscription_id}")
return
# Stop stream
await self.stream_manager.stop_stream(camera_id)
# Unload model
self.model_manager.unload_models(camera_id)
# Clean up mappings
subscription_to_camera.pop(subscription_id, None)
# Clean up session state
self.session_cache.clear_session(camera_id)
logger.info(f"Unsubscribed from stream: {subscription_id}")
except Exception as e:
logger.error(f"Error handling unsubscription: {e}")
traceback.print_exc()
async def _handle_request_state(self, data: Dict[str, Any]) -> None:
"""Handle state request message."""
# Send immediate state report
await self._send_heartbeat()
async def _handle_set_session(self, data: Dict[str, Any]) -> None:
"""Handle setSessionId message."""
payload = data.get("payload", {})
display_id = payload.get("displayIdentifier")
session_id = payload.get("sessionId")
if display_id and session_id:
self.session_ids[display_id] = session_id
# Update session for all cameras of this display
with self.stream_manager.streams_lock:
for camera_id, stream in self.stream_manager.streams.items():
if stream["subscriptionIdentifier"].startswith(display_id + ";"):
self.session_cache.update_session_id(camera_id, session_id)
# Send acknowledgment
response = {
"type": "ack",
"requestId": data.get("requestId", str(uuid.uuid4())),
"code": "200",
"data": {
"message": f"Session ID set for display {display_id}",
"sessionId": session_id
}
}
ws_logger.info(f"TX -> {json.dumps(response, separators=(',', ':'))}")
await self.websocket.send_json(response)
logger.info(f"Set session {session_id} for display {display_id}")
async def _handle_patch_session(self, data: Dict[str, Any]) -> None:
"""Handle patchSession message."""
payload = data.get("payload", {})
session_id = payload.get("sessionId")
patch_data = payload.get("data", {})
if session_id:
# Store patch data (could be used for session updates)
logger.info(f"Received patch for session {session_id}: {patch_data}")
# Send acknowledgment
response = {
"type": "ack",
"requestId": data.get("requestId", str(uuid.uuid4())),
"code": "200",
"data": {
"message": f"Session {session_id} patched successfully",
"sessionId": session_id,
"patchData": patch_data
}
}
ws_logger.info(f"TX -> {json.dumps(response, separators=(',', ':'))}")
await self.websocket.send_json(response)
async def _handle_set_progression_stage(self, data: Dict[str, Any]) -> None:
"""Handle setProgressionStage message."""
payload = data.get("payload", {})
display_id = payload.get("displayIdentifier")
progression_stage = payload.get("progressionStage")
logger.info(f"🏁 PROGRESSION STAGE RECEIVED: displayId={display_id}, stage={progression_stage}")
if not display_id:
logger.warning("Missing displayIdentifier in setProgressionStage")
return
# Find all cameras for this display
affected_cameras = []
with self.stream_manager.streams_lock:
for camera_id, stream in self.stream_manager.streams.items():
if stream["subscriptionIdentifier"].startswith(display_id + ";"):
affected_cameras.append(camera_id)
logger.debug(f"🎯 Found {len(affected_cameras)} cameras for display {display_id}: {affected_cameras}")
# Update progression stage for each camera
for camera_id in affected_cameras:
pipeline_state = self.session_cache.get_or_init_session_pipeline_state(camera_id)
current_mode = pipeline_state.get("mode", "validation_detecting")
if progression_stage == "car_fueling":
# Stop YOLO inference during fueling
if current_mode == "lightweight":
pipeline_state["yolo_inference_enabled"] = False
pipeline_state["progression_stage"] = "car_fueling"
logger.info(f"⏸️ Camera {camera_id}: YOLO inference DISABLED for car_fueling stage")
else:
logger.debug(f"📊 Camera {camera_id}: car_fueling received but not in lightweight mode (mode: {current_mode})")
elif progression_stage == "car_waitpayment":
# Resume YOLO inference for absence counter
pipeline_state["yolo_inference_enabled"] = True
pipeline_state["progression_stage"] = "car_waitpayment"
logger.info(f"▶️ Camera {camera_id}: YOLO inference RE-ENABLED for car_waitpayment stage")
elif progression_stage == "welcome":
# Ignore welcome messages during car_waitpayment
current_progression = pipeline_state.get("progression_stage")
if current_progression == "car_waitpayment":
logger.info(f"🚫 Camera {camera_id}: IGNORING welcome stage (currently in car_waitpayment)")
else:
pipeline_state["progression_stage"] = "welcome"
logger.info(f"🎉 Camera {camera_id}: Progression stage set to welcome")
elif progression_stage in ["car_wait_staff"]:
pipeline_state["progression_stage"] = progression_stage
logger.info(f"📋 Camera {camera_id}: Progression stage set to {progression_stage}")
async def _send_detection_result(
self,
camera_id: str,
stream_info: Dict[str, Any],
detection_result: DetectionResult
) -> None:
"""Send detection result over WebSocket."""
detection_data = {
"type": "imageDetection",
"subscriptionIdentifier": stream_info["subscriptionIdentifier"],
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"data": {
"detection": detection_result.to_dict(),
"modelId": stream_info["modelId"],
"modelName": stream_info["modelName"]
}
}
try:
ws_logger.info(f"TX -> {json.dumps(detection_data, separators=(',', ':'))}")
await self.websocket.send_json(detection_data)
except RuntimeError as e:
if "websocket.close" in str(e):
logger.warning(f"WebSocket closed - cannot send detection for camera {camera_id}")
else:
raise
async def _send_disconnection_notification(
self,
camera_id: str,
stream_info: Dict[str, Any]
) -> None:
"""Send camera disconnection notification."""
logger.error(f"🚨 CAMERA DISCONNECTION DETECTED: {camera_id} - sending immediate detection: null")
# Clear cached data
self.session_cache.clear_session(camera_id)
# Send null detection
detection_data = {
"type": "imageDetection",
"subscriptionIdentifier": stream_info["subscriptionIdentifier"],
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"data": {
"detection": None,
"modelId": stream_info["modelId"],
"modelName": stream_info["modelName"]
}
}
try:
ws_logger.info(f"TX -> {json.dumps(detection_data, separators=(',', ':'))}")
await self.websocket.send_json(detection_data)
except RuntimeError as e:
if "websocket.close" in str(e):
logger.warning(f"WebSocket closed - cannot send disconnection signal for camera {camera_id}")
else:
raise
self.camera_monitor.mark_disconnection_notified(camera_id)
logger.info(f"📡 SENT DISCONNECTION SIGNAL - detection: null for camera {camera_id}, backend should clear session")
def _apply_crop(self, frame: Any, stream_info: Dict[str, Any]) -> Any:
"""Apply crop to frame if crop coordinates are specified."""
crop_coords = [
stream_info.get("cropX1"),
stream_info.get("cropY1"),
stream_info.get("cropX2"),
stream_info.get("cropY2")
]
if all(coord is not None for coord in crop_coords):
x1, y1, x2, y2 = crop_coords
return frame[y1:y2, x1:x2]
return frame
# Convenience function for backward compatibility
async def handle_websocket_connection(
websocket: WebSocket,
stream_manager: StreamManager,
model_manager: ModelManager,
pipeline_executor: PipelineExecutor,
session_cache: SessionCache,
redis_client: Optional[RedisClientManager] = None
) -> None:
"""
Handle a WebSocket connection using the WebSocketHandler.
This is a convenience function that creates a handler instance
and processes the connection.
"""
handler = WebSocketHandler(
stream_manager,
model_manager,
pipeline_executor,
session_cache,
redis_client
)
await handler.handle_connection(websocket)

View file

@ -0,0 +1,441 @@
"""
Model manager module.
This module handles ML model loading, caching, and lifecycle management
for the detection worker.
"""
import os
import logging
import threading
from typing import Dict, Any, Optional, List, Set, Tuple
from urllib.parse import urlparse
import traceback
from ..core.config import MODELS_DIR
from ..core.exceptions import ModelLoadError
# Setup logging
logger = logging.getLogger("detector_worker.model_manager")
class ModelRegistry:
"""
Registry for loaded models.
Maintains a reference count for each model to enable sharing
between multiple cameras.
"""
def __init__(self):
"""Initialize the model registry."""
self.models: Dict[str, Dict[str, Any]] = {} # model_id -> model_info
self.references: Dict[str, Set[str]] = {} # model_id -> set of camera_ids
self.lock = threading.Lock()
def register_model(self, model_id: str, model_data: Any, model_path: str) -> None:
"""
Register a model in the registry.
Args:
model_id: Unique model identifier
model_data: Loaded model data
model_path: Path to model file
"""
with self.lock:
self.models[model_id] = {
"model": model_data,
"path": model_path,
"loaded_at": os.path.getmtime(model_path)
}
if model_id not in self.references:
self.references[model_id] = set()
def add_reference(self, model_id: str, camera_id: str) -> None:
"""Add a reference to a model from a camera."""
with self.lock:
if model_id in self.references:
self.references[model_id].add(camera_id)
def remove_reference(self, model_id: str, camera_id: str) -> bool:
"""
Remove a reference to a model from a camera.
Returns:
True if model has no more references and can be unloaded
"""
with self.lock:
if model_id in self.references:
self.references[model_id].discard(camera_id)
return len(self.references[model_id]) == 0
return True
def get_model(self, model_id: str) -> Optional[Any]:
"""Get a model from the registry."""
with self.lock:
model_info = self.models.get(model_id)
return model_info["model"] if model_info else None
def unregister_model(self, model_id: str) -> None:
"""Remove a model from the registry."""
with self.lock:
self.models.pop(model_id, None)
self.references.pop(model_id, None)
def get_loaded_models(self) -> List[str]:
"""Get list of loaded model IDs."""
with self.lock:
return list(self.models.keys())
def get_reference_count(self, model_id: str) -> int:
"""Get number of references to a model."""
with self.lock:
return len(self.references.get(model_id, set()))
def clear(self) -> None:
"""Clear all models from registry."""
with self.lock:
self.models.clear()
self.references.clear()
class ModelManager:
"""
Manages ML model loading, caching, and lifecycle.
This class handles:
- Model downloading and caching
- Model loading with proper error handling
- Reference counting for model sharing
- Model cleanup and memory management
- Pipeline model tree management
"""
def __init__(self, models_dir: str = MODELS_DIR):
"""
Initialize the model manager.
Args:
models_dir: Directory to cache downloaded models
"""
self.models_dir = models_dir
self.registry = ModelRegistry()
self.models_lock = threading.Lock()
# Camera to models mapping
self.camera_models: Dict[str, Dict[str, Any]] = {} # camera_id -> {model_id -> model_tree}
# Pipeline loader will be injected
self.pipeline_loader = None
# Create models directory if it doesn't exist
os.makedirs(self.models_dir, exist_ok=True)
def set_pipeline_loader(self, pipeline_loader: Any) -> None:
"""
Set the pipeline loader instance.
Args:
pipeline_loader: Pipeline loader to use for loading models
"""
self.pipeline_loader = pipeline_loader
async def load_model(
self,
camera_id: str,
model_id: str,
model_url: str,
force_reload: bool = False
) -> Any:
"""
Load a model for a specific camera.
Args:
camera_id: Camera identifier
model_id: Model identifier
model_url: URL or path to model file
force_reload: Force reload even if cached
Returns:
Loaded model tree
Raises:
ModelLoadError: If model loading fails
"""
if not self.pipeline_loader:
raise ModelLoadError("Pipeline loader not initialized")
try:
# Check if model is already loaded for this camera
with self.models_lock:
if camera_id in self.camera_models and model_id in self.camera_models[camera_id]:
if not force_reload:
logger.info(f"Model {model_id} already loaded for camera {camera_id}")
return self.camera_models[camera_id][model_id]
# Check if model is in registry
cached_model = self.registry.get_model(model_id)
if cached_model and not force_reload:
# Add reference and return cached model
self.registry.add_reference(model_id, camera_id)
with self.models_lock:
if camera_id not in self.camera_models:
self.camera_models[camera_id] = {}
self.camera_models[camera_id][model_id] = cached_model
logger.info(f"Using cached model {model_id} for camera {camera_id}")
return cached_model
# Download or locate model file
model_path = await self._get_model_path(model_url, model_id)
# Load model using pipeline loader
logger.info(f"Loading model {model_id} from {model_path}")
model_tree = await self.pipeline_loader.load_pipeline(model_path)
# Register in registry
self.registry.register_model(model_id, model_tree, model_path)
self.registry.add_reference(model_id, camera_id)
# Store in camera models
with self.models_lock:
if camera_id not in self.camera_models:
self.camera_models[camera_id] = {}
self.camera_models[camera_id][model_id] = model_tree
logger.info(f"Successfully loaded model {model_id} for camera {camera_id}")
return model_tree
except Exception as e:
logger.error(f"Failed to load model {model_id}: {e}")
traceback.print_exc()
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
async def _get_model_path(self, model_url: str, model_id: str) -> str:
"""
Get local path for a model, downloading if necessary.
Args:
model_url: URL or local path to model
model_id: Model identifier
Returns:
Local file path to model
"""
# Check if it's already a local path
if os.path.exists(model_url):
return model_url
# Parse URL
parsed = urlparse(model_url)
# Check if it's a file:// URL
if parsed.scheme == 'file':
return parsed.path
# For HTTP/HTTPS URLs, download to cache
if parsed.scheme in ['http', 'https']:
# Generate cache filename
filename = os.path.basename(parsed.path)
if not filename:
filename = f"{model_id}.mpta"
cache_path = os.path.join(self.models_dir, filename)
# Check if already cached
if os.path.exists(cache_path):
logger.info(f"Using cached model file: {cache_path}")
return cache_path
# Download model
logger.info(f"Downloading model from {model_url}")
await self._download_model(model_url, cache_path)
return cache_path
# For other schemes or no scheme, assume local path
return model_url
async def _download_model(self, url: str, destination: str) -> None:
"""
Download a model file from URL.
Args:
url: URL to download from
destination: Local path to save to
"""
import aiohttp
import aiofiles
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
response.raise_for_status()
# Get total size if available
total_size = response.headers.get('Content-Length')
if total_size:
total_size = int(total_size)
logger.info(f"Downloading {total_size / (1024*1024):.2f} MB")
# Download to temporary file first
temp_path = f"{destination}.tmp"
downloaded = 0
async with aiofiles.open(temp_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
await f.write(chunk)
downloaded += len(chunk)
# Log progress
if total_size and downloaded % (1024 * 1024) == 0:
progress = (downloaded / total_size) * 100
logger.info(f"Download progress: {progress:.1f}%")
# Move to final destination
os.rename(temp_path, destination)
logger.info(f"Model downloaded successfully to {destination}")
except Exception as e:
# Clean up temporary file if exists
temp_path = f"{destination}.tmp"
if os.path.exists(temp_path):
os.remove(temp_path)
raise ModelLoadError(f"Failed to download model: {e}")
def get_model(self, camera_id: str, model_id: str) -> Optional[Any]:
"""
Get a loaded model for a camera.
Args:
camera_id: Camera identifier
model_id: Model identifier
Returns:
Model tree if loaded, None otherwise
"""
with self.models_lock:
camera_models = self.camera_models.get(camera_id, {})
return camera_models.get(model_id)
def unload_models(self, camera_id: str) -> None:
"""
Unload all models for a camera.
Args:
camera_id: Camera identifier
"""
with self.models_lock:
if camera_id not in self.camera_models:
return
# Remove references for each model
for model_id in self.camera_models[camera_id]:
should_unload = self.registry.remove_reference(model_id, camera_id)
if should_unload:
logger.info(f"Unloading model {model_id} (no more references)")
self.registry.unregister_model(model_id)
# Clean up model if pipeline loader supports it
if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_model'):
try:
self.pipeline_loader.cleanup_model(model_id)
except Exception as e:
logger.error(f"Error cleaning up model {model_id}: {e}")
# Remove camera entry
del self.camera_models[camera_id]
logger.info(f"Unloaded all models for camera {camera_id}")
def cleanup_all_models(self) -> None:
"""Clean up all loaded models."""
logger.info("Cleaning up all loaded models")
with self.models_lock:
# Get list of cameras to clean up
cameras = list(self.camera_models.keys())
# Unload models for each camera
for camera_id in cameras:
self.unload_models(camera_id)
# Clear registry
self.registry.clear()
# Clean up pipeline loader if it has cleanup
if self.pipeline_loader and hasattr(self.pipeline_loader, 'cleanup_all'):
try:
self.pipeline_loader.cleanup_all()
except Exception as e:
logger.error(f"Error in pipeline loader cleanup: {e}")
logger.info("Model cleanup completed")
def get_loaded_models(self) -> Dict[str, List[str]]:
"""
Get information about loaded models.
Returns:
Dictionary mapping model IDs to list of camera IDs using them
"""
result = {}
with self.models_lock:
for model_id in self.registry.get_loaded_models():
cameras = []
for camera_id, models in self.camera_models.items():
if model_id in models:
cameras.append(camera_id)
result[model_id] = cameras
return result
def get_model_stats(self) -> Dict[str, Any]:
"""
Get statistics about loaded models.
Returns:
Dictionary with model statistics
"""
with self.models_lock:
total_models = len(self.registry.get_loaded_models())
total_cameras = len(self.camera_models)
# Count total model instances
total_instances = sum(
len(models) for models in self.camera_models.values()
)
# Get cache size
cache_size = 0
if os.path.exists(self.models_dir):
for filename in os.listdir(self.models_dir):
filepath = os.path.join(self.models_dir, filename)
if os.path.isfile(filepath):
cache_size += os.path.getsize(filepath)
return {
"total_models": total_models,
"total_cameras": total_cameras,
"total_instances": total_instances,
"cache_size_mb": round(cache_size / (1024 * 1024), 2),
"models_dir": self.models_dir
}
# Global model manager instance
_model_manager = None
def get_model_manager() -> ModelManager:
"""Get or create the global model manager instance."""
global _model_manager
if _model_manager is None:
_model_manager = ModelManager()
return _model_manager
# Convenience functions for backward compatibility
def initialize_model_manager(models_dir: str = MODELS_DIR) -> ModelManager:
"""Initialize the global model manager."""
global _model_manager
_model_manager = ModelManager(models_dir)
return _model_manager

View file

@ -0,0 +1,486 @@
"""
Pipeline loader module.
This module handles loading and parsing of MPTA (Machine Learning Pipeline Archive)
files, which contain model configurations and pipeline definitions.
"""
import os
import json
import logging
import zipfile
import tempfile
import shutil
from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass, field
from pathlib import Path
from ..core.exceptions import ModelLoadError, PipelineError
# Setup logging
logger = logging.getLogger("detector_worker.pipeline_loader")
@dataclass
class PipelineNode:
"""Represents a node in the pipeline tree."""
model_id: str
model_file: str
model_path: Optional[str] = None
model: Optional[Any] = None # Loaded model instance
# Node configuration
multi_class: bool = False
expected_classes: List[str] = field(default_factory=list)
trigger_classes: List[str] = field(default_factory=list)
min_confidence: float = 0.5
max_detections: Optional[int] = None
# Cropping configuration
crop: bool = False
crop_class: Optional[str] = None
crop_expand_ratio: float = 1.0
# Actions configuration
actions: List[Dict[str, Any]] = field(default_factory=list)
parallel_actions: List[Dict[str, Any]] = field(default_factory=list)
# Branch configuration
branches: List['PipelineNode'] = field(default_factory=list)
parallel: bool = False
# Detection settings
yolo_settings: Dict[str, Any] = field(default_factory=dict)
track_classes: Optional[List[str]] = None
# Metadata
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class PipelineConfig:
"""Pipeline configuration from pipeline.json."""
pipeline_id: str
version: str = "1.0"
description: str = ""
# Database configuration
database_config: Optional[Dict[str, Any]] = None
# Redis configuration
redis_config: Optional[Dict[str, Any]] = None
# Global settings
global_settings: Dict[str, Any] = field(default_factory=dict)
# Root pipeline node
root: Optional[PipelineNode] = None
class PipelineLoader:
"""
Loads and manages ML pipeline configurations.
This class handles:
- MPTA file extraction and parsing
- Pipeline configuration validation
- Model file management
- Pipeline tree construction
- Resource cleanup
"""
def __init__(self, temp_dir: Optional[str] = None):
"""
Initialize the pipeline loader.
Args:
temp_dir: Temporary directory for extracting MPTA files
"""
self.temp_dir = temp_dir or tempfile.gettempdir()
self.extracted_paths: Dict[str, str] = {} # mpta_path -> extracted_dir
self.loaded_models: Dict[str, Any] = {} # model_path -> model_instance
async def load_pipeline(self, mpta_path: str) -> PipelineNode:
"""
Load a pipeline from an MPTA file.
Args:
mpta_path: Path to MPTA file
Returns:
Root pipeline node
Raises:
ModelLoadError: If loading fails
"""
try:
# Extract MPTA if not already extracted
extracted_dir = await self._extract_mpta(mpta_path)
# Load pipeline configuration
pipeline_json_path = os.path.join(extracted_dir, "pipeline.json")
if not os.path.exists(pipeline_json_path):
raise ModelLoadError(f"pipeline.json not found in {mpta_path}")
with open(pipeline_json_path, 'r') as f:
config_data = json.load(f)
# Parse pipeline configuration
pipeline_config = self._parse_pipeline_config(config_data, extracted_dir)
# Validate pipeline
self._validate_pipeline(pipeline_config)
# Load models for the pipeline
await self._load_pipeline_models(pipeline_config.root, extracted_dir)
logger.info(f"Successfully loaded pipeline from {mpta_path}")
return pipeline_config.root
except Exception as e:
logger.error(f"Failed to load pipeline from {mpta_path}: {e}")
raise ModelLoadError(f"Failed to load pipeline: {e}")
async def _extract_mpta(self, mpta_path: str) -> str:
"""
Extract MPTA file to temporary directory.
Args:
mpta_path: Path to MPTA file
Returns:
Path to extracted directory
"""
# Check if already extracted
if mpta_path in self.extracted_paths:
extracted_dir = self.extracted_paths[mpta_path]
if os.path.exists(extracted_dir):
return extracted_dir
# Create extraction directory
mpta_name = os.path.splitext(os.path.basename(mpta_path))[0]
extracted_dir = os.path.join(self.temp_dir, f"mpta_{mpta_name}")
# Extract MPTA
logger.info(f"Extracting MPTA file: {mpta_path}")
try:
with zipfile.ZipFile(mpta_path, 'r') as zip_ref:
# Clean existing directory if exists
if os.path.exists(extracted_dir):
shutil.rmtree(extracted_dir)
os.makedirs(extracted_dir)
zip_ref.extractall(extracted_dir)
self.extracted_paths[mpta_path] = extracted_dir
logger.info(f"Extracted to: {extracted_dir}")
return extracted_dir
except Exception as e:
raise ModelLoadError(f"Failed to extract MPTA: {e}")
def _parse_pipeline_config(
self,
config_data: Dict[str, Any],
base_dir: str
) -> PipelineConfig:
"""
Parse pipeline configuration from JSON.
Args:
config_data: Pipeline JSON data
base_dir: Base directory for model files
Returns:
Parsed pipeline configuration
"""
# Create pipeline config
pipeline_config = PipelineConfig(
pipeline_id=config_data.get("pipelineId", "unknown"),
version=config_data.get("version", "1.0"),
description=config_data.get("description", "")
)
# Parse database config
if "database" in config_data:
pipeline_config.database_config = config_data["database"]
# Parse Redis config
if "redis" in config_data:
pipeline_config.redis_config = config_data["redis"]
# Parse global settings
if "globalSettings" in config_data:
pipeline_config.global_settings = config_data["globalSettings"]
# Parse pipeline tree
if "pipeline" in config_data:
pipeline_config.root = self._parse_pipeline_node(
config_data["pipeline"], base_dir
)
elif "root" in config_data:
pipeline_config.root = self._parse_pipeline_node(
config_data["root"], base_dir
)
else:
raise PipelineError("No pipeline or root node found in configuration")
return pipeline_config
def _parse_pipeline_node(
self,
node_data: Dict[str, Any],
base_dir: str
) -> PipelineNode:
"""
Parse a pipeline node from configuration.
Args:
node_data: Node configuration data
base_dir: Base directory for model files
Returns:
Parsed pipeline node
"""
# Create node
node = PipelineNode(
model_id=node_data.get("modelId", ""),
model_file=node_data.get("modelFile", "")
)
# Set model path
if node.model_file:
node.model_path = os.path.join(base_dir, node.model_file)
# Parse configuration
node.multi_class = node_data.get("multiClass", False)
node.expected_classes = node_data.get("expectedClasses", [])
node.trigger_classes = node_data.get("triggerClasses", [])
node.min_confidence = node_data.get("minConfidence", 0.5)
node.max_detections = node_data.get("maxDetections")
# Parse cropping
node.crop = node_data.get("crop", False)
node.crop_class = node_data.get("cropClass")
node.crop_expand_ratio = node_data.get("cropExpandRatio", 1.0)
# Parse actions
node.actions = node_data.get("actions", [])
node.parallel_actions = node_data.get("parallelActions", [])
# Parse YOLO settings
if "yoloSettings" in node_data:
node.yolo_settings = node_data["yoloSettings"]
elif "detectionSettings" in node_data:
node.yolo_settings = node_data["detectionSettings"]
# Parse tracking
node.track_classes = node_data.get("trackClasses")
# Parse metadata
node.metadata = node_data.get("metadata", {})
# Parse branches
branches_data = node_data.get("branches", [])
node.parallel = node_data.get("parallel", False)
for branch_data in branches_data:
branch_node = self._parse_pipeline_node(branch_data, base_dir)
node.branches.append(branch_node)
return node
def _validate_pipeline(self, pipeline_config: PipelineConfig) -> None:
"""
Validate pipeline configuration.
Args:
pipeline_config: Pipeline configuration to validate
Raises:
PipelineError: If validation fails
"""
if not pipeline_config.root:
raise PipelineError("Pipeline has no root node")
# Validate root node
self._validate_node(pipeline_config.root)
def _validate_node(self, node: PipelineNode) -> None:
"""
Validate a pipeline node.
Args:
node: Node to validate
Raises:
PipelineError: If validation fails
"""
# Check required fields
if not node.model_id:
raise PipelineError("Node missing modelId")
if not node.model_file and not node.model:
raise PipelineError(f"Node {node.model_id} missing modelFile")
# Validate model path exists
if node.model_path and not os.path.exists(node.model_path):
raise PipelineError(f"Model file not found: {node.model_path}")
# Validate cropping configuration
if node.crop and not node.crop_class:
raise PipelineError(f"Node {node.model_id} has crop=true but no cropClass")
# Validate confidence
if not 0 <= node.min_confidence <= 1:
raise PipelineError(f"Invalid minConfidence: {node.min_confidence}")
# Validate branches
for branch in node.branches:
self._validate_node(branch)
async def _load_pipeline_models(
self,
node: PipelineNode,
base_dir: str
) -> None:
"""
Load models for a pipeline node and its branches.
Args:
node: Pipeline node
base_dir: Base directory for models
"""
# Load model for this node if path is specified
if node.model_path:
node.model = await self._load_model(node.model_path, node.model_id)
# Load models for branches
for branch in node.branches:
await self._load_pipeline_models(branch, base_dir)
async def _load_model(self, model_path: str, model_id: str) -> Any:
"""
Load a single model file.
Args:
model_path: Path to model file
model_id: Model identifier
Returns:
Loaded model instance
"""
# Check if already loaded
if model_path in self.loaded_models:
logger.info(f"Using cached model: {model_id}")
return self.loaded_models[model_path]
try:
# Import here to avoid circular dependency
from ultralytics import YOLO
logger.info(f"Loading model: {model_id} from {model_path}")
# Load YOLO model
model = YOLO(model_path)
# Cache the model
self.loaded_models[model_path] = model
return model
except Exception as e:
raise ModelLoadError(f"Failed to load model {model_id}: {e}")
def cleanup_model(self, model_id: str) -> None:
"""
Clean up resources for a specific model.
Args:
model_id: Model identifier to clean up
"""
# Clean up loaded models
models_to_remove = []
for path, model in self.loaded_models.items():
if model_id in path:
models_to_remove.append(path)
for path in models_to_remove:
self.loaded_models.pop(path, None)
logger.info(f"Cleaned up model: {path}")
def cleanup_all(self) -> None:
"""Clean up all resources."""
# Clear loaded models
self.loaded_models.clear()
# Clean up extracted directories
for mpta_path, extracted_dir in self.extracted_paths.items():
if os.path.exists(extracted_dir):
try:
shutil.rmtree(extracted_dir)
logger.info(f"Cleaned up extracted directory: {extracted_dir}")
except Exception as e:
logger.error(f"Failed to clean up {extracted_dir}: {e}")
self.extracted_paths.clear()
def get_node_info(self, node: PipelineNode, level: int = 0) -> str:
"""
Get formatted information about a pipeline node.
Args:
node: Pipeline node
level: Indentation level
Returns:
Formatted node information
"""
indent = " " * level
info = []
info.append(f"{indent}Model: {node.model_id}")
info.append(f"{indent} File: {node.model_file}")
info.append(f"{indent} Multi-class: {node.multi_class}")
if node.expected_classes:
info.append(f"{indent} Expected: {', '.join(node.expected_classes)}")
if node.trigger_classes:
info.append(f"{indent} Triggers: {', '.join(node.trigger_classes)}")
info.append(f"{indent} Confidence: {node.min_confidence}")
if node.crop:
info.append(f"{indent} Crop: {node.crop_class} (ratio: {node.crop_expand_ratio})")
if node.actions:
info.append(f"{indent} Actions: {len(node.actions)}")
if node.parallel_actions:
info.append(f"{indent} Parallel Actions: {len(node.parallel_actions)}")
if node.branches:
info.append(f"{indent} Branches ({len(node.branches)}):")
for branch in node.branches:
info.append(self.get_node_info(branch, level + 2))
return "\n".join(info)
# Global pipeline loader instance
_pipeline_loader = None
def get_pipeline_loader(temp_dir: Optional[str] = None) -> PipelineLoader:
"""Get or create the global pipeline loader instance."""
global _pipeline_loader
if _pipeline_loader is None:
_pipeline_loader = PipelineLoader(temp_dir)
return _pipeline_loader
# Convenience functions
async def load_pipeline_from_mpta(mpta_path: str) -> PipelineNode:
"""Load a pipeline from an MPTA file."""
loader = get_pipeline_loader()
return await loader.load_pipeline(mpta_path)

View file

@ -0,0 +1,379 @@
"""
System monitoring utilities.
This module provides functions to monitor system resources including
CPU, memory, and GPU usage.
"""
import logging
import time
import psutil
from typing import Dict, Any, Optional, List
try:
import pynvml
NVIDIA_AVAILABLE = True
except ImportError:
NVIDIA_AVAILABLE = False
# Setup logging
logger = logging.getLogger("detector_worker.system_monitor")
# Global initialization flag
_nvidia_initialized = False
# Process start time
_start_time = time.time()
def initialize_nvidia_monitoring() -> bool:
"""
Initialize NVIDIA GPU monitoring.
Returns:
True if initialization successful, False otherwise
"""
global _nvidia_initialized
if not NVIDIA_AVAILABLE:
return False
if _nvidia_initialized:
return True
try:
pynvml.nvmlInit()
_nvidia_initialized = True
logger.info("NVIDIA GPU monitoring initialized")
return True
except Exception as e:
logger.warning(f"Failed to initialize NVIDIA monitoring: {e}")
return False
def get_gpu_info() -> List[Dict[str, Any]]:
"""
Get information about available GPUs.
Returns:
List of GPU information dictionaries
"""
if not NVIDIA_AVAILABLE or not _nvidia_initialized:
return []
gpu_info = []
try:
device_count = pynvml.nvmlDeviceGetCount()
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
# Get GPU information
name = pynvml.nvmlDeviceGetName(handle).decode('utf-8')
# Get memory info
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
total_memory = mem_info.total / (1024 ** 3) # Convert to GB
used_memory = mem_info.used / (1024 ** 3)
free_memory = mem_info.free / (1024 ** 3)
# Get utilization
try:
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
gpu_util = util.gpu
memory_util = util.memory
except Exception:
gpu_util = 0
memory_util = 0
# Get temperature
try:
temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
except Exception:
temp = 0
gpu_info.append({
"index": i,
"name": name,
"gpu_utilization": gpu_util,
"memory_utilization": memory_util,
"total_memory_gb": round(total_memory, 2),
"used_memory_gb": round(used_memory, 2),
"free_memory_gb": round(free_memory, 2),
"temperature": temp
})
except Exception as e:
logger.error(f"Error getting GPU info: {e}")
return gpu_info
def get_gpu_metrics() -> Dict[str, float]:
"""
Get current GPU metrics.
Returns:
Dictionary with GPU utilization and memory usage
"""
if not NVIDIA_AVAILABLE or not _nvidia_initialized:
return {"gpu_percent": 0.0, "gpu_memory_percent": 0.0}
try:
# Get first GPU metrics
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
# Get utilization
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
# Get memory info
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
memory_percent = (mem_info.used / mem_info.total) * 100
return {
"gpu_percent": float(util.gpu),
"gpu_memory_percent": float(memory_percent)
}
except Exception as e:
logger.debug(f"Error getting GPU metrics: {e}")
return {"gpu_percent": 0.0, "gpu_memory_percent": 0.0}
def get_cpu_metrics() -> Dict[str, float]:
"""
Get current CPU metrics.
Returns:
Dictionary with CPU usage percentage
"""
try:
# Get CPU percent with 0.1 second interval
cpu_percent = psutil.cpu_percent(interval=0.1)
# Get per-core CPU usage
cpu_per_core = psutil.cpu_percent(interval=0.1, percpu=True)
return {
"cpu_percent": cpu_percent,
"cpu_per_core": cpu_per_core,
"cpu_count": psutil.cpu_count()
}
except Exception as e:
logger.error(f"Error getting CPU metrics: {e}")
return {"cpu_percent": 0.0, "cpu_per_core": [], "cpu_count": 0}
def get_memory_metrics() -> Dict[str, Any]:
"""
Get current memory metrics.
Returns:
Dictionary with memory usage information
"""
try:
# Get virtual memory info
virtual_mem = psutil.virtual_memory()
# Get swap memory info
swap_mem = psutil.swap_memory()
return {
"memory_percent": virtual_mem.percent,
"memory_used_gb": round(virtual_mem.used / (1024 ** 3), 2),
"memory_available_gb": round(virtual_mem.available / (1024 ** 3), 2),
"memory_total_gb": round(virtual_mem.total / (1024 ** 3), 2),
"swap_percent": swap_mem.percent,
"swap_used_gb": round(swap_mem.used / (1024 ** 3), 2)
}
except Exception as e:
logger.error(f"Error getting memory metrics: {e}")
return {
"memory_percent": 0.0,
"memory_used_gb": 0.0,
"memory_available_gb": 0.0,
"memory_total_gb": 0.0,
"swap_percent": 0.0,
"swap_used_gb": 0.0
}
def get_disk_metrics(path: str = "/") -> Dict[str, Any]:
"""
Get disk usage metrics for specified path.
Args:
path: Path to check disk usage for
Returns:
Dictionary with disk usage information
"""
try:
disk_usage = psutil.disk_usage(path)
return {
"disk_percent": disk_usage.percent,
"disk_used_gb": round(disk_usage.used / (1024 ** 3), 2),
"disk_free_gb": round(disk_usage.free / (1024 ** 3), 2),
"disk_total_gb": round(disk_usage.total / (1024 ** 3), 2)
}
except Exception as e:
logger.error(f"Error getting disk metrics: {e}")
return {
"disk_percent": 0.0,
"disk_used_gb": 0.0,
"disk_free_gb": 0.0,
"disk_total_gb": 0.0
}
def get_process_metrics() -> Dict[str, Any]:
"""
Get current process metrics.
Returns:
Dictionary with process-specific metrics
"""
try:
process = psutil.Process()
# Get process info
with process.oneshot():
cpu_percent = process.cpu_percent()
memory_info = process.memory_info()
memory_percent = process.memory_percent()
num_threads = process.num_threads()
# Get open file descriptors
try:
num_fds = len(process.open_files())
except Exception:
num_fds = 0
return {
"process_cpu_percent": cpu_percent,
"process_memory_mb": round(memory_info.rss / (1024 ** 2), 2),
"process_memory_percent": round(memory_percent, 2),
"process_threads": num_threads,
"process_open_files": num_fds
}
except Exception as e:
logger.error(f"Error getting process metrics: {e}")
return {
"process_cpu_percent": 0.0,
"process_memory_mb": 0.0,
"process_memory_percent": 0.0,
"process_threads": 0,
"process_open_files": 0
}
def get_network_metrics() -> Dict[str, Any]:
"""
Get network I/O metrics.
Returns:
Dictionary with network statistics
"""
try:
net_io = psutil.net_io_counters()
return {
"bytes_sent": net_io.bytes_sent,
"bytes_recv": net_io.bytes_recv,
"packets_sent": net_io.packets_sent,
"packets_recv": net_io.packets_recv,
"errin": net_io.errin,
"errout": net_io.errout,
"dropin": net_io.dropin,
"dropout": net_io.dropout
}
except Exception as e:
logger.error(f"Error getting network metrics: {e}")
return {
"bytes_sent": 0,
"bytes_recv": 0,
"packets_sent": 0,
"packets_recv": 0,
"errin": 0,
"errout": 0,
"dropin": 0,
"dropout": 0
}
def get_system_metrics() -> Dict[str, Any]:
"""
Get comprehensive system metrics.
Returns:
Dictionary containing all system metrics
"""
# Initialize GPU monitoring if not done
if NVIDIA_AVAILABLE and not _nvidia_initialized:
initialize_nvidia_monitoring()
metrics = {
"timestamp": time.time(),
"uptime": time.time() - _start_time,
"start_time": _start_time
}
# Get CPU metrics
cpu_metrics = get_cpu_metrics()
metrics.update(cpu_metrics)
# Get memory metrics
memory_metrics = get_memory_metrics()
metrics.update(memory_metrics)
# Get GPU metrics
gpu_metrics = get_gpu_metrics()
metrics.update(gpu_metrics)
# Get process metrics
process_metrics = get_process_metrics()
metrics.update(process_metrics)
return metrics
def get_resource_summary() -> str:
"""
Get a formatted summary of system resources.
Returns:
Formatted string with resource summary
"""
metrics = get_system_metrics()
summary = []
summary.append(f"CPU: {metrics['cpu_percent']:.1f}%")
summary.append(f"Memory: {metrics['memory_percent']:.1f}% ({metrics['memory_used_gb']:.1f}GB/{metrics['memory_total_gb']:.1f}GB)")
if metrics['gpu_percent'] > 0:
summary.append(f"GPU: {metrics['gpu_percent']:.1f}%")
summary.append(f"GPU Memory: {metrics['gpu_memory_percent']:.1f}%")
summary.append(f"Process Memory: {metrics['process_memory_mb']:.1f}MB")
summary.append(f"Threads: {metrics['process_threads']}")
return " | ".join(summary)
def cleanup_nvidia_monitoring():
"""Clean up NVIDIA monitoring resources."""
global _nvidia_initialized
if NVIDIA_AVAILABLE and _nvidia_initialized:
try:
pynvml.nvmlShutdown()
_nvidia_initialized = False
logger.info("NVIDIA GPU monitoring cleaned up")
except Exception as e:
logger.error(f"Error cleaning up NVIDIA monitoring: {e}")
# Initialize on import
if NVIDIA_AVAILABLE:
initialize_nvidia_monitoring()