308 lines
13 KiB
Python
308 lines
13 KiB
Python
import os
|
||
import json
|
||
import logging
|
||
import torch
|
||
import cv2
|
||
import requests
|
||
import zipfile
|
||
import shutil
|
||
import traceback
|
||
import redis
|
||
from ultralytics import YOLO
|
||
from urllib.parse import urlparse
|
||
|
||
# Create a logger specifically for this module
|
||
logger = logging.getLogger("detector_worker.pympta")
|
||
|
||
def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client) -> dict:
|
||
# Recursively load a model node from configuration.
|
||
model_path = os.path.join(mpta_dir, node_config["modelFile"])
|
||
if not os.path.exists(model_path):
|
||
logger.error(f"Model file {model_path} not found. Current directory: {os.getcwd()}")
|
||
logger.error(f"Directory content: {os.listdir(os.path.dirname(model_path))}")
|
||
raise FileNotFoundError(f"Model file {model_path} not found.")
|
||
logger.info(f"Loading model for node {node_config['modelId']} from {model_path}")
|
||
model = YOLO(model_path)
|
||
if torch.cuda.is_available():
|
||
logger.info(f"CUDA available. Moving model {node_config['modelId']} to GPU")
|
||
model.to("cuda")
|
||
else:
|
||
logger.info(f"CUDA not available. Using CPU for model {node_config['modelId']}")
|
||
|
||
# Prepare trigger class indices for optimization
|
||
trigger_classes = node_config.get("triggerClasses", [])
|
||
trigger_class_indices = None
|
||
if trigger_classes and hasattr(model, "names"):
|
||
# Convert class names to indices for the model
|
||
trigger_class_indices = [i for i, name in model.names.items()
|
||
if name in trigger_classes]
|
||
logger.debug(f"Converted trigger classes to indices: {trigger_class_indices}")
|
||
|
||
node = {
|
||
"modelId": node_config["modelId"],
|
||
"modelFile": node_config["modelFile"],
|
||
"triggerClasses": trigger_classes,
|
||
"triggerClassIndices": trigger_class_indices,
|
||
"crop": node_config.get("crop", False),
|
||
"minConfidence": node_config.get("minConfidence", None),
|
||
"actions": node_config.get("actions", []),
|
||
"model": model,
|
||
"branches": [],
|
||
"redis_client": redis_client
|
||
}
|
||
logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}")
|
||
for child in node_config.get("branches", []):
|
||
logger.debug(f"Loading branch for parent node {node_config['modelId']}")
|
||
node["branches"].append(load_pipeline_node(child, mpta_dir, redis_client))
|
||
return node
|
||
|
||
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
|
||
logger.info(f"Attempting to load pipeline from {zip_source} to {target_dir}")
|
||
os.makedirs(target_dir, exist_ok=True)
|
||
zip_path = os.path.join(target_dir, "pipeline.mpta")
|
||
|
||
# Parse the source; only local files are supported here.
|
||
parsed = urlparse(zip_source)
|
||
if parsed.scheme in ("", "file"):
|
||
local_path = parsed.path if parsed.scheme == "file" else zip_source
|
||
logger.debug(f"Checking if local file exists: {local_path}")
|
||
if os.path.exists(local_path):
|
||
try:
|
||
shutil.copy(local_path, zip_path)
|
||
logger.info(f"Copied local .mpta file from {local_path} to {zip_path}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to copy local .mpta file from {local_path}: {str(e)}", exc_info=True)
|
||
return None
|
||
else:
|
||
logger.error(f"Local file {local_path} does not exist. Current directory: {os.getcwd()}")
|
||
# List all subdirectories of models directory to help debugging
|
||
if os.path.exists("models"):
|
||
logger.error(f"Content of models directory: {os.listdir('models')}")
|
||
for root, dirs, files in os.walk("models"):
|
||
logger.error(f"Directory {root} contains subdirs: {dirs} and files: {files}")
|
||
else:
|
||
logger.error("The models directory doesn't exist")
|
||
return None
|
||
else:
|
||
logger.error(f"HTTP download functionality has been moved. Use a local file path here. Received: {zip_source}")
|
||
return None
|
||
|
||
try:
|
||
if not os.path.exists(zip_path):
|
||
logger.error(f"Zip file not found at expected location: {zip_path}")
|
||
return None
|
||
|
||
logger.debug(f"Extracting .mpta file from {zip_path} to {target_dir}")
|
||
# Extract contents and track the directories created
|
||
extracted_dirs = []
|
||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||
file_list = zip_ref.namelist()
|
||
logger.debug(f"Files in .mpta archive: {file_list}")
|
||
|
||
# Extract and track the top-level directories
|
||
for file_path in file_list:
|
||
parts = file_path.split('/')
|
||
if len(parts) > 1:
|
||
top_dir = parts[0]
|
||
if top_dir and top_dir not in extracted_dirs:
|
||
extracted_dirs.append(top_dir)
|
||
|
||
# Now extract the files
|
||
zip_ref.extractall(target_dir)
|
||
|
||
logger.info(f"Successfully extracted .mpta file to {target_dir}")
|
||
logger.debug(f"Extracted directories: {extracted_dirs}")
|
||
|
||
# Check what was actually created after extraction
|
||
actual_dirs = [d for d in os.listdir(target_dir) if os.path.isdir(os.path.join(target_dir, d))]
|
||
logger.debug(f"Actual directories created: {actual_dirs}")
|
||
except zipfile.BadZipFile as e:
|
||
logger.error(f"Bad zip file {zip_path}: {str(e)}", exc_info=True)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Failed to extract .mpta file {zip_path}: {str(e)}", exc_info=True)
|
||
return None
|
||
finally:
|
||
if os.path.exists(zip_path):
|
||
os.remove(zip_path)
|
||
logger.debug(f"Removed temporary zip file: {zip_path}")
|
||
|
||
# Use the first extracted directory if it exists, otherwise use the expected name
|
||
pipeline_name = os.path.basename(zip_source)
|
||
pipeline_name = os.path.splitext(pipeline_name)[0]
|
||
|
||
# Find the directory with pipeline.json
|
||
mpta_dir = None
|
||
# First try the expected directory name
|
||
expected_dir = os.path.join(target_dir, pipeline_name)
|
||
if os.path.exists(expected_dir) and os.path.exists(os.path.join(expected_dir, "pipeline.json")):
|
||
mpta_dir = expected_dir
|
||
logger.debug(f"Found pipeline.json in the expected directory: {mpta_dir}")
|
||
else:
|
||
# Look through all subdirectories for pipeline.json
|
||
for subdir in actual_dirs:
|
||
potential_dir = os.path.join(target_dir, subdir)
|
||
if os.path.exists(os.path.join(potential_dir, "pipeline.json")):
|
||
mpta_dir = potential_dir
|
||
logger.info(f"Found pipeline.json in directory: {mpta_dir} (different from expected: {expected_dir})")
|
||
break
|
||
|
||
if not mpta_dir:
|
||
logger.error(f"Could not find pipeline.json in any extracted directory. Directory content: {os.listdir(target_dir)}")
|
||
return None
|
||
|
||
pipeline_json_path = os.path.join(mpta_dir, "pipeline.json")
|
||
if not os.path.exists(pipeline_json_path):
|
||
logger.error(f"pipeline.json not found in the .mpta file. Files in directory: {os.listdir(mpta_dir)}")
|
||
return None
|
||
|
||
try:
|
||
with open(pipeline_json_path, "r") as f:
|
||
pipeline_config = json.load(f)
|
||
logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}")
|
||
logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}")
|
||
|
||
# Establish Redis connection if configured
|
||
redis_client = None
|
||
if "redis" in pipeline_config:
|
||
redis_config = pipeline_config["redis"]
|
||
try:
|
||
redis_client = redis.Redis(
|
||
host=redis_config["host"],
|
||
port=redis_config["port"],
|
||
password=redis_config.get("password"),
|
||
db=redis_config.get("db", 0),
|
||
decode_responses=True
|
||
)
|
||
redis_client.ping()
|
||
logger.info(f"Successfully connected to Redis at {redis_config['host']}:{redis_config['port']}")
|
||
except redis.exceptions.ConnectionError as e:
|
||
logger.error(f"Failed to connect to Redis: {e}")
|
||
redis_client = None
|
||
|
||
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir, redis_client)
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True)
|
||
return None
|
||
except KeyError as e:
|
||
logger.error(f"Missing key in pipeline.json: {str(e)}", exc_info=True)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True)
|
||
return None
|
||
|
||
def execute_actions(node, frame, detection_result):
|
||
if not node["redis_client"] or not node["actions"]:
|
||
return
|
||
|
||
for action in node["actions"]:
|
||
try:
|
||
if action["type"] == "redis_save_image":
|
||
key = action["key"].format(**detection_result)
|
||
_, buffer = cv2.imencode('.jpg', frame)
|
||
node["redis_client"].set(key, buffer.tobytes())
|
||
logger.info(f"Saved image to Redis with key: {key}")
|
||
elif action["type"] == "redis_publish":
|
||
channel = action["channel"]
|
||
message = action["message"].format(**detection_result)
|
||
node["redis_client"].publish(channel, message)
|
||
logger.info(f"Published message to Redis channel '{channel}': {message}")
|
||
except Exception as e:
|
||
logger.error(f"Error executing action {action['type']}: {e}")
|
||
|
||
def run_pipeline(frame, node: dict, return_bbox: bool=False):
|
||
"""
|
||
- For detection nodes (task != 'classify'):
|
||
• runs `track(..., classes=triggerClassIndices)`
|
||
• picks top box ≥ minConfidence
|
||
• optionally crops & resizes → recurse into child
|
||
• else returns (det_dict, bbox)
|
||
- For classify nodes:
|
||
• runs `predict()`
|
||
• returns top (class,confidence) and no bbox
|
||
"""
|
||
try:
|
||
task = getattr(node["model"], "task", None)
|
||
|
||
# ─── Classification stage ───────────────────────────────────
|
||
if task == "classify":
|
||
# run the classifier and grab its top-1 directly via the Probs API
|
||
results = node["model"].predict(frame, stream=False)
|
||
# nothing returned?
|
||
if not results:
|
||
return (None, None) if return_bbox else None
|
||
|
||
# take the first result's probs object
|
||
r = results[0]
|
||
probs = r.probs
|
||
if probs is None:
|
||
return (None, None) if return_bbox else None
|
||
|
||
# get the top-1 class index and its confidence
|
||
top1_idx = int(probs.top1)
|
||
top1_conf = float(probs.top1conf)
|
||
|
||
det = {
|
||
"class": node["model"].names[top1_idx],
|
||
"confidence": top1_conf,
|
||
"id": None
|
||
}
|
||
execute_actions(node, frame, det)
|
||
return (det, None) if return_bbox else det
|
||
|
||
|
||
# ─── Detection stage ────────────────────────────────────────
|
||
# only look for your triggerClasses
|
||
tk = node["triggerClassIndices"]
|
||
res = node["model"].track(
|
||
frame,
|
||
stream=False,
|
||
persist=True,
|
||
**({"classes": tk} if tk else {})
|
||
)[0]
|
||
|
||
dets, boxes = [], []
|
||
for box in res.boxes:
|
||
conf = float(box.cpu().conf[0])
|
||
cid = int(box.cpu().cls[0])
|
||
name = node["model"].names[cid]
|
||
if conf < node["minConfidence"]:
|
||
continue
|
||
xy = box.cpu().xyxy[0]
|
||
x1,y1,x2,y2 = map(int, xy)
|
||
dets.append({"class": name, "confidence": conf,
|
||
"id": box.id.item() if hasattr(box, "id") else None})
|
||
boxes.append((x1, y1, x2, y2))
|
||
|
||
if not dets:
|
||
return (None, None) if return_bbox else None
|
||
|
||
# take highest‐confidence
|
||
best_idx = max(range(len(dets)), key=lambda i: dets[i]["confidence"])
|
||
best_det = dets[best_idx]
|
||
best_box = boxes[best_idx]
|
||
|
||
# ─── Branch (classification) ───────────────────────────────
|
||
for br in node["branches"]:
|
||
if (best_det["class"] in br["triggerClasses"]
|
||
and best_det["confidence"] >= br["minConfidence"]):
|
||
# crop if requested
|
||
sub = frame
|
||
if br["crop"]:
|
||
x1,y1,x2,y2 = best_box
|
||
sub = frame[y1:y2, x1:x2]
|
||
sub = cv2.resize(sub, (224, 224))
|
||
|
||
det2, _ = run_pipeline(sub, br, return_bbox=True)
|
||
if det2:
|
||
# return classification result + original bbox
|
||
execute_actions(br, sub, det2)
|
||
return (det2, best_box) if return_bbox else det2
|
||
|
||
# ─── No branch matched → return this detection ─────────────
|
||
execute_actions(node, frame, best_det)
|
||
return (best_det, best_box) if return_bbox else best_det
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error in node {node.get('modelId')}: {e}")
|
||
return (None, None) if return_bbox else None
|