795 lines
No EOL
35 KiB
Python
795 lines
No EOL
35 KiB
Python
"""
|
|
Parallel Branch Processing Module.
|
|
Handles concurrent execution of classification branches and result synchronization.
|
|
"""
|
|
import logging
|
|
import asyncio
|
|
import time
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
import numpy as np
|
|
import cv2
|
|
|
|
from ..models.inference import YOLOWrapper
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BranchProcessor:
|
|
"""
|
|
Handles parallel processing of classification branches.
|
|
Manages branch synchronization and result collection.
|
|
"""
|
|
|
|
def __init__(self, model_manager: Any):
|
|
"""
|
|
Initialize branch processor.
|
|
|
|
Args:
|
|
model_manager: Model manager for loading models
|
|
"""
|
|
self.model_manager = model_manager
|
|
|
|
# Branch models cache
|
|
self.branch_models: Dict[str, YOLOWrapper] = {}
|
|
|
|
# Thread pool for parallel execution
|
|
self.executor = ThreadPoolExecutor(max_workers=4)
|
|
|
|
# Storage managers (set during initialization)
|
|
self.redis_manager = None
|
|
self.db_manager = None
|
|
|
|
# Statistics
|
|
self.stats = {
|
|
'branches_processed': 0,
|
|
'parallel_executions': 0,
|
|
'total_processing_time': 0.0,
|
|
'models_loaded': 0
|
|
}
|
|
|
|
logger.info("BranchProcessor initialized")
|
|
|
|
async def initialize(self, pipeline_config: Any, redis_manager: Any, db_manager: Any) -> bool:
|
|
"""
|
|
Initialize branch processor with pipeline configuration.
|
|
|
|
Args:
|
|
pipeline_config: Pipeline configuration object
|
|
redis_manager: Redis manager instance
|
|
db_manager: Database manager instance
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
self.redis_manager = redis_manager
|
|
self.db_manager = db_manager
|
|
|
|
# Pre-load branch models if they exist
|
|
branches = getattr(pipeline_config, 'branches', [])
|
|
if branches:
|
|
await self._preload_branch_models(branches)
|
|
|
|
logger.info(f"BranchProcessor initialized with {len(self.branch_models)} models")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing branch processor: {e}", exc_info=True)
|
|
return False
|
|
|
|
async def _preload_branch_models(self, branches: List[Any]) -> None:
|
|
"""
|
|
Pre-load all branch models for faster execution.
|
|
|
|
Args:
|
|
branches: List of branch configurations
|
|
"""
|
|
for branch in branches:
|
|
try:
|
|
await self._load_branch_model(branch)
|
|
|
|
# Recursively load nested branches
|
|
nested_branches = getattr(branch, 'branches', [])
|
|
if nested_branches:
|
|
await self._preload_branch_models(nested_branches)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error preloading branch model {getattr(branch, 'model_id', 'unknown')}: {e}")
|
|
|
|
async def _load_branch_model(self, branch_config: Any) -> Optional[YOLOWrapper]:
|
|
"""
|
|
Load a branch model if not already loaded.
|
|
|
|
Args:
|
|
branch_config: Branch configuration object
|
|
|
|
Returns:
|
|
Loaded YOLO model wrapper or None
|
|
"""
|
|
try:
|
|
model_id = getattr(branch_config, 'model_id', None)
|
|
model_file = getattr(branch_config, 'model_file', None)
|
|
|
|
if not model_id or not model_file:
|
|
logger.warning(f"Invalid branch config: model_id={model_id}, model_file={model_file}")
|
|
return None
|
|
|
|
# Check if model is already loaded
|
|
if model_id in self.branch_models:
|
|
logger.debug(f"Branch model {model_id} already loaded")
|
|
return self.branch_models[model_id]
|
|
|
|
# Load model
|
|
logger.info(f"Loading branch model: {model_id} ({model_file})")
|
|
|
|
# Get the first available model ID from ModelManager
|
|
pipeline_models = list(self.model_manager.get_all_downloaded_models())
|
|
if pipeline_models:
|
|
actual_model_id = pipeline_models[0] # Use the first available model
|
|
model = self.model_manager.get_yolo_model(actual_model_id, model_file)
|
|
|
|
if model:
|
|
self.branch_models[model_id] = model
|
|
self.stats['models_loaded'] += 1
|
|
logger.info(f"Branch model {model_id} loaded successfully")
|
|
return model
|
|
else:
|
|
logger.error(f"Failed to load branch model {model_id}")
|
|
return None
|
|
else:
|
|
logger.error("No models available in ModelManager for branch loading")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading branch model {getattr(branch_config, 'model_id', 'unknown')}: {e}")
|
|
return None
|
|
|
|
async def execute_branches(self,
|
|
frame: np.ndarray,
|
|
branches: List[Any],
|
|
detected_regions: Dict[str, Any],
|
|
detection_context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Execute all branches in parallel and collect results.
|
|
|
|
Args:
|
|
frame: Input frame
|
|
branches: List of branch configurations
|
|
detected_regions: Dictionary of detected regions from main detection
|
|
detection_context: Detection context data
|
|
|
|
Returns:
|
|
Dictionary with branch execution results
|
|
"""
|
|
start_time = time.time()
|
|
branch_results = {}
|
|
|
|
try:
|
|
# Separate parallel and sequential branches
|
|
parallel_branches = []
|
|
sequential_branches = []
|
|
|
|
for branch in branches:
|
|
if getattr(branch, 'parallel', False):
|
|
parallel_branches.append(branch)
|
|
else:
|
|
sequential_branches.append(branch)
|
|
|
|
# Execute parallel branches concurrently
|
|
if parallel_branches:
|
|
logger.info(f"Executing {len(parallel_branches)} branches in parallel")
|
|
parallel_results = await self._execute_parallel_branches(
|
|
frame, parallel_branches, detected_regions, detection_context
|
|
)
|
|
branch_results.update(parallel_results)
|
|
self.stats['parallel_executions'] += 1
|
|
|
|
# Execute sequential branches one by one
|
|
if sequential_branches:
|
|
logger.info(f"Executing {len(sequential_branches)} branches sequentially")
|
|
sequential_results = await self._execute_sequential_branches(
|
|
frame, sequential_branches, detected_regions, detection_context
|
|
)
|
|
branch_results.update(sequential_results)
|
|
|
|
# Update statistics
|
|
self.stats['branches_processed'] += len(branches)
|
|
processing_time = time.time() - start_time
|
|
self.stats['total_processing_time'] += processing_time
|
|
|
|
logger.info(f"Branch execution completed in {processing_time:.3f}s with {len(branch_results)} results")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in branch execution: {e}", exc_info=True)
|
|
|
|
return branch_results
|
|
|
|
async def _execute_parallel_branches(self,
|
|
frame: np.ndarray,
|
|
branches: List[Any],
|
|
detected_regions: Dict[str, Any],
|
|
detection_context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Execute branches in parallel using ThreadPoolExecutor.
|
|
|
|
Args:
|
|
frame: Input frame
|
|
branches: List of parallel branch configurations
|
|
detected_regions: Dictionary of detected regions
|
|
detection_context: Detection context data
|
|
|
|
Returns:
|
|
Dictionary with parallel branch results
|
|
"""
|
|
results = {}
|
|
|
|
# Submit all branches for parallel execution
|
|
future_to_branch = {}
|
|
|
|
for branch in branches:
|
|
branch_id = getattr(branch, 'model_id', 'unknown')
|
|
logger.info(f"[PARALLEL SUBMIT] {branch_id}: Submitting branch to thread pool")
|
|
|
|
future = self.executor.submit(
|
|
self._execute_single_branch_sync,
|
|
frame, branch, detected_regions, detection_context
|
|
)
|
|
future_to_branch[future] = branch
|
|
|
|
# Collect results as they complete
|
|
for future in as_completed(future_to_branch):
|
|
branch = future_to_branch[future]
|
|
branch_id = getattr(branch, 'model_id', 'unknown')
|
|
|
|
try:
|
|
result = future.result()
|
|
results[branch_id] = result
|
|
logger.info(f"[PARALLEL COMPLETE] {branch_id}: Branch completed successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error in parallel branch {branch_id}: {e}")
|
|
results[branch_id] = {
|
|
'status': 'error',
|
|
'message': str(e),
|
|
'processing_time': 0.0
|
|
}
|
|
|
|
# Flatten nested branch results to top level for database access
|
|
flattened_results = {}
|
|
for branch_id, branch_result in results.items():
|
|
# Add the branch result itself
|
|
flattened_results[branch_id] = branch_result
|
|
|
|
# If this branch has nested branches, add them to the top level too
|
|
if isinstance(branch_result, dict) and 'nested_branches' in branch_result:
|
|
nested_branches = branch_result['nested_branches']
|
|
for nested_branch_id, nested_result in nested_branches.items():
|
|
flattened_results[nested_branch_id] = nested_result
|
|
logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results")
|
|
|
|
return flattened_results
|
|
|
|
async def _execute_sequential_branches(self,
|
|
frame: np.ndarray,
|
|
branches: List[Any],
|
|
detected_regions: Dict[str, Any],
|
|
detection_context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Execute branches sequentially.
|
|
|
|
Args:
|
|
frame: Input frame
|
|
branches: List of sequential branch configurations
|
|
detected_regions: Dictionary of detected regions
|
|
detection_context: Detection context data
|
|
|
|
Returns:
|
|
Dictionary with sequential branch results
|
|
"""
|
|
results = {}
|
|
|
|
for branch in branches:
|
|
branch_id = getattr(branch, 'model_id', 'unknown')
|
|
|
|
try:
|
|
result = await asyncio.get_event_loop().run_in_executor(
|
|
self.executor,
|
|
self._execute_single_branch_sync,
|
|
frame, branch, detected_regions, detection_context
|
|
)
|
|
results[branch_id] = result
|
|
logger.debug(f"Sequential branch {branch_id} completed successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error in sequential branch {branch_id}: {e}")
|
|
results[branch_id] = {
|
|
'status': 'error',
|
|
'message': str(e),
|
|
'processing_time': 0.0
|
|
}
|
|
|
|
# Flatten nested branch results to top level for database access
|
|
flattened_results = {}
|
|
for branch_id, branch_result in results.items():
|
|
# Add the branch result itself
|
|
flattened_results[branch_id] = branch_result
|
|
|
|
# If this branch has nested branches, add them to the top level too
|
|
if isinstance(branch_result, dict) and 'nested_branches' in branch_result:
|
|
nested_branches = branch_result['nested_branches']
|
|
for nested_branch_id, nested_result in nested_branches.items():
|
|
flattened_results[nested_branch_id] = nested_result
|
|
logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results")
|
|
|
|
return flattened_results
|
|
|
|
def _execute_single_branch_sync(self,
|
|
frame: np.ndarray,
|
|
branch_config: Any,
|
|
detected_regions: Dict[str, Any],
|
|
detection_context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Synchronous execution of a single branch (for ThreadPoolExecutor).
|
|
|
|
Args:
|
|
frame: Input frame
|
|
branch_config: Branch configuration object
|
|
detected_regions: Dictionary of detected regions
|
|
detection_context: Detection context data
|
|
|
|
Returns:
|
|
Dictionary with branch execution result
|
|
"""
|
|
start_time = time.time()
|
|
branch_id = getattr(branch_config, 'model_id', 'unknown')
|
|
|
|
logger.info(f"[BRANCH START] {branch_id}: Starting branch execution")
|
|
logger.debug(f"[BRANCH CONFIG] {branch_id}: crop={getattr(branch_config, 'crop', False)}, "
|
|
f"trigger_classes={getattr(branch_config, 'trigger_classes', [])}, "
|
|
f"min_confidence={getattr(branch_config, 'min_confidence', 0.6)}")
|
|
|
|
# Check if branch should execute based on triggerClasses (execution conditions)
|
|
trigger_classes = getattr(branch_config, 'trigger_classes', [])
|
|
logger.info(f"[DETECTED REGIONS] {branch_id}: Available parent detections: {list(detected_regions.keys())}")
|
|
for region_name, region_data in detected_regions.items():
|
|
logger.debug(f"[REGION DATA] {branch_id}: '{region_name}' -> bbox={region_data.get('bbox')}, conf={region_data.get('confidence')}")
|
|
|
|
if trigger_classes:
|
|
# Check if any parent detection matches our trigger classes
|
|
should_execute = False
|
|
for trigger_class in trigger_classes:
|
|
if trigger_class in detected_regions:
|
|
should_execute = True
|
|
logger.info(f"[TRIGGER CHECK] {branch_id}: Found '{trigger_class}' in parent detections - branch will execute")
|
|
break
|
|
|
|
if not should_execute:
|
|
logger.warning(f"[TRIGGER CHECK] {branch_id}: None of trigger classes {trigger_classes} found in parent detections {list(detected_regions.keys())} - skipping branch")
|
|
return {
|
|
'status': 'skipped',
|
|
'branch_id': branch_id,
|
|
'message': f'No trigger classes {trigger_classes} found in parent detections',
|
|
'processing_time': time.time() - start_time
|
|
}
|
|
|
|
result = {
|
|
'status': 'success',
|
|
'branch_id': branch_id,
|
|
'result': {},
|
|
'processing_time': 0.0,
|
|
'timestamp': time.time()
|
|
}
|
|
|
|
try:
|
|
# Get or load branch model
|
|
if branch_id not in self.branch_models:
|
|
logger.warning(f"Branch model {branch_id} not preloaded, loading now...")
|
|
# This should be rare since models are preloaded
|
|
return {
|
|
'status': 'error',
|
|
'message': f'Branch model {branch_id} not available',
|
|
'processing_time': time.time() - start_time
|
|
}
|
|
|
|
model = self.branch_models[branch_id]
|
|
|
|
# Get configuration values first
|
|
min_confidence = getattr(branch_config, 'min_confidence', 0.6)
|
|
|
|
# Prepare input frame for this branch
|
|
input_frame = frame
|
|
|
|
# Handle cropping if required - use biggest bbox that passes min_confidence
|
|
if getattr(branch_config, 'crop', False):
|
|
crop_classes = getattr(branch_config, 'crop_class', [])
|
|
if isinstance(crop_classes, str):
|
|
crop_classes = [crop_classes]
|
|
|
|
# Find the biggest bbox that passes min_confidence threshold
|
|
best_region = None
|
|
best_class = None
|
|
best_area = 0.0
|
|
|
|
for crop_class in crop_classes:
|
|
if crop_class in detected_regions:
|
|
region = detected_regions[crop_class]
|
|
confidence = region.get('confidence', 0.0)
|
|
|
|
# Only use detections above min_confidence
|
|
if confidence >= min_confidence:
|
|
bbox = region['bbox']
|
|
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) # width * height
|
|
|
|
# Choose biggest bbox among valid detections
|
|
if area > best_area:
|
|
best_region = region
|
|
best_class = crop_class
|
|
best_area = area
|
|
|
|
if best_region:
|
|
bbox = best_region['bbox']
|
|
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
|
cropped = frame[y1:y2, x1:x2]
|
|
if cropped.size > 0:
|
|
input_frame = cropped
|
|
confidence = best_region.get('confidence', 0.0)
|
|
logger.info(f"[CROP SUCCESS] {branch_id}: cropped '{best_class}' region (conf={confidence:.3f}, area={int(best_area)}) -> shape={cropped.shape}")
|
|
else:
|
|
logger.warning(f"Branch {branch_id}: empty crop, using full frame")
|
|
else:
|
|
logger.warning(f"Branch {branch_id}: no valid crop regions found (min_conf={min_confidence})")
|
|
|
|
logger.info(f"[INFERENCE START] {branch_id}: Running inference on {'cropped' if input_frame is not frame else 'full'} frame "
|
|
f"({input_frame.shape[1]}x{input_frame.shape[0]}) with confidence={min_confidence}")
|
|
|
|
|
|
# Use .predict() method for both detection and classification models
|
|
inference_start = time.time()
|
|
detection_results = model.model.predict(input_frame, conf=min_confidence, verbose=False)
|
|
inference_time = time.time() - inference_start
|
|
logger.info(f"[INFERENCE DONE] {branch_id}: Predict completed in {inference_time:.3f}s using .predict() method")
|
|
|
|
# Initialize branch_detections outside the conditional
|
|
branch_detections = []
|
|
|
|
# Process results using clean, unified logic
|
|
if detection_results and len(detection_results) > 0:
|
|
result_obj = detection_results[0]
|
|
|
|
# Handle detection models (have .boxes attribute)
|
|
if hasattr(result_obj, 'boxes') and result_obj.boxes is not None:
|
|
logger.info(f"[RAW DETECTIONS] {branch_id}: Found {len(result_obj.boxes)} raw detections")
|
|
|
|
for i, box in enumerate(result_obj.boxes):
|
|
class_id = int(box.cls[0])
|
|
confidence = float(box.conf[0])
|
|
bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2]
|
|
class_name = model.model.names[class_id]
|
|
|
|
logger.debug(f"[RAW DETECTION {i+1}] {branch_id}: '{class_name}', conf={confidence:.3f}")
|
|
|
|
# All detections are included - no filtering by trigger_classes here
|
|
branch_detections.append({
|
|
'class_name': class_name,
|
|
'confidence': confidence,
|
|
'bbox': bbox
|
|
})
|
|
|
|
# Handle classification models (have .probs attribute)
|
|
elif hasattr(result_obj, 'probs') and result_obj.probs is not None:
|
|
logger.info(f"[RAW CLASSIFICATION] {branch_id}: Processing classification results")
|
|
|
|
probs = result_obj.probs
|
|
top_indices = probs.top5 # Get top 5 predictions
|
|
top_conf = probs.top5conf.cpu().numpy()
|
|
|
|
for idx, conf in zip(top_indices, top_conf):
|
|
if conf >= min_confidence:
|
|
class_name = model.model.names[int(idx)]
|
|
logger.debug(f"[CLASSIFICATION RESULT {len(branch_detections)+1}] {branch_id}: '{class_name}', conf={conf:.3f}")
|
|
|
|
# For classification, use full input frame dimensions as bbox
|
|
branch_detections.append({
|
|
'class_name': class_name,
|
|
'confidence': float(conf),
|
|
'bbox': [0, 0, input_frame.shape[1], input_frame.shape[0]]
|
|
})
|
|
else:
|
|
logger.warning(f"[UNKNOWN MODEL] {branch_id}: Model results have no .boxes or .probs")
|
|
|
|
result['result'] = {
|
|
'detections': branch_detections,
|
|
'detection_count': len(branch_detections)
|
|
}
|
|
|
|
logger.info(f"[FINAL RESULTS] {branch_id}: {len(branch_detections)} detections processed")
|
|
|
|
# Extract best result for classification models
|
|
if branch_detections:
|
|
best_detection = max(branch_detections, key=lambda x: x['confidence'])
|
|
logger.info(f"[BEST DETECTION] {branch_id}: '{best_detection['class_name']}' with confidence {best_detection['confidence']:.3f}")
|
|
|
|
# Add classification-style results for database operations
|
|
if 'brand' in branch_id.lower():
|
|
result['result']['brand'] = best_detection['class_name']
|
|
elif 'body' in branch_id.lower() or 'bodytype' in branch_id.lower():
|
|
result['result']['body_type'] = best_detection['class_name']
|
|
elif 'front_rear' in branch_id.lower():
|
|
result['result']['front_rear'] = best_detection['confidence']
|
|
|
|
logger.info(f"[CLASSIFICATION RESULT] {branch_id}: Extracted classification fields")
|
|
else:
|
|
logger.warning(f"[NO RESULTS] {branch_id}: No detections found")
|
|
|
|
# Execute branch actions if this branch found valid detections
|
|
actions_executed = []
|
|
branch_actions = getattr(branch_config, 'actions', [])
|
|
if branch_actions and branch_detections:
|
|
logger.info(f"[BRANCH ACTIONS] {branch_id}: Executing {len(branch_actions)} actions")
|
|
|
|
# Create detected_regions from THIS branch's detections for actions
|
|
branch_detected_regions = {}
|
|
for detection in branch_detections:
|
|
branch_detected_regions[detection['class_name']] = {
|
|
'bbox': detection['bbox'],
|
|
'confidence': detection['confidence']
|
|
}
|
|
|
|
for action in branch_actions:
|
|
try:
|
|
action_type = action.type.value # Access the enum value
|
|
logger.info(f"[ACTION EXECUTE] {branch_id}: Executing action '{action_type}'")
|
|
|
|
if action_type == 'redis_save_image':
|
|
action_result = self._execute_redis_save_image_sync(
|
|
action, input_frame, branch_detected_regions, detection_context
|
|
)
|
|
elif action_type == 'redis_publish':
|
|
action_result = self._execute_redis_publish_sync(
|
|
action, detection_context
|
|
)
|
|
else:
|
|
logger.warning(f"[ACTION UNKNOWN] {branch_id}: Unknown action type '{action_type}'")
|
|
action_result = {'status': 'error', 'message': f'Unknown action type: {action_type}'}
|
|
|
|
actions_executed.append({
|
|
'action_type': action_type,
|
|
'result': action_result
|
|
})
|
|
|
|
logger.info(f"[ACTION COMPLETE] {branch_id}: Action '{action_type}' result: {action_result.get('status')}")
|
|
|
|
except Exception as e:
|
|
action_type = getattr(action, 'type', None)
|
|
if action_type:
|
|
action_type = action_type.value if hasattr(action_type, 'value') else str(action_type)
|
|
logger.error(f"[ACTION ERROR] {branch_id}: Error executing action '{action_type}': {e}", exc_info=True)
|
|
actions_executed.append({
|
|
'action_type': action_type,
|
|
'result': {'status': 'error', 'message': str(e)}
|
|
})
|
|
|
|
# Add actions executed to result
|
|
if actions_executed:
|
|
result['actions_executed'] = actions_executed
|
|
|
|
# Handle nested branches ONLY if parent found valid detections
|
|
nested_branches = getattr(branch_config, 'branches', [])
|
|
if nested_branches:
|
|
# Check if parent branch found any valid detections
|
|
if not branch_detections:
|
|
logger.warning(f"[BRANCH SKIP] {branch_id}: Skipping {len(nested_branches)} nested branches - parent found no valid detections")
|
|
else:
|
|
logger.debug(f"Branch {branch_id}: executing {len(nested_branches)} nested branches")
|
|
|
|
# Create detected_regions from THIS branch's detections for nested branches
|
|
# Nested branches should see their immediate parent's detections, not the root pipeline
|
|
nested_detected_regions = {}
|
|
for detection in branch_detections:
|
|
nested_detected_regions[detection['class_name']] = {
|
|
'bbox': detection['bbox'],
|
|
'confidence': detection['confidence']
|
|
}
|
|
|
|
logger.info(f"[NESTED REGIONS] {branch_id}: Passing {list(nested_detected_regions.keys())} to nested branches")
|
|
|
|
# Note: For simplicity, nested branches are executed sequentially in this sync method
|
|
# In a full async implementation, these could also be parallelized
|
|
nested_results = {}
|
|
for nested_branch in nested_branches:
|
|
nested_result = self._execute_single_branch_sync(
|
|
input_frame, nested_branch, nested_detected_regions, detection_context
|
|
)
|
|
nested_branch_id = getattr(nested_branch, 'model_id', 'unknown')
|
|
nested_results[nested_branch_id] = nested_result
|
|
|
|
result['nested_branches'] = nested_results
|
|
|
|
except Exception as e:
|
|
logger.error(f"[BRANCH ERROR] {branch_id}: Error in execution: {e}", exc_info=True)
|
|
result['status'] = 'error'
|
|
result['message'] = str(e)
|
|
|
|
result['processing_time'] = time.time() - start_time
|
|
|
|
# Summary log
|
|
logger.info(f"[BRANCH COMPLETE] {branch_id}: status={result['status']}, "
|
|
f"processing_time={result['processing_time']:.3f}s, "
|
|
f"result_keys={list(result['result'].keys()) if result['result'] else 'none'}")
|
|
|
|
return result
|
|
|
|
def _execute_redis_save_image_sync(self,
|
|
action: Dict,
|
|
frame: np.ndarray,
|
|
detected_regions: Dict[str, Any],
|
|
context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Execute redis_save_image action synchronously."""
|
|
if not self.redis_manager:
|
|
return {'status': 'error', 'message': 'Redis not available'}
|
|
|
|
try:
|
|
# Get image to save (cropped or full frame)
|
|
image_to_save = frame
|
|
region_name = action.params.get('region')
|
|
|
|
bbox = None
|
|
if region_name and region_name in detected_regions:
|
|
# Crop the specified region
|
|
bbox = detected_regions[region_name]['bbox']
|
|
elif region_name and region_name.lower() == 'frontal' and 'front_rear' in detected_regions:
|
|
# Special case: "frontal" region maps to "front_rear" detection
|
|
bbox = detected_regions['front_rear']['bbox']
|
|
|
|
if bbox is not None:
|
|
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
|
cropped = frame[y1:y2, x1:x2]
|
|
if cropped.size > 0:
|
|
image_to_save = cropped
|
|
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
|
|
else:
|
|
logger.warning(f"Empty crop for region '{region_name}', using full frame")
|
|
|
|
# Format key with context
|
|
key = action.params['key'].format(**context)
|
|
|
|
# Convert image to bytes
|
|
import cv2
|
|
image_format = action.params.get('format', 'jpeg')
|
|
quality = action.params.get('quality', 90)
|
|
|
|
if image_format.lower() == 'jpeg':
|
|
encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
|
_, image_bytes = cv2.imencode('.jpg', image_to_save, encode_param)
|
|
else:
|
|
_, image_bytes = cv2.imencode('.png', image_to_save)
|
|
|
|
# Save to Redis synchronously using a sync Redis client
|
|
try:
|
|
import redis
|
|
import cv2
|
|
|
|
# Create a synchronous Redis client with same connection details
|
|
sync_redis = redis.Redis(
|
|
host=self.redis_manager.host,
|
|
port=self.redis_manager.port,
|
|
password=self.redis_manager.password,
|
|
db=self.redis_manager.db,
|
|
decode_responses=False, # We're storing binary data
|
|
socket_timeout=self.redis_manager.socket_timeout,
|
|
socket_connect_timeout=self.redis_manager.socket_connect_timeout
|
|
)
|
|
|
|
# Encode the image
|
|
if image_format.lower() == 'jpeg':
|
|
encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
|
success, encoded_image = cv2.imencode('.jpg', image_to_save, encode_param)
|
|
else:
|
|
success, encoded_image = cv2.imencode('.png', image_to_save)
|
|
|
|
if not success:
|
|
return {'status': 'error', 'message': 'Failed to encode image'}
|
|
|
|
# Save to Redis with expiration
|
|
expire_seconds = action.params.get('expire_seconds', 600)
|
|
result = sync_redis.setex(key, expire_seconds, encoded_image.tobytes())
|
|
|
|
sync_redis.close() # Clean up connection
|
|
|
|
if result:
|
|
# Add image_key to context for subsequent actions
|
|
context['image_key'] = key
|
|
return {'status': 'success', 'key': key}
|
|
else:
|
|
return {'status': 'error', 'message': 'Failed to save image to Redis'}
|
|
|
|
except Exception as redis_error:
|
|
logger.error(f"Error calling Redis from sync context: {redis_error}")
|
|
return {'status': 'error', 'message': f'Redis operation failed: {redis_error}'}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in redis_save_image action: {e}", exc_info=True)
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
def _execute_redis_publish_sync(self, action: Dict, context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Execute redis_publish action synchronously."""
|
|
if not self.redis_manager:
|
|
return {'status': 'error', 'message': 'Redis not available'}
|
|
|
|
try:
|
|
channel = action.params['channel']
|
|
message_template = action.params['message']
|
|
|
|
# Debug the message template
|
|
logger.debug(f"Message template: {repr(message_template)}")
|
|
logger.debug(f"Context keys: {list(context.keys())}")
|
|
|
|
# Format message with context - handle JSON string formatting carefully
|
|
# The message template contains JSON which causes issues with .format()
|
|
# Use string replacement instead of format to avoid JSON brace conflicts
|
|
try:
|
|
# Ensure image_key is available for message formatting
|
|
if 'image_key' not in context:
|
|
context['image_key'] = '' # Default empty value if redis_save_image failed
|
|
|
|
# Use string replacement to avoid JSON formatting issues
|
|
message = message_template
|
|
for key, value in context.items():
|
|
placeholder = '{' + key + '}'
|
|
message = message.replace(placeholder, str(value))
|
|
|
|
logger.debug(f"Formatted message using replacement: {message}")
|
|
except Exception as e:
|
|
logger.error(f"Message formatting failed: {e}")
|
|
logger.error(f"Template: {repr(message_template)}")
|
|
logger.error(f"Context: {context}")
|
|
return {'status': 'error', 'message': f'Message formatting failed: {e}'}
|
|
|
|
# Publish message synchronously using a sync Redis client
|
|
try:
|
|
import redis
|
|
|
|
# Create a synchronous Redis client with same connection details
|
|
sync_redis = redis.Redis(
|
|
host=self.redis_manager.host,
|
|
port=self.redis_manager.port,
|
|
password=self.redis_manager.password,
|
|
db=self.redis_manager.db,
|
|
decode_responses=True, # For publishing text messages
|
|
socket_timeout=self.redis_manager.socket_timeout,
|
|
socket_connect_timeout=self.redis_manager.socket_connect_timeout
|
|
)
|
|
|
|
# Publish message
|
|
result = sync_redis.publish(channel, message)
|
|
sync_redis.close() # Clean up connection
|
|
|
|
if result >= 0: # Redis publish returns number of subscribers
|
|
return {'status': 'success', 'subscribers': result, 'channel': channel}
|
|
else:
|
|
return {'status': 'error', 'message': 'Failed to publish message to Redis'}
|
|
|
|
except Exception as redis_error:
|
|
logger.error(f"Error calling Redis from sync context: {redis_error}")
|
|
return {'status': 'error', 'message': f'Redis operation failed: {redis_error}'}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in redis_publish action: {e}", exc_info=True)
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get branch processor statistics."""
|
|
return {
|
|
**self.stats,
|
|
'loaded_models': list(self.branch_models.keys()),
|
|
'model_count': len(self.branch_models)
|
|
}
|
|
|
|
def cleanup(self):
|
|
"""Cleanup resources."""
|
|
if self.executor:
|
|
self.executor.shutdown(wait=False)
|
|
|
|
# Clear model cache
|
|
self.branch_models.clear()
|
|
|
|
logger.info("BranchProcessor cleaned up") |