rearrange tracker for user modifiability

This commit is contained in:
Siwat Sirichai 2025-01-14 23:00:35 +07:00
parent 1af0a3213a
commit e71f63369d

72
app.py
View file

@ -18,22 +18,15 @@ import psutil
app = FastAPI() app = FastAPI()
# Remove global model and class_names
# model = YOLO("yolov8n.pt")
# if torch.cuda.is_available():
# model.to('cuda')
# class_names = model.names
# Introduce models dictionary
models = {} models = {}
with open("config.json", "r") as f: with open("config.json", "r") as f:
config = json.load(f) config = json.load(f)
poll_interval = config.get("poll_interval_ms", 100) poll_interval = config.get("poll_interval_ms", 100)
reconnect_interval = config.get("reconnect_interval_sec", 5) # New setting reconnect_interval = config.get("reconnect_interval_sec", 5)
TARGET_FPS = config.get("target_fps", 10) # Add TARGET_FPS TARGET_FPS = config.get("target_fps", 10)
poll_interval = 1000 / TARGET_FPS # Adjust poll_interval based on TARGET_FPS poll_interval = 1000 / TARGET_FPS
logging.info(f"Poll interval: {poll_interval}ms") logging.info(f"Poll interval: {poll_interval}ms")
max_streams = config.get("max_streams", 5) max_streams = config.get("max_streams", 5)
max_retries = config.get("max_retries", 3) max_retries = config.get("max_retries", 3)
@ -64,6 +57,31 @@ async def detect(websocket: WebSocket):
streams = {} streams = {}
# This function is user-modifiable
# Save data you want to persist across frames in the persistent_data dictionary
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
boxes = []
for r in model(frame, stream=False):
for box in r.boxes:
boxes.append({
"class": model.names[int(box.cls[0])],
"confidence": float(box.conf[0]),
})
# Broadcast to all subscribers of this URL
detection_data = {
"type": "imageDetection",
"cameraIdentifier": camera_id,
"timestamp": time.time(),
"data": {
"detections": boxes,
"modelId": stream['modelId'],
"modelName": stream['modelName']
}
}
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
await websocket.send_json(detection_data)
return persistent_data
def frame_reader(camera_id, cap, buffer, stop_event): def frame_reader(camera_id, cap, buffer, stop_event):
import time import time
retries = 0 retries = 0
@ -110,10 +128,9 @@ async def detect(websocket: WebSocket):
break break
async def process_streams(): async def process_streams():
# Remove global model and class_names
# global model, class_names
global models global models
logging.info("Started processing streams") logging.info("Started processing streams")
persistent_data_dict = {}
try: try:
while True: while True:
start_time = time.time() start_time = time.time()
@ -122,32 +139,11 @@ async def detect(websocket: WebSocket):
buffer = stream['buffer'] buffer = stream['buffer']
if not buffer.empty(): if not buffer.empty():
frame = buffer.get() frame = buffer.get()
# Get the model for this stream's modelId
model = models.get(stream['modelId']) model = models.get(stream['modelId'])
if not model: key = (camera_id, stream['modelId'])
logging.error(f"Model {stream['modelId']} not loaded for camera {camera_id}") persistent_data = persistent_data_dict.get(key, {})
continue updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
results = model(frame, stream=False) persistent_data_dict[key] = updated_persistent_data
boxes = []
for r in results:
for box in r.boxes:
boxes.append({
"class": model.names[int(box.cls[0])],
"confidence": float(box.conf[0]),
})
# Broadcast to all subscribers of this URL
detection_data = {
"type": "imageDetection",
"cameraIdentifier": camera_id,
"timestamp": time.time(),
"data": {
"detections": boxes,
"modelId": stream['modelId'],
"modelName": stream['modelName']
}
}
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
await websocket.send_json(detection_data)
elapsed_time = (time.time() - start_time) * 1000 # in ms elapsed_time = (time.time() - start_time) * 1000 # in ms
sleep_time = max(poll_interval - elapsed_time, 0) sleep_time = max(poll_interval - elapsed_time, 0)
logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms") logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms")
@ -195,8 +191,6 @@ async def detect(websocket: WebSocket):
break break
async def on_message(): async def on_message():
# Change from nonlocal to global
# global model, class_names
global models global models
while True: while True:
msg = await websocket.receive_text() msg = await websocket.receive_text()