python-detector-worker/siwatsystem/pympta.py
2025-05-12 19:19:40 +07:00

182 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import logging
import torch
import cv2
import zipfile
import shutil
from ultralytics import YOLO
from urllib.parse import urlparse
def load_pipeline_node(node_config: dict, mpta_dir: str) -> dict:
model_path = os.path.join(mpta_dir, node_config["modelFile"])
if not os.path.exists(model_path):
logging.error(f"Model file {model_path} not found.")
raise FileNotFoundError(f"Model file {model_path} not found.")
logging.info(f"Loading model {node_config['modelId']} from {model_path}")
model = YOLO(model_path)
if torch.cuda.is_available():
model.to("cuda")
# map triggerClasses names → indices for YOLO
names = model.names # idx -> class name
trigger_names = node_config.get("triggerClasses", [])
trigger_inds = [i for i, nm in names.items() if nm in trigger_names]
return {
"modelId": node_config["modelId"],
"modelFile": node_config["modelFile"],
"triggerClasses": trigger_names,
"triggerClassIndices": trigger_inds,
"crop": node_config.get("crop", False),
"minConfidence": node_config.get("minConfidence", 0.0),
"model": model,
"branches": [
load_pipeline_node(child, mpta_dir)
for child in node_config.get("branches", [])
]
}
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
os.makedirs(target_dir, exist_ok=True)
zip_path = os.path.join(target_dir, "pipeline.mpta")
parsed = urlparse(zip_source)
if parsed.scheme in ("", "file"):
local = parsed.path if parsed.scheme == "file" else zip_source
if not os.path.exists(local):
logging.error(f"Local file {local} does not exist.")
return None
shutil.copy(local, zip_path)
else:
logging.error("HTTP download not supported; use local file.")
return None
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(target_dir)
os.remove(zip_path)
base = os.path.splitext(os.path.basename(zip_source))[0]
mpta_dir = os.path.join(target_dir, base)
cfg = os.path.join(mpta_dir, "pipeline.json")
if not os.path.exists(cfg):
logging.error("pipeline.json not found in archive.")
return None
with open(cfg) as f:
pipeline_config = json.load(f)
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir)
def run_pipeline(frame, node: dict, return_bbox: bool=False):
"""
- For detection nodes (task != 'classify'):
• runs `track(..., classes=triggerClassIndices)`
• picks top box ≥ minConfidence
• optionally crops & resizes → recurse into child
• else returns (det_dict, bbox)
- For classify nodes:
• runs `predict()`
• returns top (class,confidence) and no bbox
"""
try:
task = getattr(node["model"], "task", None)
# ─── Classification stage ───────────────────────────────────
# if task == "classify":
# results = node["model"].predict(frame, stream=False)
# dets = []
# for r in results:
# probs = r.probs
# if probs is not None:
# # sort descending
# idxs = probs.argsort(descending=True)
# for cid in idxs:
# dets.append({
# "class": node["model"].names[int(cid)],
# "confidence": float(probs[int(cid)]),
# "id": None
# })
# if not dets:
# return (None, None) if return_bbox else None
# best = dets[0]
# return (best, None) if return_bbox else best
if task == "classify":
# run the classifier and grab its top-1 directly via the Probs API
results = node["model"].predict(frame, stream=False)
# nothing returned?
if not results:
return (None, None) if return_bbox else None
# take the first result's probs object
r = results[0]
probs = r.probs
if probs is None:
return (None, None) if return_bbox else None
# get the top-1 class index and its confidence
top1_idx = int(probs.top1)
top1_conf = float(probs.top1conf)
det = {
"class": node["model"].names[top1_idx],
"confidence": top1_conf,
"id": None
}
return (det, None) if return_bbox else det
# ─── Detection stage ────────────────────────────────────────
# only look for your triggerClasses
tk = node["triggerClassIndices"]
res = node["model"].track(
frame,
stream=False,
persist=True,
**({"classes": tk} if tk else {})
)[0]
dets, boxes = [], []
for box in res.boxes:
conf = float(box.cpu().conf[0])
cid = int(box.cpu().cls[0])
name = node["model"].names[cid]
if conf < node["minConfidence"]:
continue
xy = box.cpu().xyxy[0]
x1,y1,x2,y2 = map(int, xy)
dets.append({"class": name, "confidence": conf,
"id": box.id.item() if hasattr(box, "id") else None})
boxes.append((x1, y1, x2, y2))
if not dets:
return (None, None) if return_bbox else None
# take highestconfidence
best_idx = max(range(len(dets)), key=lambda i: dets[i]["confidence"])
best_det = dets[best_idx]
best_box = boxes[best_idx]
# ─── Branch (classification) ───────────────────────────────
for br in node["branches"]:
if (best_det["class"] in br["triggerClasses"]
and best_det["confidence"] >= br["minConfidence"]):
# crop if requested
sub = frame
if br["crop"]:
x1,y1,x2,y2 = best_box
sub = frame[y1:y2, x1:x2]
sub = cv2.resize(sub, (224, 224))
det2, _ = run_pipeline(sub, br, return_bbox=True)
if det2:
# return classification result + original bbox
return (det2, best_box) if return_bbox else det2
# ─── No branch matched → return this detection ─────────────
return (best_det, best_box) if return_bbox else best_det
except Exception as e:
logging.error(f"Error in node {node.get('modelId')}: {e}")
return (None, None) if return_bbox else None