dynamic model loading
This commit is contained in:
		
							parent
							
								
									af26c1477c
								
							
						
					
					
						commit
						2d361b19ff
					
				
					 2 changed files with 34 additions and 1 deletions
				
			
		
							
								
								
									
										1
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
										
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -1,2 +1,3 @@
 | 
			
		|||
 | 
			
		||||
/__pycache__
 | 
			
		||||
models
 | 
			
		||||
							
								
								
									
										34
									
								
								app.py
									
										
									
									
									
								
							
							
						
						
									
										34
									
								
								app.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -10,6 +10,9 @@ import json
 | 
			
		|||
import logging
 | 
			
		||||
import threading
 | 
			
		||||
import queue
 | 
			
		||||
import os
 | 
			
		||||
import requests
 | 
			
		||||
from urllib.parse import urlparse  # Added import
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -41,6 +44,9 @@ logging.basicConfig(
 | 
			
		|||
    ]
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Ensure the models directory exists
 | 
			
		||||
os.makedirs("models", exist_ok=True)
 | 
			
		||||
 | 
			
		||||
@app.websocket("/")
 | 
			
		||||
async def detect(websocket: WebSocket):
 | 
			
		||||
    import asyncio
 | 
			
		||||
| 
						 | 
				
			
			@ -136,6 +142,9 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
    await websocket.accept()
 | 
			
		||||
    task = asyncio.create_task(process_streams())
 | 
			
		||||
 | 
			
		||||
    model = None
 | 
			
		||||
    model_path = None
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
| 
						 | 
				
			
			@ -144,7 +153,27 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                data = json.loads(msg)
 | 
			
		||||
                camera_id = data.get("cameraIdentifier")
 | 
			
		||||
                rtsp_url = data.get("rtspUrl")
 | 
			
		||||
 | 
			
		||||
                model_url = data.get("modelUrl")
 | 
			
		||||
    
 | 
			
		||||
                if model_url:
 | 
			
		||||
                    print(f"Downloading model from {model_url}")
 | 
			
		||||
                    parsed_url = urlparse(model_url)
 | 
			
		||||
                    filename = os.path.basename(parsed_url.path)    
 | 
			
		||||
                    model_filename = os.path.join("models", filename)
 | 
			
		||||
                    # Download the model
 | 
			
		||||
                    response = requests.get(model_url, stream=True)
 | 
			
		||||
                    if response.status_code == 200:
 | 
			
		||||
                        with open(model_filename, 'wb') as f:
 | 
			
		||||
                            for chunk in response.iter_content(chunk_size=8192):
 | 
			
		||||
                                f.write(chunk)
 | 
			
		||||
                        logging.info(f"Downloaded model from {model_url} to {model_filename}")
 | 
			
		||||
                        model = YOLO(model_filename)
 | 
			
		||||
                        if torch.cuda.is_available():
 | 
			
		||||
                            model.to('cuda')
 | 
			
		||||
                        class_names = model.names
 | 
			
		||||
                    else:
 | 
			
		||||
                        logging.error(f"Failed to download model from {model_url}")
 | 
			
		||||
                        continue
 | 
			
		||||
                if camera_id and rtsp_url:
 | 
			
		||||
                    if camera_id not in streams and len(streams) < max_streams:
 | 
			
		||||
                        cap = cv2.VideoCapture(rtsp_url)
 | 
			
		||||
| 
						 | 
				
			
			@ -191,4 +220,7 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
            stream['buffer'].queue.clear()
 | 
			
		||||
            logging.info(f"Released camera {camera_id} and cleaned up resources")
 | 
			
		||||
        streams.clear()
 | 
			
		||||
        if model_path and os.path.exists(model_path):
 | 
			
		||||
            os.remove(model_path)
 | 
			
		||||
            logging.info(f"Deleted model file {model_path}")
 | 
			
		||||
        logging.info("WebSocket connection closed")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue