Compare commits

...
Sign in to create a new pull request.

2 commits

Author SHA1 Message Date
b22b547fdc change tracking parameter 2025-01-14 22:35:22 +07:00
9893358022 tracking test 2025-01-14 22:29:50 +07:00
2 changed files with 16 additions and 5 deletions

18
app.py
View file

@ -18,9 +18,14 @@ import psutil
app = FastAPI() app = FastAPI()
# Initialize the YOLO model with tracking
model = YOLO("yolov8n.pt") model = YOLO("yolov8n.pt")
if torch.cuda.is_available(): if torch.cuda.is_available():
model.to('cuda') model.to('cuda')
model.track(
persist=True,
tracker="bytetrack.yaml" # You can choose a different tracker if desired
)
# Retrieve class names from the model # Retrieve class names from the model
class_names = model.names class_names = model.names
@ -118,13 +123,18 @@ async def detect(websocket: WebSocket):
buffer = stream['buffer'] buffer = stream['buffer']
if not buffer.empty(): if not buffer.empty():
frame = buffer.get() frame = buffer.get()
results = model(frame, stream=False) results = model.track(frame, stream=False) # Updated for tracking
boxes = [] boxes = []
for r in results: for r in results:
for box in r.boxes: for track in r.tracks:
if not track.is_confirmed():
continue
track_id = track.track_id
cls = int(track.cls)
boxes.append({ boxes.append({
"class": class_names[int(box.cls[0])], "class": class_names[cls],
"confidence": float(box.conf[0]), "confidence": float(track.conf),
"track_id": track_id # Added track ID
}) })
# Broadcast to all subscribers of this URL # Broadcast to all subscribers of this URL
detection_data = { detection_data = {

View file

@ -3,4 +3,5 @@ uvicorn
torch torch
torchvision torchvision
ultralytics ultralytics
opencv-python opencv-python
lapx