dynamic model loading
This commit is contained in:
parent
af26c1477c
commit
2d361b19ff
2 changed files with 34 additions and 1 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,2 +1,3 @@
|
|||
|
||||
/__pycache__
|
||||
models
|
32
app.py
32
app.py
|
@ -10,6 +10,9 @@ import json
|
|||
import logging
|
||||
import threading
|
||||
import queue
|
||||
import os
|
||||
import requests
|
||||
from urllib.parse import urlparse # Added import
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
@ -41,6 +44,9 @@ logging.basicConfig(
|
|||
]
|
||||
)
|
||||
|
||||
# Ensure the models directory exists
|
||||
os.makedirs("models", exist_ok=True)
|
||||
|
||||
@app.websocket("/")
|
||||
async def detect(websocket: WebSocket):
|
||||
import asyncio
|
||||
|
@ -136,6 +142,9 @@ async def detect(websocket: WebSocket):
|
|||
await websocket.accept()
|
||||
task = asyncio.create_task(process_streams())
|
||||
|
||||
model = None
|
||||
model_path = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
|
@ -144,7 +153,27 @@ async def detect(websocket: WebSocket):
|
|||
data = json.loads(msg)
|
||||
camera_id = data.get("cameraIdentifier")
|
||||
rtsp_url = data.get("rtspUrl")
|
||||
model_url = data.get("modelUrl")
|
||||
|
||||
if model_url:
|
||||
print(f"Downloading model from {model_url}")
|
||||
parsed_url = urlparse(model_url)
|
||||
filename = os.path.basename(parsed_url.path)
|
||||
model_filename = os.path.join("models", filename)
|
||||
# Download the model
|
||||
response = requests.get(model_url, stream=True)
|
||||
if response.status_code == 200:
|
||||
with open(model_filename, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
logging.info(f"Downloaded model from {model_url} to {model_filename}")
|
||||
model = YOLO(model_filename)
|
||||
if torch.cuda.is_available():
|
||||
model.to('cuda')
|
||||
class_names = model.names
|
||||
else:
|
||||
logging.error(f"Failed to download model from {model_url}")
|
||||
continue
|
||||
if camera_id and rtsp_url:
|
||||
if camera_id not in streams and len(streams) < max_streams:
|
||||
cap = cv2.VideoCapture(rtsp_url)
|
||||
|
@ -191,4 +220,7 @@ async def detect(websocket: WebSocket):
|
|||
stream['buffer'].queue.clear()
|
||||
logging.info(f"Released camera {camera_id} and cleaned up resources")
|
||||
streams.clear()
|
||||
if model_path and os.path.exists(model_path):
|
||||
os.remove(model_path)
|
||||
logging.info(f"Deleted model file {model_path}")
|
||||
logging.info("WebSocket connection closed")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue