""" Parallel Branch Processing Module. Handles concurrent execution of classification branches and result synchronization. """ import logging import asyncio import time from typing import Dict, List, Optional, Any, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np import cv2 from ..models.inference import YOLOWrapper logger = logging.getLogger(__name__) class BranchProcessor: """ Handles parallel processing of classification branches. Manages branch synchronization and result collection. """ def __init__(self, model_manager: Any): """ Initialize branch processor. Args: model_manager: Model manager for loading models """ self.model_manager = model_manager # Branch models cache self.branch_models: Dict[str, YOLOWrapper] = {} # Thread pool for parallel execution self.executor = ThreadPoolExecutor(max_workers=4) # Storage managers (set during initialization) self.redis_manager = None self.db_manager = None # Statistics self.stats = { 'branches_processed': 0, 'parallel_executions': 0, 'total_processing_time': 0.0, 'models_loaded': 0 } logger.info("BranchProcessor initialized") async def initialize(self, pipeline_config: Any, redis_manager: Any, db_manager: Any) -> bool: """ Initialize branch processor with pipeline configuration. Args: pipeline_config: Pipeline configuration object redis_manager: Redis manager instance db_manager: Database manager instance Returns: True if successful, False otherwise """ try: self.redis_manager = redis_manager self.db_manager = db_manager # Pre-load branch models if they exist branches = getattr(pipeline_config, 'branches', []) if branches: await self._preload_branch_models(branches) logger.info(f"BranchProcessor initialized with {len(self.branch_models)} models") return True except Exception as e: logger.error(f"Error initializing branch processor: {e}", exc_info=True) return False async def _preload_branch_models(self, branches: List[Any]) -> None: """ Pre-load all branch models for faster execution. Args: branches: List of branch configurations """ for branch in branches: try: await self._load_branch_model(branch) # Recursively load nested branches nested_branches = getattr(branch, 'branches', []) if nested_branches: await self._preload_branch_models(nested_branches) except Exception as e: logger.error(f"Error preloading branch model {getattr(branch, 'model_id', 'unknown')}: {e}") async def _load_branch_model(self, branch_config: Any) -> Optional[YOLOWrapper]: """ Load a branch model if not already loaded. Args: branch_config: Branch configuration object Returns: Loaded YOLO model wrapper or None """ try: model_id = getattr(branch_config, 'model_id', None) model_file = getattr(branch_config, 'model_file', None) if not model_id or not model_file: logger.warning(f"Invalid branch config: model_id={model_id}, model_file={model_file}") return None # Check if model is already loaded if model_id in self.branch_models: logger.debug(f"Branch model {model_id} already loaded") return self.branch_models[model_id] # Load model logger.info(f"Loading branch model: {model_id} ({model_file})") # Get the first available model ID from ModelManager pipeline_models = list(self.model_manager.get_all_downloaded_models()) if pipeline_models: actual_model_id = pipeline_models[0] # Use the first available model model = self.model_manager.get_yolo_model(actual_model_id, model_file) if model: self.branch_models[model_id] = model self.stats['models_loaded'] += 1 logger.info(f"Branch model {model_id} loaded successfully") return model else: logger.error(f"Failed to load branch model {model_id}") return None else: logger.error("No models available in ModelManager for branch loading") return None except Exception as e: logger.error(f"Error loading branch model {getattr(branch_config, 'model_id', 'unknown')}: {e}") return None async def execute_branches(self, frame: np.ndarray, branches: List[Any], detected_regions: Dict[str, Any], detection_context: Dict[str, Any]) -> Dict[str, Any]: """ Execute all branches in parallel and collect results. Args: frame: Input frame branches: List of branch configurations detected_regions: Dictionary of detected regions from main detection detection_context: Detection context data Returns: Dictionary with branch execution results """ start_time = time.time() branch_results = {} try: # Separate parallel and sequential branches parallel_branches = [] sequential_branches = [] for branch in branches: if getattr(branch, 'parallel', False): parallel_branches.append(branch) else: sequential_branches.append(branch) # Execute parallel branches concurrently if parallel_branches: logger.info(f"Executing {len(parallel_branches)} branches in parallel") parallel_results = await self._execute_parallel_branches( frame, parallel_branches, detected_regions, detection_context ) branch_results.update(parallel_results) self.stats['parallel_executions'] += 1 # Execute sequential branches one by one if sequential_branches: logger.info(f"Executing {len(sequential_branches)} branches sequentially") sequential_results = await self._execute_sequential_branches( frame, sequential_branches, detected_regions, detection_context ) branch_results.update(sequential_results) # Update statistics self.stats['branches_processed'] += len(branches) processing_time = time.time() - start_time self.stats['total_processing_time'] += processing_time logger.info(f"Branch execution completed in {processing_time:.3f}s with {len(branch_results)} results") except Exception as e: logger.error(f"Error in branch execution: {e}", exc_info=True) return branch_results async def _execute_parallel_branches(self, frame: np.ndarray, branches: List[Any], detected_regions: Dict[str, Any], detection_context: Dict[str, Any]) -> Dict[str, Any]: """ Execute branches in parallel using ThreadPoolExecutor. Args: frame: Input frame branches: List of parallel branch configurations detected_regions: Dictionary of detected regions detection_context: Detection context data Returns: Dictionary with parallel branch results """ results = {} # Submit all branches for parallel execution future_to_branch = {} for branch in branches: branch_id = getattr(branch, 'model_id', 'unknown') logger.info(f"[PARALLEL SUBMIT] {branch_id}: Submitting branch to thread pool") future = self.executor.submit( self._execute_single_branch_sync, frame, branch, detected_regions, detection_context ) future_to_branch[future] = branch # Collect results as they complete for future in as_completed(future_to_branch): branch = future_to_branch[future] branch_id = getattr(branch, 'model_id', 'unknown') try: result = future.result() results[branch_id] = result logger.info(f"[PARALLEL COMPLETE] {branch_id}: Branch completed successfully") except Exception as e: logger.error(f"Error in parallel branch {branch_id}: {e}") results[branch_id] = { 'status': 'error', 'message': str(e), 'processing_time': 0.0 } # Flatten nested branch results to top level for database access flattened_results = {} for branch_id, branch_result in results.items(): # Add the branch result itself flattened_results[branch_id] = branch_result # If this branch has nested branches, add them to the top level too if isinstance(branch_result, dict) and 'nested_branches' in branch_result: nested_branches = branch_result['nested_branches'] for nested_branch_id, nested_result in nested_branches.items(): flattened_results[nested_branch_id] = nested_result logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results") return flattened_results async def _execute_sequential_branches(self, frame: np.ndarray, branches: List[Any], detected_regions: Dict[str, Any], detection_context: Dict[str, Any]) -> Dict[str, Any]: """ Execute branches sequentially. Args: frame: Input frame branches: List of sequential branch configurations detected_regions: Dictionary of detected regions detection_context: Detection context data Returns: Dictionary with sequential branch results """ results = {} for branch in branches: branch_id = getattr(branch, 'model_id', 'unknown') try: result = await asyncio.get_event_loop().run_in_executor( self.executor, self._execute_single_branch_sync, frame, branch, detected_regions, detection_context ) results[branch_id] = result logger.debug(f"Sequential branch {branch_id} completed successfully") except Exception as e: logger.error(f"Error in sequential branch {branch_id}: {e}") results[branch_id] = { 'status': 'error', 'message': str(e), 'processing_time': 0.0 } # Flatten nested branch results to top level for database access flattened_results = {} for branch_id, branch_result in results.items(): # Add the branch result itself flattened_results[branch_id] = branch_result # If this branch has nested branches, add them to the top level too if isinstance(branch_result, dict) and 'nested_branches' in branch_result: nested_branches = branch_result['nested_branches'] for nested_branch_id, nested_result in nested_branches.items(): flattened_results[nested_branch_id] = nested_result logger.info(f"[FLATTEN] Added nested branch {nested_branch_id} to top-level results") return flattened_results def _execute_single_branch_sync(self, frame: np.ndarray, branch_config: Any, detected_regions: Dict[str, Any], detection_context: Dict[str, Any]) -> Dict[str, Any]: """ Synchronous execution of a single branch (for ThreadPoolExecutor). Args: frame: Input frame branch_config: Branch configuration object detected_regions: Dictionary of detected regions detection_context: Detection context data Returns: Dictionary with branch execution result """ start_time = time.time() branch_id = getattr(branch_config, 'model_id', 'unknown') logger.info(f"[BRANCH START] {branch_id}: Starting branch execution") logger.debug(f"[BRANCH CONFIG] {branch_id}: crop={getattr(branch_config, 'crop', False)}, " f"trigger_classes={getattr(branch_config, 'trigger_classes', [])}, " f"min_confidence={getattr(branch_config, 'min_confidence', 0.6)}") # Check if branch should execute based on triggerClasses (execution conditions) trigger_classes = getattr(branch_config, 'trigger_classes', []) logger.info(f"[DETECTED REGIONS] {branch_id}: Available parent detections: {list(detected_regions.keys())}") for region_name, region_data in detected_regions.items(): logger.debug(f"[REGION DATA] {branch_id}: '{region_name}' -> bbox={region_data.get('bbox')}, conf={region_data.get('confidence')}") if trigger_classes: # Check if any parent detection matches our trigger classes should_execute = False for trigger_class in trigger_classes: if trigger_class in detected_regions: should_execute = True logger.info(f"[TRIGGER CHECK] {branch_id}: Found '{trigger_class}' in parent detections - branch will execute") break if not should_execute: logger.warning(f"[TRIGGER CHECK] {branch_id}: None of trigger classes {trigger_classes} found in parent detections {list(detected_regions.keys())} - skipping branch") return { 'status': 'skipped', 'branch_id': branch_id, 'message': f'No trigger classes {trigger_classes} found in parent detections', 'processing_time': time.time() - start_time } result = { 'status': 'success', 'branch_id': branch_id, 'result': {}, 'processing_time': 0.0, 'timestamp': time.time() } try: # Get or load branch model if branch_id not in self.branch_models: logger.warning(f"Branch model {branch_id} not preloaded, loading now...") # This should be rare since models are preloaded return { 'status': 'error', 'message': f'Branch model {branch_id} not available', 'processing_time': time.time() - start_time } model = self.branch_models[branch_id] # Get configuration values first min_confidence = getattr(branch_config, 'min_confidence', 0.6) # Prepare input frame for this branch input_frame = frame # Handle cropping if required - use biggest bbox that passes min_confidence if getattr(branch_config, 'crop', False): crop_classes = getattr(branch_config, 'crop_class', []) if isinstance(crop_classes, str): crop_classes = [crop_classes] # Find the biggest bbox that passes min_confidence threshold best_region = None best_class = None best_area = 0.0 for crop_class in crop_classes: if crop_class in detected_regions: region = detected_regions[crop_class] confidence = region.get('confidence', 0.0) # Only use detections above min_confidence if confidence >= min_confidence: bbox = region['bbox'] area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) # width * height # Choose biggest bbox among valid detections if area > best_area: best_region = region best_class = crop_class best_area = area if best_region: bbox = best_region['bbox'] x1, y1, x2, y2 = [int(coord) for coord in bbox] cropped = frame[y1:y2, x1:x2] if cropped.size > 0: input_frame = cropped confidence = best_region.get('confidence', 0.0) logger.info(f"[CROP SUCCESS] {branch_id}: cropped '{best_class}' region (conf={confidence:.3f}, area={int(best_area)}) -> shape={cropped.shape}") else: logger.warning(f"Branch {branch_id}: empty crop, using full frame") else: logger.warning(f"Branch {branch_id}: no valid crop regions found (min_conf={min_confidence})") logger.info(f"[INFERENCE START] {branch_id}: Running inference on {'cropped' if input_frame is not frame else 'full'} frame " f"({input_frame.shape[1]}x{input_frame.shape[0]}) with confidence={min_confidence}") # Use .predict() method for both detection and classification models inference_start = time.time() detection_results = model.model.predict(input_frame, conf=min_confidence, verbose=False) inference_time = time.time() - inference_start logger.info(f"[INFERENCE DONE] {branch_id}: Predict completed in {inference_time:.3f}s using .predict() method") # Initialize branch_detections outside the conditional branch_detections = [] # Process results using clean, unified logic if detection_results and len(detection_results) > 0: result_obj = detection_results[0] # Handle detection models (have .boxes attribute) if hasattr(result_obj, 'boxes') and result_obj.boxes is not None: logger.info(f"[RAW DETECTIONS] {branch_id}: Found {len(result_obj.boxes)} raw detections") for i, box in enumerate(result_obj.boxes): class_id = int(box.cls[0]) confidence = float(box.conf[0]) bbox = box.xyxy[0].cpu().numpy().tolist() # [x1, y1, x2, y2] class_name = model.model.names[class_id] logger.debug(f"[RAW DETECTION {i+1}] {branch_id}: '{class_name}', conf={confidence:.3f}") # All detections are included - no filtering by trigger_classes here branch_detections.append({ 'class_name': class_name, 'confidence': confidence, 'bbox': bbox }) # Handle classification models (have .probs attribute) elif hasattr(result_obj, 'probs') and result_obj.probs is not None: logger.info(f"[RAW CLASSIFICATION] {branch_id}: Processing classification results") probs = result_obj.probs top_indices = probs.top5 # Get top 5 predictions top_conf = probs.top5conf.cpu().numpy() for idx, conf in zip(top_indices, top_conf): if conf >= min_confidence: class_name = model.model.names[int(idx)] logger.debug(f"[CLASSIFICATION RESULT {len(branch_detections)+1}] {branch_id}: '{class_name}', conf={conf:.3f}") # For classification, use full input frame dimensions as bbox branch_detections.append({ 'class_name': class_name, 'confidence': float(conf), 'bbox': [0, 0, input_frame.shape[1], input_frame.shape[0]] }) else: logger.warning(f"[UNKNOWN MODEL] {branch_id}: Model results have no .boxes or .probs") result['result'] = { 'detections': branch_detections, 'detection_count': len(branch_detections) } logger.info(f"[FINAL RESULTS] {branch_id}: {len(branch_detections)} detections processed") # Extract best result for classification models if branch_detections: best_detection = max(branch_detections, key=lambda x: x['confidence']) logger.info(f"[BEST DETECTION] {branch_id}: '{best_detection['class_name']}' with confidence {best_detection['confidence']:.3f}") # Add classification-style results for database operations if 'brand' in branch_id.lower(): result['result']['brand'] = best_detection['class_name'] elif 'body' in branch_id.lower() or 'bodytype' in branch_id.lower(): result['result']['body_type'] = best_detection['class_name'] elif 'front_rear' in branch_id.lower(): result['result']['front_rear'] = best_detection['confidence'] logger.info(f"[CLASSIFICATION RESULT] {branch_id}: Extracted classification fields") else: logger.warning(f"[NO RESULTS] {branch_id}: No detections found") # Handle nested branches ONLY if parent found valid detections nested_branches = getattr(branch_config, 'branches', []) if nested_branches: # Check if parent branch found any valid detections if not branch_detections: logger.warning(f"[BRANCH SKIP] {branch_id}: Skipping {len(nested_branches)} nested branches - parent found no valid detections") else: logger.debug(f"Branch {branch_id}: executing {len(nested_branches)} nested branches") # Create detected_regions from THIS branch's detections for nested branches # Nested branches should see their immediate parent's detections, not the root pipeline nested_detected_regions = {} for detection in branch_detections: nested_detected_regions[detection['class_name']] = { 'bbox': detection['bbox'], 'confidence': detection['confidence'] } logger.info(f"[NESTED REGIONS] {branch_id}: Passing {list(nested_detected_regions.keys())} to nested branches") # Note: For simplicity, nested branches are executed sequentially in this sync method # In a full async implementation, these could also be parallelized nested_results = {} for nested_branch in nested_branches: nested_result = self._execute_single_branch_sync( input_frame, nested_branch, nested_detected_regions, detection_context ) nested_branch_id = getattr(nested_branch, 'model_id', 'unknown') nested_results[nested_branch_id] = nested_result result['nested_branches'] = nested_results except Exception as e: logger.error(f"[BRANCH ERROR] {branch_id}: Error in execution: {e}", exc_info=True) result['status'] = 'error' result['message'] = str(e) result['processing_time'] = time.time() - start_time # Summary log logger.info(f"[BRANCH COMPLETE] {branch_id}: status={result['status']}, " f"processing_time={result['processing_time']:.3f}s, " f"result_keys={list(result['result'].keys()) if result['result'] else 'none'}") return result def get_statistics(self) -> Dict[str, Any]: """Get branch processor statistics.""" return { **self.stats, 'loaded_models': list(self.branch_models.keys()), 'model_count': len(self.branch_models) } def cleanup(self): """Cleanup resources.""" if self.executor: self.executor.shutdown(wait=False) # Clear model cache self.branch_models.clear() logger.info("BranchProcessor cleaned up")