fix dynamic model bug

This commit is contained in:
Siwat Sirichai 2025-01-14 22:41:55 +07:00
parent c40cea786e
commit 1af0a3213a

35
app.py
View file

@ -18,12 +18,14 @@ import psutil
app = FastAPI()
model = YOLO("yolov8n.pt")
if torch.cuda.is_available():
model.to('cuda')
# Remove global model and class_names
# model = YOLO("yolov8n.pt")
# if torch.cuda.is_available():
# model.to('cuda')
# class_names = model.names
# Retrieve class names from the model
class_names = model.names
# Introduce models dictionary
models = {}
with open("config.json", "r") as f:
config = json.load(f)
@ -108,7 +110,9 @@ async def detect(websocket: WebSocket):
break
async def process_streams():
global model, class_names # Added line
# Remove global model and class_names
# global model, class_names
global models
logging.info("Started processing streams")
try:
while True:
@ -118,12 +122,17 @@ async def detect(websocket: WebSocket):
buffer = stream['buffer']
if not buffer.empty():
frame = buffer.get()
# Get the model for this stream's modelId
model = models.get(stream['modelId'])
if not model:
logging.error(f"Model {stream['modelId']} not loaded for camera {camera_id}")
continue
results = model(frame, stream=False)
boxes = []
for r in results:
for box in r.boxes:
boxes.append({
"class": class_names[int(box.cls[0])],
"class": model.names[int(box.cls[0])],
"confidence": float(box.conf[0]),
})
# Broadcast to all subscribers of this URL
@ -186,7 +195,9 @@ async def detect(websocket: WebSocket):
break
async def on_message():
global model, class_names # Changed from nonlocal to global
# Change from nonlocal to global
# global model, class_names
global models
while True:
msg = await websocket.receive_text()
logging.debug(f"Received message: {msg}")
@ -202,6 +213,7 @@ async def detect(websocket: WebSocket):
modelName = payload.get("modelName")
if model_url:
if modelId not in models:
print(f"Downloading model from {model_url}")
parsed_url = urlparse(model_url)
filename = os.path.basename(parsed_url.path)
@ -216,7 +228,8 @@ async def detect(websocket: WebSocket):
model = YOLO(model_filename)
if torch.cuda.is_available():
model.to('cuda')
class_names = model.names
models[modelId] = model
logging.info(f"Loaded model {modelId} for camera {camera_id}")
else:
logging.error(f"Failed to download model from {model_url}")
continue
@ -374,7 +387,5 @@ async def detect(websocket: WebSocket):
stream['buffer'].queue.clear()
logging.info(f"Released camera {camera_id} and cleaned up resources")
streams.clear()
if model_path and os.path.exists(model_path):
os.remove(model_path)
logging.info(f"Deleted model file {model_path}")
models.clear()
logging.info("WebSocket connection closed")