Add confidence check on model
This commit is contained in:
		
							parent
							
								
									8c429cc8f6
								
							
						
					
					
						commit
						81547311d8
					
				
					 1 changed files with 58 additions and 24 deletions
				
			
		| 
						 | 
				
			
			@ -561,6 +561,9 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
 | 
			
		||||
        # ─── Detection stage - Multi-class support ──────────────────
 | 
			
		||||
        tk = node["triggerClassIndices"]
 | 
			
		||||
        logger.debug(f"Running detection for node {node['modelId']} with trigger classes: {node.get('triggerClasses', [])} (indices: {tk})")
 | 
			
		||||
        logger.debug(f"Node configuration: minConfidence={node['minConfidence']}, multiClass={node.get('multiClass', False)}")
 | 
			
		||||
        
 | 
			
		||||
        res = node["model"].track(
 | 
			
		||||
            frame,
 | 
			
		||||
            stream=False,
 | 
			
		||||
| 
						 | 
				
			
			@ -573,12 +576,17 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
        all_boxes = []
 | 
			
		||||
        regions_dict = {}
 | 
			
		||||
        
 | 
			
		||||
        for box in res.boxes:
 | 
			
		||||
        logger.debug(f"Raw detection results from model: {len(res.boxes) if res.boxes is not None else 0} detections")
 | 
			
		||||
        
 | 
			
		||||
        for i, box in enumerate(res.boxes):
 | 
			
		||||
            conf = float(box.cpu().conf[0])
 | 
			
		||||
            cid = int(box.cpu().cls[0])
 | 
			
		||||
            name = node["model"].names[cid]
 | 
			
		||||
            
 | 
			
		||||
            logger.debug(f"Detection {i}: class='{name}' (id={cid}), confidence={conf:.3f}, threshold={node['minConfidence']}")
 | 
			
		||||
            
 | 
			
		||||
            if conf < node["minConfidence"]:
 | 
			
		||||
                logger.debug(f"  -> REJECTED: confidence {conf:.3f} < threshold {node['minConfidence']}")
 | 
			
		||||
                continue
 | 
			
		||||
                
 | 
			
		||||
            xy = box.cpu().xyxy[0]
 | 
			
		||||
| 
						 | 
				
			
			@ -595,6 +603,8 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
            all_detections.append(detection)
 | 
			
		||||
            all_boxes.append(bbox)
 | 
			
		||||
            
 | 
			
		||||
            logger.debug(f"  -> ACCEPTED: {name} with confidence {conf:.3f}, bbox={bbox}")
 | 
			
		||||
            
 | 
			
		||||
            # Store highest confidence detection for each class
 | 
			
		||||
            if name not in regions_dict or conf > regions_dict[name]["confidence"]:
 | 
			
		||||
                regions_dict[name] = {
 | 
			
		||||
| 
						 | 
				
			
			@ -602,8 +612,13 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
                    "confidence": conf,
 | 
			
		||||
                    "detection": detection
 | 
			
		||||
                }
 | 
			
		||||
                logger.debug(f"  -> Updated regions_dict['{name}'] with confidence {conf:.3f}")
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Detection summary: {len(all_detections)} accepted detections from {len(res.boxes) if res.boxes is not None else 0} total")
 | 
			
		||||
        logger.info(f"Detected classes: {list(regions_dict.keys())}")
 | 
			
		||||
 | 
			
		||||
        if not all_detections:
 | 
			
		||||
            logger.warning("No detections above confidence threshold - returning null")
 | 
			
		||||
            return (None, None) if return_bbox else None
 | 
			
		||||
 | 
			
		||||
        # ─── Multi-class validation ─────────────────────────────────
 | 
			
		||||
| 
						 | 
				
			
			@ -611,13 +626,25 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
            expected_classes = node["expectedClasses"]
 | 
			
		||||
            detected_classes = list(regions_dict.keys())
 | 
			
		||||
            
 | 
			
		||||
            # Check if all expected classes are detected
 | 
			
		||||
            logger.info(f"Multi-class validation: expected={expected_classes}, detected={detected_classes}")
 | 
			
		||||
            
 | 
			
		||||
            # Check if at least one expected class is detected (flexible mode)
 | 
			
		||||
            matching_classes = [cls for cls in expected_classes if cls in detected_classes]
 | 
			
		||||
            missing_classes = [cls for cls in expected_classes if cls not in detected_classes]
 | 
			
		||||
            if missing_classes:
 | 
			
		||||
                logger.debug(f"Missing expected classes: {missing_classes}. Detected: {detected_classes}")
 | 
			
		||||
            
 | 
			
		||||
            logger.debug(f"Matching classes: {matching_classes}, Missing classes: {missing_classes}")
 | 
			
		||||
            
 | 
			
		||||
            if not matching_classes:
 | 
			
		||||
                # No expected classes found at all
 | 
			
		||||
                logger.warning(f"PIPELINE REJECTED: No expected classes detected. Expected: {expected_classes}, Detected: {detected_classes}")
 | 
			
		||||
                return (None, None) if return_bbox else None
 | 
			
		||||
            
 | 
			
		||||
            logger.info(f"Multi-class detection success: {detected_classes}")
 | 
			
		||||
            if missing_classes:
 | 
			
		||||
                logger.info(f"Partial multi-class detection: {matching_classes} found, {missing_classes} missing")
 | 
			
		||||
            else:
 | 
			
		||||
                logger.info(f"Complete multi-class detection success: {detected_classes}")
 | 
			
		||||
        else:
 | 
			
		||||
            logger.debug("No multi-class validation - proceeding with all detections")
 | 
			
		||||
 | 
			
		||||
        # ─── Execute actions with region information ────────────────
 | 
			
		||||
        detection_result = {
 | 
			
		||||
| 
						 | 
				
			
			@ -628,26 +655,33 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
 | 
			
		|||
        
 | 
			
		||||
        # ─── Create initial database record when Car+Frontal detected ────
 | 
			
		||||
        if node.get("db_manager") and node.get("multiClass", False):
 | 
			
		||||
            # Generate UUID session_id since client session is None for now
 | 
			
		||||
            import uuid as uuid_lib
 | 
			
		||||
            from datetime import datetime
 | 
			
		||||
            generated_session_id = str(uuid_lib.uuid4())
 | 
			
		||||
            # Only create database record if we have both Car and Frontal
 | 
			
		||||
            has_car = "Car" in regions_dict
 | 
			
		||||
            has_frontal = "Frontal" in regions_dict
 | 
			
		||||
            
 | 
			
		||||
            # Insert initial detection record
 | 
			
		||||
            camera_id = detection_result.get("camera_id", "unknown")
 | 
			
		||||
            timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
 | 
			
		||||
            
 | 
			
		||||
            inserted_session_id = node["db_manager"].insert_initial_detection(
 | 
			
		||||
                camera_id=camera_id,
 | 
			
		||||
                captured_timestamp=timestamp,
 | 
			
		||||
                session_id=generated_session_id
 | 
			
		||||
            )
 | 
			
		||||
            
 | 
			
		||||
            if inserted_session_id:
 | 
			
		||||
                # Update detection_result with the generated session_id for actions and branches
 | 
			
		||||
                detection_result["session_id"] = inserted_session_id
 | 
			
		||||
                detection_result["timestamp"] = timestamp  # Update with proper timestamp
 | 
			
		||||
                logger.info(f"Created initial database record with session_id: {inserted_session_id}")
 | 
			
		||||
            if has_car and has_frontal:
 | 
			
		||||
                # Generate UUID session_id since client session is None for now
 | 
			
		||||
                import uuid as uuid_lib
 | 
			
		||||
                from datetime import datetime
 | 
			
		||||
                generated_session_id = str(uuid_lib.uuid4())
 | 
			
		||||
                
 | 
			
		||||
                # Insert initial detection record
 | 
			
		||||
                camera_id = detection_result.get("camera_id", "unknown")
 | 
			
		||||
                timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
 | 
			
		||||
                
 | 
			
		||||
                inserted_session_id = node["db_manager"].insert_initial_detection(
 | 
			
		||||
                    camera_id=camera_id,
 | 
			
		||||
                    captured_timestamp=timestamp,
 | 
			
		||||
                    session_id=generated_session_id
 | 
			
		||||
                )
 | 
			
		||||
                
 | 
			
		||||
                if inserted_session_id:
 | 
			
		||||
                    # Update detection_result with the generated session_id for actions and branches
 | 
			
		||||
                    detection_result["session_id"] = inserted_session_id
 | 
			
		||||
                    detection_result["timestamp"] = timestamp  # Update with proper timestamp
 | 
			
		||||
                    logger.info(f"Created initial database record with session_id: {inserted_session_id}")
 | 
			
		||||
            else:
 | 
			
		||||
                logger.debug(f"Database record not created - missing required classes. Has Car: {has_car}, Has Frontal: {has_frontal}")
 | 
			
		||||
        
 | 
			
		||||
        execute_actions(node, frame, detection_result, regions_dict)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue