Compare commits

..

1 commit

Author SHA1 Message Date
Pongsatorn Kanjanasantisak
527dc748b5 update reqs 2025-08-12 00:41:58 +07:00
6 changed files with 88 additions and 3328 deletions

373
app.py
View file

@ -13,13 +13,7 @@ import requests
import asyncio import asyncio
import psutil import psutil
import zipfile import zipfile
import ssl
import urllib3
import subprocess
import tempfile
from urllib.parse import urlparse from urllib.parse import urlparse
from requests.adapters import HTTPAdapter
from urllib3.util.ssl_ import create_urllib3_context
from fastapi import FastAPI, WebSocket, HTTPException from fastapi import FastAPI, WebSocket, HTTPException
from fastapi.websockets import WebSocketDisconnect from fastapi.websockets import WebSocketDisconnect
from fastapi.responses import Response from fastapi.responses import Response
@ -246,14 +240,16 @@ async def detect(websocket: WebSocket):
logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}") logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}")
start_time = time.time() start_time = time.time()
# Extract display identifier for pipeline context # Extract display identifier for session ID lookup
subscription_parts = stream["subscriptionIdentifier"].split(';') subscription_parts = stream["subscriptionIdentifier"].split(';')
display_identifier = subscription_parts[0] if subscription_parts else None display_identifier = subscription_parts[0] if subscription_parts else None
session_id = session_ids.get(display_identifier) if display_identifier else None
# Create context for pipeline execution (session_id will be generated by pipeline) # Create context for pipeline execution
pipeline_context = { pipeline_context = {
"camera_id": camera_id, "camera_id": camera_id,
"display_id": display_identifier "display_id": display_identifier,
"session_id": session_id
} }
detection_result = run_pipeline(cropped_frame, model_tree, context=pipeline_context) detection_result = run_pipeline(cropped_frame, model_tree, context=pipeline_context)
@ -263,63 +259,57 @@ async def detect(websocket: WebSocket):
# Log the raw detection result for debugging # Log the raw detection result for debugging
logger.debug(f"Raw detection result for camera {camera_id}:\n{json.dumps(detection_result, indent=2, default=str)}") logger.debug(f"Raw detection result for camera {camera_id}:\n{json.dumps(detection_result, indent=2, default=str)}")
# Extract session_id from pipeline result (generated during database record creation) # Direct class result (no detections/classifications structure)
session_id = None if detection_result and isinstance(detection_result, dict) and "class" in detection_result and "confidence" in detection_result:
if detection_result and isinstance(detection_result, dict): highest_confidence_detection = {
# Check if pipeline generated a session_id (happens when Car+Frontal detected together) "class": detection_result.get("class", "none"),
if "session_id" in detection_result: "confidence": detection_result.get("confidence", 1.0),
session_id = detection_result["session_id"] "box": [0, 0, 0, 0] # Empty bounding box for classifications
logger.debug(f"Extracted session_id from pipeline result: {session_id}") }
# Handle case when no detections found or result is empty
elif not detection_result or not detection_result.get("detections"):
# Check if we have classification results
if detection_result and detection_result.get("classifications"):
# Get the highest confidence classification
classifications = detection_result.get("classifications", [])
highest_confidence_class = max(classifications, key=lambda x: x.get("confidence", 0)) if classifications else None
# Process detection result - run_pipeline returns the primary detection directly if highest_confidence_class:
if detection_result and isinstance(detection_result, dict) and "class" in detection_result: highest_confidence_detection = {
highest_confidence_detection = detection_result "class": highest_confidence_class.get("class", "none"),
"confidence": highest_confidence_class.get("confidence", 1.0),
"box": [0, 0, 0, 0] # Empty bounding box for classifications
}
else: else:
# No detection found
highest_confidence_detection = { highest_confidence_detection = {
"class": "none", "class": "none",
"confidence": 1.0, "confidence": 1.0,
"bbox": [0, 0, 0, 0], "box": [0, 0, 0, 0]
"branch_results": {}
} }
# Convert detection format to match backend expectations exactly as in worker.md section 4.2
detection_dict = {
"carModel": None,
"carBrand": None,
"carYear": None,
"bodyType": None,
"licensePlateText": None,
"licensePlateConfidence": None
}
# Extract and process branch results from parallel classification
branch_results = highest_confidence_detection.get("branch_results", {})
if branch_results:
logger.debug(f"Processing branch results: {branch_results}")
# Transform branch results into backend-expected detection attributes
for branch_id, branch_data in branch_results.items():
if isinstance(branch_data, dict):
logger.debug(f"Processing branch {branch_id}: {branch_data}")
# Map common classification fields to backend-expected names
if "brand" in branch_data:
detection_dict["carBrand"] = branch_data["brand"]
if "body_type" in branch_data:
detection_dict["bodyType"] = branch_data["body_type"]
if "class" in branch_data:
class_name = branch_data["class"]
# Map based on branch/model type
if "brand" in branch_id.lower():
detection_dict["carBrand"] = class_name
elif "bodytype" in branch_id.lower() or "body" in branch_id.lower():
detection_dict["bodyType"] = class_name
logger.info(f"Detection payload after branch processing: {detection_dict}")
else: else:
logger.debug("No branch results found in detection result") highest_confidence_detection = {
"class": "none",
"confidence": 1.0,
"box": [0, 0, 0, 0]
}
else:
# Find detection with highest confidence
detections = detection_result.get("detections", [])
highest_confidence_detection = max(detections, key=lambda x: x.get("confidence", 0)) if detections else {
"class": "none",
"confidence": 1.0,
"box": [0, 0, 0, 0]
}
# Convert detection format to match protocol - flatten detection attributes
detection_dict = {}
# Handle different detection result formats
if isinstance(highest_confidence_detection, dict):
# Copy all fields from the detection result
for key, value in highest_confidence_detection.items():
if key not in ["box", "id"]: # Skip internal fields
detection_dict[key] = value
detection_data = { detection_data = {
"type": "imageDetection", "type": "imageDetection",
@ -332,14 +322,12 @@ async def detect(websocket: WebSocket):
} }
} }
# Add session ID if available (generated by pipeline when Car+Frontal detected) # Add session ID if available
if session_id is not None: if session_id is not None:
detection_data["sessionId"] = session_id detection_data["sessionId"] = session_id
logger.debug(f"Added session_id to WebSocket response: {session_id}")
if highest_confidence_detection.get("class") != "none": if highest_confidence_detection["class"] != "none":
confidence = highest_confidence_detection.get("confidence", 0.0) logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}")
logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {confidence:.2f} using model {stream['modelName']}")
# Log session ID if available # Log session ID if available
if session_id: if session_id:
@ -347,7 +335,6 @@ async def detect(websocket: WebSocket):
await websocket.send_json(detection_data) await websocket.send_json(detection_data)
logger.debug(f"Sent detection data to client for camera {camera_id}") logger.debug(f"Sent detection data to client for camera {camera_id}")
logger.debug(f"Sent this detection data: {detection_data}")
return persistent_data return persistent_data
except Exception as e: except Exception as e:
logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True) logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True)
@ -513,199 +500,6 @@ async def detect(websocket: WebSocket):
finally: finally:
logger.info(f"Snapshot reader thread for camera {camera_id} is exiting") logger.info(f"Snapshot reader thread for camera {camera_id} is exiting")
async def reconcile_subscriptions(desired_subscriptions, websocket):
"""
Declarative reconciliation: Compare desired vs current subscriptions and make changes
"""
logger.info(f"Reconciling subscriptions: {len(desired_subscriptions)} desired")
with streams_lock:
# Get current subscriptions
current_subscription_ids = set(streams.keys())
desired_subscription_ids = set(sub["subscriptionIdentifier"] for sub in desired_subscriptions)
# Find what to add and remove
to_add = desired_subscription_ids - current_subscription_ids
to_remove = current_subscription_ids - desired_subscription_ids
to_check_for_changes = current_subscription_ids & desired_subscription_ids
logger.info(f"Reconciliation: {len(to_add)} to add, {len(to_remove)} to remove, {len(to_check_for_changes)} to check for changes")
# Remove subscriptions that are no longer wanted
for subscription_id in to_remove:
await unsubscribe_internal(subscription_id)
# Check existing subscriptions for parameter changes
for subscription_id in to_check_for_changes:
desired_sub = next(sub for sub in desired_subscriptions if sub["subscriptionIdentifier"] == subscription_id)
current_stream = streams[subscription_id]
# Check if parameters changed
if has_subscription_changed(desired_sub, current_stream):
logger.info(f"Parameters changed for {subscription_id}, resubscribing")
await unsubscribe_internal(subscription_id)
await subscribe_internal(desired_sub, websocket)
# Add new subscriptions
for subscription_id in to_add:
desired_sub = next(sub for sub in desired_subscriptions if sub["subscriptionIdentifier"] == subscription_id)
await subscribe_internal(desired_sub, websocket)
def has_subscription_changed(desired_sub, current_stream):
"""Check if subscription parameters have changed"""
return (
desired_sub.get("rtspUrl") != current_stream.get("rtsp_url") or
desired_sub.get("snapshotUrl") != current_stream.get("snapshot_url") or
desired_sub.get("snapshotInterval") != current_stream.get("snapshot_interval") or
desired_sub.get("cropX1") != current_stream.get("cropX1") or
desired_sub.get("cropY1") != current_stream.get("cropY1") or
desired_sub.get("cropX2") != current_stream.get("cropX2") or
desired_sub.get("cropY2") != current_stream.get("cropY2") or
desired_sub.get("modelId") != current_stream.get("modelId") or
desired_sub.get("modelName") != current_stream.get("modelName")
)
async def subscribe_internal(subscription, websocket):
"""Internal subscription logic extracted from original subscribe handler"""
subscriptionIdentifier = subscription.get("subscriptionIdentifier")
rtsp_url = subscription.get("rtspUrl")
snapshot_url = subscription.get("snapshotUrl")
snapshot_interval = subscription.get("snapshotInterval")
model_url = subscription.get("modelUrl")
modelId = subscription.get("modelId")
modelName = subscription.get("modelName")
cropX1 = subscription.get("cropX1")
cropY1 = subscription.get("cropY1")
cropX2 = subscription.get("cropX2")
cropY2 = subscription.get("cropY2")
# Extract camera_id from subscriptionIdentifier
parts = subscriptionIdentifier.split(';')
if len(parts) != 2:
logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}")
return
display_identifier, camera_identifier = parts
camera_id = subscriptionIdentifier
# Load model if needed
if model_url:
with models_lock:
if (camera_id not in models) or (modelId not in models[camera_id]):
logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
extraction_dir = os.path.join("models", camera_identifier, str(modelId))
os.makedirs(extraction_dir, exist_ok=True)
# Handle model loading (same as original)
parsed = urlparse(model_url)
if parsed.scheme in ("http", "https"):
filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
local_mpta = os.path.join(extraction_dir, filename)
local_path = download_mpta(model_url, local_mpta)
if not local_path:
logger.error(f"Failed to download model from {model_url}")
return
model_tree = load_pipeline_from_zip(local_path, extraction_dir)
else:
if not os.path.exists(model_url):
logger.error(f"Model file not found: {model_url}")
return
model_tree = load_pipeline_from_zip(model_url, extraction_dir)
if model_tree is None:
logger.error(f"Failed to load model {modelId}")
return
if camera_id not in models:
models[camera_id] = {}
models[camera_id][modelId] = model_tree
# Create stream (same logic as original)
if camera_id and (rtsp_url or snapshot_url) and len(streams) < max_streams:
camera_url = snapshot_url if snapshot_url else rtsp_url
# Check if we already have a stream for this camera URL
shared_stream = camera_streams.get(camera_url)
if shared_stream:
# Reuse existing stream
buffer = shared_stream["buffer"]
stop_event = shared_stream["stop_event"]
thread = shared_stream["thread"]
mode = shared_stream["mode"]
shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1
else:
# Create new stream
buffer = queue.Queue(maxsize=1)
stop_event = threading.Event()
if snapshot_url and snapshot_interval:
thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "snapshot"
shared_stream = {
"buffer": buffer, "thread": thread, "stop_event": stop_event,
"mode": mode, "url": snapshot_url, "snapshot_interval": snapshot_interval, "ref_count": 1
}
camera_streams[camera_url] = shared_stream
elif rtsp_url:
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream for camera {camera_id}")
return
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "rtsp"
shared_stream = {
"buffer": buffer, "thread": thread, "stop_event": stop_event,
"mode": mode, "url": rtsp_url, "cap": cap, "ref_count": 1
}
camera_streams[camera_url] = shared_stream
else:
logger.error(f"No valid URL provided for camera {camera_id}")
return
# Create stream info
stream_info = {
"buffer": buffer, "thread": thread, "stop_event": stop_event,
"modelId": modelId, "modelName": modelName, "subscriptionIdentifier": subscriptionIdentifier,
"cropX1": cropX1, "cropY1": cropY1, "cropX2": cropX2, "cropY2": cropY2,
"mode": mode, "camera_url": camera_url, "modelUrl": model_url
}
if mode == "snapshot":
stream_info["snapshot_url"] = snapshot_url
stream_info["snapshot_interval"] = snapshot_interval
elif mode == "rtsp":
stream_info["rtsp_url"] = rtsp_url
stream_info["cap"] = shared_stream["cap"]
streams[camera_id] = stream_info
subscription_to_camera[camera_id] = camera_url
logger.info(f"Subscribed to camera {camera_id}")
async def unsubscribe_internal(subscription_id):
"""Internal unsubscription logic"""
if subscription_id in streams:
stream = streams.pop(subscription_id)
camera_url = subscription_to_camera.pop(subscription_id, None)
if camera_url and camera_url in camera_streams:
shared_stream = camera_streams[camera_url]
shared_stream["ref_count"] -= 1
if shared_stream["ref_count"] <= 0:
shared_stream["stop_event"].set()
shared_stream["thread"].join()
if "cap" in shared_stream:
shared_stream["cap"].release()
del camera_streams[camera_url]
latest_frames.pop(subscription_id, None)
logger.info(f"Unsubscribed from camera {subscription_id}")
async def process_streams(): async def process_streams():
logger.info("Started processing streams") logger.info("Started processing streams")
try: try:
@ -773,10 +567,6 @@ async def detect(websocket: WebSocket):
"modelId": stream["modelId"], "modelId": stream["modelId"],
"modelName": stream["modelName"], "modelName": stream["modelName"],
"online": True, "online": True,
# Include all subscription parameters for proper change detection
"rtspUrl": stream.get("rtsp_url"),
"snapshotUrl": stream.get("snapshot_url"),
"snapshotInterval": stream.get("snapshot_interval"),
**{k: v for k, v in get_crop_coords(stream).items() if v is not None} **{k: v for k, v in get_crop_coords(stream).items() if v is not None}
} }
for camera_id, stream in streams.items() for camera_id, stream in streams.items()
@ -805,44 +595,29 @@ async def detect(websocket: WebSocket):
data = json.loads(msg) data = json.loads(msg)
msg_type = data.get("type") msg_type = data.get("type")
if msg_type == "setSubscriptionList": if msg_type == "subscribe":
# Declarative approach: Backend sends list of subscriptions this worker should have
desired_subscriptions = data.get("subscriptions", [])
logger.info(f"Received subscription list with {len(desired_subscriptions)} subscriptions")
await reconcile_subscriptions(desired_subscriptions, websocket)
elif msg_type == "subscribe":
# Legacy support - convert single subscription to list
payload = data.get("payload", {})
await reconcile_subscriptions([payload], websocket)
elif msg_type == "unsubscribe":
# Legacy support - remove subscription
payload = data.get("payload", {}) payload = data.get("payload", {})
subscriptionIdentifier = payload.get("subscriptionIdentifier") subscriptionIdentifier = payload.get("subscriptionIdentifier")
# Remove from current subscriptions and reconcile rtsp_url = payload.get("rtspUrl")
current_subs = [] snapshot_url = payload.get("snapshotUrl")
with streams_lock: snapshot_interval = payload.get("snapshotInterval")
for camera_id, stream in streams.items(): model_url = payload.get("modelUrl")
if stream["subscriptionIdentifier"] != subscriptionIdentifier: modelId = payload.get("modelId")
# Convert stream back to subscription format modelName = payload.get("modelName")
current_subs.append({ cropX1 = payload.get("cropX1")
"subscriptionIdentifier": stream["subscriptionIdentifier"], cropY1 = payload.get("cropY1")
"rtspUrl": stream.get("rtsp_url"), cropX2 = payload.get("cropX2")
"snapshotUrl": stream.get("snapshot_url"), cropY2 = payload.get("cropY2")
"snapshotInterval": stream.get("snapshot_interval"),
"modelId": stream["modelId"], # Extract camera_id from subscriptionIdentifier (format: displayIdentifier;cameraIdentifier)
"modelName": stream["modelName"], parts = subscriptionIdentifier.split(';')
"modelUrl": stream.get("modelUrl", ""), if len(parts) != 2:
"cropX1": stream.get("cropX1"), logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}")
"cropY1": stream.get("cropY1"), continue
"cropX2": stream.get("cropX2"),
"cropY2": stream.get("cropY2") display_identifier, camera_identifier = parts
}) camera_id = subscriptionIdentifier # Use full subscriptionIdentifier as camera_id for mapping
await reconcile_subscriptions(current_subs, websocket)
elif msg_type == "old_subscribe_logic_removed":
if model_url: if model_url:
with models_lock: with models_lock:
if (camera_id not in models) or (modelId not in models[camera_id]): if (camera_id not in models) or (modelId not in models[camera_id]):
@ -1038,10 +813,6 @@ async def detect(websocket: WebSocket):
"modelId": stream["modelId"], "modelId": stream["modelId"],
"modelName": stream["modelName"], "modelName": stream["modelName"],
"online": True, "online": True,
# Include all subscription parameters for proper change detection
"rtspUrl": stream.get("rtsp_url"),
"snapshotUrl": stream.get("snapshot_url"),
"snapshotInterval": stream.get("snapshot_interval"),
**{k: v for k, v in get_crop_coords(stream).items() if v is not None} **{k: v for k, v in get_crop_coords(stream).items() if v is not None}
} }
for camera_id, stream in streams.items() for camera_id, stream in streams.items()

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,13 @@
torch torch>=1.12.0,<2.1.0
torchvision torchvision>=0.13.0,<0.16.0
ultralytics ultralytics>=8.3.0
opencv-python opencv-python>=4.6.0,<4.9.0
scipy scipy>=1.9.0,<1.12.0
filterpy filterpy>=1.4.0,<1.5.0
psycopg2-binary psycopg2-binary>=2.9.0,<2.10.0
easydict
loguru
pyzmq
gitpython
gdown
lap

View file

@ -1,6 +1,5 @@
fastapi fastapi[standard]
uvicorn uvicorn
websockets websockets
fastapi[standard]
redis redis
urllib3<2.0.0 urllib3<2.0.0

View file

@ -514,65 +514,6 @@ def resolve_field_mapping(value_template, branch_results, action_context):
logger.error(f"Error resolving field mapping '{value_template}': {e}") logger.error(f"Error resolving field mapping '{value_template}': {e}")
return None return None
def validate_pipeline_execution(node, regions_dict):
"""
Pre-validate that all required branches will execute successfully before
committing to Redis actions and database records.
Returns:
- (True, []) if pipeline can execute completely
- (False, missing_branches) if some required branches won't execute
"""
# Get all branches that parallel actions are waiting for
required_branches = set()
for action in node.get("parallelActions", []):
if action.get("type") == "postgresql_update_combined":
wait_for_branches = action.get("waitForBranches", [])
required_branches.update(wait_for_branches)
if not required_branches:
# No parallel actions requiring specific branches
logger.debug("No parallel actions with waitForBranches - validation passes")
return True, []
logger.debug(f"Pre-validation: checking if required branches {list(required_branches)} will execute")
# Check each required branch
missing_branches = []
for branch in node.get("branches", []):
branch_id = branch["modelId"]
if branch_id not in required_branches:
continue # This branch is not required by parallel actions
# Check if this branch would be triggered
trigger_classes = branch.get("triggerClasses", [])
min_conf = branch.get("minConfidence", 0)
branch_triggered = False
for det_class in regions_dict:
det_confidence = regions_dict[det_class]["confidence"]
if (det_class in trigger_classes and det_confidence >= min_conf):
branch_triggered = True
logger.debug(f"Pre-validation: branch {branch_id} WILL be triggered by {det_class} (conf={det_confidence:.3f} >= {min_conf})")
break
if not branch_triggered:
missing_branches.append(branch_id)
logger.warning(f"Pre-validation: branch {branch_id} will NOT be triggered - no matching classes or insufficient confidence")
logger.debug(f" Required: {trigger_classes} with min_conf={min_conf}")
logger.debug(f" Available: {[(cls, regions_dict[cls]['confidence']) for cls in regions_dict]}")
if missing_branches:
logger.error(f"Pipeline pre-validation FAILED: required branches {missing_branches} will not execute")
return False, missing_branches
else:
logger.info(f"Pipeline pre-validation PASSED: all required branches {list(required_branches)} will execute")
return True, []
def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None): def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
""" """
Enhanced pipeline that supports: Enhanced pipeline that supports:
@ -705,14 +646,6 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
else: else:
logger.debug("No multi-class validation - proceeding with all detections") logger.debug("No multi-class validation - proceeding with all detections")
# ─── Pre-validate pipeline execution ────────────────────────
pipeline_valid, missing_branches = validate_pipeline_execution(node, regions_dict)
if not pipeline_valid:
logger.error(f"Pipeline execution validation FAILED - required branches {missing_branches} cannot execute")
logger.error("Aborting pipeline: no Redis actions or database records will be created")
return (None, None) if return_bbox else None
# ─── Execute actions with region information ──────────────── # ─── Execute actions with region information ────────────────
detection_result = { detection_result = {
"detections": all_detections, "detections": all_detections,
@ -853,11 +786,9 @@ def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
primary_detection = max(all_detections, key=lambda x: x["confidence"]) primary_detection = max(all_detections, key=lambda x: x["confidence"])
primary_bbox = primary_detection["bbox"] primary_bbox = primary_detection["bbox"]
# Add branch results and session_id to primary detection for compatibility # Add branch results to primary detection for compatibility
if "branch_results" in detection_result: if "branch_results" in detection_result:
primary_detection["branch_results"] = detection_result["branch_results"] primary_detection["branch_results"] = detection_result["branch_results"]
if "session_id" in detection_result:
primary_detection["session_id"] = detection_result["session_id"]
return (primary_detection, primary_bbox) if return_bbox else primary_detection return (primary_detection, primary_bbox) if return_bbox else primary_detection