refactor: move old code to archive

This commit is contained in:
ziesorx 2025-09-22 17:16:32 +07:00
parent dced713ac5
commit b2d5726b7a
3 changed files with 1912 additions and 0 deletions

903
archive/app.py Normal file
View file

@ -0,0 +1,903 @@
from typing import Any, Dict
import os
import json
import time
import queue
import torch
import cv2
import numpy as np
import base64
import logging
import threading
import requests
import asyncio
import psutil
import zipfile
from urllib.parse import urlparse
from fastapi import FastAPI, WebSocket, HTTPException
from fastapi.websockets import WebSocketDisconnect
from fastapi.responses import Response
from websockets.exceptions import ConnectionClosedError
from ultralytics import YOLO
# Import shared pipeline functions
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline
app = FastAPI()
# Global dictionaries to keep track of models and streams
# "models" now holds a nested dict: { camera_id: { modelId: model_tree } }
models: Dict[str, Dict[str, Any]] = {}
streams: Dict[str, Dict[str, Any]] = {}
# Store session IDs per display
session_ids: Dict[str, int] = {}
# Track shared camera streams by camera URL
camera_streams: Dict[str, Dict[str, Any]] = {}
# Map subscriptions to their camera URL
subscription_to_camera: Dict[str, str] = {}
# Store latest frames for REST API access (separate from processing buffer)
latest_frames: Dict[str, Any] = {}
with open("config.json", "r") as f:
config = json.load(f)
poll_interval = config.get("poll_interval_ms", 100)
reconnect_interval = config.get("reconnect_interval_sec", 5)
TARGET_FPS = config.get("target_fps", 10)
poll_interval = 1000 / TARGET_FPS
logging.info(f"Poll interval: {poll_interval}ms")
max_streams = config.get("max_streams", 5)
max_retries = config.get("max_retries", 3)
# Configure logging
logging.basicConfig(
level=logging.INFO, # Set to INFO level for less verbose output
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.FileHandler("detector_worker.log"), # Write logs to a file
logging.StreamHandler() # Also output to console
]
)
# Create a logger specifically for this application
logger = logging.getLogger("detector_worker")
logger.setLevel(logging.DEBUG) # Set app-specific logger to DEBUG level
# Ensure all other libraries (including root) use at least INFO level
logging.getLogger().setLevel(logging.INFO)
logger.info("Starting detector worker application")
logger.info(f"Configuration: Target FPS: {TARGET_FPS}, Max streams: {max_streams}, Max retries: {max_retries}")
# Ensure the models directory exists
os.makedirs("models", exist_ok=True)
logger.info("Ensured models directory exists")
# Constants for heartbeat and timeouts
HEARTBEAT_INTERVAL = 2 # seconds
WORKER_TIMEOUT_MS = 10000
logger.debug(f"Heartbeat interval set to {HEARTBEAT_INTERVAL} seconds")
# Locks for thread-safe operations
streams_lock = threading.Lock()
models_lock = threading.Lock()
logger.debug("Initialized thread locks")
# Add helper to download mpta ZIP file from a remote URL
def download_mpta(url: str, dest_path: str) -> str:
try:
logger.info(f"Starting download of model from {url} to {dest_path}")
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
response = requests.get(url, stream=True)
if response.status_code == 200:
file_size = int(response.headers.get('content-length', 0))
logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
downloaded = 0
with open(dest_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
downloaded += len(chunk)
if file_size > 0 and downloaded % (file_size // 10) < 8192: # Log approximately every 10%
logger.debug(f"Download progress: {downloaded/file_size*100:.1f}%")
logger.info(f"Successfully downloaded mpta file from {url} to {dest_path}")
return dest_path
else:
logger.error(f"Failed to download mpta file (status code {response.status_code}): {response.text}")
return None
except Exception as e:
logger.error(f"Exception downloading mpta file from {url}: {str(e)}", exc_info=True)
return None
# Add helper to fetch snapshot image from HTTP/HTTPS URL
def fetch_snapshot(url: str):
try:
from requests.auth import HTTPBasicAuth, HTTPDigestAuth
# Parse URL to extract credentials
parsed = urlparse(url)
# Prepare headers - some cameras require User-Agent
headers = {
'User-Agent': 'Mozilla/5.0 (compatible; DetectorWorker/1.0)'
}
# Reconstruct URL without credentials
clean_url = f"{parsed.scheme}://{parsed.hostname}"
if parsed.port:
clean_url += f":{parsed.port}"
clean_url += parsed.path
if parsed.query:
clean_url += f"?{parsed.query}"
auth = None
if parsed.username and parsed.password:
# Try HTTP Digest authentication first (common for IP cameras)
try:
auth = HTTPDigestAuth(parsed.username, parsed.password)
response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
if response.status_code == 200:
logger.debug(f"Successfully authenticated using HTTP Digest for {clean_url}")
elif response.status_code == 401:
# If Digest fails, try Basic auth
logger.debug(f"HTTP Digest failed, trying Basic auth for {clean_url}")
auth = HTTPBasicAuth(parsed.username, parsed.password)
response = requests.get(clean_url, auth=auth, headers=headers, timeout=10)
if response.status_code == 200:
logger.debug(f"Successfully authenticated using HTTP Basic for {clean_url}")
except Exception as auth_error:
logger.debug(f"Authentication setup error: {auth_error}")
# Fallback to original URL with embedded credentials
response = requests.get(url, headers=headers, timeout=10)
else:
# No credentials in URL, make request as-is
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
# Convert response content to numpy array
nparr = np.frombuffer(response.content, np.uint8)
# Decode image
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame is not None:
logger.debug(f"Successfully fetched snapshot from {clean_url}, shape: {frame.shape}")
return frame
else:
logger.error(f"Failed to decode image from snapshot URL: {clean_url}")
return None
else:
logger.error(f"Failed to fetch snapshot (status code {response.status_code}): {clean_url}")
return None
except Exception as e:
logger.error(f"Exception fetching snapshot from {url}: {str(e)}")
return None
# Helper to get crop coordinates from stream
def get_crop_coords(stream):
return {
"cropX1": stream.get("cropX1"),
"cropY1": stream.get("cropY1"),
"cropX2": stream.get("cropX2"),
"cropY2": stream.get("cropY2")
}
####################################################
# REST API endpoint for image retrieval
####################################################
@app.get("/camera/{camera_id}/image")
async def get_camera_image(camera_id: str):
"""
Get the current frame from a camera as JPEG image
"""
try:
# URL decode the camera_id to handle encoded characters like %3B for semicolon
from urllib.parse import unquote
original_camera_id = camera_id
camera_id = unquote(camera_id)
logger.debug(f"REST API request: original='{original_camera_id}', decoded='{camera_id}'")
with streams_lock:
if camera_id not in streams:
logger.warning(f"Camera ID '{camera_id}' not found in streams. Current streams: {list(streams.keys())}")
raise HTTPException(status_code=404, detail=f"Camera {camera_id} not found or not active")
# Check if we have a cached frame for this camera
if camera_id not in latest_frames:
logger.warning(f"No cached frame available for camera '{camera_id}'.")
raise HTTPException(status_code=404, detail=f"No frame available for camera {camera_id}")
frame = latest_frames[camera_id]
logger.debug(f"Retrieved cached frame for camera '{camera_id}', frame shape: {frame.shape}")
# Encode frame as JPEG
success, buffer_img = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
if not success:
raise HTTPException(status_code=500, detail="Failed to encode image as JPEG")
# Return image as binary response
return Response(content=buffer_img.tobytes(), media_type="image/jpeg")
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving image for camera {camera_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
####################################################
# Detection and frame processing functions
####################################################
@app.websocket("/")
async def detect(websocket: WebSocket):
logger.info("WebSocket connection accepted")
persistent_data_dict = {}
async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data):
try:
# Apply crop if specified
cropped_frame = frame
if all(coord is not None for coord in [stream.get("cropX1"), stream.get("cropY1"), stream.get("cropX2"), stream.get("cropY2")]):
cropX1, cropY1, cropX2, cropY2 = stream["cropX1"], stream["cropY1"], stream["cropX2"], stream["cropY2"]
cropped_frame = frame[cropY1:cropY2, cropX1:cropX2]
logger.debug(f"Applied crop coordinates ({cropX1}, {cropY1}, {cropX2}, {cropY2}) to frame for camera {camera_id}")
logger.debug(f"Processing frame for camera {camera_id} with model {stream['modelId']}")
start_time = time.time()
# Extract display identifier for session ID lookup
subscription_parts = stream["subscriptionIdentifier"].split(';')
display_identifier = subscription_parts[0] if subscription_parts else None
session_id = session_ids.get(display_identifier) if display_identifier else None
# Create context for pipeline execution
pipeline_context = {
"camera_id": camera_id,
"display_id": display_identifier,
"session_id": session_id
}
detection_result = run_pipeline(cropped_frame, model_tree, context=pipeline_context)
process_time = (time.time() - start_time) * 1000
logger.debug(f"Detection for camera {camera_id} completed in {process_time:.2f}ms")
# Log the raw detection result for debugging
logger.debug(f"Raw detection result for camera {camera_id}:\n{json.dumps(detection_result, indent=2, default=str)}")
# Direct class result (no detections/classifications structure)
if detection_result and isinstance(detection_result, dict) and "class" in detection_result and "confidence" in detection_result:
highest_confidence_detection = {
"class": detection_result.get("class", "none"),
"confidence": detection_result.get("confidence", 1.0),
"box": [0, 0, 0, 0] # Empty bounding box for classifications
}
# Handle case when no detections found or result is empty
elif not detection_result or not detection_result.get("detections"):
# Check if we have classification results
if detection_result and detection_result.get("classifications"):
# Get the highest confidence classification
classifications = detection_result.get("classifications", [])
highest_confidence_class = max(classifications, key=lambda x: x.get("confidence", 0)) if classifications else None
if highest_confidence_class:
highest_confidence_detection = {
"class": highest_confidence_class.get("class", "none"),
"confidence": highest_confidence_class.get("confidence", 1.0),
"box": [0, 0, 0, 0] # Empty bounding box for classifications
}
else:
highest_confidence_detection = {
"class": "none",
"confidence": 1.0,
"box": [0, 0, 0, 0]
}
else:
highest_confidence_detection = {
"class": "none",
"confidence": 1.0,
"box": [0, 0, 0, 0]
}
else:
# Find detection with highest confidence
detections = detection_result.get("detections", [])
highest_confidence_detection = max(detections, key=lambda x: x.get("confidence", 0)) if detections else {
"class": "none",
"confidence": 1.0,
"box": [0, 0, 0, 0]
}
# Convert detection format to match protocol - flatten detection attributes
detection_dict = {}
# Handle different detection result formats
if isinstance(highest_confidence_detection, dict):
# Copy all fields from the detection result
for key, value in highest_confidence_detection.items():
if key not in ["box", "id"]: # Skip internal fields
detection_dict[key] = value
detection_data = {
"type": "imageDetection",
"subscriptionIdentifier": stream["subscriptionIdentifier"],
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()),
"data": {
"detection": detection_dict,
"modelId": stream["modelId"],
"modelName": stream["modelName"]
}
}
# Add session ID if available
if session_id is not None:
detection_data["sessionId"] = session_id
if highest_confidence_detection["class"] != "none":
logger.info(f"Camera {camera_id}: Detected {highest_confidence_detection['class']} with confidence {highest_confidence_detection['confidence']:.2f} using model {stream['modelName']}")
# Log session ID if available
if session_id:
logger.debug(f"Detection associated with session ID: {session_id}")
await websocket.send_json(detection_data)
logger.debug(f"Sent detection data to client for camera {camera_id}")
return persistent_data
except Exception as e:
logger.error(f"Error in handle_detection for camera {camera_id}: {str(e)}", exc_info=True)
return persistent_data
def frame_reader(camera_id, cap, buffer, stop_event):
retries = 0
logger.info(f"Starting frame reader thread for camera {camera_id}")
frame_count = 0
last_log_time = time.time()
try:
# Log initial camera status and properties
if cap.isOpened():
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
logger.info(f"Camera {camera_id} opened successfully with resolution {width}x{height}, FPS: {fps}")
else:
logger.error(f"Camera {camera_id} failed to open initially")
while not stop_event.is_set():
try:
if not cap.isOpened():
logger.error(f"Camera {camera_id} is not open before trying to read")
# Attempt to reopen
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
time.sleep(reconnect_interval)
continue
logger.debug(f"Attempting to read frame from camera {camera_id}")
ret, frame = cap.read()
if not ret:
logger.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}")
cap.release()
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logger.error(f"Max retries reached for camera: {camera_id}, stopping frame reader")
break
# Re-open
logger.info(f"Attempting to reopen RTSP stream for camera: {camera_id}")
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
if not cap.isOpened():
logger.error(f"Failed to reopen RTSP stream for camera: {camera_id}")
continue
logger.info(f"Successfully reopened RTSP stream for camera: {camera_id}")
continue
# Successfully read a frame
frame_count += 1
current_time = time.time()
# Log frame stats every 5 seconds
if current_time - last_log_time > 5:
logger.info(f"Camera {camera_id}: Read {frame_count} frames in the last {current_time - last_log_time:.1f} seconds")
frame_count = 0
last_log_time = current_time
logger.debug(f"Successfully read frame from camera {camera_id}, shape: {frame.shape}")
retries = 0
# Overwrite old frame if buffer is full
if not buffer.empty():
try:
buffer.get_nowait()
logger.debug(f"[frame_reader] Removed old frame from buffer for camera {camera_id}")
except queue.Empty:
pass
buffer.put(frame)
logger.debug(f"[frame_reader] Added new frame to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}")
# Short sleep to avoid CPU overuse
time.sleep(0.01)
except cv2.error as e:
logger.error(f"OpenCV error for camera {camera_id}: {e}", exc_info=True)
cap.release()
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logger.error(f"Max retries reached after OpenCV error for camera {camera_id}")
break
logger.info(f"Attempting to reopen RTSP stream after OpenCV error for camera: {camera_id}")
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
if not cap.isOpened():
logger.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error")
continue
logger.info(f"Successfully reopened RTSP stream after OpenCV error for camera: {camera_id}")
except Exception as e:
logger.error(f"Unexpected error for camera {camera_id}: {str(e)}", exc_info=True)
cap.release()
break
except Exception as e:
logger.error(f"Error in frame_reader thread for camera {camera_id}: {str(e)}", exc_info=True)
finally:
logger.info(f"Frame reader thread for camera {camera_id} is exiting")
if cap and cap.isOpened():
cap.release()
def snapshot_reader(camera_id, snapshot_url, snapshot_interval, buffer, stop_event):
"""Frame reader that fetches snapshots from HTTP/HTTPS URL at specified intervals"""
retries = 0
logger.info(f"Starting snapshot reader thread for camera {camera_id} from {snapshot_url}")
frame_count = 0
last_log_time = time.time()
try:
interval_seconds = snapshot_interval / 1000.0 # Convert milliseconds to seconds
logger.info(f"Snapshot interval for camera {camera_id}: {interval_seconds}s")
while not stop_event.is_set():
try:
start_time = time.time()
frame = fetch_snapshot(snapshot_url)
if frame is None:
logger.warning(f"Failed to fetch snapshot for camera: {camera_id}, retry {retries+1}/{max_retries}")
retries += 1
if retries > max_retries and max_retries != -1:
logger.error(f"Max retries reached for snapshot camera: {camera_id}, stopping reader")
break
time.sleep(min(interval_seconds, reconnect_interval))
continue
# Successfully fetched a frame
frame_count += 1
current_time = time.time()
# Log frame stats every 5 seconds
if current_time - last_log_time > 5:
logger.info(f"Camera {camera_id}: Fetched {frame_count} snapshots in the last {current_time - last_log_time:.1f} seconds")
frame_count = 0
last_log_time = current_time
logger.debug(f"Successfully fetched snapshot from camera {camera_id}, shape: {frame.shape}")
retries = 0
# Overwrite old frame if buffer is full
if not buffer.empty():
try:
buffer.get_nowait()
logger.debug(f"[snapshot_reader] Removed old snapshot from buffer for camera {camera_id}")
except queue.Empty:
pass
buffer.put(frame)
logger.debug(f"[snapshot_reader] Added new snapshot to buffer for camera {camera_id}. Buffer size: {buffer.qsize()}")
# Wait for the specified interval
elapsed = time.time() - start_time
sleep_time = max(interval_seconds - elapsed, 0)
if sleep_time > 0:
time.sleep(sleep_time)
except Exception as e:
logger.error(f"Unexpected error fetching snapshot for camera {camera_id}: {str(e)}", exc_info=True)
retries += 1
if retries > max_retries and max_retries != -1:
logger.error(f"Max retries reached after error for snapshot camera {camera_id}")
break
time.sleep(min(interval_seconds, reconnect_interval))
except Exception as e:
logger.error(f"Error in snapshot_reader thread for camera {camera_id}: {str(e)}", exc_info=True)
finally:
logger.info(f"Snapshot reader thread for camera {camera_id} is exiting")
async def process_streams():
logger.info("Started processing streams")
try:
while True:
start_time = time.time()
with streams_lock:
current_streams = list(streams.items())
if current_streams:
logger.debug(f"Processing {len(current_streams)} active streams")
else:
logger.debug("No active streams to process")
for camera_id, stream in current_streams:
buffer = stream["buffer"]
if buffer.empty():
logger.debug(f"Frame buffer is empty for camera {camera_id}")
continue
logger.debug(f"Got frame from buffer for camera {camera_id}")
frame = buffer.get()
# Cache the frame for REST API access
latest_frames[camera_id] = frame.copy()
logger.debug(f"Cached frame for REST API access for camera {camera_id}")
with models_lock:
model_tree = models.get(camera_id, {}).get(stream["modelId"])
if not model_tree:
logger.warning(f"Model not found for camera {camera_id}, modelId {stream['modelId']}")
continue
logger.debug(f"Found model tree for camera {camera_id}, modelId {stream['modelId']}")
key = (camera_id, stream["modelId"])
persistent_data = persistent_data_dict.get(key, {})
logger.debug(f"Starting detection for camera {camera_id} with modelId {stream['modelId']}")
updated_persistent_data = await handle_detection(
camera_id, stream, frame, websocket, model_tree, persistent_data
)
persistent_data_dict[key] = updated_persistent_data
elapsed_time = (time.time() - start_time) * 1000 # ms
sleep_time = max(poll_interval - elapsed_time, 0)
logger.debug(f"Frame processing cycle: {elapsed_time:.2f}ms, sleeping for: {sleep_time:.2f}ms")
await asyncio.sleep(sleep_time / 1000.0)
except asyncio.CancelledError:
logger.info("Stream processing task cancelled")
except Exception as e:
logger.error(f"Error in process_streams: {str(e)}", exc_info=True)
async def send_heartbeat():
while True:
try:
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2)
else:
gpu_usage = None
gpu_memory_usage = None
camera_connections = [
{
"subscriptionIdentifier": stream["subscriptionIdentifier"],
"modelId": stream["modelId"],
"modelName": stream["modelName"],
"online": True,
**{k: v for k, v in get_crop_coords(stream).items() if v is not None}
}
for camera_id, stream in streams.items()
]
state_report = {
"type": "stateReport",
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemoryUsage": gpu_memory_usage,
"cameraConnections": camera_connections
}
await websocket.send_text(json.dumps(state_report))
logger.debug(f"Sent stateReport as heartbeat: CPU {cpu_usage:.1f}%, Memory {memory_usage:.1f}%, {len(camera_connections)} active cameras")
await asyncio.sleep(HEARTBEAT_INTERVAL)
except Exception as e:
logger.error(f"Error sending stateReport heartbeat: {e}")
break
async def on_message():
while True:
try:
msg = await websocket.receive_text()
logger.debug(f"Received message: {msg}")
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "subscribe":
payload = data.get("payload", {})
subscriptionIdentifier = payload.get("subscriptionIdentifier")
rtsp_url = payload.get("rtspUrl")
snapshot_url = payload.get("snapshotUrl")
snapshot_interval = payload.get("snapshotInterval")
model_url = payload.get("modelUrl")
modelId = payload.get("modelId")
modelName = payload.get("modelName")
cropX1 = payload.get("cropX1")
cropY1 = payload.get("cropY1")
cropX2 = payload.get("cropX2")
cropY2 = payload.get("cropY2")
# Extract camera_id from subscriptionIdentifier (format: displayIdentifier;cameraIdentifier)
parts = subscriptionIdentifier.split(';')
if len(parts) != 2:
logger.error(f"Invalid subscriptionIdentifier format: {subscriptionIdentifier}")
continue
display_identifier, camera_identifier = parts
camera_id = subscriptionIdentifier # Use full subscriptionIdentifier as camera_id for mapping
if model_url:
with models_lock:
if (camera_id not in models) or (modelId not in models[camera_id]):
logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
extraction_dir = os.path.join("models", camera_identifier, str(modelId))
os.makedirs(extraction_dir, exist_ok=True)
# If model_url is remote, download it first.
parsed = urlparse(model_url)
if parsed.scheme in ("http", "https"):
logger.info(f"Downloading remote .mpta file from {model_url}")
filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
local_mpta = os.path.join(extraction_dir, filename)
logger.debug(f"Download destination: {local_mpta}")
local_path = download_mpta(model_url, local_mpta)
if not local_path:
logger.error(f"Failed to download the remote .mpta file from {model_url}")
error_response = {
"type": "error",
"subscriptionIdentifier": subscriptionIdentifier,
"error": f"Failed to download model from {model_url}"
}
await websocket.send_json(error_response)
continue
model_tree = load_pipeline_from_zip(local_path, extraction_dir)
else:
logger.info(f"Loading local .mpta file from {model_url}")
# Check if file exists before attempting to load
if not os.path.exists(model_url):
logger.error(f"Local .mpta file not found: {model_url}")
logger.debug(f"Current working directory: {os.getcwd()}")
error_response = {
"type": "error",
"subscriptionIdentifier": subscriptionIdentifier,
"error": f"Model file not found: {model_url}"
}
await websocket.send_json(error_response)
continue
model_tree = load_pipeline_from_zip(model_url, extraction_dir)
if model_tree is None:
logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}")
error_response = {
"type": "error",
"subscriptionIdentifier": subscriptionIdentifier,
"error": f"Failed to load model {modelId}"
}
await websocket.send_json(error_response)
continue
if camera_id not in models:
models[camera_id] = {}
models[camera_id][modelId] = model_tree
logger.info(f"Successfully loaded model {modelId} for camera {camera_id}")
logger.debug(f"Model extraction directory: {extraction_dir}")
if camera_id and (rtsp_url or snapshot_url):
with streams_lock:
# Determine camera URL for shared stream management
camera_url = snapshot_url if snapshot_url else rtsp_url
if camera_id not in streams and len(streams) < max_streams:
# Check if we already have a stream for this camera URL
shared_stream = camera_streams.get(camera_url)
if shared_stream:
# Reuse existing stream
logger.info(f"Reusing existing stream for camera URL: {camera_url}")
buffer = shared_stream["buffer"]
stop_event = shared_stream["stop_event"]
thread = shared_stream["thread"]
mode = shared_stream["mode"]
# Increment reference count
shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1
else:
# Create new stream
buffer = queue.Queue(maxsize=1)
stop_event = threading.Event()
if snapshot_url and snapshot_interval:
logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}")
thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "snapshot"
# Store shared stream info
shared_stream = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"mode": mode,
"url": snapshot_url,
"snapshot_interval": snapshot_interval,
"ref_count": 1
}
camera_streams[camera_url] = shared_stream
elif rtsp_url:
logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream for camera {camera_id}")
continue
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "rtsp"
# Store shared stream info
shared_stream = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"mode": mode,
"url": rtsp_url,
"cap": cap,
"ref_count": 1
}
camera_streams[camera_url] = shared_stream
else:
logger.error(f"No valid URL provided for camera {camera_id}")
continue
# Create stream info for this subscription
stream_info = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"modelId": modelId,
"modelName": modelName,
"subscriptionIdentifier": subscriptionIdentifier,
"cropX1": cropX1,
"cropY1": cropY1,
"cropX2": cropX2,
"cropY2": cropY2,
"mode": mode,
"camera_url": camera_url
}
if mode == "snapshot":
stream_info["snapshot_url"] = snapshot_url
stream_info["snapshot_interval"] = snapshot_interval
elif mode == "rtsp":
stream_info["rtsp_url"] = rtsp_url
stream_info["cap"] = shared_stream["cap"]
streams[camera_id] = stream_info
subscription_to_camera[camera_id] = camera_url
elif camera_id and camera_id in streams:
# If already subscribed, unsubscribe first
logger.info(f"Resubscribing to camera {camera_id}")
# Note: Keep models in memory for reuse across subscriptions
elif msg_type == "unsubscribe":
payload = data.get("payload", {})
subscriptionIdentifier = payload.get("subscriptionIdentifier")
camera_id = subscriptionIdentifier
with streams_lock:
if camera_id and camera_id in streams:
stream = streams.pop(camera_id)
camera_url = subscription_to_camera.pop(camera_id, None)
if camera_url and camera_url in camera_streams:
shared_stream = camera_streams[camera_url]
shared_stream["ref_count"] -= 1
# If no more references, stop the shared stream
if shared_stream["ref_count"] <= 0:
logger.info(f"Stopping shared stream for camera URL: {camera_url}")
shared_stream["stop_event"].set()
shared_stream["thread"].join()
if "cap" in shared_stream:
shared_stream["cap"].release()
del camera_streams[camera_url]
else:
logger.info(f"Shared stream for {camera_url} still has {shared_stream['ref_count']} references")
# Clean up cached frame
latest_frames.pop(camera_id, None)
logger.info(f"Unsubscribed from camera {camera_id}")
# Note: Keep models in memory for potential reuse
elif msg_type == "requestState":
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.utilization() if hasattr(torch.cuda, 'utilization') else None
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2)
else:
gpu_usage = None
gpu_memory_usage = None
camera_connections = [
{
"subscriptionIdentifier": stream["subscriptionIdentifier"],
"modelId": stream["modelId"],
"modelName": stream["modelName"],
"online": True,
**{k: v for k, v in get_crop_coords(stream).items() if v is not None}
}
for camera_id, stream in streams.items()
]
state_report = {
"type": "stateReport",
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemoryUsage": gpu_memory_usage,
"cameraConnections": camera_connections
}
await websocket.send_text(json.dumps(state_report))
elif msg_type == "setSessionId":
payload = data.get("payload", {})
display_identifier = payload.get("displayIdentifier")
session_id = payload.get("sessionId")
if display_identifier:
# Store session ID for this display
if session_id is None:
session_ids.pop(display_identifier, None)
logger.info(f"Cleared session ID for display {display_identifier}")
else:
session_ids[display_identifier] = session_id
logger.info(f"Set session ID {session_id} for display {display_identifier}")
elif msg_type == "patchSession":
session_id = data.get("sessionId")
patch_data = data.get("data", {})
# For now, just acknowledge the patch - actual implementation depends on backend requirements
response = {
"type": "patchSessionResult",
"payload": {
"sessionId": session_id,
"success": True,
"message": "Session patch acknowledged"
}
}
await websocket.send_json(response)
logger.info(f"Acknowledged patch for session {session_id}")
else:
logger.error(f"Unknown message type: {msg_type}")
except json.JSONDecodeError:
logger.error("Received invalid JSON message")
except (WebSocketDisconnect, ConnectionClosedError) as e:
logger.warning(f"WebSocket disconnected: {e}")
break
except Exception as e:
logger.error(f"Error handling message: {e}")
break
try:
await websocket.accept()
stream_task = asyncio.create_task(process_streams())
heartbeat_task = asyncio.create_task(send_heartbeat())
message_task = asyncio.create_task(on_message())
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logger.error(f"Error in detect websocket: {e}")
finally:
stream_task.cancel()
await stream_task
with streams_lock:
# Clean up shared camera streams
for camera_url, shared_stream in camera_streams.items():
shared_stream["stop_event"].set()
shared_stream["thread"].join()
if "cap" in shared_stream:
shared_stream["cap"].release()
while not shared_stream["buffer"].empty():
try:
shared_stream["buffer"].get_nowait()
except queue.Empty:
pass
logger.info(f"Released shared camera stream for {camera_url}")
streams.clear()
camera_streams.clear()
subscription_to_camera.clear()
with models_lock:
models.clear()
latest_frames.clear()
session_ids.clear()
logger.info("WebSocket connection closed")

