From f6014abb7aa2b99b1255a65e4ac4d55590b3b2ea Mon Sep 17 00:00:00 2001 From: Siwat Sirichai Date: Wed, 28 May 2025 19:31:22 +0700 Subject: [PATCH] refactor run_pipeline function for improved clarity and efficiency; add trigger class index handling and streamline detection logic --- siwatsystem/pympta.py | 215 +++++++++++++++++++----------------------- 1 file changed, 96 insertions(+), 119 deletions(-) diff --git a/siwatsystem/pympta.py b/siwatsystem/pympta.py index 2ebc6a6..5e32596 100644 --- a/siwatsystem/pympta.py +++ b/siwatsystem/pympta.py @@ -27,10 +27,21 @@ def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict: model.to("cuda") else: logger.info(f"CUDA not available. Using CPU for model {node_config['modelId']}") + + # Prepare trigger class indices for optimization + trigger_classes = node_config.get("triggerClasses", []) + trigger_class_indices = None + if trigger_classes and hasattr(model, "names"): + # Convert class names to indices for the model + trigger_class_indices = [i for i, name in model.names.items() + if name in trigger_classes] + logger.debug(f"Converted trigger classes to indices: {trigger_class_indices}") + node = { "modelId": node_config["modelId"], "modelFile": node_config["modelFile"], - "triggerClasses": node_config.get("triggerClasses", []), + "triggerClasses": trigger_classes, + "triggerClassIndices": trigger_class_indices, "crop": node_config.get("crop", False), "minConfidence": node_config.get("minConfidence", None), "model": model, @@ -158,130 +169,96 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict: logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True) return None -def run_pipeline(frame, node: dict, return_bbox: bool = False, is_last_stage: bool = True): +def run_pipeline(frame, node: dict, return_bbox: bool=False): """ - Processes the frame with the given pipeline node. When return_bbox is True, - the function returns a tuple (detection, bbox) where bbox is (x1,y1,x2,y2) - for drawing. Otherwise, returns only the detection. - - The is_last_stage parameter controls whether this node is considered the last - in the pipeline chain. Only the last stage will return detection results. + - For detection nodes (task != 'classify'): + • runs `track(..., classes=triggerClassIndices)` + • picks top box ≥ minConfidence + • optionally crops & resizes → recurse into child + • else returns (det_dict, bbox) + - For classify nodes: + • runs `predict()` + • returns top (class,confidence) and no bbox """ try: - # Check model type and use appropriate method - model_task = getattr(node["model"], "task", None) - - if model_task == "classify": - # Classification models need to use predict() instead of track() - logger.debug(f"Running classification model: {node.get('modelId')}") + task = getattr(node["model"], "task", None) + + # ─── Classification stage ─────────────────────────────────── + if task == "classify": + # run the classifier and grab its top-1 directly via the Probs API results = node["model"].predict(frame, stream=False) - detection = None - best_box = None - - # Process classification results - for r in results: - probs = r.probs - if probs is not None and len(probs) > 0: - # Get the most confident class - class_id = int(probs.top1) - conf = float(probs.top1conf) - detection = { - "class": node["model"].names[class_id], - "confidence": conf, - "id": None # Classification doesn't have tracking IDs - } - logger.debug(f"Classification detection: {detection}") - else: - logger.debug(f"Empty classification results for model {node.get('modelId')}") - - # Classification doesn't produce bounding boxes - bbox = None - - else: - # Detection/segmentation models use tracking - logger.debug(f"Running detection/tracking model: {node.get('modelId')}") - results = node["model"].track(frame, stream=False, persist=True) - detection = None - best_box = None - max_conf = -1 + # nothing returned? + if not results: + return (None, None) if return_bbox else None - # Log raw detection count - detection_count = 0 - for r in results: - if hasattr(r.boxes, 'cpu') and len(r.boxes.cpu()) > 0: - detection_count += len(r.boxes.cpu()) - - if detection_count == 0: - logger.debug(f"Empty detection results (no objects found) for model {node.get('modelId')}") - else: - logger.debug(f"Detection model {node.get('modelId')} found {detection_count} objects") + # take the first result's probs object + r = results[0] + probs = r.probs + if probs is None: + return (None, None) if return_bbox else None - for r in results: - for box in r.boxes: - box_cpu = box.cpu() - conf = float(box_cpu.conf[0]) - if conf > max_conf and hasattr(box, "id") and box.id is not None: - max_conf = conf - detection = { - "class": node["model"].names[int(box_cpu.cls[0])], - "confidence": conf, - "id": box.id.item() - } - best_box = box_cpu - - if detection: - logger.debug(f"Best detection: {detection}") - else: - logger.debug(f"No valid detection with tracking ID for model {node.get('modelId')}") + # get the top-1 class index and its confidence + top1_idx = int(probs.top1) + top1_conf = float(probs.top1conf) - bbox = None - # Calculate bbox if best_box exists - if detection and best_box is not None: - coords = best_box.xyxy[0] - x1, y1, x2, y2 = map(int, coords) - h, w = frame.shape[:2] - x1, y1 = max(0, x1), max(0, y1) - x2, y2 = min(w, x2), min(h, y2) - if x2 > x1 and y2 > y1: - bbox = (x1, y1, x2, y2) - logger.debug(f"Detection bounding box: {bbox}") - if node.get("crop", False): - frame = frame[y1:y2, x1:x2] - logger.debug(f"Cropped frame to {frame.shape}") + det = { + "class": node["model"].names[top1_idx], + "confidence": top1_conf, + "id": None + } + return (det, None) if return_bbox else det + + + # ─── Detection stage ──────────────────────────────────────── + # only look for your triggerClasses + tk = node["triggerClassIndices"] + res = node["model"].track( + frame, + stream=False, + persist=True, + **({"classes": tk} if tk else {}) + )[0] + + dets, boxes = [], [] + for box in res.boxes: + conf = float(box.cpu().conf[0]) + cid = int(box.cpu().cls[0]) + name = node["model"].names[cid] + if conf < node["minConfidence"]: + continue + xy = box.cpu().xyxy[0] + x1,y1,x2,y2 = map(int, xy) + dets.append({"class": name, "confidence": conf, + "id": box.id.item() if hasattr(box, "id") else None}) + boxes.append((x1, y1, x2, y2)) + + if not dets: + return (None, None) if return_bbox else None + + # take highest‐confidence + best_idx = max(range(len(dets)), key=lambda i: dets[i]["confidence"]) + best_det = dets[best_idx] + best_box = boxes[best_idx] + + # ─── Branch (classification) ─────────────────────────────── + for br in node["branches"]: + if (best_det["class"] in br["triggerClasses"] + and best_det["confidence"] >= br["minConfidence"]): + # crop if requested + sub = frame + if br["crop"]: + x1,y1,x2,y2 = best_box + sub = frame[y1:y2, x1:x2] + sub = cv2.resize(sub, (224, 224)) + + det2, _ = run_pipeline(sub, br, return_bbox=True) + if det2: + # return classification result + original bbox + return (det2, best_box) if return_bbox else det2 + + # ─── No branch matched → return this detection ───────────── + return (best_det, best_box) if return_bbox else best_det - # Check if we should process branches - if detection is not None: - for branch in node["branches"]: - if detection["class"] in branch.get("triggerClasses", []): - min_conf = branch.get("minConfidence") - if min_conf is not None and detection["confidence"] < min_conf: - logger.debug(f"Confidence {detection['confidence']} below threshold {min_conf} for branch {branch['modelId']}.") - break - - # If we have branches, this is not the last stage - branch_result = run_pipeline(frame, branch, return_bbox, is_last_stage=True) - - # This node is no longer the last stage, so its results shouldn't be returned - is_last_stage = False - - if branch_result is not None: - if return_bbox: - return branch_result - return branch_result - break - - # Return this node's detection only if it's considered the last stage - if is_last_stage: - if return_bbox: - return detection, bbox - return detection - - # No detection or not the last stage - if return_bbox: - return None, None - return None except Exception as e: - logger.error(f"Error running pipeline on node {node.get('modelId')}: {e}") - if return_bbox: - return None, None - return None + logging.error(f"Error in node {node.get('modelId')}: {e}") + return (None, None) if return_bbox else None