From ce53d6070210928a1482c31b2e8072f5dad39a6d Mon Sep 17 00:00:00 2001 From: Siwat Sirichai Date: Tue, 14 Jan 2025 23:24:20 +0700 Subject: [PATCH] tracking --- app.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/app.py b/app.py index fccd42b..b6f9bee 100644 --- a/app.py +++ b/app.py @@ -1,3 +1,4 @@ +from typing import List from fastapi import FastAPI, WebSocket from fastapi.websockets import WebSocketDisconnect from websockets.exceptions import ConnectionClosedError @@ -61,11 +62,13 @@ async def detect(websocket: WebSocket): # Save data you want to persist across frames in the persistent_data dictionary async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data): boxes = [] - for r in model(frame, stream=False): + for r in model.track(frame, stream=False, persist=True): for box in r.boxes: + box_cpu = box.cpu() boxes.append({ - "class": model.names[int(box.cls[0])], - "confidence": float(box.conf[0]), + "class": model.names[int(box_cpu.cls[0])], + "confidence": float(box_cpu.conf[0]), + "id": box_cpu.id, }) # Broadcast to all subscribers of this URL detection_data = { @@ -139,7 +142,7 @@ async def detect(websocket: WebSocket): buffer = stream['buffer'] if not buffer.empty(): frame = buffer.get() - model = models.get(stream['modelId']) + model = models.get(camera_id, {}).get(stream['modelId']) key = (camera_id, stream['modelId']) persistent_data = persistent_data_dict.get(key, {}) updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data) @@ -207,7 +210,9 @@ async def detect(websocket: WebSocket): modelName = payload.get("modelName") if model_url: - if modelId not in models: + if camera_id not in models: + models[camera_id] = {} + if modelId not in models[camera_id]: print(f"Downloading model from {model_url}") parsed_url = urlparse(model_url) filename = os.path.basename(parsed_url.path) @@ -222,7 +227,7 @@ async def detect(websocket: WebSocket): model = YOLO(model_filename) if torch.cuda.is_available(): model.to('cuda') - models[modelId] = model + models[camera_id][modelId] = model logging.info(f"Loaded model {modelId} for camera {camera_id}") else: logging.error(f"Failed to download model from {model_url}") @@ -252,6 +257,10 @@ async def detect(websocket: WebSocket): stream = streams.pop(camera_id) stream['cap'].release() logging.info(f"Unsubscribed from camera {camera_id}") + if camera_id in models and modelId in models[camera_id]: + del models[camera_id][modelId] + if not models[camera_id]: + del models[camera_id] elif msg_type == "unsubscribe": payload = data.get("payload", {}) camera_id = payload.get("cameraIdentifier") @@ -259,6 +268,10 @@ async def detect(websocket: WebSocket): stream = streams.pop(camera_id) stream['cap'].release() logging.info(f"Unsubscribed from camera {camera_id}") + if camera_id in models and modelId in models[camera_id]: + del models[camera_id][modelId] + if not models[camera_id]: + del models[camera_id] elif msg_type == "requestState": # Handle state request cpu_usage = psutil.cpu_percent()