feat: optimize model declaration in ram
All checks were successful
Build Worker Base and Application Images / check-base-changes (push) Successful in 8s
Build Worker Base and Application Images / build-base (push) Has been skipped
Build Worker Base and Application Images / build-docker (push) Successful in 2m47s
Build Worker Base and Application Images / deploy-stack (push) Successful in 10s

This commit is contained in:
ziesorx 2025-09-01 18:36:39 +07:00
parent c715b26a2a
commit ac85caca39
4 changed files with 679 additions and 216 deletions

254
app.py
View file

@ -28,7 +28,9 @@ from websockets.exceptions import ConnectionClosedError
from ultralytics import YOLO
# Import shared pipeline functions
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline, cleanup_camera_stability
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline, cleanup_camera_stability, cleanup_pipeline_node
from siwatsystem.model_registry import get_registry_status, cleanup_registry
from siwatsystem.mpta_manager import get_or_download_mpta, release_mpta, get_mpta_manager_status, cleanup_mpta_manager
app = FastAPI()
@ -444,30 +446,6 @@ streams_lock = threading.Lock()
models_lock = threading.Lock()
logger.debug("Initialized thread locks")
# Add helper to download mpta ZIP file from a remote URL
def download_mpta(url: str, dest_path: str) -> str:
try:
logger.info(f"Starting download of model from {url} to {dest_path}")
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
response = requests.get(url, stream=True)
if response.status_code == 200:
file_size = int(response.headers.get('content-length', 0))
logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
downloaded = 0
with open(dest_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
downloaded += len(chunk)
if file_size > 0 and downloaded % (file_size // 10) < 8192: # Log approximately every 10%
logger.debug(f"Download progress: {downloaded/file_size*100:.1f}%")
logger.info(f"Successfully downloaded mpta file from {url} to {dest_path}")
return dest_path
else:
logger.error(f"Failed to download mpta file (status code {response.status_code}): {response.text}")
return None
except Exception as e:
logger.error(f"Exception downloading mpta file from {url}: {str(e)}", exc_info=True)
return None
# Add helper to fetch snapshot image from HTTP/HTTPS URL
def fetch_snapshot(url: str):
@ -703,7 +681,9 @@ async def get_lpr_debug_info():
},
"thread_status": {
"lpr_listener_alive": lpr_listener_thread.is_alive() if lpr_listener_thread else False,
"cleanup_timer_alive": cleanup_timer_thread.is_alive() if cleanup_timer_thread else False
"cleanup_timer_alive": cleanup_timer_thread.is_alive() if cleanup_timer_thread else False,
"model_registry": get_registry_status(),
"mpta_manager": get_mpta_manager_status()
},
"cached_detections_by_camera": list(cached_detections.keys())
}
@ -1715,32 +1695,24 @@ async def detect(websocket: WebSocket):
display_identifier, camera_identifier = parts
camera_id = subscriptionIdentifier
# Load model if needed
# Load model if needed using shared MPTA manager
if model_url:
with models_lock:
if (camera_id not in models) or (modelId not in models[camera_id]):
logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
extraction_dir = os.path.join("models", camera_identifier, str(modelId))
os.makedirs(extraction_dir, exist_ok=True)
logger.info(f"Getting shared MPTA for camera {camera_id}, modelId {modelId}")
# Handle model loading (same as original)
parsed = urlparse(model_url)
if parsed.scheme in ("http", "https"):
filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
local_mpta = os.path.join(extraction_dir, filename)
local_path = download_mpta(model_url, local_mpta)
if not local_path:
logger.error(f"Failed to download model from {model_url}")
return
model_tree = load_pipeline_from_zip(local_path, extraction_dir)
else:
if not os.path.exists(model_url):
logger.error(f"Model file not found: {model_url}")
return
model_tree = load_pipeline_from_zip(model_url, extraction_dir)
# Use shared MPTA manager for optimized downloads
mpta_result = get_or_download_mpta(modelId, model_url, camera_id)
if not mpta_result:
logger.error(f"Failed to get/download MPTA for modelId {modelId}")
return
shared_extraction_path, local_mpta_file = mpta_result
# Load pipeline from local MPTA file
model_tree = load_pipeline_from_zip(local_mpta_file, shared_extraction_path)
if model_tree is None:
logger.error(f"Failed to load model {modelId}")
logger.error(f"Failed to load model {modelId} from shared MPTA")
return
if camera_id not in models:
@ -1857,6 +1829,18 @@ async def detect(websocket: WebSocket):
stream = streams.pop(subscription_id)
camera_url = subscription_to_camera.pop(subscription_id, None)
# Clean up model references for this camera
with models_lock:
if subscription_id in models:
camera_models = models[subscription_id]
for model_id, model_tree in camera_models.items():
logger.info(f"🧹 Cleaning up model references for camera {subscription_id}, modelId {model_id}")
# Release model registry references
cleanup_pipeline_node(model_tree)
# Release MPTA manager reference
release_mpta(model_id, subscription_id)
del models[subscription_id]
if camera_url and camera_url in camera_streams:
shared_stream = camera_streams[camera_url]
shared_stream["ref_count"] -= 1
@ -2015,169 +1999,6 @@ async def detect(websocket: WebSocket):
})
await reconcile_subscriptions(current_subs, websocket)
elif msg_type == "old_subscribe_logic_removed":
if model_url:
with models_lock:
if (camera_id not in models) or (modelId not in models[camera_id]):
logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
extraction_dir = os.path.join("models", camera_identifier, str(modelId))
os.makedirs(extraction_dir, exist_ok=True)
# If model_url is remote, download it first.
parsed = urlparse(model_url)
if parsed.scheme in ("http", "https"):
logger.info(f"Downloading remote .mpta file from {model_url}")
filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
local_mpta = os.path.join(extraction_dir, filename)
logger.debug(f"Download destination: {local_mpta}")
local_path = download_mpta(model_url, local_mpta)
if not local_path:
logger.error(f"Failed to download the remote .mpta file from {model_url}")
error_response = {
"type": "error",
"subscriptionIdentifier": subscriptionIdentifier,
"error": f"Failed to download model from {model_url}"
}
ws_logger.info(f"TX -> {json.dumps(error_response, separators=(',', ':'))}")
await websocket.send_json(error_response)
continue
model_tree = load_pipeline_from_zip(local_path, extraction_dir)
else:
logger.info(f"Loading local .mpta file from {model_url}")
# Check if file exists before attempting to load
if not os.path.exists(model_url):
logger.error(f"Local .mpta file not found: {model_url}")
logger.debug(f"Current working directory: {os.getcwd()}")
error_response = {
"type": "error",
"subscriptionIdentifier": subscriptionIdentifier,
"error": f"Model file not found: {model_url}"
}
ws_logger.info(f"TX -> {json.dumps(error_response, separators=(',', ':'))}")
await websocket.send_json(error_response)
continue
model_tree = load_pipeline_from_zip(model_url, extraction_dir)
if model_tree is None:
logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}")
error_response = {
"type": "error",
"subscriptionIdentifier": subscriptionIdentifier,
"error": f"Failed to load model {modelId}"
}
await websocket.send_json(error_response)
continue
if camera_id not in models:
models[camera_id] = {}
models[camera_id][modelId] = model_tree
logger.info(f"Successfully loaded model {modelId} for camera {camera_id}")
logger.debug(f"Model extraction directory: {extraction_dir}")
# Start LPR integration threads after first model is loaded (only once)
if not lpr_integration_started and hasattr(model_tree, 'get') and model_tree.get('redis_client'):
try:
start_lpr_integration()
lpr_integration_started = True
logger.info("🚀 LPR integration started after first model load")
except Exception as e:
logger.error(f"❌ Failed to start LPR integration: {e}")
if camera_id and (rtsp_url or snapshot_url):
with streams_lock:
# Determine camera URL for shared stream management
camera_url = snapshot_url if snapshot_url else rtsp_url
if camera_id not in streams and len(streams) < max_streams:
# Check if we already have a stream for this camera URL
shared_stream = camera_streams.get(camera_url)
if shared_stream:
# Reuse existing stream
logger.info(f"Reusing existing stream for camera URL: {camera_url}")
buffer = shared_stream["buffer"]
stop_event = shared_stream["stop_event"]
thread = shared_stream["thread"]
mode = shared_stream["mode"]
# Increment reference count
shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1
else:
# Create new stream
buffer = queue.Queue(maxsize=1)
stop_event = threading.Event()
if snapshot_url and snapshot_interval:
logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}")
thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "snapshot"
# Store shared stream info
shared_stream = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"mode": mode,
"url": snapshot_url,
"snapshot_interval": snapshot_interval,
"ref_count": 1
}
camera_streams[camera_url] = shared_stream
elif rtsp_url:
logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url)
if not cap.isOpened():
logger.error(f"Failed to open RTSP stream for camera {camera_id}")
continue
thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
thread.daemon = True
thread.start()
mode = "rtsp"
# Store shared stream info
shared_stream = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"mode": mode,
"url": rtsp_url,
"cap": cap,
"ref_count": 1
}
camera_streams[camera_url] = shared_stream
else:
logger.error(f"No valid URL provided for camera {camera_id}")
continue
# Create stream info for this subscription
stream_info = {
"buffer": buffer,
"thread": thread,
"stop_event": stop_event,
"modelId": modelId,
"modelName": modelName,
"subscriptionIdentifier": subscriptionIdentifier,
"cropX1": cropX1,
"cropY1": cropY1,
"cropX2": cropX2,
"cropY2": cropY2,
"mode": mode,
"camera_url": camera_url
}
if mode == "snapshot":
stream_info["snapshot_url"] = snapshot_url
stream_info["snapshot_interval"] = snapshot_interval
elif mode == "rtsp":
stream_info["rtsp_url"] = rtsp_url
stream_info["cap"] = shared_stream["cap"]
streams[camera_id] = stream_info
subscription_to_camera[camera_id] = camera_url
elif camera_id and camera_id in streams:
# If already subscribed, unsubscribe first
logger.info(f"Resubscribing to camera {camera_id}")
# Note: Keep models in memory for reuse across subscriptions
elif msg_type == "unsubscribe":
payload = data.get("payload", {})
subscriptionIdentifier = payload.get("subscriptionIdentifier")
@ -2473,7 +2294,22 @@ async def detect(websocket: WebSocket):
camera_streams.clear()
subscription_to_camera.clear()
with models_lock:
# Clean up all model references before clearing models dict
for camera_id, camera_models in models.items():
for model_id, model_tree in camera_models.items():
logger.info(f"🧹 Shutdown cleanup: Releasing model {model_id} for camera {camera_id}")
# Release model registry references
cleanup_pipeline_node(model_tree)
# Release MPTA manager reference
release_mpta(model_id, camera_id)
models.clear()
# Clean up the entire model registry and MPTA manager
# logger.info("🏭 Performing final model registry cleanup...")
# cleanup_registry()
# logger.info("🏭 Performing final MPTA manager cleanup...")
# cleanup_mpta_manager()
latest_frames.clear()
cached_detections.clear()
frame_skip_flags.clear()