feat: enhance session management in worker communication protocol; implement session ID handling and crop frame processing
This commit is contained in:
		
							parent
							
								
									c7bb46e1e3
								
							
						
					
					
						commit
						428f7a9671
					
				
					 2 changed files with 334 additions and 82 deletions
				
			
		
							
								
								
									
										285
									
								
								app.py
									
										
									
									
									
								
							
							
						
						
									
										285
									
								
								app.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -29,6 +29,12 @@ app = FastAPI()
 | 
			
		|||
# "models" now holds a nested dict: { camera_id: { modelId: model_tree } }
 | 
			
		||||
models: Dict[str, Dict[str, Any]] = {}
 | 
			
		||||
streams: Dict[str, Dict[str, Any]] = {}
 | 
			
		||||
# Store session IDs per display
 | 
			
		||||
session_ids: Dict[str, int] = {}
 | 
			
		||||
# Track shared camera streams by camera URL
 | 
			
		||||
camera_streams: Dict[str, Dict[str, Any]] = {}
 | 
			
		||||
# Map subscriptions to their camera URL
 | 
			
		||||
subscription_to_camera: Dict[str, str] = {}
 | 
			
		||||
 | 
			
		||||
with open("config.json", "r") as f:
 | 
			
		||||
    config = json.load(f)
 | 
			
		||||
| 
						 | 
				
			
			@ -184,9 +190,16 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
 | 
			
		||||
    async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data):
 | 
			
		||||
        try:
 | 
			
		||||
            # Apply crop if specified
 | 
			
		||||
            cropped_frame = frame
 | 
			
		||||
            if all(coord is not None for coord in [stream.get("cropX1"), stream.get("cropY1"), stream.get("cropX2"), stream.get("cropY2")]):
 | 
			
		||||
                cropX1, cropY1, cropX2, cropY2 = stream["cropX1"], stream["cropY1"], stream["cropX2"], stream["cropY2"]
 | 
			
		||||
                cropped_frame = frame[cropY1:cropY2, cropX1:cropX2]
 | 
			
		||||
                logger.debug(f"Applied crop coordinates ({cropX1}, {cropY1}, {cropX2}, {cropY2}) to frame for camera {camera_id}")
 | 
			
		||||
            
 | 
			
		||||
            logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}")
 | 
			
		||||
            start_time = time.time()
 | 
			
		||||
            detection_result = run_pipeline(frame, model_tree)
 | 
			
		||||
            detection_result = run_pipeline(cropped_frame, model_tree)
 | 
			
		||||
            process_time = (time.time() - start_time) * 1000
 | 
			
		||||
            logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms")
 | 
			
		||||
            
 | 
			
		||||
| 
						 | 
				
			
			@ -235,22 +248,48 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                    "box": [0, 0, 0, 0]
 | 
			
		||||
                }
 | 
			
		||||
            
 | 
			
		||||
            # Convert detection format to match protocol - flatten detection attributes
 | 
			
		||||
            detection_dict = {}
 | 
			
		||||
            
 | 
			
		||||
            # Handle different detection result formats
 | 
			
		||||
            if isinstance(highest_confidence_detection, dict):
 | 
			
		||||
                # Copy all fields from the detection result
 | 
			
		||||
                for key, value in highest_confidence_detection.items():
 | 
			
		||||
                    if key not in ["box", "id"]:  # Skip internal fields
 | 
			
		||||
                        detection_dict[key] = value
 | 
			
		||||
            
 | 
			
		||||
            # Extract display identifier for session ID lookup
 | 
			
		||||
            subscription_parts = stream["subscriptionIdentifier"].split(';')
 | 
			
		||||
            display_identifier = subscription_parts[0] if subscription_parts else None
 | 
			
		||||
            session_id = session_ids.get(display_identifier) if display_identifier else None
 | 
			
		||||
            
 | 
			
		||||
            detection_data = {
 | 
			
		||||
                "type": "imageDetection",
 | 
			
		||||
                "subscriptionIdentifier": stream["subscriptionIdentifier"],
 | 
			
		||||
                "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()),
 | 
			
		||||
                "data": {
 | 
			
		||||
                    "detection": highest_confidence_detection,  # Send only the highest confidence detection
 | 
			
		||||
                    "detection": detection_dict,
 | 
			
		||||
                    "modelId": stream["modelId"],
 | 
			
		||||
                    "modelName": stream["modelName"]
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            
 | 
			
		||||
            # Add session ID if available
 | 
			
		||||
            if session_id is not None:
 | 
			
		||||
                detection_data["sessionId"] = session_id
 | 
			
		||||
            
 | 
			
		||||
            if highest_confidence_detection["class"] != "none":
 | 
			
		||||
                logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}")
 | 
			
		||||
                
 | 
			
		||||
                # Log session ID if available
 | 
			
		||||
                subscription_parts = stream["subscriptionIdentifier"].split(';')
 | 
			
		||||
                display_identifier = subscription_parts[0] if subscription_parts else None
 | 
			
		||||
                session_id = session_ids.get(display_identifier) if display_identifier else None
 | 
			
		||||
                if session_id:
 | 
			
		||||
                    logger.debug(f"Detection associated with session ID: {session_id}")
 | 
			
		||||
            
 | 
			
		||||
            await websocket.send_json(detection_data)
 | 
			
		||||
            logger.debug(f"Sent detection data to client for camera {camera_id}:\n{json.dumps(detection_data, indent=2)}")
 | 
			
		||||
            logger.debug(f"Sent detection data to client for camera {camera_id}")
 | 
			
		||||
            return persistent_data
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True)
 | 
			
		||||
| 
						 | 
				
			
			@ -521,50 +560,58 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                    cropX2 = payload.get("cropX2")
 | 
			
		||||
                    cropY2 = payload.get("cropY2")
 | 
			
		||||
 | 
			
		||||
                    camera_id = subscriptionIdentifier  # Use subscriptionIdentifier as camera_id for mapping
 | 
			
		||||
                    # Extract camera_id from subscriptionIdentifier (format: displayIdentifier;cameraIdentifier)
 | 
			
		||||
                    parts = subscriptionIdentifier.split(';')
 | 
			
		||||
                    if len(parts) != 2:
 | 
			
		||||
                        logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}")
 | 
			
		||||
                        continue
 | 
			
		||||
                    
 | 
			
		||||
                    display_identifier, camera_identifier = parts
 | 
			
		||||
                    camera_id = subscriptionIdentifier  # Use full subscriptionIdentifier as camera_id for mapping
 | 
			
		||||
 | 
			
		||||
                    if model_url:
 | 
			
		||||
                        with models_lock:
 | 
			
		||||
                            if (camera_id not in models) or (modelId not in models[camera_id]):
 | 
			
		||||
                                logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
 | 
			
		||||
                                extraction_dir = os.path.join("models", camera_id, str(modelId))
 | 
			
		||||
                                extraction_dir = os.path.join("models", camera_identifier, str(modelId))
 | 
			
		||||
                                os.makedirs(extraction_dir, exist_ok=True)
 | 
			
		||||
                                # If model_url is remote, download it first.
 | 
			
		||||
                                parsed = urlparse(model_url)
 | 
			
		||||
                                if parsed.scheme in ("http", "https"):
 | 
			
		||||
                                    logger.info(f"Downloading remote model from {model_url}")
 | 
			
		||||
                                    local_mpta = os.path.join(extraction_dir, os.path.basename(parsed.path))
 | 
			
		||||
                                    logger.info(f"Downloading remote .mpta file from {model_url}")
 | 
			
		||||
                                    filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
 | 
			
		||||
                                    local_mpta = os.path.join(extraction_dir, filename)
 | 
			
		||||
                                    logger.debug(f"Download destination: {local_mpta}")
 | 
			
		||||
                                    local_path = download_mpta(model_url, local_mpta)
 | 
			
		||||
                                    if not local_path:
 | 
			
		||||
                                        logger.error(f"Failed to download the remote mpta file from {model_url}")
 | 
			
		||||
                                        logger.error(f"Failed to download the remote .mpta file from {model_url}")
 | 
			
		||||
                                        error_response = {
 | 
			
		||||
                                            "type": "error",
 | 
			
		||||
                                            "cameraIdentifier": camera_id,
 | 
			
		||||
                                            "subscriptionIdentifier": subscriptionIdentifier,
 | 
			
		||||
                                            "error": f"Failed to download model from {model_url}"
 | 
			
		||||
                                        }
 | 
			
		||||
                                        await websocket.send_json(error_response)
 | 
			
		||||
                                        continue
 | 
			
		||||
                                    model_tree = load_pipeline_from_zip(local_path, extraction_dir)
 | 
			
		||||
                                else:
 | 
			
		||||
                                    logger.info(f"Loading local model from {model_url}")
 | 
			
		||||
                                    logger.info(f"Loading local .mpta file from {model_url}")
 | 
			
		||||
                                    # Check if file exists before attempting to load
 | 
			
		||||
                                    if not os.path.exists(model_url):
 | 
			
		||||
                                        logger.error(f"Local model file not found: {model_url}")
 | 
			
		||||
                                        logger.error(f"Local .mpta file not found: {model_url}")
 | 
			
		||||
                                        logger.debug(f"Current working directory: {os.getcwd()}")
 | 
			
		||||
                                        error_response = {
 | 
			
		||||
                                            "type": "error",
 | 
			
		||||
                                            "cameraIdentifier": camera_id,
 | 
			
		||||
                                            "subscriptionIdentifier": subscriptionIdentifier,
 | 
			
		||||
                                            "error": f"Model file not found: {model_url}"
 | 
			
		||||
                                        }
 | 
			
		||||
                                        await websocket.send_json(error_response)
 | 
			
		||||
                                        continue
 | 
			
		||||
                                    model_tree = load_pipeline_from_zip(model_url, extraction_dir)
 | 
			
		||||
                                if model_tree is None:
 | 
			
		||||
                                    logger.error(f"Failed to load model {modelId} from mpta file for camera {camera_id}")
 | 
			
		||||
                                    logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}")
 | 
			
		||||
                                    error_response = {
 | 
			
		||||
                                        "type": "error",
 | 
			
		||||
                                        "cameraIdentifier": camera_id,
 | 
			
		||||
                                        "subscriptionIdentifier": subscriptionIdentifier,
 | 
			
		||||
                                        "error": f"Failed to load model {modelId}"
 | 
			
		||||
                                    }
 | 
			
		||||
                                    await websocket.send_json(error_response)
 | 
			
		||||
