Compare commits

..

4 commits

6 changed files with 646 additions and 87 deletions

285
app.py
View file

@ -29,6 +29,12 @@ app = FastAPI()
# "models" now holds a nested dict: { camera_id: { modelId: model_tree } } # "models" now holds a nested dict: { camera_id: { modelId: model_tree } }
models: Dict[str, Dict[str, Any]] = {} models: Dict[str, Dict[str, Any]] = {}
streams: Dict[str, Dict[str, Any]] = {} streams: Dict[str, Dict[str, Any]] = {}
# Store session IDs per display
session_ids: Dict[str, int] = {}
# Track shared camera streams by camera URL
camera_streams: Dict[str, Dict[str, Any]] = {}
# Map subscriptions to their camera URL
subscription_to_camera: Dict[str, str] = {}
with open("config.json", "r") as f: with open("config.json", "r") as f:
config = json.load(f) config = json.load(f)
@ -184,9 +190,16 @@ async def detect(websocket: WebSocket):
async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data): async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data):
try: try:
# Apply crop if specified
cropped_frame = frame
if all(coord is not None for coord in [stream.get("cropX1"), stream.get("cropY1"), stream.get("cropX2"), stream.get("cropY2")]):
cropX1, cropY1, cropX2, cropY2 = stream["cropX1"], stream["cropY1"], stream["cropX2"], stream["cropY2"]
cropped_frame = frame[cropY1:cropY2, cropX1:cropX2]
logger.debug(f"Applied crop coordinates ({cropX1}, {cropY1}, {cropX2}, {cropY2}) to frame for camera {camera_id}")
logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}") logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}")
start_time = time.time() start_time = time.time()
detection_result = run_pipeline(frame, model_tree) detection_result = run_pipeline(cropped_frame, model_tree)
process_time = (time.time() - start_time) * 1000 process_time = (time.time() - start_time) * 1000
logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms") logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms")
@ -235,22 +248,48 @@ async def detect(websocket: WebSocket):
"box": [0, 0, 0, 0] "box": [0, 0, 0, 0]
} }
# Convert detection format to match protocol - flatten detection attributes
detection_dict = {}
# Handle different detection result formats
if isinstance(highest_confidence_detection, dict):
# Copy all fields from the detection result
for key, value in highest_confidence_detection.items():
if key not in ["box", "id"]: # Skip internal fields
detection_dict[key] = value
# Extract display identifier for session ID lookup
subscription_parts = stream["subscriptionIdentifier"].split(';')
display_identifier = subscription_parts[0] if subscription_parts else None
session_id = session_ids.get(display_identifier) if display_identifier else None
detection_data = { detection_data = {
"type": "imageDetection", "type": "imageDetection",
"subscriptionIdentifier": stream["subscriptionIdentifier"], "subscriptionIdentifier": stream["subscriptionIdentifier"],
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()), "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()),
"data": { "data": {
"detection": highest_confidence_detection, # Send only the highest confidence detection "detection": detection_dict,
"modelId": stream["modelId"], "modelId": stream["modelId"],
"modelName": stream["modelName"] "modelName": stream["modelName"]
} }
} }
# Add session ID if available
if session_id is not None:
detection_data["sessionId"] = session_id
if highest_confidence_detection["class"] != "none": if highest_confidence_detection["class"] != "none":
logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}") logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}")
# Log session ID if available
subscription_parts = stream["subscriptionIdentifier"].split(';')
display_identifier = subscription_parts[0] if subscription_parts else None
session_id = session_ids.get(display_identifier) if display_identifier else None
if session_id:
logger.debug(f"Detection associated with session ID: {session_id}")
await websocket.send_json(detection_data) await websocket.send_json(detection_data)
logger.debug(f"Sent detection data to client for camera {camera_id}:\n{json.dumps(detection_data, indent=2)}") logger.debug(f"Sent detection data to client for camera {camera_id}")
return persistent_data return persistent_data
except Exception as e: except Exception as e:
logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True) logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True)
@ -521,50 +560,58 @@ async def detect(websocket: WebSocket):
cropX2 = payload.get("cropX2") cropX2 = payload.get("cropX2")
cropY2 = payload.get("cropY2") cropY2 = payload.get("cropY2")
camera_id = subscriptionIdentifier # Use subscriptionIdentifier as camera_id for mapping # Extract camera_id from subscriptionIdentifier (format: displayIdentifier;cameraIdentifier)
parts = subscriptionIdentifier.split(';')
if len(parts) != 2:
logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}")
continue
display_identifier, camera_identifier = parts
camera_id = subscriptionIdentifier # Use full subscriptionIdentifier as camera_id for mapping
if model_url: if model_url:
with models_lock: with models_lock:
if (camera_id not in models) or (modelId not in models[camera_id]): if (camera_id not in models) or (modelId not in models[camera_id]):
logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}") logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
extraction_dir = os.path.join("models", camera_id, str(modelId)) extraction_dir = os.path.join("models", camera_identifier, str(modelId))
os.makedirs(extraction_dir, exist_ok=True) os.makedirs(extraction_dir, exist_ok=True)
# If model_url is remote, download it first. # If model_url is remote, download it first.
parsed = urlparse(model_url) parsed = urlparse(model_url)
if parsed.scheme in ("http", "https"): if parsed.scheme in ("http", "https"):
logger.info(f"Downloading remote model from {model_url}") logger.info(f"Downloading remote .mpta file from {model_url}")
local_mpta = os.path.join(extraction_dir, os.path.basename(parsed.path)) filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
local_mpta = os.path.join(extraction_dir, filename)
logger.debug(f"Download destination: {local_mpta}") logger.debug(f"Download destination: {local_mpta}")
local_path = download_mpta(model_url, local_mpta) local_path = download_mpta(model_url, local_mpta)
if not local_path: if not local_path:
logger.error(f"Failed to download the remote mpta file from {model_url}") logger.error(f"Failed to download the remote .mpta file from {model_url}")
error_response = { error_response = {
"type": "error", "type": "error",
"cameraIdentifier": camera_id, "subscriptionIdentifier": subscriptionIdentifier,
"error": f"Failed to download model from {model_url}" "error": f"Failed to download model from {model_url}"
} }
await websocket.send_json(error_response) await websocket.send_json(error_response)
continue continue
model_tree = load_pipeline_from_zip(local_path, extraction_dir) model_tree = load_pipeline_from_zip(local_path, extraction_dir)
else: else:
logger.info(f"Loading local model from {model_url}") logger.info(f"Loading local .mpta file from {model_url}")
# Check if file exists before attempting to load # Check if file exists before attempting to load
if not os.path.exists(model_url): if not os.path.exists(model_url):
logger.error(f"Local model file not found: {model_url}") logger.error(f"Local .mpta file not found: {model_url}")
logger.debug(f"Current working directory: {os.getcwd()}") logger.debug(f"Current working directory: {os.getcwd()}")
error_response = { error_response = {
"type": "error", "type": "error",
"cameraIdentifier": camera_id, "subscriptionIdentifier": subscriptionIdentifier,
"error": f"Model file not found: {model_url}" "error": f"Model file not found: {model_url}"
} }
await websocket.send_json(error_response) await websocket.send_json(error_response)
continue continue
model_tree = load_pipeline_from_zip(model_url, extraction_dir) model_tree = load_pipeline_from_zip(model_url, extraction_dir)
if model_tree is None: if model_tree is None:
logger.error(f"Failed to load model {modelId} from mpta file for camera {camera_id}") logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}")
error_response = { error_response = {
"type": "error", "type": "error",
"cameraIdentifier": camera_id, "subscriptionIdentifier": subscriptionIdentifier,
"error": f"Failed to load model {modelId}" "error": f"Failed to load model {modelId}"
} }
await websocket.send_json(error_response) await websocket.send_json(error_response)
@ -573,20 +620,80 @@ async def detect(websocket: WebSocket):
models[camera_id] = {} models[camera_id] = {}
models[camera_id][modelId] = model_tree models[camera_id][modelId] = model_tree
logger.info(f"Successfully loaded model {modelId} for camera {camera_id}") logger.info(f"Successfully loaded model {modelId} for camera {camera_id}")
success_response = { logger.debug(f"Model extraction directory: {extraction_dir}")
"type": "modelLoaded",
"cameraIdentifier": camera_id,
"modelId": modelId
}
await websocket.send_json(success_response)
if camera_id and (rtsp_url or snapshot_url): if camera_id and (rtsp_url or snapshot_url):
with streams_lock: with streams_lock:
# Determine camera URL for shared stream management
camera_url = snapshot_url if snapshot_url else rtsp_url
if camera_id not in streams and len(streams) < max_streams: if camera_id not in streams and len(streams) < max_streams:
# Check if we already have a stream for this camera URL
shared_stream = camera_streams.get(camera_url)
if shared_stream:
# Reuse existing stream
logger.info(f"Reusing existing stream for camera URL: {camera_url}")
buffer = shared_stream["buffer"]
stop_event = shared_stream["stop_event"]
thread = shared_stream["thread"]
mode = shared_stream["mode"]
# Increment reference count
shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1
else:
# Create new stream
buffer = queue.Queue(maxsize=1) buffer = queue.Queue(maxsize=1)
stop_event = threading.Event() stop_event = threading.Event()
if snapshot_url and snapshot_interval:
logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}")
thread = threading.Thread(target=snapshot_reader, args=(camera_identifier, snapshot_url, snapshot_interval, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "snapshot"
# Store shared stream info
shared_stream = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"mode": mode,
"url": snapshot_url,
"snapshot_interval": snapshot_interval,
"ref_count": 1
}
camera_streams[camera_url] = shared_stream
elif rtsp_url:
logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream for camera {camera_id}")
continue
thread = threading.Thread(target=frame_reader, args=(camera_identifier, cap, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "rtsp"
# Store shared stream info
shared_stream = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"mode": mode,
"url": rtsp_url,
"cap": cap,
"ref_count": 1
}
camera_streams[camera_url] = shared_stream
else:
logger.error(f"No valid URL provided for camera {camera_id}")
continue
# Create stream info for this subscription
stream_info = { stream_info = {
"buffer": buffer, "buffer": buffer,
"thread": None, "thread": thread,
"stop_event": stop_event, "stop_event": stop_event,
"modelId": modelId, "modelId": modelId,
"modelName": modelName, "modelName": modelName,
@ -594,52 +701,25 @@ async def detect(websocket: WebSocket):
"cropX1": cropX1, "cropX1": cropX1,
"cropY1": cropY1, "cropY1": cropY1,
"cropX2": cropX2, "cropX2": cropX2,
"cropY2": cropY2 "cropY2": cropY2,
"mode": mode,
"camera_url": camera_url
} }
if snapshot_url and snapshot_interval:
logger.info(f"Using snapshot mode for camera {camera_id}: {snapshot_url}") if mode == "snapshot":
thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event)) stream_info["snapshot_url"] = snapshot_url
thread.daemon = True stream_info["snapshot_interval"] = snapshot_interval
thread.start() elif mode == "rtsp":
stream_info.update({ stream_info["rtsp_url"] = rtsp_url
"snapshot_url": snapshot_url, stream_info["cap"] = shared_stream["cap"]
"snapshot_interval": snapshot_interval,
"mode": "snapshot"
})
stream_info["thread"] = thread
streams[camera_id] = stream_info streams[camera_id] = stream_info
elif rtsp_url: subscription_to_camera[camera_id] = camera_url
logger.info(f"Using RTSP mode for camera {camera_id}: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream for camera {camera_id}")
continue
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
stream_info.update({
"cap": cap,
"rtsp_url": rtsp_url,
"mode": "rtsp"
})
stream_info["thread"] = thread
streams[camera_id] = stream_info
else:
logger.error(f"No valid URL provided for camera {camera_id}")
continue
elif camera_id and camera_id in streams: elif camera_id and camera_id in streams:
# If already subscribed, unsubscribe first # If already subscribed, unsubscribe first
stream = streams.pop(camera_id) logger.info(f"Resubscribing to camera {camera_id}")
stream["stop_event"].set() # Note: Keep models in memory for reuse across subscriptions
stream["thread"].join()
if "cap" in stream:
stream["cap"].release()
logger.info(f"Unsubscribed from camera {camera_id} for resubscription")
with models_lock:
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
elif msg_type == "unsubscribe": elif msg_type == "unsubscribe":
payload = data.get("payload", {}) payload = data.get("payload", {})
subscriptionIdentifier = payload.get("subscriptionIdentifier") subscriptionIdentifier = payload.get("subscriptionIdentifier")
@ -647,13 +727,25 @@ async def detect(websocket: WebSocket):
with streams_lock: with streams_lock:
if camera_id and camera_id in streams: if camera_id and camera_id in streams:
stream = streams.pop(camera_id) stream = streams.pop(camera_id)
stream["stop_event"].set() camera_url = subscription_to_camera.pop(camera_id, None)
stream["thread"].join()
if "cap" in stream: if camera_url and camera_url in camera_streams:
stream["cap"].release() shared_stream = camera_streams[camera_url]
with models_lock: shared_stream["ref_count"] -= 1
if camera_id in models:
del models[camera_id] # If no more references, stop the shared stream
if shared_stream["ref_count"] <= 0:
logger.info(f"Stopping shared stream for camera URL: {camera_url}")
shared_stream["stop_event"].set()
shared_stream["thread"].join()
if "cap" in shared_stream:
shared_stream["cap"].release()
del camera_streams[camera_url]
else:
logger.info(f"Shared stream for {camera_url} still has {shared_stream['ref_count']} references")
logger.info(f"Unsubscribed from camera {camera_id}")
# Note: Keep models in memory for potential reuse
elif msg_type == "requestState": elif msg_type == "requestState":
cpu_usage = psutil.cpu_percent() cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent memory_usage = psutil.virtual_memory().percent
@ -684,6 +776,37 @@ async def detect(websocket: WebSocket):
"cameraConnections": camera_connections "cameraConnections": camera_connections
} }
await websocket.send_text(json.dumps(state_report)) await websocket.send_text(json.dumps(state_report))
elif msg_type == "setSessionId":
payload = data.get("payload", {})
display_identifier = payload.get("displayIdentifier")
session_id = payload.get("sessionId")
if display_identifier:
# Store session ID for this display
if session_id is None:
session_ids.pop(display_identifier, None)
logger.info(f"Cleared session ID for display {display_identifier}")
else:
session_ids[display_identifier] = session_id
logger.info(f"Set session ID {session_id} for display {display_identifier}")
elif msg_type == "patchSession":
session_id = data.get("sessionId")
patch_data = data.get("data", {})
# For now, just acknowledge the patch - actual implementation depends on backend requirements
response = {
"type": "patchSessionResult",
"payload": {
"sessionId": session_id,
"success": True,
"message": "Session patch acknowledged"
}
}
await websocket.send_json(response)
logger.info(f"Acknowledged patch for session {session_id}")
else: else:
logger.error(f"Unknown message type: {msg_type}") logger.error(f"Unknown message type: {msg_type}")
except json.JSONDecodeError: except json.JSONDecodeError:
@ -706,19 +829,23 @@ async def detect(websocket: WebSocket):
stream_task.cancel() stream_task.cancel()
await stream_task await stream_task
with streams_lock: with streams_lock:
for camera_id, stream in streams.items(): # Clean up shared camera streams
stream["stop_event"].set() for camera_url, shared_stream in camera_streams.items():
stream["thread"].join() shared_stream["stop_event"].set()
# Only release cap if it exists (RTSP mode) shared_stream["thread"].join()
if "cap" in stream: if "cap" in shared_stream:
stream["cap"].release() shared_stream["cap"].release()
while not stream["buffer"].empty(): while not shared_stream["buffer"].empty():
try: try:
stream["buffer"].get_nowait() shared_stream["buffer"].get_nowait()
except queue.Empty: except queue.Empty:
pass pass
logger.info(f"Released camera {camera_id} and cleaned up resources") logger.info(f"Released shared camera stream for {camera_url}")
streams.clear() streams.clear()
camera_streams.clear()
subscription_to_camera.clear()
with models_lock: with models_lock:
models.clear() models.clear()
session_ids.clear()
logger.info("WebSocket connection closed") logger.info("WebSocket connection closed")

204
pympta.md Normal file
View file

@ -0,0 +1,204 @@
# pympta: Modular Pipeline Task Executor
`pympta` is a Python module designed to load and execute modular, multi-stage AI pipelines defined in a special package format (`.mpta`). It is primarily used within the detector worker to run complex computer vision tasks where the output of one model can trigger a subsequent model on a specific region of interest.
## Core Concepts
### 1. MPTA Package (`.mpta`)
An `.mpta` file is a standard `.zip` archive with a different extension. It bundles all the necessary components for a pipeline to run.
A typical `.mpta` file has the following structure:
```
my_pipeline.mpta/
├── pipeline.json
├── model1.pt
├── model2.pt
└── ...
```
- **`pipeline.json`**: (Required) The manifest file that defines the structure of the pipeline, the models to use, and the logic connecting them.
- **Model Files (`.pt`, etc.)**: The actual pre-trained model files (e.g., PyTorch, ONNX). The pipeline currently uses `ultralytics.YOLO` models.
### 2. Pipeline Structure
A pipeline is a tree-like structure of "nodes," defined in `pipeline.json`.
- **Root Node**: The entry point of the pipeline. It processes the initial, full-frame image.
- **Branch Nodes**: Child nodes that are triggered by specific detection results from their parent. For example, a root node might detect a "vehicle," which then triggers a branch node to detect a "license plate" within the vehicle's bounding box.
This modular structure allows for creating complex and efficient inference logic, avoiding the need to run every model on every frame.
## `pipeline.json` Specification
This file defines the entire pipeline logic. The root object contains a `pipeline` key for the pipeline definition and an optional `redis` key for Redis configuration.
### Top-Level Object Structure
| Key | Type | Required | Description |
| ---------- | ------ | -------- | ------------------------------------------------------- |
| `pipeline` | Object | Yes | The root node object of the pipeline. |
| `redis` | Object | No | Configuration for connecting to a Redis server. |
### Redis Configuration (`redis`)
| Key | Type | Required | Description |
| ---------- | ------ | -------- | ------------------------------------------------------- |
| `host` | String | Yes | The hostname or IP address of the Redis server. |
| `port` | Number | Yes | The port number of the Redis server. |
| `password` | String | No | The password for Redis authentication. |
| `db` | Number | No | The Redis database number to use. Defaults to `0`. |
### Node Object Structure
| Key | Type | Required | Description |
| ------------------- | ------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------------- |
| `modelId` | String | Yes | A unique identifier for this model node (e.g., "vehicle-detector"). |
| `modelFile` | String | Yes | The path to the model file within the `.mpta` archive (e.g., "yolov8n.pt"). |
| `minConfidence` | Float | Yes | The minimum confidence score (0.0 to 1.0) required for a detection to be considered valid and potentially trigger a branch. |
| `triggerClasses` | Array<String> | Yes | A list of class names that, when detected by the parent, can trigger this node. For the root node, this lists all classes of interest. |
| `crop` | Boolean | No | If `true`, the image is cropped to the parent's detection bounding box before being passed to this node's model. Defaults to `false`. |
| `branches` | Array<Node> | No | A list of child node objects that can be triggered by this node's detections. |
| `actions` | Array<Action> | No | A list of actions to execute upon a successful detection in this node. |
### Action Object Structure
Actions allow the pipeline to interact with Redis. They are executed sequentially for a given detection.
#### Action Context & Dynamic Keys
All actions have access to a dynamic context for formatting keys and messages. The context is created for each detection event and includes:
- All key-value pairs from the detection result (e.g., `class`, `confidence`, `id`).
- `{timestamp_ms}`: The current Unix timestamp in milliseconds.
- `{uuid}`: A unique identifier (UUID4) for the detection event.
- `{image_key}`: If a `redis_save_image` action has already been executed for this event, this placeholder will be replaced with the key where the image was stored.
#### `redis_save_image`
Saves the current image frame (or cropped sub-image) to a Redis key.
| Key | Type | Required | Description |
| ---------------- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
| `type` | String | Yes | Must be `"redis_save_image"`. |
| `key` | String | Yes | The Redis key to save the image to. Can contain any of the dynamic placeholders. |
| `expire_seconds` | Number | No | If provided, sets an expiration time (in seconds) for the Redis key. |
#### `redis_publish`
Publishes a message to a Redis channel.
| Key | Type | Required | Description |
| --------- | ------ | -------- | ------------------------------------------------------------------------------------------------------- |
| `type` | String | Yes | Must be `"redis_publish"`. |
| `channel` | String | Yes | The Redis channel to publish the message to. |
| `message` | String | Yes | The message to publish. Can contain any of the dynamic placeholders, including `{image_key}`. |
### Example `pipeline.json` with Redis
This example demonstrates a pipeline that detects vehicles, saves a uniquely named image of each detection that expires in one hour, and then publishes a notification with the image key.
```json
{
"redis": {
"host": "redis.local",
"port": 6379,
"password": "your-super-secret-password"
},
"pipeline": {
"modelId": "vehicle-detector",
"modelFile": "vehicle_model.pt",
"minConfidence": 0.6,
"triggerClasses": ["car", "truck"],
"actions": [
{
"type": "redis_save_image",
"key": "detections:{class}:{timestamp_ms}:{uuid}",
"expire_seconds": 3600
},
{
"type": "redis_publish",
"channel": "vehicle_events",
"message": "{\"event\":\"new_detection\",\"class\":\"{class}\",\"confidence\":{confidence},\"image_key\":\"{image_key}\"}"
}
],
"branches": []
}
}
```
## API Reference
The `pympta` module exposes two main functions.
### `load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict`
Loads, extracts, and parses an `.mpta` file to build a pipeline tree in memory. It also establishes a Redis connection if configured in `pipeline.json`.
- **Parameters:**
- `zip_source` (str): The file path to the local `.mpta` zip archive.
- `target_dir` (str): A directory path where the archive's contents will be extracted.
- **Returns:**
- A dictionary representing the root node of the pipeline, ready to be used with `run_pipeline`. Returns `None` if loading fails.
### `run_pipeline(frame, node: dict, return_bbox: bool = False)`
Executes the inference pipeline on a single image frame.
- **Parameters:**
- `frame`: The input image frame (e.g., a NumPy array from OpenCV).
- `node` (dict): The pipeline node to execute (typically the root node returned by `load_pipeline_from_zip`).
- `return_bbox` (bool): If `True`, the function returns a tuple `(detection, bounding_box)`. Otherwise, it returns only the `detection`.
- **Returns:**
- The final detection result from the last executed node in the chain. A detection is a dictionary like `{'class': 'car', 'confidence': 0.95, 'id': 1}`. If no detection meets the criteria, it returns `None` (or `(None, None)` if `return_bbox` is `True`).
## Usage Example
This snippet, inspired by `pipeline_webcam.py`, shows how to use `pympta` to load a pipeline and process an image from a webcam.
```python
import cv2
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline
# 1. Define paths
MPTA_FILE = "path/to/your/pipeline.mpta"
CACHE_DIR = ".mptacache"
# 2. Load the pipeline from the .mpta file
# This reads pipeline.json and loads the YOLO models into memory.
model_tree = load_pipeline_from_zip(MPTA_FILE, CACHE_DIR)
if not model_tree:
print("Failed to load pipeline.")
exit()
# 3. Open a video source
cap = cv2.VideoCapture(0)
while True:
ret, frame = cap.read()
if not ret:
break
# 4. Run the pipeline on the current frame
# The function will handle the entire logic tree (e.g., find a car, then find its license plate).
detection_result, bounding_box = run_pipeline(frame, model_tree, return_bbox=True)
# 5. Display the results
if detection_result:
print(f"Detected: {detection_result['class']} with confidence {detection_result['confidence']:.2f}")
if bounding_box:
x1, y1, x2, y2 = bounding_box
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, detection_result['class'], (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
cv2.imshow("Pipeline Output", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
```

View file

@ -6,3 +6,4 @@ ultralytics
opencv-python opencv-python
websockets websockets
fastapi[standard] fastapi[standard]
redis

View file

@ -7,13 +7,16 @@ import requests
import zipfile import zipfile
import shutil import shutil
import traceback import traceback
import redis
import time
import uuid
from ultralytics import YOLO from ultralytics import YOLO
from urllib.parse import urlparse from urllib.parse import urlparse
# Create a logger specifically for this module # Create a logger specifically for this module
logger = logging.getLogger("detector_worker.pympta") logger = logging.getLogger("detector_worker.pympta")
def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict: def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client) -> dict:
# Recursively load a model node from configuration. # Recursively load a model node from configuration.
model_path = os.path.join(mpta_dir, node_config["modelFile"]) model_path = os.path.join(mpta_dir, node_config["modelFile"])
if not os.path.exists(model_path): if not os.path.exists(model_path):
@ -44,13 +47,15 @@ def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict:
"triggerClassIndices": trigger_class_indices, "triggerClassIndices": trigger_class_indices,
"crop": node_config.get("crop", False), "crop": node_config.get("crop", False),
"minConfidence": node_config.get("minConfidence", None), "minConfidence": node_config.get("minConfidence", None),
"actions": node_config.get("actions", []),
"model": model, "model": model,
"branches": [] "branches": [],
"redis_client": redis_client
} }
logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}") logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}")
for child in node_config.get("branches", []): for child in node_config.get("branches", []):
logger.debug(f"Loading branch for parent node {node_config['modelId']}") logger.debug(f"Loading branch for parent node {node_config['modelId']}")
node["branches"].append(load_pipeline_node(child, mpta_dir)) node["branches"].append(load_pipeline_node(child, mpta_dir, redis_client))
return node return node
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict: def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
@ -158,7 +163,26 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
pipeline_config = json.load(f) pipeline_config = json.load(f)
logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}") logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}")
logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}") logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}")
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir)
# Establish Redis connection if configured
redis_client = None
if "redis" in pipeline_config:
redis_config = pipeline_config["redis"]
try:
redis_client = redis.Redis(
host=redis_config["host"],
port=redis_config["port"],
password=redis_config.get("password"),
db=redis_config.get("db", 0),
decode_responses=True
)
redis_client.ping()
logger.info(f"Successfully connected to Redis at {redis_config['host']}:{redis_config['port']}")
except redis.exceptions.ConnectionError as e:
logger.error(f"Failed to connect to Redis: {e}")
redis_client = None
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir, redis_client)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True) logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True)
return None return None
@ -169,6 +193,39 @@ def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True) logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True)
return None return None
def execute_actions(node, frame, detection_result):
if not node["redis_client"] or not node["actions"]:
return
# Create a dynamic context for this detection event
action_context = {
**detection_result,
"timestamp_ms": int(time.time() * 1000),
"uuid": str(uuid.uuid4()),
}
for action in node["actions"]:
try:
if action["type"] == "redis_save_image":
key = action["key"].format(**action_context)
_, buffer = cv2.imencode('.jpg', frame)
expire_seconds = action.get("expire_seconds")
if expire_seconds:
node["redis_client"].setex(key, expire_seconds, buffer.tobytes())
logger.info(f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)")
else:
node["redis_client"].set(key, buffer.tobytes())
logger.info(f"Saved image to Redis with key: {key}")
# Add the generated key to the context for subsequent actions
action_context["image_key"] = key
elif action["type"] == "redis_publish":
channel = action["channel"]
message = action["message"].format(**action_context)
node["redis_client"].publish(channel, message)
logger.info(f"Published message to Redis channel '{channel}': {message}")
except Exception as e:
logger.error(f"Error executing action {action['type']}: {e}")
def run_pipeline(frame, node: dict, return_bbox: bool=False): def run_pipeline(frame, node: dict, return_bbox: bool=False):
""" """
- For detection nodes (task != 'classify'): - For detection nodes (task != 'classify'):
@ -206,6 +263,7 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False):
"confidence": top1_conf, "confidence": top1_conf,
"id": None "id": None
} }
execute_actions(node, frame, det)
return (det, None) if return_bbox else det return (det, None) if return_bbox else det
@ -254,9 +312,11 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False):
det2, _ = run_pipeline(sub, br, return_bbox=True) det2, _ = run_pipeline(sub, br, return_bbox=True)
if det2: if det2:
# return classification result + original bbox # return classification result + original bbox
execute_actions(br, sub, det2)
return (det2, best_box) if return_bbox else det2 return (det2, best_box) if return_bbox else det2
# ─── No branch matched → return this detection ───────────── # ─── No branch matched → return this detection ─────────────
execute_actions(node, frame, best_det)
return (best_det, best_box) if return_bbox else best_det return (best_det, best_box) if return_bbox else best_det
except Exception as e: except Exception as e:

125
test_protocol.py Normal file
View file

@ -0,0 +1,125 @@
#!/usr/bin/env python3
"""
Test script to verify the worker implementation follows the protocol
"""
import json
import asyncio
import websockets
import time
async def test_protocol():
"""Test the worker protocol implementation"""
uri = "ws://localhost:8000"
try:
async with websockets.connect(uri) as websocket:
print("✓ Connected to worker")
# Test 1: Check if we receive heartbeat (stateReport)
print("\n1. Testing heartbeat...")
try:
message = await asyncio.wait_for(websocket.recv(), timeout=5)
data = json.loads(message)
if data.get("type") == "stateReport":
print("✓ Received stateReport heartbeat")
print(f" - CPU Usage: {data.get('cpuUsage', 'N/A')}%")
print(f" - Memory Usage: {data.get('memoryUsage', 'N/A')}%")
print(f" - Camera Connections: {len(data.get('cameraConnections', []))}")
else:
print(f"✗ Expected stateReport, got {data.get('type')}")
except asyncio.TimeoutError:
print("✗ No heartbeat received within 5 seconds")
# Test 2: Request state
print("\n2. Testing requestState...")
await websocket.send(json.dumps({"type": "requestState"}))
try:
message = await asyncio.wait_for(websocket.recv(), timeout=5)
data = json.loads(message)
if data.get("type") == "stateReport":
print("✓ Received stateReport response")
else:
print(f"✗ Expected stateReport, got {data.get('type')}")
except asyncio.TimeoutError:
print("✗ No response to requestState within 5 seconds")
# Test 3: Set session ID
print("\n3. Testing setSessionId...")
session_message = {
"type": "setSessionId",
"payload": {
"displayIdentifier": "display-001",
"sessionId": 12345
}
}
await websocket.send(json.dumps(session_message))
print("✓ Sent setSessionId message")
# Test 4: Test patchSession
print("\n4. Testing patchSession...")
patch_message = {
"type": "patchSession",
"sessionId": 12345,
"data": {
"currentCar": {
"carModel": "Civic",
"carBrand": "Honda"
}
}
}
await websocket.send(json.dumps(patch_message))
# Wait for patchSessionResult
try:
message = await asyncio.wait_for(websocket.recv(), timeout=5)
data = json.loads(message)
if data.get("type") == "patchSessionResult":
print("✓ Received patchSessionResult")
print(f" - Success: {data.get('payload', {}).get('success')}")
print(f" - Message: {data.get('payload', {}).get('message')}")
else:
print(f"✗ Expected patchSessionResult, got {data.get('type')}")
except asyncio.TimeoutError:
print("✗ No patchSessionResult received within 5 seconds")
# Test 5: Test subscribe message format (without actual camera)
print("\n5. Testing subscribe message format...")
subscribe_message = {
"type": "subscribe",
"payload": {
"subscriptionIdentifier": "display-001;cam-001",
"snapshotUrl": "http://example.com/snapshot.jpg",
"snapshotInterval": 5000,
"modelUrl": "http://example.com/model.mpta",
"modelName": "Test Model",
"modelId": 101,
"cropX1": 100,
"cropY1": 200,
"cropX2": 300,
"cropY2": 400
}
}
await websocket.send(json.dumps(subscribe_message))
print("✓ Sent subscribe message (will fail without actual camera/model)")
# Listen for a few more messages to catch any errors
print("\n6. Listening for additional messages...")
for i in range(3):
try:
message = await asyncio.wait_for(websocket.recv(), timeout=2)
data = json.loads(message)
msg_type = data.get("type")
print(f" - Received {msg_type}")
if msg_type == "error":
print(f" Error: {data.get('error')}")
except asyncio.TimeoutError:
break
print("\n✓ Protocol test completed successfully!")
except Exception as e:
print(f"✗ Connection failed: {e}")
print("Make sure the worker is running on localhost:8000")
if __name__ == "__main__":
asyncio.run(test_protocol())

View file

@ -439,3 +439,45 @@ This section shows a typical sequence of messages between the backend and the wo
"cameraConnections": [] "cameraConnections": []
} }
``` ```
## 7. HTTP API: Image Retrieval
In addition to the WebSocket protocol, the worker exposes an HTTP endpoint for retrieving the latest image frame from a camera.
### Endpoint
```
GET /camera/{camera_id}/image
```
- **`camera_id`**: The full `subscriptionIdentifier` (e.g., `display-001;cam-001`).
### Response
- **Success (200):** Returns the latest JPEG image from the camera stream.
- `Content-Type: image/jpeg`
- Binary JPEG data.
- **Error (404):** If the camera is not found or no frame is available.
- JSON error response.
- **Error (500):** Internal server error.
### Example Request
```
GET /camera/display-001;cam-001/image
```
### Example Response
- **Headers:**
```
Content-Type: image/jpeg
```
- **Body:** Binary JPEG image.
### Notes
- The endpoint returns the most recent frame available for the specified camera subscription.
- If multiple displays share the same camera, each subscription has its own buffer; the endpoint uses the buffer for the given `camera_id`.
- This API is useful for debugging, monitoring, or integrating with external systems that require direct image access.