rearrange tracker for user modifiability
This commit is contained in:
		
							parent
							
								
									1af0a3213a
								
							
						
					
					
						commit
						e71f63369d
					
				
					 1 changed files with 33 additions and 39 deletions
				
			
		
							
								
								
									
										72
									
								
								app.py
									
										
									
									
									
								
							
							
						
						
									
										72
									
								
								app.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -18,22 +18,15 @@ import psutil
 | 
			
		|||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
# Remove global model and class_names
 | 
			
		||||
# model = YOLO("yolov8n.pt")
 | 
			
		||||
# if torch.cuda.is_available():
 | 
			
		||||
#     model.to('cuda')
 | 
			
		||||
# class_names = model.names
 | 
			
		||||
 | 
			
		||||
# Introduce models dictionary
 | 
			
		||||
models = {}
 | 
			
		||||
 | 
			
		||||
with open("config.json", "r") as f:
 | 
			
		||||
    config = json.load(f)
 | 
			
		||||
 | 
			
		||||
poll_interval = config.get("poll_interval_ms", 100)
 | 
			
		||||
reconnect_interval = config.get("reconnect_interval_sec", 5)  # New setting
 | 
			
		||||
TARGET_FPS = config.get("target_fps", 10)  # Add TARGET_FPS
 | 
			
		||||
poll_interval = 1000 / TARGET_FPS  # Adjust poll_interval based on TARGET_FPS
 | 
			
		||||
reconnect_interval = config.get("reconnect_interval_sec", 5) 
 | 
			
		||||
TARGET_FPS = config.get("target_fps", 10)
 | 
			
		||||
poll_interval = 1000 / TARGET_FPS
 | 
			
		||||
logging.info(f"Poll interval: {poll_interval}ms")
 | 
			
		||||
max_streams = config.get("max_streams", 5)
 | 
			
		||||
max_retries = config.get("max_retries", 3)
 | 
			
		||||
| 
						 | 
				
			
			@ -64,6 +57,31 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
 | 
			
		||||
    streams = {}
 | 
			
		||||
 | 
			
		||||
    # This function is user-modifiable
 | 
			
		||||
    # Save data you want to persist across frames in the persistent_data dictionary
 | 
			
		||||
    async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
 | 
			
		||||
        boxes = []
 | 
			
		||||
        for r in model(frame, stream=False):
 | 
			
		||||
            for box in r.boxes:
 | 
			
		||||
                boxes.append({
 | 
			
		||||
                    "class": model.names[int(box.cls[0])],
 | 
			
		||||
                    "confidence": float(box.conf[0]),
 | 
			
		||||
                })
 | 
			
		||||
        # Broadcast to all subscribers of this URL
 | 
			
		||||
        detection_data = {
 | 
			
		||||
            "type": "imageDetection",
 | 
			
		||||
            "cameraIdentifier": camera_id,
 | 
			
		||||
            "timestamp": time.time(),
 | 
			
		||||
            "data": {
 | 
			
		||||
                "detections": boxes,
 | 
			
		||||
                "modelId": stream['modelId'],
 | 
			
		||||
                "modelName": stream['modelName']
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
 | 
			
		||||
        await websocket.send_json(detection_data)
 | 
			
		||||
        return persistent_data
 | 
			
		||||
 | 
			
		||||
    def frame_reader(camera_id, cap, buffer, stop_event):
 | 
			
		||||
        import time
 | 
			
		||||
        retries = 0
 | 
			
		||||
| 
						 | 
				
			
			@ -110,10 +128,9 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                break
 | 
			
		||||
 | 
			
		||||
    async def process_streams():
 | 
			
		||||
        # Remove global model and class_names
 | 
			
		||||
        # global model, class_names
 | 
			
		||||
        global models
 | 
			
		||||
        logging.info("Started processing streams")
 | 
			
		||||
        persistent_data_dict = {} 
 | 
			
		||||
        try:
 | 
			
		||||
            while True:
 | 
			
		||||
                start_time = time.time()
 | 
			
		||||
| 
						 | 
				
			
			@ -122,32 +139,11 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                    buffer = stream['buffer']
 | 
			
		||||
                    if not buffer.empty():
 | 
			
		||||
                        frame = buffer.get()
 | 
			
		||||
                        # Get the model for this stream's modelId
 | 
			
		||||
                        model = models.get(stream['modelId'])
 | 
			
		||||
                        if not model:
 | 
			
		||||
                            logging.error(f"Model {stream['modelId']} not loaded for camera {camera_id}")
 | 
			
		||||
                            continue
 | 
			
		||||
                        results = model(frame, stream=False)
 | 
			
		||||
                        boxes = []
 | 
			
		||||
                        for r in results:
 | 
			
		||||
                            for box in r.boxes:
 | 
			
		||||
                                boxes.append({
 | 
			
		||||
                                    "class": model.names[int(box.cls[0])],
 | 
			
		||||
                                    "confidence": float(box.conf[0]),
 | 
			
		||||
                                })
 | 
			
		||||
                        # Broadcast to all subscribers of this URL
 | 
			
		||||
                        detection_data = {
 | 
			
		||||
                            "type": "imageDetection",
 | 
			
		||||
                            "cameraIdentifier": camera_id,
 | 
			
		||||
                            "timestamp": time.time(),
 | 
			
		||||
                            "data": {
 | 
			
		||||
                                "detections": boxes,
 | 
			
		||||
                                "modelId": stream['modelId'],
 | 
			
		||||
                                "modelName": stream['modelName']
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                        logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
 | 
			
		||||
                        await websocket.send_json(detection_data)
 | 
			
		||||
                        key = (camera_id, stream['modelId'])
 | 
			
		||||
                        persistent_data = persistent_data_dict.get(key, {})
 | 
			
		||||
                        updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
 | 
			
		||||
                        persistent_data_dict[key] = updated_persistent_data
 | 
			
		||||
                elapsed_time = (time.time() - start_time) * 1000  # in ms
 | 
			
		||||
                sleep_time = max(poll_interval - elapsed_time, 0)
 | 
			
		||||
                logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms")
 | 
			
		||||
| 
						 | 
				
			
			@ -195,8 +191,6 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                break
 | 
			
		||||
 | 
			
		||||
    async def on_message():
 | 
			
		||||
        # Change from nonlocal to global
 | 
			
		||||
        # global model, class_names
 | 
			
		||||
        global models
 | 
			
		||||
        while True:
 | 
			
		||||
            msg = await websocket.receive_text()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue