fix dynamic model bug
This commit is contained in:
		
							parent
							
								
									c40cea786e
								
							
						
					
					
						commit
						1af0a3213a
					
				
					 1 changed files with 40 additions and 29 deletions
				
			
		
							
								
								
									
										69
									
								
								app.py
									
										
									
									
									
								
							
							
						
						
									
										69
									
								
								app.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -18,12 +18,14 @@ import psutil
 | 
			
		|||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
model = YOLO("yolov8n.pt")
 | 
			
		||||
if torch.cuda.is_available():
 | 
			
		||||
    model.to('cuda')
 | 
			
		||||
# Remove global model and class_names
 | 
			
		||||
# model = YOLO("yolov8n.pt")
 | 
			
		||||
# if torch.cuda.is_available():
 | 
			
		||||
#     model.to('cuda')
 | 
			
		||||
# class_names = model.names
 | 
			
		||||
 | 
			
		||||
# Retrieve class names from the model
 | 
			
		||||
class_names = model.names
 | 
			
		||||
# Introduce models dictionary
 | 
			
		||||
models = {}
 | 
			
		||||
 | 
			
		||||
with open("config.json", "r") as f:
 | 
			
		||||
    config = json.load(f)
 | 
			
		||||
| 
						 | 
				
			
			@ -108,7 +110,9 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                break
 | 
			
		||||
 | 
			
		||||
    async def process_streams():
 | 
			
		||||
        global model, class_names  # Added line
 | 
			
		||||
        # Remove global model and class_names
 | 
			
		||||
        # global model, class_names
 | 
			
		||||
        global models
 | 
			
		||||
        logging.info("Started processing streams")
 | 
			
		||||
        try:
 | 
			
		||||
            while True:
 | 
			
		||||
| 
						 | 
				
			
			@ -118,12 +122,17 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                    buffer = stream['buffer']
 | 
			
		||||
                    if not buffer.empty():
 | 
			
		||||
                        frame = buffer.get()
 | 
			
		||||
                        # Get the model for this stream's modelId
 | 
			
		||||
                        model = models.get(stream['modelId'])
 | 
			
		||||
                        if not model:
 | 
			
		||||
                            logging.error(f"Model {stream['modelId']} not loaded for camera {camera_id}")
 | 
			
		||||
                            continue
 | 
			
		||||
                        results = model(frame, stream=False)
 | 
			
		||||
                        boxes = []
 | 
			
		||||
                        for r in results:
 | 
			
		||||
                            for box in r.boxes:
 | 
			
		||||
                                boxes.append({
 | 
			
		||||
                                    "class": class_names[int(box.cls[0])],
 | 
			
		||||
                                    "class": model.names[int(box.cls[0])],
 | 
			
		||||
                                    "confidence": float(box.conf[0]),
 | 
			
		||||
                                })
 | 
			
		||||
                        # Broadcast to all subscribers of this URL
 | 
			
		||||
| 
						 | 
				
			
			@ -186,7 +195,9 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                break
 | 
			
		||||
 | 
			
		||||
    async def on_message():
 | 
			
		||||
        global model, class_names  # Changed from nonlocal to global
 | 
			
		||||
        # Change from nonlocal to global
 | 
			
		||||
        # global model, class_names
 | 
			
		||||
        global models
 | 
			
		||||
        while True:
 | 
			
		||||
            msg = await websocket.receive_text()
 | 
			
		||||
            logging.debug(f"Received message: {msg}")
 | 
			
		||||
| 
						 | 
				
			
			@ -202,24 +213,26 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                modelName = payload.get("modelName")
 | 
			
		||||
    
 | 
			
		||||
                if model_url:
 | 
			
		||||
                    print(f"Downloading model from {model_url}")
 | 
			
		||||
                    parsed_url = urlparse(model_url)
 | 
			
		||||
                    filename = os.path.basename(parsed_url.path)    
 | 
			
		||||
                    model_filename = os.path.join("models", filename)
 | 
			
		||||
                    # Download the model
 | 
			
		||||
                    response = requests.get(model_url, stream=True)
 | 
			
		||||
                    if response.status_code == 200:
 | 
			
		||||
                        with open(model_filename, 'wb') as f:
 | 
			
		||||
                            for chunk in response.iter_content(chunk_size=8192):
 | 
			
		||||
                                f.write(chunk)
 | 
			
		||||
                        logging.info(f"Downloaded model from {model_url} to {model_filename}")
 | 
			
		||||
                        model = YOLO(model_filename)
 | 
			
		||||
                        if torch.cuda.is_available():
 | 
			
		||||
                            model.to('cuda')
 | 
			
		||||
                        class_names = model.names
 | 
			
		||||
                    else:
 | 
			
		||||
                        logging.error(f"Failed to download model from {model_url}")
 | 
			
		||||
                        continue
 | 
			
		||||
                    if modelId not in models:
 | 
			
		||||
                        print(f"Downloading model from {model_url}")
 | 
			
		||||
                        parsed_url = urlparse(model_url)
 | 
			
		||||
                        filename = os.path.basename(parsed_url.path)    
 | 
			
		||||
                        model_filename = os.path.join("models", filename)
 | 
			
		||||
                        # Download the model
 | 
			
		||||
                        response = requests.get(model_url, stream=True)
 | 
			
		||||
                        if response.status_code == 200:
 | 
			
		||||
                            with open(model_filename, 'wb') as f:
 | 
			
		||||
                                for chunk in response.iter_content(chunk_size=8192):
 | 
			
		||||
                                    f.write(chunk)
 | 
			
		||||
                            logging.info(f"Downloaded model from {model_url} to {model_filename}")
 | 
			
		||||
                            model = YOLO(model_filename)
 | 
			
		||||
                            if torch.cuda.is_available():
 | 
			
		||||
                                model.to('cuda')
 | 
			
		||||
                            models[modelId] = model
 | 
			
		||||
                            logging.info(f"Loaded model {modelId} for camera {camera_id}")
 | 
			
		||||
                        else:
 | 
			
		||||
                            logging.error(f"Failed to download model from {model_url}")
 | 
			
		||||
                            continue
 | 
			
		||||
                if camera_id and rtsp_url:
 | 
			
		||||
                    if camera_id not in streams and len(streams) < max_streams:
 | 
			
		||||
                        cap = cv2.VideoCapture(rtsp_url)
 | 
			
		||||
| 
						 | 
				
			
			@ -374,7 +387,5 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
            stream['buffer'].queue.clear()
 | 
			
		||||
            logging.info(f"Released camera {camera_id} and cleaned up resources")
 | 
			
		||||
        streams.clear()
 | 
			
		||||
        if model_path and os.path.exists(model_path):
 | 
			
		||||
            os.remove(model_path)
 | 
			
		||||
            logging.info(f"Deleted model file {model_path}")
 | 
			
		||||
        models.clear()
 | 
			
		||||
        logging.info("WebSocket connection closed")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue