From 0bcf572242b5d72cd97a5e40bcce51d29cafe275 Mon Sep 17 00:00:00 2001 From: Siwat Sirichai Date: Wed, 13 Aug 2025 00:06:27 +0700 Subject: [PATCH] Implement subscription reconciliation logic for improved management of camera streams --- app.py | 248 ++++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 228 insertions(+), 20 deletions(-) diff --git a/app.py b/app.py index c9727e8..1b163b7 100644 --- a/app.py +++ b/app.py @@ -500,6 +500,199 @@ async def detect(websocket: WebSocket): finally: logger.info(f"Snapshot reader thread for camera {camera_id} is exiting") + async def reconcile_subscriptions(desired_subscriptions, websocket): + """ + Declarative reconciliation: Compare desired vs current subscriptions and make changes + """ + logger.info(f"Reconciling subscriptions: {len(desired_subscriptions)} desired") + + with streams_lock: + # Get current subscriptions + current_subscription_ids = set(streams.keys()) + desired_subscription_ids = set(sub["subscriptionIdentifier"] for sub in desired_subscriptions) + + # Find what to add and remove + to_add = desired_subscription_ids - current_subscription_ids + to_remove = current_subscription_ids - desired_subscription_ids + to_check_for_changes = current_subscription_ids & desired_subscription_ids + + logger.info(f"Reconciliation: {len(to_add)} to add, {len(to_remove)} to remove, {len(to_check_for_changes)} to check for changes") + + # Remove subscriptions that are no longer wanted + for subscription_id in to_remove: + await unsubscribe_internal(subscription_id) + + # Check existing subscriptions for parameter changes + for subscription_id in to_check_for_changes: + desired_sub = next(sub for sub in desired_subscriptions if sub["subscriptionIdentifier"] == subscription_id) + current_stream = streams[subscription_id] + + # Check if parameters changed + if has_subscription_changed(desired_sub, current_stream): + logger.info(f"Parameters changed for {subscription_id}, resubscribing") + await unsubscribe_internal(subscription_id) + await subscribe_internal(desired_sub, websocket) + + # Add new subscriptions + for subscription_id in to_add: + desired_sub = next(sub for sub in desired_subscriptions if sub["subscriptionIdentifier"] == subscription_id) + await subscribe_internal(desired_sub, websocket) + + def has_subscription_changed(desired_sub, current_stream): + """Check if subscription parameters have changed""" + return ( + desired_sub.get("rtspUrl") != current_stream.get("rtsp_url") or + desired_sub.get("snapshotUrl") != current_stream.get("snapshot_url") or + desired_sub.get("snapshotInterval") != current_stream.get("snapshot_interval") or + desired_sub.get("cropX1") != current_stream.get("cropX1") or + desired_sub.get("cropY1") != current_stream.get("cropY1") or + desired_sub.get("cropX2") != current_stream.get("cropX2") or + desired_sub.get("cropY2") != current_stream.get("cropY2") or + desired_sub.get("modelId") != current_stream.get("modelId") or + desired_sub.get("modelName") != current_stream.get("modelName") + ) + + async def subscribe_internal(subscription, websocket): + """Internal subscription logic extracted from original subscribe handler""" + subscriptionIdentifier = subscription.get("subscriptionIdentifier") + rtsp_url = subscription.get("rtspUrl") + snapshot_url = subscription.get("snapshotUrl") + snapshot_interval = subscription.get("snapshotInterval") + model_url = subscription.get("modelUrl") + modelId = subscription.get("modelId") + modelName = subscription.get("modelName") + cropX1 = subscription.get("cropX1") + cropY1 = subscription.get("cropY1") + cropX2 = subscription.get("cropX2") + cropY2 = subscription.get("cropY2") + + # Extract camera_id from subscriptionIdentifier + parts = subscriptionIdentifier.split(';') + if len(parts) != 2: + logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}") + return + + display_identifier, camera_identifier = parts + camera_id = subscriptionIdentifier + + # Load model if needed + if model_url: + with models_lock: + if (camera_id not in models) or (modelId not in models[camera_id]): + logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}") + extraction_dir = os.path.join("models", camera_identifier, str(modelId)) + os.makedirs(extraction_dir, exist_ok=True) + + # Handle model loading (same as original) + parsed = urlparse(model_url) + if parsed.scheme in ("http", "https"): + filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta" + local_mpta = os.path.join(extraction_dir, filename) + local_path = download_mpta(model_url, local_mpta) + if not local_path: + logger.error(f"Failed to download model from {model_url}") + return + model_tree = load_pipeline_from_zip(local_path, extraction_dir) + else: + if not os.path.exists(model_url): + logger.error(f"Model file not found: {model_url}") + return + model_tree = load_pipeline_from_zip(model_url, extraction_dir) + + if model_tree is None: + logger.error(f"Failed to load model {modelId}") + return + + if camera_id not in models: + models[camera_id] = {} + models[camera_id][modelId] = model_tree + + # Create stream (same logic as original) + if camera_id and (rtsp_url or snapshot_url) and len(streams) < max_streams: + camera_url = snapshot_url if snapshot_url else rtsp_url + + # Check if we already have a stream for this camera URL + shared_stream = camera_streams.get(camera_url) + + if shared_stream: + # Reuse existing stream + buffer = shared_stream["buffer"] + stop_event = shared_stream["stop_event"] + thread = shared_stream["thread"] + mode = shared_stream["mode"] + shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1 + else: + # Create new stream + buffer = queue.Queue(maxsize=1) + stop_event = threading.Event() + + if snapshot_url and snapshot_interval: + thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event)) + thread.daemon = True + thread.start() + mode = "snapshot" + shared_stream = { + "buffer": buffer, "thread": thread, "stop_event": stop_event, + "mode": mode, "url": snapshot_url, "snapshot_interval": snapshot_interval, "ref_count": 1 + } + camera_streams[camera_url] = shared_stream + elif rtsp_url: + cap = cv2.VideoCapture(rtsp_url) + if not cap.isOpened(): + logger.error(f"Failed to open RTSP stream for camera {camera_id}") + return + thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event)) + thread.daemon = True + thread.start() + mode = "rtsp" + shared_stream = { + "buffer": buffer, "thread": thread, "stop_event": stop_event, + "mode": mode, "url": rtsp_url, "cap": cap, "ref_count": 1 + } + camera_streams[camera_url] = shared_stream + else: + logger.error(f"No valid URL provided for camera {camera_id}") + return + + # Create stream info + stream_info = { + "buffer": buffer, "thread": thread, "stop_event": stop_event, + "modelId": modelId, "modelName": modelName, "subscriptionIdentifier": subscriptionIdentifier, + "cropX1": cropX1, "cropY1": cropY1, "cropX2": cropX2, "cropY2": cropY2, + "mode": mode, "camera_url": camera_url, "modelUrl": model_url + } + + if mode == "snapshot": + stream_info["snapshot_url"] = snapshot_url + stream_info["snapshot_interval"] = snapshot_interval + elif mode == "rtsp": + stream_info["rtsp_url"] = rtsp_url + stream_info["cap"] = shared_stream["cap"] + + streams[camera_id] = stream_info + subscription_to_camera[camera_id] = camera_url + logger.info(f"Subscribed to camera {camera_id}") + + async def unsubscribe_internal(subscription_id): + """Internal unsubscription logic""" + if subscription_id in streams: + stream = streams.pop(subscription_id) + camera_url = subscription_to_camera.pop(subscription_id, None) + + if camera_url and camera_url in camera_streams: + shared_stream = camera_streams[camera_url] + shared_stream["ref_count"] -= 1 + + if shared_stream["ref_count"] <= 0: + shared_stream["stop_event"].set() + shared_stream["thread"].join() + if "cap" in shared_stream: + shared_stream["cap"].release() + del camera_streams[camera_url] + + latest_frames.pop(subscription_id, None) + logger.info(f"Unsubscribed from camera {subscription_id}") + async def process_streams(): logger.info("Started processing streams") try: @@ -599,29 +792,44 @@ async def detect(websocket: WebSocket): data = json.loads(msg) msg_type = data.get("type") - if msg_type == "subscribe": + if msg_type == "setSubscriptionList": + # Declarative approach: Backend sends list of subscriptions this worker should have + desired_subscriptions = data.get("subscriptions", []) + logger.info(f"Received subscription list with {len(desired_subscriptions)} subscriptions") + + await reconcile_subscriptions(desired_subscriptions, websocket) + + elif msg_type == "subscribe": + # Legacy support - convert single subscription to list + payload = data.get("payload", {}) + await reconcile_subscriptions([payload], websocket) + + elif msg_type == "unsubscribe": + # Legacy support - remove subscription payload = data.get("payload", {}) subscriptionIdentifier = payload.get("subscriptionIdentifier") - rtsp_url = payload.get("rtspUrl") - snapshot_url = payload.get("snapshotUrl") - snapshot_interval = payload.get("snapshotInterval") - model_url = payload.get("modelUrl") - modelId = payload.get("modelId") - modelName = payload.get("modelName") - cropX1 = payload.get("cropX1") - cropY1 = payload.get("cropY1") - cropX2 = payload.get("cropX2") - cropY2 = payload.get("cropY2") - - # Extract camera_id from subscriptionIdentifier (format: displayIdentifier;cameraIdentifier) - parts = subscriptionIdentifier.split(';') - if len(parts) != 2: - logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}") - continue + # Remove from current subscriptions and reconcile + current_subs = [] + with streams_lock: + for camera_id, stream in streams.items(): + if stream["subscriptionIdentifier"] != subscriptionIdentifier: + # Convert stream back to subscription format + current_subs.append({ + "subscriptionIdentifier": stream["subscriptionIdentifier"], + "rtspUrl": stream.get("rtsp_url"), + "snapshotUrl": stream.get("snapshot_url"), + "snapshotInterval": stream.get("snapshot_interval"), + "modelId": stream["modelId"], + "modelName": stream["modelName"], + "modelUrl": stream.get("modelUrl", ""), + "cropX1": stream.get("cropX1"), + "cropY1": stream.get("cropY1"), + "cropX2": stream.get("cropX2"), + "cropY2": stream.get("cropY2") + }) + await reconcile_subscriptions(current_subs, websocket) - display_identifier, camera_identifier = parts - camera_id = subscriptionIdentifier # Use full subscriptionIdentifier as camera_id for mapping - + elif msg_type == "old_subscribe_logic_removed": if model_url: with models_lock: if (camera_id not in models) or (modelId not in models[camera_id]):