dynamic model loading

This commit is contained in:
Siwat Sirichai 2025-01-11 02:17:29 +07:00
parent af26c1477c
commit 2d361b19ff
2 changed files with 34 additions and 1 deletions

1
.gitignore vendored
View file

@ -1,2 +1,3 @@
/__pycache__
models

34
app.py
View file

@ -10,6 +10,9 @@ import json
import logging
import threading
import queue
import os
import requests
from urllib.parse import urlparse # Added import
app = FastAPI()
@ -41,6 +44,9 @@ logging.basicConfig(
]
)
# Ensure the models directory exists
os.makedirs("models", exist_ok=True)
@app.websocket("/")
async def detect(websocket: WebSocket):
import asyncio
@ -136,6 +142,9 @@ async def detect(websocket: WebSocket):
await websocket.accept()
task = asyncio.create_task(process_streams())
model = None
model_path = None
try:
while True:
try:
@ -144,7 +153,27 @@ async def detect(websocket: WebSocket):
data = json.loads(msg)
camera_id = data.get("cameraIdentifier")
rtsp_url = data.get("rtspUrl")
model_url = data.get("modelUrl")
if model_url:
print(f"Downloading model from {model_url}")
parsed_url = urlparse(model_url)
filename = os.path.basename(parsed_url.path)
model_filename = os.path.join("models", filename)
# Download the model
response = requests.get(model_url, stream=True)
if response.status_code == 200:
with open(model_filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logging.info(f"Downloaded model from {model_url} to {model_filename}")
model = YOLO(model_filename)
if torch.cuda.is_available():
model.to('cuda')
class_names = model.names
else:
logging.error(f"Failed to download model from {model_url}")
continue
if camera_id and rtsp_url:
if camera_id not in streams and len(streams) < max_streams:
cap = cv2.VideoCapture(rtsp_url)
@ -191,4 +220,7 @@ async def detect(websocket: WebSocket):
stream['buffer'].queue.clear()
logging.info(f"Released camera {camera_id} and cleaned up resources")
streams.clear()
if model_path and os.path.exists(model_path):
os.remove(model_path)
logging.info(f"Deleted model file {model_path}")
logging.info("WebSocket connection closed")