python-detector-worker/core/detection/branches.py
ziesorx 2e5316ca01
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 8s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m44s
Build Worker Base and Application Images / deploy-stack (push) Successful in 9s
fix: model calling method
2025-09-25 15:06:41 +07:00

796 lines
No EOL
35 KiB
Python

"""
Parallel Branch Processing Module.
Handles concurrent execution of classification branches and result synchronization.
"""
import logging
import asyncio
import time
from typing import Dict, List, Optional, Any, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import cv2
from ..models.inference import YOLOWrapper
logger = logging.getLogger(__name__)
class BranchProcessor:
"""
Handles parallel processing of classification branches.
Manages branch synchronization and result collection.
"""
def __init__(self, model_manager: Any, model_id: int):
"""
Initialize branch processor.
Args:
model_manager: Model manager for loading models
model_id: The model ID to use for loading models
"""
self.model_manager = model_manager
self.model_id = model_id
# Branch models cache
self.branch_models: Dict[str, YOLOWrapper] = {}
# Thread pool for parallel execution
self.executor = ThreadPoolExecutor(max_workers=4)
# Storage managers (set during initialization)
self.redis_manager = None
self.db_manager = None
# Statistics
self.stats = {
'branches_processed': 0,
'parallel_executions': 0,
'total_processing_time': 0.0,
'models_loaded': 0
}
logger.info("BranchProcessor initialized")
async def initialize(self, pipeline_config: Any, redis_manager: Any, db_manager: Any) -> bool:
"""
Initialize branch processor with pipeline configuration.
Args:
pipeline_config: Pipeline configuration object
redis_manager: Redis manager instance
db_manager: Database manager instance
Returns:
True if successful, False otherwise
"""
try:
self.redis_manager = redis_manager
self.db_manager = db_manager
# Pre-load branch models if they exist
branches = getattr(pipeline_config, 'branches', [])
if branches:
await self._preload_branch_models(branches)
logger.info(f"BranchProcessor initialized with {len(self.branch_models)} models")
return True
except Exception as e:
logger.error(f"Error initializing branch processor: {e}", exc_info=True)
return False
async def _preload_branch_models(self, branches: List[Any]) -> None:
"""
Pre-load all branch models for faster execution.
Args:
branches: List of branch configurations
"""
for branch in branches:
try:
await self._load_branch_model(branch)
# Recursively load nested branches
nested_branches = getattr(branch, 'branches', [])
if nested_branches:
await self._preload_branch_models(nested_branches)
except Exception as e:
logger.error(f"Error preloading branch model {getattr(branch, 'model_id', 'unknown')}: {e}")
async def _load_branch_model(self, branch_config: Any) -> Optional[YOLOWrapper]:
"""
Load a branch model if not already loaded.
Args:
branch_config: Branch configuration object
Returns:
Loaded YOLO model wrapper or None
"""
try:
model_id = getattr(branch_config, 'model_id', None)
model_file = getattr(branch_config, 'model_file', None)
if not model_id or not model_file:
logger.warning(f"Invalid branch config: model_id={model_id}, model_file={model_file}")
return None
# Check if model is already loaded
if model_id in self.branch_models:
logger.debug(f"Branch model {model_id} already loaded")
return self.branch_models[model_id]
# Load model
logger.info(f"Loading branch model: {model_id} ({model_file})")
# Load model using the proper model ID
model = self.model_manager.get_yolo_model(self.model_id, model_file)
if model:
self.branch_models[model_id] = model
self.stats['models_loaded'] += 1
logger.info(f"Branch model {model_id} loaded successfully")
return model
else:
logger.error(f"Failed to load branch model {model_id}")
return None
except Exception as e:
logger.error(f"Error loading branch model {getattr(branch_config, 'model_id', 'unknown')}: {e}")
return None
async def execute_branches(self,
frame: np.ndarray,
branches: List[Any],
detected_regions: Dict[str, Any],
detection_context: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute all branches in parallel and collect results.
Args:
frame: Input frame
branches: List of branch configurations
detected_regions: Dictionary of detected regions from main detection
detection_context: Detection context data
Returns:
Dictionary with branch execution results
"""
start_time = time.time()
branch_results = {}
try:
# Separate parallel and sequential branches
parallel_branches = []
sequential_branches = []
for branch in branches:
if getattr(branch, 'parallel', False):
parallel_branches.append(branch)
else:
sequential_branches.append(branch)
# Execute parallel branches concurrently
if parallel_branches:
logger.info(f"Executing {len(parallel_branches)} branches in parallel")
parallel_results = await self._execute_parallel_branches(
frame, parallel_branches, detected_regions, detection_context
)
branch_results.update(parallel_results)
self.stats['parallel_executions'] += 1
# Execute sequential branches one by one
if sequential_branches:
logger.info(f"Executing {len(sequential_branches)} branches sequentially")
sequential_results = await self._execute_sequential_branches(
frame, sequential_branches, detected_regions, detection_context
)
branch_results.update(sequential_results)
# Update statistics
self.stats['branches_processed'] += len(branches)
processing_time = time.time() - start_time
self.stats['total_processing_time'] += processing_time
logger.info(f"Branch execution completed in {processing_time:.3f}s with {len(branch_results)} results")
except Exception as e:
logger.error(f"Error in branch execution: {e}", exc_info=True)
return branch_results
async def _execute_parallel_branches(self,
frame: np.ndarray,
branches: List[Any],
detected_regions: Dict[str, Any],
detection_context: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute branches in parallel using ThreadPoolExecutor.
Args:
frame: Input frame
branches: List of parallel branch configurations
detected_regions: Dictionary of detected regions
detection_context: Detection context data
Returns:
Dictionary with parallel branch results
"""
results = {}
# Submit all branches for parallel execution
future_to_branch = {}
for branch in branches:
branch_id = getattr(branch, 'model_id', 'unknown')
logger.info(f"[PARALLEL SUBMIT] {branch_id}: Submitting branch to thread pool")
future = self.executor.submit(
self._execute_single_branch_sync,
frame, branch, detected_regions, detection_context
)
future_to_branch[future] = branch
# Collect results as they complete
for future in as_completed(future_to_branch):
branch = future_to_branch[future]
branch_id = getattr(branch, 'model_id', 'unknown')
try:
result = future.result()
results[branch_id] = result
logger.info(f"[PARALLEL COMPLETE] {branch_id}: Branch completed successfully")
except Exception as e:
logger.error(f"Error in parallel branch {branch_id}: {e}")
results[branch_id] = {
'status': 'error',
'message': str(e),
'processing_time': 0.0
}
# Flatten nested branch results to top level for database access
flattened_results = {}
for branch_id, branch_result in results.items():
# Add the branch result itself
flattened_results[branch_id] = branch_result
# If this branch has nested branches, add them to the top level too
if isinstance(branch_result, dict) and 'nested_branches' in branch_result:
nested_branches = branch_result['nested_branches']
for nested_branch_id, nested_result in nested_branches.items():
flattened_results[nested_branch_id] = nested_result
logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results")
return flattened_results
async def _execute_sequential_branches(self,
frame: np.ndarray,
branches: List[Any],
detected_regions: Dict[str, Any],
detection_context: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute branches sequentially.
Args:
frame: Input frame
branches: List of sequential branch configurations
detected_regions: Dictionary of detected regions
detection_context: Detection context data
Returns:
Dictionary with sequential branch results
"""
results = {}
for branch in branches:
branch_id = getattr(branch, 'model_id', 'unknown')
try:
result = await asyncio.get_event_loop().run_in_executor(
self.executor,
self._execute_single_branch_sync,
frame, branch, detected_regions, detection_context
)
results[branch_id] = result
logger.debug(f"Sequential branch {branch_id} completed successfully")
except Exception as e:
logger.error(f"Error in sequential branch {branch_id}: {e}")
results[branch_id] = {
'status': 'error',
'message': str(e),
'processing_time': 0.0
}
# Flatten nested branch results to top level for database access
flattened_results = {}
for branch_id, branch_result in results.items():
# Add the branch result itself
flattened_results[branch_id] = branch_result
# If this branch has nested branches, add them to the top level too
if isinstance(branch_result, dict) and 'nested_branches' in branch_result:
nested_branches = branch_result['nested_branches']
for nested_branch_id, nested_result in nested_branches.items():
flattened_results[nested_branch_id] = nested_result
logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results")
return flattened_results
def _execute_single_branch_sync(self,
frame: np.ndarray,
branch_config: Any,
detected_regions: Dict[str, Any],
detection_context: Dict[str, Any]) -> Dict[str, Any]:
"""
Synchronous execution of a single branch (for ThreadPoolExecutor).
Args:
frame: Input frame
branch_config: Branch configuration object
detected_regions: Dictionary of detected regions
detection_context: Detection context data
Returns:
Dictionary with branch execution result
"""
start_time = time.time()
branch_id = getattr(branch_config, 'model_id', 'unknown')
logger.info(f"[BRANCH START] {branch_id}: Starting branch execution")
logger.debug(f"[BRANCH CONFIG] {branch_id}: crop={getattr(branch_config, 'crop', False)}, "
f"trigger_classes={getattr(branch_config, 'trigger_classes', [])}, "
f"min_confidence={getattr(branch_config, 'min_confidence', 0.6)}")
# Check if branch should execute based on triggerClasses (execution conditions)
trigger_classes = getattr(branch_config, 'trigger_classes', [])
logger.info(f"[DETECTED REGIONS] {branch_id}: Available parent detections: {list(detected_regions.keys())}")
for region_name, region_data in detected_regions.items():
logger.debug(f"[REGION DATA] {branch_id}: '{region_name}' -> bbox={region_data.get('bbox')}, conf={region_data.get('confidence')}")
if trigger_classes:
# Check if any parent detection matches our trigger classes
should_execute = False
for trigger_class in trigger_classes:
if trigger_class in detected_regions:
should_execute = True
logger.info(f"[TRIGGER CHECK] {branch_id}: Found '{trigger_class}' in parent detections - branch will execute")
break
if not should_execute:
logger.warning(f"[TRIGGER CHECK] {branch_id}: None of trigger classes {trigger_classes} found in parent detections {list(detected_regions.keys())} - skipping branch")
return {
'status': 'skipped',
'branch_id': branch_id,
'message': f'No trigger classes {trigger_classes} found in parent detections',
'processing_time': time.time() - start_time
}
result = {
'status': 'success',
'branch_id': branch_id,
'result': {},
'processing_time': 0.0,
'timestamp': time.time()
}
try:
# Get or load branch model
if branch_id not in self.branch_models:
logger.warning(f"Branch model {branch_id} not preloaded, loading now...")
# This should be rare since models are preloaded
return {
'status': 'error',
'message': f'Branch model {branch_id} not available',
'processing_time': time.time() - start_time
}
model = self.branch_models[branch_id]
# Get configuration values first
min_confidence = getattr(branch_config, 'min_confidence', 0.6)
# Prepare input frame for this branch
input_frame = frame
# Handle cropping if required - use biggest bbox that passes min_confidence
if getattr(branch_config, 'crop', False):
crop_classes = getattr(branch_config, 'crop_class', [])
if isinstance(crop_classes, str):
crop_classes = [crop_classes]
# Find the biggest bbox that passes min_confidence threshold
best_region = None
best_class = None
best_area = 0.0
for crop_class in crop_classes:
if crop_class in detected_regions:
region = detected_regions[crop_class]
confidence = region.get('confidence', 0.0)
# Only use detections above min_confidence
if confidence >= min_confidence:
bbox = region['bbox']
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) # width * height
# Choose biggest bbox among valid detections
if area > best_area:
best_region = region
best_class = crop_class
best_area = area
if best_region:
bbox = best_region['bbox']
x1, y1, x2, y2 = [int(coord) for coord in bbox]
cropped = frame[y1:y2, x1:x2]
if cropped.size > 0:
input_frame = cropped
confidence = best_region.get('confidence', 0.0)
logger.info(f"[CROP SUCCESS] {branch_id}: cropped '{best_class}' region (conf={confidence:.3f}, area={int(best_area)}) -> shape={cropped.shape}")
else:
logger.warning(f"Branch {branch_id}: empty crop, using full frame")
else:
logger.warning(f"Branch {branch_id}: no valid crop regions found (min_conf={min_confidence})")
logger.info(f"[INFERENCE START] {branch_id}: Running inference on {'cropped' if input_frame is not frame else 'full'} frame "
f"({input_frame.shape[1]}x{input_frame.shape[0]}) with confidence={min_confidence}")
# Determine model type and use appropriate calling method (like ML engineer's approach)
inference_start = time.time()
# Check if this is a classification model based on filename or model structure
is_classification = 'cls' in branch_id.lower() or 'classify' in branch_id.lower()
if is_classification:
# Use .predict() method for classification models (like ML engineer's classification_test.py)
detection_results = model.model.predict(source=input_frame, verbose=False)
logger.info(f"[INFERENCE DONE] {branch_id}: Classification completed in {time.time() - inference_start:.3f}s using .predict()")
else:
# Use direct model call for detection models (like ML engineer's detection_test.py)
detection_results = model.model(input_frame, conf=min_confidence, verbose=False)
logger.info(f"[INFERENCE DONE] {branch_id}: Detection completed in {time.time() - inference_start:.3f}s using direct call")
inference_time = time.time() - inference_start
# Initialize branch_detections outside the conditional
branch_detections = []
# Process results using clean, unified logic
if detection_results and len(detection_results) > 0:
result_obj = detection_results[0]
# Handle detection models (have .boxes attribute)
if hasattr(result_obj, 'boxes') and result_obj.boxes is not None:
logger.info(f"[RAW DETECTIONS] {branch_id}: Found {len(result_obj.boxes)} raw detections")
for i, box in enumerate(result_obj.boxes):
class_id = int(box.cls[0])
confidence = float(box.conf[0])
bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2]
class_name = model.model.names[class_id]
logger.debug(f"[RAW DETECTION {i+1}] {branch_id}: '{class_name}', conf={confidence:.3f}")
# All detections are included - no filtering by trigger_classes here
branch_detections.append({
'class_name': class_name,
'confidence': confidence,
'bbox': bbox
})
# Handle classification models (have .probs attribute)
elif hasattr(result_obj, 'probs') and result_obj.probs is not None:
logger.info(f"[RAW CLASSIFICATION] {branch_id}: Processing classification results")
probs = result_obj.probs
top_indices = probs.top5 # Get top 5 predictions
top_conf = probs.top5conf.cpu().numpy()
for idx, conf in zip(top_indices, top_conf):
if conf >= min_confidence:
class_name = model.model.names[int(idx)]
logger.debug(f"[CLASSIFICATION RESULT {len(branch_detections)+1}] {branch_id}: '{class_name}', conf={conf:.3f}")
# For classification, use full input frame dimensions as bbox
branch_detections.append({
'class_name': class_name,
'confidence': float(conf),
'bbox': [0, 0, input_frame.shape[1], input_frame.shape[0]]
})
else:
logger.warning(f"[UNKNOWN MODEL] {branch_id}: Model results have no .boxes or .probs")
result['result'] = {
'detections': branch_detections,
'detection_count': len(branch_detections)
}
logger.info(f"[FINAL RESULTS] {branch_id}: {len(branch_detections)} detections processed")
# Extract best result for classification models
if branch_detections:
best_detection = max(branch_detections, key=lambda x: x['confidence'])
logger.info(f"[BEST DETECTION] {branch_id}: '{best_detection['class_name']}' with confidence {best_detection['confidence']:.3f}")
# Add classification-style results for database operations
if 'brand' in branch_id.lower():
result['result']['brand'] = best_detection['class_name']
elif 'body' in branch_id.lower() or 'bodytype' in branch_id.lower():
result['result']['body_type'] = best_detection['class_name']
elif 'front_rear' in branch_id.lower():
result['result']['front_rear'] = best_detection['confidence']
logger.info(f"[CLASSIFICATION RESULT] {branch_id}: Extracted classification fields")
else:
logger.warning(f"[NO RESULTS] {branch_id}: No detections found")
# Execute branch actions if this branch found valid detections
actions_executed = []
branch_actions = getattr(branch_config, 'actions', [])
if branch_actions and branch_detections:
logger.info(f"[BRANCH ACTIONS] {branch_id}: Executing {len(branch_actions)} actions")
# Create detected_regions from THIS branch's detections for actions
branch_detected_regions = {}
for detection in branch_detections:
branch_detected_regions[detection['class_name']] = {
'bbox': detection['bbox'],
'confidence': detection['confidence']
}
for action in branch_actions:
try:
action_type = action.type.value # Access the enum value
logger.info(f"[ACTION EXECUTE] {branch_id}: Executing action '{action_type}'")
if action_type == 'redis_save_image':
action_result = self._execute_redis_save_image_sync(
action, input_frame, branch_detected_regions, detection_context
)
elif action_type == 'redis_publish':
action_result = self._execute_redis_publish_sync(
action, detection_context
)
else:
logger.warning(f"[ACTION UNKNOWN] {branch_id}: Unknown action type '{action_type}'")
action_result = {'status': 'error', 'message': f'Unknown action type: {action_type}'}
actions_executed.append({
'action_type': action_type,
'result': action_result
})
logger.info(f"[ACTION COMPLETE] {branch_id}: Action '{action_type}' result: {action_result.get('status')}")
except Exception as e:
action_type = getattr(action, 'type', None)
if action_type:
action_type = action_type.value if hasattr(action_type, 'value') else str(action_type)
logger.error(f"[ACTION ERROR] {branch_id}: Error executing action '{action_type}': {e}", exc_info=True)
actions_executed.append({
'action_type': action_type,
'result': {'status': 'error', 'message': str(e)}
})
# Add actions executed to result
if actions_executed:
result['actions_executed'] = actions_executed
# Handle nested branches ONLY if parent found valid detections
nested_branches = getattr(branch_config, 'branches', [])
if nested_branches:
# Check if parent branch found any valid detections
if not branch_detections:
logger.warning(f"[BRANCH SKIP] {branch_id}: Skipping {len(nested_branches)} nested branches - parent found no valid detections")
else:
logger.debug(f"Branch {branch_id}: executing {len(nested_branches)} nested branches")
# Create detected_regions from THIS branch's detections for nested branches
# Nested branches should see their immediate parent's detections, not the root pipeline
nested_detected_regions = {}
for detection in branch_detections:
nested_detected_regions[detection['class_name']] = {
'bbox': detection['bbox'],
'confidence': detection['confidence']
}
logger.info(f"[NESTED REGIONS] {branch_id}: Passing {list(nested_detected_regions.keys())} to nested branches")
# Note: For simplicity, nested branches are executed sequentially in this sync method
# In a full async implementation, these could also be parallelized
nested_results = {}
for nested_branch in nested_branches:
nested_result = self._execute_single_branch_sync(
input_frame, nested_branch, nested_detected_regions, detection_context
)
nested_branch_id = getattr(nested_branch, 'model_id', 'unknown')
nested_results[nested_branch_id] = nested_result
result['nested_branches'] = nested_results
except Exception as e:
logger.error(f"[BRANCH ERROR] {branch_id}: Error in execution: {e}", exc_info=True)
result['status'] = 'error'
result['message'] = str(e)
result['processing_time'] = time.time() - start_time
# Summary log
logger.info(f"[BRANCH COMPLETE] {branch_id}: status={result['status']}, "
f"processing_time={result['processing_time']:.3f}s, "
f"result_keys={list(result['result'].keys()) if result['result'] else 'none'}")
return result
def _execute_redis_save_image_sync(self,
action: Dict,
frame: np.ndarray,
detected_regions: Dict[str, Any],
context: Dict[str, Any]) -> Dict[str, Any]:
"""Execute redis_save_image action synchronously."""
if not self.redis_manager:
return {'status': 'error', 'message': 'Redis not available'}
try:
# Get image to save (cropped or full frame)
image_to_save = frame
region_name = action.params.get('region')
bbox = None
if region_name and region_name in detected_regions:
# Crop the specified region
bbox = detected_regions[region_name]['bbox']
elif region_name and region_name.lower() == 'frontal' and 'front_rear' in detected_regions:
# Special case: "frontal" region maps to "front_rear" detection
bbox = detected_regions['front_rear']['bbox']
if bbox is not None:
x1, y1, x2, y2 = [int(coord) for coord in bbox]
cropped = frame[y1:y2, x1:x2]
if cropped.size > 0:
image_to_save = cropped
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
else:
logger.warning(f"Empty crop for region '{region_name}', using full frame")
# Format key with context
key = action.params['key'].format(**context)
# Get image format parameters
import cv2
image_format = action.params.get('format', 'jpeg')
quality = action.params.get('quality', 90)
# Save to Redis synchronously using a sync Redis client
try:
import redis
import cv2
# Create a synchronous Redis client with same connection details
sync_redis = redis.Redis(
host=self.redis_manager.host,
port=self.redis_manager.port,
password=self.redis_manager.password,
db=self.redis_manager.db,
decode_responses=False, # We're storing binary data
socket_timeout=self.redis_manager.socket_timeout,
socket_connect_timeout=self.redis_manager.socket_connect_timeout
)
# Encode the image
if image_format.lower() == 'jpeg':
encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality]
success, encoded_image = cv2.imencode('.jpg', image_to_save, encode_param)
else:
success, encoded_image = cv2.imencode('.png', image_to_save)
if not success:
return {'status': 'error', 'message': 'Failed to encode image'}
# Save to Redis with expiration
expire_seconds = action.params.get('expire_seconds', 600)
result = sync_redis.setex(key, expire_seconds, encoded_image.tobytes())
sync_redis.close() # Clean up connection
if result:
# Add image_key to context for subsequent actions
context['image_key'] = key
return {'status': 'success', 'key': key}
else:
return {'status': 'error', 'message': 'Failed to save image to Redis'}
except Exception as redis_error:
logger.error(f"Error calling Redis from sync context: {redis_error}")
return {'status': 'error', 'message': f'Redis operation failed: {redis_error}'}
except Exception as e:
logger.error(f"Error in redis_save_image action: {e}", exc_info=True)
return {'status': 'error', 'message': str(e)}
def _execute_redis_publish_sync(self, action: Dict, context: Dict[str, Any]) -> Dict[str, Any]:
"""Execute redis_publish action synchronously."""
if not self.redis_manager:
return {'status': 'error', 'message': 'Redis not available'}
try:
channel = action.params['channel']
message_template = action.params['message']
# Debug the message template
logger.debug(f"Message template: {repr(message_template)}")
logger.debug(f"Context keys: {list(context.keys())}")
# Format message with context - handle JSON string formatting carefully
# The message template contains JSON which causes issues with .format()
# Use string replacement instead of format to avoid JSON brace conflicts
try:
# Ensure image_key is available for message formatting
if 'image_key' not in context:
context['image_key'] = '' # Default empty value if redis_save_image failed
# Use string replacement to avoid JSON formatting issues
message = message_template
for key, value in context.items():
placeholder = '{' + key + '}'
message = message.replace(placeholder, str(value))
logger.debug(f"Formatted message using replacement: {message}")
except Exception as e:
logger.error(f"Message formatting failed: {e}")
logger.error(f"Template: {repr(message_template)}")
logger.error(f"Context: {context}")
return {'status': 'error', 'message': f'Message formatting failed: {e}'}
# Publish message synchronously using a sync Redis client
try:
import redis
# Create a synchronous Redis client with same connection details
sync_redis = redis.Redis(
host=self.redis_manager.host,
port=self.redis_manager.port,
password=self.redis_manager.password,
db=self.redis_manager.db,
decode_responses=True, # For publishing text messages
socket_timeout=self.redis_manager.socket_timeout,
socket_connect_timeout=self.redis_manager.socket_connect_timeout
)
# Publish message
result = sync_redis.publish(channel, message)
sync_redis.close() # Clean up connection
if result >= 0: # Redis publish returns number of subscribers
return {'status': 'success', 'subscribers': result, 'channel': channel}
else:
return {'status': 'error', 'message': 'Failed to publish message to Redis'}
except Exception as redis_error:
logger.error(f"Error calling Redis from sync context: {redis_error}")
return {'status': 'error', 'message': f'Redis operation failed: {redis_error}'}
except Exception as e:
logger.error(f"Error in redis_publish action: {e}", exc_info=True)
return {'status': 'error', 'message': str(e)}
def get_statistics(self) -> Dict[str, Any]:
"""Get branch processor statistics."""
return {
**self.stats,
'loaded_models': list(self.branch_models.keys()),
'model_count': len(self.branch_models)
}
def cleanup(self):
"""Cleanup resources."""
if self.executor:
self.executor.shutdown(wait=False)
# Clear model cache
self.branch_models.clear()
logger.info("BranchProcessor cleaned up")