import os import json import logging import torch import cv2 import zipfile import shutil import traceback import redis import time import uuid import concurrent.futures from ultralytics import YOLO from urllib.parse import urlparse from .database import DatabaseManager # Create a logger specifically for this module logger = logging.getLogger("detector_worker.pympta") def validate_redis_config(redis_config: dict) -> bool: """Validate Redis configuration parameters.""" required_fields = ["host", "port"] for field in required_fields: if field not in redis_config: logger.error(f"Missing required Redis config field: {field}") return False if not isinstance(redis_config["port"], int) or redis_config["port"] <= 0: logger.error(f"Invalid Redis port: {redis_config['port']}") return False return True def validate_postgresql_config(pg_config: dict) -> bool: """Validate PostgreSQL configuration parameters.""" required_fields = ["host", "port", "database", "username", "password"] for field in required_fields: if field not in pg_config: logger.error(f"Missing required PostgreSQL config field: {field}") return False if not isinstance(pg_config["port"], int) or pg_config["port"] <= 0: logger.error(f"Invalid PostgreSQL port: {pg_config['port']}") return False return True def crop_region_by_class(frame, regions_dict, class_name): """Crop a specific region from frame based on detected class.""" if class_name not in regions_dict: logger.warning(f"Class '{class_name}' not found in detected regions") return None bbox = regions_dict[class_name]['bbox'] x1, y1, x2, y2 = bbox cropped = frame[y1:y2, x1:x2] if cropped.size == 0: logger.warning(f"Empty crop for class '{class_name}' with bbox {bbox}") return None return cropped def format_action_context(base_context, additional_context=None): """Format action context with dynamic values.""" context = {**base_context} if additional_context: context.update(additional_context) return context def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client, db_manager=None) -> dict: # Recursively load a model node from configuration. model_path = os.path.join(mpta_dir, node_config["modelFile"]) if not os.path.exists(model_path): logger.error(f"Model file {model_path} not found. Current directory: {os.getcwd()}") logger.error(f"Directory content: {os.listdir(os.path.dirname(model_path))}") raise FileNotFoundError(f"Model file {model_path} not found.") logger.info(f"Loading model for node {node_config['modelId']} from {model_path}") model = YOLO(model_path) if torch.cuda.is_available(): logger.info(f"CUDA available. Moving model {node_config['modelId']} to GPU") model.to("cuda") else: logger.info(f"CUDA not available. Using CPU for model {node_config['modelId']}") # Prepare trigger class indices for optimization trigger_classes = node_config.get("triggerClasses", []) trigger_class_indices = None if trigger_classes and hasattr(model, "names"): # Convert class names to indices for the model trigger_class_indices = [i for i, name in model.names.items() if name in trigger_classes] logger.debug(f"Converted trigger classes to indices: {trigger_class_indices}") node = { "modelId": node_config["modelId"], "modelFile": node_config["modelFile"], "triggerClasses": trigger_classes, "triggerClassIndices": trigger_class_indices, "crop": node_config.get("crop", False), "cropClass": node_config.get("cropClass"), "minConfidence": node_config.get("minConfidence", None), "multiClass": node_config.get("multiClass", False), "expectedClasses": node_config.get("expectedClasses", []), "parallel": node_config.get("parallel", False), "actions": node_config.get("actions", []), "parallelActions": node_config.get("parallelActions", []), "model": model, "branches": [], "redis_client": redis_client, "db_manager": db_manager } logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}") for child in node_config.get("branches", []): logger.debug(f"Loading branch for parent node {node_config['modelId']}") node["branches"].append(load_pipeline_node(child, mpta_dir, redis_client, db_manager)) return node def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict: logger.info(f"Attempting to load pipeline from {zip_source} to {target_dir}") os.makedirs(target_dir, exist_ok=True) zip_path = os.path.join(target_dir, "pipeline.mpta") # Parse the source; only local files are supported here. parsed = urlparse(zip_source) if parsed.scheme in ("", "file"): local_path = parsed.path if parsed.scheme == "file" else zip_source logger.debug(f"Checking if local file exists: {local_path}") if os.path.exists(local_path): try: shutil.copy(local_path, zip_path) logger.info(f"Copied local .mpta file from {local_path} to {zip_path}") except Exception as e: logger.error(f"Failed to copy local .mpta file from {local_path}: {str(e)}", exc_info=True) return None else: logger.error(f"Local file {local_path} does not exist. Current directory: {os.getcwd()}") # List all subdirectories of models directory to help debugging if os.path.exists("models"): logger.error(f"Content of models directory: {os.listdir('models')}") for root, dirs, files in os.walk("models"): logger.error(f"Directory {root} contains subdirs: {dirs} and files: {files}") else: logger.error("The models directory doesn't exist") return None else: logger.error(f"HTTP download functionality has been moved. Use a local file path here. Received: {zip_source}") return None try: if not os.path.exists(zip_path): logger.error(f"Zip file not found at expected location: {zip_path}") return None logger.debug(f"Extracting .mpta file from {zip_path} to {target_dir}") # Extract contents and track the directories created extracted_dirs = [] with zipfile.ZipFile(zip_path, "r") as zip_ref: file_list = zip_ref.namelist() logger.debug(f"Files in .mpta archive: {file_list}") # Extract and track the top-level directories for file_path in file_list: parts = file_path.split('/') if len(parts) > 1: top_dir = parts[0] if top_dir and top_dir not in extracted_dirs: extracted_dirs.append(top_dir) # Now extract the files zip_ref.extractall(target_dir) logger.info(f"Successfully extracted .mpta file to {target_dir}") logger.debug(f"Extracted directories: {extracted_dirs}") # Check what was actually created after extraction actual_dirs = [d for d in os.listdir(target_dir) if os.path.isdir(os.path.join(target_dir, d))] logger.debug(f"Actual directories created: {actual_dirs}") except zipfile.BadZipFile as e: logger.error(f"Bad zip file {zip_path}: {str(e)}", exc_info=True) return None except Exception as e: logger.error(f"Failed to extract .mpta file {zip_path}: {str(e)}", exc_info=True) return None finally: if os.path.exists(zip_path): os.remove(zip_path) logger.debug(f"Removed temporary zip file: {zip_path}") # Use the first extracted directory if it exists, otherwise use the expected name pipeline_name = os.path.basename(zip_source) pipeline_name = os.path.splitext(pipeline_name)[0] # Find the directory with pipeline.json mpta_dir = None # First try the expected directory name expected_dir = os.path.join(target_dir, pipeline_name) if os.path.exists(expected_dir) and os.path.exists(os.path.join(expected_dir, "pipeline.json")): mpta_dir = expected_dir logger.debug(f"Found pipeline.json in the expected directory: {mpta_dir}") else: # Look through all subdirectories for pipeline.json for subdir in actual_dirs: potential_dir = os.path.join(target_dir, subdir) if os.path.exists(os.path.join(potential_dir, "pipeline.json")): mpta_dir = potential_dir logger.info(f"Found pipeline.json in directory: {mpta_dir} (different from expected: {expected_dir})") break if not mpta_dir: logger.error(f"Could not find pipeline.json in any extracted directory. Directory content: {os.listdir(target_dir)}") return None pipeline_json_path = os.path.join(mpta_dir, "pipeline.json") if not os.path.exists(pipeline_json_path): logger.error(f"pipeline.json not found in the .mpta file. Files in directory: {os.listdir(mpta_dir)}") return None try: with open(pipeline_json_path, "r") as f: pipeline_config = json.load(f) logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}") logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}") # Establish Redis connection if configured redis_client = None if "redis" in pipeline_config: redis_config = pipeline_config["redis"] if not validate_redis_config(redis_config): logger.error("Invalid Redis configuration, skipping Redis connection") else: try: redis_client = redis.Redis( host=redis_config["host"], port=redis_config["port"], password=redis_config.get("password"), db=redis_config.get("db", 0), decode_responses=True ) redis_client.ping() logger.info(f"Successfully connected to Redis at {redis_config['host']}:{redis_config['port']}") except redis.exceptions.ConnectionError as e: logger.error(f"Failed to connect to Redis: {e}") redis_client = None # Establish PostgreSQL connection if configured db_manager = None if "postgresql" in pipeline_config: pg_config = pipeline_config["postgresql"] if not validate_postgresql_config(pg_config): logger.error("Invalid PostgreSQL configuration, skipping database connection") else: try: db_manager = DatabaseManager(pg_config) if db_manager.connect(): logger.info(f"Successfully connected to PostgreSQL at {pg_config['host']}:{pg_config['port']}") else: logger.error("Failed to connect to PostgreSQL") db_manager = None except Exception as e: logger.error(f"Error initializing PostgreSQL connection: {e}") db_manager = None return load_pipeline_node(pipeline_config["pipeline"], mpta_dir, redis_client, db_manager) except json.JSONDecodeError as e: logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True) return None except KeyError as e: logger.error(f"Missing key in pipeline.json: {str(e)}", exc_info=True) return None except Exception as e: logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True) return None def execute_actions(node, frame, detection_result, regions_dict=None): if not node["redis_client"] or not node["actions"]: return # Create a dynamic context for this detection event from datetime import datetime action_context = { **detection_result, "timestamp_ms": int(time.time() * 1000), "uuid": str(uuid.uuid4()), "timestamp": datetime.now().strftime("%Y-%m-%dT%H-%M-%S"), "filename": f"{uuid.uuid4()}.jpg" } for action in node["actions"]: try: if action["type"] == "redis_save_image": key = action["key"].format(**action_context) # Check if we need to crop a specific region region_name = action.get("region") image_to_save = frame if region_name and regions_dict: cropped_image = crop_region_by_class(frame, regions_dict, region_name) if cropped_image is not None: image_to_save = cropped_image logger.debug(f"Cropped region '{region_name}' for redis_save_image") else: logger.warning(f"Could not crop region '{region_name}', saving full frame instead") # Encode image with specified format and quality (default to JPEG) img_format = action.get("format", "jpeg").lower() quality = action.get("quality", 90) if img_format == "jpeg": encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] success, buffer = cv2.imencode('.jpg', image_to_save, encode_params) elif img_format == "png": success, buffer = cv2.imencode('.png', image_to_save) else: success, buffer = cv2.imencode('.jpg', image_to_save, [cv2.IMWRITE_JPEG_QUALITY, quality]) if not success: logger.error(f"Failed to encode image for redis_save_image") continue expire_seconds = action.get("expire_seconds") if expire_seconds: node["redis_client"].setex(key, expire_seconds, buffer.tobytes()) logger.info(f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)") else: node["redis_client"].set(key, buffer.tobytes()) logger.info(f"Saved image to Redis with key: {key}") action_context["image_key"] = key elif action["type"] == "redis_publish": channel = action["channel"] try: # Handle JSON message format by creating it programmatically message_template = action["message"] # Check if the message is JSON-like (starts and ends with braces) if message_template.strip().startswith('{') and message_template.strip().endswith('}'): # Create JSON data programmatically to avoid formatting issues json_data = {} # Add common fields json_data["event"] = "frontal_detected" json_data["display_id"] = action_context.get("display_id", "unknown") json_data["session_id"] = action_context.get("session_id") json_data["timestamp"] = action_context.get("timestamp", "") json_data["image_key"] = action_context.get("image_key", "") # Convert to JSON string message = json.dumps(json_data) else: # Use regular string formatting for non-JSON messages message = message_template.format(**action_context) # Publish to Redis if not node["redis_client"]: logger.error("Redis client is None, cannot publish message") continue # Test Redis connection try: node["redis_client"].ping() logger.debug("Redis connection is active") except Exception as ping_error: logger.error(f"Redis connection test failed: {ping_error}") continue result = node["redis_client"].publish(channel, message) logger.info(f"Published message to Redis channel '{channel}': {message}") logger.info(f"Redis publish result (subscribers count): {result}") # Additional debug info if result == 0: logger.warning(f"No subscribers listening to channel '{channel}'") else: logger.info(f"Message delivered to {result} subscriber(s)") except KeyError as e: logger.error(f"Missing key in redis_publish message template: {e}") logger.debug(f"Available context keys: {list(action_context.keys())}") except Exception as e: logger.error(f"Error in redis_publish action: {e}") logger.debug(f"Message template: {action['message']}") logger.debug(f"Available context keys: {list(action_context.keys())}") import traceback logger.debug(f"Full traceback: {traceback.format_exc()}") except Exception as e: logger.error(f"Error executing action {action['type']}: {e}") def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None): """ Enhanced pipeline that supports: - Multi-class detection (detecting multiple classes simultaneously) - Parallel branch processing - Region-based actions and cropping - Context passing for session/camera information """ try: task = getattr(node["model"], "task", None) # ─── Classification stage ─────────────────────────────────── if task == "classify": results = node["model"].predict(frame, stream=False) if not results: return (None, None) if return_bbox else None r = results[0] probs = r.probs if probs is None: return (None, None) if return_bbox else None top1_idx = int(probs.top1) top1_conf = float(probs.top1conf) det = { "class": node["model"].names[top1_idx], "confidence": top1_conf, "id": None, node["model"].names[top1_idx]: node["model"].names[top1_idx] # Add class name as key } execute_actions(node, frame, det) return (det, None) if return_bbox else det # ─── Detection stage - Multi-class support ────────────────── tk = node["triggerClassIndices"] res = node["model"].track( frame, stream=False, persist=True, **({"classes": tk} if tk else {}) )[0] # Collect all detections above confidence threshold all_detections = [] all_boxes = [] regions_dict = {} for box in res.boxes: conf = float(box.cpu().conf[0]) cid = int(box.cpu().cls[0]) name = node["model"].names[cid] if conf < node["minConfidence"]: continue xy = box.cpu().xyxy[0] x1, y1, x2, y2 = map(int, xy) bbox = (x1, y1, x2, y2) detection = { "class": name, "confidence": conf, "id": box.id.item() if hasattr(box, "id") else None, "bbox": bbox } all_detections.append(detection) all_boxes.append(bbox) # Store highest confidence detection for each class if name not in regions_dict or conf > regions_dict[name]["confidence"]: regions_dict[name] = { "bbox": bbox, "confidence": conf, "detection": detection } if not all_detections: return (None, None) if return_bbox else None # ─── Multi-class validation ───────────────────────────────── if node.get("multiClass", False) and node.get("expectedClasses"): expected_classes = node["expectedClasses"] detected_classes = list(regions_dict.keys()) # Check if all expected classes are detected missing_classes = [cls for cls in expected_classes if cls not in detected_classes] if missing_classes: logger.debug(f"Missing expected classes: {missing_classes}. Detected: {detected_classes}") return (None, None) if return_bbox else None logger.info(f"Multi-class detection success: {detected_classes}") # ─── Execute actions with region information ──────────────── detection_result = { "detections": all_detections, "regions": regions_dict, **(context or {}) } execute_actions(node, frame, detection_result, regions_dict) # ─── Parallel branch processing ───────────────────────────── if node["branches"]: branch_results = {} # Filter branches that should be triggered active_branches = [] for br in node["branches"]: trigger_classes = br.get("triggerClasses", []) min_conf = br.get("minConfidence", 0) # Check if any detected class matches branch trigger for det_class in regions_dict: if (det_class in trigger_classes and regions_dict[det_class]["confidence"] >= min_conf): active_branches.append(br) break if active_branches: if node.get("parallel", False) or any(br.get("parallel", False) for br in active_branches): # Run branches in parallel with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_branches)) as executor: futures = {} for br in active_branches: crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None) sub_frame = frame if br.get("crop", False) and crop_class: cropped = crop_region_by_class(frame, regions_dict, crop_class) if cropped is not None: sub_frame = cv2.resize(cropped, (224, 224)) else: continue future = executor.submit(run_pipeline, sub_frame, br, True, context) futures[future] = br # Collect results for future in concurrent.futures.as_completed(futures): br = futures[future] try: result, _ = future.result() if result: branch_results[br["modelId"]] = result logger.info(f"Branch {br['modelId']} completed: {result}") except Exception as e: logger.error(f"Branch {br['modelId']} failed: {e}") else: # Run branches sequentially for br in active_branches: crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None) sub_frame = frame if br.get("crop", False) and crop_class: cropped = crop_region_by_class(frame, regions_dict, crop_class) if cropped is not None: sub_frame = cv2.resize(cropped, (224, 224)) else: continue result, _ = run_pipeline(sub_frame, br, True, context) if result: branch_results[br["modelId"]] = result logger.info(f"Branch {br['modelId']} completed: {result}") # Store branch results in detection_result for parallel actions detection_result["branch_results"] = branch_results # ─── Return detection result ──────────────────────────────── primary_detection = max(all_detections, key=lambda x: x["confidence"]) primary_bbox = primary_detection["bbox"] # Add branch results to primary detection for compatibility if "branch_results" in detection_result: primary_detection["branch_results"] = detection_result["branch_results"] return (primary_detection, primary_bbox) if return_bbox else primary_detection except Exception as e: logger.error(f"Error in node {node.get('modelId')}: {e}") traceback.print_exc() return (None, None) if return_bbox else None