View file

@ -0,0 +1,211 @@
import psycopg2
import psycopg2.extras
from typing import Optional, Dict, Any
import logging
import uuid
logger = logging.getLogger(__name__)
class DatabaseManager:
def __init__(self, config: Dict[str, Any]):
self.config = config
self.connection: Optional[psycopg2.extensions.connection] = None
def connect(self) -> bool:
try:
self.connection = psycopg2.connect(
host=self.config['host'],
port=self.config['port'],
database=self.config['database'],
user=self.config['username'],
password=self.config['password']
)
logger.info("PostgreSQL connection established successfully")
return True
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
return False
def disconnect(self):
if self.connection:
self.connection.close()
self.connection = None
logger.info("PostgreSQL connection closed")
def is_connected(self) -> bool:
try:
if self.connection and not self.connection.closed:
cur = self.connection.cursor()
cur.execute("SELECT 1")
cur.fetchone()
cur.close()
return True
except:
pass
return False
def update_car_info(self, session_id: str, brand: str, model: str, body_type: str) -> bool:
if not self.is_connected():
if not self.connect():
return False
try:
cur = self.connection.cursor()
query = """
INSERT INTO car_frontal_info (session_id, car_brand, car_model, car_body_type, updated_at)
VALUES (%s, %s, %s, %s, NOW())
ON CONFLICT (session_id)
DO UPDATE SET
car_brand = EXCLUDED.car_brand,
car_model = EXCLUDED.car_model,
car_body_type = EXCLUDED.car_body_type,
updated_at = NOW()
"""
cur.execute(query, (session_id, brand, model, body_type))
self.connection.commit()
cur.close()
logger.info(f"Updated car info for session {session_id}: {brand} {model} ({body_type})")
return True
except Exception as e:
logger.error(f"Failed to update car info: {e}")
if self.connection:
self.connection.rollback()
return False
def execute_update(self, table: str, key_field: str, key_value: str, fields: Dict[str, str]) -> bool:
if not self.is_connected():
if not self.connect():
return False
try:
cur = self.connection.cursor()
# Build the UPDATE query dynamically
set_clauses = []
values = []
for field, value in fields.items():
if value == "NOW()":
set_clauses.append(f"{field} = NOW()")
else:
set_clauses.append(f"{field} = %s")
values.append(value)
# Add schema prefix if table doesn't already have it
full_table_name = table if '.' in table else f"gas_station_1.{table}"
query = f"""
INSERT INTO {full_table_name} ({key_field}, {', '.join(fields.keys())})
VALUES (%s, {', '.join(['%s'] * len(fields))})
ON CONFLICT ({key_field})
DO UPDATE SET {', '.join(set_clauses)}
"""
# Add key_value to the beginning of values list
all_values = [key_value] + list(fields.values()) + values
cur.execute(query, all_values)
self.connection.commit()
cur.close()
logger.info(f"Updated {table} for {key_field}={key_value}")
return True
except Exception as e:
logger.error(f"Failed to execute update on {table}: {e}")
if self.connection:
self.connection.rollback()
return False
def create_car_frontal_info_table(self) -> bool:
"""Create the car_frontal_info table in gas_station_1 schema if it doesn't exist."""
if not self.is_connected():
if not self.connect():
return False
try:
cur = self.connection.cursor()
# Create schema if it doesn't exist
cur.execute("CREATE SCHEMA IF NOT EXISTS gas_station_1")
# Create table if it doesn't exist
create_table_query = """
CREATE TABLE IF NOT EXISTS gas_station_1.car_frontal_info (
display_id VARCHAR(255),
captured_timestamp VARCHAR(255),
session_id VARCHAR(255) PRIMARY KEY,
license_character VARCHAR(255) DEFAULT NULL,
license_type VARCHAR(255) DEFAULT 'No model available',
car_brand VARCHAR(255) DEFAULT NULL,
car_model VARCHAR(255) DEFAULT NULL,
car_body_type VARCHAR(255) DEFAULT NULL,
updated_at TIMESTAMP DEFAULT NOW()
)
"""
cur.execute(create_table_query)
# Add columns if they don't exist (for existing tables)
alter_queries = [
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_brand VARCHAR(255) DEFAULT NULL",
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_model VARCHAR(255) DEFAULT NULL",
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS car_body_type VARCHAR(255) DEFAULT NULL",
"ALTER TABLE gas_station_1.car_frontal_info ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP DEFAULT NOW()"
]
for alter_query in alter_queries:
try:
cur.execute(alter_query)
logger.debug(f"Executed: {alter_query}")
except Exception as e:
# Ignore errors if column already exists (for older PostgreSQL versions)
if "already exists" in str(e).lower():
logger.debug(f"Column already exists, skipping: {alter_query}")
else:
logger.warning(f"Error in ALTER TABLE: {e}")
self.connection.commit()
cur.close()
logger.info("Successfully created/verified car_frontal_info table with all required columns")
return True
except Exception as e:
logger.error(f"Failed to create car_frontal_info table: {e}")
if self.connection:
self.connection.rollback()
return False
def insert_initial_detection(self, display_id: str, captured_timestamp: str, session_id: str = None) -> str:
"""Insert initial detection record and return the session_id."""
if not self.is_connected():
if not self.connect():
return None
# Generate session_id if not provided
if not session_id:
session_id = str(uuid.uuid4())
try:
# Ensure table exists
if not self.create_car_frontal_info_table():
logger.error("Failed to create/verify table before insertion")
return None
cur = self.connection.cursor()
insert_query = """
INSERT INTO gas_station_1.car_frontal_info
(display_id, captured_timestamp, session_id, license_character, license_type, car_brand, car_model, car_body_type)
VALUES (%s, %s, %s, NULL, 'No model available', NULL, NULL, NULL)
ON CONFLICT (session_id) DO NOTHING
"""
cur.execute(insert_query, (display_id, captured_timestamp, session_id))
self.connection.commit()
cur.close()
logger.info(f"Inserted initial detection record with session_id: {session_id}")
return session_id
except Exception as e:
logger.error(f"Failed to insert initial detection record: {e}")
if self.connection:
self.connection.rollback()
return None

