from fastapi import FastAPI, WebSocket from fastapi.websockets import WebSocketDisconnect from websockets.exceptions import ConnectionClosedError from ultralytics import YOLO import torch import cv2 import base64 import numpy as np import json import logging import threading import queue import os import requests from urllib.parse import urlparse import asyncio import psutil app = FastAPI() # Remove global model and class_names # model = YOLO("yolov8n.pt") # if torch.cuda.is_available(): # model.to('cuda') # class_names = model.names # Introduce models dictionary models = {} with open("config.json", "r") as f: config = json.load(f) poll_interval = config.get("poll_interval_ms", 100) reconnect_interval = config.get("reconnect_interval_sec", 5) # New setting TARGET_FPS = config.get("target_fps", 10) # Add TARGET_FPS poll_interval = 1000 / TARGET_FPS # Adjust poll_interval based on TARGET_FPS logging.info(f"Poll interval: {poll_interval}ms") max_streams = config.get("max_streams", 5) max_retries = config.get("max_retries", 3) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.FileHandler("app.log"), logging.StreamHandler() ] ) # Ensure the models directory exists os.makedirs("models", exist_ok=True) # Add constants for heartbeat HEARTBEAT_INTERVAL = 2 # seconds WORKER_TIMEOUT_MS = 10000 @app.websocket("/") async def detect(websocket: WebSocket): import asyncio import time logging.info("WebSocket connection accepted") streams = {} def frame_reader(camera_id, cap, buffer, stop_event): import time retries = 0 while not stop_event.is_set(): try: ret, frame = cap.read() if not ret: logging.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}") cap.release() time.sleep(reconnect_interval) retries += 1 if retries > max_retries: logging.error(f"Max retries reached for camera: {camera_id}") break # Re-open the VideoCapture cap = cv2.VideoCapture(streams[camera_id]['rtsp_url']) if not cap.isOpened(): logging.error(f"Failed to reopen RTSP stream for camera: {camera_id}") continue continue retries = 0 # Reset on success if not buffer.empty(): try: buffer.get_nowait() # Discard the old frame except queue.Empty: pass buffer.put(frame) except cv2.error as e: logging.error(f"OpenCV error for camera {camera_id}: {e}") cap.release() time.sleep(reconnect_interval) retries += 1 if retries > max_retries and max_retries != -1: logging.error(f"Max retries reached after OpenCV error for camera: {camera_id}") break # Re-open the VideoCapture cap = cv2.VideoCapture(streams[camera_id]['rtsp_url']) if not cap.isOpened(): logging.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error") continue except Exception as e: logging.error(f"Unexpected error for camera {camera_id}: {e}") cap.release() break async def process_streams(): # Remove global model and class_names # global model, class_names global models logging.info("Started processing streams") try: while True: start_time = time.time() # Round-robin processing for camera_id, stream in list(streams.items()): buffer = stream['buffer'] if not buffer.empty(): frame = buffer.get() # Get the model for this stream's modelId model = models.get(stream['modelId']) if not model: logging.error(f"Model {stream['modelId']} not loaded for camera {camera_id}") continue results = model(frame, stream=False) boxes = [] for r in results: for box in r.boxes: boxes.append({ "class": model.names[int(box.cls[0])], "confidence": float(box.conf[0]), }) # Broadcast to all subscribers of this URL detection_data = { "type": "imageDetection", "cameraIdentifier": camera_id, "timestamp": time.time(), "data": { "detections": boxes, "modelId": stream['modelId'], "modelName": stream['modelName'] } } logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}") await websocket.send_json(detection_data) elapsed_time = (time.time() - start_time) * 1000 # in ms sleep_time = max(poll_interval - elapsed_time, 0) logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms") await asyncio.sleep(sleep_time / 1000.0) except asyncio.CancelledError: logging.info("Stream processing task cancelled") except Exception as e: logging.error(f"Error in process_streams: {e}") async def send_heartbeat(): while True: try: cpu_usage = psutil.cpu_percent() memory_usage = psutil.virtual_memory().percent if torch.cuda.is_available(): gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB else: gpu_usage = None gpu_memory_usage = None camera_connections = [ { "cameraIdentifier": camera_id, "modelId": stream['modelId'], "modelName": stream['modelName'], "online": True } for camera_id, stream in streams.items() ] state_report = { "type": "stateReport", "cpuUsage": cpu_usage, "memoryUsage": memory_usage, "gpuUsage": gpu_usage, "gpuMemoryUsage": gpu_memory_usage, "cameraConnections": camera_connections } await websocket.send_text(json.dumps(state_report)) logging.debug("Sent stateReport as heartbeat") await asyncio.sleep(HEARTBEAT_INTERVAL) except Exception as e: logging.error(f"Error sending stateReport heartbeat: {e}") break async def on_message(): # Change from nonlocal to global # global model, class_names global models while True: msg = await websocket.receive_text() logging.debug(f"Received message: {msg}") data = json.loads(msg) msg_type = data.get("type") if msg_type == "subscribe": payload = data.get("payload", {}) camera_id = payload.get("cameraIdentifier") rtsp_url = payload.get("rtspUrl") model_url = payload.get("modelUrl") modelId = payload.get("modelId") modelName = payload.get("modelName") if model_url: if modelId not in models: print(f"Downloading model from {model_url}") parsed_url = urlparse(model_url) filename = os.path.basename(parsed_url.path) model_filename = os.path.join("models", filename) # Download the model response = requests.get(model_url, stream=True) if response.status_code == 200: with open(model_filename, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logging.info(f"Downloaded model from {model_url} to {model_filename}") model = YOLO(model_filename) if torch.cuda.is_available(): model.to('cuda') models[modelId] = model logging.info(f"Loaded model {modelId} for camera {camera_id}") else: logging.error(f"Failed to download model from {model_url}") continue if camera_id and rtsp_url: if camera_id not in streams and len(streams) < max_streams: cap = cv2.VideoCapture(rtsp_url) if not cap.isOpened(): logging.error(f"Failed to open RTSP stream for camera {camera_id}") continue buffer = queue.Queue(maxsize=1) stop_event = threading.Event() thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event)) thread.daemon = True thread.start() streams[camera_id] = { 'cap': cap, 'buffer': buffer, 'thread': thread, 'rtsp_url': rtsp_url, 'stop_event': stop_event, 'modelId': modelId, 'modelName': modelName } logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName} and URL {rtsp_url}") elif camera_id and camera_id in streams: stream = streams.pop(camera_id) stream['cap'].release() logging.info(f"Unsubscribed from camera {camera_id}") elif msg_type == "unsubscribe": payload = data.get("payload", {}) camera_id = payload.get("cameraIdentifier") if camera_id and camera_id in streams: stream = streams.pop(camera_id) stream['cap'].release() logging.info(f"Unsubscribed from camera {camera_id}") elif msg_type == "requestState": # Handle state request cpu_usage = psutil.cpu_percent() memory_usage = psutil.virtual_memory().percent if torch.cuda.is_available(): gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB else: gpu_usage = None gpu_memory_usage = None camera_connections = [ { "cameraIdentifier": camera_id, "modelId": stream['modelId'], "modelName": stream['modelName'], "online": True } for camera_id, stream in streams.items() ] state_report = { "type": "stateReport", "cpuUsage": cpu_usage, "memoryUsage": memory_usage, "gpuUsage": gpu_usage, "gpuMemoryUsage": gpu_memory_usage, "cameraConnections": camera_connections } await websocket.send_text(json.dumps(state_report)) else: logging.error(f"Unknown message type: {msg_type}") await websocket.accept() task = asyncio.create_task(process_streams()) heartbeat_task = asyncio.create_task(send_heartbeat()) message_task = asyncio.create_task(on_message()) await asyncio.gather(heartbeat_task, message_task) model = None model_path = None try: while True: try: msg = await websocket.receive_text() logging.debug(f"Received message: {msg}") data = json.loads(msg) camera_id = data.get("cameraIdentifier") rtsp_url = data.get("rtspUrl") model_url = data.get("modelUrl") modelId = data.get("modelId") modelName = data.get("modelName") if model_url: print(f"Downloading model from {model_url}") parsed_url = urlparse(model_url) filename = os.path.basename(parsed_url.path) model_filename = os.path.join("models", filename) # Download the model response = requests.get(model_url, stream=True) if response.status_code == 200: with open(model_filename, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logging.info(f"Downloaded model from {model_url} to {model_filename}") model = YOLO(model_filename) if torch.cuda.is_available(): model.to('cuda') class_names = model.names else: logging.error(f"Failed to download model from {model_url}") continue if camera_id and rtsp_url: if camera_id not in streams and len(streams) < max_streams: cap = cv2.VideoCapture(rtsp_url) if not cap.isOpened(): logging.error(f"Failed to open RTSP stream for camera {camera_id}") continue buffer = queue.Queue(maxsize=1) stop_event = threading.Event() thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event)) thread.daemon = True thread.start() streams[camera_id] = { 'cap': cap, 'buffer': buffer, 'thread': thread, 'rtsp_url': rtsp_url, 'stop_event': stop_event, 'modelId': modelId, 'modelName': modelName } logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName} and URL {rtsp_url}") elif camera_id and camera_id in streams: stream = streams.pop(camera_id) stream['cap'].release() logging.info(f"Unsubscribed from camera {camera_id}") elif data.get("command") == "stop": logging.info("Received stop command") break except json.JSONDecodeError: logging.error("Received invalid JSON message") except (WebSocketDisconnect, ConnectionClosedError) as e: logging.warning(f"WebSocket disconnected: {e}") break except Exception as e: logging.error(f"Error handling message: {e}") break except Exception as e: logging.error(f"Unexpected error in WebSocket connection: {e}") finally: task.cancel() await task for camera_id, stream in streams.items(): stream['stop_event'].set() stream['thread'].join() stream['cap'].release() stream['buffer'].queue.clear() logging.info(f"Released camera {camera_id} and cleaned up resources") streams.clear() models.clear() logging.info("WebSocket connection closed")