dynamic model loading
This commit is contained in:
parent
af26c1477c
commit
2d361b19ff
2 changed files with 34 additions and 1 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,2 +1,3 @@
|
||||||
|
|
||||||
/__pycache__
|
/__pycache__
|
||||||
|
models
|
32
app.py
32
app.py
|
@ -10,6 +10,9 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import queue
|
import queue
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
from urllib.parse import urlparse # Added import
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -41,6 +44,9 @@ logging.basicConfig(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ensure the models directory exists
|
||||||
|
os.makedirs("models", exist_ok=True)
|
||||||
|
|
||||||
@app.websocket("/")
|
@app.websocket("/")
|
||||||
async def detect(websocket: WebSocket):
|
async def detect(websocket: WebSocket):
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -136,6 +142,9 @@ async def detect(websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
task = asyncio.create_task(process_streams())
|
task = asyncio.create_task(process_streams())
|
||||||
|
|
||||||
|
model = None
|
||||||
|
model_path = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -144,7 +153,27 @@ async def detect(websocket: WebSocket):
|
||||||
data = json.loads(msg)
|
data = json.loads(msg)
|
||||||
camera_id = data.get("cameraIdentifier")
|
camera_id = data.get("cameraIdentifier")
|
||||||
rtsp_url = data.get("rtspUrl")
|
rtsp_url = data.get("rtspUrl")
|
||||||
|
model_url = data.get("modelUrl")
|
||||||
|
|
||||||
|
if model_url:
|
||||||
|
print(f"Downloading model from {model_url}")
|
||||||
|
parsed_url = urlparse(model_url)
|
||||||
|
filename = os.path.basename(parsed_url.path)
|
||||||
|
model_filename = os.path.join("models", filename)
|
||||||
|
# Download the model
|
||||||
|
response = requests.get(model_url, stream=True)
|
||||||
|
if response.status_code == 200:
|
||||||
|
with open(model_filename, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
logging.info(f"Downloaded model from {model_url} to {model_filename}")
|
||||||
|
model = YOLO(model_filename)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
model.to('cuda')
|
||||||
|
class_names = model.names
|
||||||
|
else:
|
||||||
|
logging.error(f"Failed to download model from {model_url}")
|
||||||
|
continue
|
||||||
if camera_id and rtsp_url:
|
if camera_id and rtsp_url:
|
||||||
if camera_id not in streams and len(streams) < max_streams:
|
if camera_id not in streams and len(streams) < max_streams:
|
||||||
cap = cv2.VideoCapture(rtsp_url)
|
cap = cv2.VideoCapture(rtsp_url)
|
||||||
|
@ -191,4 +220,7 @@ async def detect(websocket: WebSocket):
|
||||||
stream['buffer'].queue.clear()
|
stream['buffer'].queue.clear()
|
||||||
logging.info(f"Released camera {camera_id} and cleaned up resources")
|
logging.info(f"Released camera {camera_id} and cleaned up resources")
|
||||||
streams.clear()
|
streams.clear()
|
||||||
|
if model_path and os.path.exists(model_path):
|
||||||
|
os.remove(model_path)
|
||||||
|
logging.info(f"Deleted model file {model_path}")
|
||||||
logging.info("WebSocket connection closed")
|
logging.info("WebSocket connection closed")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue