diff --git a/detector_worker/pipeline/action_executor.py b/detector_worker/pipeline/action_executor.py new file mode 100644 index 0000000..1490ad9 --- /dev/null +++ b/detector_worker/pipeline/action_executor.py @@ -0,0 +1,669 @@ +""" +Action execution engine for pipeline workflows. + +This module provides comprehensive action execution functionality including: +- Redis-based actions (save image, publish messages) +- Database actions (PostgreSQL updates with field mapping) +- Parallel action coordination +- Dynamic context resolution +""" + +import cv2 +import json +import time +import uuid +import logging +from typing import Dict, List, Any, Optional, Tuple +from dataclasses import dataclass, field +from datetime import datetime +import numpy as np + +from ..core.constants import ( + REDIS_IMAGE_DEFAULT_QUALITY, + REDIS_IMAGE_DEFAULT_FORMAT, + DEFAULT_IMAGE_ENCODING_PARAMS +) +from ..core.exceptions import ActionExecutionError, create_pipeline_error +from ..pipeline.field_mapper import resolve_field_mapping + +logger = logging.getLogger(__name__) + + +@dataclass +class ActionContext: + """Context information for action execution.""" + timestamp_ms: int = field(default_factory=lambda: int(time.time() * 1000)) + uuid: str = field(default_factory=lambda: str(uuid.uuid4())) + timestamp: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + filename: str = field(default_factory=lambda: f"{uuid.uuid4()}.jpg") + image_key: Optional[str] = None + + # Detection context + camera_id: str = "unknown" + display_id: str = "unknown" + session_id: Optional[str] = None + backend_session_id: Optional[str] = None + + # Additional context from detection result + regions: Optional[Dict[str, Any]] = None + detections: Optional[List[Dict[str, Any]]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for template formatting.""" + result = { + "timestamp_ms": self.timestamp_ms, + "uuid": self.uuid, + "timestamp": self.timestamp, + "filename": self.filename, + "camera_id": self.camera_id, + "display_id": self.display_id, + } + + if self.image_key: + result["image_key"] = self.image_key + if self.session_id: + result["session_id"] = self.session_id + if self.backend_session_id: + result["backend_session_id"] = self.backend_session_id + + return result + + def update_from_dict(self, data: Dict[str, Any]) -> None: + """Update context from detection result dictionary.""" + if "camera_id" in data: + self.camera_id = data["camera_id"] + if "display_id" in data: + self.display_id = data["display_id"] + if "session_id" in data: + self.session_id = data["session_id"] + if "backend_session_id" in data: + self.backend_session_id = data["backend_session_id"] + if "regions" in data: + self.regions = data["regions"] + if "detections" in data: + self.detections = data["detections"] + + +@dataclass +class ActionResult: + """Result from action execution.""" + success: bool + action_type: str + message: str = "" + data: Optional[Dict[str, Any]] = None + error: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + result = { + "success": self.success, + "action_type": self.action_type, + "message": self.message + } + if self.data: + result["data"] = self.data + if self.error: + result["error"] = self.error + return result + + +class RedisActionExecutor: + """Executor for Redis-based actions.""" + + def __init__(self, redis_client): + """ + Initialize Redis action executor. + + Args: + redis_client: Redis client instance + """ + self.redis_client = redis_client + + def _crop_region_by_class(self, + frame: np.ndarray, + regions_dict: Dict[str, Any], + class_name: str) -> Optional[np.ndarray]: + """Crop a specific region from frame based on detected class.""" + if class_name not in regions_dict: + logger.warning(f"Class '{class_name}' not found in detected regions") + return None + + bbox = regions_dict[class_name]["bbox"] + x1, y1, x2, y2 = bbox + + # Validate bbox coordinates + if x2 <= x1 or y2 <= y1: + logger.warning(f"Invalid bbox for class {class_name}: {bbox}") + return None + + try: + cropped = frame[y1:y2, x1:x2] + if cropped.size == 0: + logger.warning(f"Empty crop for class {class_name}") + return None + return cropped + except Exception as e: + logger.error(f"Error cropping region for class {class_name}: {e}") + return None + + def _encode_image(self, + image: np.ndarray, + img_format: str = REDIS_IMAGE_DEFAULT_FORMAT, + quality: int = REDIS_IMAGE_DEFAULT_QUALITY) -> Tuple[bool, Optional[bytes]]: + """Encode image with specified format and quality.""" + try: + img_format = img_format.lower() + + if img_format == "jpeg" or img_format == "jpg": + encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] + success, buffer = cv2.imencode('.jpg', image, encode_params) + elif img_format == "png": + compression = DEFAULT_IMAGE_ENCODING_PARAMS["png"]["compression"] + encode_params = [cv2.IMWRITE_PNG_COMPRESSION, compression] + success, buffer = cv2.imencode('.png', image, encode_params) + elif img_format == "webp": + encode_params = [cv2.IMWRITE_WEBP_QUALITY, quality] + success, buffer = cv2.imencode('.webp', image, encode_params) + else: + # Default to JPEG + encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] + success, buffer = cv2.imencode('.jpg', image, encode_params) + + if success: + return True, buffer.tobytes() + else: + return False, None + + except Exception as e: + logger.error(f"Error encoding image as {img_format}: {e}") + return False, None + + def execute_redis_save_image(self, + action: Dict[str, Any], + frame: np.ndarray, + context: ActionContext, + regions_dict: Optional[Dict[str, Any]] = None) -> ActionResult: + """Execute redis_save_image action.""" + try: + # Format the key template + key = action["key"].format(**context.to_dict()) + + # Check if we need to crop a specific region + region_name = action.get("region") + image_to_save = frame + + if region_name and regions_dict: + cropped_image = self._crop_region_by_class(frame, regions_dict, region_name) + if cropped_image is not None: + image_to_save = cropped_image + logger.debug(f"Cropped region '{region_name}' for redis_save_image") + else: + logger.warning(f"Could not crop region '{region_name}', saving full frame instead") + + # Encode image with specified format and quality + img_format = action.get("format", REDIS_IMAGE_DEFAULT_FORMAT) + quality = action.get("quality", REDIS_IMAGE_DEFAULT_QUALITY) + + success, image_bytes = self._encode_image(image_to_save, img_format, quality) + + if not success or not image_bytes: + return ActionResult( + success=False, + action_type="redis_save_image", + error=f"Failed to encode image as {img_format}" + ) + + # Save to Redis + expire_seconds = action.get("expire_seconds") + if expire_seconds: + self.redis_client.setex(key, expire_seconds, image_bytes) + message = f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)" + else: + self.redis_client.set(key, image_bytes) + message = f"Saved image to Redis with key: {key}" + + logger.info(message) + + # Update context with image key + context.image_key = key + + return ActionResult( + success=True, + action_type="redis_save_image", + message=message, + data={"key": key, "format": img_format, "quality": quality} + ) + + except Exception as e: + error_msg = f"Error executing redis_save_image: {e}" + logger.error(error_msg) + return ActionResult( + success=False, + action_type="redis_save_image", + error=error_msg + ) + + def execute_redis_publish(self, + action: Dict[str, Any], + context: ActionContext) -> ActionResult: + """Execute redis_publish action.""" + try: + channel = action["channel"] + message_template = action["message"] + + # Handle JSON message format by creating it programmatically + if message_template.strip().startswith('{') and message_template.strip().endswith('}'): + # Create JSON data programmatically to avoid formatting issues + json_data = {} + + # Add common fields + json_data["event"] = "frontal_detected" + json_data["display_id"] = context.display_id + json_data["session_id"] = context.session_id + json_data["timestamp"] = context.timestamp + json_data["image_key"] = context.image_key or "" + + # Convert to JSON string + message = json.dumps(json_data) + else: + # Use regular string formatting for non-JSON messages + message = message_template.format(**context.to_dict()) + + # Test Redis connection + if not self.redis_client: + return ActionResult( + success=False, + action_type="redis_publish", + error="Redis client is None" + ) + + try: + self.redis_client.ping() + logger.debug("Redis connection is active") + except Exception as ping_error: + return ActionResult( + success=False, + action_type="redis_publish", + error=f"Redis connection test failed: {ping_error}" + ) + + # Publish to Redis + result = self.redis_client.publish(channel, message) + + success_msg = f"Published message to Redis channel '{channel}': {message}" + logger.info(success_msg) + logger.info(f"Redis publish result (subscribers count): {result}") + + # Additional debug info + if result == 0: + logger.warning(f"No subscribers listening to channel '{channel}'") + else: + logger.info(f"Message delivered to {result} subscriber(s)") + + return ActionResult( + success=True, + action_type="redis_publish", + message=success_msg, + data={ + "channel": channel, + "message": message, + "subscribers_count": result + } + ) + + except KeyError as e: + error_msg = f"Missing key in redis_publish message template: {e}" + logger.error(error_msg) + logger.debug(f"Available context keys: {list(context.to_dict().keys())}") + return ActionResult( + success=False, + action_type="redis_publish", + error=error_msg + ) + except Exception as e: + error_msg = f"Error in redis_publish action: {e}" + logger.error(error_msg) + logger.debug(f"Message template: {action['message']}") + logger.debug(f"Available context keys: {list(context.to_dict().keys())}") + return ActionResult( + success=False, + action_type="redis_publish", + error=error_msg + ) + + +class DatabaseActionExecutor: + """Executor for database-based actions.""" + + def __init__(self, db_manager): + """ + Initialize database action executor. + + Args: + db_manager: Database manager instance + """ + self.db_manager = db_manager + + def execute_postgresql_update_combined(self, + action: Dict[str, Any], + detection_result: Dict[str, Any], + branch_results: Dict[str, Any]) -> ActionResult: + """Execute PostgreSQL update with combined branch results.""" + try: + table = action["table"] + key_field = action["key_field"] + key_value_template = action["key_value"] + fields = action["fields"] + + # Create context for key value formatting + action_context = {**detection_result} + key_value = key_value_template.format(**action_context) + + logger.info(f"Executing database update: table={table}, {key_field}={key_value}") + logger.debug(f"Available branch results: {list(branch_results.keys())}") + + # Process field mappings + mapped_fields = {} + mapping_errors = [] + + for db_field, value_template in fields.items(): + try: + mapped_value = resolve_field_mapping(value_template, branch_results, action_context) + if mapped_value is not None: + mapped_fields[db_field] = mapped_value + logger.info(f"Mapped field: {db_field} = {mapped_value}") + else: + error_msg = f"Could not resolve field mapping for {db_field}: {value_template}" + logger.warning(error_msg) + mapping_errors.append(error_msg) + logger.debug(f"Available branch results: {branch_results}") + except Exception as e: + error_msg = f"Error mapping field {db_field} with template '{value_template}': {e}" + logger.error(error_msg) + mapping_errors.append(error_msg) + + if not mapped_fields: + error_msg = "No fields mapped successfully, skipping database update" + logger.warning(error_msg) + logger.debug(f"Branch results available: {branch_results}") + logger.debug(f"Field templates: {fields}") + return ActionResult( + success=False, + action_type="postgresql_update_combined", + error=error_msg, + data={"mapping_errors": mapping_errors} + ) + + # Add updated_at field automatically + mapped_fields["updated_at"] = "NOW()" + + # Execute the database update + logger.info(f"Attempting database update with fields: {mapped_fields}") + success = self.db_manager.execute_update(table, key_field, key_value, mapped_fields) + + if success: + success_msg = f"Successfully updated database: {table} with {len(mapped_fields)} fields" + logger.info(f"✅ {success_msg}") + logger.info(f"Updated fields: {mapped_fields}") + return ActionResult( + success=True, + action_type="postgresql_update_combined", + message=success_msg, + data={ + "table": table, + "key_field": key_field, + "key_value": key_value, + "updated_fields": mapped_fields + } + ) + else: + error_msg = f"Failed to update database: {table}" + logger.error(f"❌ {error_msg}") + logger.error(f"Attempted update with: {key_field}={key_value}, fields={mapped_fields}") + return ActionResult( + success=False, + action_type="postgresql_update_combined", + error=error_msg, + data={ + "table": table, + "key_field": key_field, + "key_value": key_value, + "attempted_fields": mapped_fields + } + ) + + except KeyError as e: + error_msg = f"Missing required field in postgresql_update_combined action: {e}" + logger.error(error_msg) + logger.debug(f"Action config: {action}") + return ActionResult( + success=False, + action_type="postgresql_update_combined", + error=error_msg + ) + except Exception as e: + error_msg = f"Error in postgresql_update_combined action: {e}" + logger.error(error_msg) + return ActionResult( + success=False, + action_type="postgresql_update_combined", + error=error_msg + ) + + +class ActionExecutor: + """ + Main action execution engine for pipeline workflows. + + This class coordinates the execution of various action types including + Redis operations, database operations, and parallel action coordination. + """ + + def __init__(self): + """Initialize action executor.""" + self._redis_executors: Dict[Any, RedisActionExecutor] = {} + self._db_executors: Dict[Any, DatabaseActionExecutor] = {} + + def _get_redis_executor(self, redis_client) -> RedisActionExecutor: + """Get or create Redis executor for a client.""" + if redis_client not in self._redis_executors: + self._redis_executors[redis_client] = RedisActionExecutor(redis_client) + return self._redis_executors[redis_client] + + def _get_db_executor(self, db_manager) -> DatabaseActionExecutor: + """Get or create database executor for a manager.""" + if db_manager not in self._db_executors: + self._db_executors[db_manager] = DatabaseActionExecutor(db_manager) + return self._db_executors[db_manager] + + def _create_action_context(self, detection_result: Dict[str, Any]) -> ActionContext: + """Create action context from detection result.""" + context = ActionContext() + context.update_from_dict(detection_result) + return context + + def execute_actions(self, + node: Dict[str, Any], + frame: np.ndarray, + detection_result: Dict[str, Any], + regions_dict: Optional[Dict[str, Any]] = None) -> List[ActionResult]: + """ + Execute all actions for a node. + + Args: + node: Pipeline node configuration + frame: Input frame + detection_result: Detection result data + regions_dict: Optional region dictionary for cropping + + Returns: + List of action results + """ + if not node.get("redis_client") or not node.get("actions"): + return [] + + results = [] + redis_executor = self._get_redis_executor(node["redis_client"]) + context = self._create_action_context(detection_result) + + for action in node["actions"]: + try: + action_type = action.get("type") + + if action_type == "redis_save_image": + result = redis_executor.execute_redis_save_image( + action, frame, context, regions_dict + ) + elif action_type == "redis_publish": + result = redis_executor.execute_redis_publish(action, context) + else: + result = ActionResult( + success=False, + action_type=action_type or "unknown", + error=f"Unknown action type: {action_type}" + ) + + results.append(result) + + if not result.success: + logger.error(f"Action failed: {result.error}") + + except Exception as e: + error_msg = f"Error executing action {action.get('type', 'unknown')}: {e}" + logger.error(error_msg) + results.append(ActionResult( + success=False, + action_type=action.get('type', 'unknown'), + error=error_msg + )) + + return results + + def execute_parallel_actions(self, + node: Dict[str, Any], + frame: np.ndarray, + detection_result: Dict[str, Any], + regions_dict: Dict[str, Any]) -> List[ActionResult]: + """ + Execute parallel actions after all required branches have completed. + + Args: + node: Pipeline node configuration + frame: Input frame + detection_result: Detection result data with branch results + regions_dict: Region dictionary + + Returns: + List of action results + """ + if not node.get("parallelActions"): + return [] + + results = [] + branch_results = detection_result.get("branch_results", {}) + + logger.debug("Executing parallel actions...") + + for action in node["parallelActions"]: + try: + action_type = action.get("type") + logger.debug(f"Processing parallel action: {action_type}") + + if action_type == "postgresql_update_combined": + # Check if all required branches have completed + wait_for_branches = action.get("waitForBranches", []) + missing_branches = [branch for branch in wait_for_branches if branch not in branch_results] + + if missing_branches: + error_msg = f"Cannot execute postgresql_update_combined: missing branch results for {missing_branches}" + logger.warning(error_msg) + results.append(ActionResult( + success=False, + action_type=action_type, + error=error_msg, + data={"missing_branches": missing_branches} + )) + continue + + logger.info(f"All required branches completed: {wait_for_branches}") + + # Execute the database update + if node.get("db_manager"): + db_executor = self._get_db_executor(node["db_manager"]) + result = db_executor.execute_postgresql_update_combined( + action, detection_result, branch_results + ) + results.append(result) + else: + results.append(ActionResult( + success=False, + action_type=action_type, + error="No database manager available" + )) + else: + error_msg = f"Unknown parallel action type: {action_type}" + logger.warning(error_msg) + results.append(ActionResult( + success=False, + action_type=action_type or "unknown", + error=error_msg + )) + + except Exception as e: + error_msg = f"Error executing parallel action {action.get('type', 'unknown')}: {e}" + logger.error(error_msg) + results.append(ActionResult( + success=False, + action_type=action.get('type', 'unknown'), + error=error_msg + )) + + return results + + +# Global action executor instance +action_executor = ActionExecutor() + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide the same interface as the original functions in pympta.py + +def execute_actions(node: Dict[str, Any], + frame: np.ndarray, + detection_result: Dict[str, Any], + regions_dict: Optional[Dict[str, Any]] = None) -> None: + """Execute all actions for a node using global executor.""" + results = action_executor.execute_actions(node, frame, detection_result, regions_dict) + + # Log any failures (maintain original behavior of not returning results) + for result in results: + if not result.success: + logger.error(f"Action execution failed: {result.error}") + + +def execute_parallel_actions(node: Dict[str, Any], + frame: np.ndarray, + detection_result: Dict[str, Any], + regions_dict: Dict[str, Any]) -> None: + """Execute parallel actions using global executor.""" + results = action_executor.execute_parallel_actions(node, frame, detection_result, regions_dict) + + # Log any failures (maintain original behavior of not returning results) + for result in results: + if not result.success: + logger.error(f"Parallel action execution failed: {result.error}") + + +def execute_postgresql_update_combined(node: Dict[str, Any], + action: Dict[str, Any], + detection_result: Dict[str, Any], + branch_results: Dict[str, Any]) -> None: + """Execute PostgreSQL update with combined branch results using global executor.""" + if node.get("db_manager"): + db_executor = action_executor._get_db_executor(node["db_manager"]) + result = db_executor.execute_postgresql_update_combined(action, detection_result, branch_results) + + if not result.success: + logger.error(f"PostgreSQL update failed: {result.error}") + else: + logger.error("No database manager available for postgresql_update_combined action") \ No newline at end of file diff --git a/detector_worker/pipeline/field_mapper.py b/detector_worker/pipeline/field_mapper.py new file mode 100644 index 0000000..903eeef --- /dev/null +++ b/detector_worker/pipeline/field_mapper.py @@ -0,0 +1,341 @@ +""" +Field mapping and template resolution for dynamic database operations. + +This module provides functionality for resolving field mapping templates +that reference branch results and context variables for database operations. +""" + +import re +import logging +from typing import Dict, Any, Optional, List, Union + +from ..core.exceptions import FieldMappingError + +logger = logging.getLogger(__name__) + + +class FieldMapper: + """ + Field mapping resolver for dynamic template resolution. + + This class handles the resolution of field mapping templates that can reference: + - Branch results (e.g., {car_brand_cls_v1.brand}) + - Context variables (e.g., {session_id}) + - Nested field lookups with fallback strategies + """ + + def __init__(self): + """Initialize field mapper.""" + pass + + def _extract_branch_references(self, template: str) -> List[str]: + """Extract branch references from template string.""" + # Match patterns like {model_id.field_name} + branch_refs = re.findall(r'\{([^}]+\.[^}]+)\}', template) + return branch_refs + + def _resolve_simple_template(self, template: str, context: Dict[str, Any]) -> str: + """Resolve simple template without branch references.""" + try: + result = template.format(**context) + logger.debug(f"Simple template resolved: '{template}' -> '{result}'") + return result + except KeyError as e: + logger.warning(f"Could not resolve context variable in simple template: {e}") + return template + + def _find_fallback_value(self, + branch_data: Dict[str, Any], + field_name: str, + model_id: str) -> Optional[str]: + """Find fallback value using various strategies.""" + if not isinstance(branch_data, dict): + logger.error(f"Branch data for '{model_id}' is not a dictionary: {type(branch_data)}") + return None + + # First, try the exact field name + if field_name in branch_data: + return branch_data[field_name] + + # Then try 'class' field as fallback + if 'class' in branch_data: + fallback_value = branch_data['class'] + logger.info(f"Using 'class' field as fallback for '{field_name}': '{fallback_value}'") + return fallback_value + + # For brand models, check if the class name exists as a key + if field_name == 'brand' and branch_data.get('class') in branch_data: + fallback_value = branch_data[branch_data['class']] + logger.info(f"Found brand value using class name as key: '{fallback_value}'") + return fallback_value + + # For body_type models, check if the class name exists as a key + if field_name == 'body_type' and branch_data.get('class') in branch_data: + fallback_value = branch_data[branch_data['class']] + logger.info(f"Found body_type value using class name as key: '{fallback_value}'") + return fallback_value + + # Additional fallback strategies for common field mappings + field_mappings = { + 'brand': ['car_brand', 'brand_name', 'detected_brand'], + 'body_type': ['car_body_type', 'bodytype', 'body', 'car_type'], + 'model': ['car_model', 'model_name', 'detected_model'], + 'color': ['car_color', 'color_name', 'detected_color'] + } + + if field_name in field_mappings: + for alternative in field_mappings[field_name]: + if alternative in branch_data: + fallback_value = branch_data[alternative] + logger.info(f"Found '{field_name}' value using alternative field '{alternative}': '{fallback_value}'") + return fallback_value + + return None + + def _resolve_branch_reference(self, + ref: str, + branch_results: Dict[str, Any]) -> Optional[str]: + """Resolve a single branch reference.""" + try: + model_id, field_name = ref.split('.', 1) + logger.debug(f"Processing branch reference: model_id='{model_id}', field_name='{field_name}'") + + if model_id not in branch_results: + logger.warning(f"Branch '{model_id}' not found in results. Available branches: {list(branch_results.keys())}") + return None + + branch_data = branch_results[model_id] + logger.debug(f"Branch '{model_id}' data: {branch_data}") + + if field_name in branch_data: + field_value = branch_data[field_name] + logger.info(f"✅ Resolved {ref} to '{field_value}'") + return str(field_value) + else: + logger.warning(f"Field '{field_name}' not found in branch '{model_id}' results.") + logger.debug(f"Available fields in '{model_id}': {list(branch_data.keys()) if isinstance(branch_data, dict) else 'N/A'}") + + # Try fallback strategies + fallback_value = self._find_fallback_value(branch_data, field_name, model_id) + + if fallback_value is not None: + logger.info(f"✅ Resolved {ref} to '{fallback_value}' (using fallback)") + return str(fallback_value) + else: + logger.error(f"No suitable field found for '{field_name}' in branch '{model_id}'") + logger.debug(f"Branch data structure: {branch_data}") + return None + + except ValueError as e: + logger.error(f"Invalid branch reference format: {ref}") + return None + except Exception as e: + logger.error(f"Error resolving branch reference '{ref}': {e}") + return None + + def resolve_field_mapping(self, + value_template: str, + branch_results: Dict[str, Any], + action_context: Dict[str, Any]) -> Optional[str]: + """ + Resolve field mapping templates like {car_brand_cls_v1.brand}. + + Args: + value_template: Template string with placeholders + branch_results: Dictionary of branch execution results + action_context: Context variables for template resolution + + Returns: + Resolved string value or None if resolution failed + """ + try: + logger.debug(f"Resolving field mapping: '{value_template}'") + logger.debug(f"Available branch results: {list(branch_results.keys())}") + + # Handle simple context variables first (non-branch references) + if '.' not in value_template: + result = self._resolve_simple_template(value_template, action_context) + return result + + # Handle branch result references like {model_id.field} + branch_refs = self._extract_branch_references(value_template) + logger.debug(f"Found branch references: {branch_refs}") + + resolved_template = value_template + + for ref in branch_refs: + resolved_value = self._resolve_branch_reference(ref, branch_results) + + if resolved_value is not None: + resolved_template = resolved_template.replace(f'{{{ref}}}', resolved_value) + else: + logger.error(f"Failed to resolve branch reference: {ref}") + return None + + # Format any remaining simple variables + try: + final_value = resolved_template.format(**action_context) + logger.debug(f"Final resolved value: '{final_value}'") + return final_value + except KeyError as e: + logger.warning(f"Could not resolve context variable in template: {e}") + return resolved_template + + except Exception as e: + logger.error(f"Error resolving field mapping '{value_template}': {e}") + return None + + def resolve_multiple_fields(self, + field_templates: Dict[str, str], + branch_results: Dict[str, Any], + action_context: Dict[str, Any]) -> Dict[str, str]: + """ + Resolve multiple field mappings at once. + + Args: + field_templates: Dictionary mapping field names to templates + branch_results: Dictionary of branch execution results + action_context: Context variables for template resolution + + Returns: + Dictionary mapping field names to resolved values + """ + resolved_fields = {} + + for field_name, template in field_templates.items(): + try: + resolved_value = self.resolve_field_mapping(template, branch_results, action_context) + if resolved_value is not None: + resolved_fields[field_name] = resolved_value + logger.debug(f"Successfully resolved field '{field_name}': {resolved_value}") + else: + logger.warning(f"Failed to resolve field '{field_name}' with template: {template}") + except Exception as e: + logger.error(f"Error resolving field '{field_name}' with template '{template}': {e}") + + return resolved_fields + + def validate_template(self, template: str) -> Dict[str, Any]: + """ + Validate a field mapping template and return analysis. + + Args: + template: Template string to validate + + Returns: + Dictionary with validation results and analysis + """ + analysis = { + "valid": True, + "errors": [], + "warnings": [], + "branch_references": [], + "context_references": [], + "has_branch_refs": False, + "has_context_refs": False + } + + try: + # Extract all placeholders + all_refs = re.findall(r'\{([^}]+)\}', template) + + for ref in all_refs: + if '.' in ref: + # This is a branch reference + analysis["branch_references"].append(ref) + analysis["has_branch_refs"] = True + + # Validate format + parts = ref.split('.') + if len(parts) != 2: + analysis["errors"].append(f"Invalid branch reference format: {ref}") + analysis["valid"] = False + elif not parts[0] or not parts[1]: + analysis["errors"].append(f"Empty model_id or field_name in reference: {ref}") + analysis["valid"] = False + else: + # This is a context reference + analysis["context_references"].append(ref) + analysis["has_context_refs"] = True + + # Check for common issues + if analysis["has_branch_refs"] and not analysis["has_context_refs"]: + analysis["warnings"].append("Template only uses branch references, consider adding context info") + + if not analysis["branch_references"] and not analysis["context_references"]: + analysis["warnings"].append("Template has no placeholders - it's a static value") + + except Exception as e: + analysis["valid"] = False + analysis["errors"].append(f"Template analysis failed: {e}") + + return analysis + + +# Global field mapper instance +field_mapper = FieldMapper() + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide the same interface as the original functions in pympta.py + +def resolve_field_mapping(value_template: str, + branch_results: Dict[str, Any], + action_context: Dict[str, Any]) -> Optional[str]: + """Resolve field mapping templates like {car_brand_cls_v1.brand}.""" + return field_mapper.resolve_field_mapping(value_template, branch_results, action_context) + + +def resolve_multiple_field_mappings(field_templates: Dict[str, str], + branch_results: Dict[str, Any], + action_context: Dict[str, Any]) -> Dict[str, str]: + """Resolve multiple field mappings at once.""" + return field_mapper.resolve_multiple_fields(field_templates, branch_results, action_context) + + +def validate_field_mapping_template(template: str) -> Dict[str, Any]: + """Validate a field mapping template and return analysis.""" + return field_mapper.validate_template(template) + + +def get_available_field_mappings(branch_results: Dict[str, Any]) -> Dict[str, List[str]]: + """ + Get available field mappings from branch results. + + Args: + branch_results: Dictionary of branch execution results + + Returns: + Dictionary mapping model IDs to available field names + """ + available_mappings = {} + + for model_id, branch_data in branch_results.items(): + if isinstance(branch_data, dict): + available_mappings[model_id] = list(branch_data.keys()) + else: + logger.warning(f"Branch '{model_id}' data is not a dictionary: {type(branch_data)}") + available_mappings[model_id] = [] + + return available_mappings + + +def create_field_mapping_examples(branch_results: Dict[str, Any]) -> List[str]: + """ + Create example field mapping templates based on available branch results. + + Args: + branch_results: Dictionary of branch execution results + + Returns: + List of example template strings + """ + examples = [] + + for model_id, branch_data in branch_results.items(): + if isinstance(branch_data, dict): + for field_name in branch_data.keys(): + example = f"{{{model_id}.{field_name}}}" + examples.append(example) + + return examples \ No newline at end of file diff --git a/detector_worker/storage/database_manager.py b/detector_worker/storage/database_manager.py new file mode 100644 index 0000000..6e02b4b --- /dev/null +++ b/detector_worker/storage/database_manager.py @@ -0,0 +1,617 @@ +""" +Database management and operations. + +This module provides comprehensive database functionality including: +- PostgreSQL connection management +- Table schema management +- Dynamic query execution +- Transaction handling +""" + +import psycopg2 +import psycopg2.extras +import uuid +import logging +from typing import Optional, Dict, Any, List, Tuple +from dataclasses import dataclass, field +from datetime import datetime + +from ..core.constants import ( + DB_CONNECTION_TIMEOUT, + DB_OPERATION_TIMEOUT, + DB_RETRY_ATTEMPTS +) +from ..core.exceptions import DatabaseError, create_detection_error + +logger = logging.getLogger(__name__) + + +@dataclass +class DatabaseConfig: + """Database connection configuration.""" + host: str + port: int + database: str + username: str + password: str + schema: str = "gas_station_1" + connection_timeout: int = DB_CONNECTION_TIMEOUT + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "host": self.host, + "port": self.port, + "database": self.database, + "username": self.username, + "password": self.password, + "schema": self.schema, + "connection_timeout": self.connection_timeout + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'DatabaseConfig': + """Create from dictionary data.""" + return cls( + host=data["host"], + port=data["port"], + database=data["database"], + username=data.get("username", data.get("user", "")), + password=data["password"], + schema=data.get("schema", "gas_station_1"), + connection_timeout=data.get("connection_timeout", DB_CONNECTION_TIMEOUT) + ) + + +@dataclass +class QueryResult: + """Result from database query execution.""" + success: bool + affected_rows: int = 0 + data: Optional[List[Dict[str, Any]]] = None + error: Optional[str] = None + query: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + result = { + "success": self.success, + "affected_rows": self.affected_rows + } + if self.data: + result["data"] = self.data + if self.error: + result["error"] = self.error + if self.query: + result["query"] = self.query + return result + + +class DatabaseManager: + """ + Comprehensive database manager for PostgreSQL operations. + + This class provides connection management, schema operations, + and dynamic query execution with transaction support. + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize database manager. + + Args: + config: Database configuration dictionary + """ + if isinstance(config, dict): + self.config = DatabaseConfig.from_dict(config) + else: + self.config = config + + self.connection: Optional[psycopg2.extensions.connection] = None + self._lock = None + + def _ensure_thread_safety(self): + """Initialize thread safety if not already done.""" + if self._lock is None: + import threading + self._lock = threading.RLock() + + def connect(self) -> bool: + """ + Establish database connection. + + Returns: + True if connection successful, False otherwise + """ + self._ensure_thread_safety() + + with self._lock: + try: + if self.connection and not self.connection.closed: + # Connection already exists and is open + return True + + self.connection = psycopg2.connect( + host=self.config.host, + port=self.config.port, + database=self.config.database, + user=self.config.username, + password=self.config.password, + connect_timeout=self.config.connection_timeout + ) + + # Set connection properties + self.connection.set_client_encoding('UTF8') + + logger.info(f"PostgreSQL connection established successfully to {self.config.host}:{self.config.port}") + return True + + except Exception as e: + logger.error(f"Failed to connect to PostgreSQL: {e}") + self.connection = None + return False + + def disconnect(self) -> None: + """Close database connection.""" + self._ensure_thread_safety() + + with self._lock: + if self.connection: + try: + self.connection.close() + logger.info("PostgreSQL connection closed") + except Exception as e: + logger.error(f"Error closing PostgreSQL connection: {e}") + finally: + self.connection = None + + def is_connected(self) -> bool: + """ + Check if database connection is active. + + Returns: + True if connected, False otherwise + """ + self._ensure_thread_safety() + + with self._lock: + try: + if self.connection and not self.connection.closed: + cur = self.connection.cursor() + cur.execute("SELECT 1") + cur.fetchone() + cur.close() + return True + except Exception as e: + logger.debug(f"Connection check failed: {e}") + + return False + + def _ensure_connected(self) -> bool: + """Ensure database connection is active.""" + if not self.is_connected(): + return self.connect() + return True + + def execute_query(self, + query: str, + params: Optional[Tuple] = None, + fetch_results: bool = False) -> QueryResult: + """ + Execute a database query. + + Args: + query: SQL query string + params: Query parameters + fetch_results: Whether to fetch and return results + + Returns: + QueryResult with execution details + """ + self._ensure_thread_safety() + + with self._lock: + if not self._ensure_connected(): + return QueryResult( + success=False, + error="Failed to establish database connection", + query=query + ) + + cursor = None + try: + cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) + + logger.debug(f"Executing query: {query}") + if params: + logger.debug(f"Query parameters: {params}") + + cursor.execute(query, params) + affected_rows = cursor.rowcount + + data = None + if fetch_results: + data = [dict(row) for row in cursor.fetchall()] + + self.connection.commit() + + logger.debug(f"Query executed successfully, affected rows: {affected_rows}") + + return QueryResult( + success=True, + affected_rows=affected_rows, + data=data, + query=query + ) + + except Exception as e: + error_msg = f"Database query execution failed: {e}" + logger.error(error_msg) + logger.debug(f"Failed query: {query}") + if params: + logger.debug(f"Query parameters: {params}") + + if self.connection: + try: + self.connection.rollback() + except Exception as rollback_error: + logger.error(f"Rollback failed: {rollback_error}") + + return QueryResult( + success=False, + error=error_msg, + query=query + ) + finally: + if cursor: + cursor.close() + + def create_schema_if_not_exists(self, schema_name: Optional[str] = None) -> bool: + """ + Create database schema if it doesn't exist. + + Args: + schema_name: Schema name (uses config schema if not provided) + + Returns: + True if successful, False otherwise + """ + schema_name = schema_name or self.config.schema + query = f"CREATE SCHEMA IF NOT EXISTS {schema_name}" + result = self.execute_query(query) + + if result.success: + logger.info(f"Schema '{schema_name}' created or verified successfully") + + return result.success + + def create_car_frontal_info_table(self) -> bool: + """ + Create the car_frontal_info table in the configured schema if it doesn't exist. + + Returns: + True if successful, False otherwise + """ + schema_name = self.config.schema + + # Ensure schema exists + if not self.create_schema_if_not_exists(schema_name): + return False + + try: + # Create table if it doesn't exist + create_table_query = f""" + CREATE TABLE IF NOT EXISTS {schema_name}.car_frontal_info ( + display_id VARCHAR(255), + captured_timestamp VARCHAR(255), + session_id VARCHAR(255) PRIMARY KEY, + license_character VARCHAR(255) DEFAULT NULL, + license_type VARCHAR(255) DEFAULT 'No model available', + car_brand VARCHAR(255) DEFAULT NULL, + car_model VARCHAR(255) DEFAULT NULL, + car_body_type VARCHAR(255) DEFAULT NULL, + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW() + ) + """ + + result = self.execute_query(create_table_query) + if not result.success: + return False + + # Add columns if they don't exist (for existing tables) + alter_queries = [ + f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_brand VARCHAR(255) DEFAULT NULL", + f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_model VARCHAR(255) DEFAULT NULL", + f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS car_body_type VARCHAR(255) DEFAULT NULL", + f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS created_at TIMESTAMP DEFAULT NOW()", + f"ALTER TABLE {schema_name}.car_frontal_info ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT NOW()" + ] + + for alter_query in alter_queries: + try: + alter_result = self.execute_query(alter_query) + if alter_result.success: + logger.debug(f"Executed: {alter_query}") + else: + # Check if it's just a "column already exists" error + if "already exists" not in (alter_result.error or "").lower(): + logger.warning(f"ALTER TABLE failed: {alter_result.error}") + except Exception as e: + # Ignore errors if column already exists (for older PostgreSQL versions) + if "already exists" in str(e).lower(): + logger.debug(f"Column already exists, skipping: {alter_query}") + else: + logger.warning(f"Error in ALTER TABLE: {e}") + + logger.info("Successfully created/verified car_frontal_info table with all required columns") + return True + + except Exception as e: + logger.error(f"Failed to create car_frontal_info table: {e}") + return False + + def insert_initial_detection(self, + display_id: str, + captured_timestamp: str, + session_id: Optional[str] = None) -> Optional[str]: + """ + Insert initial detection record and return the session_id. + + Args: + display_id: Display identifier + captured_timestamp: Timestamp of capture + session_id: Optional session ID (generated if not provided) + + Returns: + Session ID if successful, None otherwise + """ + # Generate session_id if not provided + if not session_id: + session_id = str(uuid.uuid4()) + + try: + # Ensure table exists + if not self.create_car_frontal_info_table(): + logger.error("Failed to create/verify table before insertion") + return None + + schema_name = self.config.schema + insert_query = f""" + INSERT INTO {schema_name}.car_frontal_info + (display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type, created_at) + VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL, NOW()) + ON CONFLICT (session_id) DO NOTHING + """ + + result = self.execute_query(insert_query, (display_id, captured_timestamp, session_id)) + + if result.success: + logger.info(f"Inserted initial detection record with session_id: {session_id}") + return session_id + else: + logger.error(f"Failed to insert initial detection record: {result.error}") + return None + + except Exception as e: + logger.error(f"Failed to insert initial detection record: {e}") + return None + + def execute_update(self, + table: str, + key_field: str, + key_value: str, + fields: Dict[str, Any]) -> bool: + """ + Execute dynamic update/insert operation. + + Args: + table: Table name + key_field: Primary key field name + key_value: Primary key value + fields: Dictionary of fields to update + + Returns: + True if successful, False otherwise + """ + try: + # Add schema prefix if table doesn't already have it + if '.' not in table: + table = f"{self.config.schema}.{table}" + + # Build the INSERT and UPDATE query dynamically + insert_placeholders = [] + insert_values = [key_value] # Start with key_value + + set_clauses = [] + update_values = [] + + for field, value in fields.items(): + if value == "NOW()": + # Special handling for NOW() + insert_placeholders.append("NOW()") + set_clauses.append(f"{field} = NOW()") + else: + insert_placeholders.append("%s") + insert_values.append(value) + set_clauses.append(f"{field} = %s") + update_values.append(value) + + # Build the complete query + query = f""" + INSERT INTO {table} ({key_field}, {', '.join(fields.keys())}) + VALUES (%s, {', '.join(insert_placeholders)}) + ON CONFLICT ({key_field}) + DO UPDATE SET {', '.join(set_clauses)} + """ + + # Combine values for the query: insert_values + update_values + all_values = tuple(insert_values + update_values) + + result = self.execute_query(query, all_values) + + if result.success: + logger.info(f"✅ Updated {table} for {key_field}={key_value} with {len(fields)} fields") + logger.debug(f"Updated fields: {fields}") + return True + else: + logger.error(f"❌ Failed to update {table}: {result.error}") + return False + + except Exception as e: + logger.error(f"❌ Failed to execute update on {table}: {e}") + return False + + def update_car_info(self, + session_id: str, + brand: str, + model: str, + body_type: str) -> bool: + """ + Update car information for a session. + + Args: + session_id: Session identifier + brand: Car brand + model: Car model + body_type: Car body type + + Returns: + True if successful, False otherwise + """ + schema_name = self.config.schema + query = f""" + INSERT INTO {schema_name}.car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at) + VALUES (%s, %s, %s, %s, NOW()) + ON CONFLICT (session_id) + DO UPDATE SET + car_brand = EXCLUDED.car_brand, + car_model = EXCLUDED.car_model, + car_body_type = EXCLUDED.car_body_type, + updated_at = NOW() + """ + + result = self.execute_query(query, (session_id, brand, model, body_type)) + + if result.success: + logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})") + + return result.success + + def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: + """ + Get session data by session ID. + + Args: + session_id: Session identifier + + Returns: + Session data dictionary or None if not found + """ + schema_name = self.config.schema + query = f"SELECT * FROM {schema_name}.car_frontal_info WHERE session_id = %s" + + result = self.execute_query(query, (session_id,), fetch_results=True) + + if result.success and result.data: + return result.data[0] + + return None + + def get_connection_stats(self) -> Dict[str, Any]: + """ + Get database connection statistics. + + Returns: + Dictionary with connection statistics + """ + stats = { + "connected": self.is_connected(), + "config": self.config.to_dict(), + "connection_closed": self.connection.closed if self.connection else True + } + + if self.is_connected(): + try: + # Get database version and basic stats + version_result = self.execute_query("SELECT version()", fetch_results=True) + if version_result.success and version_result.data: + stats["database_version"] = version_result.data[0]["version"] + + # Get current database name + db_result = self.execute_query("SELECT current_database()", fetch_results=True) + if db_result.success and db_result.data: + stats["current_database"] = db_result.data[0]["current_database"] + + except Exception as e: + stats["stats_error"] = str(e) + + return stats + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide compatibility with the original database.py interface + +def validate_postgresql_config(config: Dict[str, Any]) -> bool: + """ + Validate PostgreSQL configuration. + + Args: + config: Configuration dictionary + + Returns: + True if configuration is valid + """ + required_fields = ["host", "port", "database", "password"] + + for field in required_fields: + if field not in config: + logger.error(f"Missing required PostgreSQL config field: {field}") + return False + + if not config[field]: + logger.error(f"Empty PostgreSQL config field: {field}") + return False + + # Check for username (could be 'username' or 'user') + if not config.get("username") and not config.get("user"): + logger.error("Missing PostgreSQL username (provide either 'username' or 'user')") + return False + + # Validate port is numeric + try: + port = int(config["port"]) + if port <= 0 or port > 65535: + logger.error(f"Invalid PostgreSQL port: {port}") + return False + except (ValueError, TypeError): + logger.error(f"PostgreSQL port must be numeric: {config['port']}") + return False + + return True + + +def create_database_manager(config: Dict[str, Any]) -> Optional[DatabaseManager]: + """ + Create database manager with configuration validation. + + Args: + config: Database configuration + + Returns: + DatabaseManager instance or None if invalid config + """ + if not validate_postgresql_config(config): + return None + + try: + db_manager = DatabaseManager(config) + if db_manager.connect(): + logger.info(f"Successfully created database manager for {config['host']}:{config['port']}") + return db_manager + else: + logger.error("Failed to establish initial database connection") + return None + except Exception as e: + logger.error(f"Error creating database manager: {e}") + return None \ No newline at end of file diff --git a/detector_worker/storage/redis_client.py b/detector_worker/storage/redis_client.py new file mode 100644 index 0000000..8bd9be8 --- /dev/null +++ b/detector_worker/storage/redis_client.py @@ -0,0 +1,733 @@ +""" +Redis client management and operations. + +This module provides comprehensive Redis functionality including: +- Connection management with retries +- Key-value operations +- Pub/Sub messaging +- Image storage with compression +- Pipeline operations +""" + +import redis +import json +import time +import logging +from typing import Dict, Any, Optional, List, Union, Tuple +from dataclasses import dataclass, field +from datetime import datetime, timedelta + +from ..core.constants import ( + REDIS_CONNECTION_TIMEOUT, + REDIS_SOCKET_TIMEOUT, + REDIS_IMAGE_DEFAULT_QUALITY, + REDIS_IMAGE_DEFAULT_FORMAT +) +from ..core.exceptions import RedisError + +logger = logging.getLogger(__name__) + + +@dataclass +class RedisConfig: + """Redis connection configuration.""" + host: str + port: int + password: Optional[str] = None + db: int = 0 + connection_timeout: int = REDIS_CONNECTION_TIMEOUT + socket_timeout: int = REDIS_SOCKET_TIMEOUT + retry_on_timeout: bool = True + health_check_interval: int = 30 + max_connections: int = 10 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "host": self.host, + "port": self.port, + "password": self.password, + "db": self.db, + "connection_timeout": self.connection_timeout, + "socket_timeout": self.socket_timeout, + "retry_on_timeout": self.retry_on_timeout, + "health_check_interval": self.health_check_interval, + "max_connections": self.max_connections + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'RedisConfig': + """Create from dictionary data.""" + return cls( + host=data["host"], + port=data["port"], + password=data.get("password"), + db=data.get("db", 0), + connection_timeout=data.get("connection_timeout", REDIS_CONNECTION_TIMEOUT), + socket_timeout=data.get("socket_timeout", REDIS_SOCKET_TIMEOUT), + retry_on_timeout=data.get("retry_on_timeout", True), + health_check_interval=data.get("health_check_interval", 30), + max_connections=data.get("max_connections", 10) + ) + + +@dataclass +class RedisOperationResult: + """Result from Redis operation.""" + success: bool + data: Optional[Any] = None + error: Optional[str] = None + operation: Optional[str] = None + key: Optional[str] = None + execution_time: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + result = { + "success": self.success, + "operation": self.operation + } + if self.data is not None: + result["data"] = self.data + if self.error: + result["error"] = self.error + if self.key: + result["key"] = self.key + if self.execution_time: + result["execution_time"] = self.execution_time + return result + + +class RedisClientManager: + """ + Comprehensive Redis client manager with connection pooling and retry logic. + + This class provides high-level Redis operations with automatic reconnection, + connection pooling, and comprehensive error handling. + """ + + def __init__(self, config: Union[Dict[str, Any], RedisConfig]): + """ + Initialize Redis client manager. + + Args: + config: Redis configuration dictionary or RedisConfig object + """ + if isinstance(config, dict): + self.config = RedisConfig.from_dict(config) + else: + self.config = config + + self.client: Optional[redis.Redis] = None + self.connection_pool: Optional[redis.ConnectionPool] = None + self._lock = None + self._last_health_check = 0.0 + + def _ensure_thread_safety(self): + """Initialize thread safety if not already done.""" + if self._lock is None: + import threading + self._lock = threading.RLock() + + def _create_connection_pool(self) -> redis.ConnectionPool: + """Create Redis connection pool.""" + pool_kwargs = { + "host": self.config.host, + "port": self.config.port, + "db": self.config.db, + "socket_timeout": self.config.socket_timeout, + "socket_connect_timeout": self.config.connection_timeout, + "retry_on_timeout": self.config.retry_on_timeout, + "health_check_interval": self.config.health_check_interval, + "max_connections": self.config.max_connections + } + + if self.config.password: + pool_kwargs["password"] = self.config.password + + return redis.ConnectionPool(**pool_kwargs) + + def connect(self) -> bool: + """ + Establish Redis connection. + + Returns: + True if connection successful, False otherwise + """ + self._ensure_thread_safety() + + with self._lock: + try: + # Create connection pool + self.connection_pool = self._create_connection_pool() + + # Create Redis client + self.client = redis.Redis(connection_pool=self.connection_pool) + + # Test connection + self.client.ping() + self._last_health_check = time.time() + + logger.info(f"Redis connection established successfully to {self.config.host}:{self.config.port}") + return True + + except Exception as e: + logger.error(f"Failed to connect to Redis: {e}") + self.client = None + self.connection_pool = None + return False + + def disconnect(self) -> None: + """Close Redis connection and cleanup resources.""" + self._ensure_thread_safety() + + with self._lock: + if self.connection_pool: + try: + self.connection_pool.disconnect() + logger.info("Redis connection pool disconnected") + except Exception as e: + logger.error(f"Error disconnecting Redis pool: {e}") + finally: + self.connection_pool = None + + self.client = None + + def is_connected(self) -> bool: + """ + Check if Redis connection is active. + + Returns: + True if connected, False otherwise + """ + self._ensure_thread_safety() + + with self._lock: + try: + if self.client: + # Perform periodic health check + current_time = time.time() + if current_time - self._last_health_check > self.config.health_check_interval: + self.client.ping() + self._last_health_check = current_time + return True + except Exception as e: + logger.debug(f"Redis connection check failed: {e}") + + return False + + def _ensure_connected(self) -> bool: + """Ensure Redis connection is active.""" + if not self.is_connected(): + return self.connect() + return True + + def _execute_operation(self, + operation_name: str, + operation_func, + key: Optional[str] = None) -> RedisOperationResult: + """Execute Redis operation with error handling and timing.""" + if not self._ensure_connected(): + return RedisOperationResult( + success=False, + error="Failed to establish Redis connection", + operation=operation_name, + key=key + ) + + start_time = time.time() + try: + result = operation_func() + execution_time = time.time() - start_time + + return RedisOperationResult( + success=True, + data=result, + operation=operation_name, + key=key, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + error_msg = f"Redis {operation_name} operation failed: {e}" + logger.error(error_msg) + + return RedisOperationResult( + success=False, + error=error_msg, + operation=operation_name, + key=key, + execution_time=execution_time + ) + + # ===== KEY-VALUE OPERATIONS ===== + + def set(self, key: str, value: Any, ex: Optional[int] = None) -> RedisOperationResult: + """ + Set key-value pair with optional expiration. + + Args: + key: Redis key + value: Value to store + ex: Expiration time in seconds + + Returns: + RedisOperationResult with operation status + """ + def operation(): + return self.client.set(key, value, ex=ex) + + result = self._execute_operation("SET", operation, key) + + if result.success: + expire_msg = f" (expires in {ex}s)" if ex else "" + logger.debug(f"Set Redis key '{key}'{expire_msg}") + + return result + + def setex(self, key: str, time: int, value: Any) -> RedisOperationResult: + """ + Set key-value pair with expiration time. + + Args: + key: Redis key + time: Expiration time in seconds + value: Value to store + + Returns: + RedisOperationResult with operation status + """ + def operation(): + return self.client.setex(key, time, value) + + result = self._execute_operation("SETEX", operation, key) + + if result.success: + logger.debug(f"Set Redis key '{key}' with {time}s expiration") + + return result + + def get(self, key: str) -> RedisOperationResult: + """ + Get value by key. + + Args: + key: Redis key + + Returns: + RedisOperationResult with value data + """ + def operation(): + return self.client.get(key) + + return self._execute_operation("GET", operation, key) + + def delete(self, *keys: str) -> RedisOperationResult: + """ + Delete one or more keys. + + Args: + *keys: Redis keys to delete + + Returns: + RedisOperationResult with number of deleted keys + """ + def operation(): + return self.client.delete(*keys) + + result = self._execute_operation("DELETE", operation) + + if result.success: + logger.debug(f"Deleted {result.data} Redis key(s): {keys}") + + return result + + def exists(self, *keys: str) -> RedisOperationResult: + """ + Check if keys exist. + + Args: + *keys: Redis keys to check + + Returns: + RedisOperationResult with count of existing keys + """ + def operation(): + return self.client.exists(*keys) + + return self._execute_operation("EXISTS", operation) + + def expire(self, key: str, time: int) -> RedisOperationResult: + """ + Set expiration time for a key. + + Args: + key: Redis key + time: Expiration time in seconds + + Returns: + RedisOperationResult with operation status + """ + def operation(): + return self.client.expire(key, time) + + return self._execute_operation("EXPIRE", operation, key) + + def ttl(self, key: str) -> RedisOperationResult: + """ + Get time to live for a key. + + Args: + key: Redis key + + Returns: + RedisOperationResult with TTL in seconds + """ + def operation(): + return self.client.ttl(key) + + return self._execute_operation("TTL", operation, key) + + # ===== PUB/SUB OPERATIONS ===== + + def publish(self, channel: str, message: Union[str, Dict[str, Any]]) -> RedisOperationResult: + """ + Publish message to channel. + + Args: + channel: Channel name + message: Message to publish (string or dict) + + Returns: + RedisOperationResult with number of subscribers + """ + def operation(): + # Convert dict to JSON string + if isinstance(message, dict): + message_str = json.dumps(message) + else: + message_str = str(message) + + return self.client.publish(channel, message_str) + + result = self._execute_operation("PUBLISH", operation) + + if result.success: + logger.debug(f"Published to channel '{channel}', subscribers: {result.data}") + if result.data == 0: + logger.warning(f"No subscribers listening to channel '{channel}'") + + return result + + def subscribe(self, *channels: str): + """ + Subscribe to channels (returns pubsub object). + + Args: + *channels: Channel names to subscribe to + + Returns: + PubSub object for listening to messages + """ + if not self._ensure_connected(): + raise RedisError("Failed to establish Redis connection for subscription") + + pubsub = self.client.pubsub() + pubsub.subscribe(*channels) + + logger.info(f"Subscribed to channels: {channels}") + return pubsub + + # ===== HASH OPERATIONS ===== + + def hset(self, key: str, field: str, value: Any) -> RedisOperationResult: + """ + Set hash field. + + Args: + key: Redis key + field: Hash field name + value: Field value + + Returns: + RedisOperationResult with operation status + """ + def operation(): + return self.client.hset(key, field, value) + + return self._execute_operation("HSET", operation, key) + + def hget(self, key: str, field: str) -> RedisOperationResult: + """ + Get hash field value. + + Args: + key: Redis key + field: Hash field name + + Returns: + RedisOperationResult with field value + """ + def operation(): + return self.client.hget(key, field) + + return self._execute_operation("HGET", operation, key) + + def hgetall(self, key: str) -> RedisOperationResult: + """ + Get all hash fields and values. + + Args: + key: Redis key + + Returns: + RedisOperationResult with hash dictionary + """ + def operation(): + return self.client.hgetall(key) + + return self._execute_operation("HGETALL", operation, key) + + # ===== LIST OPERATIONS ===== + + def lpush(self, key: str, *values: Any) -> RedisOperationResult: + """ + Push values to left of list. + + Args: + key: Redis key + *values: Values to push + + Returns: + RedisOperationResult with list length + """ + def operation(): + return self.client.lpush(key, *values) + + return self._execute_operation("LPUSH", operation, key) + + def rpush(self, key: str, *values: Any) -> RedisOperationResult: + """ + Push values to right of list. + + Args: + key: Redis key + *values: Values to push + + Returns: + RedisOperationResult with list length + """ + def operation(): + return self.client.rpush(key, *values) + + return self._execute_operation("RPUSH", operation, key) + + def lrange(self, key: str, start: int, end: int) -> RedisOperationResult: + """ + Get range of list elements. + + Args: + key: Redis key + start: Start index + end: End index + + Returns: + RedisOperationResult with list elements + """ + def operation(): + return self.client.lrange(key, start, end) + + return self._execute_operation("LRANGE", operation, key) + + # ===== UTILITY OPERATIONS ===== + + def ping(self) -> RedisOperationResult: + """ + Ping Redis server. + + Returns: + RedisOperationResult with ping response + """ + def operation(): + return self.client.ping() + + return self._execute_operation("PING", operation) + + def info(self, section: Optional[str] = None) -> RedisOperationResult: + """ + Get Redis server information. + + Args: + section: Optional info section + + Returns: + RedisOperationResult with server info + """ + def operation(): + return self.client.info(section) + + return self._execute_operation("INFO", operation) + + def flushdb(self) -> RedisOperationResult: + """ + Flush current database. + + Returns: + RedisOperationResult with operation status + """ + def operation(): + return self.client.flushdb() + + result = self._execute_operation("FLUSHDB", operation) + + if result.success: + logger.warning(f"Flushed Redis database {self.config.db}") + + return result + + def keys(self, pattern: str = "*") -> RedisOperationResult: + """ + Get keys matching pattern. + + Args: + pattern: Key pattern (default: all keys) + + Returns: + RedisOperationResult with matching keys + """ + def operation(): + return self.client.keys(pattern) + + return self._execute_operation("KEYS", operation) + + # ===== BATCH OPERATIONS ===== + + def pipeline(self): + """ + Create Redis pipeline for batch operations. + + Returns: + Redis pipeline object + """ + if not self._ensure_connected(): + raise RedisError("Failed to establish Redis connection for pipeline") + + return self.client.pipeline() + + def execute_pipeline(self, pipeline) -> RedisOperationResult: + """ + Execute Redis pipeline. + + Args: + pipeline: Redis pipeline object + + Returns: + RedisOperationResult with pipeline results + """ + def operation(): + return pipeline.execute() + + return self._execute_operation("PIPELINE", operation) + + # ===== CONNECTION MANAGEMENT ===== + + def get_connection_stats(self) -> Dict[str, Any]: + """ + Get Redis connection statistics. + + Returns: + Dictionary with connection statistics + """ + stats = { + "connected": self.is_connected(), + "config": self.config.to_dict(), + "last_health_check": self._last_health_check, + "connection_pool_created": self.connection_pool is not None + } + + if self.connection_pool: + stats["connection_pool_stats"] = { + "created_connections": self.connection_pool.created_connections, + "available_connections": len(self.connection_pool._available_connections), + "in_use_connections": len(self.connection_pool._in_use_connections) + } + + # Get Redis server info if connected + if self.is_connected(): + try: + info_result = self.info() + if info_result.success: + redis_info = info_result.data + stats["server_info"] = { + "redis_version": redis_info.get("redis_version"), + "connected_clients": redis_info.get("connected_clients"), + "used_memory": redis_info.get("used_memory"), + "used_memory_human": redis_info.get("used_memory_human"), + "total_commands_processed": redis_info.get("total_commands_processed"), + "uptime_in_seconds": redis_info.get("uptime_in_seconds") + } + except Exception as e: + stats["server_info_error"] = str(e) + + return stats + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide compatibility and simplified access + +def create_redis_client(config: Dict[str, Any]) -> Optional[RedisClientManager]: + """ + Create Redis client with configuration validation. + + Args: + config: Redis configuration + + Returns: + RedisClientManager instance or None if connection failed + """ + try: + client_manager = RedisClientManager(config) + if client_manager.connect(): + logger.info(f"Successfully created Redis client for {config['host']}:{config['port']}") + return client_manager + else: + logger.error("Failed to establish initial Redis connection") + return None + except Exception as e: + logger.error(f"Error creating Redis client: {e}") + return None + + +def validate_redis_config(config: Dict[str, Any]) -> bool: + """ + Validate Redis configuration. + + Args: + config: Configuration dictionary + + Returns: + True if configuration is valid + """ + required_fields = ["host", "port"] + + for field in required_fields: + if field not in config: + logger.error(f"Missing required Redis config field: {field}") + return False + + if not config[field]: + logger.error(f"Empty Redis config field: {field}") + return False + + # Validate port is numeric + try: + port = int(config["port"]) + if port <= 0 or port > 65535: + logger.error(f"Invalid Redis port: {port}") + return False + except (ValueError, TypeError): + logger.error(f"Redis port must be numeric: {config['port']}") + return False + + return True \ No newline at end of file diff --git a/detector_worker/storage/session_cache.py b/detector_worker/storage/session_cache.py new file mode 100644 index 0000000..09d16a0 --- /dev/null +++ b/detector_worker/storage/session_cache.py @@ -0,0 +1,688 @@ +""" +Session and cache management for detection workflows. + +This module provides comprehensive session and cache management including: +- Session state tracking and lifecycle management +- Detection result caching with TTL +- Pipeline mode state management +- Session cleanup and garbage collection +""" + +import time +import logging +from typing import Dict, List, Any, Optional, Set, Tuple +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from enum import Enum + +from ..core.constants import ( + SESSION_CACHE_TTL_MINUTES, + DETECTION_CACHE_CLEANUP_INTERVAL, + SESSION_TIMEOUT_SECONDS +) +from ..core.exceptions import SessionError + +logger = logging.getLogger(__name__) + + +class PipelineMode(Enum): + """Pipeline execution modes.""" + VALIDATION_DETECTING = "validation_detecting" + SEND_DETECTIONS = "send_detections" + WAITING_FOR_SESSION_ID = "waiting_for_session_id" + FULL_PIPELINE = "full_pipeline" + LIGHTWEIGHT = "lightweight" + CAR_GONE_WAITING = "car_gone_waiting" + + +@dataclass +class SessionState: + """Session state information.""" + session_id: Optional[str] = None + backend_session_id: Optional[str] = None + mode: PipelineMode = PipelineMode.VALIDATION_DETECTING + session_id_received: bool = False + created_at: float = field(default_factory=time.time) + last_updated: float = field(default_factory=time.time) + last_detection: Optional[float] = None + detection_count: int = 0 + + # Mode-specific state + validation_frames_processed: int = 0 + stable_track_achieved: bool = False + waiting_start_time: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "session_id": self.session_id, + "backend_session_id": self.backend_session_id, + "mode": self.mode.value, + "session_id_received": self.session_id_received, + "created_at": self.created_at, + "last_updated": self.last_updated, + "last_detection": self.last_detection, + "detection_count": self.detection_count, + "validation_frames_processed": self.validation_frames_processed, + "stable_track_achieved": self.stable_track_achieved, + "waiting_start_time": self.waiting_start_time + } + + def update_activity(self) -> None: + """Update last activity timestamp.""" + self.last_updated = time.time() + + def record_detection(self) -> None: + """Record a detection occurrence.""" + current_time = time.time() + self.last_detection = current_time + self.detection_count += 1 + self.update_activity() + + def is_expired(self, ttl_seconds: int) -> bool: + """Check if session has expired based on TTL.""" + return time.time() - self.last_updated > ttl_seconds + + +@dataclass +class CachedDetection: + """Cached detection result.""" + detection_data: Dict[str, Any] + camera_id: str + session_id: Optional[str] = None + created_at: float = field(default_factory=time.time) + access_count: int = 0 + last_accessed: float = field(default_factory=time.time) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "detection_data": self.detection_data, + "camera_id": self.camera_id, + "session_id": self.session_id, + "created_at": self.created_at, + "access_count": self.access_count, + "last_accessed": self.last_accessed + } + + def access(self) -> None: + """Record access to this cached detection.""" + self.access_count += 1 + self.last_accessed = time.time() + + def is_expired(self, ttl_seconds: int) -> bool: + """Check if cached detection has expired.""" + return time.time() - self.created_at > ttl_seconds + + +class SessionManager: + """ + Session lifecycle and state management. + + This class provides comprehensive session management including: + - Session creation and cleanup + - Pipeline mode transitions + - Session timeout handling + """ + + def __init__(self, session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS): + """ + Initialize session manager. + + Args: + session_timeout_seconds: Default session timeout + """ + self.session_timeout_seconds = session_timeout_seconds + self._sessions: Dict[str, SessionState] = {} + self._session_ids: Dict[str, str] = {} # display_id -> session_id mapping + self._lock = None + + def _ensure_thread_safety(self): + """Initialize thread safety if not already done.""" + if self._lock is None: + import threading + self._lock = threading.RLock() + + def create_session(self, display_id: str, session_id: Optional[str] = None) -> str: + """ + Create a new session or get existing one. + + Args: + display_id: Display identifier + session_id: Optional session ID + + Returns: + Session ID + """ + self._ensure_thread_safety() + + with self._lock: + # Check if session already exists for this display + if display_id in self._session_ids: + existing_session_id = self._session_ids[display_id] + if existing_session_id in self._sessions: + session_state = self._sessions[existing_session_id] + session_state.update_activity() + logger.debug(f"Using existing session for display {display_id}: {existing_session_id}") + return existing_session_id + + # Create new session + if not session_id: + import uuid + session_id = str(uuid.uuid4()) + + session_state = SessionState(session_id=session_id) + self._sessions[session_id] = session_state + self._session_ids[display_id] = session_id + + logger.info(f"Created new session for display {display_id}: {session_id}") + return session_id + + def get_session(self, session_id: str) -> Optional[SessionState]: + """ + Get session state by session ID. + + Args: + session_id: Session identifier + + Returns: + SessionState or None if not found + """ + self._ensure_thread_safety() + + with self._lock: + session = self._sessions.get(session_id) + if session: + session.update_activity() + return session + + def get_session_by_display(self, display_id: str) -> Optional[SessionState]: + """ + Get session state by display ID. + + Args: + display_id: Display identifier + + Returns: + SessionState or None if not found + """ + self._ensure_thread_safety() + + with self._lock: + session_id = self._session_ids.get(display_id) + if session_id: + return self.get_session(session_id) + return None + + def update_session_mode(self, + session_id: str, + new_mode: PipelineMode, + backend_session_id: Optional[str] = None) -> bool: + """ + Update session pipeline mode. + + Args: + session_id: Session identifier + new_mode: New pipeline mode + backend_session_id: Optional backend session ID + + Returns: + True if updated successfully + """ + self._ensure_thread_safety() + + with self._lock: + session = self.get_session(session_id) + if not session: + logger.warning(f"Session not found for mode update: {session_id}") + return False + + old_mode = session.mode + session.mode = new_mode + + if backend_session_id: + session.backend_session_id = backend_session_id + session.session_id_received = True + + # Handle mode-specific state changes + if new_mode == PipelineMode.WAITING_FOR_SESSION_ID: + session.waiting_start_time = time.time() + elif old_mode == PipelineMode.WAITING_FOR_SESSION_ID: + session.waiting_start_time = None + + session.update_activity() + + logger.info(f"Session {session_id}: Mode changed from {old_mode.value} to {new_mode.value}") + return True + + def record_detection(self, session_id: str) -> bool: + """ + Record a detection for a session. + + Args: + session_id: Session identifier + + Returns: + True if recorded successfully + """ + self._ensure_thread_safety() + + with self._lock: + session = self.get_session(session_id) + if session: + session.record_detection() + return True + return False + + def cleanup_expired_sessions(self, ttl_seconds: Optional[int] = None) -> int: + """ + Clean up expired sessions. + + Args: + ttl_seconds: TTL in seconds (uses default if not provided) + + Returns: + Number of sessions cleaned up + """ + self._ensure_thread_safety() + + if ttl_seconds is None: + ttl_seconds = SESSION_CACHE_TTL_MINUTES * 60 + + with self._lock: + expired_sessions = [] + expired_displays = [] + + # Find expired sessions + for session_id, session in self._sessions.items(): + if session.is_expired(ttl_seconds): + expired_sessions.append(session_id) + + # Find displays pointing to expired sessions + for display_id, session_id in self._session_ids.items(): + if session_id in expired_sessions: + expired_displays.append(display_id) + + # Remove expired sessions and mappings + for session_id in expired_sessions: + del self._sessions[session_id] + + for display_id in expired_displays: + del self._session_ids[display_id] + + if expired_sessions: + logger.info(f"Cleaned up {len(expired_sessions)} expired sessions") + + return len(expired_sessions) + + def get_session_stats(self) -> Dict[str, Any]: + """ + Get session management statistics. + + Returns: + Dictionary with session statistics + """ + self._ensure_thread_safety() + + with self._lock: + current_time = time.time() + mode_counts = {} + total_detections = 0 + + for session in self._sessions.values(): + mode = session.mode.value + mode_counts[mode] = mode_counts.get(mode, 0) + 1 + total_detections += session.detection_count + + return { + "total_sessions": len(self._sessions), + "total_display_mappings": len(self._session_ids), + "mode_distribution": mode_counts, + "total_detections_processed": total_detections, + "session_timeout_seconds": self.session_timeout_seconds + } + + +class DetectionCache: + """ + Detection result caching with TTL and access tracking. + + This class provides caching for detection results with automatic + expiration and access pattern tracking. + """ + + def __init__(self, ttl_minutes: int = SESSION_CACHE_TTL_MINUTES): + """ + Initialize detection cache. + + Args: + ttl_minutes: Time to live for cached detections in minutes + """ + self.ttl_seconds = ttl_minutes * 60 + self._cache: Dict[str, CachedDetection] = {} + self._camera_index: Dict[str, Set[str]] = {} # camera_id -> set of cache keys + self._session_index: Dict[str, Set[str]] = {} # session_id -> set of cache keys + self._lock = None + self._last_cleanup = time.time() + + def _ensure_thread_safety(self): + """Initialize thread safety if not already done.""" + if self._lock is None: + import threading + self._lock = threading.RLock() + + def _generate_cache_key(self, camera_id: str, detection_type: str = "default") -> str: + """Generate cache key for detection.""" + return f"{camera_id}:{detection_type}:{time.time()}" + + def store_detection(self, + camera_id: str, + detection_data: Dict[str, Any], + session_id: Optional[str] = None, + detection_type: str = "default") -> str: + """ + Store detection in cache. + + Args: + camera_id: Camera identifier + detection_data: Detection result data + session_id: Optional session identifier + detection_type: Type of detection for categorization + + Returns: + Cache key for the stored detection + """ + self._ensure_thread_safety() + + with self._lock: + cache_key = self._generate_cache_key(camera_id, detection_type) + + cached_detection = CachedDetection( + detection_data=detection_data, + camera_id=camera_id, + session_id=session_id + ) + + self._cache[cache_key] = cached_detection + + # Update indexes + if camera_id not in self._camera_index: + self._camera_index[camera_id] = set() + self._camera_index[camera_id].add(cache_key) + + if session_id: + if session_id not in self._session_index: + self._session_index[session_id] = set() + self._session_index[session_id].add(cache_key) + + logger.debug(f"Stored detection in cache: {cache_key}") + return cache_key + + def get_detection(self, cache_key: str) -> Optional[Dict[str, Any]]: + """ + Get detection from cache by key. + + Args: + cache_key: Cache key + + Returns: + Detection data or None if not found/expired + """ + self._ensure_thread_safety() + + with self._lock: + cached_detection = self._cache.get(cache_key) + if not cached_detection: + return None + + if cached_detection.is_expired(self.ttl_seconds): + self._remove_from_cache(cache_key) + return None + + cached_detection.access() + return cached_detection.detection_data + + def get_latest_detection(self, camera_id: str) -> Optional[Dict[str, Any]]: + """ + Get latest detection for a camera. + + Args: + camera_id: Camera identifier + + Returns: + Latest detection data or None if not found + """ + self._ensure_thread_safety() + + with self._lock: + camera_keys = self._camera_index.get(camera_id, set()) + if not camera_keys: + return None + + # Find the most recent non-expired detection + latest_detection = None + latest_time = 0 + + for cache_key in camera_keys: + cached_detection = self._cache.get(cache_key) + if cached_detection and not cached_detection.is_expired(self.ttl_seconds): + if cached_detection.created_at > latest_time: + latest_time = cached_detection.created_at + latest_detection = cached_detection + + if latest_detection: + latest_detection.access() + return latest_detection.detection_data + + return None + + def get_session_detections(self, session_id: str) -> List[Dict[str, Any]]: + """ + Get all detections for a session. + + Args: + session_id: Session identifier + + Returns: + List of detection data dictionaries + """ + self._ensure_thread_safety() + + with self._lock: + session_keys = self._session_index.get(session_id, set()) + detections = [] + + for cache_key in session_keys: + cached_detection = self._cache.get(cache_key) + if cached_detection and not cached_detection.is_expired(self.ttl_seconds): + cached_detection.access() + detections.append(cached_detection.detection_data) + + # Sort by creation time (newest first) + detections.sort(key=lambda x: x.get("timestamp", 0), reverse=True) + return detections + + def _remove_from_cache(self, cache_key: str) -> None: + """Remove detection from cache and indexes.""" + cached_detection = self._cache.get(cache_key) + if cached_detection: + # Remove from indexes + camera_id = cached_detection.camera_id + if camera_id in self._camera_index: + self._camera_index[camera_id].discard(cache_key) + if not self._camera_index[camera_id]: + del self._camera_index[camera_id] + + session_id = cached_detection.session_id + if session_id and session_id in self._session_index: + self._session_index[session_id].discard(cache_key) + if not self._session_index[session_id]: + del self._session_index[session_id] + + # Remove from main cache + if cache_key in self._cache: + del self._cache[cache_key] + + def cleanup_expired_detections(self) -> int: + """ + Clean up expired cached detections. + + Returns: + Number of detections cleaned up + """ + self._ensure_thread_safety() + + with self._lock: + expired_keys = [] + + # Find expired detections + for cache_key, cached_detection in self._cache.items(): + if cached_detection.is_expired(self.ttl_seconds): + expired_keys.append(cache_key) + + # Remove expired detections + for cache_key in expired_keys: + self._remove_from_cache(cache_key) + + if expired_keys: + logger.info(f"Cleaned up {len(expired_keys)} expired cached detections") + + self._last_cleanup = time.time() + return len(expired_keys) + + def get_cache_stats(self) -> Dict[str, Any]: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics + """ + self._ensure_thread_safety() + + with self._lock: + total_access_count = sum(detection.access_count for detection in self._cache.values()) + + return { + "total_cached_detections": len(self._cache), + "cameras_with_cache": len(self._camera_index), + "sessions_with_cache": len(self._session_index), + "total_access_count": total_access_count, + "ttl_seconds": self.ttl_seconds, + "last_cleanup": self._last_cleanup + } + + +class SessionCacheManager: + """ + Combined session and cache management. + + This class provides unified management of sessions and detection caching + with automatic cleanup and comprehensive statistics. + """ + + def __init__(self, + session_timeout_seconds: int = SESSION_TIMEOUT_SECONDS, + cache_ttl_minutes: int = SESSION_CACHE_TTL_MINUTES): + """ + Initialize session cache manager. + + Args: + session_timeout_seconds: Session timeout in seconds + cache_ttl_minutes: Cache TTL in minutes + """ + self.session_manager = SessionManager(session_timeout_seconds) + self.detection_cache = DetectionCache(cache_ttl_minutes) + self._last_cleanup = time.time() + + def cleanup_expired_data(self, force: bool = False) -> Dict[str, int]: + """ + Clean up expired sessions and cached detections. + + Args: + force: Force cleanup regardless of interval + + Returns: + Dictionary with cleanup statistics + """ + current_time = time.time() + + # Check if cleanup is needed + if not force and current_time - self._last_cleanup < DETECTION_CACHE_CLEANUP_INTERVAL: + return {"sessions_cleaned": 0, "detections_cleaned": 0, "cleanup_skipped": True} + + sessions_cleaned = self.session_manager.cleanup_expired_sessions() + detections_cleaned = self.detection_cache.cleanup_expired_detections() + + self._last_cleanup = current_time + + return { + "sessions_cleaned": sessions_cleaned, + "detections_cleaned": detections_cleaned, + "cleanup_skipped": False + } + + def get_comprehensive_stats(self) -> Dict[str, Any]: + """ + Get comprehensive statistics for sessions and cache. + + Returns: + Dictionary with all statistics + """ + return { + "session_stats": self.session_manager.get_session_stats(), + "cache_stats": self.detection_cache.get_cache_stats(), + "last_cleanup": self._last_cleanup, + "cleanup_interval_seconds": DETECTION_CACHE_CLEANUP_INTERVAL + } + + +# Global session cache manager instance +session_cache_manager = SessionCacheManager() + + +# ===== CONVENIENCE FUNCTIONS ===== +# These provide simplified access to session and cache functionality + +def create_session(display_id: str, session_id: Optional[str] = None) -> str: + """Create a new session using global manager.""" + return session_cache_manager.session_manager.create_session(display_id, session_id) + + +def get_session_state(session_id: str) -> Optional[Dict[str, Any]]: + """Get session state by session ID.""" + session = session_cache_manager.session_manager.get_session(session_id) + return session.to_dict() if session else None + + +def update_pipeline_mode(session_id: str, + new_mode: str, + backend_session_id: Optional[str] = None) -> bool: + """Update session pipeline mode.""" + try: + mode = PipelineMode(new_mode) + return session_cache_manager.session_manager.update_session_mode(session_id, mode, backend_session_id) + except ValueError: + logger.error(f"Invalid pipeline mode: {new_mode}") + return False + + +def cache_detection(camera_id: str, + detection_data: Dict[str, Any], + session_id: Optional[str] = None) -> str: + """Cache detection data using global manager.""" + return session_cache_manager.detection_cache.store_detection(camera_id, detection_data, session_id) + + +def get_cached_detection(camera_id: str) -> Optional[Dict[str, Any]]: + """Get latest cached detection for a camera.""" + return session_cache_manager.detection_cache.get_latest_detection(camera_id) + + +def cleanup_expired_sessions_and_cache() -> Dict[str, int]: + """Clean up expired sessions and cached data.""" + return session_cache_manager.cleanup_expired_data() + + +def get_session_and_cache_stats() -> Dict[str, Any]: + """Get comprehensive session and cache statistics.""" + return session_cache_manager.get_comprehensive_stats() \ No newline at end of file