| 
						 | 
				
			
			@ -573,20 +620,80 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                                    models[camera_id] = {}
 | 
			
		||||
                                models[camera_id][modelId] = model_tree
 | 
			
		||||
                                logger.info(f"Successfully loaded model {modelId} for camera {camera_id}")
 | 
			
		||||
                                success_response = {
 | 
			
		||||
                                    "type": "modelLoaded",
 | 
			
		||||
                                    "cameraIdentifier": camera_id,
 | 
			
		||||
                                    "modelId": modelId
 | 
			
		||||
                                }
 | 
			
		||||
                                await websocket.send_json(success_response)
 | 
			
		||||
                                logger.debug(f"Model extraction directory: {extraction_dir}")
 | 
			
		||||
                    if camera_id and (rtsp_url or snapshot_url):
 | 
			
		||||
                        with streams_lock:
 | 
			
		||||
                            # Determine camera URL for shared stream management
 | 
			
		||||
                            camera_url = snapshot_url if snapshot_url else rtsp_url
 | 
			
		||||
                            
 | 
			
		||||
                            if camera_id not in streams and len(streams) < max_streams:
 | 
			
		||||
                                # Check if we already have a stream for this camera URL
 | 
			
		||||
                                shared_stream = camera_streams.get(camera_url)
 | 
			
		||||
                                
 | 
			
		||||
                                if shared_stream:
 | 
			
		||||
                                    # Reuse existing stream
 | 
			
		||||
                                    logger.info(f"Reusing existing stream for camera URL: {camera_url}")
 | 
			
		||||
                                    buffer = shared_stream["buffer"]
 | 
			
		||||
                                    stop_event = shared_stream["stop_event"]
 | 
			
		||||
                                    thread = shared_stream["thread"]
 | 
			
		||||
                                    mode = shared_stream["mode"]
 | 
			
		||||
                                    
 | 
			
		||||
                                    # Increment reference count
 | 
			
		||||
                                    shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1
 | 
			
		||||
                                else:
 | 
			
		||||
                                    # Create new stream
 | 
			
		||||
                                    buffer = queue.Queue(maxsize=1)
 | 
			
		||||
                                    stop_event = threading.Event()
 | 
			
		||||
                                    
 | 
			
		||||
                                    if snapshot_url and snapshot_interval:
 | 
			
		||||
                                        logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}")
 | 
			
		||||
                                        thread = threading.Thread(target=snapshot_reader, args=(camera_identifier, snapshot_url, snapshot_interval, buffer, stop_event))
 | 
			
		||||
                                        thread.daemon = True
 | 
			
		||||
                                        thread.start()
 | 
			
		||||
                                        mode = "snapshot"
 | 
			
		||||
                                        
 | 
			
		||||
                                        # Store shared stream info
 | 
			
		||||
                                        shared_stream = {
 | 
			
		||||
                                            "buffer": buffer,
 | 
			
		||||
                                            "thread": thread,
 | 
			
		||||
                                            "stop_event": stop_event,
 | 
			
		||||
                                            "mode": mode,
 | 
			
		||||
                                            "url": snapshot_url,
 | 
			
		||||
                                            "snapshot_interval": snapshot_interval,
 | 
			
		||||
                                            "ref_count": 1
 | 
			
		||||
                                        }
 | 
			
		||||
                                        camera_streams[camera_url] = shared_stream
 | 
			
		||||
                                        
 | 
			
		||||
                                    elif rtsp_url:
 | 
			
		||||
                                        logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}")
 | 
			
		||||
                                        cap = cv2.VideoCapture(rtsp_url)
 | 
			
		||||
                                        if not cap.isOpened():
 | 
			
		||||
                                            logger.error(f"Failed to open RTSP stream for camera {camera_id}")
 | 
			
		||||
                                            continue
 | 
			
		||||
                                        thread = threading.Thread(target=frame_reader, args=(camera_identifier, cap, buffer, stop_event))
 | 
			
		||||
                                        thread.daemon = True
 | 
			
		||||
                                        thread.start()
 | 
			
		||||
                                        mode = "rtsp"
 | 
			
		||||
                                        
 | 
			
		||||
                                        # Store shared stream info
 | 
			
		||||
                                        shared_stream = {
 | 
			
		||||
                                            "buffer": buffer,
 | 
			
		||||
                                            "thread": thread,
 | 
			
		||||
                                            "stop_event": stop_event,
 | 
			
		||||
                                            "mode": mode,
 | 
			
		||||
                                            "url": rtsp_url,
 | 
			
		||||
                                            "cap": cap,
 | 
			
		||||
                                            "ref_count": 1
 | 
			
		||||
                                        }
 | 
			
		||||
                                        camera_streams[camera_url] = shared_stream
 | 
			
		||||
                                    else:
 | 
			
		||||
                                        logger.error(f"No valid URL provided for camera {camera_id}")
 | 
			
		||||
                                        continue
 | 
			
		||||
                                
 | 
			
		||||
                                # Create stream info for this subscription
 | 
			
		||||
                                stream_info = {
 | 
			
		||||
                                    "buffer": buffer,
 | 
			
		||||
                                    "thread": None,
 | 
			
		||||
                                    "thread": thread,
 | 
			
		||||
                                    "stop_event": stop_event,
 | 
			
		||||
                                    "modelId": modelId,
 | 
			
		||||
                                    "modelName": modelName,
 | 
			
		||||
| 
						 | 
				
			
			@ -594,52 +701,25 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                                    "cropX1": cropX1,
 | 
			
		||||
                                    "cropY1": cropY1,
 | 
			
		||||
                                    "cropX2": cropX2,
 | 
			
		||||
                                    "cropY2": cropY2
 | 
			
		||||
                                    "cropY2": cropY2,
 | 
			
		||||
                                    "mode": mode,
 | 
			
		||||
                                    "camera_url": camera_url
 | 
			
		||||
                                }
 | 
			
		||||
                                if snapshot_url and snapshot_interval:
 | 
			
		||||
                                    logger.info(f"Using snapshot mode for camera {camera_id}: {snapshot_url}")
 | 
			
		||||
                                    thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
 | 
			
		||||
                                    thread.daemon = True
 | 
			
		||||
                                    thread.start()
 | 
			
		||||
                                    stream_info.update({
 | 
			
		||||
                                        "snapshot_url": snapshot_url,
 | 
			
		||||
                                        "snapshot_interval": snapshot_interval,
 | 
			
		||||
                                        "mode": "snapshot"
 | 
			
		||||
                                    })
 | 
			
		||||
                                    stream_info["thread"] = thread
 | 
			
		||||
                                
 | 
			
		||||
                                if mode == "snapshot":
 | 
			
		||||
                                    stream_info["snapshot_url"] = snapshot_url
 | 
			
		||||
                                    stream_info["snapshot_interval"] = snapshot_interval
 | 
			
		||||
                                elif mode == "rtsp":
 | 
			
		||||
                                    stream_info["rtsp_url"] = rtsp_url
 | 
			
		||||
                                    stream_info["cap"] = shared_stream["cap"]
 | 
			
		||||
                                
 | 
			
		||||
                                streams[camera_id] = stream_info
 | 
			
		||||
                                elif rtsp_url:
 | 
			
		||||
                                    logger.info(f"Using RTSP mode for camera {camera_id}: {rtsp_url}")
 | 
			
		||||
                                    cap = cv2.VideoCapture(rtsp_url)
 | 
			
		||||
                                    if not cap.isOpened():
 | 
			
		||||
                                        logger.error(f"Failed to open RTSP stream for camera {camera_id}")
 | 
			
		||||
                                        continue
 | 
			
		||||
                                    thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
 | 
			
		||||
                                    thread.daemon = True
 | 
			
		||||
                                    thread.start()
 | 
			
		||||
                                    stream_info.update({
 | 
			
		||||
                                        "cap": cap,
 | 
			
		||||
                                        "rtsp_url": rtsp_url,
 | 
			
		||||
                                        "mode": "rtsp"
 | 
			
		||||
                                    })
 | 
			
		||||
                                    stream_info["thread"] = thread
 | 
			
		||||
                                    streams[camera_id] = stream_info
 | 
			
		||||
                                else:
 | 
			
		||||
                                    logger.error(f"No valid URL provided for camera {camera_id}")
 | 
			
		||||
                                    continue
 | 
			
		||||
                                subscription_to_camera[camera_id] = camera_url
 | 
			
		||||
                                
 | 
			
		||||
                            elif camera_id and camera_id in streams:
 | 
			
		||||
                                # If already subscribed, unsubscribe first
 | 
			
		||||
                                stream = streams.pop(camera_id)
 | 
			
		||||
                                stream["stop_event"].set()
 | 
			
		||||
                                stream["thread"].join()
 | 
			
		||||
                                if "cap" in stream:
 | 
			
		||||
                                    stream["cap"].release()
 | 
			
		||||
                                logger.info(f"Unsubscribed from camera {camera_id} for resubscription")
 | 
			
		||||
                                with models_lock:
 | 
			
		||||
                                    if camera_id in models and modelId in models[camera_id]:
 | 
			
		||||
                                        del models[camera_id][modelId]
 | 
			
		||||
                                        if not models[camera_id]:
 | 
			
		||||
                                            del models[camera_id]
 | 
			
		||||
                                logger.info(f"Resubscribing to camera {camera_id}")
 | 
			
		||||
                                # Note: Keep models in memory for reuse across subscriptions
 | 
			
		||||
                elif msg_type == "unsubscribe":
 | 
			
		||||
                    payload = data.get("payload", {})
 | 
			
		||||
                    subscriptionIdentifier = payload.get("subscriptionIdentifier")
 | 
			
		||||
| 
						 | 
				
			
			@ -647,13 +727,25 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                    with streams_lock:
 | 
			
		||||
                        if camera_id and camera_id in streams:
 | 
			
		||||
                            stream = streams.pop(camera_id)
 | 
			
		||||
                            stream["stop_event"].set()
 | 
			
		||||
                            stream["thread"].join()
 | 
			
		||||
                            if "cap" in stream:
 | 
			
		||||
                                stream["cap"].release()
 | 
			
		||||
                            with models_lock:
 | 
			
		||||
                                if camera_id in models:
 | 
			
		||||
                                    del models[camera_id]
 | 
			
		||||
                            camera_url = subscription_to_camera.pop(camera_id, None)
 | 
			
		||||
                            
 | 
			
		||||
                            if camera_url and camera_url in camera_streams:
 | 
			
		||||
                                shared_stream = camera_streams[camera_url]
 | 
			
		||||
                                shared_stream["ref_count"] -= 1
 | 
			
		||||
                                
 | 
			
		||||
                                # If no more references, stop the shared stream
 | 
			
		||||
                                if shared_stream["ref_count"] <= 0:
 | 
			
		||||
                                    logger.info(f"Stopping shared stream for camera URL: {camera_url}")
 | 
			
		||||
                                    shared_stream["stop_event"].set()
 | 
			
		||||
                                    shared_stream["thread"].join()
 | 
			
		||||
                                    if "cap" in shared_stream:
 | 
			
		||||
                                        shared_stream["cap"].release()
 | 
			
		||||
                                    del camera_streams[camera_url]
 | 
			
		||||
                                else:
 | 
			
		||||
                                    logger.info(f"Shared stream for {camera_url} still has {shared_stream['ref_count']} references")
 | 
			
		||||
                            
 | 
			
		||||
                            logger.info(f"Unsubscribed from camera {camera_id}")
 | 
			
		||||
                            # Note: Keep models in memory for potential reuse
 | 
			
		||||
                elif msg_type == "requestState":
 | 
			
		||||
                    cpu_usage = psutil.cpu_percent()
 | 
			
		||||
                    memory_usage = psutil.virtual_memory().percent
 | 
			
		||||
| 
						 | 
				
			
			@ -684,6 +776,37 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                        "cameraConnections": camera_connections
 | 
			
		||||
                    }
 | 
			
		||||
                    await websocket.send_text(json.dumps(state_report))
 | 
			
		||||
                
 | 
			
		||||
                elif msg_type == "setSessionId":
 | 
			
		||||
                    payload = data.get("payload", {})
 | 
			
		||||
                    display_identifier = payload.get("displayIdentifier")
 | 
			
		||||
                    session_id = payload.get("sessionId")
 | 
			
		||||
                    
 | 
			
		||||
                    if display_identifier:
 | 
			
		||||
                        # Store session ID for this display
 | 
			
		||||
                        if session_id is None:
 | 
			
		||||
                            session_ids.pop(display_identifier, None)
 | 
			
		||||
                            logger.info(f"Cleared session ID for display {display_identifier}")
 | 
			
		||||
                        else:
 | 
			
		||||
                            session_ids[display_identifier] = session_id
 | 
			
		||||
                            logger.info(f"Set session ID {session_id} for display {display_identifier}")
 | 
			
		||||
                
 | 
			
		||||
                elif msg_type == "patchSession":
 | 
			
		||||
                    session_id = data.get("sessionId")
 | 
			
		||||
                    patch_data = data.get("data", {})
 | 
			
		||||
                    
 | 
			
		||||
                    # For now, just acknowledge the patch - actual implementation depends on backend requirements
 | 
			
		||||
                    response = {
 | 
			
		||||
                        "type": "patchSessionResult",
 | 
			
		||||
                        "payload": {
 | 
			
		||||
                            "sessionId": session_id,
 | 
			
		||||
                            "success": True,
 | 
			
		||||
                            "message": "Session patch acknowledged"
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                    await websocket.send_json(response)
 | 
			
		||||
                    logger.info(f"Acknowledged patch for session {session_id}")
 | 
			
		||||
                
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.error(f"Unknown message type: {msg_type}")
 | 
			
		||||
            except json.JSONDecodeError:
 | 
			
		||||
| 
						 | 
				
			
			@ -706,19 +829,23 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
        stream_task.cancel()
 | 
			
		||||
        await stream_task
 | 
			
		||||
        with streams_lock:
 | 
			
		||||
            for camera_id, stream in streams.items():
 | 
			
		||||
                stream["stop_event"].set()
 | 
			
		||||
                stream["thread"].join()
 | 
			
		||||
                # Only release cap if it exists (RTSP mode)
 | 
			
		||||
                if "cap" in stream:
 | 
			
		||||
                    stream["cap"].release()
 | 
			
		||||
                while not stream["buffer"].empty():
 | 
			
		||||
            # Clean up shared camera streams
 | 
			
		||||
            for camera_url, shared_stream in camera_streams.items():
 | 
			
		||||
                shared_stream["stop_event"].set()
 | 
			
		||||
                shared_stream["thread"].join()
 | 
			
		||||
                if "cap" in shared_stream:
 | 
			
		||||
                    shared_stream["cap"].release()
 | 
			
		||||
                while not shared_stream["buffer"].empty():
 | 
			
		||||
                    try:
 | 
			
		||||
                        stream["buffer"].get_nowait()
 | 
			
		||||
                        shared_stream["buffer"].get_nowait()
 | 
			
		||||
                    except queue.Empty:
 | 
			
		||||
                        pass
 | 
			
		||||
                logger.info(f"Released camera {camera_id} and cleaned up resources")
 | 
			
		||||
                logger.info(f"Released shared camera stream for {camera_url}")
 | 
			
		||||
            
 | 
			
		||||
            streams.clear()
 | 
			
		||||
            camera_streams.clear()
 | 
			
		||||
            subscription_to_camera.clear()
 | 
			
		||||
        with models_lock:
 | 
			
		||||
            models.clear()
 | 
			
		||||
        session_ids.clear()
 | 
			
		||||
        logger.info("WebSocket connection closed")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										125
									
								
								test_protocol.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								test_protocol.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,125 @@
 | 
			
		|||
#!/usr/bin/env python3
 | 
			
		||||
"""
 | 
			
		||||
Test script to verify the worker implementation follows the protocol
 | 
			
		||||
"""
 | 
			
		||||
import json
 | 
			
		||||
import asyncio
 | 
			
		||||
import websockets
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
async def test_protocol():
 | 
			
		||||
    """Test the worker protocol implementation"""
 | 
			
		||||
    uri = "ws://localhost:8000"
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        async with websockets.connect(uri) as websocket:
 | 
			
		||||
            print("✓ Connected to worker")
 | 
			
		||||
            
 | 
			
		||||
            # Test 1: Check if we receive heartbeat (stateReport)
 | 
			
		||||
            print("\n1. Testing heartbeat...")
 | 
			
		||||
            try:
 | 
			
		||||
                message = await asyncio.wait_for(websocket.recv(), timeout=5)
 | 
			
		||||
                data = json.loads(message)
 | 
			
		||||
                if data.get("type") == "stateReport":
 | 
			
		||||
                    print("✓ Received stateReport heartbeat")
 | 
			
		||||
                    print(f"  - CPU Usage: {data.get('cpuUsage', 'N/A')}%")
 | 
			
		||||
                    print(f"  - Memory Usage: {data.get('memoryUsage', 'N/A')}%")
 | 
			
		||||
                    print(f"  - Camera Connections: {len(data.get('cameraConnections', []))}")
 | 
			
		||||
                else:
 | 
			
		||||
                    print(f"✗ Expected stateReport, got {data.get('type')}")
 | 
			
		||||
            except asyncio.TimeoutError:
 | 
			
		||||
                print("✗ No heartbeat received within 5 seconds")
 | 
			
		||||
            
 | 
			
		||||
            # Test 2: Request state
 | 
			
		||||
            print("\n2. Testing requestState...")
 | 
			
		||||
            await websocket.send(json.dumps({"type": "requestState"}))
 | 
			
		||||
            try:
 | 
			
		||||
                message = await asyncio.wait_for(websocket.recv(), timeout=5)
 | 
			
		||||
                data = json.loads(message)
 | 
			
		||||
                if data.get("type") == "stateReport":
 | 
			
		||||
                    print("✓ Received stateReport response")
 | 
			
		||||
                else:
 | 
			
		||||
                    print(f"✗ Expected stateReport, got {data.get('type')}")
 | 
			
		||||
            except asyncio.TimeoutError:
 | 
			
		||||
                print("✗ No response to requestState within 5 seconds")
 | 
			
		||||
            
 | 
			
		||||
            # Test 3: Set session ID
 | 
			
		||||
            print("\n3. Testing setSessionId...")
 | 
			
		||||
            session_message = {
 | 
			
		||||
                "type": "setSessionId",
 | 
			
		||||
                "payload": {
 | 
			
		||||
                    "displayIdentifier": "display-001",
 | 
			
		||||
                    "sessionId": 12345
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            await websocket.send(json.dumps(session_message))
 | 
			
		||||
            print("✓ Sent setSessionId message")
 | 
			
		||||
            
 | 
			
		||||
            # Test 4: Test patchSession
 | 
			
		||||
            print("\n4. Testing patchSession...")
 | 
			
		||||
            patch_message = {
 | 
			
		||||
                "type": "patchSession",
 | 
			
		||||
                "sessionId": 12345,
 | 
			
		||||
                "data": {
 | 
			
		||||
                    "currentCar": {
 | 
			
		||||
                        "carModel": "Civic",
 | 
			
		||||
                        "carBrand": "Honda"
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            await websocket.send(json.dumps(patch_message))
 | 
			
		||||
            
 | 
			
		||||
            # Wait for patchSessionResult
 | 
			
		||||
            try:
 | 
			
		||||
                message = await asyncio.wait_for(websocket.recv(), timeout=5)
 | 
			
		||||
                data = json.loads(message)
 | 
			
		||||
                if data.get("type") == "patchSessionResult":
 | 
			
		||||
                    print("✓ Received patchSessionResult")
 | 
			
		||||
                    print(f"  - Success: {data.get('payload', {}).get('success')}")
 | 
			
		||||
                    print(f"  - Message: {data.get('payload', {}).get('message')}")
 | 
			
		||||
                else:
 | 
			
		||||
                    print(f"✗ Expected patchSessionResult, got {data.get('type')}")
 | 
			
		||||
            except asyncio.TimeoutError:
 | 
			
		||||
                print("✗ No patchSessionResult received within 5 seconds")
 | 
			
		||||
            
 | 
			
		||||
            # Test 5: Test subscribe message format (without actual camera)
 | 
			
		||||
            print("\n5. Testing subscribe message format...")
 | 
			
		||||
            subscribe_message = {
 | 
			
		||||
                "type": "subscribe",
 | 
			
		||||
                "payload": {
 | 
			
		||||
                    "subscriptionIdentifier": "display-001;cam-001",
 | 
			
		||||
                    "snapshotUrl": "http://example.com/snapshot.jpg",
 | 
			
		||||
                    "snapshotInterval": 5000,
 | 
			
		||||
                    "modelUrl": "http://example.com/model.mpta",
 | 
			
		||||
                    "modelName": "Test Model",
 | 
			
		||||
                    "modelId": 101,
 | 
			
		||||
                    "cropX1": 100,
 | 
			
		||||
                    "cropY1": 200,
 | 
			
		||||
                    "cropX2": 300,
 | 
			
		||||
                    "cropY2": 400
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            await websocket.send(json.dumps(subscribe_message))
 | 
			
		||||
            print("✓ Sent subscribe message (will fail without actual camera/model)")
 | 
			
		||||
            
 | 
			
		||||
            # Listen for a few more messages to catch any errors
 | 
			
		||||
            print("\n6. Listening for additional messages...")
 | 
			
		||||
            for i in range(3):
 | 
			
		||||
                try:
 | 
			
		||||
                    message = await asyncio.wait_for(websocket.recv(), timeout=2)
 | 
			
		||||
                    data = json.loads(message)
 | 
			
		||||
                    msg_type = data.get("type")
 | 
			
		||||
                    print(f"  - Received {msg_type}")
 | 
			
		||||
                    if msg_type == "error":
 | 
			
		||||
                        print(f"    Error: {data.get('error')}")
 | 
			
		||||
                except asyncio.TimeoutError:
 | 
			
		||||
                    break
 | 
			
		||||
            
 | 
			
		||||
            print("\n✓ Protocol test completed successfully!")
 | 
			
		||||
            
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"✗ Connection failed: {e}")
 | 
			
		||||
        print("Make sure the worker is running on localhost:8000")
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    asyncio.run(test_protocol())
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue