feat: integrate Redis support in pipeline execution; add actions for saving images and publishing messages

This commit is contained in:
Siwat Sirichai 2025-07-15 00:30:09 +07:00
parent a1f797f564
commit 769371a1a3
3 changed files with 250 additions and 5 deletions

200
pympta.md Normal file
View file

@ -0,0 +1,200 @@
# pympta: Modular Pipeline Task Executor
`pympta` is a Python module designed to load and execute modular, multi-stage AI pipelines defined in a special package format (`.mpta`). It is primarily used within the detector worker to run complex computer vision tasks where the output of one model can trigger a subsequent model on a specific region of interest.
## Core Concepts
### 1. MPTA Package (`.mpta`)
An `.mpta` file is a standard `.zip` archive with a different extension. It bundles all the necessary components for a pipeline to run.
A typical `.mpta` file has the following structure:
```
my_pipeline.mpta/
├── pipeline.json
├── model1.pt
├── model2.pt
└── ...
```
- **`pipeline.json`**: (Required) The manifest file that defines the structure of the pipeline, the models to use, and the logic connecting them.
- **Model Files (`.pt`, etc.)**: The actual pre-trained model files (e.g., PyTorch, ONNX). The pipeline currently uses `ultralytics.YOLO` models.
### 2. Pipeline Structure
A pipeline is a tree-like structure of "nodes," defined in `pipeline.json`.
- **Root Node**: The entry point of the pipeline. It processes the initial, full-frame image.
- **Branch Nodes**: Child nodes that are triggered by specific detection results from their parent. For example, a root node might detect a "vehicle," which then triggers a branch node to detect a "license plate" within the vehicle's bounding box.
This modular structure allows for creating complex and efficient inference logic, avoiding the need to run every model on every frame.
## `pipeline.json` Specification
This file defines the entire pipeline logic. The root object contains a `pipeline` key for the pipeline definition and an optional `redis` key for Redis configuration.
### Top-Level Object Structure
| Key | Type | Required | Description |
| ---------- | ------ | -------- | ------------------------------------------------------- |
| `pipeline` | Object | Yes | The root node object of the pipeline. |
| `redis` | Object | No | Configuration for connecting to a Redis server. |
### Redis Configuration (`redis`)
| Key | Type | Required | Description |
| ---------- | ------ | -------- | ------------------------------------------------------- |
| `host` | String | Yes | The hostname or IP address of the Redis server. |
| `port` | Number | Yes | The port number of the Redis server. |
| `password` | String | No | The password for Redis authentication. |
| `db` | Number | No | The Redis database number to use. Defaults to `0`. |
### Node Object Structure
| Key | Type | Required | Description |
| ------------------- | ------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------------- |
| `modelId` | String | Yes | A unique identifier for this model node (e.g., "vehicle-detector"). |
| `modelFile` | String | Yes | The path to the model file within the `.mpta` archive (e.g., "yolov8n.pt"). |
| `minConfidence` | Float | Yes | The minimum confidence score (0.0 to 1.0) required for a detection to be considered valid and potentially trigger a branch. |
| `triggerClasses` | Array<String> | Yes | A list of class names that, when detected by the parent, can trigger this node. For the root node, this lists all classes of interest. |
| `crop` | Boolean | No | If `true`, the image is cropped to the parent's detection bounding box before being passed to this node's model. Defaults to `false`. |
| `branches` | Array<Node> | No | A list of child node objects that can be triggered by this node's detections. |
| `actions` | Array<Action> | No | A list of actions to execute upon a successful detection in this node. |
### Action Object Structure
Actions allow the pipeline to interact with Redis.
#### `redis_save_image`
Saves the current image frame (or cropped sub-image) to a Redis key.
| Key | Type | Required | Description |
| ----- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
| `type`| String | Yes | Must be `"redis_save_image"`. |
| `key` | String | Yes | The Redis key to save the image to. Can contain placeholders like `{class}` or `{id}` to be formatted with detection results. |
#### `redis_publish`
Publishes a message to a Redis channel.
| Key | Type | Required | Description |
| --------- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
| `type` | String | Yes | Must be `"redis_publish"`. |
| `channel` | String | Yes | The Redis channel to publish the message to. |
| `message` | String | Yes | The message to publish. Can contain placeholders like `{class}` or `{id}` to be formatted with detection results. |
### Example `pipeline.json` with Redis
```json
{
"redis": {
"host": "localhost",
"port": 6379,
"password": "your-password"
},
"pipeline": {
"modelId": "vehicle-detector",
"modelFile": "vehicle_model.pt",
"minConfidence": 0.5,
"triggerClasses": ["car", "truck"],
"actions": [
{
"type": "redis_save_image",
"key": "detection:image:{id}"
},
{
"type": "redis_publish",
"channel": "detections",
"message": "Detected a {class} with ID {id}"
}
],
"branches": [
{
"modelId": "lpr-us",
"modelFile": "lpr_model.pt",
"minConfidence": 0.7,
"triggerClasses": ["car"],
"crop": true,
"branches": []
}
]
}
}
```
## API Reference
The `pympta` module exposes two main functions.
### `load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict`
Loads, extracts, and parses an `.mpta` file to build a pipeline tree in memory. It also establishes a Redis connection if configured in `pipeline.json`.
- **Parameters:**
- `zip_source` (str): The file path to the local `.mpta` zip archive.
- `target_dir` (str): A directory path where the archive's contents will be extracted.
- **Returns:**
- A dictionary representing the root node of the pipeline, ready to be used with `run_pipeline`. Returns `None` if loading fails.
### `run_pipeline(frame, node: dict, return_bbox: bool = False)`
Executes the inference pipeline on a single image frame.
- **Parameters:**
- `frame`: The input image frame (e.g., a NumPy array from OpenCV).
- `node` (dict): The pipeline node to execute (typically the root node returned by `load_pipeline_from_zip`).
- `return_bbox` (bool): If `True`, the function returns a tuple `(detection, bounding_box)`. Otherwise, it returns only the `detection`.
- **Returns:**
- The final detection result from the last executed node in the chain. A detection is a dictionary like `{'class': 'car', 'confidence': 0.95, 'id': 1}`. If no detection meets the criteria, it returns `None` (or `(None, None)` if `return_bbox` is `True`).
## Usage Example
This snippet, inspired by `pipeline_webcam.py`, shows how to use `pympta` to load a pipeline and process an image from a webcam.
```python
import cv2
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline
# 1. Define paths
MPTA_FILE = "path/to/your/pipeline.mpta"
CACHE_DIR = ".mptacache"
# 2. Load the pipeline from the .mpta file
# This reads pipeline.json and loads the YOLO models into memory.
model_tree = load_pipeline_from_zip(MPTA_FILE, CACHE_DIR)
if not model_tree:
print("Failed to load pipeline.")
exit()
# 3. Open a video source
cap = cv2.VideoCapture(0)
while True:
ret, frame = cap.read()
if not ret:
break
# 4. Run the pipeline on the current frame
# The function will handle the entire logic tree (e.g., find a car, then find its license plate).
detection_result, bounding_box = run_pipeline(frame, model_tree, return_bbox=True)
# 5. Display the results
if detection_result:
print(f"Detected: {detection_result['class']} with confidence {detection_result['confidence']:.2f}")
if bounding_box:
x1, y1, x2, y2 = bounding_box
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, detection_result['class'], (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
cv2.imshow("Pipeline Output", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
```