View file

@ -0,0 +1,798 @@
import os
import json
import logging
import torch
import cv2
import zipfile
import shutil
import traceback
import redis
import time
import uuid
import concurrent.futures
from ultralytics import YOLO
from urllib.parse import urlparse
from .database import DatabaseManager
# Create a logger specifically for this module
logger = logging.getLogger("detector_worker.pympta")
def validate_redis_config(redis_config: dict) -> bool:
"""Validate Redis configuration parameters."""
required_fields = ["host", "port"]
for field in required_fields:
if field not in redis_config:
logger.error(f"Missing required Redis config field: {field}")
return False
if not isinstance(redis_config["port"], int) or redis_config["port"] <= 0:
logger.error(f"Invalid Redis port: {redis_config['port']}")
return False
return True
def validate_postgresql_config(pg_config: dict) -> bool:
"""Validate PostgreSQL configuration parameters."""
required_fields = ["host", "port", "database", "username", "password"]
for field in required_fields:
if field not in pg_config:
logger.error(f"Missing required PostgreSQL config field: {field}")
return False
if not isinstance(pg_config["port"], int) or pg_config["port"] <= 0:
logger.error(f"Invalid PostgreSQL port: {pg_config['port']}")
return False
return True
def crop_region_by_class(frame, regions_dict, class_name):
"""Crop a specific region from frame based on detected class."""
if class_name not in regions_dict:
logger.warning(f"Class '{class_name}' not found in detected regions")
return None
bbox = regions_dict[class_name]['bbox']
x1, y1, x2, y2 = bbox
cropped = frame[y1:y2, x1:x2]
if cropped.size == 0:
logger.warning(f"Empty crop for class '{class_name}' with bbox {bbox}")
return None
return cropped
def format_action_context(base_context, additional_context=None):
"""Format action context with dynamic values."""
context = {**base_context}
if additional_context:
context.update(additional_context)
return context
def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client, db_manager=None) -> dict:
# Recursively load a model node from configuration.
model_path = os.path.join(mpta_dir, node_config["modelFile"])
if not os.path.exists(model_path):
logger.error(f"Model file {model_path} not found. Current directory: {os.getcwd()}")
logger.error(f"Directory content: {os.listdir(os.path.dirname(model_path))}")
raise FileNotFoundError(f"Model file {model_path} not found.")
logger.info(f"Loading model for node {node_config['modelId']} from {model_path}")
model = YOLO(model_path)
if torch.cuda.is_available():
logger.info(f"CUDA available. Moving model {node_config['modelId']} to GPU")
model.to("cuda")
else:
logger.info(f"CUDA not available. Using CPU for model {node_config['modelId']}")
# Prepare trigger class indices for optimization
trigger_classes = node_config.get("triggerClasses", [])
trigger_class_indices = None
if trigger_classes and hasattr(model, "names"):
# Convert class names to indices for the model
trigger_class_indices = [i for i, name in model.names.items()
if name in trigger_classes]
logger.debug(f"Converted trigger classes to indices: {trigger_class_indices}")
node = {
"modelId": node_config["modelId"],
"modelFile": node_config["modelFile"],
"triggerClasses": trigger_classes,
"triggerClassIndices": trigger_class_indices,
"crop": node_config.get("crop", False),
"cropClass": node_config.get("cropClass"),
"minConfidence": node_config.get("minConfidence", None),
"multiClass": node_config.get("multiClass", False),
"expectedClasses": node_config.get("expectedClasses", []),
"parallel": node_config.get("parallel", False),
"actions": node_config.get("actions", []),
"parallelActions": node_config.get("parallelActions", []),
"model": model,
"branches": [],
"redis_client": redis_client,
"db_manager": db_manager
}
logger.debug(f"Configured node {node_config['modelId']} with trigger classes: {node['triggerClasses']}")
for child in node_config.get("branches", []):
logger.debug(f"Loading branch for parent node {node_config['modelId']}")
node["branches"].append(load_pipeline_node(child, mpta_dir, redis_client, db_manager))
return node
def load_pipeline_from_zip(zip_source: str, target_dir: str) -> dict:
logger.info(f"Attempting to load pipeline from {zip_source} to {target_dir}")
os.makedirs(target_dir, exist_ok=True)
zip_path = os.path.join(target_dir, "pipeline.mpta")
# Parse the source; only local files are supported here.
parsed = urlparse(zip_source)
if parsed.scheme in ("", "file"):
local_path = parsed.path if parsed.scheme == "file" else zip_source
logger.debug(f"Checking if local file exists: {local_path}")
if os.path.exists(local_path):
try:
shutil.copy(local_path, zip_path)
logger.info(f"Copied local .mpta file from {local_path} to {zip_path}")
except Exception as e:
logger.error(f"Failed to copy local .mpta file from {local_path}: {str(e)}", exc_info=True)
return None
else:
logger.error(f"Local file {local_path} does not exist. Current directory: {os.getcwd()}")
# List all subdirectories of models directory to help debugging
if os.path.exists("models"):
logger.error(f"Content of models directory: {os.listdir('models')}")
for root, dirs, files in os.walk("models"):
logger.error(f"Directory {root} contains subdirs: {dirs} and files: {files}")
else:
logger.error("The models directory doesn't exist")
return None
else:
logger.error(f"HTTP download functionality has been moved. Use a local file path here. Received: {zip_source}")
return None
try:
if not os.path.exists(zip_path):
logger.error(f"Zip file not found at expected location: {zip_path}")
return None
logger.debug(f"Extracting .mpta file from {zip_path} to {target_dir}")
# Extract contents and track the directories created
extracted_dirs = []
with zipfile.ZipFile(zip_path, "r") as zip_ref:
file_list = zip_ref.namelist()
logger.debug(f"Files in .mpta archive: {file_list}")
# Extract and track the top-level directories
for file_path in file_list:
parts = file_path.split('/')
if len(parts) > 1:
top_dir = parts[0]
if top_dir and top_dir not in extracted_dirs:
extracted_dirs.append(top_dir)
# Now extract the files
zip_ref.extractall(target_dir)
logger.info(f"Successfully extracted .mpta file to {target_dir}")
logger.debug(f"Extracted directories: {extracted_dirs}")
# Check what was actually created after extraction
actual_dirs = [d for d in os.listdir(target_dir) if os.path.isdir(os.path.join(target_dir, d))]
logger.debug(f"Actual directories created: {actual_dirs}")
except zipfile.BadZipFile as e:
logger.error(f"Bad zip file {zip_path}: {str(e)}", exc_info=True)
return None
except Exception as e:
logger.error(f"Failed to extract .mpta file {zip_path}: {str(e)}", exc_info=True)
return None
finally:
if os.path.exists(zip_path):
os.remove(zip_path)
logger.debug(f"Removed temporary zip file: {zip_path}")
# Use the first extracted directory if it exists, otherwise use the expected name
pipeline_name = os.path.basename(zip_source)
pipeline_name = os.path.splitext(pipeline_name)[0]
# Find the directory with pipeline.json
mpta_dir = None
# First try the expected directory name
expected_dir = os.path.join(target_dir, pipeline_name)
if os.path.exists(expected_dir) and os.path.exists(os.path.join(expected_dir, "pipeline.json")):
mpta_dir = expected_dir
logger.debug(f"Found pipeline.json in the expected directory: {mpta_dir}")
else:
# Look through all subdirectories for pipeline.json
for subdir in actual_dirs:
potential_dir = os.path.join(target_dir, subdir)
if os.path.exists(os.path.join(potential_dir, "pipeline.json")):
mpta_dir = potential_dir
logger.info(f"Found pipeline.json in directory: {mpta_dir} (different from expected: {expected_dir})")
break
if not mpta_dir:
logger.error(f"Could not find pipeline.json in any extracted directory. Directory content: {os.listdir(target_dir)}")
return None
pipeline_json_path = os.path.join(mpta_dir, "pipeline.json")
if not os.path.exists(pipeline_json_path):
logger.error(f"pipeline.json not found in the .mpta file. Files in directory: {os.listdir(mpta_dir)}")
return None
try:
with open(pipeline_json_path, "r") as f:
pipeline_config = json.load(f)
logger.info(f"Successfully loaded pipeline configuration from {pipeline_json_path}")
logger.debug(f"Pipeline config: {json.dumps(pipeline_config, indent=2)}")
# Establish Redis connection if configured
redis_client = None
if "redis" in pipeline_config:
redis_config = pipeline_config["redis"]
if not validate_redis_config(redis_config):
logger.error("Invalid Redis configuration, skipping Redis connection")
else:
try:
redis_client = redis.Redis(
host=redis_config["host"],
port=redis_config["port"],
password=redis_config.get("password"),
db=redis_config.get("db", 0),
decode_responses=True
)
redis_client.ping()
logger.info(f"Successfully connected to Redis at {redis_config['host']}:{redis_config['port']}")
except redis.exceptions.ConnectionError as e:
logger.error(f"Failed to connect to Redis: {e}")
redis_client = None
# Establish PostgreSQL connection if configured
db_manager = None
if "postgresql" in pipeline_config:
pg_config = pipeline_config["postgresql"]
if not validate_postgresql_config(pg_config):
logger.error("Invalid PostgreSQL configuration, skipping database connection")
else:
try:
db_manager = DatabaseManager(pg_config)
if db_manager.connect():
logger.info(f"Successfully connected to PostgreSQL at {pg_config['host']}:{pg_config['port']}")
else:
logger.error("Failed to connect to PostgreSQL")
db_manager = None
except Exception as e:
logger.error(f"Error initializing PostgreSQL connection: {e}")
db_manager = None
return load_pipeline_node(pipeline_config["pipeline"], mpta_dir, redis_client, db_manager)
except json.JSONDecodeError as e:
logger.error(f"Error parsing pipeline.json: {str(e)}", exc_info=True)
return None
except KeyError as e:
logger.error(f"Missing key in pipeline.json: {str(e)}", exc_info=True)
return None
except Exception as e:
logger.error(f"Error loading pipeline.json: {str(e)}", exc_info=True)
return None
def execute_actions(node, frame, detection_result, regions_dict=None):
if not node["redis_client"] or not node["actions"]:
return
# Create a dynamic context for this detection event
from datetime import datetime
action_context = {
**detection_result,
"timestamp_ms": int(time.time() * 1000),
"uuid": str(uuid.uuid4()),
"timestamp": datetime.now().strftime("%Y-%m-%dT%H-%M-%S"),
"filename": f"{uuid.uuid4()}.jpg"
}
for action in node["actions"]:
try:
if action["type"] == "redis_save_image":
key = action["key"].format(**action_context)
# Check if we need to crop a specific region
region_name = action.get("region")
image_to_save = frame
if region_name and regions_dict:
cropped_image = crop_region_by_class(frame, regions_dict, region_name)
if cropped_image is not None:
image_to_save = cropped_image
logger.debug(f"Cropped region '{region_name}' for redis_save_image")
else:
logger.warning(f"Could not crop region '{region_name}', saving full frame instead")
# Encode image with specified format and quality (default to JPEG)
img_format = action.get("format", "jpeg").lower()
quality = action.get("quality", 90)
if img_format == "jpeg":
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
success, buffer = cv2.imencode('.jpg', image_to_save, encode_params)
elif img_format == "png":
success, buffer = cv2.imencode('.png', image_to_save)
else:
success, buffer = cv2.imencode('.jpg', image_to_save, [cv2.IMWRITE_JPEG_QUALITY, quality])
if not success:
logger.error(f"Failed to encode image for redis_save_image")
continue
expire_seconds = action.get("expire_seconds")
if expire_seconds:
node["redis_client"].setex(key, expire_seconds, buffer.tobytes())
logger.info(f"Saved image to Redis with key: {key} (expires in {expire_seconds}s)")
else:
node["redis_client"].set(key, buffer.tobytes())
logger.info(f"Saved image to Redis with key: {key}")
action_context["image_key"] = key
elif action["type"] == "redis_publish":
channel = action["channel"]
try:
# Handle JSON message format by creating it programmatically
message_template = action["message"]
# Check if the message is JSON-like (starts and ends with braces)
if message_template.strip().startswith('{') and message_template.strip().endswith('}'):
# Create JSON data programmatically to avoid formatting issues
json_data = {}
# Add common fields
json_data["event"] = "frontal_detected"
json_data["display_id"] = action_context.get("display_id", "unknown")
json_data["session_id"] = action_context.get("session_id")
json_data["timestamp"] = action_context.get("timestamp", "")
json_data["image_key"] = action_context.get("image_key", "")
# Convert to JSON string
message = json.dumps(json_data)
else:
# Use regular string formatting for non-JSON messages
message = message_template.format(**action_context)
# Publish to Redis
if not node["redis_client"]:
logger.error("Redis client is None, cannot publish message")
continue
# Test Redis connection
try:
node["redis_client"].ping()
logger.debug("Redis connection is active")
except Exception as ping_error:
logger.error(f"Redis connection test failed: {ping_error}")
continue
result = node["redis_client"].publish(channel, message)
logger.info(f"Published message to Redis channel '{channel}': {message}")
logger.info(f"Redis publish result (subscribers count): {result}")
# Additional debug info
if result == 0:
logger.warning(f"No subscribers listening to channel '{channel}'")
else:
logger.info(f"Message delivered to {result} subscriber(s)")
except KeyError as e:
logger.error(f"Missing key in redis_publish message template: {e}")
logger.debug(f"Available context keys: {list(action_context.keys())}")
except Exception as e:
logger.error(f"Error in redis_publish action: {e}")
logger.debug(f"Message template: {action['message']}")
logger.debug(f"Available context keys: {list(action_context.keys())}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
except Exception as e:
logger.error(f"Error executing action {action['type']}: {e}")
def execute_parallel_actions(node, frame, detection_result, regions_dict):
"""Execute parallel actions after all required branches have completed."""
if not node.get("parallelActions"):
return
logger.debug("Executing parallel actions...")
branch_results = detection_result.get("branch_results", {})
for action in node["parallelActions"]:
try:
action_type = action.get("type")
logger.debug(f"Processing parallel action: {action_type}")
if action_type == "postgresql_update_combined":
# Check if all required branches have completed
wait_for_branches = action.get("waitForBranches", [])
missing_branches = [branch for branch in wait_for_branches if branch not in branch_results]
if missing_branches:
logger.warning(f"Cannot execute postgresql_update_combined: missing branch results for {missing_branches}")
continue
logger.info(f"All required branches completed: {wait_for_branches}")
# Execute the database update
execute_postgresql_update_combined(node, action, detection_result, branch_results)
else:
logger.warning(f"Unknown parallel action type: {action_type}")
except Exception as e:
logger.error(f"Error executing parallel action {action.get('type', 'unknown')}: {e}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
def execute_postgresql_update_combined(node, action, detection_result, branch_results):
"""Execute a PostgreSQL update with combined branch results."""
if not node.get("db_manager"):
logger.error("No database manager available for postgresql_update_combined action")
return
try:
table = action["table"]
key_field = action["key_field"]
key_value_template = action["key_value"]
fields = action["fields"]
# Create context for key value formatting
action_context = {**detection_result}
key_value = key_value_template.format(**action_context)
logger.info(f"Executing database update: table={table}, {key_field}={key_value}")
# Process field mappings
mapped_fields = {}
for db_field, value_template in fields.items():
try:
mapped_value = resolve_field_mapping(value_template, branch_results, action_context)
if mapped_value is not None:
mapped_fields[db_field] = mapped_value
logger.debug(f"Mapped field: {db_field} = {mapped_value}")
else:
logger.warning(f"Could not resolve field mapping for {db_field}: {value_template}")
except Exception as e:
logger.error(f"Error mapping field {db_field} with template '{value_template}': {e}")
if not mapped_fields:
logger.warning("No fields mapped successfully, skipping database update")
return
# Execute the database update
success = node["db_manager"].execute_update(table, key_field, key_value, mapped_fields)
if success:
logger.info(f"Successfully updated database: {table} with {len(mapped_fields)} fields")
else:
logger.error(f"Failed to update database: {table}")
except KeyError as e:
logger.error(f"Missing required field in postgresql_update_combined action: {e}")
except Exception as e:
logger.error(f"Error in postgresql_update_combined action: {e}")
import traceback
logger.debug(f"Full traceback: {traceback.format_exc()}")
def resolve_field_mapping(value_template, branch_results, action_context):
"""Resolve field mapping templates like {car_brand_cls_v1.brand}."""
try:
# Handle simple context variables first (non-branch references)
if not '.' in value_template:
return value_template.format(**action_context)
# Handle branch result references like {model_id.field}
import re
branch_refs = re.findall(r'\{([^}]+\.[^}]+)\}', value_template)
resolved_template = value_template
for ref in branch_refs:
try:
model_id, field_name = ref.split('.', 1)
if model_id in branch_results:
branch_data = branch_results[model_id]
if field_name in branch_data:
field_value = branch_data[field_name]
resolved_template = resolved_template.replace(f'{{{ref}}}', str(field_value))
logger.debug(f"Resolved {ref} to {field_value}")
else:
logger.warning(f"Field '{field_name}' not found in branch '{model_id}' results. Available fields: {list(branch_data.keys())}")
return None
else:
logger.warning(f"Branch '{model_id}' not found in results. Available branches: {list(branch_results.keys())}")
return None
except ValueError as e:
logger.error(f"Invalid branch reference format: {ref}")
return None
# Format any remaining simple variables
try:
final_value = resolved_template.format(**action_context)
return final_value
except KeyError as e:
logger.warning(f"Could not resolve context variable in template: {e}")
return resolved_template
except Exception as e:
logger.error(f"Error resolving field mapping '{value_template}': {e}")
return None
def run_pipeline(frame, node: dict, return_bbox: bool=False, context=None):
"""
Enhanced pipeline that supports:
- Multi-class detection (detecting multiple classes simultaneously)
- Parallel branch processing
- Region-based actions and cropping
- Context passing for session/camera information
"""
try:
task = getattr(node["model"], "task", None)
# ─── Classification stage ───────────────────────────────────
if task == "classify":
results = node["model"].predict(frame, stream=False)
if not results:
return (None, None) if return_bbox else None
r = results[0]
probs = r.probs
if probs is None:
return (None, None) if return_bbox else None
top1_idx = int(probs.top1)
top1_conf = float(probs.top1conf)
class_name = node["model"].names[top1_idx]
det = {
"class": class_name,
"confidence": top1_conf,
"id": None,
class_name: class_name # Add class name as key for backward compatibility
}
# Add specific field mappings for database operations based on model type
model_id = node.get("modelId", "").lower()
if "brand" in model_id or "brand_cls" in model_id:
det["brand"] = class_name
elif "bodytype" in model_id or "body" in model_id:
det["body_type"] = class_name
elif "color" in model_id:
det["color"] = class_name
execute_actions(node, frame, det)
return (det, None) if return_bbox else det
# ─── Detection stage - Multi-class support ──────────────────
tk = node["triggerClassIndices"]
logger.debug(f"Running detection for node {node['modelId']} with trigger classes: {node.get('triggerClasses', [])} (indices: {tk})")
logger.debug(f"Node configuration: minConfidence={node['minConfidence']}, multiClass={node.get('multiClass', False)}")
res = node["model"].track(
frame,
stream=False,
persist=True,
**({"classes": tk} if tk else {})
)[0]
# Collect all detections above confidence threshold
all_detections = []
all_boxes = []
regions_dict = {}
logger.debug(f"Raw detection results from model: {len(res.boxes) if res.boxes is not None else 0} detections")
for i, box in enumerate(res.boxes):
conf = float(box.cpu().conf[0])
cid = int(box.cpu().cls[0])
name = node["model"].names[cid]
logger.debug(f"Detection {i}: class='{name}' (id={cid}), confidence={conf:.3f}, threshold={node['minConfidence']}")
if conf < node["minConfidence"]:
logger.debug(f" -> REJECTED: confidence {conf:.3f} < threshold {node['minConfidence']}")
continue
xy = box.cpu().xyxy[0]
x1, y1, x2, y2 = map(int, xy)
bbox = (x1, y1, x2, y2)
detection = {
"class": name,
"confidence": conf,
"id": box.id.item() if hasattr(box, "id") else None,
"bbox": bbox
}
all_detections.append(detection)
all_boxes.append(bbox)
logger.debug(f" -> ACCEPTED: {name} with confidence {conf:.3f}, bbox={bbox}")
# Store highest confidence detection for each class
if name not in regions_dict or conf > regions_dict[name]["confidence"]:
regions_dict[name] = {
"bbox": bbox,
"confidence": conf,
"detection": detection
}
logger.debug(f" -> Updated regions_dict['{name}'] with confidence {conf:.3f}")
logger.info(f"Detection summary: {len(all_detections)} accepted detections from {len(res.boxes) if res.boxes is not None else 0} total")
logger.info(f"Detected classes: {list(regions_dict.keys())}")
if not all_detections:
logger.warning("No detections above confidence threshold - returning null")
return (None, None) if return_bbox else None
# ─── Multi-class validation ─────────────────────────────────
if node.get("multiClass", False) and node.get("expectedClasses"):
expected_classes = node["expectedClasses"]
detected_classes = list(regions_dict.keys())
logger.info(f"Multi-class validation: expected={expected_classes}, detected={detected_classes}")
# Check if at least one expected class is detected (flexible mode)
matching_classes = [cls for cls in expected_classes if cls in detected_classes]
missing_classes = [cls for cls in expected_classes if cls not in detected_classes]
logger.debug(f"Matching classes: {matching_classes}, Missing classes: {missing_classes}")
if not matching_classes:
# No expected classes found at all
logger.warning(f"PIPELINE REJECTED: No expected classes detected. Expected: {expected_classes}, Detected: {detected_classes}")
return (None, None) if return_bbox else None
if missing_classes:
logger.info(f"Partial multi-class detection: {matching_classes} found, {missing_classes} missing")
else:
logger.info(f"Complete multi-class detection success: {detected_classes}")
else:
logger.debug("No multi-class validation - proceeding with all detections")
# ─── Execute actions with region information ────────────────
detection_result = {
"detections": all_detections,
"regions": regions_dict,
**(context or {})
}
# ─── Create initial database record when Car+Frontal detected ────
if node.get("db_manager") and node.get("multiClass", False):
# Only create database record if we have both Car and Frontal
has_car = "Car" in regions_dict
has_frontal = "Frontal" in regions_dict
if has_car and has_frontal:
# Generate UUID session_id since client session is None for now
import uuid as uuid_lib
from datetime import datetime
generated_session_id = str(uuid_lib.uuid4())
# Insert initial detection record
display_id = detection_result.get("display_id", "unknown")
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
inserted_session_id = node["db_manager"].insert_initial_detection(
display_id=display_id,
captured_timestamp=timestamp,
session_id=generated_session_id
)
if inserted_session_id:
# Update detection_result with the generated session_id for actions and branches
detection_result["session_id"] = inserted_session_id
detection_result["timestamp"] = timestamp # Update with proper timestamp
logger.info(f"Created initial database record with session_id: {inserted_session_id}")
else:
logger.debug(f"Database record not created - missing required classes. Has Car: {has_car}, Has Frontal: {has_frontal}")
execute_actions(node, frame, detection_result, regions_dict)
# ─── Parallel branch processing ─────────────────────────────
if node["branches"]:
branch_results = {}
# Filter branches that should be triggered
active_branches = []
for br in node["branches"]:
trigger_classes = br.get("triggerClasses", [])
min_conf = br.get("minConfidence", 0)
logger.debug(f"Evaluating branch {br['modelId']}: trigger_classes={trigger_classes}, min_conf={min_conf}")
# Check if any detected class matches branch trigger
branch_triggered = False
for det_class in regions_dict:
det_confidence = regions_dict[det_class]["confidence"]
logger.debug(f" Checking detected class '{det_class}' (confidence={det_confidence:.3f}) against triggers {trigger_classes}")
if (det_class in trigger_classes and det_confidence >= min_conf):
active_branches.append(br)
branch_triggered = True
logger.info(f"Branch {br['modelId']} activated by class '{det_class}' (conf={det_confidence:.3f} >= {min_conf})")
break
if not branch_triggered:
logger.debug(f"Branch {br['modelId']} not triggered - no matching classes or insufficient confidence")
if active_branches:
if node.get("parallel", False) or any(br.get("parallel", False) for br in active_branches):
# Run branches in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active_branches)) as executor:
futures = {}
for br in active_branches:
crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
sub_frame = frame
logger.info(f"Starting parallel branch: {br['modelId']}, crop_class: {crop_class}")
if br.get("crop", False) and crop_class:
cropped = crop_region_by_class(frame, regions_dict, crop_class)
if cropped is not None:
sub_frame = cv2.resize(cropped, (224, 224))
logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
else:
logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
continue
future = executor.submit(run_pipeline, sub_frame, br, True, context)
futures[future] = br
# Collect results
for future in concurrent.futures.as_completed(futures):
br = futures[future]
try:
result, _ = future.result()
if result:
branch_results[br["modelId"]] = result
logger.info(f"Branch {br['modelId']} completed: {result}")
except Exception as e:
logger.error(f"Branch {br['modelId']} failed: {e}")
else:
# Run branches sequentially
for br in active_branches:
crop_class = br.get("cropClass", br.get("triggerClasses", [])[0] if br.get("triggerClasses") else None)
sub_frame = frame
logger.info(f"Starting sequential branch: {br['modelId']}, crop_class: {crop_class}")
if br.get("crop", False) and crop_class:
cropped = crop_region_by_class(frame, regions_dict, crop_class)
if cropped is not None:
sub_frame = cv2.resize(cropped, (224, 224))
logger.debug(f"Successfully cropped {crop_class} region for {br['modelId']}")
else:
logger.warning(f"Failed to crop {crop_class} region for {br['modelId']}, skipping branch")
continue
try:
result, _ = run_pipeline(sub_frame, br, True, context)
if result:
branch_results[br["modelId"]] = result
logger.info(f"Branch {br['modelId']} completed: {result}")
else:
logger.warning(f"Branch {br['modelId']} returned no result")
except Exception as e:
logger.error(f"Error in sequential branch {br['modelId']}: {e}")
import traceback
logger.debug(f"Branch error traceback: {traceback.format_exc()}")
# Store branch results in detection_result for parallel actions
detection_result["branch_results"] = branch_results
# ─── Execute Parallel Actions ───────────────────────────────
if node.get("parallelActions") and "branch_results" in detection_result:
execute_parallel_actions(node, frame, detection_result, regions_dict)
# ─── Return detection result ────────────────────────────────
primary_detection = max(all_detections, key=lambda x: x["confidence"])
primary_bbox = primary_detection["bbox"]
# Add branch results to primary detection for compatibility
if "branch_results" in detection_result:
primary_detection["branch_results"] = detection_result["branch_results"]
return (primary_detection, primary_bbox) if return_bbox else primary_detection
except Exception as e:
logger.error(f"Error in node {node.get('modelId')}: {e}")
traceback.print_exc()
return (None, None) if return_bbox else None