tracking
This commit is contained in:
parent
e71f63369d
commit
ce53d60702
1 changed files with 19 additions and 6 deletions
25
app.py
25
app.py
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import List
|
||||||
from fastapi import FastAPI, WebSocket
|
from fastapi import FastAPI, WebSocket
|
||||||
from fastapi.websockets import WebSocketDisconnect
|
from fastapi.websockets import WebSocketDisconnect
|
||||||
from websockets.exceptions import ConnectionClosedError
|
from websockets.exceptions import ConnectionClosedError
|
||||||
|
@ -61,11 +62,13 @@ async def detect(websocket: WebSocket):
|
||||||
# Save data you want to persist across frames in the persistent_data dictionary
|
# Save data you want to persist across frames in the persistent_data dictionary
|
||||||
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
|
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
|
||||||
boxes = []
|
boxes = []
|
||||||
for r in model(frame, stream=False):
|
for r in model.track(frame, stream=False, persist=True):
|
||||||
for box in r.boxes:
|
for box in r.boxes:
|
||||||
|
box_cpu = box.cpu()
|
||||||
boxes.append({
|
boxes.append({
|
||||||
"class": model.names[int(box.cls[0])],
|
"class": model.names[int(box_cpu.cls[0])],
|
||||||
"confidence": float(box.conf[0]),
|
"confidence": float(box_cpu.conf[0]),
|
||||||
|
"id": box_cpu.id,
|
||||||
})
|
})
|
||||||
# Broadcast to all subscribers of this URL
|
# Broadcast to all subscribers of this URL
|
||||||
detection_data = {
|
detection_data = {
|
||||||
|
@ -139,7 +142,7 @@ async def detect(websocket: WebSocket):
|
||||||
buffer = stream['buffer']
|
buffer = stream['buffer']
|
||||||
if not buffer.empty():
|
if not buffer.empty():
|
||||||
frame = buffer.get()
|
frame = buffer.get()
|
||||||
model = models.get(stream['modelId'])
|
model = models.get(camera_id, {}).get(stream['modelId'])
|
||||||
key = (camera_id, stream['modelId'])
|
key = (camera_id, stream['modelId'])
|
||||||
persistent_data = persistent_data_dict.get(key, {})
|
persistent_data = persistent_data_dict.get(key, {})
|
||||||
updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
|
updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
|
||||||
|
@ -207,7 +210,9 @@ async def detect(websocket: WebSocket):
|
||||||
modelName = payload.get("modelName")
|
modelName = payload.get("modelName")
|
||||||
|
|
||||||
if model_url:
|
if model_url:
|
||||||
if modelId not in models:
|
if camera_id not in models:
|
||||||
|
models[camera_id] = {}
|
||||||
|
if modelId not in models[camera_id]:
|
||||||
print(f"Downloading model from {model_url}")
|
print(f"Downloading model from {model_url}")
|
||||||
parsed_url = urlparse(model_url)
|
parsed_url = urlparse(model_url)
|
||||||
filename = os.path.basename(parsed_url.path)
|
filename = os.path.basename(parsed_url.path)
|
||||||
|
@ -222,7 +227,7 @@ async def detect(websocket: WebSocket):
|
||||||
model = YOLO(model_filename)
|
model = YOLO(model_filename)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
model.to('cuda')
|
model.to('cuda')
|
||||||
models[modelId] = model
|
models[camera_id][modelId] = model
|
||||||
logging.info(f"Loaded model {modelId} for camera {camera_id}")
|
logging.info(f"Loaded model {modelId} for camera {camera_id}")
|
||||||
else:
|
else:
|
||||||
logging.error(f"Failed to download model from {model_url}")
|
logging.error(f"Failed to download model from {model_url}")
|
||||||
|
@ -252,6 +257,10 @@ async def detect(websocket: WebSocket):
|
||||||
stream = streams.pop(camera_id)
|
stream = streams.pop(camera_id)
|
||||||
stream['cap'].release()
|
stream['cap'].release()
|
||||||
logging.info(f"Unsubscribed from camera {camera_id}")
|
logging.info(f"Unsubscribed from camera {camera_id}")
|
||||||
|
if camera_id in models and modelId in models[camera_id]:
|
||||||
|
del models[camera_id][modelId]
|
||||||
|
if not models[camera_id]:
|
||||||
|
del models[camera_id]
|
||||||
elif msg_type == "unsubscribe":
|
elif msg_type == "unsubscribe":
|
||||||
payload = data.get("payload", {})
|
payload = data.get("payload", {})
|
||||||
camera_id = payload.get("cameraIdentifier")
|
camera_id = payload.get("cameraIdentifier")
|
||||||
|
@ -259,6 +268,10 @@ async def detect(websocket: WebSocket):
|
||||||
stream = streams.pop(camera_id)
|
stream = streams.pop(camera_id)
|
||||||
stream['cap'].release()
|
stream['cap'].release()
|
||||||
logging.info(f"Unsubscribed from camera {camera_id}")
|
logging.info(f"Unsubscribed from camera {camera_id}")
|
||||||
|
if camera_id in models and modelId in models[camera_id]:
|
||||||
|
del models[camera_id][modelId]
|
||||||
|
if not models[camera_id]:
|
||||||
|
del models[camera_id]
|
||||||
elif msg_type == "requestState":
|
elif msg_type == "requestState":
|
||||||
# Handle state request
|
# Handle state request
|
||||||
cpu_usage = psutil.cpu_percent()
|
cpu_usage = psutil.cpu_percent()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue