beta pipeline

This commit is contained in:
Siwat Sirichai 2025-02-23 01:32:56 +07:00
parent b12e4ccb7f
commit 5da166a341
2 changed files with 622 additions and 125 deletions

365
app.py
View file

@ -1,25 +1,29 @@
from typing import List
from typing import Any, Dict
import os
import json
import time
import queue
import torch
import cv2
import base64
import logging
import threading
import requests
import asyncio
import psutil
import zipfile
from urllib.parse import urlparse
from fastapi import FastAPI, WebSocket
from fastapi.websockets import WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
from ultralytics import YOLO
import torch
import cv2
import base64
import numpy as np
import json
import logging
import threading
import queue
import os
import requests
from urllib.parse import urlparse
import asyncio
import psutil
app = FastAPI()
models = {}
# Global dictionaries to keep track of models and streams
# "models" now holds a nested dict: { camera_id: { modelId: model_tree } }
models: Dict[str, Dict[str, Any]] = {}
streams: Dict[str, Dict[str, Any]] = {}
with open("config.json", "r") as f:
config = json.load(f)
@ -45,51 +49,188 @@ logging.basicConfig(
# Ensure the models directory exists
os.makedirs("models", exist_ok=True)
# Add constants for heartbeat
# Constants for heartbeat and timeouts
HEARTBEAT_INTERVAL = 2 # seconds
WORKER_TIMEOUT_MS = 10000
# Add a lock for thread-safe operations on shared resources
# Locks for thread-safe operations
streams_lock = threading.Lock()
models_lock = threading.Lock()
####################################################
# Pipeline (Model)-loading helper functions
####################################################
def load_pipeline_node(node_config: dict, models_dir: str) -> dict:
"""
Recursively load a model node.
Expects node_config to have:
- modelId: a unique identifier
- modelFile: the .pt file in models_dir
- triggerClasses: list of class names that activate child branches
- crop: boolean; if True, we crop to the bounding box for the next model
- minConfidence: (optional) minimum confidence required to enter this branch
- branches: list of child node configurations
"""
model_path = os.path.join(models_dir, node_config["modelFile"])
if not os.path.exists(model_path):
logging.error(f"Model file {model_path} not found.")
raise FileNotFoundError(f"Model file {model_path} not found.")
logging.info(f"Loading model for node {node_config['modelId']} from {model_path}")
model = YOLO(model_path)
if torch.cuda.is_available():
model.to("cuda")
node = {
"modelId": node_config["modelId"],
"modelFile": node_config["modelFile"],
"triggerClasses": node_config.get("triggerClasses", []),
"crop": node_config.get("crop", False),
"minConfidence": node_config.get("minConfidence", None), # NEW FIELD
"model": model,
"branches": []
}
for child_config in node_config.get("branches", []):
child_node = load_pipeline_node(child_config, models_dir)
node["branches"].append(child_node)
return node
def load_pipeline_from_zip(zip_url: str, target_dir: str) -> dict:
"""
Download the .mpta file from zip_url, extract it to target_dir,
and load the pipeline configuration (pipeline.json).
Returns the model tree (root node) loaded with YOLO models.
"""
os.makedirs(target_dir, exist_ok=True)
zip_path = os.path.join(target_dir, "pipeline.mpta")
try:
response = requests.get(zip_url, stream=True)
if response.status_code == 200:
with open(zip_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logging.info(f"Downloaded .mpta file from {zip_url} to {zip_path}")
else:
logging.error(f"Failed to download .mpta file (status {response.status_code})")
return None
except Exception as e:
logging.error(f"Exception downloading .mpta file from {zip_url}: {e}")
return None
# Extract the .mpta file
try:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(target_dir)
logging.info(f"Extracted .mpta file to {target_dir}")
except Exception as e:
logging.error(f"Failed to extract .mpta file: {e}")
return None
finally:
if os.path.exists(zip_path):
os.remove(zip_path)
# Load pipeline.json
pipeline_json_path = os.path.join(target_dir, "pipeline.json")
if not os.path.exists(pipeline_json_path):
logging.error("pipeline.json not found in the .mpta file")
return None
try:
with open(pipeline_json_path, "r") as f:
pipeline_config = json.load(f)
# Build the model tree recursively
model_tree = load_pipeline_node(pipeline_config["pipeline"], target_dir)
return model_tree
except Exception as e:
logging.error(f"Error loading pipeline.json: {e}")
return None
####################################################
# Model execution function
####################################################
def run_pipeline(frame, node: dict):
"""
Run the model at the current node.
- Select the highest-confidence detection (if any).
- If 'crop' is True, crop to the bounding box for the next stage.
- If the detected class matches a branch's triggerClasses, check the confidence.
If the detection's confidence is below branch["minConfidence"] (if specified),
do not enter the branch and return the current detection.
Returns the final detection result (dict) or None.
"""
try:
results = node["model"].track(frame, stream=False, persist=True)
detection = None
max_conf = -1
best_box = None
for r in results:
for box in r.boxes:
box_cpu = box.cpu()
conf = float(box_cpu.conf[0])
if conf > max_conf and hasattr(box, "id") and box.id is not None:
max_conf = conf
detection = {
"class": node["model"].names[int(box_cpu.cls[0])],
"confidence": conf,
"id": box.id.item(),
}
best_box = box_cpu
# If there's a detection and crop is True, crop frame to bounding box
if detection and node.get("crop", False) and best_box is not None:
coords = best_box.xyxy[0] # [x1, y1, x2, y2]
x1, y1, x2, y2 = map(int, coords)
h, w = frame.shape[:2]
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(w, x2)
y2 = min(h, y2)
if x2 > x1 and y2 > y1:
frame = frame[y1:y2, x1:x2] # crop the frame
if detection is not None:
# Check if any branch should be entered based on trigger classes
for branch in node["branches"]:
if detection["class"] in branch.get("triggerClasses", []):
# Check for a minimum confidence threshold for this branch
min_conf = branch.get("minConfidence")
if min_conf is not None and detection["confidence"] < min_conf:
logging.debug(
f"Detection confidence {detection['confidence']} below threshold "
f"{min_conf} for branch {branch['modelId']}. Ending pipeline at current node."
)
return detection
branch_detection = run_pipeline(frame, branch)
if branch_detection is not None:
return branch_detection
return detection
return None
except Exception as e:
logging.error(f"Error running pipeline on node {node.get('modelId')}: {e}")
return None
####################################################
# Detection and frame processing functions
####################################################
@app.websocket("/")
async def detect(websocket: WebSocket):
import asyncio
import time
logging.info("WebSocket connection accepted")
persistent_data_dict = {}
streams = {}
# This function is user-modifiable
# Save data you want to persist across frames in the persistent_data dictionary
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
async def handle_detection(camera_id, stream, frame, websocket, model_tree, persistent_data):
try:
highest_conf_box = None
max_conf = -1
for r in model.track(frame, stream=False, persist=True):
for box in r.boxes:
box_cpu = box.cpu()
conf = float(box_cpu.conf[0])
if conf > max_conf and hasattr(box, "id") and box.id is not None:
max_conf = conf
highest_conf_box = {
"class": model.names[int(box_cpu.cls[0])],
"confidence": conf,
"id": box.id.item(),
}
# Broadcast to all subscribers of this URL
detection_result = run_pipeline(frame, model_tree)
detection_data = {
"type": "imageDetection",
"cameraIdentifier": camera_id,
"timestamp": time.time(),
"data": {
"detections": highest_conf_box if highest_conf_box else None,
"modelId": stream['modelId'],
"modelName": stream['modelName']
"detection": detection_result if detection_result else None,
"modelId": stream["modelId"],
"modelName": stream["modelName"]
}
}
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
@ -100,7 +241,6 @@ async def detect(websocket: WebSocket):
return persistent_data
def frame_reader(camera_id, cap, buffer, stop_event):
import time
retries = 0
try:
while not stop_event.is_set():
@ -114,16 +254,17 @@ async def detect(websocket: WebSocket):
if retries > max_retries and max_retries != -1:
logging.error(f"Max retries reached for camera: {camera_id}")
break
# Re-open the VideoCapture
cap = cv2.VideoCapture(streams[camera_id]['rtsp_url'])
# Re-open
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
if not cap.isOpened():
logging.error(f"Failed to reopen RTSP stream for camera: {camera_id}")
continue
continue
retries = 0 # Reset on success
retries = 0
# Overwrite old frame if buffer is full
if not buffer.empty():
try:
buffer.get_nowait() # Discard the old frame
buffer.get_nowait()
except queue.Empty:
pass
buffer.put(frame)
@ -133,10 +274,9 @@ async def detect(websocket: WebSocket):
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logging.error(f"Max retries reached after OpenCV error for camera: {camera_id}")
logging.error(f"Max retries reached after OpenCV error for camera {camera_id}")
break
# Re-open the VideoCapture
cap = cv2.VideoCapture(streams[camera_id]['rtsp_url'])
cap = cv2.VideoCapture(streams[camera_id]["rtsp_url"])
if not cap.isOpened():
logging.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error")
continue
@ -148,26 +288,25 @@ async def detect(websocket: WebSocket):
logging.error(f"Error in frame_reader thread for camera {camera_id}: {e}")
async def process_streams():
global models
logging.info("Started processing streams")
persistent_data_dict = {}
try:
while True:
start_time = time.time()
# Round-robin processing
with streams_lock:
current_streams = list(streams.items())
for camera_id, stream in current_streams:
buffer = stream['buffer']
buffer = stream["buffer"]
if not buffer.empty():
frame = buffer.get()
with models_lock:
model = models.get(camera_id, {}).get(stream['modelId'])
key = (camera_id, stream['modelId'])
model_tree = models.get(camera_id, {}).get(stream["modelId"])
key = (camera_id, stream["modelId"])
persistent_data = persistent_data_dict.get(key, {})
updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
updated_persistent_data = await handle_detection(
camera_id, stream, frame, websocket, model_tree, persistent_data
)
persistent_data_dict[key] = updated_persistent_data
elapsed_time = (time.time() - start_time) * 1000 # in ms
elapsed_time = (time.time() - start_time) * 1000 # ms
sleep_time = max(poll_interval - elapsed_time, 0)
logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms")
await asyncio.sleep(sleep_time / 1000.0)
@ -182,8 +321,8 @@ async def detect(websocket: WebSocket):
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # MB
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # MB
else:
gpu_usage = None
gpu_memory_usage = None
@ -191,8 +330,8 @@ async def detect(websocket: WebSocket):
camera_connections = [
{
"cameraIdentifier": camera_id,
"modelId": stream['modelId'],
"modelName": stream['modelName'],
"modelId": stream["modelId"],
"modelName": stream["modelName"],
"online": True
}
for camera_id, stream in streams.items()
@ -214,12 +353,10 @@ async def detect(websocket: WebSocket):
break
async def on_message():
global models
while True:
try:
msg = await websocket.receive_text()
logging.debug(f"Received message: {msg}")
print(f"Received message: {msg}")
data = json.loads(msg)
msg_type = data.get("type")
@ -227,7 +364,7 @@ async def detect(websocket: WebSocket):
payload = data.get("payload", {})
camera_id = payload.get("cameraIdentifier")
rtsp_url = payload.get("rtspUrl")
model_url = payload.get("modelUrl")
model_url = payload.get("modelUrl") # ZIP file URL
modelId = payload.get("modelId")
modelName = payload.get("modelName")
@ -236,25 +373,16 @@ async def detect(websocket: WebSocket):
if camera_id not in models:
models[camera_id] = {}
if modelId not in models[camera_id]:
print(f"Downloading model from {model_url}")
parsed_url = urlparse(model_url)
filename = os.path.basename(parsed_url.path)
model_filename = os.path.join("models", filename)
# Download the model
response = requests.get(model_url, stream=True)
if response.status_code == 200:
with open(model_filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logging.info(f"Downloaded model from {model_url} to {model_filename}")
model = YOLO(model_filename)
if torch.cuda.is_available():
model.to('cuda')
models[camera_id][modelId] = model
logging.info(f"Loaded model {modelId} for camera {camera_id}")
else:
logging.error(f"Failed to download model from {model_url}")
logging.info(f"Downloading model from {model_url}")
extraction_dir = os.path.join("models", camera_id, str(modelId))
os.makedirs(extraction_dir, exist_ok=True)
model_tree = load_pipeline_from_zip(model_url, extraction_dir)
if model_tree is None:
logging.error("Failed to load model from ZIP file.")
continue
models[camera_id][modelId] = model_tree
logging.info(f"Loaded model {modelId} for camera {camera_id}")
if camera_id and rtsp_url:
with streams_lock:
if camera_id not in streams and len(streams) < max_streams:
@ -268,23 +396,25 @@ async def detect(websocket: WebSocket):
thread.daemon = True
thread.start()
streams[camera_id] = {
'cap': cap,
'buffer': buffer,
'thread': thread,
'rtsp_url': rtsp_url,
'stop_event': stop_event,
'modelId': modelId,
'modelName': modelName
"cap": cap,
"buffer": buffer,
"thread": thread,
"rtsp_url": rtsp_url,
"stop_event": stop_event,
"modelId": modelId,
"modelName": modelName
}
logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName} and URL {rtsp_url}")
logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName}, URL {rtsp_url}")
elif camera_id and camera_id in streams:
# If already subscribed, unsubscribe
stream = streams.pop(camera_id)
stream['cap'].release()
stream["cap"].release()
logging.info(f"Unsubscribed from camera {camera_id}")
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
with models_lock:
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
elif msg_type == "unsubscribe":
payload = data.get("payload", {})
camera_id = payload.get("cameraIdentifier")
@ -292,21 +422,19 @@ async def detect(websocket: WebSocket):
with streams_lock:
if camera_id and camera_id in streams:
stream = streams.pop(camera_id)
stream['stop_event'].set()
stream['thread'].join()
stream['cap'].release()
stream["stop_event"].set()
stream["thread"].join()
stream["cap"].release()
logging.info(f"Unsubscribed from camera {camera_id}")
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
with models_lock:
if camera_id in models:
del models[camera_id]
elif msg_type == "requestState":
# Handle state request
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2)
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2)
else:
gpu_usage = None
gpu_memory_usage = None
@ -314,8 +442,8 @@ async def detect(websocket: WebSocket):
camera_connections = [
{
"cameraIdentifier": camera_id,
"modelId": stream['modelId'],
"modelName": stream['modelName'],
"modelId": stream["modelId"],
"modelName": stream["modelName"],
"online": True
}
for camera_id, stream in streams.items()
@ -343,22 +471,25 @@ async def detect(websocket: WebSocket):
try:
await websocket.accept()
task = asyncio.create_task(process_streams())
stream_task = asyncio.create_task(process_streams())
heartbeat_task = asyncio.create_task(send_heartbeat())
message_task = asyncio.create_task(on_message())
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logging.error(f"Error in detect websocket: {e}")
finally:
task.cancel()
await task
stream_task.cancel()
await stream_task
with streams_lock:
for camera_id, stream in streams.items():
stream['stop_event'].set()
stream['thread'].join()
stream['cap'].release()
stream['buffer'].queue.clear()
stream["stop_event"].set()
stream["thread"].join()
stream["cap"].release()
while not stream["buffer"].empty():
try:
stream["buffer"].get_nowait()
except queue.Empty:
pass
logging.info(f"Released camera {camera_id} and cleaned up resources")
streams.clear()
with models_lock:

366
app_single.py Normal file
View file

@ -0,0 +1,366 @@
from typing import List
from fastapi import FastAPI, WebSocket
from fastapi.websockets import WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError
from ultralytics import YOLO
import torch
import cv2
import base64
import numpy as np
import json
import logging
import threading
import queue
import os
import requests
from urllib.parse import urlparse
import asyncio
import psutil
app = FastAPI()
models = {}
with open("config.json", "r") as f:
config = json.load(f)
poll_interval = config.get("poll_interval_ms", 100)
reconnect_interval = config.get("reconnect_interval_sec", 5)
TARGET_FPS = config.get("target_fps", 10)
poll_interval = 1000 / TARGET_FPS
logging.info(f"Poll interval: {poll_interval}ms")
max_streams = config.get("max_streams", 5)
max_retries = config.get("max_retries", 3)
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
# Ensure the models directory exists
os.makedirs("models", exist_ok=True)
# Add constants for heartbeat
HEARTBEAT_INTERVAL = 2 # seconds
WORKER_TIMEOUT_MS = 10000
# Add a lock for thread-safe operations on shared resources
streams_lock = threading.Lock()
models_lock = threading.Lock()
@app.websocket("/")
async def detect(websocket: WebSocket):
import asyncio
import time
logging.info("WebSocket connection accepted")
streams = {}
# This function is user-modifiable
# Save data you want to persist across frames in the persistent_data dictionary
async def handle_detection(camera_id, stream, frame, websocket, model: YOLO, persistent_data):
try:
highest_conf_box = None
max_conf = -1
for r in model.track(frame, stream=False, persist=True):
for box in r.boxes:
box_cpu = box.cpu()
conf = float(box_cpu.conf[0])
if conf > max_conf and hasattr(box, "id") and box.id is not None:
max_conf = conf
highest_conf_box = {
"class": model.names[int(box_cpu.cls[0])],
"confidence": conf,
"id": box.id.item(),
}
# Broadcast to all subscribers of this URL
detection_data = {
"type": "imageDetection",
"cameraIdentifier": camera_id,
"timestamp": time.time(),
"data": {
"detections": highest_conf_box if highest_conf_box else None,
"modelId": stream['modelId'],
"modelName": stream['modelName']
}
}
logging.debug(f"Sending detection data for camera {camera_id}: {detection_data}")
await websocket.send_json(detection_data)
return persistent_data
except Exception as e:
logging.error(f"Error in handle_detection for camera {camera_id}: {e}")
return persistent_data
def frame_reader(camera_id, cap, buffer, stop_event):
import time
retries = 0
try:
while not stop_event.is_set():
try:
ret, frame = cap.read()
if not ret:
logging.warning(f"Connection lost for camera: {camera_id}, retry {retries+1}/{max_retries}")
cap.release()
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logging.error(f"Max retries reached for camera: {camera_id}")
break
# Re-open the VideoCapture
cap = cv2.VideoCapture(streams[camera_id]['rtsp_url'])
if not cap.isOpened():
logging.error(f"Failed to reopen RTSP stream for camera: {camera_id}")
continue
continue
retries = 0 # Reset on success
if not buffer.empty():
try:
buffer.get_nowait() # Discard the old frame
except queue.Empty:
pass
buffer.put(frame)
except cv2.error as e:
logging.error(f"OpenCV error for camera {camera_id}: {e}")
cap.release()
time.sleep(reconnect_interval)
retries += 1
if retries > max_retries and max_retries != -1:
logging.error(f"Max retries reached after OpenCV error for camera: {camera_id}")
break
# Re-open the VideoCapture
cap = cv2.VideoCapture(streams[camera_id]['rtsp_url'])
if not cap.isOpened():
logging.error(f"Failed to reopen RTSP stream for camera {camera_id} after OpenCV error")
continue
except Exception as e:
logging.error(f"Unexpected error for camera {camera_id}: {e}")
cap.release()
break
except Exception as e:
logging.error(f"Error in frame_reader thread for camera {camera_id}: {e}")
async def process_streams():
global models
logging.info("Started processing streams")
persistent_data_dict = {}
try:
while True:
start_time = time.time()
# Round-robin processing
with streams_lock:
current_streams = list(streams.items())
for camera_id, stream in current_streams:
buffer = stream['buffer']
if not buffer.empty():
frame = buffer.get()
with models_lock:
model = models.get(camera_id, {}).get(stream['modelId'])
key = (camera_id, stream['modelId'])
persistent_data = persistent_data_dict.get(key, {})
updated_persistent_data = await handle_detection(camera_id, stream, frame, websocket, model, persistent_data)
persistent_data_dict[key] = updated_persistent_data
elapsed_time = (time.time() - start_time) * 1000 # in ms
sleep_time = max(poll_interval - elapsed_time, 0)
logging.debug(f"Elapsed time: {elapsed_time}ms, sleeping for: {sleep_time}ms")
await asyncio.sleep(sleep_time / 1000.0)
except asyncio.CancelledError:
logging.info("Stream processing task cancelled")
except Exception as e:
logging.error(f"Error in process_streams: {e}")
async def send_heartbeat():
while True:
try:
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB
else:
gpu_usage = None
gpu_memory_usage = None
camera_connections = [
{
"cameraIdentifier": camera_id,
"modelId": stream['modelId'],
"modelName": stream['modelName'],
"online": True
}
for camera_id, stream in streams.items()
]
state_report = {
"type": "stateReport",
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemoryUsage": gpu_memory_usage,
"cameraConnections": camera_connections
}
await websocket.send_text(json.dumps(state_report))
logging.debug("Sent stateReport as heartbeat")
await asyncio.sleep(HEARTBEAT_INTERVAL)
except Exception as e:
logging.error(f"Error sending stateReport heartbeat: {e}")
break
async def on_message():
global models
while True:
try:
msg = await websocket.receive_text()
logging.debug(f"Received message: {msg}")
print(f"Received message: {msg}")
data = json.loads(msg)
msg_type = data.get("type")
if msg_type == "subscribe":
payload = data.get("payload", {})
camera_id = payload.get("cameraIdentifier")
rtsp_url = payload.get("rtspUrl")
model_url = payload.get("modelUrl")
modelId = payload.get("modelId")
modelName = payload.get("modelName")
if model_url:
with models_lock:
if camera_id not in models:
models[camera_id] = {}
if modelId not in models[camera_id]:
print(f"Downloading model from {model_url}")
parsed_url = urlparse(model_url)
filename = os.path.basename(parsed_url.path)
model_filename = os.path.join("models", filename)
# Download the model
response = requests.get(model_url, stream=True)
if response.status_code == 200:
with open(model_filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logging.info(f"Downloaded model from {model_url} to {model_filename}")
model = YOLO(model_filename)
if torch.cuda.is_available():
model.to('cuda')
models[camera_id][modelId] = model
logging.info(f"Loaded model {modelId} for camera {camera_id}")
else:
logging.error(f"Failed to download model from {model_url}")
continue
if camera_id and rtsp_url:
with streams_lock:
if camera_id not in streams and len(streams) < max_streams:
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logging.error(f"Failed to open RTSP stream for camera {camera_id}")
continue
buffer = queue.Queue(maxsize=1)
stop_event = threading.Event()
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
streams[camera_id] = {
'cap': cap,
'buffer': buffer,
'thread': thread,
'rtsp_url': rtsp_url,
'stop_event': stop_event,
'modelId': modelId,
'modelName': modelName
}
logging.info(f"Subscribed to camera {camera_id} with modelId {modelId}, modelName {modelName} and URL {rtsp_url}")
elif camera_id and camera_id in streams:
stream = streams.pop(camera_id)
stream['cap'].release()
logging.info(f"Unsubscribed from camera {camera_id}")
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
elif msg_type == "unsubscribe":
payload = data.get("payload", {})
camera_id = payload.get("cameraIdentifier")
logging.debug(f"Unsubscribing from camera {camera_id}")
with streams_lock:
if camera_id and camera_id in streams:
stream = streams.pop(camera_id)
stream['stop_event'].set()
stream['thread'].join()
stream['cap'].release()
logging.info(f"Unsubscribed from camera {camera_id}")
if camera_id in models and modelId in models[camera_id]:
del models[camera_id][modelId]
if not models[camera_id]:
del models[camera_id]
elif msg_type == "requestState":
# Handle state request
cpu_usage = psutil.cpu_percent()
memory_usage = psutil.virtual_memory().percent
if torch.cuda.is_available():
gpu_usage = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB
gpu_memory_usage = torch.cuda.memory_reserved() / (1024 ** 2) # Convert to MB
else:
gpu_usage = None
gpu_memory_usage = None
camera_connections = [
{
"cameraIdentifier": camera_id,
"modelId": stream['modelId'],
"modelName": stream['modelName'],
"online": True
}
for camera_id, stream in streams.items()
]
state_report = {
"type": "stateReport",
"cpuUsage": cpu_usage,
"memoryUsage": memory_usage,
"gpuUsage": gpu_usage,
"gpuMemoryUsage": gpu_memory_usage,
"cameraConnections": camera_connections
}
await websocket.send_text(json.dumps(state_report))
else:
logging.error(f"Unknown message type: {msg_type}")
except json.JSONDecodeError:
logging.error("Received invalid JSON message")
except (WebSocketDisconnect, ConnectionClosedError) as e:
logging.warning(f"WebSocket disconnected: {e}")
break
except Exception as e:
logging.error(f"Error handling message: {e}")
break
try:
await websocket.accept()
task = asyncio.create_task(process_streams())
heartbeat_task = asyncio.create_task(send_heartbeat())
message_task = asyncio.create_task(on_message())
await asyncio.gather(heartbeat_task, message_task)
except Exception as e:
logging.error(f"Error in detect websocket: {e}")
finally:
task.cancel()
await task
with streams_lock:
for camera_id, stream in streams.items():
stream['stop_event'].set()
stream['thread'].join()
stream['cap'].release()
stream['buffer'].queue.clear()
logging.info(f"Released camera {camera_id} and cleaned up resources")
streams.clear()
with models_lock:
models.clear()
logging.info("WebSocket connection closed")