Compare commits
2 commits
Author | SHA1 | Date | |
---|---|---|---|
b22b547fdc | |||
9893358022 |
2 changed files with 16 additions and 5 deletions
18
app.py
18
app.py
|
@ -18,9 +18,14 @@ import psutil
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize the YOLO model with tracking
|
||||
model = YOLO("yolov8n.pt")
|
||||
if torch.cuda.is_available():
|
||||
model.to('cuda')
|
||||
model.track(
|
||||
persist=True,
|
||||
tracker="bytetrack.yaml" # You can choose a different tracker if desired
|
||||
)
|
||||
|
||||
# Retrieve class names from the model
|
||||
class_names = model.names
|
||||
|
@ -118,13 +123,18 @@ async def detect(websocket: WebSocket):
|
|||
buffer = stream['buffer']
|
||||
if not buffer.empty():
|
||||
frame = buffer.get()
|
||||
results = model(frame, stream=False)
|
||||
results = model.track(frame, stream=False) # Updated for tracking
|
||||
boxes = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
for track in r.tracks:
|
||||
if not track.is_confirmed():
|
||||
continue
|
||||
track_id = track.track_id
|
||||
cls = int(track.cls)
|
||||
boxes.append({
|
||||
"class": class_names[int(box.cls[0])],
|
||||
"confidence": float(box.conf[0]),
|
||||
"class": class_names[cls],
|
||||
"confidence": float(track.conf),
|
||||
"track_id": track_id # Added track ID
|
||||
})
|
||||
# Broadcast to all subscribers of this URL
|
||||
detection_data = {
|
||||
|
|
|
@ -3,4 +3,5 @@ uvicorn
|
|||
torch
|
||||
torchvision
|
||||
ultralytics
|
||||
opencv-python
|
||||
opencv-python
|
||||
lapx
|
Loading…
Add table
Add a link
Reference in a new issue