Refactor: PHASE 3: Action & Storage Extraction
This commit is contained in:
parent
4e9ae6bcc4
commit
cdeaaf4a4f
5 changed files with 3048 additions and 0 deletions
669
detector_worker/pipeline/action_executor.py
Normal file
669
detector_worker/pipeline/action_executor.py
Normal file
|
@ -0,0 +1,669 @@
|
|||
"""
|
||||
Action execution engine for pipeline workflows.
|
||||
|
||||
This module provides comprehensive action execution functionality including:
|
||||
- Redis-based actions (save image, publish messages)
|
||||
- Database actions (PostgreSQL updates with field mapping)
|
||||
- Parallel action coordination
|
||||
- Dynamic context resolution
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from ..core.constants import (
|
||||
REDIS_IMAGE_DEFAULT_QUALITY,
|
||||
REDIS_IMAGE_DEFAULT_FORMAT,
|
||||
DEFAULT_IMAGE_ENCODING_PARAMS
|
||||
)
|
||||
from ..core.exceptions import ActionExecutionError, create_pipeline_error
|
||||
from ..pipeline.field_mapper import resolve_field_mapping
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionContext:
|
||||
"""Context information for action execution."""
|
||||
timestamp_ms: int = field(default_factory=lambda: int(time.time() * 1000))
|
||||
uuid: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
||||
filename: str = field(default_factory=lambda: f"{uuid.uuid4()}.jpg")
|
||||
image_key: Optional[str] = None
|
||||
|
||||
# Detection context
|
||||
camera_id: str = "unknown"
|
||||
display_id: str = "unknown"
|
||||
session_id: Optional[str] = None
|
||||
backend_session_id: Optional[str] = None
|
||||
|
||||
# Additional context from detection result
|
||||
regions: Optional[Dict[str, Any]] = None
|
||||
detections: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for template formatting."""
|
||||
result = {
|
||||
"timestamp_ms": self.timestamp_ms,
|
||||
"uuid": self.uuid,
|
||||
"timestamp": self.timestamp,
|
||||
"filename": self.filename,
|
||||
"camera_id": self.camera_id,
|
||||
"display_id": self.display_id,
|
||||
}
|
||||
|
||||
if self.image_key:
|
||||
result["image_key"] = self.image_key
|
||||
if self.session_id:
|
||||
result["session_id"] = self.session_id
|
||||
if self.backend_session_id:
|
||||
result["backend_session_id"] = self.backend_session_id
|
||||
|
||||
return result
|
||||
|
||||
def update_from_dict(self, data: Dict[str, Any]) -> None:
|
||||
"""Update context from detection result dictionary."""
|
||||
if "camera_id" in data:
|
||||
self.camera_id = data["camera_id"]
|
||||
if "display_id" in data:
|
||||
self.display_id = data["display_id"]
|
||||
if "session_id" in data:
|
||||
self.session_id = data["session_id"]
|
||||
if "backend_session_id" in data:
|
||||
self.backend_session_id = data["backend_session_id"]
|
||||
if "regions" in data:
|
||||
self.regions = data["regions"]
|
||||
if "detections" in data:
|
||||
self.detections = data["detections"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionResult:
|
||||
"""Result from action execution."""
|
||||
success: bool
|
||||
action_type: str
|
||||
message: str = ""
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"success": self.success,
|
||||
"action_type": self.action_type,
|
||||
"message": self.message
|
||||
}
|
||||
if self.data:
|
||||
result["data"] = self.data
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
return result
|
||||
|
||||
|
||||
class RedisActionExecutor:
|
||||
"""Executor for Redis-based actions."""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
"""
|
||||
Initialize Redis action executor.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
"""
|
||||
self.redis_client = redis_client
|
||||
|
||||
def _crop_region_by_class(self,
|
||||
frame: np.ndarray,
|
||||
regions_dict: Dict[str, Any],
|
||||
class_name: str) -> Optional[np.ndarray]:
|
||||
"""Crop a specific region from frame based on detected class."""
|
||||
if class_name not in regions_dict:
|
||||
logger.warning(f"Class '{class_name}' not found in detected regions")
|
||||
return None
|
||||
|
||||
bbox = regions_dict[class_name]["bbox"]
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
# Validate bbox coordinates
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
logger.warning(f"Invalid bbox for class {class_name}: {bbox}")
|
||||
return None
|
||||
|
||||
try:
|
||||
cropped = frame[y1:y2, x1:x2]
|
||||
if cropped.size == 0:
|
||||
logger.warning(f"Empty crop for class {class_name}")
|
||||
return None
|
||||
return cropped
|
||||
except Exception as e:
|
||||
logger.error(f"Error cropping region for class {class_name}: {e}")
|
||||
return None
|
||||
|
||||
def _encode_image(self,
|
||||
image: np.ndarray,
|
||||
img_format: str = REDIS_IMAGE_DEFAULT_FORMAT,
|
||||
quality: int = REDIS_IMAGE_DEFAULT_QUALITY) -> Tuple[bool, Optional[bytes]]:
|
||||
"""Encode image with specified format and quality."""
|
||||
try:
|
||||
img_format = img_format.lower()
|
||||
|
||||
if img_format == "jpeg" or img_format == "jpg":
|
||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, buffer = cv2.imencode('.jpg', image, encode_params)
|
||||
elif img_format == "png":
|
||||
compression = DEFAULT_IMAGE_ENCODING_PARAMS["png"]["compression"]
|
||||
encode_params = [cv2.IMWRITE_PNG_COMPRESSION, compression]
|
||||
success, buffer = cv2.imencode('.png', image, encode_params)
|
||||
elif img_format == "webp":
|
||||
encode_params = [cv2.IMWRITE_WEBP_QUALITY, quality]
|
||||
success, buffer = cv2.imencode('.webp', image, encode_params)
|
||||
else:
|
||||
# Default to JPEG
|
||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, buffer = cv2.imencode('.jpg', image, encode_params)
|
||||
|
||||
if success:
|
||||
return True, buffer.tobytes()
|
||||
else:
|
||||
return False, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding image as {img_format}: {e}")
|
||||
return False, None
|
||||
|
||||
def execute_redis_save_image(self,
|
||||
action: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
context: ActionContext,
|
||||
regions_dict: Optional[Dict[str, Any]] = None) -> ActionResult:
|
||||
"""Execute redis_save_image action."""
|
||||
try:
|
||||
# Format the key template
|
||||
key = action["key"].format(**context.to_dict())
|
||||
|
||||
# Check if we need to crop a specific region
|
||||
region_name = action.get("region")
|
||||
image_to_save = frame
|
||||
|
||||
if region_name and regions_dict:
|
||||
cropped_image = self._crop_region_by_class(frame, regions_dict, region_name)
|
||||
if cropped_image is not None:
|
||||
image_to_save = cropped_image
|
||||
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
|
||||
else:
|
||||
logger.warning(f"Could not crop region '{region_name}', saving full frame instead")
|
||||
|
||||
# Encode image with specified format and quality
|
||||
img_format = action.get("format", REDIS_IMAGE_DEFAULT_FORMAT)
|
||||
quality = action.get("quality", REDIS_IMAGE_DEFAULT_QUALITY)
|
||||
|
||||
success, image_bytes = self._encode_image(image_to_save, img_format, quality)
|
||||
|
||||
if not success or not image_bytes:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_save_image",
|
||||
error=f"Failed to encode image as {img_format}"
|
||||
)
|
||||
|
||||
# Save to Redis
|
||||
expire_seconds = action.get("expire_seconds")
|
||||
if expire_seconds:
|
||||
self.redis_client.setex(key, expire_seconds, image_bytes)
|
||||
message = f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)"
|
||||
else:
|
||||
self.redis_client.set(key, image_bytes)
|
||||
message = f"Saved image to Redis with key: {key}"
|
||||
|
||||
logger.info(message)
|
||||
|
||||
# Update context with image key
|
||||
context.image_key = key
|
||||
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action_type="redis_save_image",
|
||||
message=message,
|
||||
data={"key": key, "format": img_format, "quality": quality}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing redis_save_image: {e}"
|
||||
logger.error(error_msg)
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_save_image",
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
def execute_redis_publish(self,
|
||||
action: Dict[str, Any],
|
||||
context: ActionContext) -> ActionResult:
|
||||
"""Execute redis_publish action."""
|
||||
try:
|
||||
channel = action["channel"]
|
||||
message_template = action["message"]
|
||||
|
||||
# Handle JSON message format by creating it programmatically
|
||||
if message_template.strip().startswith('{') and message_template.strip().endswith('}'):
|
||||
# Create JSON data programmatically to avoid formatting issues
|
||||
json_data = {}
|
||||
|
||||
# Add common fields
|
||||
json_data["event"] = "frontal_detected"
|
||||
json_data["display_id"] = context.display_id
|
||||
json_data["session_id"] = context.session_id
|
||||
json_data["timestamp"] = context.timestamp
|
||||
json_data["image_key"] = context.image_key or ""
|
||||
|
||||
# Convert to JSON string
|
||||
message = json.dumps(json_data)
|
||||
else:
|
||||
# Use regular string formatting for non-JSON messages
|
||||
message = message_template.format(**context.to_dict())
|
||||
|
||||
# Test Redis connection
|
||||
if not self.redis_client:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_publish",
|
||||
error="Redis client is None"
|
||||
)
|
||||
|
||||
try:
|
||||
self.redis_client.ping()
|
||||
logger.debug("Redis connection is active")
|
||||
except Exception as ping_error:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_publish",
|
||||
error=f"Redis connection test failed: {ping_error}"
|
||||
)
|
||||
|
||||
# Publish to Redis
|
||||
result = self.redis_client.publish(channel, message)
|
||||
|
||||
success_msg = f"Published message to Redis channel '{channel}': {message}"
|
||||
logger.info(success_msg)
|
||||
logger.info(f"Redis publish result (subscribers count): {result}")
|
||||
|
||||
# Additional debug info
|
||||
if result == 0:
|
||||
logger.warning(f"No subscribers listening to channel '{channel}'")
|
||||
else:
|
||||
logger.info(f"Message delivered to {result} subscriber(s)")
|
||||
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action_type="redis_publish",
|
||||
message=success_msg,
|
||||
data={
|
||||
"channel": channel,
|
||||
"message": message,
|
||||
"subscribers_count": result
|
||||
}
|
||||
)
|
||||
|
||||
except KeyError as e:
|
||||
error_msg = f"Missing key in redis_publish message template: {e}"
|
||||
logger.error(error_msg)
|
||||
logger.debug(f"Available context keys: {list(context.to_dict().keys())}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_publish",
|
||||
error=error_msg
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in redis_publish action: {e}"
|
||||
logger.error(error_msg)
|
||||
logger.debug(f"Message template: {action['message']}")
|
||||
logger.debug(f"Available context keys: {list(context.to_dict().keys())}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_publish",
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
|
||||
class DatabaseActionExecutor:
|
||||
"""Executor for database-based actions."""
|
||||
|
||||
def __init__(self, db_manager):
|
||||
"""
|
||||
Initialize database action executor.
|
||||
|
||||
Args:
|
||||
db_manager: Database manager instance
|
||||
"""
|
||||
self.db_manager = db_manager
|
||||
|
||||
def execute_postgresql_update_combined(self,
|
||||
action: Dict[str, Any],
|
||||
detection_result: Dict[str, Any],
|
||||
branch_results: Dict[str, Any]) -> ActionResult:
|
||||
"""Execute PostgreSQL update with combined branch results."""
|
||||
try:
|
||||
table = action["table"]
|
||||
key_field = action["key_field"]
|
||||
key_value_template = action["key_value"]
|
||||
fields = action["fields"]
|
||||
|
||||
# Create context for key value formatting
|
||||
action_context = {**detection_result}
|
||||
key_value = key_value_template.format(**action_context)
|
||||
|
||||
logger.info(f"Executing database update: table={table}, {key_field}={key_value}")
|
||||
logger.debug(f"Available branch results: {list(branch_results.keys())}")
|
||||
|
||||
# Process field mappings
|
||||
mapped_fields = {}
|
||||
mapping_errors = []
|
||||
|
||||
for db_field, value_template in fields.items():
|
||||
try:
|
||||
mapped_value = resolve_field_mapping(value_template, branch_results, action_context)
|
||||
if mapped_value is not None:
|
||||
mapped_fields[db_field] = mapped_value
|
||||
logger.info(f"Mapped field: {db_field} = {mapped_value}")
|
||||
else:
|
||||
error_msg = f"Could not resolve field mapping for {db_field}: {value_template}"
|
||||
logger.warning(error_msg)
|
||||
mapping_errors.append(error_msg)
|
||||
logger.debug(f"Available branch results: {branch_results}")
|
||||
except Exception as e:
|
||||
error_msg = f"Error mapping field {db_field} with template '{value_template}': {e}"
|
||||
logger.error(error_msg)
|
||||
mapping_errors.append(error_msg)
|
||||
|
||||
if not mapped_fields:
|
||||
error_msg = "No fields mapped successfully, skipping database update"
|
||||
logger.warning(error_msg)
|
||||
logger.debug(f"Branch results available: {branch_results}")
|
||||
logger.debug(f"Field templates: {fields}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="postgresql_update_combined",
|
||||
error=error_msg,
|
||||
data={"mapping_errors": mapping_errors}
|
||||
)
|
||||
|
||||
# Add updated_at field automatically
|
||||
mapped_fields["updated_at"] = "NOW()"
|
||||
|
||||
# Execute the database update
|
||||
logger.info(f"Attempting database update with fields: {mapped_fields}")
|
||||
success = self.db_manager.execute_update(table, key_field, key_value, mapped_fields)
|
||||
|
||||
if success:
|
||||
success_msg = f"Successfully updated database: {table} with {len(mapped_fields)} fields"
|
||||
logger.info(f"✅ {success_msg}")
|
||||
logger.info(f"Updated fields: {mapped_fields}")
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action_type="postgresql_update_combined",
|
||||
message=success_msg,
|
||||
data={
|
||||
"table": table,
|
||||
"key_field": key_field,
|
||||
"key_value": key_value,
|
||||
"updated_fields": mapped_fields
|
||||
}
|
||||
)
|
||||
else:
|
||||
error_msg = f"Failed to update database: {table}"
|
||||
logger.error(f"❌ {error_msg}")
|
||||
logger.error(f"Attempted update with: {key_field}={key_value}, fields={mapped_fields}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="postgresql_update_combined",
|
||||
error=error_msg,
|
||||
data={
|
||||
"table": table,
|
||||
"key_field": key_field,
|
||||
"key_value": key_value,
|
||||
"attempted_fields": mapped_fields
|
||||
}
|
||||
)
|
||||
|
||||
except KeyError as e:
|
||||
error_msg = f"Missing required field in postgresql_update_combined action: {e}"
|
||||
logger.error(error_msg)
|
||||
logger.debug(f"Action config: {action}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="postgresql_update_combined",
|
||||
error=error_msg
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in postgresql_update_combined action: {e}"
|
||||
logger.error(error_msg)
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="postgresql_update_combined",
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
|
||||
class ActionExecutor:
|
||||
"""
|
||||
Main action execution engine for pipeline workflows.
|
||||
|
||||
This class coordinates the execution of various action types including
|
||||
Redis operations, database operations, and parallel action coordination.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize action executor."""
|
||||
self._redis_executors: Dict[Any, RedisActionExecutor] = {}
|
||||
self._db_executors: Dict[Any, DatabaseActionExecutor] = {}
|
||||
|
||||
def _get_redis_executor(self, redis_client) -> RedisActionExecutor:
|
||||
"""Get or create Redis executor for a client."""
|
||||
if redis_client not in self._redis_executors:
|
||||
self._redis_executors[redis_client] = RedisActionExecutor(redis_client)
|
||||
return self._redis_executors[redis_client]
|
||||
|
||||
def _get_db_executor(self, db_manager) -> DatabaseActionExecutor:
|
||||
"""Get or create database executor for a manager."""
|
||||
if db_manager not in self._db_executors:
|
||||
self._db_executors[db_manager] = DatabaseActionExecutor(db_manager)
|
||||
return self._db_executors[db_manager]
|
||||
|
||||
def _create_action_context(self, detection_result: Dict[str, Any]) -> ActionContext:
|
||||
"""Create action context from detection result."""
|
||||
context = ActionContext()
|
||||
context.update_from_dict(detection_result)
|
||||
return context
|
||||
|
||||
def execute_actions(self,
|
||||
node: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
detection_result: Dict[str, Any],
|
||||
regions_dict: Optional[Dict[str, Any]] = None) -> List[ActionResult]:
|
||||
"""
|
||||
Execute all actions for a node.
|
||||
|
||||
Args:
|
||||
node: Pipeline node configuration
|
||||
frame: Input frame
|
||||
detection_result: Detection result data
|
||||
regions_dict: Optional region dictionary for cropping
|
||||
|
||||
Returns:
|
||||
List of action results
|
||||
"""
|
||||
if not node.get("redis_client") or not node.get("actions"):
|
||||
return []
|
||||
|
||||
results = []
|
||||
redis_executor = self._get_redis_executor(node["redis_client"])
|
||||
context = self._create_action_context(detection_result)
|
||||
|
||||
for action in node["actions"]:
|
||||
try:
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "redis_save_image":
|
||||
result = redis_executor.execute_redis_save_image(
|
||||
action, frame, context, regions_dict
|
||||
)
|
||||
elif action_type == "redis_publish":
|
||||
result = redis_executor.execute_redis_publish(action, context)
|
||||
else:
|
||||
result = ActionResult(
|
||||
success=False,
|
||||
action_type=action_type or "unknown",
|
||||
error=f"Unknown action type: {action_type}"
|
||||
)
|
||||
|
||||
results.append(result)
|
||||
|
||||
if not result.success:
|
||||
logger.error(f"Action failed: {result.error}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing action {action.get('type', 'unknown')}: {e}"
|
||||
logger.error(error_msg)
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action_type=action.get('type', 'unknown'),
|
||||
error=error_msg
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
def execute_parallel_actions(self,
|
||||
node: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
detection_result: Dict[str, Any],
|
||||
regions_dict: Dict[str, Any]) -> List[ActionResult]:
|
||||
"""
|
||||
Execute parallel actions after all required branches have completed.
|
||||
|
||||
Args:
|
||||
node: Pipeline node configuration
|
||||
frame: Input frame
|
||||
detection_result: Detection result data with branch results
|
||||
regions_dict: Region dictionary
|
||||
|
||||
Returns:
|
||||
List of action results
|
||||
"""
|
||||
if not node.get("parallelActions"):
|
||||
return []
|
||||
|
||||
results = []
|
||||
branch_results = detection_result.get("branch_results", {})
|
||||
|
||||
logger.debug("Executing parallel actions...")
|
||||
|
||||
for action in node["parallelActions"]:
|
||||
try:
|
||||
action_type = action.get("type")
|
||||
logger.debug(f"Processing parallel action: {action_type}")
|
||||
|
||||
if action_type == "postgresql_update_combined":
|
||||
# Check if all required branches have completed
|
||||
wait_for_branches = action.get("waitForBranches", [])
|
||||
missing_branches = [branch for branch in wait_for_branches if branch not in branch_results]
|
||||
|
||||
if missing_branches:
|
||||
error_msg = f"Cannot execute postgresql_update_combined: missing branch results for {missing_branches}"
|
||||
logger.warning(error_msg)
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action_type=action_type,
|
||||
error=error_msg,
|
||||
data={"missing_branches": missing_branches}
|
||||
))
|
||||
continue
|
||||
|
||||
logger.info(f"All required branches completed: {wait_for_branches}")
|
||||
|
||||
# Execute the database update
|
||||
if node.get("db_manager"):
|
||||
db_executor = self._get_db_executor(node["db_manager"])
|
||||
result = db_executor.execute_postgresql_update_combined(
|
||||
action, detection_result, branch_results
|
||||
)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action_type=action_type,
|
||||
error="No database manager available"
|
||||
))
|
||||
else:
|
||||
error_msg = f"Unknown parallel action type: {action_type}"
|
||||
logger.warning(error_msg)
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action_type=action_type or "unknown",
|
||||
error=error_msg
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing parallel action {action.get('type', 'unknown')}: {e}"
|
||||
logger.error(error_msg)
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action_type=action.get('type', 'unknown'),
|
||||
error=error_msg
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global action executor instance
|
||||
action_executor = ActionExecutor()
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide the same interface as the original functions in pympta.py
|
||||
|
||||
def execute_actions(node: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
detection_result: Dict[str, Any],
|
||||
regions_dict: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Execute all actions for a node using global executor."""
|
||||
results = action_executor.execute_actions(node, frame, detection_result, regions_dict)
|
||||
|
||||
# Log any failures (maintain original behavior of not returning results)
|
||||
for result in results:
|
||||
if not result.success:
|
||||
logger.error(f"Action execution failed: {result.error}")
|
||||
|
||||
|
||||
def execute_parallel_actions(node: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
detection_result: Dict[str, Any],
|
||||
regions_dict: Dict[str, Any]) -> None:
|
||||
"""Execute parallel actions using global executor."""
|
||||
results = action_executor.execute_parallel_actions(node, frame, detection_result, regions_dict)
|
||||
|
||||
# Log any failures (maintain original behavior of not returning results)
|
||||
for result in results:
|
||||
if not result.success:
|
||||
logger.error(f"Parallel action execution failed: {result.error}")
|
||||
|
||||
|
||||
def execute_postgresql_update_combined(node: Dict[str, Any],
|
||||
action: Dict[str, Any],
|
||||
detection_result: Dict[str, Any],
|
||||
branch_results: Dict[str, Any]) -> None:
|
||||
"""Execute PostgreSQL update with combined branch results using global executor."""
|
||||
if node.get("db_manager"):
|
||||
db_executor = action_executor._get_db_executor(node["db_manager"])
|
||||
result = db_executor.execute_postgresql_update_combined(action, detection_result, branch_results)
|
||||
|
||||
if not result.success:
|
||||
logger.error(f"PostgreSQL update failed: {result.error}")
|
||||
else:
|
||||
logger.error("No database manager available for postgresql_update_combined action")
|
341
detector_worker/pipeline/field_mapper.py
Normal file
341
detector_worker/pipeline/field_mapper.py
Normal file
|
@ -0,0 +1,341 @@
|
|||
"""
|
||||
Field mapping and template resolution for dynamic database operations.
|
||||
|
||||
This module provides functionality for resolving field mapping templates
|
||||
that reference branch results and context variables for database operations.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
from ..core.exceptions import FieldMappingError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FieldMapper:
|
||||
"""
|
||||
Field mapping resolver for dynamic template resolution.
|
||||
|
||||
This class handles the resolution of field mapping templates that can reference:
|
||||
- Branch results (e.g., {car_brand_cls_v1.brand})
|
||||
- Context variables (e.g., {session_id})
|
||||
- Nested field lookups with fallback strategies
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize field mapper."""
|
||||
pass
|
||||
|
||||
def _extract_branch_references(self, template: str) -> List[str]:
|
||||
"""Extract branch references from template string."""
|
||||
# Match patterns like {model_id.field_name}
|
||||
branch_refs = re.findall(r'\{([^}]+\.[^}]+)\}', template)
|
||||
return branch_refs
|
||||
|
||||
def _resolve_simple_template(self, template: str, context: Dict[str, Any]) -> str:
|
||||
"""Resolve simple template without branch references."""
|
||||
try:
|
||||
result = template.format(**context)
|
||||
logger.debug(f"Simple template resolved: '{template}' -> '{result}'")
|
||||
return result
|
||||
except KeyError as e:
|
||||
logger.warning(f"Could not resolve context variable in simple template: {e}")
|
||||
return template
|
||||
|
||||
def _find_fallback_value(self,
|
||||
branch_data: Dict[str, Any],
|
||||
field_name: str,
|
||||
model_id: str) -> Optional[str]:
|
||||
"""Find fallback value using various strategies."""
|
||||
if not isinstance(branch_data, dict):
|
||||
logger.error(f"Branch data for '{model_id}' is not a dictionary: {type(branch_data)}")
|
||||
return None
|
||||
|
||||
# First, try the exact field name
|
||||
if field_name in branch_data:
|
||||
return branch_data[field_name]
|
||||
|
||||
# Then try 'class' field as fallback
|
||||
if 'class' in branch_data:
|
||||
fallback_value = branch_data['class']
|
||||
logger.info(f"Using 'class' field as fallback for '{field_name}': '{fallback_value}'")
|
||||
return fallback_value
|
||||
|
||||
# For brand models, check if the class name exists as a key
|
||||
if field_name == 'brand' and branch_data.get('class') in branch_data:
|
||||
fallback_value = branch_data[branch_data['class']]
|
||||
logger.info(f"Found brand value using class name as key: '{fallback_value}'")
|
||||
return fallback_value
|
||||
|
||||
# For body_type models, check if the class name exists as a key
|
||||
if field_name == 'body_type' and branch_data.get('class') in branch_data:
|
||||
fallback_value = branch_data[branch_data['class']]
|
||||
logger.info(f"Found body_type value using class name as key: '{fallback_value}'")
|
||||
return fallback_value
|
||||
|
||||
# Additional fallback strategies for common field mappings
|
||||
field_mappings = {
|
||||
'brand': ['car_brand', 'brand_name', 'detected_brand'],
|
||||
'body_type': ['car_body_type', 'bodytype', 'body', 'car_type'],
|
||||
'model': ['car_model', 'model_name', 'detected_model'],
|
||||
'color': ['car_color', 'color_name', 'detected_color']
|
||||
}
|
||||
|
||||
if field_name in field_mappings:
|
||||
for alternative in field_mappings[field_name]:
|
||||
if alternative in branch_data:
|
||||
fallback_value = branch_data[alternative]
|
||||
logger.info(f"Found '{field_name}' value using alternative field '{alternative}': '{fallback_value}'")
|
||||
return fallback_value
|
||||
|
||||
return None
|
||||
|
||||
def _resolve_branch_reference(self,
|
||||
ref: str,
|
||||
branch_results: Dict[str, Any]) -> Optional[str]:
|
||||
"""Resolve a single branch reference."""
|
||||
try:
|
||||
model_id, field_name = ref.split('.', 1)
|
||||
logger.debug(f"Processing branch reference: model_id='{model_id}', field_name='{field_name}'")
|
||||
|
||||
if model_id not in branch_results:
|
||||
logger.warning(f"Branch '{model_id}' not found in results. Available branches: {list(branch_results.keys())}")
|
||||
return None
|
||||
|
||||
branch_data = branch_results[model_id]
|
||||
logger.debug(f"Branch '{model_id}' data: {branch_data}")
|
||||
|
||||
if field_name in branch_data:
|
||||
field_value = branch_data[field_name]
|
||||
logger.info(f"✅ Resolved {ref} to '{field_value}'")
|
||||
return str(field_value)
|
||||
else:
|
||||
logger.warning(f"Field '{field_name}' not found in branch '{model_id}' results.")
|
||||
logger.debug(f"Available fields in '{model_id}': {list(branch_data.keys()) if isinstance(branch_data, dict) else 'N/A'}")
|
||||
|
||||
# Try fallback strategies
|
||||
fallback_value = self._find_fallback_value(branch_data, field_name, model_id)
|
||||
|
||||
if fallback_value is not None:
|
||||
logger.info(f"✅ Resolved {ref} to '{fallback_value}' (using fallback)")
|
||||
return str(fallback_value)
|
||||
else:
|
||||
logger.error(f"No suitable field found for '{field_name}' in branch '{model_id}'")
|
||||
logger.debug(f"Branch data structure: {branch_data}")
|
||||
return None
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid branch reference format: {ref}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving branch reference '{ref}': {e}")
|
||||
return None
|
||||
|
||||
def resolve_field_mapping(self,
|
||||
value_template: str,
|
||||
branch_results: Dict[str, Any],
|
||||
action_context: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
Resolve field mapping templates like {car_brand_cls_v1.brand}.
|
||||
|
||||
Args:
|
||||
value_template: Template string with placeholders
|
||||
branch_results: Dictionary of branch execution results
|
||||
action_context: Context variables for template resolution
|
||||
|
||||
Returns:
|
||||
Resolved string value or None if resolution failed
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Resolving field mapping: '{value_template}'")
|
||||
logger.debug(f"Available branch results: {list(branch_results.keys())}")
|
||||
|
||||
# Handle simple context variables first (non-branch references)
|
||||
if '.' not in value_template:
|
||||
result = self._resolve_simple_template(value_template, action_context)
|
||||
return result
|
||||
|
||||
# Handle branch result references like {model_id.field}
|
||||
branch_refs = self._extract_branch_references(value_template)
|
||||
logger.debug(f"Found branch references: {branch_refs}")
|
||||
|
||||
resolved_template = value_template
|
||||
|
||||
for ref in branch_refs:
|
||||
resolved_value = self._resolve_branch_reference(ref, branch_results)
|
||||
|
||||
if resolved_value is not None:
|
||||
resolved_template = resolved_template.replace(f'{{{ref}}}', resolved_value)
|
||||
else:
|
||||
logger.error(f"Failed to resolve branch reference: {ref}")
|
||||
return None
|
||||
|
||||
# Format any remaining simple variables
|
||||
try:
|
||||
final_value = resolved_template.format(**action_context)
|
||||
logger.debug(f"Final resolved value: '{final_value}'")
|
||||
return final_value
|
||||
except KeyError as e:
|
||||
logger.warning(f"Could not resolve context variable in template: {e}")
|
||||
return resolved_template
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving field mapping '{value_template}': {e}")
|
||||
return None
|
||||
|
||||
def resolve_multiple_fields(self,
|
||||
field_templates: Dict[str, str],
|
||||
branch_results: Dict[str, Any],
|
||||
action_context: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Resolve multiple field mappings at once.
|
||||
|
||||
Args:
|
||||
field_templates: Dictionary mapping field names to templates
|
||||
branch_results: Dictionary of branch execution results
|
||||
action_context: Context variables for template resolution
|
||||
|
||||
Returns:
|
||||
Dictionary mapping field names to resolved values
|
||||
"""
|
||||
resolved_fields = {}
|
||||
|
||||
for field_name, template in field_templates.items():
|
||||
try:
|
||||
resolved_value = self.resolve_field_mapping(template, branch_results, action_context)
|
||||
if resolved_value is not None:
|
||||
resolved_fields[field_name] = resolved_value
|
||||
logger.debug(f"Successfully resolved field '{field_name}': {resolved_value}")
|
||||
else:
|
||||
logger.warning(f"Failed to resolve field '{field_name}' with template: {template}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving field '{field_name}' with template '{template}': {e}")
|
||||
|
||||
return resolved_fields
|
||||
|
||||
def validate_template(self, template: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate a field mapping template and return analysis.
|
||||
|
||||
Args:
|
||||
template: Template string to validate
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results and analysis
|
||||
"""
|
||||
analysis = {
|
||||
"valid": True,
|
||||
"errors": [],
|
||||
"warnings": [],
|
||||
"branch_references": [],
|
||||
"context_references": [],
|
||||
"has_branch_refs": False,
|
||||
"has_context_refs": False
|
||||
}
|
||||
|
||||
try:
|
||||
# Extract all placeholders
|
||||
all_refs = re.findall(r'\{([^}]+)\}', template)
|
||||
|
||||
for ref in all_refs:
|
||||
if '.' in ref:
|
||||
# This is a branch reference
|
||||
analysis["branch_references"].append(ref)
|
||||
analysis["has_branch_refs"] = True
|
||||
|
||||
# Validate format
|
||||
parts = ref.split('.')
|
||||
if len(parts) != 2:
|
||||
analysis["errors"].append(f"Invalid branch reference format: {ref}")
|
||||
analysis["valid"] = False
|
||||
elif not parts[0] or not parts[1]:
|
||||
analysis["errors"].append(f"Empty model_id or field_name in reference: {ref}")
|
||||
analysis["valid"] = False
|
||||
else:
|
||||
# This is a context reference
|
||||
analysis["context_references"].append(ref)
|
||||
analysis["has_context_refs"] = True
|
||||
|
||||
# Check for common issues
|
||||
if analysis["has_branch_refs"] and not analysis["has_context_refs"]:
|
||||
analysis["warnings"].append("Template only uses branch references, consider adding context info")
|
||||
|
||||
if not analysis["branch_references"] and not analysis["context_references"]:
|
||||
analysis["warnings"].append("Template has no placeholders - it's a static value")
|
||||
|
||||
except Exception as e:
|
||||
analysis["valid"] = False
|
||||
analysis["errors"].append(f"Template analysis failed: {e}")
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
# Global field mapper instance
|
||||
field_mapper = FieldMapper()
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide the same interface as the original functions in pympta.py
|
||||
|
||||
def resolve_field_mapping(value_template: str,
|
||||
branch_results: Dict[str, Any],
|
||||
action_context: Dict[str, Any]) -> Optional[str]:
|
||||
"""Resolve field mapping templates like {car_brand_cls_v1.brand}."""
|
||||
return field_mapper.resolve_field_mapping(value_template, branch_results, action_context)
|
||||
|
||||
|
||||
def resolve_multiple_field_mappings(field_templates: Dict[str, str],
|
||||
branch_results: Dict[str, Any],
|
||||
action_context: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Resolve multiple field mappings at once."""
|
||||
return field_mapper.resolve_multiple_fields(field_templates, branch_results, action_context)
|
||||
|
||||
|
||||
def validate_field_mapping_template(template: str) -> Dict[str, Any]:
|
||||
"""Validate a field mapping template and return analysis."""
|
||||
return field_mapper.validate_template(template)
|
||||
|
||||
|
||||
def get_available_field_mappings(branch_results: Dict[str, Any]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get available field mappings from branch results.
|
||||
|
||||
Args:
|
||||
branch_results: Dictionary of branch execution results
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model IDs to available field names
|
||||
"""
|
||||
available_mappings = {}
|
||||
|
||||
for model_id, branch_data in branch_results.items():
|
||||
if isinstance(branch_data, dict):
|
||||
available_mappings[model_id] = list(branch_data.keys())
|
||||
else:
|
||||
logger.warning(f"Branch '{model_id}' data is not a dictionary: {type(branch_data)}")
|
||||
available_mappings[model_id] = []
|
||||
|
||||
return available_mappings
|
||||
|
||||
|
||||
def create_field_mapping_examples(branch_results: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
Create example field mapping templates based on available branch results.
|
||||
|
||||
Args:
|
||||
branch_results: Dictionary of branch execution results
|
||||
|
||||
Returns:
|
||||
List of example template strings
|
||||
"""
|
||||
examples = []
|
||||
|
||||
for model_id, branch_data in branch_results.items():
|
||||
if isinstance(branch_data, dict):
|
||||
for field_name in branch_data.keys():
|
||||
example = f"{{{model_id}.{field_name}}}"
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
617
detector_worker/storage/database_manager.py
Normal file
617
detector_worker/storage/database_manager.py
Normal file
|
@ -0,0 +1,617 @@
|
|||
"""
|
||||
Database management and operations.
|
||||
|
||||
This module provides comprehensive database functionality including:
|
||||
- PostgreSQL connection management
|
||||
- Table schema management
|
||||
- Dynamic query execution
|
||||
- Transaction handling
|
||||
"""
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from ..core.constants import (
|
||||
DB_CONNECTION_TIMEOUT,
|
||||
DB_OPERATION_TIMEOUT,
|
||||
DB_RETRY_ATTEMPTS
|
||||
)
|
||||
from ..core.exceptions import DatabaseError, create_detection_error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database connection configuration."""
|
||||
host: str
|
||||
port: int
|
||||
database: str
|
||||
username: str
|
||||
password: str
|
||||
schema: str = "gas_station_1"
|
||||
connection_timeout: int = DB_CONNECTION_TIMEOUT
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"database": self.database,
|
||||
"username": self.username,
|
||||
"password": self.password,
|
||||
"schema": self.schema,
|
||||
"connection_timeout": self.connection_timeout
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DatabaseConfig':
|
||||
"""Create from dictionary data."""
|
||||
return cls(
|
||||
host=data["host"],
|
||||
port=data["port"],
|
||||
database=data["database"],
|
||||
username=data.get("username", data.get("user", "")),
|
||||
password=data["password"],
|
||||
schema=data.get("schema", "gas_station_1"),
|
||||
connection_timeout=data.get("connection_timeout", DB_CONNECTION_TIMEOUT)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
"""Result from database query execution."""
|
||||
success: bool
|
||||
affected_rows: int = 0
|
||||
data: Optional[List[Dict[str, Any]]] = None
|
||||
error: Optional[str] = None
|
||||
query: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"success": self.success,
|
||||
"affected_rows": self.affected_rows
|
||||
}
|
||||
if self.data:
|
||||
result["data"] = self.data
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
if self.query:
|
||||
result["query"] = self.query
|
||||
return result
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Comprehensive database manager for PostgreSQL operations.
|
||||
|
||||
This class provides connection management, schema operations,
|
||||
and dynamic query execution with transaction support.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize database manager.
|
||||
|
||||
Args:
|
||||
config: Database configuration dictionary
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
self.config = DatabaseConfig.from_dict(config)
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.connection: Optional[psycopg2.extensions.connection] = None
|
||||
self._lock = None
|
||||
|
||||
def _ensure_thread_safety(self):
|
||||
"""Initialize thread safety if not already done."""
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
Establish database connection.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
if self.connection and not self.connection.closed:
|
||||
# Connection already exists and is open
|
||||
return True
|
||||
|
||||
self.connection = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
database=self.config.database,
|
||||
user=self.config.username,
|
||||
password=self.config.password,
|
||||
connect_timeout=self.config.connection_timeout
|
||||
)
|
||||
|
||||
# Set connection properties
|
||||
self.connection.set_client_encoding('UTF8')
|
||||
|
||||
logger.info(f"PostgreSQL connection established successfully to {self.config.host}:{self.config.port}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to PostgreSQL: {e}")
|
||||
self.connection = None
|
||||
return False
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Close database connection."""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
if self.connection:
|
||||
try:
|
||||
self.connection.close()
|
||||
logger.info("PostgreSQL connection closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing PostgreSQL connection: {e}")
|
||||
finally:
|
||||
self.connection = None
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""
|
||||
Check if database connection is active.
|
||||
|
||||
Returns:
|
||||
True if connected, False otherwise
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
if self.connection and not self.connection.closed:
|
||||
cur = self.connection.cursor()
|
||||
cur.execute("SELECT 1")
|
||||
cur.fetchone()
|
||||
cur.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Connection check failed: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def _ensure_connected(self) -> bool:
|
||||
"""Ensure database connection is active."""
|
||||
if not self.is_connected():
|
||||
return self.connect()
|
||||
return True
|
||||
|
||||
def execute_query(self,
|
||||
query: str,
|
||||
params: Optional[Tuple] = None,
|
||||
fetch_results: bool = False) -> QueryResult:
|
||||
"""
|
||||
Execute a database query.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Query parameters
|
||||
fetch_results: Whether to fetch and return results
|
||||
|
||||
Returns:
|
||||
QueryResult with execution details
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
if not self._ensure_connected():
|
||||
return QueryResult(
|
||||
success=False,
|
||||
error="Failed to establish database connection",
|
||||
query=query
|
||||
)
|
||||
|
||||
cursor = None
|
||||
try:
|
||||
cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
|
||||
logger.debug(f"Executing query: {query}")
|
||||
if params:
|
||||
logger.debug(f"Query parameters: {params}")
|
||||
|
||||
cursor.execute(query, params)
|
||||
affected_rows = cursor.rowcount
|
||||
|
||||
data = None
|
||||
if fetch_results:
|
||||
data = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
self.connection.commit()
|
||||
|
||||
logger.debug(f"Query executed successfully, affected rows: {affected_rows}")
|
||||
|
||||
return QueryResult(
|
||||
success=True,
|
||||
affected_rows=affected_rows,
|
||||
data=data,
|
||||
query=query
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Database query execution failed: {e}"
|
||||
logger.error(error_msg)
|
||||
logger.debug(f"Failed query: {query}")
|
||||
if params:
|
||||
logger.debug(f"Query parameters: {params}")
|
||||
|
||||
if self.connection:
|
||||
try:
|
||||
self.connection.rollback()
|
||||
except Exception as rollback_error:
|
||||
logger.error(f"Rollback failed: {rollback_error}")
|
||||
|
||||
return QueryResult(
|
||||
success=False,
|
||||
error=error_msg,
|
||||
query=query
|
||||
)
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
|
||||
def create_schema_if_not_exists(self, schema_name: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Create database schema if it doesn't exist.
|
||||
|
||||
Args:
|
||||
schema_name: Schema name (uses config schema if not provided)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
schema_name = schema_name or self.config.schema
|
||||
query = f"CREATE SCHEMA IF NOT EXISTS {schema_name}"
|
||||
result = self.execute_query(query)
|
||||
|
||||
if result.success:
|
||||
logger.info(f"Schema '{schema_name}' created or verified successfully")
|
||||
|
||||
return result.success
|
||||
|
||||
def create_car_frontal_info_table(self) -> bool:
|
||||
"""
|
||||
Create the car_frontal_info table in the configured schema if it doesn't exist.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
schema_name = self.config.schema
|
||||
|
||||
# Ensure schema exists
|
||||
if not self.create_schema_if_not_exists(schema_name):
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create table if it doesn't exist
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS {schema_name}.car_frontal_info (
|
||||
display_id VARCHAR(255),
|
||||
captured_timestamp VARCHAR(255),
|
||||
session_id VARCHAR(255) PRIMARY KEY,
|
||||
license_character VARCHAR(255) DEFAULT NULL,
|
||||
license_type VARCHAR(255) DEFAULT 'No model available',
|
||||
car_brand VARCHAR(255) DEFAULT NULL,
|
||||
car_model VARCHAR(255) DEFAULT NULL,
|
||||
car_body_type VARCHAR(255) DEFAULT NULL,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
)
|
||||
"""
|
||||
|
||||
result = self.execute_query(create_table_query)
|
||||
if not result.success:
|
||||
return False
|
||||
|
||||
# Add columns if they don't exist (for existing tables)
|
||||
alter_queries = [
|
||||
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_brand VARCHAR(255) DEFAULT NULL",
|
||||
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_model VARCHAR(255) DEFAULT NULL",
|
||||
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_body_type VARCHAR(255) DEFAULT NULL",
|
||||
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS created_at TIMESTAMP DEFAULT NOW()",
|
||||
f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT NOW()"
|
||||
]
|
||||
|
||||
for alter_query in alter_queries:
|
||||
try:
|
||||
alter_result = self.execute_query(alter_query)
|
||||
if alter_result.success:
|
||||
logger.debug(f"Executed: {alter_query}")
|
||||
else:
|
||||
# Check if it's just a "column already exists" error
|
||||
if "already exists" not in (alter_result.error or "").lower():
|
||||
logger.warning(f"ALTER TABLE failed: {alter_result.error}")
|
||||
except Exception as e:
|
||||
# Ignore errors if column already exists (for older PostgreSQL versions)
|
||||
if "already exists" in str(e).lower():
|
||||
logger.debug(f"Column already exists, skipping: {alter_query}")
|
||||
else:
|
||||
logger.warning(f"Error in ALTER TABLE: {e}")
|
||||
|
||||
logger.info("Successfully created/verified car_frontal_info table with all required columns")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create car_frontal_info table: {e}")
|
||||
return False
|
||||
|
||||
def insert_initial_detection(self,
|
||||
display_id: str,
|
||||
captured_timestamp: str,
|
||||
session_id: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Insert initial detection record and return the session_id.
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
captured_timestamp: Timestamp of capture
|
||||
session_id: Optional session ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
Session ID if successful, None otherwise
|
||||
"""
|
||||
# Generate session_id if not provided
|
||||
if not session_id:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Ensure table exists
|
||||
if not self.create_car_frontal_info_table():
|
||||
logger.error("Failed to create/verify table before insertion")
|
||||
return None
|
||||
|
||||
schema_name = self.config.schema
|
||||
insert_query = f"""
|
||||
INSERT INTO {schema_name}.car_frontal_info
|
||||
(display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type, created_at)
|
||||
VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL, NOW())
|
||||
ON CONFLICT (session_id) DO NOTHING
|
||||
"""
|
||||
|
||||
result = self.execute_query(insert_query, (display_id, captured_timestamp, session_id))
|
||||
|
||||
if result.success:
|
||||
logger.info(f"Inserted initial detection record with session_id: {session_id}")
|
||||
return session_id
|
||||
else:
|
||||
logger.error(f"Failed to insert initial detection record: {result.error}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert initial detection record: {e}")
|
||||
return None
|
||||
|
||||
def execute_update(self,
|
||||
table: str,
|
||||
key_field: str,
|
||||
key_value: str,
|
||||
fields: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Execute dynamic update/insert operation.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
key_field: Primary key field name
|
||||
key_value: Primary key value
|
||||
fields: Dictionary of fields to update
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Add schema prefix if table doesn't already have it
|
||||
if '.' not in table:
|
||||
table = f"{self.config.schema}.{table}"
|
||||
|
||||
# Build the INSERT and UPDATE query dynamically
|
||||
insert_placeholders = []
|
||||
insert_values = [key_value] # Start with key_value
|
||||
|
||||
set_clauses = []
|
||||
update_values = []
|
||||
|
||||
for field, value in fields.items():
|
||||
if value == "NOW()":
|
||||
# Special handling for NOW()
|
||||
insert_placeholders.append("NOW()")
|
||||
set_clauses.append(f"{field} = NOW()")
|
||||
else:
|
||||
insert_placeholders.append("%s")
|
||||
insert_values.append(value)
|
||||
set_clauses.append(f"{field} = %s")
|
||||
update_values.append(value)
|
||||
|
||||
# Build the complete query
|
||||
query = f"""
|
||||
INSERT INTO {table} ({key_field}, {', '.join(fields.keys())})
|
||||
VALUES (%s, {', '.join(insert_placeholders)})
|
||||
ON CONFLICT ({key_field})
|
||||
DO UPDATE SET {', '.join(set_clauses)}
|
||||
"""
|
||||
|
||||
# Combine values for the query: insert_values + update_values
|
||||
all_values = tuple(insert_values + update_values)
|
||||
|
||||
result = self.execute_query(query, all_values)
|
||||
|
||||
if result.success:
|
||||
logger.info(f"✅ Updated {table} for {key_field}={key_value} with {len(fields)} fields")
|
||||
logger.debug(f"Updated fields: {fields}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ Failed to update {table}: {result.error}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to execute update on {table}: {e}")
|
||||
return False
|
||||
|
||||
def update_car_info(self,
|
||||
session_id: str,
|
||||
brand: str,
|
||||
model: str,
|
||||
body_type: str) -> bool:
|
||||
"""
|
||||
Update car information for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
brand: Car brand
|
||||
model: Car model
|
||||
body_type: Car body type
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
schema_name = self.config.schema
|
||||
query = f"""
|
||||
INSERT INTO {schema_name}.car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at)
|
||||
VALUES (%s, %s, %s, %s, NOW())
|
||||
ON CONFLICT (session_id)
|
||||
DO UPDATE SET
|
||||
car_brand = EXCLUDED.car_brand,
|
||||
car_model = EXCLUDED.car_model,
|
||||
car_body_type = EXCLUDED.car_body_type,
|
||||
updated_at = NOW()
|
||||
"""
|
||||
|
||||
result = self.execute_query(query, (session_id, brand, model, body_type))
|
||||
|
||||
if result.success:
|
||||
logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})")
|
||||
|
||||
return result.success
|
||||
|
||||
def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get session data by session ID.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
Session data dictionary or None if not found
|
||||
"""
|
||||
schema_name = self.config.schema
|
||||
query = f"SELECT * FROM {schema_name}.car_frontal_info WHERE session_id = %s"
|
||||
|
||||
result = self.execute_query(query, (session_id,), fetch_results=True)
|
||||
|
||||
if result.success and result.data:
|
||||
return result.data[0]
|
||||
|
||||
return None
|
||||
|
||||
def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get database connection statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with connection statistics
|
||||
"""
|
||||
stats = {
|
||||
"connected": self.is_connected(),
|
||||
"config": self.config.to_dict(),
|
||||
"connection_closed": self.connection.closed if self.connection else True
|
||||
}
|
||||
|
||||
if self.is_connected():
|
||||
try:
|
||||
# Get database version and basic stats
|
||||
version_result = self.execute_query("SELECT version()", fetch_results=True)
|
||||
if version_result.success and version_result.data:
|
||||
stats["database_version"] = version_result.data[0]["version"]
|
||||
|
||||
# Get current database name
|
||||
db_result = self.execute_query("SELECT current_database()", fetch_results=True)
|
||||
if db_result.success and db_result.data:
|
||||
stats["current_database"] = db_result.data[0]["current_database"]
|
||||
|
||||
except Exception as e:
|
||||
stats["stats_error"] = str(e)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide compatibility with the original database.py interface
|
||||
|
||||
def validate_postgresql_config(config: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate PostgreSQL configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration is valid
|
||||
"""
|
||||
required_fields = ["host", "port", "database", "password"]
|
||||
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
logger.error(f"Missing required PostgreSQL config field: {field}")
|
||||
return False
|
||||
|
||||
if not config[field]:
|
||||
logger.error(f"Empty PostgreSQL config field: {field}")
|
||||
return False
|
||||
|
||||
# Check for username (could be 'username' or 'user')
|
||||
if not config.get("username") and not config.get("user"):
|
||||
logger.error("Missing PostgreSQL username (provide either 'username' or 'user')")
|
||||
return False
|
||||
|
||||
# Validate port is numeric
|
||||
try:
|
||||
port = int(config["port"])
|
||||
if port <= 0 or port > 65535:
|
||||
logger.error(f"Invalid PostgreSQL port: {port}")
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
logger.error(f"PostgreSQL port must be numeric: {config['port']}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def create_database_manager(config: Dict[str, Any]) -> Optional[DatabaseManager]:
|
||||
"""
|
||||
Create database manager with configuration validation.
|
||||
|
||||
Args:
|
||||
config: Database configuration
|
||||
|
||||
Returns:
|
||||
DatabaseManager instance or None if invalid config
|
||||
"""
|
||||
if not validate_postgresql_config(config):
|
||||
return None
|
||||
|
||||
try:
|
||||
db_manager = DatabaseManager(config)
|
||||
if db_manager.connect():
|
||||
logger.info(f"Successfully created database manager for {config['host']}:{config['port']}")
|
||||
return db_manager
|
||||
else:
|
||||
logger.error("Failed to establish initial database connection")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating database manager: {e}")
|
||||
return None
|
733
detector_worker/storage/redis_client.py
Normal file
733
detector_worker/storage/redis_client.py
Normal file
|
@ -0,0 +1,733 @@
|
|||
"""
|
||||
Redis client management and operations.
|
||||
|
||||
This module provides comprehensive Redis functionality including:
|
||||
- Connection management with retries
|
||||
- Key-value operations
|
||||
- Pub/Sub messaging
|
||||
- Image storage with compression
|
||||
- Pipeline operations
|
||||
"""
|
||||
|
||||
import redis
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Union, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from ..core.constants import (
|
||||
REDIS_CONNECTION_TIMEOUT,
|
||||
REDIS_SOCKET_TIMEOUT,
|
||||
REDIS_IMAGE_DEFAULT_QUALITY,
|
||||
REDIS_IMAGE_DEFAULT_FORMAT
|
||||
)
|
||||
from ..core.exceptions import RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedisConfig:
|
||||
"""Redis connection configuration."""
|
||||
host: str
|
||||
port: int
|
||||
password: Optional[str] = None
|
||||
db: int = 0
|
||||
connection_timeout: int = REDIS_CONNECTION_TIMEOUT
|
||||
socket_timeout: int = REDIS_SOCKET_TIMEOUT
|
||||
retry_on_timeout: bool = True
|
||||
health_check_interval: int = 30
|
||||
max_connections: int = 10
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"password": self.password,
|
||||
"db": self.db,
|
||||
"connection_timeout": self.connection_timeout,
|
||||
"socket_timeout": self.socket_timeout,
|
||||
"retry_on_timeout": self.retry_on_timeout,
|
||||
"health_check_interval": self.health_check_interval,
|
||||
"max_connections": self.max_connections
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig':
|
||||
"""Create from dictionary data."""
|
||||
return cls(
|
||||
host=data["host"],
|
||||
port=data["port"],
|
||||
password=data.get("password"),
|
||||
db=data.get("db", 0),
|
||||
connection_timeout=data.get("connection_timeout", REDIS_CONNECTION_TIMEOUT),
|
||||
socket_timeout=data.get("socket_timeout", REDIS_SOCKET_TIMEOUT),
|
||||
retry_on_timeout=data.get("retry_on_timeout", True),
|
||||
health_check_interval=data.get("health_check_interval", 30),
|
||||
max_connections=data.get("max_connections", 10)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedisOperationResult:
|
||||
"""Result from Redis operation."""
|
||||
success: bool
|
||||
data: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
operation: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
execution_time: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"success": self.success,
|
||||
"operation": self.operation
|
||||
}
|
||||
if self.data is not None:
|
||||
result["data"] = self.data
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
if self.key:
|
||||
result["key"] = self.key
|
||||
if self.execution_time:
|
||||
result["execution_time"] = self.execution_time
|
||||
return result
|
||||
|
||||
|
||||
class RedisClientManager:
|
||||
"""
|
||||
Comprehensive Redis client manager with connection pooling and retry logic.
|
||||
|
||||
This class provides high-level Redis operations with automatic reconnection,
|
||||
connection pooling, and comprehensive error handling.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Union[Dict[str, Any], RedisConfig]):
|
||||
"""
|
||||
Initialize Redis client manager.
|
||||
|
||||
Args:
|
||||
config: Redis configuration dictionary or RedisConfig object
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
self.config = RedisConfig.from_dict(config)
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.client: Optional[redis.Redis] = None
|
||||
self.connection_pool: Optional[redis.ConnectionPool] = None
|
||||
self._lock = None
|
||||
self._last_health_check = 0.0
|
||||
|
||||
def _ensure_thread_safety(self):
|
||||
"""Initialize thread safety if not already done."""
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _create_connection_pool(self) -> redis.ConnectionPool:
|
||||
"""Create Redis connection pool."""
|
||||
pool_kwargs = {
|
||||
"host": self.config.host,
|
||||
"port": self.config.port,
|
||||
"db": self.config.db,
|
||||
"socket_timeout": self.config.socket_timeout,
|
||||
"socket_connect_timeout": self.config.connection_timeout,
|
||||
"retry_on_timeout": self.config.retry_on_timeout,
|
||||
"health_check_interval": self.config.health_check_interval,
|
||||
"max_connections": self.config.max_connections
|
||||
}
|
||||
|
||||
if self.config.password:
|
||||
pool_kwargs["password"] = self.config.password
|
||||
|
||||
return redis.ConnectionPool(**pool_kwargs)
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
Establish Redis connection.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
# Create connection pool
|
||||
self.connection_pool = self._create_connection_pool()
|
||||
|
||||
# Create Redis client
|
||||
self.client = redis.Redis(connection_pool=self.connection_pool)
|
||||
|
||||
# Test connection
|
||||
self.client.ping()
|
||||
self._last_health_check = time.time()
|
||||
|
||||
logger.info(f"Redis connection established successfully to {self.config.host}:{self.config.port}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
self.client = None
|
||||
self.connection_pool = None
|
||||
return False
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Close Redis connection and cleanup resources."""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
if self.connection_pool:
|
||||
try:
|
||||
self.connection_pool.disconnect()
|
||||
logger.info("Redis connection pool disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting Redis pool: {e}")
|
||||
finally:
|
||||
self.connection_pool = None
|
||||
|
||||
self.client = None
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""
|
||||
Check if Redis connection is active.
|
||||
|
||||
Returns:
|
||||
True if connected, False otherwise
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
if self.client:
|
||||
# Perform periodic health check
|
||||
current_time = time.time()
|
||||
if current_time - self._last_health_check > self.config.health_check_interval:
|
||||
self.client.ping()
|
||||
self._last_health_check = current_time
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Redis connection check failed: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def _ensure_connected(self) -> bool:
|
||||
"""Ensure Redis connection is active."""
|
||||
if not self.is_connected():
|
||||
return self.connect()
|
||||
return True
|
||||
|
||||
def _execute_operation(self,
|
||||
operation_name: str,
|
||||
operation_func,
|
||||
key: Optional[str] = None) -> RedisOperationResult:
|
||||
"""Execute Redis operation with error handling and timing."""
|
||||
if not self._ensure_connected():
|
||||
return RedisOperationResult(
|
||||
success=False,
|
||||
error="Failed to establish Redis connection",
|
||||
operation=operation_name,
|
||||
key=key
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = operation_func()
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return RedisOperationResult(
|
||||
success=True,
|
||||
data=result,
|
||||
operation=operation_name,
|
||||
key=key,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
error_msg = f"Redis {operation_name} operation failed: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
return RedisOperationResult(
|
||||
success=False,
|
||||
error=error_msg,
|
||||
operation=operation_name,
|
||||
key=key,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# ===== KEY-VALUE OPERATIONS =====
|
||||
|
||||
def set(self, key: str, value: Any, ex: Optional[int] = None) -> RedisOperationResult:
|
||||
"""
|
||||
Set key-value pair with optional expiration.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
value: Value to store
|
||||
ex: Expiration time in seconds
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with operation status
|
||||
"""
|
||||
def operation():
|
||||
return self.client.set(key, value, ex=ex)
|
||||
|
||||
result = self._execute_operation("SET", operation, key)
|
||||
|
||||
if result.success:
|
||||
expire_msg = f" (expires in {ex}s)" if ex else ""
|
||||
logger.debug(f"Set Redis key '{key}'{expire_msg}")
|
||||
|
||||
return result
|
||||
|
||||
def setex(self, key: str, time: int, value: Any) -> RedisOperationResult:
|
||||
"""
|
||||
Set key-value pair with expiration time.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
time: Expiration time in seconds
|
||||
value: Value to store
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with operation status
|
||||
"""
|
||||
def operation():
|
||||
return self.client.setex(key, time, value)
|
||||
|
||||
result = self._execute_operation("SETEX", operation, key)
|
||||
|
||||
if result.success:
|
||||
logger.debug(f"Set Redis key '{key}' with {time}s expiration")
|
||||
|
||||
return result
|
||||
|
||||
def get(self, key: str) -> RedisOperationResult:
|
||||
"""
|
||||
Get value by key.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with value data
|
||||
"""
|
||||
def operation():
|
||||
return self.client.get(key)
|
||||
|
||||
return self._execute_operation("GET", operation, key)
|
||||
|
||||
def delete(self, *keys: str) -> RedisOperationResult:
|
||||
"""
|
||||
Delete one or more keys.
|
||||
|
||||
Args:
|
||||
*keys: Redis keys to delete
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with number of deleted keys
|
||||
"""
|
||||
def operation():
|
||||
return self.client.delete(*keys)
|
||||
|
||||
result = self._execute_operation("DELETE", operation)
|
||||
|
||||
if result.success:
|
||||
logger.debug(f"Deleted {result.data} Redis key(s): {keys}")
|
||||
|
||||
return result
|
||||
|
||||
def exists(self, *keys: str) -> RedisOperationResult:
|
||||
"""
|
||||
Check if keys exist.
|
||||
|
||||
Args:
|
||||
*keys: Redis keys to check
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with count of existing keys
|
||||
"""
|
||||
def operation():
|
||||
return self.client.exists(*keys)
|
||||
|
||||
return self._execute_operation("EXISTS", operation)
|
||||
|
||||
def expire(self, key: str, time: int) -> RedisOperationResult:
|
||||
"""
|
||||
Set expiration time for a key.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
time: Expiration time in seconds
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with operation status
|
||||
"""
|
||||
def operation():
|
||||
return self.client.expire(key, time)
|
||||
|
||||
return self._execute_operation("EXPIRE", operation, key)
|
||||
|
||||
def ttl(self, key: str) -> RedisOperationResult:
|
||||
"""
|
||||
Get time to live for a key.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with TTL in seconds
|
||||
"""
|
||||
def operation():
|
||||
return self.client.ttl(key)
|
||||
|
||||
return self._execute_operation("TTL", operation, key)
|
||||
|
||||
# ===== PUB/SUB OPERATIONS =====
|
||||
|
||||
def publish(self, channel: str, message: Union[str, Dict[str, Any]]) -> RedisOperationResult:
|
||||
"""
|
||||
Publish message to channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name
|
||||
message: Message to publish (string or dict)
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with number of subscribers
|
||||
"""
|
||||
def operation():
|
||||
# Convert dict to JSON string
|
||||
if isinstance(message, dict):
|
||||
message_str = json.dumps(message)
|
||||
else:
|
||||
message_str = str(message)
|
||||
|
||||
return self.client.publish(channel, message_str)
|
||||
|
||||
result = self._execute_operation("PUBLISH", operation)
|
||||
|
||||
if result.success:
|
||||
logger.debug(f"Published to channel '{channel}', subscribers: {result.data}")
|
||||
if result.data == 0:
|
||||
logger.warning(f"No subscribers listening to channel '{channel}'")
|
||||
|
||||
return result
|
||||
|
||||
def subscribe(self, *channels: str):
|
||||
"""
|
||||
Subscribe to channels (returns pubsub object).
|
||||
|
||||
Args:
|
||||
*channels: Channel names to subscribe to
|
||||
|
||||
Returns:
|
||||
PubSub object for listening to messages
|
||||
"""
|
||||
if not self._ensure_connected():
|
||||
raise RedisError("Failed to establish Redis connection for subscription")
|
||||
|
||||
pubsub = self.client.pubsub()
|
||||
pubsub.subscribe(*channels)
|
||||
|
||||
logger.info(f"Subscribed to channels: {channels}")
|
||||
return pubsub
|
||||
|
||||
# ===== HASH OPERATIONS =====
|
||||
|
||||
def hset(self, key: str, field: str, value: Any) -> RedisOperationResult:
|
||||
"""
|
||||
Set hash field.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
field: Hash field name
|
||||
value: Field value
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with operation status
|
||||
"""
|
||||
def operation():
|
||||
return self.client.hset(key, field, value)
|
||||
|
||||
return self._execute_operation("HSET", operation, key)
|
||||
|
||||
def hget(self, key: str, field: str) -> RedisOperationResult:
|
||||
"""
|
||||
Get hash field value.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
field: Hash field name
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with field value
|
||||
"""
|
||||
def operation():
|
||||
return self.client.hget(key, field)
|
||||
|
||||
return self._execute_operation("HGET", operation, key)
|
||||
|
||||
def hgetall(self, key: str) -> RedisOperationResult:
|
||||
"""
|
||||
Get all hash fields and values.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with hash dictionary
|
||||
"""
|
||||
def operation():
|
||||
return self.client.hgetall(key)
|
||||
|
||||
return self._execute_operation("HGETALL", operation, key)
|
||||
|
||||
# ===== LIST OPERATIONS =====
|
||||
|
||||
def lpush(self, key: str, *values: Any) -> RedisOperationResult:
|
||||
"""
|
||||
Push values to left of list.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
*values: Values to push
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with list length
|
||||
"""
|
||||
def operation():
|
||||
return self.client.lpush(key, *values)
|
||||
|
||||
return self._execute_operation("LPUSH", operation, key)
|
||||
|
||||
def rpush(self, key: str, *values: Any) -> RedisOperationResult:
|
||||
"""
|
||||
Push values to right of list.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
*values: Values to push
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with list length
|
||||
"""
|
||||
def operation():
|
||||
return self.client.rpush(key, *values)
|
||||
|
||||
return self._execute_operation("RPUSH", operation, key)
|
||||
|
||||
def lrange(self, key: str, start: int, end: int) -> RedisOperationResult:
|
||||
"""
|
||||
Get range of list elements.
|
||||
|
||||
Args:
|
||||
key: Redis key
|
||||
start: Start index
|
||||
end: End index
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with list elements
|
||||
"""
|
||||
def operation():
|
||||
return self.client.lrange(key, start, end)
|
||||
|
||||
return self._execute_operation("LRANGE", operation, key)
|
||||
|
||||
# ===== UTILITY OPERATIONS =====
|
||||
|
||||
def ping(self) -> RedisOperationResult:
|
||||
"""
|
||||
Ping Redis server.
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with ping response
|
||||
"""
|
||||
def operation():
|
||||
return self.client.ping()
|
||||
|
||||
return self._execute_operation("PING", operation)
|
||||
|
||||
def info(self, section: Optional[str] = None) -> RedisOperationResult:
|
||||
"""
|
||||
Get Redis server information.
|
||||
|
||||
Args:
|
||||
section: Optional info section
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with server info
|
||||
"""
|
||||
def operation():
|
||||
return self.client.info(section)
|
||||
|
||||
return self._execute_operation("INFO", operation)
|
||||
|
||||
def flushdb(self) -> RedisOperationResult:
|
||||
"""
|
||||
Flush current database.
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with operation status
|
||||
"""
|
||||
def operation():
|
||||
return self.client.flushdb()
|
||||
|
||||
result = self._execute_operation("FLUSHDB", operation)
|
||||
|
||||
if result.success:
|
||||
logger.warning(f"Flushed Redis database {self.config.db}")
|
||||
|
||||
return result
|
||||
|
||||
def keys(self, pattern: str = "*") -> RedisOperationResult:
|
||||
"""
|
||||
Get keys matching pattern.
|
||||
|
||||
Args:
|
||||
pattern: Key pattern (default: all keys)
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with matching keys
|
||||
"""
|
||||
def operation():
|
||||
return self.client.keys(pattern)
|
||||
|
||||
return self._execute_operation("KEYS", operation)
|
||||
|
||||
# ===== BATCH OPERATIONS =====
|
||||
|
||||
def pipeline(self):
|
||||
"""
|
||||
Create Redis pipeline for batch operations.
|
||||
|
||||
Returns:
|
||||
Redis pipeline object
|
||||
"""
|
||||
if not self._ensure_connected():
|
||||
raise RedisError("Failed to establish Redis connection for pipeline")
|
||||
|
||||
return self.client.pipeline()
|
||||
|
||||
def execute_pipeline(self, pipeline) -> RedisOperationResult:
|
||||
"""
|
||||
Execute Redis pipeline.
|
||||
|
||||
Args:
|
||||
pipeline: Redis pipeline object
|
||||
|
||||
Returns:
|
||||
RedisOperationResult with pipeline results
|
||||
"""
|
||||
def operation():
|
||||
return pipeline.execute()
|
||||
|
||||
return self._execute_operation("PIPELINE", operation)
|
||||
|
||||
# ===== CONNECTION MANAGEMENT =====
|
||||
|
||||
def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get Redis connection statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with connection statistics
|
||||
"""
|
||||
stats = {
|
||||
"connected": self.is_connected(),
|
||||
"config": self.config.to_dict(),
|
||||
"last_health_check": self._last_health_check,
|
||||
"connection_pool_created": self.connection_pool is not None
|
||||
}
|
||||
|
||||
if self.connection_pool:
|
||||
stats["connection_pool_stats"] = {
|
||||
"created_connections": self.connection_pool.created_connections,
|
||||
"available_connections": len(self.connection_pool._available_connections),
|
||||
"in_use_connections": len(self.connection_pool._in_use_connections)
|
||||
}
|
||||
|
||||
# Get Redis server info if connected
|
||||
if self.is_connected():
|
||||
try:
|
||||
info_result = self.info()
|
||||
if info_result.success:
|
||||
redis_info = info_result.data
|
||||
stats["server_info"] = {
|
||||
"redis_version": redis_info.get("redis_version"),
|
||||
"connected_clients": redis_info.get("connected_clients"),
|
||||
"used_memory": redis_info.get("used_memory"),
|
||||
"used_memory_human": redis_info.get("used_memory_human"),
|
||||
"total_commands_processed": redis_info.get("total_commands_processed"),
|
||||
"uptime_in_seconds": redis_info.get("uptime_in_seconds")
|
||||
}
|
||||
except Exception as e:
|
||||
stats["server_info_error"] = str(e)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide compatibility and simplified access
|
||||
|
||||
def create_redis_client(config: Dict[str, Any]) -> Optional[RedisClientManager]:
|
||||
"""
|
||||
Create Redis client with configuration validation.
|
||||
|
||||
Args:
|
||||
config: Redis configuration
|
||||
|
||||
Returns:
|
||||
RedisClientManager instance or None if connection failed
|
||||
"""
|
||||
try:
|
||||
client_manager = RedisClientManager(config)
|
||||
if client_manager.connect():
|
||||
logger.info(f"Successfully created Redis client for {config['host']}:{config['port']}")
|
||||
return client_manager
|
||||
else:
|
||||
logger.error("Failed to establish initial Redis connection")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Redis client: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def validate_redis_config(config: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate Redis configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration is valid
|
||||
"""
|
||||
required_fields = ["host", "port"]
|
||||
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
logger.error(f"Missing required Redis config field: {field}")
|
||||
return False
|
||||
|
||||
if not config[field]:
|
||||
logger.error(f"Empty Redis config field: {field}")
|
||||
return False
|
||||
|
||||
# Validate port is numeric
|
||||
try:
|
||||
port = int(config["port"])
|
||||
if port <= 0 or port > 65535:
|
||||
logger.error(f"Invalid Redis port: {port}")
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
logger.error(f"Redis port must be numeric: {config['port']}")
|
||||
return False
|
||||
|
||||
return True
|
688
detector_worker/storage/session_cache.py
Normal file
688
detector_worker/storage/session_cache.py
Normal file
|
@ -0,0 +1,688 @@
|
|||
"""
|
||||
Session and cache management for detection workflows.
|
||||
|
||||
This module provides comprehensive session and cache management including:
|
||||
- Session state tracking and lifecycle management
|
||||
- Detection result caching with TTL
|
||||
- Pipeline mode state management
|
||||
- Session cleanup and garbage collection
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
|
||||
from ..core.constants import (
|
||||
SESSION_CACHE_TTL_MINUTES,
|
||||
DETECTION_CACHE_CLEANUP_INTERVAL,
|
||||
SESSION_TIMEOUT_SECONDS
|
||||
)
|
||||
from ..core.exceptions import SessionError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineMode(Enum):
|
||||
"""Pipeline execution modes."""
|
||||
VALIDATION_DETECTING = "validation_detecting"
|
||||
SEND_DETECTIONS = "send_detections"
|
||||
WAITING_FOR_SESSION_ID = "waiting_for_session_id"
|
||||
FULL_PIPELINE = "full_pipeline"
|
||||
LIGHTWEIGHT = "lightweight"
|
||||
CAR_GONE_WAITING = "car_gone_waiting"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
"""Session state information."""
|
||||
session_id: Optional[str] = None
|
||||
backend_session_id: Optional[str] = None
|
||||
mode: PipelineMode = PipelineMode.VALIDATION_DETECTING
|
||||
session_id_received: bool = False
|
||||
created_at: float = field(default_factory=time.time)
|
||||
last_updated: float = field(default_factory=time.time)
|
||||
last_detection: Optional[float] = None
|
||||
detection_count: int = 0
|
||||
|
||||
# Mode-specific state
|
||||
validation_frames_processed: int = 0
|
||||
stable_track_achieved: bool = False
|
||||
waiting_start_time: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"backend_session_id": self.backend_session_id,
|
||||
"mode": self.mode.value,
|
||||
"session_id_received": self.session_id_received,
|
||||
"created_at": self.created_at,
|
||||
"last_updated": self.last_updated,
|
||||
"last_detection": self.last_detection,
|
||||
"detection_count": self.detection_count,
|
||||
"validation_frames_processed": self.validation_frames_processed,
|
||||
"stable_track_achieved": self.stable_track_achieved,
|
||||
"waiting_start_time": self.waiting_start_time
|
||||
}
|
||||
|
||||
def update_activity(self) -> None:
|
||||
"""Update last activity timestamp."""
|
||||
self.last_updated = time.time()
|
||||
|
||||
def record_detection(self) -> None:
|
||||
"""Record a detection occurrence."""
|
||||
current_time = time.time()
|
||||
self.last_detection = current_time
|
||||
self.detection_count += 1
|
||||
self.update_activity()
|
||||
|
||||
def is_expired(self, ttl_seconds: int) -> bool:
|
||||
"""Check if session has expired based on TTL."""
|
||||
return time.time() - self.last_updated > ttl_seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedDetection:
|
||||
"""Cached detection result."""
|
||||
detection_data: Dict[str, Any]
|
||||
camera_id: str
|
||||
session_id: Optional[str] = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
access_count: int = 0
|
||||
last_accessed: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"detection_data": self.detection_data,
|
||||
"camera_id": self.camera_id,
|
||||
"session_id": self.session_id,
|
||||
"created_at": self.created_at,
|
||||
"access_count": self.access_count,
|
||||
"last_accessed": self.last_accessed
|
||||
}
|
||||
|
||||
def access(self) -> None:
|
||||
"""Record access to this cached detection."""
|
||||
self.access_count += 1
|
||||
self.last_accessed = time.time()
|
||||
|
||||
def is_expired(self, ttl_seconds: int) -> bool:
|
||||
"""Check if cached detection has expired."""
|
||||
return time.time() - self.created_at > ttl_seconds
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Session lifecycle and state management.
|
||||
|
||||
This class provides comprehensive session management including:
|
||||
- Session creation and cleanup
|
||||
- Pipeline mode transitions
|
||||
- Session timeout handling
|
||||
"""
|
||||
|
||||
def __init__(self, session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS):
|
||||
"""
|
||||
Initialize session manager.
|
||||
|
||||
Args:
|
||||
session_timeout_seconds: Default session timeout
|
||||
"""
|
||||
self.session_timeout_seconds = session_timeout_seconds
|
||||
self._sessions: Dict[str, SessionState] = {}
|
||||
self._session_ids: Dict[str, str] = {} # display_id -> session_id mapping
|
||||
self._lock = None
|
||||
|
||||
def _ensure_thread_safety(self):
|
||||
"""Initialize thread safety if not already done."""
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def create_session(self, display_id: str, session_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Create a new session or get existing one.
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
session_id: Optional session ID
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
# Check if session already exists for this display
|
||||
if display_id in self._session_ids:
|
||||
existing_session_id = self._session_ids[display_id]
|
||||
if existing_session_id in self._sessions:
|
||||
session_state = self._sessions[existing_session_id]
|
||||
session_state.update_activity()
|
||||
logger.debug(f"Using existing session for display {display_id}: {existing_session_id}")
|
||||
return existing_session_id
|
||||
|
||||
# Create new session
|
||||
if not session_id:
|
||||
import uuid
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
session_state = SessionState(session_id=session_id)
|
||||
self._sessions[session_id] = session_state
|
||||
self._session_ids[display_id] = session_id
|
||||
|
||||
logger.info(f"Created new session for display {display_id}: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[SessionState]:
|
||||
"""
|
||||
Get session state by session ID.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
SessionState or None if not found
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
session = self._sessions.get(session_id)
|
||||
if session:
|
||||
session.update_activity()
|
||||
return session
|
||||
|
||||
def get_session_by_display(self, display_id: str) -> Optional[SessionState]:
|
||||
"""
|
||||
Get session state by display ID.
|
||||
|
||||
Args:
|
||||
display_id: Display identifier
|
||||
|
||||
Returns:
|
||||
SessionState or None if not found
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
session_id = self._session_ids.get(display_id)
|
||||
if session_id:
|
||||
return self.get_session(session_id)
|
||||
return None
|
||||
|
||||
def update_session_mode(self,
|
||||
session_id: str,
|
||||
new_mode: PipelineMode,
|
||||
backend_session_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Update session pipeline mode.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
new_mode: New pipeline mode
|
||||
backend_session_id: Optional backend session ID
|
||||
|
||||
Returns:
|
||||
True if updated successfully
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
logger.warning(f"Session not found for mode update: {session_id}")
|
||||
return False
|
||||
|
||||
old_mode = session.mode
|
||||
session.mode = new_mode
|
||||
|
||||
if backend_session_id:
|
||||
session.backend_session_id = backend_session_id
|
||||
session.session_id_received = True
|
||||
|
||||
# Handle mode-specific state changes
|
||||
if new_mode == PipelineMode.WAITING_FOR_SESSION_ID:
|
||||
session.waiting_start_time = time.time()
|
||||
elif old_mode == PipelineMode.WAITING_FOR_SESSION_ID:
|
||||
session.waiting_start_time = None
|
||||
|
||||
session.update_activity()
|
||||
|
||||
logger.info(f"Session {session_id}: Mode changed from {old_mode.value} to {new_mode.value}")
|
||||
return True
|
||||
|
||||
def record_detection(self, session_id: str) -> bool:
|
||||
"""
|
||||
Record a detection for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
True if recorded successfully
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
session = self.get_session(session_id)
|
||||
if session:
|
||||
session.record_detection()
|
||||
return True
|
||||
return False
|
||||
|
||||
def cleanup_expired_sessions(self, ttl_seconds: Optional[int] = None) -> int:
|
||||
"""
|
||||
Clean up expired sessions.
|
||||
|
||||
Args:
|
||||
ttl_seconds: TTL in seconds (uses default if not provided)
|
||||
|
||||
Returns:
|
||||
Number of sessions cleaned up
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
if ttl_seconds is None:
|
||||
ttl_seconds = SESSION_CACHE_TTL_MINUTES * 60
|
||||
|
||||
with self._lock:
|
||||
expired_sessions = []
|
||||
expired_displays = []
|
||||
|
||||
# Find expired sessions
|
||||
for session_id, session in self._sessions.items():
|
||||
if session.is_expired(ttl_seconds):
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
# Find displays pointing to expired sessions
|
||||
for display_id, session_id in self._session_ids.items():
|
||||
if session_id in expired_sessions:
|
||||
expired_displays.append(display_id)
|
||||
|
||||
# Remove expired sessions and mappings
|
||||
for session_id in expired_sessions:
|
||||
del self._sessions[session_id]
|
||||
|
||||
for display_id in expired_displays:
|
||||
del self._session_ids[display_id]
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"Cleaned up {len(expired_sessions)} expired sessions")
|
||||
|
||||
return len(expired_sessions)
|
||||
|
||||
def get_session_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get session management statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with session statistics
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
current_time = time.time()
|
||||
mode_counts = {}
|
||||
total_detections = 0
|
||||
|
||||
for session in self._sessions.values():
|
||||
mode = session.mode.value
|
||||
mode_counts[mode] = mode_counts.get(mode, 0) + 1
|
||||
total_detections += session.detection_count
|
||||
|
||||
return {
|
||||
"total_sessions": len(self._sessions),
|
||||
"total_display_mappings": len(self._session_ids),
|
||||
"mode_distribution": mode_counts,
|
||||
"total_detections_processed": total_detections,
|
||||
"session_timeout_seconds": self.session_timeout_seconds
|
||||
}
|
||||
|
||||
|
||||
class DetectionCache:
|
||||
"""
|
||||
Detection result caching with TTL and access tracking.
|
||||
|
||||
This class provides caching for detection results with automatic
|
||||
expiration and access pattern tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_minutes: int = SESSION_CACHE_TTL_MINUTES):
|
||||
"""
|
||||
Initialize detection cache.
|
||||
|
||||
Args:
|
||||
ttl_minutes: Time to live for cached detections in minutes
|
||||
"""
|
||||
self.ttl_seconds = ttl_minutes * 60
|
||||
self._cache: Dict[str, CachedDetection] = {}
|
||||
self._camera_index: Dict[str, Set[str]] = {} # camera_id -> set of cache keys
|
||||
self._session_index: Dict[str, Set[str]] = {} # session_id -> set of cache keys
|
||||
self._lock = None
|
||||
self._last_cleanup = time.time()
|
||||
|
||||
def _ensure_thread_safety(self):
|
||||
"""Initialize thread safety if not already done."""
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _generate_cache_key(self, camera_id: str, detection_type: str = "default") -> str:
|
||||
"""Generate cache key for detection."""
|
||||
return f"{camera_id}:{detection_type}:{time.time()}"
|
||||
|
||||
def store_detection(self,
|
||||
camera_id: str,
|
||||
detection_data: Dict[str, Any],
|
||||
session_id: Optional[str] = None,
|
||||
detection_type: str = "default") -> str:
|
||||
"""
|
||||
Store detection in cache.
|
||||
|
||||
Args:
|
||||
camera_id: Camera identifier
|
||||
detection_data: Detection result data
|
||||
session_id: Optional session identifier
|
||||
detection_type: Type of detection for categorization
|
||||
|
||||
Returns:
|
||||
Cache key for the stored detection
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
cache_key = self._generate_cache_key(camera_id, detection_type)
|
||||
|
||||
cached_detection = CachedDetection(
|
||||
detection_data=detection_data,
|
||||
camera_id=camera_id,
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
self._cache[cache_key] = cached_detection
|
||||
|
||||
# Update indexes
|
||||
if camera_id not in self._camera_index:
|
||||
self._camera_index[camera_id] = set()
|
||||
self._camera_index[camera_id].add(cache_key)
|
||||
|
||||
if session_id:
|
||||
if session_id not in self._session_index:
|
||||
self._session_index[session_id] = set()
|
||||
self._session_index[session_id].add(cache_key)
|
||||
|
||||
logger.debug(f"Stored detection in cache: {cache_key}")
|
||||
return cache_key
|
||||
|
||||
def get_detection(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get detection from cache by key.
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
|
||||
Returns:
|
||||
Detection data or None if not found/expired
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
cached_detection = self._cache.get(cache_key)
|
||||
if not cached_detection:
|
||||
return None
|
||||
|
||||
if cached_detection.is_expired(self.ttl_seconds):
|
||||
self._remove_from_cache(cache_key)
|
||||
return None
|
||||
|
||||
cached_detection.access()
|
||||
return cached_detection.detection_data
|
||||
|
||||
def get_latest_detection(self, camera_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get latest detection for a camera.
|
||||
|
||||
Args:
|
||||
camera_id: Camera identifier
|
||||
|
||||
Returns:
|
||||
Latest detection data or None if not found
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
camera_keys = self._camera_index.get(camera_id, set())
|
||||
if not camera_keys:
|
||||
return None
|
||||
|
||||
# Find the most recent non-expired detection
|
||||
latest_detection = None
|
||||
latest_time = 0
|
||||
|
||||
for cache_key in camera_keys:
|
||||
cached_detection = self._cache.get(cache_key)
|
||||
if cached_detection and not cached_detection.is_expired(self.ttl_seconds):
|
||||
if cached_detection.created_at > latest_time:
|
||||
latest_time = cached_detection.created_at
|
||||
latest_detection = cached_detection
|
||||
|
||||
if latest_detection:
|
||||
latest_detection.access()
|
||||
return latest_detection.detection_data
|
||||
|
||||
return None
|
||||
|
||||
def get_session_detections(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all detections for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
List of detection data dictionaries
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
session_keys = self._session_index.get(session_id, set())
|
||||
detections = []
|
||||
|
||||
for cache_key in session_keys:
|
||||
cached_detection = self._cache.get(cache_key)
|
||||
if cached_detection and not cached_detection.is_expired(self.ttl_seconds):
|
||||
cached_detection.access()
|
||||
detections.append(cached_detection.detection_data)
|
||||
|
||||
# Sort by creation time (newest first)
|
||||
detections.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
|
||||
return detections
|
||||
|
||||
def _remove_from_cache(self, cache_key: str) -> None:
|
||||
"""Remove detection from cache and indexes."""
|
||||
cached_detection = self._cache.get(cache_key)
|
||||
if cached_detection:
|
||||
# Remove from indexes
|
||||
camera_id = cached_detection.camera_id
|
||||
if camera_id in self._camera_index:
|
||||
self._camera_index[camera_id].discard(cache_key)
|
||||
if not self._camera_index[camera_id]:
|
||||
del self._camera_index[camera_id]
|
||||
|
||||
session_id = cached_detection.session_id
|
||||
if session_id and session_id in self._session_index:
|
||||
self._session_index[session_id].discard(cache_key)
|
||||
if not self._session_index[session_id]:
|
||||
del self._session_index[session_id]
|
||||
|
||||
# Remove from main cache
|
||||
if cache_key in self._cache:
|
||||
del self._cache[cache_key]
|
||||
|
||||
def cleanup_expired_detections(self) -> int:
|
||||
"""
|
||||
Clean up expired cached detections.
|
||||
|
||||
Returns:
|
||||
Number of detections cleaned up
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
expired_keys = []
|
||||
|
||||
# Find expired detections
|
||||
for cache_key, cached_detection in self._cache.items():
|
||||
if cached_detection.is_expired(self.ttl_seconds):
|
||||
expired_keys.append(cache_key)
|
||||
|
||||
# Remove expired detections
|
||||
for cache_key in expired_keys:
|
||||
self._remove_from_cache(cache_key)
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"Cleaned up {len(expired_keys)} expired cached detections")
|
||||
|
||||
self._last_cleanup = time.time()
|
||||
return len(expired_keys)
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
self._ensure_thread_safety()
|
||||
|
||||
with self._lock:
|
||||
total_access_count = sum(detection.access_count for detection in self._cache.values())
|
||||
|
||||
return {
|
||||
"total_cached_detections": len(self._cache),
|
||||
"cameras_with_cache": len(self._camera_index),
|
||||
"sessions_with_cache": len(self._session_index),
|
||||
"total_access_count": total_access_count,
|
||||
"ttl_seconds": self.ttl_seconds,
|
||||
"last_cleanup": self._last_cleanup
|
||||
}
|
||||
|
||||
|
||||
class SessionCacheManager:
|
||||
"""
|
||||
Combined session and cache management.
|
||||
|
||||
This class provides unified management of sessions and detection caching
|
||||
with automatic cleanup and comprehensive statistics.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS,
|
||||
cache_ttl_minutes: int = SESSION_CACHE_TTL_MINUTES):
|
||||
"""
|
||||
Initialize session cache manager.
|
||||
|
||||
Args:
|
||||
session_timeout_seconds: Session timeout in seconds
|
||||
cache_ttl_minutes: Cache TTL in minutes
|
||||
"""
|
||||
self.session_manager = SessionManager(session_timeout_seconds)
|
||||
self.detection_cache = DetectionCache(cache_ttl_minutes)
|
||||
self._last_cleanup = time.time()
|
||||
|
||||
def cleanup_expired_data(self, force: bool = False) -> Dict[str, int]:
|
||||
"""
|
||||
Clean up expired sessions and cached detections.
|
||||
|
||||
Args:
|
||||
force: Force cleanup regardless of interval
|
||||
|
||||
Returns:
|
||||
Dictionary with cleanup statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Check if cleanup is needed
|
||||
if not force and current_time - self._last_cleanup < DETECTION_CACHE_CLEANUP_INTERVAL:
|
||||
return {"sessions_cleaned": 0, "detections_cleaned": 0, "cleanup_skipped": True}
|
||||
|
||||
sessions_cleaned = self.session_manager.cleanup_expired_sessions()
|
||||
detections_cleaned = self.detection_cache.cleanup_expired_detections()
|
||||
|
||||
self._last_cleanup = current_time
|
||||
|
||||
return {
|
||||
"sessions_cleaned": sessions_cleaned,
|
||||
"detections_cleaned": detections_cleaned,
|
||||
"cleanup_skipped": False
|
||||
}
|
||||
|
||||
def get_comprehensive_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get comprehensive statistics for sessions and cache.
|
||||
|
||||
Returns:
|
||||
Dictionary with all statistics
|
||||
"""
|
||||
return {
|
||||
"session_stats": self.session_manager.get_session_stats(),
|
||||
"cache_stats": self.detection_cache.get_cache_stats(),
|
||||
"last_cleanup": self._last_cleanup,
|
||||
"cleanup_interval_seconds": DETECTION_CACHE_CLEANUP_INTERVAL
|
||||
}
|
||||
|
||||
|
||||
# Global session cache manager instance
|
||||
session_cache_manager = SessionCacheManager()
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide simplified access to session and cache functionality
|
||||
|
||||
def create_session(display_id: str, session_id: Optional[str] = None) -> str:
|
||||
"""Create a new session using global manager."""
|
||||
return session_cache_manager.session_manager.create_session(display_id, session_id)
|
||||
|
||||
|
||||
def get_session_state(session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get session state by session ID."""
|
||||
session = session_cache_manager.session_manager.get_session(session_id)
|
||||
return session.to_dict() if session else None
|
||||
|
||||
|
||||
def update_pipeline_mode(session_id: str,
|
||||
new_mode: str,
|
||||
backend_session_id: Optional[str] = None) -> bool:
|
||||
"""Update session pipeline mode."""
|
||||
try:
|
||||
mode = PipelineMode(new_mode)
|
||||
return session_cache_manager.session_manager.update_session_mode(session_id, mode, backend_session_id)
|
||||
except ValueError:
|
||||
logger.error(f"Invalid pipeline mode: {new_mode}")
|
||||
return False
|
||||
|
||||
|
||||
def cache_detection(camera_id: str,
|
||||
detection_data: Dict[str, Any],
|
||||
session_id: Optional[str] = None) -> str:
|
||||
"""Cache detection data using global manager."""
|
||||
return session_cache_manager.detection_cache.store_detection(camera_id, detection_data, session_id)
|
||||
|
||||
|
||||
def get_cached_detection(camera_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get latest cached detection for a camera."""
|
||||
return session_cache_manager.detection_cache.get_latest_detection(camera_id)
|
||||
|
||||
|
||||
def cleanup_expired_sessions_and_cache() -> Dict[str, int]:
|
||||
"""Clean up expired sessions and cached data."""
|
||||
return session_cache_manager.cleanup_expired_data()
|
||||
|
||||
|
||||
def get_session_and_cache_stats() -> Dict[str, Any]:
|
||||
"""Get comprehensive session and cache statistics."""
|
||||
return session_cache_manager.get_comprehensive_stats()
|
Loading…
Add table
Add a link
Reference in a new issue