View file

@ -5,4 +5,5 @@ torchvision
ultralytics ultralytics
opencv-python opencv-python
websockets websockets
fastapi[standard] fastapi[standard]
redis

View file

@ -7,13 +7,14 @@ import requests
import zipfile import zipfile
import shutil import shutil
import traceback import traceback
import redis
from ultralytics import YOLO from ultralytics import YOLO
from urllib.parse import urlparse from urllib.parse import urlparse
# Create a logger specifically for this module # Create a logger specifically for this module
logger = logging.getLogger("detector_worker.pympta") logger = logging.getLogger("detector_worker.pympta")
def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict: def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client) -> dict:
# Recursively load a model node from configuration. # Recursively load a model node from configuration.
model_path = os.path.join(mpta_dir, node_config["modelFile"]) model_path = os.path.join(mpta_dir, node_config["modelFile"])
if not os.path.exists(model_path): if not os.path.exists(model_path):
@ -44,13 +45,15 @@ def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict:
"triggerClassIndices": trigger_class_indices, "triggerClassIndices": trigger_class_indices,
"crop": node_config.get("crop", False), "crop": node_config.get("crop", False),
"minConfidence": node_config.get("minConfidence", None), "minConfidence": node_config.get("minConfidence", None),
"actions": node_config.get("actions", []),
"model": model, "model": model,
"branches": [] "branches": [],
"redis_client": redis_client
} }
logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}") logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}")
for child in node_config.get("branches", []): for child in node_config.get("branches", []):
logger.debug(f"Loading branch for parent node {node_config['modelId']}") logger.debug(f"Loading branch for parent node {node_config['modelId']}")
node["branches"].append(load_pipeline_node(child, mpta_dir)) node["branches"].append(load_pipeline_node(child, mpta_dir, redis_client))
return node return node
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict: def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
@ -158,7 +161,26 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
pipeline_config = json.load(f) pipeline_config = json.load(f)
logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}") logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}")
logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}") logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}")
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir)
# Establish Redis connection if configured
redis_client = None
if "redis" in pipeline_config:
redis_config = pipeline_config["redis"]
try:
redis_client = redis.Redis(
host=redis_config["host"],
port=redis_config["port"],
password=redis_config.get("password"),
db=redis_config.get("db", 0),
decode_responses=True
)
redis_client.ping()
logger.info(f"Successfully connected to Redis at {redis_config['host']}:{redis_config['port']}")
except redis.exceptions.ConnectionError as e:
logger.error(f"Failed to connect to Redis: {e}")
redis_client = None
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir, redis_client)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True) logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True)
return None return None
@ -169,6 +191,25 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True) logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True)
return None return None
def execute_actions(node, frame, detection_result):
if not node["redis_client"] or not node["actions"]:
return
for action in node["actions"]:
try:
if action["type"] == "redis_save_image":
key = action["key"].format(**detection_result)
_, buffer = cv2.imencode('.jpg', frame)
node["redis_client"].set(key, buffer.tobytes())
logger.info(f"Saved image to Redis with key: {key}")
elif action["type"] == "redis_publish":
channel = action["channel"]
message = action["message"].format(**detection_result)
node["redis_client"].publish(channel, message)
logger.info(f"Published message to Redis channel '{channel}': {message}")
except Exception as e:
logger.error(f"Error executing action {action['type']}: {e}")
def run_pipeline(frame, node: dict, return_bbox: bool=False): def run_pipeline(frame, node: dict, return_bbox: bool=False):
""" """
- For detection nodes (task != 'classify'): - For detection nodes (task != 'classify'):
@ -206,6 +247,7 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False):
"confidence": top1_conf, "confidence": top1_conf,
"id": None "id": None
} }
execute_actions(node, frame, det)
return (det, None) if return_bbox else det return (det, None) if return_bbox else det
@ -254,9 +296,11 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False):
det2, _ = run_pipeline(sub, br, return_bbox=True) det2, _ = run_pipeline(sub, br, return_bbox=True)
if det2: if det2:
# return classification result + original bbox # return classification result + original bbox
execute_actions(br, sub, det2)
return (det2, best_box) if return_bbox else det2 return (det2, best_box) if return_bbox else det2
# ─── No branch matched → return this detection ───────────── # ─── No branch matched → return this detection ─────────────
execute_actions(node, frame, best_det)
return (best_det, best_box) if return_bbox else best_det return (best_det, best_box) if return_bbox else best_det
except Exception as e: except Exception as e: