fix dynamic model bug

This commit is contained in:
Siwat Sirichai 2025-01-14 22:41:55 +07:00
parent c40cea786e
commit 1af0a3213a

69
app.py
View file

@ -18,12 +18,14 @@ import psutil
app = FastAPI() app = FastAPI()
model = YOLO("yolov8n.pt") # Remove global model and class_names
if torch.cuda.is_available(): # model = YOLO("yolov8n.pt")
model.to('cuda') # if torch.cuda.is_available():
# model.to('cuda')
# class_names = model.names
# Retrieve class names from the model # Introduce models dictionary
class_names = model.names models = {}
with open("config.json", "r") as f: with open("config.json", "r") as f:
config = json.load(f) config = json.load(f)
@ -108,7 +110,9 @@ async def detect(websocket: WebSocket):
break break
async def process_streams(): async def process_streams():
global model, class_names # Added line # Remove global model and class_names
# global model, class_names
global models
logging.info("Started processing streams") logging.info("Started processing streams")
try: try:
while True: while True:
@ -118,12 +122,17 @@ async def detect(websocket: WebSocket):
buffer = stream['buffer'] buffer = stream['buffer']
if not buffer.empty(): if not buffer.empty():
frame = buffer.get() frame = buffer.get()
# Get the model for this stream's modelId
model = models.get(stream['modelId'])
if not model:
logging.error(f"Model {stream['modelId']} not loaded for camera {camera_id}")
continue
results = model(frame, stream=False) results = model(frame, stream=False)
boxes = [] boxes = []
for r in results: for r in results:
for box in r.boxes: for box in r.boxes:
boxes.append({ boxes.append({
"class": class_names[int(box.cls[0])], "class": model.names[int(box.cls[0])],
"confidence": float(box.conf[0]), "confidence": float(box.conf[0]),
}) })
# Broadcast to all subscribers of this URL # Broadcast to all subscribers of this URL
@ -186,7 +195,9 @@ async def detect(websocket: WebSocket):
break break
async def on_message(): async def on_message():
global model, class_names # Changed from nonlocal to global # Change from nonlocal to global
# global model, class_names
global models
while True: while True:
msg = await websocket.receive_text() msg = await websocket.receive_text()
logging.debug(f"Received message: {msg}") logging.debug(f"Received message: {msg}")
@ -202,24 +213,26 @@ async def detect(websocket: WebSocket):
modelName = payload.get("modelName") modelName = payload.get("modelName")
if model_url: if model_url:
print(f"Downloading model from {model_url}") if modelId not in models:
parsed_url = urlparse(model_url) print(f"Downloading model from {model_url}")
filename = os.path.basename(parsed_url.path) parsed_url = urlparse(model_url)
model_filename = os.path.join("models", filename) filename = os.path.basename(parsed_url.path)
# Download the model model_filename = os.path.join("models", filename)
response = requests.get(model_url, stream=True) # Download the model
if response.status_code == 200: response = requests.get(model_url, stream=True)
with open(model_filename, 'wb') as f: if response.status_code == 200:
for chunk in response.iter_content(chunk_size=8192): with open(model_filename, 'wb') as f:
f.write(chunk) for chunk in response.iter_content(chunk_size=8192):
logging.info(f"Downloaded model from {model_url} to {model_filename}") f.write(chunk)
model = YOLO(model_filename) logging.info(f"Downloaded model from {model_url} to {model_filename}")
if torch.cuda.is_available(): model = YOLO(model_filename)
model.to('cuda') if torch.cuda.is_available():
class_names = model.names model.to('cuda')
else: models[modelId] = model
logging.error(f"Failed to download model from {model_url}") logging.info(f"Loaded model {modelId} for camera {camera_id}")
continue else:
logging.error(f"Failed to download model from {model_url}")
continue
if camera_id and rtsp_url: if camera_id and rtsp_url:
if camera_id not in streams and len(streams) < max_streams: if camera_id not in streams and len(streams) < max_streams:
cap = cv2.VideoCapture(rtsp_url) cap = cv2.VideoCapture(rtsp_url)
@ -374,7 +387,5 @@ async def detect(websocket: WebSocket):
stream['buffer'].queue.clear() stream['buffer'].queue.clear()
logging.info(f"Released camera {camera_id} and cleaned up resources") logging.info(f"Released camera {camera_id} and cleaned up resources")
streams.clear() streams.clear()
if model_path and os.path.exists(model_path): models.clear()
os.remove(model_path)
logging.info(f"Deleted model file {model_path}")
logging.info("WebSocket connection closed") logging.info("WebSocket connection closed")