fix dynamic model bug
This commit is contained in:
parent
c40cea786e
commit
1af0a3213a
1 changed files with 40 additions and 29 deletions
35
app.py
35
app.py
|
@ -18,12 +18,14 @@ import psutil
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
model = YOLO("yolov8n.pt")
|
||||
if torch.cuda.is_available():
|
||||
model.to('cuda')
|
||||
# Remove global model and class_names
|
||||
# model = YOLO("yolov8n.pt")
|
||||
# if torch.cuda.is_available():
|
||||
# model.to('cuda')
|
||||
# class_names = model.names
|
||||
|
||||
# Retrieve class names from the model
|
||||
class_names = model.names
|
||||
# Introduce models dictionary
|
||||
models = {}
|
||||
|
||||
with open("config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
@ -108,7 +110,9 @@ async def detect(websocket: WebSocket):
|
|||
break
|
||||
|
||||
async def process_streams():
|
||||
global model, class_names # Added line
|
||||
# Remove global model and class_names
|
||||
# global model, class_names
|
||||
global models
|
||||
logging.info("Started processing streams")
|
||||
try:
|
||||
while True:
|
||||
|
@ -118,12 +122,17 @@ async def detect(websocket: WebSocket):
|
|||
buffer = stream['buffer']
|
||||
if not buffer.empty():
|
||||
frame = buffer.get()
|
||||
# Get the model for this stream's modelId
|
||||
model = models.get(stream['modelId'])
|
||||
if not model:
|
||||
logging.error(f"Model {stream['modelId']} not loaded for camera {camera_id}")
|
||||
continue
|
||||
results = model(frame, stream=False)
|
||||
boxes = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
boxes.append({
|
||||
"class": class_names[int(box.cls[0])],
|
||||
"class": model.names[int(box.cls[0])],
|
||||
"confidence": float(box.conf[0]),
|
||||
})
|
||||
# Broadcast to all subscribers of this URL
|
||||
|
@ -186,7 +195,9 @@ async def detect(websocket: WebSocket):
|
|||
break
|
||||
|
||||
async def on_message():
|
||||
global model, class_names # Changed from nonlocal to global
|
||||
# Change from nonlocal to global
|
||||
# global model, class_names
|
||||
global models
|
||||
while True:
|
||||
msg = await websocket.receive_text()
|
||||
logging.debug(f"Received message: {msg}")
|
||||
|
@ -202,6 +213,7 @@ async def detect(websocket: WebSocket):
|
|||
modelName = payload.get("modelName")
|
||||
|
||||
if model_url:
|
||||
if modelId not in models:
|
||||
print(f"Downloading model from {model_url}")
|
||||
parsed_url = urlparse(model_url)
|
||||
filename = os.path.basename(parsed_url.path)
|
||||
|
@ -216,7 +228,8 @@ async def detect(websocket: WebSocket):
|
|||
model = YOLO(model_filename)
|
||||
if torch.cuda.is_available():
|
||||
model.to('cuda')
|
||||
class_names = model.names
|
||||
models[modelId] = model
|
||||
logging.info(f"Loaded model {modelId} for camera {camera_id}")
|
||||
else:
|
||||
logging.error(f"Failed to download model from {model_url}")
|
||||
continue
|
||||
|
@ -374,7 +387,5 @@ async def detect(websocket: WebSocket):
|
|||
stream['buffer'].queue.clear()
|
||||
logging.info(f"Released camera {camera_id} and cleaned up resources")
|
||||
streams.clear()
|
||||
if model_path and os.path.exists(model_path):
|
||||
os.remove(model_path)
|
||||
logging.info(f"Deleted model file {model_path}")
|
||||
models.clear()
|
||||
logging.info("WebSocket connection closed")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue