tracking
This commit is contained in:
parent
e71f63369d
commit
ce53d60702
1 changed files with 19 additions and 6 deletions
25
app.py
25
app.py
|
@ -1,3 +1,4 @@
|
|||
from typing import List
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi.websockets import WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
@ -61,11 +62,13 @@ async def detect(websocket: WebSocket):
|
|||
# Save data you want to persist across frames in the persistent_data dictionary
|
||||
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
|
||||
boxes = []
|
||||
for r in model(frame, stream=False):
|
||||
for r in model.track(frame, stream=False, persist=True):
|
||||
for box in r.boxes:
|
||||
box_cpu = box.cpu()
|
||||
boxes.append({
|
||||
"class": model.names[int(box.cls[0])],
|
||||
"confidence": float(box.conf[0]),
|
||||
"class": model.names[int(box_cpu.cls[0])],
|
||||
"confidence": float(box_cpu.conf[0]),
|
||||
"id": box_cpu.id,
|
||||
})
|
||||
# Broadcast to all subscribers of this URL
|
||||
detection_data = {
|
||||
|
@ -139,7 +142,7 @@ async def detect(websocket: WebSocket):
|
|||
buffer = stream['buffer']
|
||||
if not buffer.empty():
|
||||
frame = buffer.get()
|
||||
model = models.get(stream['modelId'])
|
||||
model = models.get(camera_id, {}).get(stream['modelId'])
|
||||
key = (camera_id, stream['modelId'])
|
||||
persistent_data = persistent_data_dict.get(key, {})
|
||||
updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
|
||||
|
@ -207,7 +210,9 @@ async def detect(websocket: WebSocket):
|
|||
modelName = payload.get("modelName")
|
||||
|
||||
if model_url:
|
||||
if modelId not in models:
|
||||
if camera_id not in models:
|
||||
models[camera_id] = {}
|
||||
if modelId not in models[camera_id]:
|
||||
print(f"Downloading model from {model_url}")
|
||||
parsed_url = urlparse(model_url)
|
||||
filename = os.path.basename(parsed_url.path)
|
||||
|
@ -222,7 +227,7 @@ async def detect(websocket: WebSocket):
|
|||
model = YOLO(model_filename)
|
||||
if torch.cuda.is_available():
|
||||
model.to('cuda')
|
||||
models[modelId] = model
|
||||
models[camera_id][modelId] = model
|
||||
logging.info(f"Loaded model {modelId} for camera {camera_id}")
|
||||
else:
|
||||
logging.error(f"Failed to download model from {model_url}")
|
||||
|
@ -252,6 +257,10 @@ async def detect(websocket: WebSocket):
|
|||
stream = streams.pop(camera_id)
|
||||
stream['cap'].release()
|
||||
logging.info(f"Unsubscribed from camera {camera_id}")
|
||||
if camera_id in models and modelId in models[camera_id]:
|
||||
del models[camera_id][modelId]
|
||||
if not models[camera_id]:
|
||||
del models[camera_id]
|
||||
elif msg_type == "unsubscribe":
|
||||
payload = data.get("payload", {})
|
||||
camera_id = payload.get("cameraIdentifier")
|
||||
|
@ -259,6 +268,10 @@ async def detect(websocket: WebSocket):
|
|||
stream = streams.pop(camera_id)
|
||||
stream['cap'].release()
|
||||
logging.info(f"Unsubscribed from camera {camera_id}")
|
||||
if camera_id in models and modelId in models[camera_id]:
|
||||
del models[camera_id][modelId]
|
||||
if not models[camera_id]:
|
||||
del models[camera_id]
|
||||
elif msg_type == "requestState":
|
||||
# Handle state request
|
||||
cpu_usage = psutil.cpu_percent()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue