tracking
This commit is contained in:
		
							parent
							
								
									e71f63369d
								
							
						
					
					
						commit
						ce53d60702
					
				
					 1 changed files with 19 additions and 6 deletions
				
			
		
							
								
								
									
										25
									
								
								app.py
									
										
									
									
									
								
							
							
						
						
									
										25
									
								
								app.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -1,3 +1,4 @@
 | 
			
		|||
from typing import List
 | 
			
		||||
from fastapi import FastAPI, WebSocket
 | 
			
		||||
from fastapi.websockets import WebSocketDisconnect
 | 
			
		||||
from websockets.exceptions import ConnectionClosedError
 | 
			
		||||
| 
						 | 
				
			
			@ -61,11 +62,13 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
    # Save data you want to persist across frames in the persistent_data dictionary
 | 
			
		||||
    async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
 | 
			
		||||
        boxes = []
 | 
			
		||||
        for r in model(frame, stream=False):
 | 
			
		||||
        for r in model.track(frame, stream=False, persist=True):
 | 
			
		||||
            for box in r.boxes:
 | 
			
		||||
                box_cpu = box.cpu()
 | 
			
		||||
                boxes.append({
 | 
			
		||||
                    "class": model.names[int(box.cls[0])],
 | 
			
		||||
                    "confidence": float(box.conf[0]),
 | 
			
		||||
                    "class": model.names[int(box_cpu.cls[0])],
 | 
			
		||||
                    "confidence": float(box_cpu.conf[0]),
 | 
			
		||||
                    "id": box_cpu.id,
 | 
			
		||||
                })
 | 
			
		||||
        # Broadcast to all subscribers of this URL
 | 
			
		||||
        detection_data = {
 | 
			
		||||
| 
						 | 
				
			
			@ -139,7 +142,7 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                    buffer = stream['buffer']
 | 
			
		||||
                    if not buffer.empty():
 | 
			
		||||
                        frame = buffer.get()
 | 
			
		||||
                        model = models.get(stream['modelId'])
 | 
			
		||||
                        model = models.get(camera_id, {}).get(stream['modelId'])
 | 
			
		||||
                        key = (camera_id, stream['modelId'])
 | 
			
		||||
                        persistent_data = persistent_data_dict.get(key, {})
 | 
			
		||||
                        updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
 | 
			
		||||
| 
						 | 
				
			
			@ -207,7 +210,9 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                modelName = payload.get("modelName")
 | 
			
		||||
    
 | 
			
		||||
                if model_url:
 | 
			
		||||
                    if modelId not in models:
 | 
			
		||||
                    if camera_id not in models:
 | 
			
		||||
                        models[camera_id] = {}
 | 
			
		||||
                    if modelId not in models[camera_id]:
 | 
			
		||||
                        print(f"Downloading model from {model_url}")
 | 
			
		||||
                        parsed_url = urlparse(model_url)
 | 
			
		||||
                        filename = os.path.basename(parsed_url.path)    
 | 
			
		||||
| 
						 | 
				
			
			@ -222,7 +227,7 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                            model = YOLO(model_filename)
 | 
			
		||||
                            if torch.cuda.is_available():
 | 
			
		||||
                                model.to('cuda')
 | 
			
		||||
                            models[modelId] = model
 | 
			
		||||
                            models[camera_id][modelId] = model
 | 
			
		||||
                            logging.info(f"Loaded model {modelId} for camera {camera_id}")
 | 
			
		||||
                        else:
 | 
			
		||||
                            logging.error(f"Failed to download model from {model_url}")
 | 
			
		||||
| 
						 | 
				
			
			@ -252,6 +257,10 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                        stream = streams.pop(camera_id)
 | 
			
		||||
                        stream['cap'].release()
 | 
			
		||||
                        logging.info(f"Unsubscribed from camera {camera_id}")
 | 
			
		||||
                        if camera_id in models and modelId in models[camera_id]:
 | 
			
		||||
                            del models[camera_id][modelId]
 | 
			
		||||
                            if not models[camera_id]:
 | 
			
		||||
                                del models[camera_id]
 | 
			
		||||
            elif msg_type == "unsubscribe":
 | 
			
		||||
                payload = data.get("payload", {})
 | 
			
		||||
                camera_id = payload.get("cameraIdentifier")
 | 
			
		||||
| 
						 | 
				
			
			@ -259,6 +268,10 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                    stream = streams.pop(camera_id)
 | 
			
		||||
                    stream['cap'].release()
 | 
			
		||||
                    logging.info(f"Unsubscribed from camera {camera_id}")
 | 
			
		||||
                    if camera_id in models and modelId in models[camera_id]:
 | 
			
		||||
                        del models[camera_id][modelId]
 | 
			
		||||
                        if not models[camera_id]:
 | 
			
		||||
                            del models[camera_id]
 | 
			
		||||
            elif msg_type == "requestState":
 | 
			
		||||
                # Handle state request
 | 
			
		||||
                cpu_usage = psutil.cpu_percent()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue