rearrange tracker for user modifiability
This commit is contained in:
parent
1af0a3213a
commit
e71f63369d
1 changed files with 33 additions and 39 deletions
72
app.py
72
app.py
|
@ -18,22 +18,15 @@ import psutil
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
# Remove global model and class_names
|
||||
# model = YOLO("yolov8n.pt")
|
||||
# if torch.cuda.is_available():
|
||||
# model.to('cuda')
|
||||
# class_names = model.names
|
||||
|
||||
# Introduce models dictionary
|
||||
models = {}
|
||||
|
||||
with open("config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
poll_interval = config.get("poll_interval_ms", 100)
|
||||
reconnect_interval = config.get("reconnect_interval_sec", 5) # New setting
|
||||
TARGET_FPS = config.get("target_fps", 10) # Add TARGET_FPS
|
||||
poll_interval = 1000 / TARGET_FPS # Adjust poll_interval based on TARGET_FPS
|
||||
reconnect_interval = config.get("reconnect_interval_sec", 5)
|
||||
TARGET_FPS = config.get("target_fps", 10)
|
||||
poll_interval = 1000 / TARGET_FPS
|
||||
logging.info(f"Poll interval: {poll_interval}ms")
|
||||
max_streams = config.get("max_streams", 5)
|
||||
max_retries = config.get("max_retries", 3)
|
||||
|
@ -64,6 +57,31 @@ async def detect(websocket: WebSocket):
|
|||
|
||||
streams = {}
|
||||
|
||||
# This function is user-modifiable
|
||||
# Save data you want to persist across frames in the persistent_data dictionary
|
||||
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
|
||||
boxes = []
|
||||
for r in model(frame, stream=False):
|
||||
for box in r.boxes:
|
||||
boxes.append({
|
||||
"class": model.names[int(box.cls[0])],
|
||||
"confidence": float(box.conf[0]),
|
||||
})
|
||||
# Broadcast to all subscribers of this URL
|
||||
detection_data = {
|
||||
"type": "imageDetection",
|
||||
"cameraIdentifier": camera_id,
|
||||
"timestamp": time.time(),
|
||||
"data": {
|
||||
"detections": boxes,
|
||||
"modelId": stream['modelId'],
|
||||
"modelName": stream['modelName']
|
||||
}
|
||||
}
|
||||
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
|
||||
await websocket.send_json(detection_data)
|
||||
return persistent_data
|
||||
|
||||
def frame_reader(camera_id, cap, buffer, stop_event):
|
||||
import time
|
||||
retries = 0
|
||||
|
@ -110,10 +128,9 @@ async def detect(websocket: WebSocket):
|
|||
break
|
||||
|
||||
async def process_streams():
|
||||
# Remove global model and class_names
|
||||
# global model, class_names
|
||||
global models
|
||||
logging.info("Started processing streams")
|
||||
persistent_data_dict = {}
|
||||
try:
|
||||
while True:
|
||||
start_time = time.time()
|
||||
|
@ -122,32 +139,11 @@ async def detect(websocket: WebSocket):
|
|||
buffer = stream['buffer']
|
||||
if not buffer.empty():
|
||||
frame = buffer.get()
|
||||
# Get the model for this stream's modelId
|
||||
model = models.get(stream['modelId'])
|
||||
if not model:
|
||||
logging.error(f"Model {stream['modelId']} not loaded for camera {camera_id}")
|
||||
continue
|
||||
results = model(frame, stream=False)
|
||||
boxes = []
|
||||
for r in results:
|
||||
for box in r.boxes:
|
||||
boxes.append({
|
||||
"class": model.names[int(box.cls[0])],
|
||||
"confidence": float(box.conf[0]),
|
||||
})
|
||||
# Broadcast to all subscribers of this URL
|
||||
detection_data = {
|
||||
"type": "imageDetection",
|
||||
"cameraIdentifier": camera_id,
|
||||
"timestamp": time.time(),
|
||||
"data": {
|
||||
"detections": boxes,
|
||||
"modelId": stream['modelId'],
|
||||
"modelName": stream['modelName']
|
||||
}
|
||||
}
|
||||
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
|
||||
await websocket.send_json(detection_data)
|
||||
key = (camera_id, stream['modelId'])
|
||||
persistent_data = persistent_data_dict.get(key, {})
|
||||
updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
|
||||
persistent_data_dict[key] = updated_persistent_data
|
||||
elapsed_time = (time.time() - start_time) * 1000 # in ms
|
||||
sleep_time = max(poll_interval - elapsed_time, 0)
|
||||
logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms")
|
||||
|
@ -195,8 +191,6 @@ async def detect(websocket: WebSocket):
|
|||
break
|
||||
|
||||
async def on_message():
|
||||
# Change from nonlocal to global
|
||||
# global model, class_names
|
||||
global models
|
||||
while True:
|
||||
msg = await websocket.receive_text()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue