diff --git a/app.py b/app.py index aace69c..ce80974 100644 --- a/app.py +++ b/app.py @@ -18,6 +18,9 @@ from fastapi.websockets import WebSocketDisconnect from websockets.exceptions import ConnectionClosedError from ultralytics import YOLO +# Import shared pipeline functions +from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline + app = FastAPI() # Global dictionaries to keep track of models and streams @@ -57,161 +60,6 @@ WORKER_TIMEOUT_MS = 10000 streams_lock = threading.Lock() models_lock = threading.Lock() -#################################################### -# Pipeline (Model)-loading helper functions -#################################################### -def load_pipeline_node(node_config: dict, models_dir: str) -> dict: - """ - Recursively load a model node. - Expects node_config to have: - - modelId: a unique identifier - - modelFile: the .pt file in models_dir - - triggerClasses: list of class names that activate child branches - - crop: boolean; if True, we crop to the bounding box for the next model - - minConfidence: (optional) minimum confidence required to enter this branch - - branches: list of child node configurations - """ - model_path = os.path.join(models_dir, node_config["modelFile"]) - if not os.path.exists(model_path): - logging.error(f"Model file {model_path} not found.") - raise FileNotFoundError(f"Model file {model_path} not found.") - - logging.info(f"Loading model for node {node_config['modelId']} from {model_path}") - model = YOLO(model_path) - if torch.cuda.is_available(): - model.to("cuda") - - node = { - "modelId": node_config["modelId"], - "modelFile": node_config["modelFile"], - "triggerClasses": node_config.get("triggerClasses", []), - "crop": node_config.get("crop", False), - "minConfidence": node_config.get("minConfidence", None), # NEW FIELD - "model": model, - "branches": [] - } - for child_config in node_config.get("branches", []): - child_node = load_pipeline_node(child_config, models_dir) - node["branches"].append(child_node) - return node - -def load_pipeline_from_zip(zip_url: str, target_dir: str) -> dict: - """ - Download the .mpta file from zip_url, extract it to target_dir, - and load the pipeline configuration (pipeline.json). - Returns the model tree (root node) loaded with YOLO models. - """ - os.makedirs(target_dir, exist_ok=True) - zip_path = os.path.join(target_dir, "pipeline.mpta") - - try: - response = requests.get(zip_url, stream=True) - if response.status_code == 200: - with open(zip_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - logging.info(f"Downloaded .mpta file from {zip_url} to {zip_path}") - else: - logging.error(f"Failed to download .mpta file (status {response.status_code})") - return None - except Exception as e: - logging.error(f"Exception downloading .mpta file from {zip_url}: {e}") - return None - - # Extract the .mpta file - try: - with zipfile.ZipFile(zip_path, "r") as zip_ref: - zip_ref.extractall(target_dir) - logging.info(f"Extracted .mpta file to {target_dir}") - except Exception as e: - logging.error(f"Failed to extract .mpta file: {e}") - return None - finally: - if os.path.exists(zip_path): - os.remove(zip_path) - - # Load pipeline.json - pipeline_json_path = os.path.join(target_dir, "pipeline.json") - if not os.path.exists(pipeline_json_path): - logging.error("pipeline.json not found in the .mpta file") - return None - - try: - with open(pipeline_json_path, "r") as f: - pipeline_config = json.load(f) - # Build the model tree recursively - model_tree = load_pipeline_node(pipeline_config["pipeline"], target_dir) - return model_tree - except Exception as e: - logging.error(f"Error loading pipeline.json: {e}") - return None - -#################################################### -# Model execution function -#################################################### -def run_pipeline(frame, node: dict): - """ - Run the model at the current node. - - Select the highest-confidence detection (if any). - - If 'crop' is True, crop to the bounding box for the next stage. - - If the detected class matches a branch's triggerClasses, check the confidence. - If the detection's confidence is below branch["minConfidence"] (if specified), - do not enter the branch and return the current detection. - Returns the final detection result (dict) or None. - """ - try: - results = node["model"].track(frame, stream=False, persist=True) - detection = None - max_conf = -1 - best_box = None - - for r in results: - for box in r.boxes: - box_cpu = box.cpu() - conf = float(box_cpu.conf[0]) - if conf > max_conf and hasattr(box, "id") and box.id is not None: - max_conf = conf - detection = { - "class": node["model"].names[int(box_cpu.cls[0])], - "confidence": conf, - "id": box.id.item(), - } - best_box = box_cpu - - # If there's a detection and crop is True, crop frame to bounding box - if detection and node.get("crop", False) and best_box is not None: - coords = best_box.xyxy[0] # [x1, y1, x2, y2] - x1, y1, x2, y2 = map(int, coords) - h, w = frame.shape[:2] - x1 = max(0, x1) - y1 = max(0, y1) - x2 = min(w, x2) - y2 = min(h, y2) - - if x2 > x1 and y2 > y1: - frame = frame[y1:y2, x1:x2] # crop the frame - - if detection is not None: - # Check if any branch should be entered based on trigger classes - for branch in node["branches"]: - if detection["class"] in branch.get("triggerClasses", []): - # Check for a minimum confidence threshold for this branch - min_conf = branch.get("minConfidence") - if min_conf is not None and detection["confidence"] < min_conf: - logging.debug( - f"Detection confidence {detection['confidence']} below threshold " - f"{min_conf} for branch {branch['modelId']}. Ending pipeline at current node." - ) - return detection - branch_detection = run_pipeline(frame, branch) - if branch_detection is not None: - return branch_detection - return detection - return None - except Exception as e: - logging.error(f"Error running pipeline on node {node.get('modelId')}: {e}") - return None - #################################################### # Detection and frame processing functions #################################################### diff --git a/pipeline_webcam.py b/pipeline_webcam.py new file mode 100644 index 0000000..d13af2a --- /dev/null +++ b/pipeline_webcam.py @@ -0,0 +1,52 @@ +import argparse +import os +import cv2 +import time +import logging + +from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline + +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(message)s") + +def main(mpta_url: str, video_source: str): + extraction_dir = os.path.join("models", "webcam_pipeline") + logging.info(f"Loading pipeline from {mpta_url}") + model_tree = load_pipeline_from_zip(mpta_url, extraction_dir) + if model_tree is None: + logging.error("Failed to load pipeline.") + return + + cap = cv2.VideoCapture(video_source) + if not cap.isOpened(): + logging.error(f"Cannot open video source {video_source}") + return + + logging.info("Press 'q' to exit.") + while True: + ret, frame = cap.read() + if not ret: + logging.error("Failed to capture frame.") + break + + detection, bbox = run_pipeline(frame, model_tree, return_bbox=True) + if bbox: + x1, y1, x2, y2 = bbox + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) + label = detection["class"] if detection else "Detection" + cv2.putText(frame, label, (x1, y1 - 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2) + + cv2.imshow("Pipeline Webcam", frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + cap.release() + cv2.destroyAllWindows() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run pipeline webcam utility.") + parser.add_argument("--mpta-url", type=str, required=True, help="URL to the pipeline mpta (ZIP) file.") + parser.add_argument("--video", type=str, default="0", help="Video source (default webcam index 0).") + args = parser.parse_args() + video_source = int(args.video) if args.video.isdigit() else args.video + main(args.mpta_url, video_source) diff --git a/siwatsystem/pympta.py b/siwatsystem/pympta.py new file mode 100644 index 0000000..b5f05d1 --- /dev/null +++ b/siwatsystem/pympta.py @@ -0,0 +1,135 @@ +import os +import json +import logging +import torch +import cv2 +import requests +import zipfile +from ultralytics import YOLO + +def load_pipeline_node(node_config: dict, models_dir: str) -> dict: + # Recursively load a model node from configuration. + model_path = os.path.join(models_dir, node_config["modelFile"]) + if not os.path.exists(model_path): + logging.error(f"Model file {model_path} not found.") + raise FileNotFoundError(f"Model file {model_path} not found.") + logging.info(f"Loading model for node {node_config['modelId']} from {model_path}") + model = YOLO(model_path) + if torch.cuda.is_available(): + model.to("cuda") + node = { + "modelId": node_config["modelId"], + "modelFile": node_config["modelFile"], + "triggerClasses": node_config.get("triggerClasses", []), + "crop": node_config.get("crop", False), + "minConfidence": node_config.get("minConfidence", None), + "model": model, + "branches": [] + } + for child in node_config.get("branches", []): + node["branches"].append(load_pipeline_node(child, models_dir)) + return node + +def load_pipeline_from_zip(zip_url: str, target_dir: str) -> dict: + # Download, extract, and load a pipeline configuration from a zip (.mpta) file. + os.makedirs(target_dir, exist_ok=True) + zip_path = os.path.join(target_dir, "pipeline.mpta") + try: + response = requests.get(zip_url, stream=True) + if response.status_code == 200: + with open(zip_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logging.info(f"Downloaded .mpta file from {zip_url} to {zip_path}") + else: + logging.error(f"Failed to download .mpta file (status {response.status_code})") + return None + except Exception as e: + logging.error(f"Exception downloading .mpta file from {zip_url}: {e}") + return None + + try: + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(target_dir) + logging.info(f"Extracted .mpta file to {target_dir}") + except Exception as e: + logging.error(f"Failed to extract .mpta file: {e}") + return None + finally: + if os.path.exists(zip_path): + os.remove(zip_path) + + pipeline_json_path = os.path.join(target_dir, "pipeline.json") + if not os.path.exists(pipeline_json_path): + logging.error("pipeline.json not found in the .mpta file") + return None + + try: + with open(pipeline_json_path, "r") as f: + pipeline_config = json.load(f) + return load_pipeline_node(pipeline_config["pipeline"], target_dir) + except Exception as e: + logging.error(f"Error loading pipeline.json: {e}") + return None + +def run_pipeline(frame, node: dict, return_bbox: bool = False): + """ + Processes the frame with the given pipeline node. When return_bbox is True, + the function returns a tuple (detection, bbox) where bbox is (x1,y1,x2,y2) + for drawing. Otherwise, returns only the detection. + """ + try: + results = node["model"].track(frame, stream=False, persist=True) + detection = None + best_box = None + max_conf = -1 + + for r in results: + for box in r.boxes: + box_cpu = box.cpu() + conf = float(box_cpu.conf[0]) + if conf > max_conf and hasattr(box, "id") and box.id is not None: + max_conf = conf + detection = { + "class": node["model"].names[int(box_cpu.cls[0])], + "confidence": conf, + "id": box.id.item() + } + best_box = box_cpu + + bbox = None + if detection and node.get("crop", False) and best_box is not None: + coords = best_box.xyxy[0] + x1, y1, x2, y2 = map(int, coords) + h, w = frame.shape[:2] + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w, x2), min(h, y2) + if x2 > x1 and y2 > y1: + bbox = (x1, y1, x2, y2) + frame = frame[y1:y2, x1:x2] + + if detection is not None: + for branch in node["branches"]: + if detection["class"] in branch.get("triggerClasses", []): + min_conf = branch.get("minConfidence") + if min_conf is not None and detection["confidence"] < min_conf: + logging.debug(f"Confidence {detection['confidence']} below threshold {min_conf} for branch {branch['modelId']}.") + if return_bbox: + return detection, bbox + return detection + res = run_pipeline(frame, branch, return_bbox) + if res is not None: + if return_bbox: + return res + return res + if return_bbox: + return detection, bbox + return detection + if return_bbox: + return None, None + return None + except Exception as e: + logging.error(f"Error running pipeline on node {node.get('modelId')}: {e}") + if return_bbox: + return None, None + return None