feat: integrate Redis support in pipeline execution; add actions for saving images and publishing messages
This commit is contained in:
		
							parent
							
								
									a1f797f564
								
							
						
					
					
						commit
						769371a1a3
					
				
					 3 changed files with 250 additions and 5 deletions
				
			
		
							
								
								
									
										200
									
								
								pympta.md
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										200
									
								
								pympta.md
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,200 @@
 | 
				
			||||||
 | 
					# pympta: Modular Pipeline Task Executor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					`pympta` is a Python module designed to load and execute modular, multi-stage AI pipelines defined in a special package format (`.mpta`). It is primarily used within the detector worker to run complex computer vision tasks where the output of one model can trigger a subsequent model on a specific region of interest.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Core Concepts
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 1. MPTA Package (`.mpta`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					An `.mpta` file is a standard `.zip` archive with a different extension. It bundles all the necessary components for a pipeline to run.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					A typical `.mpta` file has the following structure:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					my_pipeline.mpta/
 | 
				
			||||||
 | 
					├── pipeline.json
 | 
				
			||||||
 | 
					├── model1.pt
 | 
				
			||||||
 | 
					├── model2.pt
 | 
				
			||||||
 | 
					└── ...
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- **`pipeline.json`**: (Required) The manifest file that defines the structure of the pipeline, the models to use, and the logic connecting them.
 | 
				
			||||||
 | 
					- **Model Files (`.pt`, etc.)**: The actual pre-trained model files (e.g., PyTorch, ONNX). The pipeline currently uses `ultralytics.YOLO` models.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2. Pipeline Structure
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					A pipeline is a tree-like structure of "nodes," defined in `pipeline.json`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- **Root Node**: The entry point of the pipeline. It processes the initial, full-frame image.
 | 
				
			||||||
 | 
					- **Branch Nodes**: Child nodes that are triggered by specific detection results from their parent. For example, a root node might detect a "vehicle," which then triggers a branch node to detect a "license plate" within the vehicle's bounding box.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This modular structure allows for creating complex and efficient inference logic, avoiding the need to run every model on every frame.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## `pipeline.json` Specification
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This file defines the entire pipeline logic. The root object contains a `pipeline` key for the pipeline definition and an optional `redis` key for Redis configuration.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Top-Level Object Structure
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Key        | Type   | Required | Description                                             |
 | 
				
			||||||
 | 
					| ---------- | ------ | -------- | ------------------------------------------------------- |
 | 
				
			||||||
 | 
					| `pipeline` | Object | Yes      | The root node object of the pipeline.                   |
 | 
				
			||||||
 | 
					| `redis`    | Object | No       | Configuration for connecting to a Redis server.         |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Redis Configuration (`redis`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Key        | Type   | Required | Description                                             |
 | 
				
			||||||
 | 
					| ---------- | ------ | -------- | ------------------------------------------------------- |
 | 
				
			||||||
 | 
					| `host`     | String | Yes      | The hostname or IP address of the Redis server.         |
 | 
				
			||||||
 | 
					| `port`     | Number | Yes      | The port number of the Redis server.                    |
 | 
				
			||||||
 | 
					| `password` | String | No       | The password for Redis authentication.                  |
 | 
				
			||||||
 | 
					| `db`       | Number | No       | The Redis database number to use. Defaults to `0`.      |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Node Object Structure
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Key                 | Type          | Required | Description                                                                                                                            |
 | 
				
			||||||
 | 
					| ------------------- | ------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------------- |
 | 
				
			||||||
 | 
					| `modelId`           | String        | Yes      | A unique identifier for this model node (e.g., "vehicle-detector").                                                                    |
 | 
				
			||||||
 | 
					| `modelFile`         | String        | Yes      | The path to the model file within the `.mpta` archive (e.g., "yolov8n.pt").                                                            |
 | 
				
			||||||
 | 
					| `minConfidence`     | Float         | Yes      | The minimum confidence score (0.0 to 1.0) required for a detection to be considered valid and potentially trigger a branch.              |
 | 
				
			||||||
 | 
					| `triggerClasses`    | Array<String> | Yes      | A list of class names that, when detected by the parent, can trigger this node. For the root node, this lists all classes of interest. |
 | 
				
			||||||
 | 
					| `crop`              | Boolean       | No       | If `true`, the image is cropped to the parent's detection bounding box before being passed to this node's model. Defaults to `false`.    |
 | 
				
			||||||
 | 
					| `branches`          | Array<Node>   | No       | A list of child node objects that can be triggered by this node's detections.                                                          |
 | 
				
			||||||
 | 
					| `actions`           | Array<Action> | No       | A list of actions to execute upon a successful detection in this node.                                                                 |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Action Object Structure
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Actions allow the pipeline to interact with Redis.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### `redis_save_image`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Saves the current image frame (or cropped sub-image) to a Redis key.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Key   | Type   | Required | Description                                                                                             |
 | 
				
			||||||
 | 
					| ----- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
 | 
				
			||||||
 | 
					| `type`| String | Yes      | Must be `"redis_save_image"`.                                                                           |
 | 
				
			||||||
 | 
					| `key` | String | Yes      | The Redis key to save the image to. Can contain placeholders like `{class}` or `{id}` to be formatted with detection results. |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### `redis_publish`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Publishes a message to a Redis channel.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Key       | Type   | Required | Description                                                                                             |
 | 
				
			||||||
 | 
					| --------- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
 | 
				
			||||||
 | 
					| `type`    | String | Yes      | Must be `"redis_publish"`.                                                                              |
 | 
				
			||||||
 | 
					| `channel` | String | Yes      | The Redis channel to publish the message to.                                                            |
 | 
				
			||||||
 | 
					| `message` | String | Yes      | The message to publish. Can contain placeholders like `{class}` or `{id}` to be formatted with detection results. |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Example `pipeline.json` with Redis
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "redis": {
 | 
				
			||||||
 | 
					    "host": "localhost",
 | 
				
			||||||
 | 
					    "port": 6379,
 | 
				
			||||||
 | 
					    "password": "your-password"
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "pipeline": {
 | 
				
			||||||
 | 
					    "modelId": "vehicle-detector",
 | 
				
			||||||
 | 
					    "modelFile": "vehicle_model.pt",
 | 
				
			||||||
 | 
					    "minConfidence": 0.5,
 | 
				
			||||||
 | 
					    "triggerClasses": ["car", "truck"],
 | 
				
			||||||
 | 
					    "actions": [
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        "type": "redis_save_image",
 | 
				
			||||||
 | 
					        "key": "detection:image:{id}"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        "type": "redis_publish",
 | 
				
			||||||
 | 
					        "channel": "detections",
 | 
				
			||||||
 | 
					        "message": "Detected a {class} with ID {id}"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    "branches": [
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        "modelId": "lpr-us",
 | 
				
			||||||
 | 
					        "modelFile": "lpr_model.pt",
 | 
				
			||||||
 | 
					        "minConfidence": 0.7,
 | 
				
			||||||
 | 
					        "triggerClasses": ["car"],
 | 
				
			||||||
 | 
					        "crop": true,
 | 
				
			||||||
 | 
					        "branches": []
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## API Reference
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The `pympta` module exposes two main functions.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### `load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Loads, extracts, and parses an `.mpta` file to build a pipeline tree in memory. It also establishes a Redis connection if configured in `pipeline.json`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- **Parameters:**
 | 
				
			||||||
 | 
					  - `zip_source` (str): The file path to the local `.mpta` zip archive.
 | 
				
			||||||
 | 
					  - `target_dir` (str): A directory path where the archive's contents will be extracted.
 | 
				
			||||||
 | 
					- **Returns:**
 | 
				
			||||||
 | 
					  - A dictionary representing the root node of the pipeline, ready to be used with `run_pipeline`. Returns `None` if loading fails.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### `run_pipeline(frame, node: dict, return_bbox: bool = False)`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Executes the inference pipeline on a single image frame.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- **Parameters:**
 | 
				
			||||||
 | 
					  - `frame`: The input image frame (e.g., a NumPy array from OpenCV).
 | 
				
			||||||
 | 
					  - `node` (dict): The pipeline node to execute (typically the root node returned by `load_pipeline_from_zip`).
 | 
				
			||||||
 | 
					  - `return_bbox` (bool): If `True`, the function returns a tuple `(detection, bounding_box)`. Otherwise, it returns only the `detection`.
 | 
				
			||||||
 | 
					- **Returns:**
 | 
				
			||||||
 | 
					  - The final detection result from the last executed node in the chain. A detection is a dictionary like `{'class': 'car', 'confidence': 0.95, 'id': 1}`. If no detection meets the criteria, it returns `None` (or `(None, None)` if `return_bbox` is `True`).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Usage Example
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This snippet, inspired by `pipeline_webcam.py`, shows how to use `pympta` to load a pipeline and process an image from a webcam.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					import cv2
 | 
				
			||||||
 | 
					from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 1. Define paths
 | 
				
			||||||
 | 
					MPTA_FILE = "path/to/your/pipeline.mpta"
 | 
				
			||||||
 | 
					CACHE_DIR = ".mptacache"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 2. Load the pipeline from the .mpta file
 | 
				
			||||||
 | 
					# This reads pipeline.json and loads the YOLO models into memory.
 | 
				
			||||||
 | 
					model_tree = load_pipeline_from_zip(MPTA_FILE, CACHE_DIR)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if not model_tree:
 | 
				
			||||||
 | 
					    print("Failed to load pipeline.")
 | 
				
			||||||
 | 
					    exit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 3. Open a video source
 | 
				
			||||||
 | 
					cap = cv2.VideoCapture(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					while True:
 | 
				
			||||||
 | 
					    ret, frame = cap.read()
 | 
				
			||||||
 | 
					    if not ret:
 | 
				
			||||||
 | 
					        break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 4. Run the pipeline on the current frame
 | 
				
			||||||
 | 
					    # The function will handle the entire logic tree (e.g., find a car, then find its license plate).
 | 
				
			||||||
 | 
					    detection_result, bounding_box = run_pipeline(frame, model_tree, return_bbox=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 5. Display the results
 | 
				
			||||||
 | 
					    if detection_result:
 | 
				
			||||||
 | 
					        print(f"Detected: {detection_result['class']} with confidence {detection_result['confidence']:.2f}")
 | 
				
			||||||
 | 
					        if bounding_box:
 | 
				
			||||||
 | 
					            x1, y1, x2, y2 = bounding_box
 | 
				
			||||||
 | 
					            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
 | 
				
			||||||
 | 
					            cv2.putText(frame, detection_result['class'], (x1, y1 - 10),
 | 
				
			||||||
 | 
					                        cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cv2.imshow("Pipeline Output", frame)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if cv2.waitKey(1) & 0xFF == ord('q'):
 | 
				
			||||||
 | 
					        break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cap.release()
 | 
				
			||||||
 | 
					cv2.destroyAllWindows()
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
| 
						 | 
					@ -5,4 +5,5 @@ torchvision
 | 
				
			||||||
ultralytics
 | 
					ultralytics
 | 
				
			||||||
opencv-python
 | 
					opencv-python
 | 
				
			||||||
websockets
 | 
					websockets
 | 
				
			||||||
fastapi[standard]
 | 
					fastapi[standard]
 | 
				
			||||||
 | 
					redis
 | 
				
			||||||
| 
						 | 
					@ -7,13 +7,14 @@ import requests
 | 
				
			||||||
import zipfile
 | 
					import zipfile
 | 
				
			||||||
import shutil
 | 
					import shutil
 | 
				
			||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
 | 
					import redis
 | 
				
			||||||
from ultralytics import YOLO
 | 
					from ultralytics import YOLO
 | 
				
			||||||
from urllib.parse import urlparse
 | 
					from urllib.parse import urlparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Create a logger specifically for this module
 | 
					# Create a logger specifically for this module
 | 
				
			||||||
logger = logging.getLogger("detector_worker.pympta")
 | 
					logger = logging.getLogger("detector_worker.pympta")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict:
 | 
					def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client) -> dict:
 | 
				
			||||||
    # Recursively load a model node from configuration.
 | 
					    # Recursively load a model node from configuration.
 | 
				
			||||||
    model_path = os.path.join(mpta_dir, node_config["modelFile"])
 | 
					    model_path = os.path.join(mpta_dir, node_config["modelFile"])
 | 
				
			||||||
    if not os.path.exists(model_path):
 | 
					    if not os.path.exists(model_path):
 | 
				
			||||||
| 
						 | 
					@ -44,13 +45,15 @@ def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict:
 | 
				
			||||||
        "triggerClassIndices": trigger_class_indices,
 | 
					        "triggerClassIndices": trigger_class_indices,
 | 
				
			||||||
        "crop": node_config.get("crop", False),
 | 
					        "crop": node_config.get("crop", False),
 | 
				
			||||||
        "minConfidence": node_config.get("minConfidence", None),
 | 
					        "minConfidence": node_config.get("minConfidence", None),
 | 
				
			||||||
 | 
					        "actions": node_config.get("actions", []),
 | 
				
			||||||
        "model": model,
 | 
					        "model": model,
 | 
				
			||||||
        "branches": []
 | 
					        "branches": [],
 | 
				
			||||||
 | 
					        "redis_client": redis_client
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}")
 | 
					    logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}")
 | 
				
			||||||
    for child in node_config.get("branches", []):
 | 
					    for child in node_config.get("branches", []):
 | 
				
			||||||
        logger.debug(f"Loading branch for parent node {node_config['modelId']}")
 | 
					        logger.debug(f"Loading branch for parent node {node_config['modelId']}")
 | 
				
			||||||
        node["branches"].append(load_pipeline_node(child, mpta_dir))
 | 
					        node["branches"].append(load_pipeline_node(child, mpta_dir, redis_client))
 | 
				
			||||||
    return node
 | 
					    return node
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
 | 
					def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
 | 
				
			||||||
| 
						 | 
					@ -158,7 +161,26 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
 | 
				
			||||||
            pipeline_config = json.load(f)
 | 
					            pipeline_config = json.load(f)
 | 
				
			||||||
        logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}")
 | 
					        logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}")
 | 
				
			||||||
        logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}")
 | 
					        logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}")
 | 
				
			||||||
        return load_pipeline_node(pipeline_config["pipeline"], mpta_dir)
 | 
					        
 | 
				
			||||||
 | 
					        # Establish Redis connection if configured
 | 
				
			||||||
 | 
					        redis_client = None
 | 
				
			||||||
 | 
					        if "redis" in pipeline_config:
 | 
				
			||||||
 | 
					            redis_config = pipeline_config["redis"]
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                redis_client = redis.Redis(
 | 
				
			||||||
 | 
					                    host=redis_config["host"],
 | 
				
			||||||
 | 
					                    port=redis_config["port"],
 | 
				
			||||||
 | 
					                    password=redis_config.get("password"),
 | 
				
			||||||
 | 
					                    db=redis_config.get("db", 0),
 | 
				
			||||||
 | 
					                    decode_responses=True
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                redis_client.ping()
 | 
				
			||||||
 | 
					                logger.info(f"Successfully connected to Redis at {redis_config['host']}:{redis_config['port']}")
 | 
				
			||||||
 | 
					            except redis.exceptions.ConnectionError as e:
 | 
				
			||||||
 | 
					                logger.error(f"Failed to connect to Redis: {e}")
 | 
				
			||||||
 | 
					                redis_client = None
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return load_pipeline_node(pipeline_config["pipeline"], mpta_dir, redis_client)
 | 
				
			||||||
    except json.JSONDecodeError as e:
 | 
					    except json.JSONDecodeError as e:
 | 
				
			||||||
        logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True)
 | 
					        logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True)
 | 
				
			||||||
        return None
 | 
					        return None
 | 
				
			||||||
| 
						 | 
					@ -169,6 +191,25 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
 | 
				
			||||||
        logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True)
 | 
					        logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True)
 | 
				
			||||||
        return None
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def execute_actions(node, frame, detection_result):
 | 
				
			||||||
 | 
					    if not node["redis_client"] or not node["actions"]:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for action in node["actions"]:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            if action["type"] == "redis_save_image":
 | 
				
			||||||
 | 
					                key = action["key"].format(**detection_result)
 | 
				
			||||||
 | 
					                _, buffer = cv2.imencode('.jpg', frame)
 | 
				
			||||||
 | 
					                node["redis_client"].set(key, buffer.tobytes())
 | 
				
			||||||
 | 
					                logger.info(f"Saved image to Redis with key: {key}")
 | 
				
			||||||
 | 
					            elif action["type"] == "redis_publish":
 | 
				
			||||||
 | 
					                channel = action["channel"]
 | 
				
			||||||
 | 
					                message = action["message"].format(**detection_result)
 | 
				
			||||||
 | 
					                node["redis_client"].publish(channel, message)
 | 
				
			||||||
 | 
					                logger.info(f"Published message to Redis channel '{channel}': {message}")
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            logger.error(f"Error executing action {action['type']}: {e}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_pipeline(frame, node: dict, return_bbox: bool=False):
 | 
					def run_pipeline(frame, node: dict, return_bbox: bool=False):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    - For detection nodes (task != 'classify'):
 | 
					    - For detection nodes (task != 'classify'):
 | 
				
			||||||
| 
						 | 
					@ -206,6 +247,7 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False):
 | 
				
			||||||
                "confidence": top1_conf,
 | 
					                "confidence": top1_conf,
 | 
				
			||||||
                "id": None
 | 
					                "id": None
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					            execute_actions(node, frame, det)
 | 
				
			||||||
            return (det, None) if return_bbox else det
 | 
					            return (det, None) if return_bbox else det
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -254,9 +296,11 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False):
 | 
				
			||||||
                det2, _ = run_pipeline(sub, br, return_bbox=True)
 | 
					                det2, _ = run_pipeline(sub, br, return_bbox=True)
 | 
				
			||||||
                if det2:
 | 
					                if det2:
 | 
				
			||||||
                    # return classification result + original bbox
 | 
					                    # return classification result + original bbox
 | 
				
			||||||
 | 
					                    execute_actions(br, sub, det2)
 | 
				
			||||||
                    return (det2, best_box) if return_bbox else det2
 | 
					                    return (det2, best_box) if return_bbox else det2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # ─── No branch matched → return this detection ─────────────
 | 
					        # ─── No branch matched → return this detection ─────────────
 | 
				
			||||||
 | 
					        execute_actions(node, frame, best_det)
 | 
				
			||||||
        return (best_det, best_box) if return_bbox else best_det
 | 
					        return (best_det, best_box) if return_bbox else best_det
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    except Exception as e:
 | 
					    except Exception as e:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue