diff --git a/app.py b/app.py index da4a073..666730f 100644 --- a/app.py +++ b/app.py @@ -13,6 +13,8 @@ import queue import os import requests from urllib.parse import urlparse # Added import +import asyncio # Ensure asyncio is imported +import psutil # Added import app = FastAPI() @@ -47,6 +49,10 @@ logging.basicConfig( # Ensure the models directory exists os.makedirs("models", exist_ok=True) +# Add constants for heartbeat +HEARTBEAT_INTERVAL = 2 # seconds +WORKER_TIMEOUT_MS = 10000 + @app.websocket("/") async def detect(websocket: WebSocket): import asyncio @@ -102,6 +108,7 @@ async def detect(websocket: WebSocket): break async def process_streams(): + global model, class_names # Added line logging.info("Started processing streams") try: while True: @@ -141,8 +148,149 @@ async def detect(websocket: WebSocket): except Exception as e: logging.error(f"Error in process_streams: {e}") + async def send_heartbeat(): + while True: + try: + cpu_usage = psutil.cpu_percent() + memory_usage = psutil.virtual_memory().percent + if torch.cuda.is_available(): + gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB + gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB + else: + gpu_usage = None + gpu_memory_usage = None + + camera_connections = [ + { + "cameraIdentifier": camera_id, + "modelId": stream['modelId'], + "modelName": stream['modelName'], + "online": True + } + for camera_id, stream in streams.items() + ] + + state_report = { + "type": "stateReport", + "cpuUsage": cpu_usage, + "memoryUsage": memory_usage, + "gpuUsage": gpu_usage, + "gpuMemoryUsage": gpu_memory_usage, + "cameraConnections": camera_connections + } + await websocket.send_text(json.dumps(state_report)) + logging.debug("Sent stateReport as heartbeat") + await asyncio.sleep(HEARTBEAT_INTERVAL) + except Exception as e: + logging.error(f"Error sending stateReport heartbeat: {e}") + break + + async def on_message(): + global model, class_names # Changed from nonlocal to global + while True: + msg = await websocket.receive_text() + logging.debug(f"Received message: {msg}") + data = json.loads(msg) + msg_type = data.get("type") + + if msg_type == "subscribe": + payload = data.get("payload", {}) + camera_id = payload.get("cameraIdentifier") + rtsp_url = payload.get("rtspUrl") + model_url = payload.get("modelUrl") + modelId = payload.get("modelId") + modelName = payload.get("modelName") + + if model_url: + print(f"Downloading model from {model_url}") + parsed_url = urlparse(model_url) + filename = os.path.basename(parsed_url.path) + model_filename = os.path.join("models", filename) + # Download the model + response = requests.get(model_url, stream=True) + if response.status_code == 200: + with open(model_filename, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logging.info(f"Downloaded model from {model_url} to {model_filename}") + model = YOLO(model_filename) + if torch.cuda.is_available(): + model.to('cuda') + class_names = model.names + else: + logging.error(f"Failed to download model from {model_url}") + continue + if camera_id and rtsp_url: + if camera_id not in streams and len(streams) < max_streams: + cap = cv2.VideoCapture(rtsp_url) + if not cap.isOpened(): + logging.error(f"Failed to open RTSP stream for camera {camera_id}") + continue + buffer = queue.Queue(maxsize=1) + stop_event = threading.Event() + thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event)) + thread.daemon = True + thread.start() + streams[camera_id] = { + 'cap': cap, + 'buffer': buffer, + 'thread': thread, + 'rtsp_url': rtsp_url, + 'stop_event': stop_event, + 'modelId': modelId, + 'modelName': modelName + } + logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName} and URL {rtsp_url}") + elif camera_id and camera_id in streams: + stream = streams.pop(camera_id) + stream['cap'].release() + logging.info(f"Unsubscribed from camera {camera_id}") + elif msg_type == "unsubscribe": + payload = data.get("payload", {}) + camera_id = payload.get("cameraIdentifier") + if camera_id and camera_id in streams: + stream = streams.pop(camera_id) + stream['cap'].release() + logging.info(f"Unsubscribed from camera {camera_id}") + elif msg_type == "requestState": + # Handle state request + cpu_usage = psutil.cpu_percent() + memory_usage = psutil.virtual_memory().percent + if torch.cuda.is_available(): + gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB + gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB + else: + gpu_usage = None + gpu_memory_usage = None + + camera_connections = [ + { + "cameraIdentifier": camera_id, + "modelId": stream['modelId'], + "modelName": stream['modelName'], + "online": True + } + for camera_id, stream in streams.items() + ] + + state_report = { + "type": "stateReport", + "cpuUsage": cpu_usage, + "memoryUsage": memory_usage, + "gpuUsage": gpu_usage, + "gpuMemoryUsage": gpu_memory_usage, + "cameraConnections": camera_connections + } + await websocket.send_text(json.dumps(state_report)) + else: + logging.error(f"Unknown message type: {msg_type}") + await websocket.accept() task = asyncio.create_task(process_streams()) + heartbeat_task = asyncio.create_task(send_heartbeat()) + message_task = asyncio.create_task(on_message()) + + await asyncio.gather(heartbeat_task, message_task) model = None model_path = None