Refactor: PHASE 3: Action & Storage Extraction
This commit is contained in:
parent
4e9ae6bcc4
commit
cdeaaf4a4f
5 changed files with 3048 additions and 0 deletions
669
detector_worker/pipeline/action_executor.py
Normal file
669
detector_worker/pipeline/action_executor.py
Normal file
|
@ -0,0 +1,669 @@
|
||||||
|
"""
|
||||||
|
Action execution engine for pipeline workflows.
|
||||||
|
|
||||||
|
This module provides comprehensive action execution functionality including:
|
||||||
|
- Redis-based actions (save image, publish messages)
|
||||||
|
- Database actions (PostgreSQL updates with field mapping)
|
||||||
|
- Parallel action coordination
|
||||||
|
- Dynamic context resolution
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Any, Optional, Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..core.constants import (
|
||||||
|
REDIS_IMAGE_DEFAULT_QUALITY,
|
||||||
|
REDIS_IMAGE_DEFAULT_FORMAT,
|
||||||
|
DEFAULT_IMAGE_ENCODING_PARAMS
|
||||||
|
)
|
||||||
|
from ..core.exceptions import ActionExecutionError, create_pipeline_error
|
||||||
|
from ..pipeline.field_mapper import resolve_field_mapping
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActionContext:
|
||||||
|
"""Context information for action execution."""
|
||||||
|
timestamp_ms: int = field(default_factory=lambda: int(time.time() * 1000))
|
||||||
|
uuid: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
timestamp: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
||||||
|
filename: str = field(default_factory=lambda: f"{uuid.uuid4()}.jpg")
|
||||||
|
image_key: Optional[str] = None
|
||||||
|
|
||||||
|
# Detection context
|
||||||
|
camera_id: str = "unknown"
|
||||||
|
display_id: str = "unknown"
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
backend_session_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Additional context from detection result
|
||||||
|
regions: Optional[Dict[str, Any]] = None
|
||||||
|
detections: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for template formatting."""
|
||||||
|
result = {
|
||||||
|
"timestamp_ms": self.timestamp_ms,
|
||||||
|
"uuid": self.uuid,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"filename": self.filename,
|
||||||
|
"camera_id": self.camera_id,
|
||||||
|
"display_id": self.display_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.image_key:
|
||||||
|
result["image_key"] = self.image_key
|
||||||
|
if self.session_id:
|
||||||
|
result["session_id"] = self.session_id
|
||||||
|
if self.backend_session_id:
|
||||||
|
result["backend_session_id"] = self.backend_session_id
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def update_from_dict(self, data: Dict[str, Any]) -> None:
|
||||||
|
"""Update context from detection result dictionary."""
|
||||||
|
if "camera_id" in data:
|
||||||
|
self.camera_id = data["camera_id"]
|
||||||
|
if "display_id" in data:
|
||||||
|
self.display_id = data["display_id"]
|
||||||
|
if "session_id" in data:
|
||||||
|
self.session_id = data["session_id"]
|
||||||
|
if "backend_session_id" in data:
|
||||||
|
self.backend_session_id = data["backend_session_id"]
|
||||||
|
if "regions" in data:
|
||||||
|
self.regions = data["regions"]
|
||||||
|
if "detections" in data:
|
||||||
|
self.detections = data["detections"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActionResult:
|
||||||
|
"""Result from action execution."""
|
||||||
|
success: bool
|
||||||
|
action_type: str
|
||||||
|
message: str = ""
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
result = {
|
||||||
|
"success": self.success,
|
||||||
|
"action_type": self.action_type,
|
||||||
|
"message": self.message
|
||||||
|
}
|
||||||
|
if self.data:
|
||||||
|
result["data"] = self.data
|
||||||
|
if self.error:
|
||||||
|
result["error"] = self.error
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RedisActionExecutor:
|
||||||
|
"""Executor for Redis-based actions."""
|
||||||
|
|
||||||
|
def __init__(self, redis_client):
|
||||||
|
"""
|
||||||
|
Initialize Redis action executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis_client: Redis client instance
|
||||||
|
"""
|
||||||
|
self.redis_client = redis_client
|
||||||
|
|
||||||
|
def _crop_region_by_class(self,
|
||||||
|
frame: np.ndarray,
|
||||||
|
regions_dict: Dict[str, Any],
|
||||||
|
class_name: str) -> Optional[np.ndarray]:
|
||||||
|
"""Crop a specific region from frame based on detected class."""
|
||||||
|
if class_name not in regions_dict:
|
||||||
|
logger.warning(f"Class '{class_name}' not found in detected regions")
|
||||||
|
return None
|
||||||
|
|
||||||
|
bbox = regions_dict[class_name]["bbox"]
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
|
||||||
|
# Validate bbox coordinates
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
logger.warning(f"Invalid bbox for class {class_name}: {bbox}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
cropped = frame[y1:y2, x1:x2]
|
||||||
|
if cropped.size == 0:
|
||||||
|
logger.warning(f"Empty crop for class {class_name}")
|
||||||
|
return None
|
||||||
|
return cropped
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cropping region for class {class_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _encode_image(self,
|
||||||
|
image: np.ndarray,
|
||||||
|
img_format: str = REDIS_IMAGE_DEFAULT_FORMAT,
|
||||||
|
quality: int = REDIS_IMAGE_DEFAULT_QUALITY) -> Tuple[bool, Optional[bytes]]:
|
||||||
|
"""Encode image with specified format and quality."""
|
||||||
|
try:
|
||||||
|
img_format = img_format.lower()
|
||||||
|
|
||||||
|
if img_format == "jpeg" or img_format == "jpg":
|
||||||
|
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||||
|
success, buffer = cv2.imencode('.jpg', image, encode_params)
|
||||||
|
elif img_format == "png":
|
||||||
|
compression = DEFAULT_IMAGE_ENCODING_PARAMS["png"]["compression"]
|
||||||
|
encode_params = [cv2.IMWRITE_PNG_COMPRESSION, compression]
|
||||||
|
success, buffer = cv2.imencode('.png', image, encode_params)
|
||||||
|
elif img_format == "webp":
|
||||||
|
encode_params = [cv2.IMWRITE_WEBP_QUALITY, quality]
|
||||||
|
success, buffer = cv2.imencode('.webp', image, encode_params)
|
||||||
|
else:
|
||||||
|
# Default to JPEG
|
||||||
|
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||||
|
success, buffer = cv2.imencode('.jpg', image, encode_params)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return True, buffer.tobytes()
|
||||||
|
else:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error encoding image as {img_format}: {e}")
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def execute_redis_save_image(self,
|
||||||
|
action: Dict[str, Any],
|
||||||
|
frame: np.ndarray,
|
||||||
|
context: ActionContext,
|
||||||
|
regions_dict: Optional[Dict[str, Any]] = None) -> ActionResult:
|
||||||
|
"""Execute redis_save_image action."""
|
||||||
|
try:
|
||||||
|
# Format the key template
|
||||||
|
key = action["key"].format(**context.to_dict())
|
||||||
|
|
||||||
|
# Check if we need to crop a specific region
|
||||||
|
region_name = action.get("region")
|
||||||
|
image_to_save = frame
|
||||||
|
|
||||||
|
if region_name and regions_dict:
|
||||||
|
cropped_image = self._crop_region_by_class(frame, regions_dict, region_name)
|
||||||
|
if cropped_image is not None:
|
||||||
|
image_to_save = cropped_image
|
||||||
|
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Could not crop region '{region_name}', saving full frame instead")
|
||||||
|
|
||||||
|
# Encode image with specified format and quality
|
||||||
|
img_format = action.get("format", REDIS_IMAGE_DEFAULT_FORMAT)
|
||||||
|
quality = action.get("quality", REDIS_IMAGE_DEFAULT_QUALITY)
|
||||||
|
|
||||||
|
success, image_bytes = self._encode_image(image_to_save, img_format, quality)
|
||||||
|
|
||||||
|
if not success or not image_bytes:
|
||||||
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type="redis_save_image",
|
||||||
|
error=f"Failed to encode image as {img_format}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to Redis
|
||||||
|
expire_seconds = action.get("expire_seconds")
|
||||||
|
if expire_seconds:
|
||||||
|
self.redis_client.setex(key, expire_seconds, image_bytes)
|
||||||
|
message = f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)"
|
||||||
|
else:
|
||||||
|
self.redis_client.set(key, image_bytes)
|
||||||
|
message = f"Saved image to Redis with key: {key}"
|
||||||
|
|
||||||
|
logger.info(message)
|
||||||
|
|
||||||
|
# Update context with image key
|
||||||
|
context.image_key = key
|
||||||
|
|
||||||
|
return ActionResult(
|
||||||
|
success=True,
|
||||||
|
action_type="redis_save_image",
|
||||||
|
message=message,
|
||||||
|
data={"key": key, "format": img_format, "quality": quality}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error executing redis_save_image: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type="redis_save_image",
|
||||||
|
error=error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute_redis_publish(self,
|
||||||
|
action: Dict[str, Any],
|
||||||
|
context: ActionContext) -> ActionResult:
|
||||||
|
"""Execute redis_publish action."""
|
||||||
|
try:
|
||||||
|
channel = action["channel"]
|
||||||
|
message_template = action["message"]
|
||||||
|
|
||||||
|
# Handle JSON message format by creating it programmatically
|
||||||
|
if message_template.strip().startswith('{') and message_template.strip().endswith('}'):
|
||||||
|
# Create JSON data programmatically to avoid formatting issues
|
||||||
|
json_data = {}
|
||||||
|
|
||||||
|
# Add common fields
|
||||||
|
json_data["event"] = "frontal_detected"
|
||||||
|
json_data["display_id"] = context.display_id
|
||||||
|
json_data["session_id"] = context.session_id
|
||||||
|
json_data["timestamp"] = context.timestamp
|
||||||
|
json_data["image_key"] = context.image_key or ""
|
||||||
|
|
||||||
|
# Convert to JSON string
|
||||||
|
message = json.dumps(json_data)
|
||||||
|
else:
|
||||||
|
# Use regular string formatting for non-JSON messages
|
||||||
|
message = message_template.format(**context.to_dict())
|
||||||
|
|
||||||
|
# Test Redis connection
|
||||||
|
if not self.redis_client:
|
||||||
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type="redis_publish",
|
||||||
|
error="Redis client is None"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.redis_client.ping()
|
||||||
|
logger.debug("Redis connection is active")
|
||||||
|
except Exception as ping_error:
|
||||||
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type="redis_publish",
|
||||||
|
error=f"Redis connection test failed: {ping_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publish to Redis
|
||||||
|
result = self.redis_client.publish(channel, message)
|
||||||
|
|
||||||
|
success_msg = f"Published message to Redis channel '{channel}': {message}"
|
||||||
|
logger.info(success_msg)
|
||||||
|
logger.info(f"Redis publish result (subscribers count): {result}")
|
||||||
|
|
||||||
|
# Additional debug info
|
||||||
|
if result == 0:
|
||||||
|
logger.warning(f"No subscribers listening to channel '{channel}'")
|
||||||
|
else:
|
||||||
|
logger.info(f"Message delivered to {result} subscriber(s)")
|
||||||
|
|
||||||
|
return ActionResult(
|
||||||
|
success=True,
|
||||||
|
action_type="redis_publish",
|
||||||
|
message=success_msg,
|
||||||
|
data={
|
||||||
|
"channel": channel,
|
||||||
|
"message": message,
|
||||||
|
"subscribers_count": result
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
|
error_msg = f"Missing key in redis_publish message template: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.debug(f"Available context keys: {list(context.to_dict().keys())}")
|
||||||
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type="redis_publish",
|
||||||
|
error=error_msg
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error in redis_publish action: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.debug(f"Message template: {action['message']}")
|
||||||
|
logger.debug(f"Available context keys: {list(context.to_dict().keys())}")
|
||||||
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type="redis_publish",
|
||||||
|
error=error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseActionExecutor:
|
||||||
|
"""Executor for database-based actions."""
|
||||||
|
|
||||||
|
def __init__(self, db_manager):
|
||||||
|
"""
|
||||||
|
Initialize database action executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_manager: Database manager instance
|
||||||
|
"""
|
||||||
|
self.db_manager = db_manager
|
||||||
|
|
||||||
|
def execute_postgresql_update_combined(self,
|
||||||
|
action: Dict[str, Any],
|
||||||
|
detection_result: Dict[str, Any],
|
||||||
|
branch_results: Dict[str, Any]) -> ActionResult:
|
||||||
|
"""Execute PostgreSQL update with combined branch results."""
|
||||||
|
try:
|
||||||
|
table = action["table"]
|
||||||
|
key_field = action["key_field"]
|
||||||
|
key_value_template = action["key_value"]
|
||||||
|
fields = action["fields"]
|
||||||
|
|
||||||
|
# Create context for key value formatting
|
||||||
|
action_context = {**detection_result}
|
||||||
|
key_value = key_value_template.format(**action_context)
|
||||||
|
|
||||||
|
logger.info(f"Executing database update: table={table}, {key_field}={key_value}")
|
||||||
|
logger.debug(f"Available branch results: {list(branch_results.keys())}")
|
||||||
|
|
||||||
|
# Process field mappings
|
||||||
|
mapped_fields = {}
|
||||||
|
mapping_errors = []
|
||||||
|
|
||||||
|
for db_field, value_template in fields.items():
|
||||||
|
try:
|
||||||
|
mapped_value = resolve_field_mapping(value_template, branch_results, action_context)
|
||||||
|
if mapped_value is not None:
|
||||||
|
mapped_fields[db_field] = mapped_value
|
||||||
|
logger.info(f"Mapped field: {db_field} = {mapped_value}")
|
||||||
|
else:
|
||||||
|
error_msg = f"Could not resolve field mapping for {db_field}: {value_template}"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
mapping_errors.append(error_msg)
|
||||||
|
logger.debug(f"Available branch results: {branch_results}")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error mapping field {db_field} with template '{value_template}': {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
mapping_errors.append(error_msg)
|
||||||
|
|
||||||
|
if not mapped_fields:
|
||||||
|
error_msg = "No fields mapped successfully, skipping database update"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
logger.debug(f"Branch results available: {branch_results}")
|
||||||
|
logger.debug(f"Field templates: {fields}")
|
||||||
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type="postgresql_update_combined",
|
||||||
|
error=error_msg,
|
||||||
|
data={"mapping_errors": mapping_errors}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add updated_at field automatically
|
||||||
|
mapped_fields["updated_at"] = "NOW()"
|
||||||
|
|
||||||
|
# Execute the database update
|
||||||
|
logger.info(f"Attempting database update with fields: {mapped_fields}")
|
||||||
|
success = self.db_manager.execute_update(table, key_field, key_value, mapped_fields)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
success_msg = f"Successfully updated database: {table} with {len(mapped_fields)} fields"
|
||||||
|
logger.info(f"✅ {success_msg}")
|
||||||
|
logger.info(f"Updated fields: {mapped_fields}")
|
||||||
|
return ActionResult(
|
||||||
|
success=True,
|
||||||
|
action_type="postgresql_update_combined",
|
||||||
|
message=success_msg,
|
||||||
|
data={
|
||||||
|
"table": table,
|
||||||
|
"key_field": key_field,
|
||||||
|
"key_value": key_value,
|
||||||
|
"updated_fields": mapped_fields
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_msg = f"Failed to update database: {table}"
|
||||||
|
logger.error(f"❌ {error_msg}")
|
||||||
|
logger.error(f"Attempted update with: {key_field}={key_value}, fields={mapped_fields}")
|
||||||
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type="postgresql_update_combined",
|
||||||
|
error=error_msg,
|
||||||
|
data={
|
||||||
|
"table": table,
|
||||||
|
"key_field": key_field,
|
||||||
|
"key_value": key_value,
|
||||||
|
"attempted_fields": mapped_fields
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
|
error_msg = f"Missing required field in postgresql_update_combined action: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.debug(f"Action config: {action}")
|
||||||
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type="postgresql_update_combined",
|
||||||
|
error=error_msg
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error in postgresql_update_combined action: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type="postgresql_update_combined",
|
||||||
|
error=error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionExecutor:
|
||||||
|
"""
|
||||||
|
Main action execution engine for pipeline workflows.
|
||||||
|
|
||||||
|
This class coordinates the execution of various action types including
|
||||||
|
Redis operations, database operations, and parallel action coordination.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize action executor."""
|
||||||
|
self._redis_executors: Dict[Any, RedisActionExecutor] = {}
|
||||||
|
self._db_executors: Dict[Any, DatabaseActionExecutor] = {}
|
||||||
|
|
||||||
|
def _get_redis_executor(self, redis_client) -> RedisActionExecutor:
|
||||||
|
"""Get or create Redis executor for a client."""
|
||||||
|
if redis_client not in self._redis_executors:
|
||||||
|
self._redis_executors[redis_client] = RedisActionExecutor(redis_client)
|
||||||
|
return self._redis_executors[redis_client]
|
||||||
|
|
||||||
|
def _get_db_executor(self, db_manager) -> DatabaseActionExecutor:
|
||||||
|
"""Get or create database executor for a manager."""
|
||||||
|
if db_manager not in self._db_executors:
|
||||||
|
self._db_executors[db_manager] = DatabaseActionExecutor(db_manager)
|
||||||
|
return self._db_executors[db_manager]
|
||||||
|
|
||||||
|
def _create_action_context(self, detection_result: Dict[str, Any]) -> ActionContext:
|
||||||
|
"""Create action context from detection result."""
|
||||||
|
context = ActionContext()
|
||||||
|
context.update_from_dict(detection_result)
|
||||||
|
return context
|
||||||
|
|
||||||
|
def execute_actions(self,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
frame: np.ndarray,
|
||||||
|
detection_result: Dict[str, Any],
|
||||||
|
regions_dict: Optional[Dict[str, Any]] = None) -> List[ActionResult]:
|
||||||
|
"""
|
||||||
|
Execute all actions for a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Pipeline node configuration
|
||||||
|
frame: Input frame
|
||||||
|
detection_result: Detection result data
|
||||||
|
regions_dict: Optional region dictionary for cropping
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of action results
|
||||||
|
"""
|
||||||
|
if not node.get("redis_client") or not node.get("actions"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
redis_executor = self._get_redis_executor(node["redis_client"])
|
||||||
|
context = self._create_action_context(detection_result)
|
||||||
|
|
||||||
|
for action in node["actions"]:
|
||||||
|
try:
|
||||||
|
action_type = action.get("type")
|
||||||
|
|
||||||
|
if action_type == "redis_save_image":
|
||||||
|
result = redis_executor.execute_redis_save_image(
|
||||||
|
action, frame, context, regions_dict
|
||||||
|
)
|
||||||
|
elif action_type == "redis_publish":
|
||||||
|
result = redis_executor.execute_redis_publish(action, context)
|
||||||
|
else:
|
||||||
|
result = ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type=action_type or "unknown",
|
||||||
|
error=f"Unknown action type: {action_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
logger.error(f"Action failed: {result.error}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error executing action {action.get('type', 'unknown')}: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
results.append(ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type=action.get('type', 'unknown'),
|
||||||
|
error=error_msg
|
||||||
|
))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def execute_parallel_actions(self,
|
||||||
|
node: Dict[str, Any],
|
||||||
|
frame: np.ndarray,
|
||||||
|
detection_result: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any]) -> List[ActionResult]:
|
||||||
|
"""
|
||||||
|
Execute parallel actions after all required branches have completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Pipeline node configuration
|
||||||
|
frame: Input frame
|
||||||
|
detection_result: Detection result data with branch results
|
||||||
|
regions_dict: Region dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of action results
|
||||||
|
"""
|
||||||
|
if not node.get("parallelActions"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
branch_results = detection_result.get("branch_results", {})
|
||||||
|
|
||||||
|
logger.debug("Executing parallel actions...")
|
||||||
|
|
||||||
|
for action in node["parallelActions"]:
|
||||||
|
try:
|
||||||
|
action_type = action.get("type")
|
||||||
|
logger.debug(f"Processing parallel action: {action_type}")
|
||||||
|
|
||||||
|
if action_type == "postgresql_update_combined":
|
||||||
|
# Check if all required branches have completed
|
||||||
|
wait_for_branches = action.get("waitForBranches", [])
|
||||||
|
missing_branches = [branch for branch in wait_for_branches if branch not in branch_results]
|
||||||
|
|
||||||
|
if missing_branches:
|
||||||
|
error_msg = f"Cannot execute postgresql_update_combined: missing branch results for {missing_branches}"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
results.append(ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type=action_type,
|
||||||
|
error=error_msg,
|
||||||
|
data={"missing_branches": missing_branches}
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"All required branches completed: {wait_for_branches}")
|
||||||
|
|
||||||
|
# Execute the database update
|
||||||
|
if node.get("db_manager"):
|
||||||
|
db_executor = self._get_db_executor(node["db_manager"])
|
||||||
|
result = db_executor.execute_postgresql_update_combined(
|
||||||
|
action, detection_result, branch_results
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
else:
|
||||||
|
results.append(ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type=action_type,
|
||||||
|
error="No database manager available"
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
error_msg = f"Unknown parallel action type: {action_type}"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
results.append(ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type=action_type or "unknown",
|
||||||
|
error=error_msg
|
||||||
|
))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error executing parallel action {action.get('type', 'unknown')}: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
results.append(ActionResult(
|
||||||
|
success=False,
|
||||||
|
action_type=action.get('type', 'unknown'),
|
||||||
|
error=error_msg
|
||||||
|
))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# Global action executor instance
|
||||||
|
action_executor = ActionExecutor()
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide the same interface as the original functions in pympta.py
|
||||||
|
|
||||||
|
def execute_actions(node: Dict[str, Any],
|
||||||
|
frame: np.ndarray,
|
||||||
|
detection_result: Dict[str, Any],
|
||||||
|
regions_dict: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
"""Execute all actions for a node using global executor."""
|
||||||
|
results = action_executor.execute_actions(node, frame, detection_result, regions_dict)
|
||||||
|
|
||||||
|
# Log any failures (maintain original behavior of not returning results)
|
||||||
|
for result in results:
|
||||||
|
if not result.success:
|
||||||
|
logger.error(f"Action execution failed: {result.error}")
|
||||||
|
|
||||||
|
|
||||||
|
def execute_parallel_actions(node: Dict[str, Any],
|
||||||
|
frame: np.ndarray,
|
||||||
|
detection_result: Dict[str, Any],
|
||||||
|
regions_dict: Dict[str, Any]) -> None:
|
||||||
|
"""Execute parallel actions using global executor."""
|
||||||
|
results = action_executor.execute_parallel_actions(node, frame, detection_result, regions_dict)
|
||||||
|
|
||||||
|
# Log any failures (maintain original behavior of not returning results)
|
||||||
|
for result in results:
|
||||||
|
if not result.success:
|
||||||
|
logger.error(f"Parallel action execution failed: {result.error}")
|
||||||
|
|
||||||
|
|
||||||
|
def execute_postgresql_update_combined(node: Dict[str, Any],
|
||||||
|
action: Dict[str, Any],
|
||||||
|
detection_result: Dict[str, Any],
|
||||||
|
branch_results: Dict[str, Any]) -> None:
|
||||||
|
"""Execute PostgreSQL update with combined branch results using global executor."""
|
||||||
|
if node.get("db_manager"):
|
||||||
|
db_executor = action_executor._get_db_executor(node["db_manager"])
|
||||||
|
result = db_executor.execute_postgresql_update_combined(action, detection_result, branch_results)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
logger.error(f"PostgreSQL update failed: {result.error}")
|
||||||
|
else:
|
||||||
|
logger.error("No database manager available for postgresql_update_combined action")
|
341
detector_worker/pipeline/field_mapper.py
Normal file
341
detector_worker/pipeline/field_mapper.py
Normal file
|
@ -0,0 +1,341 @@
|
||||||
|
"""
|
||||||
|
Field mapping and template resolution for dynamic database operations.
|
||||||
|
|
||||||
|
This module provides functionality for resolving field mapping templates
|
||||||
|
that reference branch results and context variables for database operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional, List, Union
|
||||||
|
|
||||||
|
from ..core.exceptions import FieldMappingError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FieldMapper:
|
||||||
|
"""
|
||||||
|
Field mapping resolver for dynamic template resolution.
|
||||||
|
|
||||||
|
This class handles the resolution of field mapping templates that can reference:
|
||||||
|
- Branch results (e.g., {car_brand_cls_v1.brand})
|
||||||
|
- Context variables (e.g., {session_id})
|
||||||
|
- Nested field lookups with fallback strategies
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize field mapper."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _extract_branch_references(self, template: str) -> List[str]:
|
||||||
|
"""Extract branch references from template string."""
|
||||||
|
# Match patterns like {model_id.field_name}
|
||||||
|
branch_refs = re.findall(r'\{([^}]+\.[^}]+)\}', template)
|
||||||
|
return branch_refs
|
||||||
|
|
||||||
|
def _resolve_simple_template(self, template: str, context: Dict[str, Any]) -> str:
|
||||||
|
"""Resolve simple template without branch references."""
|
||||||
|
try:
|
||||||
|
result = template.format(**context)
|
||||||
|
logger.debug(f"Simple template resolved: '{template}' -> '{result}'")
|
||||||
|
return result
|
||||||
|
except KeyError as e:
|
||||||
|
logger.warning(f"Could not resolve context variable in simple template: {e}")
|
||||||
|
return template
|
||||||
|
|
||||||
|
def _find_fallback_value(self,
|
||||||
|
branch_data: Dict[str, Any],
|
||||||
|
field_name: str,
|
||||||
|
model_id: str) -> Optional[str]:
|
||||||
|
"""Find fallback value using various strategies."""
|
||||||
|
if not isinstance(branch_data, dict):
|
||||||
|
logger.error(f"Branch data for '{model_id}' is not a dictionary: {type(branch_data)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# First, try the exact field name
|
||||||
|
if field_name in branch_data:
|
||||||
|
return branch_data[field_name]
|
||||||
|
|
||||||
|
# Then try 'class' field as fallback
|
||||||
|
if 'class' in branch_data:
|
||||||
|
fallback_value = branch_data['class']
|
||||||
|
logger.info(f"Using 'class' field as fallback for '{field_name}': '{fallback_value}'")
|
||||||
|
return fallback_value
|
||||||
|
|
||||||
|
# For brand models, check if the class name exists as a key
|
||||||
|
if field_name == 'brand' and branch_data.get('class') in branch_data:
|
||||||
|
fallback_value = branch_data[branch_data['class']]
|
||||||
|
logger.info(f"Found brand value using class name as key: '{fallback_value}'")
|
||||||
|
return fallback_value
|
||||||
|
|
||||||
|
# For body_type models, check if the class name exists as a key
|
||||||
|
if field_name == 'body_type' and branch_data.get('class') in branch_data:
|
||||||
|
fallback_value = branch_data[branch_data['class']]
|
||||||
|
logger.info(f"Found body_type value using class name as key: '{fallback_value}'")
|
||||||
|
return fallback_value
|
||||||
|
|
||||||
|
# Additional fallback strategies for common field mappings
|
||||||
|
field_mappings = {
|
||||||
|
'brand': ['car_brand', 'brand_name', 'detected_brand'],
|
||||||
|
'body_type': ['car_body_type', 'bodytype', 'body', 'car_type'],
|
||||||
|
'model': ['car_model', 'model_name', 'detected_model'],
|
||||||
|
'color': ['car_color', 'color_name', 'detected_color']
|
||||||
|
}
|
||||||
|
|
||||||
|
if field_name in field_mappings:
|
||||||
|
for alternative in field_mappings[field_name]:
|
||||||
|
if alternative in branch_data:
|
||||||
|
fallback_value = branch_data[alternative]
|
||||||
|
logger.info(f"Found '{field_name}' value using alternative field '{alternative}': '{fallback_value}'")
|
||||||
|
return fallback_value
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_branch_reference(self,
|
||||||
|
ref: str,
|
||||||
|
branch_results: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""Resolve a single branch reference."""
|
||||||
|
try:
|
||||||
|
model_id, field_name = ref.split('.', 1)
|
||||||
|
logger.debug(f"Processing branch reference: model_id='{model_id}', field_name='{field_name}'")
|
||||||
|
|
||||||
|
if model_id not in branch_results:
|
||||||
|
logger.warning(f"Branch '{model_id}' not found in results. Available branches: {list(branch_results.keys())}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
branch_data = branch_results[model_id]
|
||||||
|
logger.debug(f"Branch '{model_id}' data: {branch_data}")
|
||||||
|
|
||||||
|
if field_name in branch_data:
|
||||||
|
field_value = branch_data[field_name]
|
||||||
|
logger.info(f"✅ Resolved {ref} to '{field_value}'")
|
||||||
|
return str(field_value)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Field '{field_name}' not found in branch '{model_id}' results.")
|
||||||
|
logger.debug(f"Available fields in '{model_id}': {list(branch_data.keys()) if isinstance(branch_data, dict) else 'N/A'}")
|
||||||
|
|
||||||
|
# Try fallback strategies
|
||||||
|
fallback_value = self._find_fallback_value(branch_data, field_name, model_id)
|
||||||
|
|
||||||
|
if fallback_value is not None:
|
||||||
|
logger.info(f"✅ Resolved {ref} to '{fallback_value}' (using fallback)")
|
||||||
|
return str(fallback_value)
|
||||||
|
else:
|
||||||
|
logger.error(f"No suitable field found for '{field_name}' in branch '{model_id}'")
|
||||||
|
logger.debug(f"Branch data structure: {branch_data}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Invalid branch reference format: {ref}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error resolving branch reference '{ref}': {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def resolve_field_mapping(self,
|
||||||
|
value_template: str,
|
||||||
|
branch_results: Dict[str, Any],
|
||||||
|
action_context: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Resolve field mapping templates like {car_brand_cls_v1.brand}.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value_template: Template string with placeholders
|
||||||
|
branch_results: Dictionary of branch execution results
|
||||||
|
action_context: Context variables for template resolution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved string value or None if resolution failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"Resolving field mapping: '{value_template}'")
|
||||||
|
logger.debug(f"Available branch results: {list(branch_results.keys())}")
|
||||||
|
|
||||||
|
# Handle simple context variables first (non-branch references)
|
||||||
|
if '.' not in value_template:
|
||||||
|
result = self._resolve_simple_template(value_template, action_context)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Handle branch result references like {model_id.field}
|
||||||
|
branch_refs = self._extract_branch_references(value_template)
|
||||||
|
logger.debug(f"Found branch references: {branch_refs}")
|
||||||
|
|
||||||
|
resolved_template = value_template
|
||||||
|
|
||||||
|
for ref in branch_refs:
|
||||||
|
resolved_value = self._resolve_branch_reference(ref, branch_results)
|
||||||
|
|
||||||
|
if resolved_value is not None:
|
||||||
|
resolved_template = resolved_template.replace(f'{{{ref}}}', resolved_value)
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to resolve branch reference: {ref}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Format any remaining simple variables
|
||||||
|
try:
|
||||||
|
final_value = resolved_template.format(**action_context)
|
||||||
|
logger.debug(f"Final resolved value: '{final_value}'")
|
||||||
|
return final_value
|
||||||
|
except KeyError as e:
|
||||||
|
logger.warning(f"Could not resolve context variable in template: {e}")
|
||||||
|
return resolved_template
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error resolving field mapping '{value_template}': {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def resolve_multiple_fields(self,
|
||||||
|
field_templates: Dict[str, str],
|
||||||
|
branch_results: Dict[str, Any],
|
||||||
|
action_context: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Resolve multiple field mappings at once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_templates: Dictionary mapping field names to templates
|
||||||
|
branch_results: Dictionary of branch execution results
|
||||||
|
action_context: Context variables for template resolution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping field names to resolved values
|
||||||
|
"""
|
||||||
|
resolved_fields = {}
|
||||||
|
|
||||||
|
for field_name, template in field_templates.items():
|
||||||
|
try:
|
||||||
|
resolved_value = self.resolve_field_mapping(template, branch_results, action_context)
|
||||||
|
if resolved_value is not None:
|
||||||
|
resolved_fields[field_name] = resolved_value
|
||||||
|
logger.debug(f"Successfully resolved field '{field_name}': {resolved_value}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to resolve field '{field_name}' with template: {template}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error resolving field '{field_name}' with template '{template}': {e}")
|
||||||
|
|
||||||
|
return resolved_fields
|
||||||
|
|
||||||
|
def validate_template(self, template: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Validate a field mapping template and return analysis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: Template string to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with validation results and analysis
|
||||||
|
"""
|
||||||
|
analysis = {
|
||||||
|
"valid": True,
|
||||||
|
"errors": [],
|
||||||
|
"warnings": [],
|
||||||
|
"branch_references": [],
|
||||||
|
"context_references": [],
|
||||||
|
"has_branch_refs": False,
|
||||||
|
"has_context_refs": False
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract all placeholders
|
||||||
|
all_refs = re.findall(r'\{([^}]+)\}', template)
|
||||||
|
|
||||||
|
for ref in all_refs:
|
||||||
|
if '.' in ref:
|
||||||
|
# This is a branch reference
|
||||||
|
analysis["branch_references"].append(ref)
|
||||||
|
analysis["has_branch_refs"] = True
|
||||||
|
|
||||||
|
# Validate format
|
||||||
|
parts = ref.split('.')
|
||||||
|
if len(parts) != 2:
|
||||||
|
analysis["errors"].append(f"Invalid branch reference format: {ref}")
|
||||||
|
analysis["valid"] = False
|
||||||
|
elif not parts[0] or not parts[1]:
|
||||||
|
analysis["errors"].append(f"Empty model_id or field_name in reference: {ref}")
|
||||||
|
analysis["valid"] = False
|
||||||
|
else:
|
||||||
|
# This is a context reference
|
||||||
|
analysis["context_references"].append(ref)
|
||||||
|
analysis["has_context_refs"] = True
|
||||||
|
|
||||||
|
# Check for common issues
|
||||||
|
if analysis["has_branch_refs"] and not analysis["has_context_refs"]:
|
||||||
|
analysis["warnings"].append("Template only uses branch references, consider adding context info")
|
||||||
|
|
||||||
|
if not analysis["branch_references"] and not analysis["context_references"]:
|
||||||
|
analysis["warnings"].append("Template has no placeholders - it's a static value")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
analysis["valid"] = False
|
||||||
|
analysis["errors"].append(f"Template analysis failed: {e}")
|
||||||
|
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
|
||||||
|
# Global field mapper instance
|
||||||
|
field_mapper = FieldMapper()
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide the same interface as the original functions in pympta.py
|
||||||
|
|
||||||
|
def resolve_field_mapping(value_template: str,
|
||||||
|
branch_results: Dict[str, Any],
|
||||||
|
action_context: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""Resolve field mapping templates like {car_brand_cls_v1.brand}."""
|
||||||
|
return field_mapper.resolve_field_mapping(value_template, branch_results, action_context)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_multiple_field_mappings(field_templates: Dict[str, str],
|
||||||
|
branch_results: Dict[str, Any],
|
||||||
|
action_context: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""Resolve multiple field mappings at once."""
|
||||||
|
return field_mapper.resolve_multiple_fields(field_templates, branch_results, action_context)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_field_mapping_template(template: str) -> Dict[str, Any]:
|
||||||
|
"""Validate a field mapping template and return analysis."""
|
||||||
|
return field_mapper.validate_template(template)
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_field_mappings(branch_results: Dict[str, Any]) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Get available field mappings from branch results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch_results: Dictionary of branch execution results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping model IDs to available field names
|
||||||
|
"""
|
||||||
|
available_mappings = {}
|
||||||
|
|
||||||
|
for model_id, branch_data in branch_results.items():
|
||||||
|
if isinstance(branch_data, dict):
|
||||||
|
available_mappings[model_id] = list(branch_data.keys())
|
||||||
|
else:
|
||||||
|
logger.warning(f"Branch '{model_id}' data is not a dictionary: {type(branch_data)}")
|
||||||
|
available_mappings[model_id] = []
|
||||||
|
|
||||||
|
return available_mappings
|
||||||
|
|
||||||
|
|
||||||
|
def create_field_mapping_examples(branch_results: Dict[str, Any]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Create example field mapping templates based on available branch results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch_results: Dictionary of branch execution results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of example template strings
|
||||||
|
"""
|
||||||
|
examples = []
|
||||||
|
|
||||||
|
for model_id, branch_data in branch_results.items():
|
||||||
|
if isinstance(branch_data, dict):
|
||||||
|
for field_name in branch_data.keys():
|
||||||
|
example = f"{{{model_id}.{field_name}}}"
|
||||||
|
examples.append(example)
|
||||||
|
|
||||||
|
return examples
|
617
detector_worker/storage/database_manager.py
Normal file
617
detector_worker/storage/database_manager.py
Normal file
|
@ -0,0 +1,617 @@
|
||||||
|
"""
|
||||||
|
Database management and operations.
|
||||||
|
|
||||||
|
This module provides comprehensive database functionality including:
|
||||||
|
- PostgreSQL connection management
|
||||||
|
- Table schema management
|
||||||
|
- Dynamic query execution
|
||||||
|
- Transaction handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
import psycopg2.extras
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any, List, Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ..core.constants import (
|
||||||
|
DB_CONNECTION_TIMEOUT,
|
||||||
|
DB_OPERATION_TIMEOUT,
|
||||||
|
DB_RETRY_ATTEMPTS
|
||||||
|
)
|
||||||
|
from ..core.exceptions import DatabaseError, create_detection_error
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatabaseConfig:
|
||||||
|
"""Database connection configuration."""
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
database: str
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
schema: str = "gas_station_1"
|
||||||
|
connection_timeout: int = DB_CONNECTION_TIMEOUT
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"host": self.host,
|
||||||
|
"port": self.port,
|
||||||
|
"database": self.database,
|
||||||
|
"username": self.username,
|
||||||
|
"password": self.password,
|
||||||
|
"schema": self.schema,
|
||||||
|
"connection_timeout": self.connection_timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'DatabaseConfig':
|
||||||
|
"""Create from dictionary data."""
|
||||||
|
return cls(
|
||||||
|
host=data["host"],
|
||||||
|
port=data["port"],
|
||||||
|
database=data["database"],
|
||||||
|
username=data.get("username", data.get("user", "")),
|
||||||
|
password=data["password"],
|
||||||
|
schema=data.get("schema", "gas_station_1"),
|
||||||
|
connection_timeout=data.get("connection_timeout", DB_CONNECTION_TIMEOUT)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueryResult:
|
||||||
|
"""Result from database query execution."""
|
||||||
|
success: bool
|
||||||
|
affected_rows: int = 0
|
||||||
|
data: Optional[List[Dict[str, Any]]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
query: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
result = {
|
||||||
|
"success": self.success,
|
||||||
|
"affected_rows": self.affected_rows
|
||||||
|
}
|
||||||
|
if self.data:
|
||||||
|
result["data"] = self.data
|
||||||
|
if self.error:
|
||||||
|
result["error"] = self.error
|
||||||
|
if self.query:
|
||||||
|
result["query"] = self.query
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseManager:
|
||||||
|
"""
|
||||||
|
Comprehensive database manager for PostgreSQL operations.
|
||||||
|
|
||||||
|
This class provides connection management, schema operations,
|
||||||
|
and dynamic query execution with transaction support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Initialize database manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Database configuration dictionary
|
||||||
|
"""
|
||||||
|
if isinstance(config, dict):
|
||||||
|
self.config = DatabaseConfig.from_dict(config)
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.connection: Optional[psycopg2.extensions.connection] = None
|
||||||
|
self._lock = None
|
||||||
|
|
||||||
|
def _ensure_thread_safety(self):
|
||||||
|
"""Initialize thread safety if not already done."""
|
||||||
|
if self._lock is None:
|
||||||
|
import threading
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def connect(self) -> bool:
|
||||||
|
"""
|
||||||
|
Establish database connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection successful, False otherwise
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
if self.connection and not self.connection.closed:
|
||||||
|
# Connection already exists and is open
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.connection = psycopg2.connect(
|
||||||
|
host=self.config.host,
|
||||||
|
port=self.config.port,
|
||||||
|
database=self.config.database,
|
||||||
|
user=self.config.username,
|
||||||
|
password=self.config.password,
|
||||||
|
connect_timeout=self.config.connection_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set connection properties
|
||||||
|
self.connection.set_client_encoding('UTF8')
|
||||||
|
|
||||||
|
logger.info(f"PostgreSQL connection established successfully to {self.config.host}:{self.config.port}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to PostgreSQL: {e}")
|
||||||
|
self.connection = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
"""Close database connection."""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self.connection:
|
||||||
|
try:
|
||||||
|
self.connection.close()
|
||||||
|
logger.info("PostgreSQL connection closed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing PostgreSQL connection: {e}")
|
||||||
|
finally:
|
||||||
|
self.connection = None
|
||||||
|
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if database connection is active.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connected, False otherwise
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
if self.connection and not self.connection.closed:
|
||||||
|
cur = self.connection.cursor()
|
||||||
|
cur.execute("SELECT 1")
|
||||||
|
cur.fetchone()
|
||||||
|
cur.close()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Connection check failed: {e}")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _ensure_connected(self) -> bool:
|
||||||
|
"""Ensure database connection is active."""
|
||||||
|
if not self.is_connected():
|
||||||
|
return self.connect()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def execute_query(self,
|
||||||
|
query: str,
|
||||||
|
params: Optional[Tuple] = None,
|
||||||
|
fetch_results: bool = False) -> QueryResult:
|
||||||
|
"""
|
||||||
|
Execute a database query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: SQL query string
|
||||||
|
params: Query parameters
|
||||||
|
fetch_results: Whether to fetch and return results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueryResult with execution details
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if not self._ensure_connected():
|
||||||
|
return QueryResult(
|
||||||
|
success=False,
|
||||||
|
error="Failed to establish database connection",
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor = None
|
||||||
|
try:
|
||||||
|
cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||||
|
|
||||||
|
logger.debug(f"Executing query: {query}")
|
||||||
|
if params:
|
||||||
|
logger.debug(f"Query parameters: {params}")
|
||||||
|
|
||||||
|
cursor.execute(query, params)
|
||||||
|
affected_rows = cursor.rowcount
|
||||||
|
|
||||||
|
data = None
|
||||||
|
if fetch_results:
|
||||||
|
data = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
logger.debug(f"Query executed successfully, affected rows: {affected_rows}")
|
||||||
|
|
||||||
|
return QueryResult(
|
||||||
|
success=True,
|
||||||
|
affected_rows=affected_rows,
|
||||||
|
data=data,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Database query execution failed: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.debug(f"Failed query: {query}")
|
||||||
|
if params:
|
||||||
|
logger.debug(f"Query parameters: {params}")
|
||||||
|
|
||||||
|
if self.connection:
|
||||||
|
try:
|
||||||
|
self.connection.rollback()
|
||||||
|
except Exception as rollback_error:
|
||||||
|
logger.error(f"Rollback failed: {rollback_error}")
|
||||||
|
|
||||||
|
return QueryResult(
|
||||||
|
success=False,
|
||||||
|
error=error_msg,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if cursor:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
def create_schema_if_not_exists(self, schema_name: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Create database schema if it doesn't exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema_name: Schema name (uses config schema if not provided)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
schema_name = schema_name or self.config.schema
|
||||||
|
query = f"CREATE SCHEMA IF NOT EXISTS {schema_name}"
|
||||||
|
result = self.execute_query(query)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
logger.info(f"Schema '{schema_name}' created or verified successfully")
|
||||||
|
|
||||||
|
return result.success
|
||||||
|
|
||||||
|
def create_car_frontal_info_table(self) -> bool:
|
||||||
|
"""
|
||||||
|
Create the car_frontal_info table in the configured schema if it doesn't exist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
schema_name = self.config.schema
|
||||||
|
|
||||||
|
# Ensure schema exists
|
||||||
|
if not self.create_schema_if_not_exists(schema_name):
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create table if it doesn't exist
|
||||||
|
create_table_query = f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {schema_name}.car_frontal_info (
|
||||||
|
display_id VARCHAR(255),
|
||||||
|
captured_timestamp VARCHAR(255),
|
||||||
|
session_id VARCHAR(255) PRIMARY KEY,
|
||||||
|
license_character VARCHAR(255) DEFAULT NULL,
|
||||||
|
license_type VARCHAR(255) DEFAULT 'No model available',
|
||||||
|
car_brand VARCHAR(255) DEFAULT NULL,
|
||||||
|
car_model VARCHAR(255) DEFAULT NULL,
|
||||||
|
car_body_type VARCHAR(255) DEFAULT NULL,
|
||||||
|
created_at TIMESTAMP DEFAULT NOW(),
|
||||||
|
updated_at TIMESTAMP DEFAULT NOW()
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = self.execute_query(create_table_query)
|
||||||
|
if not result.success:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Add columns if they don't exist (for existing tables)
|
||||||
|
alter_queries = [
|
||||||
|
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_brand VARCHAR(255) DEFAULT NULL",
|
||||||
|
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_model VARCHAR(255) DEFAULT NULL",
|
||||||
|
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_body_type VARCHAR(255) DEFAULT NULL",
|
||||||
|
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS created_at TIMESTAMP DEFAULT NOW()",
|
||||||
|
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT NOW()"
|
||||||
|
]
|
||||||
|
|
||||||
|
for alter_query in alter_queries:
|
||||||
|
try:
|
||||||
|
alter_result = self.execute_query(alter_query)
|
||||||
|
if alter_result.success:
|
||||||
|
logger.debug(f"Executed: {alter_query}")
|
||||||
|
else:
|
||||||
|
# Check if it's just a "column already exists" error
|
||||||
|
if "already exists" not in (alter_result.error or "").lower():
|
||||||
|
logger.warning(f"ALTER TABLE failed: {alter_result.error}")
|
||||||
|
except Exception as e:
|
||||||
|
# Ignore errors if column already exists (for older PostgreSQL versions)
|
||||||
|
if "already exists" in str(e).lower():
|
||||||
|
logger.debug(f"Column already exists, skipping: {alter_query}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Error in ALTER TABLE: {e}")
|
||||||
|
|
||||||
|
logger.info("Successfully created/verified car_frontal_info table with all required columns")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create car_frontal_info table: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def insert_initial_detection(self,
|
||||||
|
display_id: str,
|
||||||
|
captured_timestamp: str,
|
||||||
|
session_id: Optional[str] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Insert initial detection record and return the session_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
display_id: Display identifier
|
||||||
|
captured_timestamp: Timestamp of capture
|
||||||
|
session_id: Optional session ID (generated if not provided)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session ID if successful, None otherwise
|
||||||
|
"""
|
||||||
|
# Generate session_id if not provided
|
||||||
|
if not session_id:
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ensure table exists
|
||||||
|
if not self.create_car_frontal_info_table():
|
||||||
|
logger.error("Failed to create/verify table before insertion")
|
||||||
|
return None
|
||||||
|
|
||||||
|
schema_name = self.config.schema
|
||||||
|
insert_query = f"""
|
||||||
|
INSERT INTO {schema_name}.car_frontal_info
|
||||||
|
(display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type, created_at)
|
||||||
|
VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL, NOW())
|
||||||
|
ON CONFLICT (session_id) DO NOTHING
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = self.execute_query(insert_query, (display_id, captured_timestamp, session_id))
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
logger.info(f"Inserted initial detection record with session_id: {session_id}")
|
||||||
|
return session_id
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to insert initial detection record: {result.error}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to insert initial detection record: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def execute_update(self,
|
||||||
|
table: str,
|
||||||
|
key_field: str,
|
||||||
|
key_value: str,
|
||||||
|
fields: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Execute dynamic update/insert operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table: Table name
|
||||||
|
key_field: Primary key field name
|
||||||
|
key_value: Primary key value
|
||||||
|
fields: Dictionary of fields to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Add schema prefix if table doesn't already have it
|
||||||
|
if '.' not in table:
|
||||||
|
table = f"{self.config.schema}.{table}"
|
||||||
|
|
||||||
|
# Build the INSERT and UPDATE query dynamically
|
||||||
|
insert_placeholders = []
|
||||||
|
insert_values = [key_value] # Start with key_value
|
||||||
|
|
||||||
|
set_clauses = []
|
||||||
|
update_values = []
|
||||||
|
|
||||||
|
for field, value in fields.items():
|
||||||
|
if value == "NOW()":
|
||||||
|
# Special handling for NOW()
|
||||||
|
insert_placeholders.append("NOW()")
|
||||||
|
set_clauses.append(f"{field} = NOW()")
|
||||||
|
else:
|
||||||
|
insert_placeholders.append("%s")
|
||||||
|
insert_values.append(value)
|
||||||
|
set_clauses.append(f"{field} = %s")
|
||||||
|
update_values.append(value)
|
||||||
|
|
||||||
|
# Build the complete query
|
||||||
|
query = f"""
|
||||||
|
INSERT INTO {table} ({key_field}, {', '.join(fields.keys())})
|
||||||
|
VALUES (%s, {', '.join(insert_placeholders)})
|
||||||
|
ON CONFLICT ({key_field})
|
||||||
|
DO UPDATE SET {', '.join(set_clauses)}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Combine values for the query: insert_values + update_values
|
||||||
|
all_values = tuple(insert_values + update_values)
|
||||||
|
|
||||||
|
result = self.execute_query(query, all_values)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
logger.info(f"✅ Updated {table} for {key_field}={key_value} with {len(fields)} fields")
|
||||||
|
logger.debug(f"Updated fields: {fields}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"❌ Failed to update {table}: {result.error}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Failed to execute update on {table}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def update_car_info(self,
|
||||||
|
session_id: str,
|
||||||
|
brand: str,
|
||||||
|
model: str,
|
||||||
|
body_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
Update car information for a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session identifier
|
||||||
|
brand: Car brand
|
||||||
|
model: Car model
|
||||||
|
body_type: Car body type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
schema_name = self.config.schema
|
||||||
|
query = f"""
|
||||||
|
INSERT INTO {schema_name}.car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at)
|
||||||
|
VALUES (%s, %s, %s, %s, NOW())
|
||||||
|
ON CONFLICT (session_id)
|
||||||
|
DO UPDATE SET
|
||||||
|
car_brand = EXCLUDED.car_brand,
|
||||||
|
car_model = EXCLUDED.car_model,
|
||||||
|
car_body_type = EXCLUDED.car_body_type,
|
||||||
|
updated_at = NOW()
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = self.execute_query(query, (session_id, brand, model, body_type))
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})")
|
||||||
|
|
||||||
|
return result.success
|
||||||
|
|
||||||
|
def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get session data by session ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session data dictionary or None if not found
|
||||||
|
"""
|
||||||
|
schema_name = self.config.schema
|
||||||
|
query = f"SELECT * FROM {schema_name}.car_frontal_info WHERE session_id = %s"
|
||||||
|
|
||||||
|
result = self.execute_query(query, (session_id,), fetch_results=True)
|
||||||
|
|
||||||
|
if result.success and result.data:
|
||||||
|
return result.data[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_connection_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get database connection statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with connection statistics
|
||||||
|
"""
|
||||||
|
stats = {
|
||||||
|
"connected": self.is_connected(),
|
||||||
|
"config": self.config.to_dict(),
|
||||||
|
"connection_closed": self.connection.closed if self.connection else True
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.is_connected():
|
||||||
|
try:
|
||||||
|
# Get database version and basic stats
|
||||||
|
version_result = self.execute_query("SELECT version()", fetch_results=True)
|
||||||
|
if version_result.success and version_result.data:
|
||||||
|
stats["database_version"] = version_result.data[0]["version"]
|
||||||
|
|
||||||
|
# Get current database name
|
||||||
|
db_result = self.execute_query("SELECT current_database()", fetch_results=True)
|
||||||
|
if db_result.success and db_result.data:
|
||||||
|
stats["current_database"] = db_result.data[0]["current_database"]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
stats["stats_error"] = str(e)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide compatibility with the original database.py interface
|
||||||
|
|
||||||
|
def validate_postgresql_config(config: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Validate PostgreSQL configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if configuration is valid
|
||||||
|
"""
|
||||||
|
required_fields = ["host", "port", "database", "password"]
|
||||||
|
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in config:
|
||||||
|
logger.error(f"Missing required PostgreSQL config field: {field}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not config[field]:
|
||||||
|
logger.error(f"Empty PostgreSQL config field: {field}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for username (could be 'username' or 'user')
|
||||||
|
if not config.get("username") and not config.get("user"):
|
||||||
|
logger.error("Missing PostgreSQL username (provide either 'username' or 'user')")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate port is numeric
|
||||||
|
try:
|
||||||
|
port = int(config["port"])
|
||||||
|
if port <= 0 or port > 65535:
|
||||||
|
logger.error(f"Invalid PostgreSQL port: {port}")
|
||||||
|
return False
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.error(f"PostgreSQL port must be numeric: {config['port']}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def create_database_manager(config: Dict[str, Any]) -> Optional[DatabaseManager]:
|
||||||
|
"""
|
||||||
|
Create database manager with configuration validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Database configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseManager instance or None if invalid config
|
||||||
|
"""
|
||||||
|
if not validate_postgresql_config(config):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
db_manager = DatabaseManager(config)
|
||||||
|
if db_manager.connect():
|
||||||
|
logger.info(f"Successfully created database manager for {config['host']}:{config['port']}")
|
||||||
|
return db_manager
|
||||||
|
else:
|
||||||
|
logger.error("Failed to establish initial database connection")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating database manager: {e}")
|
||||||
|
return None
|
733
detector_worker/storage/redis_client.py
Normal file
733
detector_worker/storage/redis_client.py
Normal file
|
@ -0,0 +1,733 @@
|
||||||
|
"""
|
||||||
|
Redis client management and operations.
|
||||||
|
|
||||||
|
This module provides comprehensive Redis functionality including:
|
||||||
|
- Connection management with retries
|
||||||
|
- Key-value operations
|
||||||
|
- Pub/Sub messaging
|
||||||
|
- Image storage with compression
|
||||||
|
- Pipeline operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import redis
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional, List, Union, Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from ..core.constants import (
|
||||||
|
REDIS_CONNECTION_TIMEOUT,
|
||||||
|
REDIS_SOCKET_TIMEOUT,
|
||||||
|
REDIS_IMAGE_DEFAULT_QUALITY,
|
||||||
|
REDIS_IMAGE_DEFAULT_FORMAT
|
||||||
|
)
|
||||||
|
from ..core.exceptions import RedisError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RedisConfig:
|
||||||
|
"""Redis connection configuration."""
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
password: Optional[str] = None
|
||||||
|
db: int = 0
|
||||||
|
connection_timeout: int = REDIS_CONNECTION_TIMEOUT
|
||||||
|
socket_timeout: int = REDIS_SOCKET_TIMEOUT
|
||||||
|
retry_on_timeout: bool = True
|
||||||
|
health_check_interval: int = 30
|
||||||
|
max_connections: int = 10
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"host": self.host,
|
||||||
|
"port": self.port,
|
||||||
|
"password": self.password,
|
||||||
|
"db": self.db,
|
||||||
|
"connection_timeout": self.connection_timeout,
|
||||||
|
"socket_timeout": self.socket_timeout,
|
||||||
|
"retry_on_timeout": self.retry_on_timeout,
|
||||||
|
"health_check_interval": self.health_check_interval,
|
||||||
|
"max_connections": self.max_connections
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
|
||||||
|
"""Create from dictionary data."""
|
||||||
|
return cls(
|
||||||
|
host=data["host"],
|
||||||
|
port=data["port"],
|
||||||
|
password=data.get("password"),
|
||||||
|
db=data.get("db", 0),
|
||||||
|
connection_timeout=data.get("connection_timeout", REDIS_CONNECTION_TIMEOUT),
|
||||||
|
socket_timeout=data.get("socket_timeout", REDIS_SOCKET_TIMEOUT),
|
||||||
|
retry_on_timeout=data.get("retry_on_timeout", True),
|
||||||
|
health_check_interval=data.get("health_check_interval", 30),
|
||||||
|
max_connections=data.get("max_connections", 10)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RedisOperationResult:
|
||||||
|
"""Result from Redis operation."""
|
||||||
|
success: bool
|
||||||
|
data: Optional[Any] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
operation: Optional[str] = None
|
||||||
|
key: Optional[str] = None
|
||||||
|
execution_time: Optional[float] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
result = {
|
||||||
|
"success": self.success,
|
||||||
|
"operation": self.operation
|
||||||
|
}
|
||||||
|
if self.data is not None:
|
||||||
|
result["data"] = self.data
|
||||||
|
if self.error:
|
||||||
|
result["error"] = self.error
|
||||||
|
if self.key:
|
||||||
|
result["key"] = self.key
|
||||||
|
if self.execution_time:
|
||||||
|
result["execution_time"] = self.execution_time
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RedisClientManager:
|
||||||
|
"""
|
||||||
|
Comprehensive Redis client manager with connection pooling and retry logic.
|
||||||
|
|
||||||
|
This class provides high-level Redis operations with automatic reconnection,
|
||||||
|
connection pooling, and comprehensive error handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Union[Dict[str, Any], RedisConfig]):
|
||||||
|
"""
|
||||||
|
Initialize Redis client manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Redis configuration dictionary or RedisConfig object
|
||||||
|
"""
|
||||||
|
if isinstance(config, dict):
|
||||||
|
self.config = RedisConfig.from_dict(config)
|
||||||
|
else:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.client: Optional[redis.Redis] = None
|
||||||
|
self.connection_pool: Optional[redis.ConnectionPool] = None
|
||||||
|
self._lock = None
|
||||||
|
self._last_health_check = 0.0
|
||||||
|
|
||||||
|
def _ensure_thread_safety(self):
|
||||||
|
"""Initialize thread safety if not already done."""
|
||||||
|
if self._lock is None:
|
||||||
|
import threading
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _create_connection_pool(self) -> redis.ConnectionPool:
|
||||||
|
"""Create Redis connection pool."""
|
||||||
|
pool_kwargs = {
|
||||||
|
"host": self.config.host,
|
||||||
|
"port": self.config.port,
|
||||||
|
"db": self.config.db,
|
||||||
|
"socket_timeout": self.config.socket_timeout,
|
||||||
|
"socket_connect_timeout": self.config.connection_timeout,
|
||||||
|
"retry_on_timeout": self.config.retry_on_timeout,
|
||||||
|
"health_check_interval": self.config.health_check_interval,
|
||||||
|
"max_connections": self.config.max_connections
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.config.password:
|
||||||
|
pool_kwargs["password"] = self.config.password
|
||||||
|
|
||||||
|
return redis.ConnectionPool(**pool_kwargs)
|
||||||
|
|
||||||
|
def connect(self) -> bool:
|
||||||
|
"""
|
||||||
|
Establish Redis connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection successful, False otherwise
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
# Create connection pool
|
||||||
|
self.connection_pool = self._create_connection_pool()
|
||||||
|
|
||||||
|
# Create Redis client
|
||||||
|
self.client = redis.Redis(connection_pool=self.connection_pool)
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
self.client.ping()
|
||||||
|
self._last_health_check = time.time()
|
||||||
|
|
||||||
|
logger.info(f"Redis connection established successfully to {self.config.host}:{self.config.port}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Redis: {e}")
|
||||||
|
self.client = None
|
||||||
|
self.connection_pool = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
"""Close Redis connection and cleanup resources."""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self.connection_pool:
|
||||||
|
try:
|
||||||
|
self.connection_pool.disconnect()
|
||||||
|
logger.info("Redis connection pool disconnected")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error disconnecting Redis pool: {e}")
|
||||||
|
finally:
|
||||||
|
self.connection_pool = None
|
||||||
|
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if Redis connection is active.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connected, False otherwise
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
if self.client:
|
||||||
|
# Perform periodic health check
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - self._last_health_check > self.config.health_check_interval:
|
||||||
|
self.client.ping()
|
||||||
|
self._last_health_check = current_time
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Redis connection check failed: {e}")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _ensure_connected(self) -> bool:
|
||||||
|
"""Ensure Redis connection is active."""
|
||||||
|
if not self.is_connected():
|
||||||
|
return self.connect()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _execute_operation(self,
|
||||||
|
operation_name: str,
|
||||||
|
operation_func,
|
||||||
|
key: Optional[str] = None) -> RedisOperationResult:
|
||||||
|
"""Execute Redis operation with error handling and timing."""
|
||||||
|
if not self._ensure_connected():
|
||||||
|
return RedisOperationResult(
|
||||||
|
success=False,
|
||||||
|
error="Failed to establish Redis connection",
|
||||||
|
operation=operation_name,
|
||||||
|
key=key
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
result = operation_func()
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
|
||||||
|
return RedisOperationResult(
|
||||||
|
success=True,
|
||||||
|
data=result,
|
||||||
|
operation=operation_name,
|
||||||
|
key=key,
|
||||||
|
execution_time=execution_time
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
error_msg = f"Redis {operation_name} operation failed: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
return RedisOperationResult(
|
||||||
|
success=False,
|
||||||
|
error=error_msg,
|
||||||
|
operation=operation_name,
|
||||||
|
key=key,
|
||||||
|
execution_time=execution_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# ===== KEY-VALUE OPERATIONS =====
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any, ex: Optional[int] = None) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Set key-value pair with optional expiration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
value: Value to store
|
||||||
|
ex: Expiration time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with operation status
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.set(key, value, ex=ex)
|
||||||
|
|
||||||
|
result = self._execute_operation("SET", operation, key)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
expire_msg = f" (expires in {ex}s)" if ex else ""
|
||||||
|
logger.debug(f"Set Redis key '{key}'{expire_msg}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def setex(self, key: str, time: int, value: Any) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Set key-value pair with expiration time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
time: Expiration time in seconds
|
||||||
|
value: Value to store
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with operation status
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.setex(key, time, value)
|
||||||
|
|
||||||
|
result = self._execute_operation("SETEX", operation, key)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
logger.debug(f"Set Redis key '{key}' with {time}s expiration")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get(self, key: str) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Get value by key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with value data
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.get(key)
|
||||||
|
|
||||||
|
return self._execute_operation("GET", operation, key)
|
||||||
|
|
||||||
|
def delete(self, *keys: str) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Delete one or more keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*keys: Redis keys to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with number of deleted keys
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.delete(*keys)
|
||||||
|
|
||||||
|
result = self._execute_operation("DELETE", operation)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
logger.debug(f"Deleted {result.data} Redis key(s): {keys}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def exists(self, *keys: str) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Check if keys exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*keys: Redis keys to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with count of existing keys
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.exists(*keys)
|
||||||
|
|
||||||
|
return self._execute_operation("EXISTS", operation)
|
||||||
|
|
||||||
|
def expire(self, key: str, time: int) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Set expiration time for a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
time: Expiration time in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with operation status
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.expire(key, time)
|
||||||
|
|
||||||
|
return self._execute_operation("EXPIRE", operation, key)
|
||||||
|
|
||||||
|
def ttl(self, key: str) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Get time to live for a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with TTL in seconds
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.ttl(key)
|
||||||
|
|
||||||
|
return self._execute_operation("TTL", operation, key)
|
||||||
|
|
||||||
|
# ===== PUB/SUB OPERATIONS =====
|
||||||
|
|
||||||
|
def publish(self, channel: str, message: Union[str, Dict[str, Any]]) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Publish message to channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: Channel name
|
||||||
|
message: Message to publish (string or dict)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with number of subscribers
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
# Convert dict to JSON string
|
||||||
|
if isinstance(message, dict):
|
||||||
|
message_str = json.dumps(message)
|
||||||
|
else:
|
||||||
|
message_str = str(message)
|
||||||
|
|
||||||
|
return self.client.publish(channel, message_str)
|
||||||
|
|
||||||
|
result = self._execute_operation("PUBLISH", operation)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
logger.debug(f"Published to channel '{channel}', subscribers: {result.data}")
|
||||||
|
if result.data == 0:
|
||||||
|
logger.warning(f"No subscribers listening to channel '{channel}'")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def subscribe(self, *channels: str):
|
||||||
|
"""
|
||||||
|
Subscribe to channels (returns pubsub object).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*channels: Channel names to subscribe to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PubSub object for listening to messages
|
||||||
|
"""
|
||||||
|
if not self._ensure_connected():
|
||||||
|
raise RedisError("Failed to establish Redis connection for subscription")
|
||||||
|
|
||||||
|
pubsub = self.client.pubsub()
|
||||||
|
pubsub.subscribe(*channels)
|
||||||
|
|
||||||
|
logger.info(f"Subscribed to channels: {channels}")
|
||||||
|
return pubsub
|
||||||
|
|
||||||
|
# ===== HASH OPERATIONS =====
|
||||||
|
|
||||||
|
def hset(self, key: str, field: str, value: Any) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Set hash field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
field: Hash field name
|
||||||
|
value: Field value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with operation status
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.hset(key, field, value)
|
||||||
|
|
||||||
|
return self._execute_operation("HSET", operation, key)
|
||||||
|
|
||||||
|
def hget(self, key: str, field: str) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Get hash field value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
field: Hash field name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with field value
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.hget(key, field)
|
||||||
|
|
||||||
|
return self._execute_operation("HGET", operation, key)
|
||||||
|
|
||||||
|
def hgetall(self, key: str) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Get all hash fields and values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with hash dictionary
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.hgetall(key)
|
||||||
|
|
||||||
|
return self._execute_operation("HGETALL", operation, key)
|
||||||
|
|
||||||
|
# ===== LIST OPERATIONS =====
|
||||||
|
|
||||||
|
def lpush(self, key: str, *values: Any) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Push values to left of list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
*values: Values to push
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with list length
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.lpush(key, *values)
|
||||||
|
|
||||||
|
return self._execute_operation("LPUSH", operation, key)
|
||||||
|
|
||||||
|
def rpush(self, key: str, *values: Any) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Push values to right of list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
*values: Values to push
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with list length
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.rpush(key, *values)
|
||||||
|
|
||||||
|
return self._execute_operation("RPUSH", operation, key)
|
||||||
|
|
||||||
|
def lrange(self, key: str, start: int, end: int) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Get range of list elements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis key
|
||||||
|
start: Start index
|
||||||
|
end: End index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with list elements
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.lrange(key, start, end)
|
||||||
|
|
||||||
|
return self._execute_operation("LRANGE", operation, key)
|
||||||
|
|
||||||
|
# ===== UTILITY OPERATIONS =====
|
||||||
|
|
||||||
|
def ping(self) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Ping Redis server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with ping response
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.ping()
|
||||||
|
|
||||||
|
return self._execute_operation("PING", operation)
|
||||||
|
|
||||||
|
def info(self, section: Optional[str] = None) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Get Redis server information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
section: Optional info section
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with server info
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.info(section)
|
||||||
|
|
||||||
|
return self._execute_operation("INFO", operation)
|
||||||
|
|
||||||
|
def flushdb(self) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Flush current database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with operation status
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.flushdb()
|
||||||
|
|
||||||
|
result = self._execute_operation("FLUSHDB", operation)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
logger.warning(f"Flushed Redis database {self.config.db}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def keys(self, pattern: str = "*") -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Get keys matching pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Key pattern (default: all keys)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with matching keys
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return self.client.keys(pattern)
|
||||||
|
|
||||||
|
return self._execute_operation("KEYS", operation)
|
||||||
|
|
||||||
|
# ===== BATCH OPERATIONS =====
|
||||||
|
|
||||||
|
def pipeline(self):
|
||||||
|
"""
|
||||||
|
Create Redis pipeline for batch operations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redis pipeline object
|
||||||
|
"""
|
||||||
|
if not self._ensure_connected():
|
||||||
|
raise RedisError("Failed to establish Redis connection for pipeline")
|
||||||
|
|
||||||
|
return self.client.pipeline()
|
||||||
|
|
||||||
|
def execute_pipeline(self, pipeline) -> RedisOperationResult:
|
||||||
|
"""
|
||||||
|
Execute Redis pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline: Redis pipeline object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisOperationResult with pipeline results
|
||||||
|
"""
|
||||||
|
def operation():
|
||||||
|
return pipeline.execute()
|
||||||
|
|
||||||
|
return self._execute_operation("PIPELINE", operation)
|
||||||
|
|
||||||
|
# ===== CONNECTION MANAGEMENT =====
|
||||||
|
|
||||||
|
def get_connection_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get Redis connection statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with connection statistics
|
||||||
|
"""
|
||||||
|
stats = {
|
||||||
|
"connected": self.is_connected(),
|
||||||
|
"config": self.config.to_dict(),
|
||||||
|
"last_health_check": self._last_health_check,
|
||||||
|
"connection_pool_created": self.connection_pool is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.connection_pool:
|
||||||
|
stats["connection_pool_stats"] = {
|
||||||
|
"created_connections": self.connection_pool.created_connections,
|
||||||
|
"available_connections": len(self.connection_pool._available_connections),
|
||||||
|
"in_use_connections": len(self.connection_pool._in_use_connections)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get Redis server info if connected
|
||||||
|
if self.is_connected():
|
||||||
|
try:
|
||||||
|
info_result = self.info()
|
||||||
|
if info_result.success:
|
||||||
|
redis_info = info_result.data
|
||||||
|
stats["server_info"] = {
|
||||||
|
"redis_version": redis_info.get("redis_version"),
|
||||||
|
"connected_clients": redis_info.get("connected_clients"),
|
||||||
|
"used_memory": redis_info.get("used_memory"),
|
||||||
|
"used_memory_human": redis_info.get("used_memory_human"),
|
||||||
|
"total_commands_processed": redis_info.get("total_commands_processed"),
|
||||||
|
"uptime_in_seconds": redis_info.get("uptime_in_seconds")
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
stats["server_info_error"] = str(e)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide compatibility and simplified access
|
||||||
|
|
||||||
|
def create_redis_client(config: Dict[str, Any]) -> Optional[RedisClientManager]:
|
||||||
|
"""
|
||||||
|
Create Redis client with configuration validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Redis configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RedisClientManager instance or None if connection failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client_manager = RedisClientManager(config)
|
||||||
|
if client_manager.connect():
|
||||||
|
logger.info(f"Successfully created Redis client for {config['host']}:{config['port']}")
|
||||||
|
return client_manager
|
||||||
|
else:
|
||||||
|
logger.error("Failed to establish initial Redis connection")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating Redis client: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_redis_config(config: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Validate Redis configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if configuration is valid
|
||||||
|
"""
|
||||||
|
required_fields = ["host", "port"]
|
||||||
|
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in config:
|
||||||
|
logger.error(f"Missing required Redis config field: {field}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not config[field]:
|
||||||
|
logger.error(f"Empty Redis config field: {field}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Validate port is numeric
|
||||||
|
try:
|
||||||
|
port = int(config["port"])
|
||||||
|
if port <= 0 or port > 65535:
|
||||||
|
logger.error(f"Invalid Redis port: {port}")
|
||||||
|
return False
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.error(f"Redis port must be numeric: {config['port']}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
688
detector_worker/storage/session_cache.py
Normal file
688
detector_worker/storage/session_cache.py
Normal file
|
@ -0,0 +1,688 @@
|
||||||
|
"""
|
||||||
|
Session and cache management for detection workflows.
|
||||||
|
|
||||||
|
This module provides comprehensive session and cache management including:
|
||||||
|
- Session state tracking and lifecycle management
|
||||||
|
- Detection result caching with TTL
|
||||||
|
- Pipeline mode state management
|
||||||
|
- Session cleanup and garbage collection
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Any, Optional, Set, Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from ..core.constants import (
|
||||||
|
SESSION_CACHE_TTL_MINUTES,
|
||||||
|
DETECTION_CACHE_CLEANUP_INTERVAL,
|
||||||
|
SESSION_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
|
from ..core.exceptions import SessionError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMode(Enum):
|
||||||
|
"""Pipeline execution modes."""
|
||||||
|
VALIDATION_DETECTING = "validation_detecting"
|
||||||
|
SEND_DETECTIONS = "send_detections"
|
||||||
|
WAITING_FOR_SESSION_ID = "waiting_for_session_id"
|
||||||
|
FULL_PIPELINE = "full_pipeline"
|
||||||
|
LIGHTWEIGHT = "lightweight"
|
||||||
|
CAR_GONE_WAITING = "car_gone_waiting"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionState:
|
||||||
|
"""Session state information."""
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
backend_session_id: Optional[str] = None
|
||||||
|
mode: PipelineMode = PipelineMode.VALIDATION_DETECTING
|
||||||
|
session_id_received: bool = False
|
||||||
|
created_at: float = field(default_factory=time.time)
|
||||||
|
last_updated: float = field(default_factory=time.time)
|
||||||
|
last_detection: Optional[float] = None
|
||||||
|
detection_count: int = 0
|
||||||
|
|
||||||
|
# Mode-specific state
|
||||||
|
validation_frames_processed: int = 0
|
||||||
|
stable_track_achieved: bool = False
|
||||||
|
waiting_start_time: Optional[float] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"backend_session_id": self.backend_session_id,
|
||||||
|
"mode": self.mode.value,
|
||||||
|
"session_id_received": self.session_id_received,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"last_updated": self.last_updated,
|
||||||
|
"last_detection": self.last_detection,
|
||||||
|
"detection_count": self.detection_count,
|
||||||
|
"validation_frames_processed": self.validation_frames_processed,
|
||||||
|
"stable_track_achieved": self.stable_track_achieved,
|
||||||
|
"waiting_start_time": self.waiting_start_time
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_activity(self) -> None:
|
||||||
|
"""Update last activity timestamp."""
|
||||||
|
self.last_updated = time.time()
|
||||||
|
|
||||||
|
def record_detection(self) -> None:
|
||||||
|
"""Record a detection occurrence."""
|
||||||
|
current_time = time.time()
|
||||||
|
self.last_detection = current_time
|
||||||
|
self.detection_count += 1
|
||||||
|
self.update_activity()
|
||||||
|
|
||||||
|
def is_expired(self, ttl_seconds: int) -> bool:
|
||||||
|
"""Check if session has expired based on TTL."""
|
||||||
|
return time.time() - self.last_updated > ttl_seconds
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CachedDetection:
|
||||||
|
"""Cached detection result."""
|
||||||
|
detection_data: Dict[str, Any]
|
||||||
|
camera_id: str
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
created_at: float = field(default_factory=time.time)
|
||||||
|
access_count: int = 0
|
||||||
|
last_accessed: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"detection_data": self.detection_data,
|
||||||
|
"camera_id": self.camera_id,
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"access_count": self.access_count,
|
||||||
|
"last_accessed": self.last_accessed
|
||||||
|
}
|
||||||
|
|
||||||
|
def access(self) -> None:
|
||||||
|
"""Record access to this cached detection."""
|
||||||
|
self.access_count += 1
|
||||||
|
self.last_accessed = time.time()
|
||||||
|
|
||||||
|
def is_expired(self, ttl_seconds: int) -> bool:
|
||||||
|
"""Check if cached detection has expired."""
|
||||||
|
return time.time() - self.created_at > ttl_seconds
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
"""
|
||||||
|
Session lifecycle and state management.
|
||||||
|
|
||||||
|
This class provides comprehensive session management including:
|
||||||
|
- Session creation and cleanup
|
||||||
|
- Pipeline mode transitions
|
||||||
|
- Session timeout handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS):
|
||||||
|
"""
|
||||||
|
Initialize session manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_timeout_seconds: Default session timeout
|
||||||
|
"""
|
||||||
|
self.session_timeout_seconds = session_timeout_seconds
|
||||||
|
self._sessions: Dict[str, SessionState] = {}
|
||||||
|
self._session_ids: Dict[str, str] = {} # display_id -> session_id mapping
|
||||||
|
self._lock = None
|
||||||
|
|
||||||
|
def _ensure_thread_safety(self):
|
||||||
|
"""Initialize thread safety if not already done."""
|
||||||
|
if self._lock is None:
|
||||||
|
import threading
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def create_session(self, display_id: str, session_id: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Create a new session or get existing one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
display_id: Display identifier
|
||||||
|
session_id: Optional session ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session ID
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# Check if session already exists for this display
|
||||||
|
if display_id in self._session_ids:
|
||||||
|
existing_session_id = self._session_ids[display_id]
|
||||||
|
if existing_session_id in self._sessions:
|
||||||
|
session_state = self._sessions[existing_session_id]
|
||||||
|
session_state.update_activity()
|
||||||
|
logger.debug(f"Using existing session for display {display_id}: {existing_session_id}")
|
||||||
|
return existing_session_id
|
||||||
|
|
||||||
|
# Create new session
|
||||||
|
if not session_id:
|
||||||
|
import uuid
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
session_state = SessionState(session_id=session_id)
|
||||||
|
self._sessions[session_id] = session_state
|
||||||
|
self._session_ids[display_id] = session_id
|
||||||
|
|
||||||
|
logger.info(f"Created new session for display {display_id}: {session_id}")
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
def get_session(self, session_id: str) -> Optional[SessionState]:
|
||||||
|
"""
|
||||||
|
Get session state by session ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SessionState or None if not found
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
session = self._sessions.get(session_id)
|
||||||
|
if session:
|
||||||
|
session.update_activity()
|
||||||
|
return session
|
||||||
|
|
||||||
|
def get_session_by_display(self, display_id: str) -> Optional[SessionState]:
|
||||||
|
"""
|
||||||
|
Get session state by display ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
display_id: Display identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SessionState or None if not found
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
session_id = self._session_ids.get(display_id)
|
||||||
|
if session_id:
|
||||||
|
return self.get_session(session_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update_session_mode(self,
|
||||||
|
session_id: str,
|
||||||
|
new_mode: PipelineMode,
|
||||||
|
backend_session_id: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Update session pipeline mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session identifier
|
||||||
|
new_mode: New pipeline mode
|
||||||
|
backend_session_id: Optional backend session ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if updated successfully
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
session = self.get_session(session_id)
|
||||||
|
if not session:
|
||||||
|
logger.warning(f"Session not found for mode update: {session_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
old_mode = session.mode
|
||||||
|
session.mode = new_mode
|
||||||
|
|
||||||
|
if backend_session_id:
|
||||||
|
session.backend_session_id = backend_session_id
|
||||||
|
session.session_id_received = True
|
||||||
|
|
||||||
|
# Handle mode-specific state changes
|
||||||
|
if new_mode == PipelineMode.WAITING_FOR_SESSION_ID:
|
||||||
|
session.waiting_start_time = time.time()
|
||||||
|
elif old_mode == PipelineMode.WAITING_FOR_SESSION_ID:
|
||||||
|
session.waiting_start_time = None
|
||||||
|
|
||||||
|
session.update_activity()
|
||||||
|
|
||||||
|
logger.info(f"Session {session_id}: Mode changed from {old_mode.value} to {new_mode.value}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def record_detection(self, session_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Record a detection for a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if recorded successfully
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
session = self.get_session(session_id)
|
||||||
|
if session:
|
||||||
|
session.record_detection()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cleanup_expired_sessions(self, ttl_seconds: Optional[int] = None) -> int:
|
||||||
|
"""
|
||||||
|
Clean up expired sessions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ttl_seconds: TTL in seconds (uses default if not provided)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of sessions cleaned up
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
if ttl_seconds is None:
|
||||||
|
ttl_seconds = SESSION_CACHE_TTL_MINUTES * 60
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
expired_sessions = []
|
||||||
|
expired_displays = []
|
||||||
|
|
||||||
|
# Find expired sessions
|
||||||
|
for session_id, session in self._sessions.items():
|
||||||
|
if session.is_expired(ttl_seconds):
|
||||||
|
expired_sessions.append(session_id)
|
||||||
|
|
||||||
|
# Find displays pointing to expired sessions
|
||||||
|
for display_id, session_id in self._session_ids.items():
|
||||||
|
if session_id in expired_sessions:
|
||||||
|
expired_displays.append(display_id)
|
||||||
|
|
||||||
|
# Remove expired sessions and mappings
|
||||||
|
for session_id in expired_sessions:
|
||||||
|
del self._sessions[session_id]
|
||||||
|
|
||||||
|
for display_id in expired_displays:
|
||||||
|
del self._session_ids[display_id]
|
||||||
|
|
||||||
|
if expired_sessions:
|
||||||
|
logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")
|
||||||
|
|
||||||
|
return len(expired_sessions)
|
||||||
|
|
||||||
|
def get_session_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get session management statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with session statistics
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
current_time = time.time()
|
||||||
|
mode_counts = {}
|
||||||
|
total_detections = 0
|
||||||
|
|
||||||
|
for session in self._sessions.values():
|
||||||
|
mode = session.mode.value
|
||||||
|
mode_counts[mode] = mode_counts.get(mode, 0) + 1
|
||||||
|
total_detections += session.detection_count
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_sessions": len(self._sessions),
|
||||||
|
"total_display_mappings": len(self._session_ids),
|
||||||
|
"mode_distribution": mode_counts,
|
||||||
|
"total_detections_processed": total_detections,
|
||||||
|
"session_timeout_seconds": self.session_timeout_seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionCache:
|
||||||
|
"""
|
||||||
|
Detection result caching with TTL and access tracking.
|
||||||
|
|
||||||
|
This class provides caching for detection results with automatic
|
||||||
|
expiration and access pattern tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ttl_minutes: int = SESSION_CACHE_TTL_MINUTES):
|
||||||
|
"""
|
||||||
|
Initialize detection cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ttl_minutes: Time to live for cached detections in minutes
|
||||||
|
"""
|
||||||
|
self.ttl_seconds = ttl_minutes * 60
|
||||||
|
self._cache: Dict[str, CachedDetection] = {}
|
||||||
|
self._camera_index: Dict[str, Set[str]] = {} # camera_id -> set of cache keys
|
||||||
|
self._session_index: Dict[str, Set[str]] = {} # session_id -> set of cache keys
|
||||||
|
self._lock = None
|
||||||
|
self._last_cleanup = time.time()
|
||||||
|
|
||||||
|
def _ensure_thread_safety(self):
|
||||||
|
"""Initialize thread safety if not already done."""
|
||||||
|
if self._lock is None:
|
||||||
|
import threading
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _generate_cache_key(self, camera_id: str, detection_type: str = "default") -> str:
|
||||||
|
"""Generate cache key for detection."""
|
||||||
|
return f"{camera_id}:{detection_type}:{time.time()}"
|
||||||
|
|
||||||
|
def store_detection(self,
|
||||||
|
camera_id: str,
|
||||||
|
detection_data: Dict[str, Any],
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
detection_type: str = "default") -> str:
|
||||||
|
"""
|
||||||
|
Store detection in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
detection_data: Detection result data
|
||||||
|
session_id: Optional session identifier
|
||||||
|
detection_type: Type of detection for categorization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache key for the stored detection
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
cache_key = self._generate_cache_key(camera_id, detection_type)
|
||||||
|
|
||||||
|
cached_detection = CachedDetection(
|
||||||
|
detection_data=detection_data,
|
||||||
|
camera_id=camera_id,
|
||||||
|
session_id=session_id
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cache[cache_key] = cached_detection
|
||||||
|
|
||||||
|
# Update indexes
|
||||||
|
if camera_id not in self._camera_index:
|
||||||
|
self._camera_index[camera_id] = set()
|
||||||
|
self._camera_index[camera_id].add(cache_key)
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
if session_id not in self._session_index:
|
||||||
|
self._session_index[session_id] = set()
|
||||||
|
self._session_index[session_id].add(cache_key)
|
||||||
|
|
||||||
|
logger.debug(f"Stored detection in cache: {cache_key}")
|
||||||
|
return cache_key
|
||||||
|
|
||||||
|
def get_detection(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get detection from cache by key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Detection data or None if not found/expired
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
cached_detection = self._cache.get(cache_key)
|
||||||
|
if not cached_detection:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cached_detection.is_expired(self.ttl_seconds):
|
||||||
|
self._remove_from_cache(cache_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
cached_detection.access()
|
||||||
|
return cached_detection.detection_data
|
||||||
|
|
||||||
|
def get_latest_detection(self, camera_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get latest detection for a camera.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: Camera identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Latest detection data or None if not found
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
camera_keys = self._camera_index.get(camera_id, set())
|
||||||
|
if not camera_keys:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the most recent non-expired detection
|
||||||
|
latest_detection = None
|
||||||
|
latest_time = 0
|
||||||
|
|
||||||
|
for cache_key in camera_keys:
|
||||||
|
cached_detection = self._cache.get(cache_key)
|
||||||
|
if cached_detection and not cached_detection.is_expired(self.ttl_seconds):
|
||||||
|
if cached_detection.created_at > latest_time:
|
||||||
|
latest_time = cached_detection.created_at
|
||||||
|
latest_detection = cached_detection
|
||||||
|
|
||||||
|
if latest_detection:
|
||||||
|
latest_detection.access()
|
||||||
|
return latest_detection.detection_data
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_session_detections(self, session_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get all detections for a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Session identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of detection data dictionaries
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
session_keys = self._session_index.get(session_id, set())
|
||||||
|
detections = []
|
||||||
|
|
||||||
|
for cache_key in session_keys:
|
||||||
|
cached_detection = self._cache.get(cache_key)
|
||||||
|
if cached_detection and not cached_detection.is_expired(self.ttl_seconds):
|
||||||
|
cached_detection.access()
|
||||||
|
detections.append(cached_detection.detection_data)
|
||||||
|
|
||||||
|
# Sort by creation time (newest first)
|
||||||
|
detections.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def _remove_from_cache(self, cache_key: str) -> None:
|
||||||
|
"""Remove detection from cache and indexes."""
|
||||||
|
cached_detection = self._cache.get(cache_key)
|
||||||
|
if cached_detection:
|
||||||
|
# Remove from indexes
|
||||||
|
camera_id = cached_detection.camera_id
|
||||||
|
if camera_id in self._camera_index:
|
||||||
|
self._camera_index[camera_id].discard(cache_key)
|
||||||
|
if not self._camera_index[camera_id]:
|
||||||
|
del self._camera_index[camera_id]
|
||||||
|
|
||||||
|
session_id = cached_detection.session_id
|
||||||
|
if session_id and session_id in self._session_index:
|
||||||
|
self._session_index[session_id].discard(cache_key)
|
||||||
|
if not self._session_index[session_id]:
|
||||||
|
del self._session_index[session_id]
|
||||||
|
|
||||||
|
# Remove from main cache
|
||||||
|
if cache_key in self._cache:
|
||||||
|
del self._cache[cache_key]
|
||||||
|
|
||||||
|
def cleanup_expired_detections(self) -> int:
|
||||||
|
"""
|
||||||
|
Clean up expired cached detections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of detections cleaned up
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
expired_keys = []
|
||||||
|
|
||||||
|
# Find expired detections
|
||||||
|
for cache_key, cached_detection in self._cache.items():
|
||||||
|
if cached_detection.is_expired(self.ttl_seconds):
|
||||||
|
expired_keys.append(cache_key)
|
||||||
|
|
||||||
|
# Remove expired detections
|
||||||
|
for cache_key in expired_keys:
|
||||||
|
self._remove_from_cache(cache_key)
|
||||||
|
|
||||||
|
if expired_keys:
|
||||||
|
logger.info(f"Cleaned up {len(expired_keys)} expired cached detections")
|
||||||
|
|
||||||
|
self._last_cleanup = time.time()
|
||||||
|
return len(expired_keys)
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cache statistics
|
||||||
|
"""
|
||||||
|
self._ensure_thread_safety()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
total_access_count = sum(detection.access_count for detection in self._cache.values())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_cached_detections": len(self._cache),
|
||||||
|
"cameras_with_cache": len(self._camera_index),
|
||||||
|
"sessions_with_cache": len(self._session_index),
|
||||||
|
"total_access_count": total_access_count,
|
||||||
|
"ttl_seconds": self.ttl_seconds,
|
||||||
|
"last_cleanup": self._last_cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCacheManager:
|
||||||
|
"""
|
||||||
|
Combined session and cache management.
|
||||||
|
|
||||||
|
This class provides unified management of sessions and detection caching
|
||||||
|
with automatic cleanup and comprehensive statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS,
|
||||||
|
cache_ttl_minutes: int = SESSION_CACHE_TTL_MINUTES):
|
||||||
|
"""
|
||||||
|
Initialize session cache manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_timeout_seconds: Session timeout in seconds
|
||||||
|
cache_ttl_minutes: Cache TTL in minutes
|
||||||
|
"""
|
||||||
|
self.session_manager = SessionManager(session_timeout_seconds)
|
||||||
|
self.detection_cache = DetectionCache(cache_ttl_minutes)
|
||||||
|
self._last_cleanup = time.time()
|
||||||
|
|
||||||
|
def cleanup_expired_data(self, force: bool = False) -> Dict[str, int]:
|
||||||
|
"""
|
||||||
|
Clean up expired sessions and cached detections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: Force cleanup regardless of interval
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cleanup statistics
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Check if cleanup is needed
|
||||||
|
if not force and current_time - self._last_cleanup < DETECTION_CACHE_CLEANUP_INTERVAL:
|
||||||
|
return {"sessions_cleaned": 0, "detections_cleaned": 0, "cleanup_skipped": True}
|
||||||
|
|
||||||
|
sessions_cleaned = self.session_manager.cleanup_expired_sessions()
|
||||||
|
detections_cleaned = self.detection_cache.cleanup_expired_detections()
|
||||||
|
|
||||||
|
self._last_cleanup = current_time
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sessions_cleaned": sessions_cleaned,
|
||||||
|
"detections_cleaned": detections_cleaned,
|
||||||
|
"cleanup_skipped": False
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_comprehensive_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get comprehensive statistics for sessions and cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with all statistics
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"session_stats": self.session_manager.get_session_stats(),
|
||||||
|
"cache_stats": self.detection_cache.get_cache_stats(),
|
||||||
|
"last_cleanup": self._last_cleanup,
|
||||||
|
"cleanup_interval_seconds": DETECTION_CACHE_CLEANUP_INTERVAL
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global session cache manager instance
|
||||||
|
session_cache_manager = SessionCacheManager()
|
||||||
|
|
||||||
|
|
||||||
|
# ===== CONVENIENCE FUNCTIONS =====
|
||||||
|
# These provide simplified access to session and cache functionality
|
||||||
|
|
||||||
|
def create_session(display_id: str, session_id: Optional[str] = None) -> str:
|
||||||
|
"""Create a new session using global manager."""
|
||||||
|
return session_cache_manager.session_manager.create_session(display_id, session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_state(session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get session state by session ID."""
|
||||||
|
session = session_cache_manager.session_manager.get_session(session_id)
|
||||||
|
return session.to_dict() if session else None
|
||||||
|
|
||||||
|
|
||||||
|
def update_pipeline_mode(session_id: str,
|
||||||
|
new_mode: str,
|
||||||
|
backend_session_id: Optional[str] = None) -> bool:
|
||||||
|
"""Update session pipeline mode."""
|
||||||
|
try:
|
||||||
|
mode = PipelineMode(new_mode)
|
||||||
|
return session_cache_manager.session_manager.update_session_mode(session_id, mode, backend_session_id)
|
||||||
|
except ValueError:
|
||||||
|
logger.error(f"Invalid pipeline mode: {new_mode}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def cache_detection(camera_id: str,
|
||||||
|
detection_data: Dict[str, Any],
|
||||||
|
session_id: Optional[str] = None) -> str:
|
||||||
|
"""Cache detection data using global manager."""
|
||||||
|
return session_cache_manager.detection_cache.store_detection(camera_id, detection_data, session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cached_detection(camera_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get latest cached detection for a camera."""
|
||||||
|
return session_cache_manager.detection_cache.get_latest_detection(camera_id)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_expired_sessions_and_cache() -> Dict[str, int]:
|
||||||
|
"""Clean up expired sessions and cached data."""
|
||||||
|
return session_cache_manager.cleanup_expired_data()
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_and_cache_stats() -> Dict[str, Any]:
|
||||||
|
"""Get comprehensive session and cache statistics."""
|
||||||
|
return session_cache_manager.get_comprehensive_stats()
|
Loading…
Add table
Add a link
Reference in a new issue