From 9893358022842ba751dad2aefe496b5590a856c2 Mon Sep 17 00:00:00 2001 From: Siwat Sirichai Date: Tue, 14 Jan 2025 22:29:50 +0700 Subject: [PATCH 1/2] tracking test --- app.py | 19 +++++++++++++++---- requirements.txt | 3 ++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index e4adce5..595a40e 100644 --- a/app.py +++ b/app.py @@ -18,9 +18,15 @@ import psutil app = FastAPI() +# Initialize the YOLO model with tracking model = YOLO("yolov8n.pt") if torch.cuda.is_available(): model.to('cuda') +model.track( + persist=True, + tracker="bytetrack.yaml", # You can choose a different tracker if desired + track_kps=False +) # Retrieve class names from the model class_names = model.names @@ -118,13 +124,18 @@ async def detect(websocket: WebSocket): buffer = stream['buffer'] if not buffer.empty(): frame = buffer.get() - results = model(frame, stream=False) + results = model.track(frame, stream=False) boxes = [] for r in results: - for box in r.boxes: + for track in r.tracks: + if not track.is_confirmed(): + continue + track_id = track.track_id + cls = int(track.cls) boxes.append({ - "class": class_names[int(box.cls[0])], - "confidence": float(box.conf[0]), + "class": class_names[cls], + "confidence": float(track.conf), + "track_id": track_id }) # Broadcast to all subscribers of this URL detection_data = { diff --git a/requirements.txt b/requirements.txt index 46a2624..5555f62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ uvicorn torch torchvision ultralytics -opencv-python \ No newline at end of file +opencv-python +lapx \ No newline at end of file From b22b547fdcb92432df6a1be5680512d78dfcb79d Mon Sep 17 00:00:00 2001 From: Siwat Sirichai Date: Tue, 14 Jan 2025 22:35:22 +0700 Subject: [PATCH 2/2] change tracking parameter --- app.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/app.py b/app.py index 595a40e..180583b 100644 --- a/app.py +++ b/app.py @@ -24,8 +24,7 @@ if torch.cuda.is_available(): model.to('cuda') model.track( persist=True, - tracker="bytetrack.yaml", # You can choose a different tracker if desired - track_kps=False + tracker="bytetrack.yaml" # You can choose a different tracker if desired ) # Retrieve class names from the model @@ -124,7 +123,7 @@ async def detect(websocket: WebSocket): buffer = stream['buffer'] if not buffer.empty(): frame = buffer.get() - results = model.track(frame, stream=False) + results = model.track(frame, stream=False) # Updated for tracking boxes = [] for r in results: for track in r.tracks: @@ -135,7 +134,7 @@ async def detect(websocket: WebSocket): boxes.append({ "class": class_names[cls], "confidence": float(track.conf), - "track_id": track_id + "track_id": track_id # Added track ID }) # Broadcast to all subscribers of this URL detection_data = {