From 2d361b19ff9c636f79cc1db5a893c02842f6c9ca Mon Sep 17 00:00:00 2001 From: Siwat Sirichai Date: Sat, 11 Jan 2025 02:17:29 +0700 Subject: [PATCH] dynamic model loading --- .gitignore | 1 + app.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0ce25c6..2c881e8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /__pycache__ +models \ No newline at end of file diff --git a/app.py b/app.py index 1d5c515..4bbf336 100644 --- a/app.py +++ b/app.py @@ -10,6 +10,9 @@ import json import logging import threading import queue +import os +import requests +from urllib.parse import urlparse # Added import app = FastAPI() @@ -41,6 +44,9 @@ logging.basicConfig( ] ) +# Ensure the models directory exists +os.makedirs("models", exist_ok=True) + @app.websocket("/") async def detect(websocket: WebSocket): import asyncio @@ -136,6 +142,9 @@ async def detect(websocket: WebSocket): await websocket.accept() task = asyncio.create_task(process_streams()) + model = None + model_path = None + try: while True: try: @@ -144,7 +153,27 @@ async def detect(websocket: WebSocket): data = json.loads(msg) camera_id = data.get("cameraIdentifier") rtsp_url = data.get("rtspUrl") - + model_url = data.get("modelUrl") + + if model_url: + print(f"Downloading model from {model_url}") + parsed_url = urlparse(model_url) + filename = os.path.basename(parsed_url.path) + model_filename = os.path.join("models", filename) + # Download the model + response = requests.get(model_url, stream=True) + if response.status_code == 200: + with open(model_filename, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logging.info(f"Downloaded model from {model_url} to {model_filename}") + model = YOLO(model_filename) + if torch.cuda.is_available(): + model.to('cuda') + class_names = model.names + else: + logging.error(f"Failed to download model from {model_url}") + continue if camera_id and rtsp_url: if camera_id not in streams and len(streams) < max_streams: cap = cv2.VideoCapture(rtsp_url) @@ -191,4 +220,7 @@ async def detect(websocket: WebSocket): stream['buffer'].queue.clear() logging.info(f"Released camera {camera_id} and cleaned up resources") streams.clear() + if model_path and os.path.exists(model_path): + os.remove(model_path) + logging.info(f"Deleted model file {model_path}") logging.info("WebSocket connection closed")