This commit is contained in:
Siwat Sirichai 2025-01-14 23:24:20 +07:00
parent e71f63369d
commit ce53d60702

25
app.py
View file

@ -1,3 +1,4 @@
from typing import List
from fastapi import FastAPI, WebSocket
from fastapi.websockets import WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
@ -61,11 +62,13 @@ async def detect(websocket: WebSocket):
# Save data you want to persist across frames in the persistent_data dictionary
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
boxes = []
for r in model(frame, stream=False):
for r in model.track(frame, stream=False, persist=True):
for box in r.boxes:
box_cpu = box.cpu()
boxes.append({
"class": model.names[int(box.cls[0])],
"confidence": float(box.conf[0]),
"class": model.names[int(box_cpu.cls[0])],
"confidence": float(box_cpu.conf[0]),
"id": box_cpu.id,
})
# Broadcast to all subscribers of this URL
detection_data = {
@ -139,7 +142,7 @@ async def detect(websocket: WebSocket):
buffer = stream['buffer']
if not buffer.empty():
frame = buffer.get()
model = models.get(stream['modelId'])
model = models.get(camera_id, {}).get(stream['modelId'])
key = (camera_id, stream['modelId'])
persistent_data = persistent_data_dict.get(key, {})
updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
@ -207,7 +210,9 @@ async def detect(websocket: WebSocket):
modelName = payload.get("modelName")
if model_url:
if modelId not in models:
if camera_id not in models:
models[camera_id] = {}
if modelId not in models[camera_id]:
print(f"Downloading model from {model_url}")
parsed_url = urlparse(model_url)
filename = os.path.basename(parsed_url.path)
@ -222,7 +227,7 @@ async def detect(websocket: WebSocket):
model = YOLO(model_filename)
if torch.cuda.is_available():
model.to('cuda')
models[modelId] = model
models[camera_id][modelId] = model
logging.info(f"Loaded model {modelId} for camera {camera_id}")
else:
logging.error(f"Failed to download model from {model_url}")
@ -252,6 +257,10 @@ async def detect(websocket: WebSocket):
stream = streams.pop(camera_id)
stream['cap'].release()
logging.info(f"Unsubscribed from camera {camera_id}")
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
elif msg_type == "unsubscribe":
payload = data.get("payload", {})
camera_id = payload.get("cameraIdentifier")
@ -259,6 +268,10 @@ async def detect(websocket: WebSocket):
stream = streams.pop(camera_id)
stream['cap'].release()
logging.info(f"Unsubscribed from camera {camera_id}")
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
elif msg_type == "requestState":
# Handle state request
cpu_usage = psutil.cpu_percent()