diff --git a/app.py b/app.py index e4adce5..180583b 100644 --- a/app.py +++ b/app.py @@ -18,9 +18,14 @@ import psutil app = FastAPI() +# Initialize the YOLO model with tracking model = YOLO("yolov8n.pt") if torch.cuda.is_available(): model.to('cuda') +model.track( + persist=True, + tracker="bytetrack.yaml" # You can choose a different tracker if desired +) # Retrieve class names from the model class_names = model.names @@ -118,13 +123,18 @@ async def detect(websocket: WebSocket): buffer = stream['buffer'] if not buffer.empty(): frame = buffer.get() - results = model(frame, stream=False) + results = model.track(frame, stream=False) # Updated for tracking boxes = [] for r in results: - for box in r.boxes: + for track in r.tracks: + if not track.is_confirmed(): + continue + track_id = track.track_id + cls = int(track.cls) boxes.append({ - "class": class_names[int(box.cls[0])], - "confidence": float(box.conf[0]), + "class": class_names[cls], + "confidence": float(track.conf), + "track_id": track_id # Added track ID }) # Broadcast to all subscribers of this URL detection_data = { diff --git a/requirements.txt b/requirements.txt index 46a2624..5555f62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ uvicorn torch torchvision ultralytics -opencv-python \ No newline at end of file +opencv-python +lapx \ No newline at end of file