Feat: pre-evaluation confidence level
This commit is contained in:
		
							parent
							
								
									c4ab4d6cde
								
							
						
					
					
						commit
						c281ca6c6d
					
				
					 1 changed files with 67 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -514,6 +514,65 @@ def resolve_field_mapping(value_template, branch_results, action_context):
 | 
			
		|||
        logger.error(f"Error resolving field mapping '{value_template}': {e}")
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
def validate_pipeline_execution(node, regions_dict):
 | 
			
		||||
    """
 | 
			
		||||
    Pre-validate that all required branches will execute successfully before 
 | 
			
		||||
    committing to Redis actions and database records.
 | 
			
		||||
    
 | 
			
		||||
    Returns:
 | 
			
		||||
        - (True, []) if pipeline can execute completely
 | 
			
		||||
        - (False, missing_branches) if some required branches won't execute
 | 
			
		||||
    """
 | 
			
		||||
    # Get all branches that parallel actions are waiting for
 | 
			
		||||
    required_branches = set()
 | 
			
		||||
    
 | 
			
		||||
    for action in node.get("parallelActions", []):
 | 
			
		||||
        if action.get("type") == "postgresql_update_combined":
 | 
			
		||||
            wait_for_branches = action.get("waitForBranches", [])
 | 
			
		||||
            required_branches.update(wait_for_branches)
 | 
			
		||||
    
 | 
			
		||||
    if not required_branches:
 | 
			
		||||
        # No parallel actions requiring specific branches
 | 
			
		||||
        logger.debug("No parallel actions with waitForBranches - validation passes")
 | 
			
		||||
        return True, []
 | 
			
		||||
    
 | 
			
		||||
    logger.debug(f"Pre-validation: checking if required branches {list(required_branches)} will execute")
 | 
			
		||||
    
 | 
			
		||||
    # Check each required branch
 | 
			
		||||
    missing_branches = []
 | 
			
		||||
    
 | 
			
		||||
    for branch in node.get("branches", []):
 | 
			
		||||
        branch_id = branch["modelId"]
 | 
			
		||||
        
 | 
			
		||||
        if branch_id not in required_branches:
 | 
			
		||||
            continue  # This branch is not required by parallel actions
 | 
			
		||||
            
 | 
			
		||||
        # Check if this branch would be triggered
 | 
			
		||||
        trigger_classes = branch.get("triggerClasses", [])
 | 
			
		||||
        min_conf = branch.get("minConfidence", 0)
 | 
			
		||||
        
 | 
			
		||||
        branch_triggered = False
 | 
			
		||||
        for det_class in regions_dict:
 | 
			
		||||
            det_confidence = regions_dict[det_class]["confidence"]
 | 
			
		||||
            
 | 
			
		||||
            if (det_class in trigger_classes and det_confidence >= min_conf):
 | 
			
		||||
                branch_triggered = True
 | 
			
		||||
                logger.debug(f"Pre-validation: branch {branch_id} WILL be triggered by {det_class} (conf={det_confidence:.3f} >= {min_conf})")
 | 
			
		||||
                break
 | 
			
		||||
        
 | 
			
		||||
        if not branch_triggered:
 | 
			
		||||
            missing_branches.append(branch_id)
 | 
			
		||||
            logger.warning(f"Pre-validation: branch {branch_id} will NOT be triggered - no matching classes or insufficient confidence")
 | 
			
		||||
            logger.debug(f"  Required: {trigger_classes} with min_conf={min_conf}")
 | 
			
		||||
            logger.debug(f"  Available: {[(cls, regions_dict[cls]['confidence']) for cls in regions_dict]}")
 | 
			
		||||
    
 | 
			
		||||
    if missing_branches:
 | 
			
		||||
        logger.error(f"Pipeline pre-validation FAILED: required branches {missing_branches} will not execute")
 | 
			
		||||
        return False, missing_branches
 | 
			
		||||
    else:
 | 
			
		||||
        logger.info(f"Pipeline pre-validation PASSED: all required branches {list(required_branches)} will execute")
 | 
			
		||||
        return True, []
 | 
			
		||||
 | 
			
		||||
def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		||||
    """
 | 
			
		||||
    Enhanced pipeline that supports:
 | 
			
		||||
| 
						 | 
				
			
			@ -646,6 +705,14 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
        else:
 | 
			
		||||
            logger.debug("No multi-class validation - proceeding with all detections")
 | 
			
		||||
 | 
			
		||||
        # ─── Pre-validate pipeline execution ────────────────────────
 | 
			
		||||
        pipeline_valid, missing_branches = validate_pipeline_execution(node, regions_dict)
 | 
			
		||||
        
 | 
			
		||||
        if not pipeline_valid:
 | 
			
		||||
            logger.error(f"Pipeline execution validation FAILED - required branches {missing_branches} cannot execute")
 | 
			
		||||
            logger.error("Aborting pipeline: no Redis actions or database records will be created")
 | 
			
		||||
            return (None, None) if return_bbox else None
 | 
			
		||||
 | 
			
		||||
        # ─── Execute actions with region information ────────────────
 | 
			
		||||
        detection_result = {
 | 
			
		||||
            "detections": all_detections,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue