182 lines
6.8 KiB
Python
182 lines
6.8 KiB
Python
import os
|
||
import json
|
||
import logging
|
||
import torch
|
||
import cv2
|
||
import zipfile
|
||
import shutil
|
||
from ultralytics import YOLO
|
||
from urllib.parse import urlparse
|
||
|
||
def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict:
|
||
model_path = os.path.join(mpta_dir, node_config["modelFile"])
|
||
if not os.path.exists(model_path):
|
||
logging.error(f"Model file {model_path} not found.")
|
||
raise FileNotFoundError(f"Model file {model_path} not found.")
|
||
logging.info(f"Loading model {node_config['modelId']} from {model_path}")
|
||
model = YOLO(model_path)
|
||
if torch.cuda.is_available():
|
||
model.to("cuda")
|
||
|
||
# map triggerClasses names → indices for YOLO
|
||
names = model.names # idx -> class name
|
||
trigger_names = node_config.get("triggerClasses", [])
|
||
trigger_inds = [i for i, nm in names.items() if nm in trigger_names]
|
||
|
||
return {
|
||
"modelId": node_config["modelId"],
|
||
"modelFile": node_config["modelFile"],
|
||
"triggerClasses": trigger_names,
|
||
"triggerClassIndices": trigger_inds,
|
||
"crop": node_config.get("crop", False),
|
||
"minConfidence": node_config.get("minConfidence", 0.0),
|
||
"model": model,
|
||
"branches": [
|
||
load_pipeline_node(child, mpta_dir)
|
||
for child in node_config.get("branches", [])
|
||
]
|
||
}
|
||
|
||
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
|
||
os.makedirs(target_dir, exist_ok=True)
|
||
zip_path = os.path.join(target_dir, "pipeline.mpta")
|
||
parsed = urlparse(zip_source)
|
||
if parsed.scheme in ("", "file"):
|
||
local = parsed.path if parsed.scheme == "file" else zip_source
|
||
if not os.path.exists(local):
|
||
logging.error(f"Local file {local} does not exist.")
|
||
return None
|
||
shutil.copy(local, zip_path)
|
||
else:
|
||
logging.error("HTTP download not supported; use local file.")
|
||
return None
|
||
|
||
with zipfile.ZipFile(zip_path, "r") as z:
|
||
z.extractall(target_dir)
|
||
os.remove(zip_path)
|
||
|
||
base = os.path.splitext(os.path.basename(zip_source))[0]
|
||
mpta_dir = os.path.join(target_dir, base)
|
||
cfg = os.path.join(mpta_dir, "pipeline.json")
|
||
if not os.path.exists(cfg):
|
||
logging.error("pipeline.json not found in archive.")
|
||
return None
|
||
|
||
with open(cfg) as f:
|
||
pipeline_config = json.load(f)
|
||
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir)
|
||
|
||
|
||
def run_pipeline(frame, node: dict, return_bbox: bool=False):
|
||
"""
|
||
- For detection nodes (task != 'classify'):
|
||
• runs `track(..., classes=triggerClassIndices)`
|
||
• picks top box ≥ minConfidence
|
||
• optionally crops & resizes → recurse into child
|
||
• else returns (det_dict, bbox)
|
||
- For classify nodes:
|
||
• runs `predict()`
|
||
• returns top (class,confidence) and no bbox
|
||
"""
|
||
try:
|
||
task = getattr(node["model"], "task", None)
|
||
|
||
# ─── Classification stage ───────────────────────────────────
|
||
# if task == "classify":
|
||
# results = node["model"].predict(frame, stream=False)
|
||
# dets = []
|
||
# for r in results:
|
||
# probs = r.probs
|
||
# if probs is not None:
|
||
# # sort descending
|
||
# idxs = probs.argsort(descending=True)
|
||
# for cid in idxs:
|
||
# dets.append({
|
||
# "class": node["model"].names[int(cid)],
|
||
# "confidence": float(probs[int(cid)]),
|
||
# "id": None
|
||
# })
|
||
# if not dets:
|
||
# return (None, None) if return_bbox else None
|
||
|
||
# best = dets[0]
|
||
# return (best, None) if return_bbox else best
|
||
|
||
if task == "classify":
|
||
# run the classifier and grab its top-1 directly via the Probs API
|
||
results = node["model"].predict(frame, stream=False)
|
||
# nothing returned?
|
||
if not results:
|
||
return (None, None) if return_bbox else None
|
||
|
||
# take the first result's probs object
|
||
r = results[0]
|
||
probs = r.probs
|
||
if probs is None:
|
||
return (None, None) if return_bbox else None
|
||
|
||
# get the top-1 class index and its confidence
|
||
top1_idx = int(probs.top1)
|
||
top1_conf = float(probs.top1conf)
|
||
|
||
det = {
|
||
"class": node["model"].names[top1_idx],
|
||
"confidence": top1_conf,
|
||
"id": None
|
||
}
|
||
return (det, None) if return_bbox else det
|
||
|
||
|
||
# ─── Detection stage ────────────────────────────────────────
|
||
# only look for your triggerClasses
|
||
tk = node["triggerClassIndices"]
|
||
res = node["model"].track(
|
||
frame,
|
||
stream=False,
|
||
persist=True,
|
||
**({"classes": tk} if tk else {})
|
||
)[0]
|
||
|
||
dets, boxes = [], []
|
||
for box in res.boxes:
|
||
conf = float(box.cpu().conf[0])
|
||
cid = int(box.cpu().cls[0])
|
||
name = node["model"].names[cid]
|
||
if conf < node["minConfidence"]:
|
||
continue
|
||
xy = box.cpu().xyxy[0]
|
||
x1,y1,x2,y2 = map(int, xy)
|
||
dets.append({"class": name, "confidence": conf,
|
||
"id": box.id.item() if hasattr(box, "id") else None})
|
||
boxes.append((x1, y1, x2, y2))
|
||
|
||
if not dets:
|
||
return (None, None) if return_bbox else None
|
||
|
||
# take highest‐confidence
|
||
best_idx = max(range(len(dets)), key=lambda i: dets[i]["confidence"])
|
||
best_det = dets[best_idx]
|
||
best_box = boxes[best_idx]
|
||
|
||
# ─── Branch (classification) ───────────────────────────────
|
||
for br in node["branches"]:
|
||
if (best_det["class"] in br["triggerClasses"]
|
||
and best_det["confidence"] >= br["minConfidence"]):
|
||
# crop if requested
|
||
sub = frame
|
||
if br["crop"]:
|
||
x1,y1,x2,y2 = best_box
|
||
sub = frame[y1:y2, x1:x2]
|
||
sub = cv2.resize(sub, (224, 224))
|
||
|
||
det2, _ = run_pipeline(sub, br, return_bbox=True)
|
||
if det2:
|
||
# return classification result + original bbox
|
||
return (det2, best_box) if return_bbox else det2
|
||
|
||
# ─── No branch matched → return this detection ─────────────
|
||
return (best_det, best_box) if return_bbox else best_det
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error in node {node.get('modelId')}: {e}")
|
||
return (None, None) if return_bbox else None
|