fix dynamic model bug
This commit is contained in:
parent
c40cea786e
commit
1af0a3213a
1 changed files with 40 additions and 29 deletions
35
app.py
35
app.py
|
@ -18,12 +18,14 @@ import psutil
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
model = YOLO("yolov8n.pt")
|
# Remove global model and class_names
|
||||||
if torch.cuda.is_available():
|
# model = YOLO("yolov8n.pt")
|
||||||
model.to('cuda')
|
# if torch.cuda.is_available():
|
||||||
|
# model.to('cuda')
|
||||||
|
# class_names = model.names
|
||||||
|
|
||||||
# Retrieve class names from the model
|
# Introduce models dictionary
|
||||||
class_names = model.names
|
models = {}
|
||||||
|
|
||||||
with open("config.json", "r") as f:
|
with open("config.json", "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
@ -108,7 +110,9 @@ async def detect(websocket: WebSocket):
|
||||||
break
|
break
|
||||||
|
|
||||||
async def process_streams():
|
async def process_streams():
|
||||||
global model, class_names # Added line
|
# Remove global model and class_names
|
||||||
|
# global model, class_names
|
||||||
|
global models
|
||||||
logging.info("Started processing streams")
|
logging.info("Started processing streams")
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
@ -118,12 +122,17 @@ async def detect(websocket: WebSocket):
|
||||||
buffer = stream['buffer']
|
buffer = stream['buffer']
|
||||||
if not buffer.empty():
|
if not buffer.empty():
|
||||||
frame = buffer.get()
|
frame = buffer.get()
|
||||||
|
# Get the model for this stream's modelId
|
||||||
|
model = models.get(stream['modelId'])
|
||||||
|
if not model:
|
||||||
|
logging.error(f"Model {stream['modelId']} not loaded for camera {camera_id}")
|
||||||
|
continue
|
||||||
results = model(frame, stream=False)
|
results = model(frame, stream=False)
|
||||||
boxes = []
|
boxes = []
|
||||||
for r in results:
|
for r in results:
|
||||||
for box in r.boxes:
|
for box in r.boxes:
|
||||||
boxes.append({
|
boxes.append({
|
||||||
"class": class_names[int(box.cls[0])],
|
"class": model.names[int(box.cls[0])],
|
||||||
"confidence": float(box.conf[0]),
|
"confidence": float(box.conf[0]),
|
||||||
})
|
})
|
||||||
# Broadcast to all subscribers of this URL
|
# Broadcast to all subscribers of this URL
|
||||||
|
@ -186,7 +195,9 @@ async def detect(websocket: WebSocket):
|
||||||
break
|
break
|
||||||
|
|
||||||
async def on_message():
|
async def on_message():
|
||||||
global model, class_names # Changed from nonlocal to global
|
# Change from nonlocal to global
|
||||||
|
# global model, class_names
|
||||||
|
global models
|
||||||
while True:
|
while True:
|
||||||
msg = await websocket.receive_text()
|
msg = await websocket.receive_text()
|
||||||
logging.debug(f"Received message: {msg}")
|
logging.debug(f"Received message: {msg}")
|
||||||
|
@ -202,6 +213,7 @@ async def detect(websocket: WebSocket):
|
||||||
modelName = payload.get("modelName")
|
modelName = payload.get("modelName")
|
||||||
|
|
||||||
if model_url:
|
if model_url:
|
||||||
|
if modelId not in models:
|
||||||
print(f"Downloading model from {model_url}")
|
print(f"Downloading model from {model_url}")
|
||||||
parsed_url = urlparse(model_url)
|
parsed_url = urlparse(model_url)
|
||||||
filename = os.path.basename(parsed_url.path)
|
filename = os.path.basename(parsed_url.path)
|
||||||
|
@ -216,7 +228,8 @@ async def detect(websocket: WebSocket):
|
||||||
model = YOLO(model_filename)
|
model = YOLO(model_filename)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
model.to('cuda')
|
model.to('cuda')
|
||||||
class_names = model.names
|
models[modelId] = model
|
||||||
|
logging.info(f"Loaded model {modelId} for camera {camera_id}")
|
||||||
else:
|
else:
|
||||||
logging.error(f"Failed to download model from {model_url}")
|
logging.error(f"Failed to download model from {model_url}")
|
||||||
continue
|
continue
|
||||||
|
@ -374,7 +387,5 @@ async def detect(websocket: WebSocket):
|
||||||
stream['buffer'].queue.clear()
|
stream['buffer'].queue.clear()
|
||||||
logging.info(f"Released camera {camera_id} and cleaned up resources")
|
logging.info(f"Released camera {camera_id} and cleaned up resources")
|
||||||
streams.clear()
|
streams.clear()
|
||||||
if model_path and os.path.exists(model_path):
|
models.clear()
|
||||||
os.remove(model_path)
|
|
||||||
logging.info(f"Deleted model file {model_path}")
|
|
||||||
logging.info("WebSocket connection closed")
|
logging.info("WebSocket connection closed")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue