tracking test
This commit is contained in:
		
							parent
							
								
									c40cea786e
								
							
						
					
					
						commit
						9893358022
					
				
					 2 changed files with 17 additions and 5 deletions
				
			
		
							
								
								
									
										19
									
								
								app.py
									
										
									
									
									
								
							
							
						
						
									
										19
									
								
								app.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -18,9 +18,15 @@ import psutil
 | 
			
		|||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
# Initialize the YOLO model with tracking
 | 
			
		||||
model = YOLO("yolov8n.pt")
 | 
			
		||||
if torch.cuda.is_available():
 | 
			
		||||
    model.to('cuda')
 | 
			
		||||
model.track(
 | 
			
		||||
    persist=True,
 | 
			
		||||
    tracker="bytetrack.yaml",  # You can choose a different tracker if desired
 | 
			
		||||
    track_kps=False
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Retrieve class names from the model
 | 
			
		||||
class_names = model.names
 | 
			
		||||
| 
						 | 
				
			
			@ -118,13 +124,18 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                    buffer = stream['buffer']
 | 
			
		||||
                    if not buffer.empty():
 | 
			
		||||
                        frame = buffer.get()
 | 
			
		||||
                        results = model(frame, stream=False)
 | 
			
		||||
                        results = model.track(frame, stream=False)
 | 
			
		||||
                        boxes = []
 | 
			
		||||
                        for r in results:
 | 
			
		||||
                            for box in r.boxes:
 | 
			
		||||
                            for track in r.tracks:
 | 
			
		||||
                                if not track.is_confirmed():
 | 
			
		||||
                                    continue
 | 
			
		||||
                                track_id = track.track_id
 | 
			
		||||
                                cls = int(track.cls)
 | 
			
		||||
                                boxes.append({
 | 
			
		||||
                                    "class": class_names[int(box.cls[0])],
 | 
			
		||||
                                    "confidence": float(box.conf[0]),
 | 
			
		||||
                                    "class": class_names[cls],
 | 
			
		||||
                                    "confidence": float(track.conf),
 | 
			
		||||
                                    "track_id": track_id
 | 
			
		||||
                                })
 | 
			
		||||
                        # Broadcast to all subscribers of this URL
 | 
			
		||||
                        detection_data = {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,4 +3,5 @@ uvicorn
 | 
			
		|||
torch
 | 
			
		||||
torchvision
 | 
			
		||||
ultralytics
 | 
			
		||||
opencv-python
 | 
			
		||||
opencv-python
 | 
			
		||||
lapx
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue