rearrange tracker for user modifiability
This commit is contained in:
parent
1af0a3213a
commit
e71f63369d
1 changed files with 33 additions and 39 deletions
72
app.py
72
app.py
|
@ -18,22 +18,15 @@ import psutil
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# Remove global model and class_names
|
|
||||||
# model = YOLO("yolov8n.pt")
|
|
||||||
# if torch.cuda.is_available():
|
|
||||||
# model.to('cuda')
|
|
||||||
# class_names = model.names
|
|
||||||
|
|
||||||
# Introduce models dictionary
|
|
||||||
models = {}
|
models = {}
|
||||||
|
|
||||||
with open("config.json", "r") as f:
|
with open("config.json", "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
poll_interval = config.get("poll_interval_ms", 100)
|
poll_interval = config.get("poll_interval_ms", 100)
|
||||||
reconnect_interval = config.get("reconnect_interval_sec", 5) # New setting
|
reconnect_interval = config.get("reconnect_interval_sec", 5)
|
||||||
TARGET_FPS = config.get("target_fps", 10) # Add TARGET_FPS
|
TARGET_FPS = config.get("target_fps", 10)
|
||||||
poll_interval = 1000 / TARGET_FPS # Adjust poll_interval based on TARGET_FPS
|
poll_interval = 1000 / TARGET_FPS
|
||||||
logging.info(f"Poll interval: {poll_interval}ms")
|
logging.info(f"Poll interval: {poll_interval}ms")
|
||||||
max_streams = config.get("max_streams", 5)
|
max_streams = config.get("max_streams", 5)
|
||||||
max_retries = config.get("max_retries", 3)
|
max_retries = config.get("max_retries", 3)
|
||||||
|
@ -64,6 +57,31 @@ async def detect(websocket: WebSocket):
|
||||||
|
|
||||||
streams = {}
|
streams = {}
|
||||||
|
|
||||||
|
# This function is user-modifiable
|
||||||
|
# Save data you want to persist across frames in the persistent_data dictionary
|
||||||
|
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
|
||||||
|
boxes = []
|
||||||
|
for r in model(frame, stream=False):
|
||||||
|
for box in r.boxes:
|
||||||
|
boxes.append({
|
||||||
|
"class": model.names[int(box.cls[0])],
|
||||||
|
"confidence": float(box.conf[0]),
|
||||||
|
})
|
||||||
|
# Broadcast to all subscribers of this URL
|
||||||
|
detection_data = {
|
||||||
|
"type": "imageDetection",
|
||||||
|
"cameraIdentifier": camera_id,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"data": {
|
||||||
|
"detections": boxes,
|
||||||
|
"modelId": stream['modelId'],
|
||||||
|
"modelName": stream['modelName']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
|
||||||
|
await websocket.send_json(detection_data)
|
||||||
|
return persistent_data
|
||||||
|
|
||||||
def frame_reader(camera_id, cap, buffer, stop_event):
|
def frame_reader(camera_id, cap, buffer, stop_event):
|
||||||
import time
|
import time
|
||||||
retries = 0
|
retries = 0
|
||||||
|
@ -110,10 +128,9 @@ async def detect(websocket: WebSocket):
|
||||||
break
|
break
|
||||||
|
|
||||||
async def process_streams():
|
async def process_streams():
|
||||||
# Remove global model and class_names
|
|
||||||
# global model, class_names
|
|
||||||
global models
|
global models
|
||||||
logging.info("Started processing streams")
|
logging.info("Started processing streams")
|
||||||
|
persistent_data_dict = {}
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -122,32 +139,11 @@ async def detect(websocket: WebSocket):
|
||||||
buffer = stream['buffer']
|
buffer = stream['buffer']
|
||||||
if not buffer.empty():
|
if not buffer.empty():
|
||||||
frame = buffer.get()
|
frame = buffer.get()
|
||||||
# Get the model for this stream's modelId
|
|
||||||
model = models.get(stream['modelId'])
|
model = models.get(stream['modelId'])
|
||||||
if not model:
|
key = (camera_id, stream['modelId'])
|
||||||
logging.error(f"Model {stream['modelId']} not loaded for camera {camera_id}")
|
persistent_data = persistent_data_dict.get(key, {})
|
||||||
continue
|
updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
|
||||||
results = model(frame, stream=False)
|
persistent_data_dict[key] = updated_persistent_data
|
||||||
boxes = []
|
|
||||||
for r in results:
|
|
||||||
for box in r.boxes:
|
|
||||||
boxes.append({
|
|
||||||
"class": model.names[int(box.cls[0])],
|
|
||||||
"confidence": float(box.conf[0]),
|
|
||||||
})
|
|
||||||
# Broadcast to all subscribers of this URL
|
|
||||||
detection_data = {
|
|
||||||
"type": "imageDetection",
|
|
||||||
"cameraIdentifier": camera_id,
|
|
||||||
"timestamp": time.time(),
|
|
||||||
"data": {
|
|
||||||
"detections": boxes,
|
|
||||||
"modelId": stream['modelId'],
|
|
||||||
"modelName": stream['modelName']
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
|
|
||||||
await websocket.send_json(detection_data)
|
|
||||||
elapsed_time = (time.time() - start_time) * 1000 # in ms
|
elapsed_time = (time.time() - start_time) * 1000 # in ms
|
||||||
sleep_time = max(poll_interval - elapsed_time, 0)
|
sleep_time = max(poll_interval - elapsed_time, 0)
|
||||||
logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms")
|
logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms")
|
||||||
|
@ -195,8 +191,6 @@ async def detect(websocket: WebSocket):
|
||||||
break
|
break
|
||||||
|
|
||||||
async def on_message():
|
async def on_message():
|
||||||
# Change from nonlocal to global
|
|
||||||
# global model, class_names
|
|
||||||
global models
|
global models
|
||||||
while True:
|
while True:
|
||||||
msg = await websocket.receive_text()
|
msg = await websocket.receive_text()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue