Refactor: PHASE 3: Action & Storage Extraction
This commit is contained in:
parent
4e9ae6bcc4
commit
cdeaaf4a4f
5 changed files with 3048 additions and 0 deletions
669
detector_worker/pipeline/action_executor.py
Normal file
669
detector_worker/pipeline/action_executor.py
Normal file
|
@ -0,0 +1,669 @@
|
|||
"""
|
||||
Action execution engine for pipeline workflows.
|
||||
|
||||
This module provides comprehensive action execution functionality including:
|
||||
- Redis-based actions (save image, publish messages)
|
||||
- Database actions (PostgreSQL updates with field mapping)
|
||||
- Parallel action coordination
|
||||
- Dynamic context resolution
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from ..core.constants import (
|
||||
REDIS_IMAGE_DEFAULT_QUALITY,
|
||||
REDIS_IMAGE_DEFAULT_FORMAT,
|
||||
DEFAULT_IMAGE_ENCODING_PARAMS
|
||||
)
|
||||
from ..core.exceptions import ActionExecutionError, create_pipeline_error
|
||||
from ..pipeline.field_mapper import resolve_field_mapping
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionContext:
|
||||
"""Context information for action execution."""
|
||||
timestamp_ms: int = field(default_factory=lambda: int(time.time() * 1000))
|
||||
uuid: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
||||
filename: str = field(default_factory=lambda: f"{uuid.uuid4()}.jpg")
|
||||
image_key: Optional[str] = None
|
||||
|
||||
# Detection context
|
||||
camera_id: str = "unknown"
|
||||
display_id: str = "unknown"
|
||||
session_id: Optional[str] = None
|
||||
backend_session_id: Optional[str] = None
|
||||
|
||||
# Additional context from detection result
|
||||
regions: Optional[Dict[str, Any]] = None
|
||||
detections: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for template formatting."""
|
||||
result = {
|
||||
"timestamp_ms": self.timestamp_ms,
|
||||
"uuid": self.uuid,
|
||||
"timestamp": self.timestamp,
|
||||
"filename": self.filename,
|
||||
"camera_id": self.camera_id,
|
||||
"display_id": self.display_id,
|
||||
}
|
||||
|
||||
if self.image_key:
|
||||
result["image_key"] = self.image_key
|
||||
if self.session_id:
|
||||
result["session_id"] = self.session_id
|
||||
if self.backend_session_id:
|
||||
result["backend_session_id"] = self.backend_session_id
|
||||
|
||||
return result
|
||||
|
||||
def update_from_dict(self, data: Dict[str, Any]) -> None:
|
||||
"""Update context from detection result dictionary."""
|
||||
if "camera_id" in data:
|
||||
self.camera_id = data["camera_id"]
|
||||
if "display_id" in data:
|
||||
self.display_id = data["display_id"]
|
||||
if "session_id" in data:
|
||||
self.session_id = data["session_id"]
|
||||
if "backend_session_id" in data:
|
||||
self.backend_session_id = data["backend_session_id"]
|
||||
if "regions" in data:
|
||||
self.regions = data["regions"]
|
||||
if "detections" in data:
|
||||
self.detections = data["detections"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionResult:
|
||||
"""Result from action execution."""
|
||||
success: bool
|
||||
action_type: str
|
||||
message: str = ""
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
result = {
|
||||
"success": self.success,
|
||||
"action_type": self.action_type,
|
||||
"message": self.message
|
||||
}
|
||||
if self.data:
|
||||
result["data"] = self.data
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
return result
|
||||
|
||||
|
||||
class RedisActionExecutor:
|
||||
"""Executor for Redis-based actions."""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
"""
|
||||
Initialize Redis action executor.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
"""
|
||||
self.redis_client = redis_client
|
||||
|
||||
def _crop_region_by_class(self,
|
||||
frame: np.ndarray,
|
||||
regions_dict: Dict[str, Any],
|
||||
class_name: str) -> Optional[np.ndarray]:
|
||||
"""Crop a specific region from frame based on detected class."""
|
||||
if class_name not in regions_dict:
|
||||
logger.warning(f"Class '{class_name}' not found in detected regions")
|
||||
return None
|
||||
|
||||
bbox = regions_dict[class_name]["bbox"]
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
# Validate bbox coordinates
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
logger.warning(f"Invalid bbox for class {class_name}: {bbox}")
|
||||
return None
|
||||
|
||||
try:
|
||||
cropped = frame[y1:y2, x1:x2]
|
||||
if cropped.size == 0:
|
||||
logger.warning(f"Empty crop for class {class_name}")
|
||||
return None
|
||||
return cropped
|
||||
except Exception as e:
|
||||
logger.error(f"Error cropping region for class {class_name}: {e}")
|
||||
return None
|
||||
|
||||
def _encode_image(self,
|
||||
image: np.ndarray,
|
||||
img_format: str = REDIS_IMAGE_DEFAULT_FORMAT,
|
||||
quality: int = REDIS_IMAGE_DEFAULT_QUALITY) -> Tuple[bool, Optional[bytes]]:
|
||||
"""Encode image with specified format and quality."""
|
||||
try:
|
||||
img_format = img_format.lower()
|
||||
|
||||
if img_format == "jpeg" or img_format == "jpg":
|
||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, buffer = cv2.imencode('.jpg', image, encode_params)
|
||||
elif img_format == "png":
|
||||
compression = DEFAULT_IMAGE_ENCODING_PARAMS["png"]["compression"]
|
||||
encode_params = [cv2.IMWRITE_PNG_COMPRESSION, compression]
|
||||
success, buffer = cv2.imencode('.png', image, encode_params)
|
||||
elif img_format == "webp":
|
||||
encode_params = [cv2.IMWRITE_WEBP_QUALITY, quality]
|
||||
success, buffer = cv2.imencode('.webp', image, encode_params)
|
||||
else:
|
||||
# Default to JPEG
|
||||
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
||||
success, buffer = cv2.imencode('.jpg', image, encode_params)
|
||||
|
||||
if success:
|
||||
return True, buffer.tobytes()
|
||||
else:
|
||||
return False, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding image as {img_format}: {e}")
|
||||
return False, None
|
||||
|
||||
def execute_redis_save_image(self,
|
||||
action: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
context: ActionContext,
|
||||
regions_dict: Optional[Dict[str, Any]] = None) -> ActionResult:
|
||||
"""Execute redis_save_image action."""
|
||||
try:
|
||||
# Format the key template
|
||||
key = action["key"].format(**context.to_dict())
|
||||
|
||||
# Check if we need to crop a specific region
|
||||
region_name = action.get("region")
|
||||
image_to_save = frame
|
||||
|
||||
if region_name and regions_dict:
|
||||
cropped_image = self._crop_region_by_class(frame, regions_dict, region_name)
|
||||
if cropped_image is not None:
|
||||
image_to_save = cropped_image
|
||||
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
|
||||
else:
|
||||
logger.warning(f"Could not crop region '{region_name}', saving full frame instead")
|
||||
|
||||
# Encode image with specified format and quality
|
||||
img_format = action.get("format", REDIS_IMAGE_DEFAULT_FORMAT)
|
||||
quality = action.get("quality", REDIS_IMAGE_DEFAULT_QUALITY)
|
||||
|
||||
success, image_bytes = self._encode_image(image_to_save, img_format, quality)
|
||||
|
||||
if not success or not image_bytes:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_save_image",
|
||||
error=f"Failed to encode image as {img_format}"
|
||||
)
|
||||
|
||||
# Save to Redis
|
||||
expire_seconds = action.get("expire_seconds")
|
||||
if expire_seconds:
|
||||
self.redis_client.setex(key, expire_seconds, image_bytes)
|
||||
message = f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)"
|
||||
else:
|
||||
self.redis_client.set(key, image_bytes)
|
||||
message = f"Saved image to Redis with key: {key}"
|
||||
|
||||
logger.info(message)
|
||||
|
||||
# Update context with image key
|
||||
context.image_key = key
|
||||
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action_type="redis_save_image",
|
||||
message=message,
|
||||
data={"key": key, "format": img_format, "quality": quality}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing redis_save_image: {e}"
|
||||
logger.error(error_msg)
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_save_image",
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
def execute_redis_publish(self,
|
||||
action: Dict[str, Any],
|
||||
context: ActionContext) -> ActionResult:
|
||||
"""Execute redis_publish action."""
|
||||
try:
|
||||
channel = action["channel"]
|
||||
message_template = action["message"]
|
||||
|
||||
# Handle JSON message format by creating it programmatically
|
||||
if message_template.strip().startswith('{') and message_template.strip().endswith('}'):
|
||||
# Create JSON data programmatically to avoid formatting issues
|
||||
json_data = {}
|
||||
|
||||
# Add common fields
|
||||
json_data["event"] = "frontal_detected"
|
||||
json_data["display_id"] = context.display_id
|
||||
json_data["session_id"] = context.session_id
|
||||
json_data["timestamp"] = context.timestamp
|
||||
json_data["image_key"] = context.image_key or ""
|
||||
|
||||
# Convert to JSON string
|
||||
message = json.dumps(json_data)
|
||||
else:
|
||||
# Use regular string formatting for non-JSON messages
|
||||
message = message_template.format(**context.to_dict())
|
||||
|
||||
# Test Redis connection
|
||||
if not self.redis_client:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_publish",
|
||||
error="Redis client is None"
|
||||
)
|
||||
|
||||
try:
|
||||
self.redis_client.ping()
|
||||
logger.debug("Redis connection is active")
|
||||
except Exception as ping_error:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_publish",
|
||||
error=f"Redis connection test failed: {ping_error}"
|
||||
)
|
||||
|
||||
# Publish to Redis
|
||||
result = self.redis_client.publish(channel, message)
|
||||
|
||||
success_msg = f"Published message to Redis channel '{channel}': {message}"
|
||||
logger.info(success_msg)
|
||||
logger.info(f"Redis publish result (subscribers count): {result}")
|
||||
|
||||
# Additional debug info
|
||||
if result == 0:
|
||||
logger.warning(f"No subscribers listening to channel '{channel}'")
|
||||
else:
|
||||
logger.info(f"Message delivered to {result} subscriber(s)")
|
||||
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action_type="redis_publish",
|
||||
message=success_msg,
|
||||
data={
|
||||
"channel": channel,
|
||||
"message": message,
|
||||
"subscribers_count": result
|
||||
}
|
||||
)
|
||||
|
||||
except KeyError as e:
|
||||
error_msg = f"Missing key in redis_publish message template: {e}"
|
||||
logger.error(error_msg)
|
||||
logger.debug(f"Available context keys: {list(context.to_dict().keys())}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_publish",
|
||||
error=error_msg
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in redis_publish action: {e}"
|
||||
logger.error(error_msg)
|
||||
logger.debug(f"Message template: {action['message']}")
|
||||
logger.debug(f"Available context keys: {list(context.to_dict().keys())}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="redis_publish",
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
|
||||
class DatabaseActionExecutor:
|
||||
"""Executor for database-based actions."""
|
||||
|
||||
def __init__(self, db_manager):
|
||||
"""
|
||||
Initialize database action executor.
|
||||
|
||||
Args:
|
||||
db_manager: Database manager instance
|
||||
"""
|
||||
self.db_manager = db_manager
|
||||
|
||||
def execute_postgresql_update_combined(self,
|
||||
action: Dict[str, Any],
|
||||
detection_result: Dict[str, Any],
|
||||
branch_results: Dict[str, Any]) -> ActionResult:
|
||||
"""Execute PostgreSQL update with combined branch results."""
|
||||
try:
|
||||
table = action["table"]
|
||||
key_field = action["key_field"]
|
||||
key_value_template = action["key_value"]
|
||||
fields = action["fields"]
|
||||
|
||||
# Create context for key value formatting
|
||||
action_context = {**detection_result}
|
||||
key_value = key_value_template.format(**action_context)
|
||||
|
||||
logger.info(f"Executing database update: table={table}, {key_field}={key_value}")
|
||||
logger.debug(f"Available branch results: {list(branch_results.keys())}")
|
||||
|
||||
# Process field mappings
|
||||
mapped_fields = {}
|
||||
mapping_errors = []
|
||||
|
||||
for db_field, value_template in fields.items():
|
||||
try:
|
||||
mapped_value = resolve_field_mapping(value_template, branch_results, action_context)
|
||||
if mapped_value is not None:
|
||||
mapped_fields[db_field] = mapped_value
|
||||
logger.info(f"Mapped field: {db_field} = {mapped_value}")
|
||||
else:
|
||||
error_msg = f"Could not resolve field mapping for {db_field}: {value_template}"
|
||||
logger.warning(error_msg)
|
||||
mapping_errors.append(error_msg)
|
||||
logger.debug(f"Available branch results: {branch_results}")
|
||||
except Exception as e:
|
||||
error_msg = f"Error mapping field {db_field} with template '{value_template}': {e}"
|
||||
logger.error(error_msg)
|
||||
mapping_errors.append(error_msg)
|
||||
|
||||
if not mapped_fields:
|
||||
error_msg = "No fields mapped successfully, skipping database update"
|
||||
logger.warning(error_msg)
|
||||
logger.debug(f"Branch results available: {branch_results}")
|
||||
logger.debug(f"Field templates: {fields}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="postgresql_update_combined",
|
||||
error=error_msg,
|
||||
data={"mapping_errors": mapping_errors}
|
||||
)
|
||||
|
||||
# Add updated_at field automatically
|
||||
mapped_fields["updated_at"] = "NOW()"
|
||||
|
||||
# Execute the database update
|
||||
logger.info(f"Attempting database update with fields: {mapped_fields}")
|
||||
success = self.db_manager.execute_update(table, key_field, key_value, mapped_fields)
|
||||
|
||||
if success:
|
||||
success_msg = f"Successfully updated database: {table} with {len(mapped_fields)} fields"
|
||||
logger.info(f"✅ {success_msg}")
|
||||
logger.info(f"Updated fields: {mapped_fields}")
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action_type="postgresql_update_combined",
|
||||
message=success_msg,
|
||||
data={
|
||||
"table": table,
|
||||
"key_field": key_field,
|
||||
"key_value": key_value,
|
||||
"updated_fields": mapped_fields
|
||||
}
|
||||
)
|
||||
else:
|
||||
error_msg = f"Failed to update database: {table}"
|
||||
logger.error(f"❌ {error_msg}")
|
||||
logger.error(f"Attempted update with: {key_field}={key_value}, fields={mapped_fields}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="postgresql_update_combined",
|
||||
error=error_msg,
|
||||
data={
|
||||
"table": table,
|
||||
"key_field": key_field,
|
||||
"key_value": key_value,
|
||||
"attempted_fields": mapped_fields
|
||||
}
|
||||
)
|
||||
|
||||
except KeyError as e:
|
||||
error_msg = f"Missing required field in postgresql_update_combined action: {e}"
|
||||
logger.error(error_msg)
|
||||
logger.debug(f"Action config: {action}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="postgresql_update_combined",
|
||||
error=error_msg
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in postgresql_update_combined action: {e}"
|
||||
logger.error(error_msg)
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action_type="postgresql_update_combined",
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
|
||||
class ActionExecutor:
|
||||
"""
|
||||
Main action execution engine for pipeline workflows.
|
||||
|
||||
This class coordinates the execution of various action types including
|
||||
Redis operations, database operations, and parallel action coordination.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize action executor."""
|
||||
self._redis_executors: Dict[Any, RedisActionExecutor] = {}
|
||||
self._db_executors: Dict[Any, DatabaseActionExecutor] = {}
|
||||
|
||||
def _get_redis_executor(self, redis_client) -> RedisActionExecutor:
|
||||
"""Get or create Redis executor for a client."""
|
||||
if redis_client not in self._redis_executors:
|
||||
self._redis_executors[redis_client] = RedisActionExecutor(redis_client)
|
||||
return self._redis_executors[redis_client]
|
||||
|
||||
def _get_db_executor(self, db_manager) -> DatabaseActionExecutor:
|
||||
"""Get or create database executor for a manager."""
|
||||
if db_manager not in self._db_executors:
|
||||
self._db_executors[db_manager] = DatabaseActionExecutor(db_manager)
|
||||
return self._db_executors[db_manager]
|
||||
|
||||
def _create_action_context(self, detection_result: Dict[str, Any]) -> ActionContext:
|
||||
"""Create action context from detection result."""
|
||||
context = ActionContext()
|
||||
context.update_from_dict(detection_result)
|
||||
return context
|
||||
|
||||
def execute_actions(self,
|
||||
node: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
detection_result: Dict[str, Any],
|
||||
regions_dict: Optional[Dict[str, Any]] = None) -> List[ActionResult]:
|
||||
"""
|
||||
Execute all actions for a node.
|
||||
|
||||
Args:
|
||||
node: Pipeline node configuration
|
||||
frame: Input frame
|
||||
detection_result: Detection result data
|
||||
regions_dict: Optional region dictionary for cropping
|
||||
|
||||
Returns:
|
||||
List of action results
|
||||
"""
|
||||
if not node.get("redis_client") or not node.get("actions"):
|
||||
return []
|
||||
|
||||
results = []
|
||||
redis_executor = self._get_redis_executor(node["redis_client"])
|
||||
context = self._create_action_context(detection_result)
|
||||
|
||||
for action in node["actions"]:
|
||||
try:
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "redis_save_image":
|
||||
result = redis_executor.execute_redis_save_image(
|
||||
action, frame, context, regions_dict
|
||||
)
|
||||
elif action_type == "redis_publish":
|
||||
result = redis_executor.execute_redis_publish(action, context)
|
||||
else:
|
||||
result = ActionResult(
|
||||
success=False,
|
||||
action_type=action_type or "unknown",
|
||||
error=f"Unknown action type: {action_type}"
|
||||
)
|
||||
|
||||
results.append(result)
|
||||
|
||||
if not result.success:
|
||||
logger.error(f"Action failed: {result.error}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing action {action.get('type', 'unknown')}: {e}"
|
||||
logger.error(error_msg)
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action_type=action.get('type', 'unknown'),
|
||||
error=error_msg
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
def execute_parallel_actions(self,
|
||||
node: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
detection_result: Dict[str, Any],
|
||||
regions_dict: Dict[str, Any]) -> List[ActionResult]:
|
||||
"""
|
||||
Execute parallel actions after all required branches have completed.
|
||||
|
||||
Args:
|
||||
node: Pipeline node configuration
|
||||
frame: Input frame
|
||||
detection_result: Detection result data with branch results
|
||||
regions_dict: Region dictionary
|
||||
|
||||
Returns:
|
||||
List of action results
|
||||
"""
|
||||
if not node.get("parallelActions"):
|
||||
return []
|
||||
|
||||
results = []
|
||||
branch_results = detection_result.get("branch_results", {})
|
||||
|
||||
logger.debug("Executing parallel actions...")
|
||||
|
||||
for action in node["parallelActions"]:
|
||||
try:
|
||||
action_type = action.get("type")
|
||||
logger.debug(f"Processing parallel action: {action_type}")
|
||||
|
||||
if action_type == "postgresql_update_combined":
|
||||
# Check if all required branches have completed
|
||||
wait_for_branches = action.get("waitForBranches", [])
|
||||
missing_branches = [branch for branch in wait_for_branches if branch not in branch_results]
|
||||
|
||||
if missing_branches:
|
||||
error_msg = f"Cannot execute postgresql_update_combined: missing branch results for {missing_branches}"
|
||||
logger.warning(error_msg)
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action_type=action_type,
|
||||
error=error_msg,
|
||||
data={"missing_branches": missing_branches}
|
||||
))
|
||||
continue
|
||||
|
||||
logger.info(f"All required branches completed: {wait_for_branches}")
|
||||
|
||||
# Execute the database update
|
||||
if node.get("db_manager"):
|
||||
db_executor = self._get_db_executor(node["db_manager"])
|
||||
result = db_executor.execute_postgresql_update_combined(
|
||||
action, detection_result, branch_results
|
||||
)
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action_type=action_type,
|
||||
error="No database manager available"
|
||||
))
|
||||
else:
|
||||
error_msg = f"Unknown parallel action type: {action_type}"
|
||||
logger.warning(error_msg)
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action_type=action_type or "unknown",
|
||||
error=error_msg
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing parallel action {action.get('type', 'unknown')}: {e}"
|
||||
logger.error(error_msg)
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action_type=action.get('type', 'unknown'),
|
||||
error=error_msg
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global action executor instance
|
||||
action_executor = ActionExecutor()
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide the same interface as the original functions in pympta.py
|
||||
|
||||
def execute_actions(node: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
detection_result: Dict[str, Any],
|
||||
regions_dict: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Execute all actions for a node using global executor."""
|
||||
results = action_executor.execute_actions(node, frame, detection_result, regions_dict)
|
||||
|
||||
# Log any failures (maintain original behavior of not returning results)
|
||||
for result in results:
|
||||
if not result.success:
|
||||
logger.error(f"Action execution failed: {result.error}")
|
||||
|
||||
|
||||
def execute_parallel_actions(node: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
detection_result: Dict[str, Any],
|
||||
regions_dict: Dict[str, Any]) -> None:
|
||||
"""Execute parallel actions using global executor."""
|
||||
results = action_executor.execute_parallel_actions(node, frame, detection_result, regions_dict)
|
||||
|
||||
# Log any failures (maintain original behavior of not returning results)
|
||||
for result in results:
|
||||
if not result.success:
|
||||
logger.error(f"Parallel action execution failed: {result.error}")
|
||||
|
||||
|
||||
def execute_postgresql_update_combined(node: Dict[str, Any],
|
||||
action: Dict[str, Any],
|
||||
detection_result: Dict[str, Any],
|
||||
branch_results: Dict[str, Any]) -> None:
|
||||
"""Execute PostgreSQL update with combined branch results using global executor."""
|
||||
if node.get("db_manager"):
|
||||
db_executor = action_executor._get_db_executor(node["db_manager"])
|
||||
result = db_executor.execute_postgresql_update_combined(action, detection_result, branch_results)
|
||||
|
||||
if not result.success:
|
||||
logger.error(f"PostgreSQL update failed: {result.error}")
|
||||
else:
|
||||
logger.error("No database manager available for postgresql_update_combined action")
|
341
detector_worker/pipeline/field_mapper.py
Normal file
341
detector_worker/pipeline/field_mapper.py
Normal file
|
@ -0,0 +1,341 @@
|
|||
"""
|
||||
Field mapping and template resolution for dynamic database operations.
|
||||
|
||||
This module provides functionality for resolving field mapping templates
|
||||
that reference branch results and context variables for database operations.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
|
||||
from ..core.exceptions import FieldMappingError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FieldMapper:
|
||||
"""
|
||||
Field mapping resolver for dynamic template resolution.
|
||||
|
||||
This class handles the resolution of field mapping templates that can reference:
|
||||
- Branch results (e.g., {car_brand_cls_v1.brand})
|
||||
- Context variables (e.g., {session_id})
|
||||
- Nested field lookups with fallback strategies
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize field mapper."""
|
||||
pass
|
||||
|
||||
def _extract_branch_references(self, template: str) -> List[str]:
|
||||
"""Extract branch references from template string."""
|
||||
# Match patterns like {model_id.field_name}
|
||||
branch_refs = re.findall(r'\{([^}]+\.[^}]+)\}', template)
|
||||
return branch_refs
|
||||
|
||||
def _resolve_simple_template(self, template: str, context: Dict[str, Any]) -> str:
|
||||
"""Resolve simple template without branch references."""
|
||||
try:
|
||||
result = template.format(**context)
|
||||
logger.debug(f"Simple template resolved: '{template}' -> '{result}'")
|
||||
return result
|
||||
except KeyError as e:
|
||||
logger.warning(f"Could not resolve context variable in simple template: {e}")
|
||||
return template
|
||||
|
||||
def _find_fallback_value(self,
|
||||
branch_data: Dict[str, Any],
|
||||
field_name: str,
|
||||
model_id: str) -> Optional[str]:
|
||||
"""Find fallback value using various strategies."""
|
||||
if not isinstance(branch_data, dict):
|
||||
logger.error(f"Branch data for '{model_id}' is not a dictionary: {type(branch_data)}")
|
||||
return None
|
||||
|
||||
# First, try the exact field name
|
||||
if field_name in branch_data:
|
||||
return branch_data[field_name]
|
||||
|
||||
# Then try 'class' field as fallback
|
||||
if 'class' in branch_data:
|
||||
fallback_value = branch_data['class']
|
||||
logger.info(f"Using 'class' field as fallback for '{field_name}': '{fallback_value}'")
|
||||
return fallback_value
|
||||
|
||||
# For brand models, check if the class name exists as a key
|
||||
if field_name == 'brand' and branch_data.get('class') in branch_data:
|
||||
fallback_value = branch_data[branch_data['class']]
|
||||
logger.info(f"Found brand value using class name as key: '{fallback_value}'")
|
||||
return fallback_value
|
||||
|
||||
# For body_type models, check if the class name exists as a key
|
||||
if field_name == 'body_type' and branch_data.get('class') in branch_data:
|
||||
fallback_value = branch_data[branch_data['class']]
|
||||
logger.info(f"Found body_type value using class name as key: '{fallback_value}'")
|
||||
return fallback_value
|
||||
|
||||
# Additional fallback strategies for common field mappings
|
||||
field_mappings = {
|
||||
'brand': ['car_brand', 'brand_name', 'detected_brand'],
|
||||
'body_type': ['car_body_type', 'bodytype', 'body', 'car_type'],
|
||||
'model': ['car_model', 'model_name', 'detected_model'],
|
||||
'color': ['car_color', 'color_name', 'detected_color']
|
||||
}
|
||||
|
||||
if field_name in field_mappings:
|
||||
for alternative in field_mappings[field_name]:
|
||||
if alternative in branch_data:
|
||||
fallback_value = branch_data[alternative]
|
||||
logger.info(f"Found '{field_name}' value using alternative field '{alternative}': '{fallback_value}'")
|
||||
return fallback_value
|
||||
|
||||
return None
|
||||
|
||||
def _resolve_branch_reference(self,
|
||||
ref: str,
|
||||
branch_results: Dict[str, Any]) -> Optional[str]:
|
||||
"""Resolve a single branch reference."""
|
||||
try:
|
||||
model_id, field_name = ref.split('.', 1)
|
||||
logger.debug(f"Processing branch reference: model_id='{model_id}', field_name='{field_name}'")
|
||||
|
||||
if model_id not in branch_results:
|
||||
logger.warning(f"Branch '{model_id}' not found in results. Available branches: {list(branch_results.keys())}")
|
||||
return None
|
||||
|
||||
branch_data = branch_results[model_id]
|
||||
logger.debug(f"Branch '{model_id}' data: {branch_data}")
|
||||
|
||||
if field_name in branch_data:
|
||||
field_value = branch_data[field_name]
|
||||
logger.info(f"✅ Resolved {ref} to '{field_value}'")
|
||||
return str(field_value)
|
||||
else:
|
||||
logger.warning(f"Field '{field_name}' not found in branch '{model_id}' results.")
|
||||
logger.debug(f"Available fields in '{model_id}': {list(branch_data.keys()) if isinstance(branch_data, dict) else 'N/A'}")
|
||||
|
||||
# Try fallback strategies
|
||||
fallback_value = self._find_fallback_value(branch_data, field_name, model_id)
|
||||
|
||||
if fallback_value is not None:
|
||||
logger.info(f"✅ Resolved {ref} to '{fallback_value}' (using fallback)")
|
||||
return str(fallback_value)
|
||||
else:
|
||||
logger.error(f"No suitable field found for '{field_name}' in branch '{model_id}'")
|
||||
logger.debug(f"Branch data structure: {branch_data}")
|
||||
return None
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid branch reference format: {ref}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving branch reference '{ref}': {e}")
|
||||
return None
|
||||
|
||||
def resolve_field_mapping(self,
|
||||
value_template: str,
|
||||
branch_results: Dict[str, Any],
|
||||
action_context: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
Resolve field mapping templates like {car_brand_cls_v1.brand}.
|
||||
|
||||
Args:
|
||||
value_template: Template string with placeholders
|
||||
branch_results: Dictionary of branch execution results
|
||||
action_context: Context variables for template resolution
|
||||
|
||||
Returns:
|
||||
Resolved string value or None if resolution failed
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Resolving field mapping: '{value_template}'")
|
||||
logger.debug(f"Available branch results: {list(branch_results.keys())}")
|
||||
|
||||
# Handle simple context variables first (non-branch references)
|
||||
if '.' not in value_template:
|
||||
result = self._resolve_simple_template(value_template, action_context)
|
||||
return result
|
||||
|
||||
# Handle branch result references like {model_id.field}
|
||||
branch_refs = self._extract_branch_references(value_template)
|
||||
logger.debug(f"Found branch references: {branch_refs}")
|
||||
|
||||
resolved_template = value_template
|
||||
|
||||
for ref in branch_refs:
|
||||
resolved_value = self._resolve_branch_reference(ref, branch_results)
|
||||
|
||||
if resolved_value is not None:
|
||||
resolved_template = resolved_template.replace(f'{{{ref}}}', resolved_value)
|
||||
else:
|
||||
logger.error(f"Failed to resolve branch reference: {ref}")
|
||||
return None
|
||||
|
||||
# Format any remaining simple variables
|
||||
try:
|
||||
final_value = resolved_template.format(**action_context)
|
||||
logger.debug(f"Final resolved value: '{final_value}'")
|
||||
return final_value
|
||||
except KeyError as e:
|
||||
logger.warning(f"Could not resolve context variable in template: {e}")
|
||||
return resolved_template
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving field mapping '{value_template}': {e}")
|
||||
return None
|
||||
|
||||
def resolve_multiple_fields(self,
|
||||
field_templates: Dict[str, str],
|
||||
branch_results: Dict[str, Any],
|
||||
action_context: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Resolve multiple field mappings at once.
|
||||
|
||||
Args:
|
||||
field_templates: Dictionary mapping field names to templates
|
||||
branch_results: Dictionary of branch execution results
|
||||
action_context: Context variables for template resolution
|
||||
|
||||
Returns:
|
||||
Dictionary mapping field names to resolved values
|
||||
"""
|
||||
resolved_fields = {}
|
||||
|
||||
for field_name, template in field_templates.items():
|
||||
try:
|
||||
resolved_value = self.resolve_field_mapping(template, branch_results, action_context)
|
||||
if resolved_value is not None:
|
||||
resolved_fields[field_name] = resolved_value
|
||||
logger.debug(f"Successfully resolved field '{field_name}': {resolved_value}")
|
||||
else:
|
||||
logger.warning(f"Failed to resolve field '{field_name}' with template: {template}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving field '{field_name}' with template '{template}': {e}")
|
||||
|
||||
return resolved_fields
|
||||
|
||||
def validate_template(self, template: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate a field mapping template and return analysis.
|
||||
|
||||
Args:
|
||||
template: Template string to validate
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results and analysis
|
||||
"""
|
||||
analysis = {
|
||||
"valid": True,
|
||||
"errors": [],
|
||||
"warnings": [],
|
||||
"branch_references": [],
|
||||
"context_references": [],
|
||||
"has_branch_refs": False,
|
||||
"has_context_refs": False
|
||||
}
|
||||
|
||||
try:
|
||||
# Extract all placeholders
|
||||
all_refs = re.findall(r'\{([^}]+)\}', template)
|
||||
|
||||
for ref in all_refs:
|
||||
if '.' in ref:
|
||||
# This is a branch reference
|
||||
analysis["branch_references"].append(ref)
|
||||
analysis["has_branch_refs"] = True
|
||||
|
||||
# Validate format
|
||||
parts = ref.split('.')
|
||||
if len(parts) != 2:
|
||||
analysis["errors"].append(f"Invalid branch reference format: {ref}")
|
||||
analysis["valid"] = False
|
||||
elif not parts[0] or not parts[1]:
|
||||
analysis["errors"].append(f"Empty model_id or field_name in reference: {ref}")
|
||||
analysis["valid"] = False
|
||||
else:
|
||||
# This is a context reference
|
||||
analysis["context_references"].append(ref)
|
||||
analysis["has_context_refs"] = True
|
||||
|
||||
# Check for common issues
|
||||
if analysis["has_branch_refs"] and not analysis["has_context_refs"]:
|
||||
analysis["warnings"].append("Template only uses branch references, consider adding context info")
|
||||
|
||||
if not analysis["branch_references"] and not analysis["context_references"]:
|
||||
analysis["warnings"].append("Template has no placeholders - it's a static value")
|
||||
|
||||
except Exception as e:
|
||||
analysis["valid"] = False
|
||||
analysis["errors"].append(f"Template analysis failed: {e}")
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
# Global field mapper instance
|
||||
field_mapper = FieldMapper()
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
# These provide the same interface as the original functions in pympta.py
|
||||
|
||||
def resolve_field_mapping(value_template: str,
|
||||
branch_results: Dict[str, Any],
|
||||
action_context: Dict[str, Any]) -> Optional[str]:
|
||||
"""Resolve field mapping templates like {car_brand_cls_v1.brand}."""
|
||||
return field_mapper.resolve_field_mapping(value_template, branch_results, action_context)
|
||||
|
||||
|
||||
def resolve_multiple_field_mappings(field_templates: Dict[str, str],
|
||||
branch_results: Dict[str, Any],
|
||||
action_context: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Resolve multiple field mappings at once."""
|
||||
return field_mapper.resolve_multiple_fields(field_templates, branch_results, action_context)
|
||||
|
||||
|
||||
def validate_field_mapping_template(template: str) -> Dict[str, Any]:
|
||||
"""Validate a field mapping template and return analysis."""
|
||||
return field_mapper.validate_template(template)
|
||||
|
||||
|
||||
def get_available_field_mappings(branch_results: Dict[str, Any]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get available field mappings from branch results.
|
||||
|
||||
Args:
|
||||
branch_results: Dictionary of branch execution results
|
||||
|
||||
Returns:
|
||||
Dictionary mapping model IDs to available field names
|
||||
"""
|
||||
available_mappings = {}
|
||||
|
||||
for model_id, branch_data in branch_results.items():
|
||||
if isinstance(branch_data, dict):
|
||||
available_mappings[model_id] = list(branch_data.keys())
|
||||
else:
|
||||
logger.warning(f"Branch '{model_id}' data is not a dictionary: {type(branch_data)}")
|
||||
available_mappings[model_id] = []
|
||||
|
||||
return available_mappings
|
||||
|
||||
|
||||
def create_field_mapping_examples(branch_results: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
Create example field mapping templates based on available branch results.
|
||||
|
||||
Args:
|
||||
branch_results: Dictionary of branch execution results
|
||||
|
||||
Returns:
|
||||
List of example template strings
|
||||
"""
|
||||
examples = []
|
||||
|
||||
for model_id, branch_data in branch_results.items():
|
||||
if isinstance(branch_data, dict):
|
||||
for field_name in branch_data.keys():
|
||||
example = f"{{{model_id}.{field_name}}}"
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
Loading…
Add table
Add a link
Reference in a new issue