from typing import Any, Dict import os import json import time import queue import torch import cv2 import base64 import logging import threading import requests import asyncio import psutil import zipfile from urllib.parse import urlparse from fastapi import FastAPI, WebSocket from fastapi.websockets import WebSocketDisconnect from websockets.exceptions import ConnectionClosedError from ultralytics import YOLO app = FastAPI() # Global dictionaries to keep track of models and streams # "models" now holds a nested dict: { camera_id: { modelId: model_tree } } models: Dict[str, Dict[str, Any]] = {} streams: Dict[str, Dict[str, Any]] = {} with open("config.json", "r") as f: config = json.load(f) poll_interval = config.get("poll_interval_ms", 100) reconnect_interval = config.get("reconnect_interval_sec", 5) TARGET_FPS = config.get("target_fps", 10) poll_interval = 1000 / TARGET_FPS logging.info(f"Poll interval: {poll_interval}ms") max_streams = config.get("max_streams", 5) max_retries = config.get("max_retries", 3) # Configure logging logging.basicConfig( level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.FileHandler("app.log"), logging.StreamHandler() ] ) # Ensure the models directory exists os.makedirs("models", exist_ok=True) # Constants for heartbeat and timeouts HEARTBEAT_INTERVAL = 2 # seconds WORKER_TIMEOUT_MS = 10000 # Locks for thread-safe operations streams_lock = threading.Lock() models_lock = threading.Lock() #################################################### # Pipeline (Model)-loading helper functions #################################################### def load_pipeline_node(node_config: dict, models_dir: str) -> dict: """ Recursively load a model node. Expects node_config to have: - modelId: a unique identifier - modelFile: the .pt file in models_dir - triggerClasses: list of class names that activate child branches - crop: boolean; if True, we crop to the bounding box for the next model - minConfidence: (optional) minimum confidence required to enter this branch - branches: list of child node configurations """ model_path = os.path.join(models_dir, node_config["modelFile"]) if not os.path.exists(model_path): logging.error(f"Model file {model_path} not found.") raise FileNotFoundError(f"Model file {model_path} not found.") logging.info(f"Loading model for node {node_config['modelId']} from {model_path}") model = YOLO(model_path) if torch.cuda.is_available(): model.to("cuda") node = { "modelId": node_config["modelId"], "modelFile": node_config["modelFile"], "triggerClasses": node_config.get("triggerClasses", []), "crop": node_config.get("crop", False), "minConfidence": node_config.get("minConfidence", None), # NEW FIELD "model": model, "branches": [] } for child_config in node_config.get("branches", []): child_node = load_pipeline_node(child_config, models_dir) node["branches"].append(child_node) return node def load_pipeline_from_zip(zip_url: str, target_dir: str) -> dict: """ Download the .mpta file from zip_url, extract it to target_dir, and load the pipeline configuration (pipeline.json). Returns the model tree (root node) loaded with YOLO models. """ os.makedirs(target_dir, exist_ok=True) zip_path = os.path.join(target_dir, "pipeline.mpta") try: response = requests.get(zip_url, stream=True) if response.status_code == 200: with open(zip_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logging.info(f"Downloaded .mpta file from {zip_url} to {zip_path}") else: logging.error(f"Failed to download .mpta file (status {response.status_code})") return None except Exception as e: logging.error(f"Exception downloading .mpta file from {zip_url}: {e}") return None # Extract the .mpta file try: with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(target_dir) logging.info(f"Extracted .mpta file to {target_dir}") except Exception as e: logging.error(f"Failed to extract .mpta file: {e}") return None finally: if os.path.exists(zip_path): os.remove(zip_path) # Load pipeline.json pipeline_json_path = os.path.join(target_dir, "pipeline.json") if not os.path.exists(pipeline_json_path): logging.error("pipeline.json not found in the .mpta file") return None try: with open(pipeline_json_path, "r") as f: pipeline_config = json.load(f) # Build the model tree recursively model_tree = load_pipeline_node(pipeline_config["pipeline"], target_dir) return model_tree except Exception as e: logging.error(f"Error loading pipeline.json: {e}") return None #################################################### # Model execution function #################################################### def run_pipeline(frame, node: dict): """ Run the model at the current node. - Select the highest-confidence detection (if any). - If 'crop' is True, crop to the bounding box for the next stage. - If the detected class matches a branch's triggerClasses, check the confidence. If the detection's confidence is below branch["minConfidence"] (if specified), do not enter the branch and return the current detection. Returns the final detection result (dict) or None. """ try: results = node["model"].track(frame, stream=False, persist=True) detection = None max_conf = -1 best_box = None for r in results: for box in r.boxes: box_cpu = box.cpu() conf = float(box_cpu.conf[0]) if conf > max_conf and hasattr(box, "id") and box.id is not None: max_conf = conf detection = { "class": node["model"].names[int(box_cpu.cls[0])], "confidence": conf, "id": box.id.item(), } best_box = box_cpu # If there's a detection and crop is True, crop frame to bounding box if detection and node.get("crop", False) and best_box is not None: coords = best_box.xyxy[0] # [x1, y1, x2, y2] x1, y1, x2, y2 = map(int, coords) h, w = frame.shape[:2] x1 = max(0, x1) y1 = max(0, y1) x2 = min(w, x2) y2 = min(h, y2) if x2 > x1 and y2 > y1: frame = frame[y1:y2, x1:x2] # crop the frame if detection is not None: # Check if any branch should be entered based on trigger classes for branch in node["branches"]: if detection["class"] in branch.get("triggerClasses", []): # Check for a minimum confidence threshold for this branch min_conf = branch.get("minConfidence") if min_conf is not None and detection["confidence"] < min_conf: logging.debug( f"Detection confidence {detection['confidence']} below threshold " f"{min_conf} for branch {branch['modelId']}. Ending pipeline at current node." ) return detection branch_detection = run_pipeline(frame, branch) if branch_detection is not None: return branch_detection return detection return None except Exception as e: logging.error(f"Error running pipeline on node {node.get('modelId')}: {e}") return None #################################################### # Detection and frame processing functions #################################################### @app.websocket("/") async def detect(websocket: WebSocket): logging.info("WebSocket connection accepted") persistent_data_dict = {} async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data): try: detection_result = run_pipeline(frame, model_tree) detection_data = { "type": "imageDetection", "cameraIdentifier": camera_id, "timestamp": time.time(), "data": { "detection": detection_result if detection_result else None, "modelId": stream["modelId"], "modelName": stream["modelName"] } } logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}") await websocket.send_json(detection_data) return persistent_data except Exception as e: logging.error(f"Error in handle_detection for camera {camera_id}: {e}") return persistent_data def frame_reader(camera_id, cap, buffer, stop_event): retries = 0 try: while not stop_event.is_set(): try: ret, frame = cap.read() if not ret: logging.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}") cap.release() time.sleep(reconnect_interval) retries += 1 if retries > max_retries and max_retries != -1: logging.error(f"Max retries reached for camera: {camera_id}") break # Re-open cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"]) if not cap.isOpened(): logging.error(f"Failed to reopen RTSP stream for camera: {camera_id}") continue continue retries = 0 # Overwrite old frame if buffer is full if not buffer.empty(): try: buffer.get_nowait() except queue.Empty: pass buffer.put(frame) except cv2.error as e: logging.error(f"OpenCV error for camera {camera_id}: {e}") cap.release() time.sleep(reconnect_interval) retries += 1 if retries > max_retries and max_retries != -1: logging.error(f"Max retries reached after OpenCV error for camera {camera_id}") break cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"]) if not cap.isOpened(): logging.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error") continue except Exception as e: logging.error(f"Unexpected error for camera {camera_id}: {e}") cap.release() break except Exception as e: logging.error(f"Error in frame_reader thread for camera {camera_id}: {e}") async def process_streams(): logging.info("Started processing streams") try: while True: start_time = time.time() with streams_lock: current_streams = list(streams.items()) for camera_id, stream in current_streams: buffer = stream["buffer"] if not buffer.empty(): frame = buffer.get() with models_lock: model_tree = models.get(camera_id, {}).get(stream["modelId"]) key = (camera_id, stream["modelId"]) persistent_data = persistent_data_dict.get(key, {}) updated_persistent_data = await handle_detection( camera_id, stream, frame, websocket, model_tree, persistent_data ) persistent_data_dict[key] = updated_persistent_data elapsed_time = (time.time() - start_time) * 1000 # ms sleep_time = max(poll_interval - elapsed_time, 0) logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms") await asyncio.sleep(sleep_time / 1000.0) except asyncio.CancelledError: logging.info("Stream processing task cancelled") except Exception as e: logging.error(f"Error in process_streams: {e}") async def send_heartbeat(): while True: try: cpu_usage = psutil.cpu_percent() memory_usage = psutil.virtual_memory().percent if torch.cuda.is_available(): gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # MB gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # MB else: gpu_usage = None gpu_memory_usage = None camera_connections = [ { "cameraIdentifier": camera_id, "modelId": stream["modelId"], "modelName": stream["modelName"], "online": True } for camera_id, stream in streams.items() ] state_report = { "type": "stateReport", "cpuUsage": cpu_usage, "memoryUsage": memory_usage, "gpuUsage": gpu_usage, "gpuMemoryUsage": gpu_memory_usage, "cameraConnections": camera_connections } await websocket.send_text(json.dumps(state_report)) logging.debug("Sent stateReport as heartbeat") await asyncio.sleep(HEARTBEAT_INTERVAL) except Exception as e: logging.error(f"Error sending stateReport heartbeat: {e}") break async def on_message(): while True: try: msg = await websocket.receive_text() logging.debug(f"Received message: {msg}") data = json.loads(msg) msg_type = data.get("type") if msg_type == "subscribe": payload = data.get("payload", {}) camera_id = payload.get("cameraIdentifier") rtsp_url = payload.get("rtspUrl") model_url = payload.get("modelUrl") # ZIP file URL modelId = payload.get("modelId") modelName = payload.get("modelName") if model_url: with models_lock: if camera_id not in models: models[camera_id] = {} if modelId not in models[camera_id]: logging.info(f"Downloading model from {model_url}") extraction_dir = os.path.join("models", camera_id, str(modelId)) os.makedirs(extraction_dir, exist_ok=True) model_tree = load_pipeline_from_zip(model_url, extraction_dir) if model_tree is None: logging.error("Failed to load model from ZIP file.") continue models[camera_id][modelId] = model_tree logging.info(f"Loaded model {modelId} for camera {camera_id}") if camera_id and rtsp_url: with streams_lock: if camera_id not in streams and len(streams) < max_streams: cap = cv2.VideoCapture(rtsp_url) if not cap.isOpened(): logging.error(f"Failed to open RTSP stream for camera {camera_id}") continue buffer = queue.Queue(maxsize=1) stop_event = threading.Event() thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event)) thread.daemon = True thread.start() streams[camera_id] = { "cap": cap, "buffer": buffer, "thread": thread, "rtsp_url": rtsp_url, "stop_event": stop_event, "modelId": modelId, "modelName": modelName } logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName}, URL {rtsp_url}") elif camera_id and camera_id in streams: # If already subscribed, unsubscribe stream = streams.pop(camera_id) stream["cap"].release() logging.info(f"Unsubscribed from camera {camera_id}") with models_lock: if camera_id in models and modelId in models[camera_id]: del models[camera_id][modelId] if not models[camera_id]: del models[camera_id] elif msg_type == "unsubscribe": payload = data.get("payload", {}) camera_id = payload.get("cameraIdentifier") logging.debug(f"Unsubscribing from camera {camera_id}") with streams_lock: if camera_id and camera_id in streams: stream = streams.pop(camera_id) stream["stop_event"].set() stream["thread"].join() stream["cap"].release() logging.info(f"Unsubscribed from camera {camera_id}") with models_lock: if camera_id in models: del models[camera_id] elif msg_type == "requestState": cpu_usage = psutil.cpu_percent() memory_usage = psutil.virtual_memory().percent if torch.cuda.is_available(): gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) else: gpu_usage = None gpu_memory_usage = None camera_connections = [ { "cameraIdentifier": camera_id, "modelId": stream["modelId"], "modelName": stream["modelName"], "online": True } for camera_id, stream in streams.items() ] state_report = { "type": "stateReport", "cpuUsage": cpu_usage, "memoryUsage": memory_usage, "gpuUsage": gpu_usage, "gpuMemoryUsage": gpu_memory_usage, "cameraConnections": camera_connections } await websocket.send_text(json.dumps(state_report)) else: logging.error(f"Unknown message type: {msg_type}") except json.JSONDecodeError: logging.error("Received invalid JSON message") except (WebSocketDisconnect, ConnectionClosedError) as e: logging.warning(f"WebSocket disconnected: {e}") break except Exception as e: logging.error(f"Error handling message: {e}") break try: await websocket.accept() stream_task = asyncio.create_task(process_streams()) heartbeat_task = asyncio.create_task(send_heartbeat()) message_task = asyncio.create_task(on_message()) await asyncio.gather(heartbeat_task, message_task) except Exception as e: logging.error(f"Error in detect websocket: {e}") finally: stream_task.cancel() await stream_task with streams_lock: for camera_id, stream in streams.items(): stream["stop_event"].set() stream["thread"].join() stream["cap"].release() while not stream["buffer"].empty(): try: stream["buffer"].get_nowait() except queue.Empty: pass logging.info(f"Released camera {camera_id} and cleaned up resources") streams.clear() with models_lock: models.clear() logging.info("WebSocket connection closed")