diff --git a/app.py b/app.py index f0c8266..aace69c 100644 --- a/app.py +++ b/app.py @@ -1,31 +1,35 @@ -from typing import List +from typing import Any, Dict +import os +import json +import time +import queue +import torch +import cv2 +import base64 +import logging +import threading +import requests +import asyncio +import psutil +import zipfile +from urllib.parse import urlparse from fastapi import FastAPI, WebSocket from fastapi.websockets import WebSocketDisconnect from websockets.exceptions import ConnectionClosedError from ultralytics import YOLO -import torch -import cv2 -import base64 -import numpy as np -import json -import logging -import threading -import queue -import os -import requests -from urllib.parse import urlparse -import asyncio -import psutil app = FastAPI() -models = {} +# Global dictionaries to keep track of models and streams +# "models" now holds a nested dict: { camera_id: { modelId: model_tree } } +models: Dict[str, Dict[str, Any]] = {} +streams: Dict[str, Dict[str, Any]] = {} with open("config.json", "r") as f: config = json.load(f) poll_interval = config.get("poll_interval_ms", 100) -reconnect_interval = config.get("reconnect_interval_sec", 5) +reconnect_interval = config.get("reconnect_interval_sec", 5) TARGET_FPS = config.get("target_fps", 10) poll_interval = 1000 / TARGET_FPS logging.info(f"Poll interval: {poll_interval}ms") @@ -45,51 +49,188 @@ logging.basicConfig( # Ensure the models directory exists os.makedirs("models", exist_ok=True) -# Add constants for heartbeat +# Constants for heartbeat and timeouts HEARTBEAT_INTERVAL = 2 # seconds WORKER_TIMEOUT_MS = 10000 -# Add a lock for thread-safe operations on shared resources +# Locks for thread-safe operations streams_lock = threading.Lock() models_lock = threading.Lock() +#################################################### +# Pipeline (Model)-loading helper functions +#################################################### +def load_pipeline_node(node_config: dict, models_dir: str) -> dict: + """ + Recursively load a model node. + Expects node_config to have: + - modelId: a unique identifier + - modelFile: the .pt file in models_dir + - triggerClasses: list of class names that activate child branches + - crop: boolean; if True, we crop to the bounding box for the next model + - minConfidence: (optional) minimum confidence required to enter this branch + - branches: list of child node configurations + """ + model_path = os.path.join(models_dir, node_config["modelFile"]) + if not os.path.exists(model_path): + logging.error(f"Model file {model_path} not found.") + raise FileNotFoundError(f"Model file {model_path} not found.") + + logging.info(f"Loading model for node {node_config['modelId']} from {model_path}") + model = YOLO(model_path) + if torch.cuda.is_available(): + model.to("cuda") + + node = { + "modelId": node_config["modelId"], + "modelFile": node_config["modelFile"], + "triggerClasses": node_config.get("triggerClasses", []), + "crop": node_config.get("crop", False), + "minConfidence": node_config.get("minConfidence", None), # NEW FIELD + "model": model, + "branches": [] + } + for child_config in node_config.get("branches", []): + child_node = load_pipeline_node(child_config, models_dir) + node["branches"].append(child_node) + return node + +def load_pipeline_from_zip(zip_url: str, target_dir: str) -> dict: + """ + Download the .mpta file from zip_url, extract it to target_dir, + and load the pipeline configuration (pipeline.json). + Returns the model tree (root node) loaded with YOLO models. + """ + os.makedirs(target_dir, exist_ok=True) + zip_path = os.path.join(target_dir, "pipeline.mpta") + + try: + response = requests.get(zip_url, stream=True) + if response.status_code == 200: + with open(zip_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logging.info(f"Downloaded .mpta file from {zip_url} to {zip_path}") + else: + logging.error(f"Failed to download .mpta file (status {response.status_code})") + return None + except Exception as e: + logging.error(f"Exception downloading .mpta file from {zip_url}: {e}") + return None + + # Extract the .mpta file + try: + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(target_dir) + logging.info(f"Extracted .mpta file to {target_dir}") + except Exception as e: + logging.error(f"Failed to extract .mpta file: {e}") + return None + finally: + if os.path.exists(zip_path): + os.remove(zip_path) + + # Load pipeline.json + pipeline_json_path = os.path.join(target_dir, "pipeline.json") + if not os.path.exists(pipeline_json_path): + logging.error("pipeline.json not found in the .mpta file") + return None + + try: + with open(pipeline_json_path, "r") as f: + pipeline_config = json.load(f) + # Build the model tree recursively + model_tree = load_pipeline_node(pipeline_config["pipeline"], target_dir) + return model_tree + except Exception as e: + logging.error(f"Error loading pipeline.json: {e}") + return None + +#################################################### +# Model execution function +#################################################### +def run_pipeline(frame, node: dict): + """ + Run the model at the current node. + - Select the highest-confidence detection (if any). + - If 'crop' is True, crop to the bounding box for the next stage. + - If the detected class matches a branch's triggerClasses, check the confidence. + If the detection's confidence is below branch["minConfidence"] (if specified), + do not enter the branch and return the current detection. + Returns the final detection result (dict) or None. + """ + try: + results = node["model"].track(frame, stream=False, persist=True) + detection = None + max_conf = -1 + best_box = None + + for r in results: + for box in r.boxes: + box_cpu = box.cpu() + conf = float(box_cpu.conf[0]) + if conf > max_conf and hasattr(box, "id") and box.id is not None: + max_conf = conf + detection = { + "class": node["model"].names[int(box_cpu.cls[0])], + "confidence": conf, + "id": box.id.item(), + } + best_box = box_cpu + + # If there's a detection and crop is True, crop frame to bounding box + if detection and node.get("crop", False) and best_box is not None: + coords = best_box.xyxy[0] # [x1, y1, x2, y2] + x1, y1, x2, y2 = map(int, coords) + h, w = frame.shape[:2] + x1 = max(0, x1) + y1 = max(0, y1) + x2 = min(w, x2) + y2 = min(h, y2) + + if x2 > x1 and y2 > y1: + frame = frame[y1:y2, x1:x2] # crop the frame + + if detection is not None: + # Check if any branch should be entered based on trigger classes + for branch in node["branches"]: + if detection["class"] in branch.get("triggerClasses", []): + # Check for a minimum confidence threshold for this branch + min_conf = branch.get("minConfidence") + if min_conf is not None and detection["confidence"] < min_conf: + logging.debug( + f"Detection confidence {detection['confidence']} below threshold " + f"{min_conf} for branch {branch['modelId']}. Ending pipeline at current node." + ) + return detection + branch_detection = run_pipeline(frame, branch) + if branch_detection is not None: + return branch_detection + return detection + return None + except Exception as e: + logging.error(f"Error running pipeline on node {node.get('modelId')}: {e}") + return None + +#################################################### +# Detection and frame processing functions +#################################################### @app.websocket("/") async def detect(websocket: WebSocket): - import asyncio - import time - logging.info("WebSocket connection accepted") + persistent_data_dict = {} - streams = {} - - # This function is user-modifiable - # Save data you want to persist across frames in the persistent_data dictionary - async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data): + async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data): try: - highest_conf_box = None - max_conf = -1 - - for r in model.track(frame, stream=False, persist=True): - for box in r.boxes: - box_cpu = box.cpu() - conf = float(box_cpu.conf[0]) - if conf > max_conf and hasattr(box, "id") and box.id is not None: - max_conf = conf - highest_conf_box = { - "class": model.names[int(box_cpu.cls[0])], - "confidence": conf, - "id": box.id.item(), - } - - # Broadcast to all subscribers of this URL + detection_result = run_pipeline(frame, model_tree) detection_data = { "type": "imageDetection", "cameraIdentifier": camera_id, "timestamp": time.time(), "data": { - "detections": highest_conf_box if highest_conf_box else None, - "modelId": stream['modelId'], - "modelName": stream['modelName'] + "detection": detection_result if detection_result else None, + "modelId": stream["modelId"], + "modelName": stream["modelName"] } } logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}") @@ -100,7 +241,6 @@ async def detect(websocket: WebSocket): return persistent_data def frame_reader(camera_id, cap, buffer, stop_event): - import time retries = 0 try: while not stop_event.is_set(): @@ -114,16 +254,17 @@ async def detect(websocket: WebSocket): if retries > max_retries and max_retries != -1: logging.error(f"Max retries reached for camera: {camera_id}") break - # Re-open the VideoCapture - cap = cv2.VideoCapture(streams[camera_id]['rtsp_url']) + # Re-open + cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"]) if not cap.isOpened(): logging.error(f"Failed to reopen RTSP stream for camera: {camera_id}") continue continue - retries = 0 # Reset on success + retries = 0 + # Overwrite old frame if buffer is full if not buffer.empty(): try: - buffer.get_nowait() # Discard the old frame + buffer.get_nowait() except queue.Empty: pass buffer.put(frame) @@ -133,10 +274,9 @@ async def detect(websocket: WebSocket): time.sleep(reconnect_interval) retries += 1 if retries > max_retries and max_retries != -1: - logging.error(f"Max retries reached after OpenCV error for camera: {camera_id}") + logging.error(f"Max retries reached after OpenCV error for camera {camera_id}") break - # Re-open the VideoCapture - cap = cv2.VideoCapture(streams[camera_id]['rtsp_url']) + cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"]) if not cap.isOpened(): logging.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error") continue @@ -148,26 +288,25 @@ async def detect(websocket: WebSocket): logging.error(f"Error in frame_reader thread for camera {camera_id}: {e}") async def process_streams(): - global models logging.info("Started processing streams") - persistent_data_dict = {} try: while True: start_time = time.time() - # Round-robin processing with streams_lock: current_streams = list(streams.items()) for camera_id, stream in current_streams: - buffer = stream['buffer'] + buffer = stream["buffer"] if not buffer.empty(): frame = buffer.get() with models_lock: - model = models.get(camera_id, {}).get(stream['modelId']) - key = (camera_id, stream['modelId']) + model_tree = models.get(camera_id, {}).get(stream["modelId"]) + key = (camera_id, stream["modelId"]) persistent_data = persistent_data_dict.get(key, {}) - updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data) + updated_persistent_data = await handle_detection( + camera_id, stream, frame, websocket, model_tree, persistent_data + ) persistent_data_dict[key] = updated_persistent_data - elapsed_time = (time.time() - start_time) * 1000 # in ms + elapsed_time = (time.time() - start_time) * 1000 # ms sleep_time = max(poll_interval - elapsed_time, 0) logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms") await asyncio.sleep(sleep_time / 1000.0) @@ -182,22 +321,22 @@ async def detect(websocket: WebSocket): cpu_usage = psutil.cpu_percent() memory_usage = psutil.virtual_memory().percent if torch.cuda.is_available(): - gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB - gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB + gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # MB + gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # MB else: gpu_usage = None gpu_memory_usage = None - + camera_connections = [ { "cameraIdentifier": camera_id, - "modelId": stream['modelId'], - "modelName": stream['modelName'], + "modelId": stream["modelId"], + "modelName": stream["modelName"], "online": True } for camera_id, stream in streams.items() ] - + state_report = { "type": "stateReport", "cpuUsage": cpu_usage, @@ -214,12 +353,10 @@ async def detect(websocket: WebSocket): break async def on_message(): - global models while True: try: msg = await websocket.receive_text() logging.debug(f"Received message: {msg}") - print(f"Received message: {msg}") data = json.loads(msg) msg_type = data.get("type") @@ -227,34 +364,25 @@ async def detect(websocket: WebSocket): payload = data.get("payload", {}) camera_id = payload.get("cameraIdentifier") rtsp_url = payload.get("rtspUrl") - model_url = payload.get("modelUrl") + model_url = payload.get("modelUrl") # ZIP file URL modelId = payload.get("modelId") modelName = payload.get("modelName") - + if model_url: with models_lock: if camera_id not in models: models[camera_id] = {} if modelId not in models[camera_id]: - print(f"Downloading model from {model_url}") - parsed_url = urlparse(model_url) - filename = os.path.basename(parsed_url.path) - model_filename = os.path.join("models", filename) - # Download the model - response = requests.get(model_url, stream=True) - if response.status_code == 200: - with open(model_filename, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - logging.info(f"Downloaded model from {model_url} to {model_filename}") - model = YOLO(model_filename) - if torch.cuda.is_available(): - model.to('cuda') - models[camera_id][modelId] = model - logging.info(f"Loaded model {modelId} for camera {camera_id}") - else: - logging.error(f"Failed to download model from {model_url}") + logging.info(f"Downloading model from {model_url}") + extraction_dir = os.path.join("models", camera_id, str(modelId)) + os.makedirs(extraction_dir, exist_ok=True) + model_tree = load_pipeline_from_zip(model_url, extraction_dir) + if model_tree is None: + logging.error("Failed to load model from ZIP file.") continue + models[camera_id][modelId] = model_tree + logging.info(f"Loaded model {modelId} for camera {camera_id}") + if camera_id and rtsp_url: with streams_lock: if camera_id not in streams and len(streams) < max_streams: @@ -268,23 +396,25 @@ async def detect(websocket: WebSocket): thread.daemon = True thread.start() streams[camera_id] = { - 'cap': cap, - 'buffer': buffer, - 'thread': thread, - 'rtsp_url': rtsp_url, - 'stop_event': stop_event, - 'modelId': modelId, - 'modelName': modelName + "cap": cap, + "buffer": buffer, + "thread": thread, + "rtsp_url": rtsp_url, + "stop_event": stop_event, + "modelId": modelId, + "modelName": modelName } - logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName} and URL {rtsp_url}") + logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName}, URL {rtsp_url}") elif camera_id and camera_id in streams: + # If already subscribed, unsubscribe stream = streams.pop(camera_id) - stream['cap'].release() + stream["cap"].release() logging.info(f"Unsubscribed from camera {camera_id}") - if camera_id in models and modelId in models[camera_id]: - del models[camera_id][modelId] - if not models[camera_id]: - del models[camera_id] + with models_lock: + if camera_id in models and modelId in models[camera_id]: + del models[camera_id][modelId] + if not models[camera_id]: + del models[camera_id] elif msg_type == "unsubscribe": payload = data.get("payload", {}) camera_id = payload.get("cameraIdentifier") @@ -292,35 +422,33 @@ async def detect(websocket: WebSocket): with streams_lock: if camera_id and camera_id in streams: stream = streams.pop(camera_id) - stream['stop_event'].set() - stream['thread'].join() - stream['cap'].release() + stream["stop_event"].set() + stream["thread"].join() + stream["cap"].release() logging.info(f"Unsubscribed from camera {camera_id}") - if camera_id in models and modelId in models[camera_id]: - del models[camera_id][modelId] - if not models[camera_id]: + with models_lock: + if camera_id in models: del models[camera_id] elif msg_type == "requestState": - # Handle state request cpu_usage = psutil.cpu_percent() memory_usage = psutil.virtual_memory().percent if torch.cuda.is_available(): - gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB - gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB + gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) + gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) else: gpu_usage = None gpu_memory_usage = None - + camera_connections = [ { "cameraIdentifier": camera_id, - "modelId": stream['modelId'], - "modelName": stream['modelName'], + "modelId": stream["modelId"], + "modelName": stream["modelName"], "online": True } for camera_id, stream in streams.items() ] - + state_report = { "type": "stateReport", "cpuUsage": cpu_usage, @@ -336,31 +464,34 @@ async def detect(websocket: WebSocket): logging.error("Received invalid JSON message") except (WebSocketDisconnect, ConnectionClosedError) as e: logging.warning(f"WebSocket disconnected: {e}") - break + break except Exception as e: logging.error(f"Error handling message: {e}") break try: await websocket.accept() - task = asyncio.create_task(process_streams()) + stream_task = asyncio.create_task(process_streams()) heartbeat_task = asyncio.create_task(send_heartbeat()) message_task = asyncio.create_task(on_message()) - await asyncio.gather(heartbeat_task, message_task) except Exception as e: logging.error(f"Error in detect websocket: {e}") finally: - task.cancel() - await task + stream_task.cancel() + await stream_task with streams_lock: for camera_id, stream in streams.items(): - stream['stop_event'].set() - stream['thread'].join() - stream['cap'].release() - stream['buffer'].queue.clear() + stream["stop_event"].set() + stream["thread"].join() + stream["cap"].release() + while not stream["buffer"].empty(): + try: + stream["buffer"].get_nowait() + except queue.Empty: + pass logging.info(f"Released camera {camera_id} and cleaned up resources") streams.clear() with models_lock: models.clear() - logging.info("WebSocket connection closed") \ No newline at end of file + logging.info("WebSocket connection closed") diff --git a/app_single.py b/app_single.py new file mode 100644 index 0000000..f0c8266 --- /dev/null +++ b/app_single.py @@ -0,0 +1,366 @@ +from typing import List +from fastapi import FastAPI, WebSocket +from fastapi.websockets import WebSocketDisconnect +from websockets.exceptions import ConnectionClosedError +from ultralytics import YOLO +import torch +import cv2 +import base64 +import numpy as np +import json +import logging +import threading +import queue +import os +import requests +from urllib.parse import urlparse +import asyncio +import psutil + +app = FastAPI() + +models = {} + +with open("config.json", "r") as f: + config = json.load(f) + +poll_interval = config.get("poll_interval_ms", 100) +reconnect_interval = config.get("reconnect_interval_sec", 5) +TARGET_FPS = config.get("target_fps", 10) +poll_interval = 1000 / TARGET_FPS +logging.info(f"Poll interval: {poll_interval}ms") +max_streams = config.get("max_streams", 5) +max_retries = config.get("max_retries", 3) + +# Configure logging +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.FileHandler("app.log"), + logging.StreamHandler() + ] +) + +# Ensure the models directory exists +os.makedirs("models", exist_ok=True) + +# Add constants for heartbeat +HEARTBEAT_INTERVAL = 2 # seconds +WORKER_TIMEOUT_MS = 10000 + +# Add a lock for thread-safe operations on shared resources +streams_lock = threading.Lock() +models_lock = threading.Lock() + +@app.websocket("/") +async def detect(websocket: WebSocket): + import asyncio + import time + + logging.info("WebSocket connection accepted") + + streams = {} + + # This function is user-modifiable + # Save data you want to persist across frames in the persistent_data dictionary + async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data): + try: + highest_conf_box = None + max_conf = -1 + + for r in model.track(frame, stream=False, persist=True): + for box in r.boxes: + box_cpu = box.cpu() + conf = float(box_cpu.conf[0]) + if conf > max_conf and hasattr(box, "id") and box.id is not None: + max_conf = conf + highest_conf_box = { + "class": model.names[int(box_cpu.cls[0])], + "confidence": conf, + "id": box.id.item(), + } + + # Broadcast to all subscribers of this URL + detection_data = { + "type": "imageDetection", + "cameraIdentifier": camera_id, + "timestamp": time.time(), + "data": { + "detections": highest_conf_box if highest_conf_box else None, + "modelId": stream['modelId'], + "modelName": stream['modelName'] + } + } + logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}") + await websocket.send_json(detection_data) + return persistent_data + except Exception as e: + logging.error(f"Error in handle_detection for camera {camera_id}: {e}") + return persistent_data + + def frame_reader(camera_id, cap, buffer, stop_event): + import time + retries = 0 + try: + while not stop_event.is_set(): + try: + ret, frame = cap.read() + if not ret: + logging.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}") + cap.release() + time.sleep(reconnect_interval) + retries += 1 + if retries > max_retries and max_retries != -1: + logging.error(f"Max retries reached for camera: {camera_id}") + break + # Re-open the VideoCapture + cap = cv2.VideoCapture(streams[camera_id]['rtsp_url']) + if not cap.isOpened(): + logging.error(f"Failed to reopen RTSP stream for camera: {camera_id}") + continue + continue + retries = 0 # Reset on success + if not buffer.empty(): + try: + buffer.get_nowait() # Discard the old frame + except queue.Empty: + pass + buffer.put(frame) + except cv2.error as e: + logging.error(f"OpenCV error for camera {camera_id}: {e}") + cap.release() + time.sleep(reconnect_interval) + retries += 1 + if retries > max_retries and max_retries != -1: + logging.error(f"Max retries reached after OpenCV error for camera: {camera_id}") + break + # Re-open the VideoCapture + cap = cv2.VideoCapture(streams[camera_id]['rtsp_url']) + if not cap.isOpened(): + logging.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error") + continue + except Exception as e: + logging.error(f"Unexpected error for camera {camera_id}: {e}") + cap.release() + break + except Exception as e: + logging.error(f"Error in frame_reader thread for camera {camera_id}: {e}") + + async def process_streams(): + global models + logging.info("Started processing streams") + persistent_data_dict = {} + try: + while True: + start_time = time.time() + # Round-robin processing + with streams_lock: + current_streams = list(streams.items()) + for camera_id, stream in current_streams: + buffer = stream['buffer'] + if not buffer.empty(): + frame = buffer.get() + with models_lock: + model = models.get(camera_id, {}).get(stream['modelId']) + key = (camera_id, stream['modelId']) + persistent_data = persistent_data_dict.get(key, {}) + updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data) + persistent_data_dict[key] = updated_persistent_data + elapsed_time = (time.time() - start_time) * 1000 # in ms + sleep_time = max(poll_interval - elapsed_time, 0) + logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms") + await asyncio.sleep(sleep_time / 1000.0) + except asyncio.CancelledError: + logging.info("Stream processing task cancelled") + except Exception as e: + logging.error(f"Error in process_streams: {e}") + + async def send_heartbeat(): + while True: + try: + cpu_usage = psutil.cpu_percent() + memory_usage = psutil.virtual_memory().percent + if torch.cuda.is_available(): + gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB + gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB + else: + gpu_usage = None + gpu_memory_usage = None + + camera_connections = [ + { + "cameraIdentifier": camera_id, + "modelId": stream['modelId'], + "modelName": stream['modelName'], + "online": True + } + for camera_id, stream in streams.items() + ] + + state_report = { + "type": "stateReport", + "cpuUsage": cpu_usage, + "memoryUsage": memory_usage, + "gpuUsage": gpu_usage, + "gpuMemoryUsage": gpu_memory_usage, + "cameraConnections": camera_connections + } + await websocket.send_text(json.dumps(state_report)) + logging.debug("Sent stateReport as heartbeat") + await asyncio.sleep(HEARTBEAT_INTERVAL) + except Exception as e: + logging.error(f"Error sending stateReport heartbeat: {e}") + break + + async def on_message(): + global models + while True: + try: + msg = await websocket.receive_text() + logging.debug(f"Received message: {msg}") + print(f"Received message: {msg}") + data = json.loads(msg) + msg_type = data.get("type") + + if msg_type == "subscribe": + payload = data.get("payload", {}) + camera_id = payload.get("cameraIdentifier") + rtsp_url = payload.get("rtspUrl") + model_url = payload.get("modelUrl") + modelId = payload.get("modelId") + modelName = payload.get("modelName") + + if model_url: + with models_lock: + if camera_id not in models: + models[camera_id] = {} + if modelId not in models[camera_id]: + print(f"Downloading model from {model_url}") + parsed_url = urlparse(model_url) + filename = os.path.basename(parsed_url.path) + model_filename = os.path.join("models", filename) + # Download the model + response = requests.get(model_url, stream=True) + if response.status_code == 200: + with open(model_filename, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logging.info(f"Downloaded model from {model_url} to {model_filename}") + model = YOLO(model_filename) + if torch.cuda.is_available(): + model.to('cuda') + models[camera_id][modelId] = model + logging.info(f"Loaded model {modelId} for camera {camera_id}") + else: + logging.error(f"Failed to download model from {model_url}") + continue + if camera_id and rtsp_url: + with streams_lock: + if camera_id not in streams and len(streams) < max_streams: + cap = cv2.VideoCapture(rtsp_url) + if not cap.isOpened(): + logging.error(f"Failed to open RTSP stream for camera {camera_id}") + continue + buffer = queue.Queue(maxsize=1) + stop_event = threading.Event() + thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event)) + thread.daemon = True + thread.start() + streams[camera_id] = { + 'cap': cap, + 'buffer': buffer, + 'thread': thread, + 'rtsp_url': rtsp_url, + 'stop_event': stop_event, + 'modelId': modelId, + 'modelName': modelName + } + logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName} and URL {rtsp_url}") + elif camera_id and camera_id in streams: + stream = streams.pop(camera_id) + stream['cap'].release() + logging.info(f"Unsubscribed from camera {camera_id}") + if camera_id in models and modelId in models[camera_id]: + del models[camera_id][modelId] + if not models[camera_id]: + del models[camera_id] + elif msg_type == "unsubscribe": + payload = data.get("payload", {}) + camera_id = payload.get("cameraIdentifier") + logging.debug(f"Unsubscribing from camera {camera_id}") + with streams_lock: + if camera_id and camera_id in streams: + stream = streams.pop(camera_id) + stream['stop_event'].set() + stream['thread'].join() + stream['cap'].release() + logging.info(f"Unsubscribed from camera {camera_id}") + if camera_id in models and modelId in models[camera_id]: + del models[camera_id][modelId] + if not models[camera_id]: + del models[camera_id] + elif msg_type == "requestState": + # Handle state request + cpu_usage = psutil.cpu_percent() + memory_usage = psutil.virtual_memory().percent + if torch.cuda.is_available(): + gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB + gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB + else: + gpu_usage = None + gpu_memory_usage = None + + camera_connections = [ + { + "cameraIdentifier": camera_id, + "modelId": stream['modelId'], + "modelName": stream['modelName'], + "online": True + } + for camera_id, stream in streams.items() + ] + + state_report = { + "type": "stateReport", + "cpuUsage": cpu_usage, + "memoryUsage": memory_usage, + "gpuUsage": gpu_usage, + "gpuMemoryUsage": gpu_memory_usage, + "cameraConnections": camera_connections + } + await websocket.send_text(json.dumps(state_report)) + else: + logging.error(f"Unknown message type: {msg_type}") + except json.JSONDecodeError: + logging.error("Received invalid JSON message") + except (WebSocketDisconnect, ConnectionClosedError) as e: + logging.warning(f"WebSocket disconnected: {e}") + break + except Exception as e: + logging.error(f"Error handling message: {e}") + break + + try: + await websocket.accept() + task = asyncio.create_task(process_streams()) + heartbeat_task = asyncio.create_task(send_heartbeat()) + message_task = asyncio.create_task(on_message()) + + await asyncio.gather(heartbeat_task, message_task) + except Exception as e: + logging.error(f"Error in detect websocket: {e}") + finally: + task.cancel() + await task + with streams_lock: + for camera_id, stream in streams.items(): + stream['stop_event'].set() + stream['thread'].join() + stream['cap'].release() + stream['buffer'].queue.clear() + logging.info(f"Released camera {camera_id} and cleaned up resources") + streams.clear() + with models_lock: + models.clear() + logging.info("WebSocket connection closed") \ No newline at end of file