python-detector-worker/app.py

373 lines
17 KiB
Python

from typing import Any, Dict
import os
import json
import time
import queue
import torch
import cv2
import base64
import logging
import threading
import requests
import asyncio
import psutil
import zipfile
from urllib.parse import urlparse
from fastapi import FastAPI, WebSocket
from fastapi.websockets import WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
from ultralytics import YOLO
# Import shared pipeline functions
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline
app = FastAPI()
# Global dictionaries to keep track of models and streams
# "models" now holds a nested dict: { camera_id: { modelId: model_tree } }
models: Dict[str, Dict[str, Any]] = {}
streams: Dict[str, Dict[str, Any]] = {}
with open("config.json", "r") as f:
config = json.load(f)
poll_interval = config.get("poll_interval_ms", 100)
reconnect_interval = config.get("reconnect_interval_sec", 5)
TARGET_FPS = config.get("target_fps", 10)
poll_interval = 1000 / TARGET_FPS
logging.info(f"Poll interval: {poll_interval}ms")
max_streams = config.get("max_streams", 5)
max_retries = config.get("max_retries", 3)
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
# Ensure the models directory exists
os.makedirs("models", exist_ok=True)
# Constants for heartbeat and timeouts
HEARTBEAT_INTERVAL = 2 # seconds
WORKER_TIMEOUT_MS = 10000
# Locks for thread-safe operations
streams_lock = threading.Lock()
models_lock = threading.Lock()
# Add helper to download mpta ZIP file from a remote URL
def download_mpta(url: str, dest_path: str) -> str:
try:
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(dest_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logging.info(f"Downloaded mpta file from {url} to {dest_path}")
return dest_path
else:
logging.error(f"Failed to download mpta file (status code {response.status_code})")
return None
except Exception as e:
logging.error(f"Exception downloading mpta file from {url}: {e}")
return None
####################################################
# Detection and frame processing functions
####################################################
@app.websocket("/")
async def detect(websocket: WebSocket):
logging.info("WebSocket connection accepted")
persistent_data_dict = {}
async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data):
try:
detection_result = run_pipeline(frame, model_tree)
detection_data = {
"type": "imageDetection",
"cameraIdentifier": camera_id,
"timestamp": time.time(),
"data": {
"detection": detection_result if detection_result else None,
"modelId": stream["modelId"],
"modelName": stream["modelName"]
}
}
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
await websocket.send_json(detection_data)
return persistent_data
except Exception as e:
logging.error(f"Error in handle_detection for camera {camera_id}: {e}")
return persistent_data
def frame_reader(camera_id, cap, buffer, stop_event):
retries = 0
try:
while not stop_event.is_set():
try:
ret, frame = cap.read()
if not ret:
logging.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}")
cap.release()
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logging.error(f"Max retries reached for camera: {camera_id}")
break
# Re-open
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
if not cap.isOpened():
logging.error(f"Failed to reopen RTSP stream for camera: {camera_id}")
continue
continue
retries = 0
# Overwrite old frame if buffer is full
if not buffer.empty():
try:
buffer.get_nowait()
except queue.Empty:
pass
buffer.put(frame)
except cv2.error as e:
logging.error(f"OpenCV error for camera {camera_id}: {e}")
cap.release()
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logging.error(f"Max retries reached after OpenCV error for camera {camera_id}")
break
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
if not cap.isOpened():
logging.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error")
continue
except Exception as e:
logging.error(f"Unexpected error for camera {camera_id}: {e}")
cap.release()
break
except Exception as e:
logging.error(f"Error in frame_reader thread for camera {camera_id}: {e}")
async def process_streams():
logging.info("Started processing streams")
try:
while True:
start_time = time.time()
with streams_lock:
current_streams = list(streams.items())
for camera_id, stream in current_streams:
buffer = stream["buffer"]
if not buffer.empty():
frame = buffer.get()
with models_lock:
model_tree = models.get(camera_id, {}).get(stream["modelId"])
key = (camera_id, stream["modelId"])
persistent_data = persistent_data_dict.get(key, {})
updated_persistent_data = await handle_detection(
camera_id, stream, frame, websocket, model_tree, persistent_data
)
persistent_data_dict[key] = updated_persistent_data
elapsed_time = (time.time() - start_time) * 1000 # ms
sleep_time = max(poll_interval - elapsed_time, 0)
logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms")
await asyncio.sleep(sleep_time / 1000.0)
except asyncio.CancelledError:
logging.info("Stream processing task cancelled")
except Exception as e:
logging.error(f"Error in process_streams: {e}")
async def send_heartbeat():
while True:
try:
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # MB
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # MB
else:
gpu_usage = None
gpu_memory_usage = None
camera_connections = [
{
"cameraIdentifier": camera_id,
"modelId": stream["modelId"],
"modelName": stream["modelName"],
"online": True
}
for camera_id, stream in streams.items()
]
state_report = {
"type": "stateReport",
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemoryUsage": gpu_memory_usage,
"cameraConnections": camera_connections
}
await websocket.send_text(json.dumps(state_report))
logging.debug("Sent stateReport as heartbeat")
await asyncio.sleep(HEARTBEAT_INTERVAL)
except Exception as e:
logging.error(f"Error sending stateReport heartbeat: {e}")
break
async def on_message():
while True:
try:
msg = await websocket.receive_text()
logging.debug(f"Received message: {msg}")
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "subscribe":
payload = data.get("payload", {})
camera_id = payload.get("cameraIdentifier")
rtsp_url = payload.get("rtspUrl")
model_url = payload.get("modelUrl") # may be remote or local
modelId = payload.get("modelId")
modelName = payload.get("modelName")
if model_url:
with models_lock:
if camera_id not in models:
models[camera_id] = {}
if modelId not in models[camera_id]:
logging.info(f"Loading model from {model_url}")
extraction_dir = os.path.join("models", camera_id, str(modelId))
os.makedirs(extraction_dir, exist_ok=True)
# If model_url is remote, download it first.
parsed = urlparse(model_url)
if parsed.scheme in ("http", "https"):
local_mpta = os.path.join(extraction_dir, os.path.basename(parsed.path))
local_path = download_mpta(model_url, local_mpta)
if not local_path:
logging.error("Failed to download the remote mpta file.")
continue
model_tree = load_pipeline_from_zip(local_path, extraction_dir)
else:
model_tree = load_pipeline_from_zip(model_url, extraction_dir)
if model_tree is None:
logging.error("Failed to load model from mpta file.")
continue
models[camera_id][modelId] = model_tree
logging.info(f"Loaded model {modelId} for camera {camera_id}")
if camera_id and rtsp_url:
with streams_lock:
if camera_id not in streams and len(streams) < max_streams:
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logging.error(f"Failed to open RTSP stream for camera {camera_id}")
continue
buffer = queue.Queue(maxsize=1)
stop_event = threading.Event()
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
streams[camera_id] = {
"cap": cap,
"buffer": buffer,
"thread": thread,
"rtsp_url": rtsp_url,
"stop_event": stop_event,
"modelId": modelId,
"modelName": modelName
}
logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName}, URL {rtsp_url}")
elif camera_id and camera_id in streams:
# If already subscribed, unsubscribe
stream = streams.pop(camera_id)
stream["cap"].release()
logging.info(f"Unsubscribed from camera {camera_id}")
with models_lock:
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
elif msg_type == "unsubscribe":
payload = data.get("payload", {})
camera_id = payload.get("cameraIdentifier")
logging.debug(f"Unsubscribing from camera {camera_id}")
with streams_lock:
if camera_id and camera_id in streams:
stream = streams.pop(camera_id)
stream["stop_event"].set()
stream["thread"].join()
stream["cap"].release()
logging.info(f"Unsubscribed from camera {camera_id}")
with models_lock:
if camera_id in models:
del models[camera_id]
elif msg_type == "requestState":
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2)
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2)
else:
gpu_usage = None
gpu_memory_usage = None
camera_connections = [
{
"cameraIdentifier": camera_id,
"modelId": stream["modelId"],
"modelName": stream["modelName"],
"online": True
}
for camera_id, stream in streams.items()
]
state_report = {
"type": "stateReport",
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemoryUsage": gpu_memory_usage,
"cameraConnections": camera_connections
}
await websocket.send_text(json.dumps(state_report))
else:
logging.error(f"Unknown message type: {msg_type}")
except json.JSONDecodeError:
logging.error("Received invalid JSON message")
except (WebSocketDisconnect, ConnectionClosedError) as e:
logging.warning(f"WebSocket disconnected: {e}")
break
except Exception as e:
logging.error(f"Error handling message: {e}")
break
try:
await websocket.accept()
stream_task = asyncio.create_task(process_streams())
heartbeat_task = asyncio.create_task(send_heartbeat())
message_task = asyncio.create_task(on_message())
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logging.error(f"Error in detect websocket: {e}")
finally:
stream_task.cancel()
await stream_task
with streams_lock:
for camera_id, stream in streams.items():
stream["stop_event"].set()
stream["thread"].join()
stream["cap"].release()
while not stream["buffer"].empty():
try:
stream["buffer"].get_nowait()
except queue.Empty:
pass
logging.info(f"Released camera {camera_id} and cleaned up resources")
streams.clear()
with models_lock:
models.clear()
logging.info("WebSocket connection closed")