diff --git a/app.py b/app.py index 4c15324..3a1c84a 100644 --- a/app.py +++ b/app.py @@ -28,7 +28,9 @@ from websockets.exceptions import ConnectionClosedError from ultralytics import YOLO # Import shared pipeline functions -from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline, cleanup_camera_stability +from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline, cleanup_camera_stability, cleanup_pipeline_node +from siwatsystem.model_registry import get_registry_status, cleanup_registry +from siwatsystem.mpta_manager import get_or_download_mpta, release_mpta, get_mpta_manager_status, cleanup_mpta_manager app = FastAPI() @@ -444,30 +446,6 @@ streams_lock = threading.Lock() models_lock = threading.Lock() logger.debug("Initialized thread locks") -# Add helper to download mpta ZIP file from a remote URL -def download_mpta(url: str, dest_path: str) -> str: - try: - logger.info(f"Starting download of model from {url} to {dest_path}") - os.makedirs(os.path.dirname(dest_path), exist_ok=True) - response = requests.get(url, stream=True) - if response.status_code == 200: - file_size = int(response.headers.get('content-length', 0)) - logger.info(f"Model file size: {file_size/1024/1024:.2f} MB") - downloaded = 0 - with open(dest_path, "wb") as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - downloaded += len(chunk) - if file_size > 0 and downloaded % (file_size // 10) < 8192: # Log approximately every 10% - logger.debug(f"Download progress: {downloaded/file_size*100:.1f}%") - logger.info(f"Successfully downloaded mpta file from {url} to {dest_path}") - return dest_path - else: - logger.error(f"Failed to download mpta file (status code {response.status_code}): {response.text}") - return None - except Exception as e: - logger.error(f"Exception downloading mpta file from {url}: {str(e)}", exc_info=True) - return None # Add helper to fetch snapshot image from HTTP/HTTPS URL def fetch_snapshot(url: str): @@ -703,7 +681,9 @@ async def get_lpr_debug_info(): }, "thread_status": { "lpr_listener_alive": lpr_listener_thread.is_alive() if lpr_listener_thread else False, - "cleanup_timer_alive": cleanup_timer_thread.is_alive() if cleanup_timer_thread else False + "cleanup_timer_alive": cleanup_timer_thread.is_alive() if cleanup_timer_thread else False, + "model_registry": get_registry_status(), + "mpta_manager": get_mpta_manager_status() }, "cached_detections_by_camera": list(cached_detections.keys()) } @@ -1715,32 +1695,24 @@ async def detect(websocket: WebSocket): display_identifier, camera_identifier = parts camera_id = subscriptionIdentifier - # Load model if needed + # Load model if needed using shared MPTA manager if model_url: with models_lock: if (camera_id not in models) or (modelId not in models[camera_id]): - logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}") - extraction_dir = os.path.join("models", camera_identifier, str(modelId)) - os.makedirs(extraction_dir, exist_ok=True) + logger.info(f"Getting shared MPTA for camera {camera_id}, modelId {modelId}") - # Handle model loading (same as original) - parsed = urlparse(model_url) - if parsed.scheme in ("http", "https"): - filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta" - local_mpta = os.path.join(extraction_dir, filename) - local_path = download_mpta(model_url, local_mpta) - if not local_path: - logger.error(f"Failed to download model from {model_url}") - return - model_tree = load_pipeline_from_zip(local_path, extraction_dir) - else: - if not os.path.exists(model_url): - logger.error(f"Model file not found: {model_url}") - return - model_tree = load_pipeline_from_zip(model_url, extraction_dir) + # Use shared MPTA manager for optimized downloads + mpta_result = get_or_download_mpta(modelId, model_url, camera_id) + if not mpta_result: + logger.error(f"Failed to get/download MPTA for modelId {modelId}") + return + shared_extraction_path, local_mpta_file = mpta_result + + # Load pipeline from local MPTA file + model_tree = load_pipeline_from_zip(local_mpta_file, shared_extraction_path) if model_tree is None: - logger.error(f"Failed to load model {modelId}") + logger.error(f"Failed to load model {modelId} from shared MPTA") return if camera_id not in models: @@ -1857,6 +1829,18 @@ async def detect(websocket: WebSocket): stream = streams.pop(subscription_id) camera_url = subscription_to_camera.pop(subscription_id, None) + # Clean up model references for this camera + with models_lock: + if subscription_id in models: + camera_models = models[subscription_id] + for model_id, model_tree in camera_models.items(): + logger.info(f"๐Ÿงน Cleaning up model references for camera {subscription_id}, modelId {model_id}") + # Release model registry references + cleanup_pipeline_node(model_tree) + # Release MPTA manager reference + release_mpta(model_id, subscription_id) + del models[subscription_id] + if camera_url and camera_url in camera_streams: shared_stream = camera_streams[camera_url] shared_stream["ref_count"] -= 1 @@ -2015,169 +1999,6 @@ async def detect(websocket: WebSocket): }) await reconcile_subscriptions(current_subs, websocket) - elif msg_type == "old_subscribe_logic_removed": - if model_url: - with models_lock: - if (camera_id not in models) or (modelId not in models[camera_id]): - logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}") - extraction_dir = os.path.join("models", camera_identifier, str(modelId)) - os.makedirs(extraction_dir, exist_ok=True) - # If model_url is remote, download it first. - parsed = urlparse(model_url) - if parsed.scheme in ("http", "https"): - logger.info(f"Downloading remote .mpta file from {model_url}") - filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta" - local_mpta = os.path.join(extraction_dir, filename) - logger.debug(f"Download destination: {local_mpta}") - local_path = download_mpta(model_url, local_mpta) - if not local_path: - logger.error(f"Failed to download the remote .mpta file from {model_url}") - error_response = { - "type": "error", - "subscriptionIdentifier": subscriptionIdentifier, - "error": f"Failed to download model from {model_url}" - } - ws_logger.info(f"TX -> {json.dumps(error_response, separators=(',', ':'))}") - await websocket.send_json(error_response) - continue - model_tree = load_pipeline_from_zip(local_path, extraction_dir) - else: - logger.info(f"Loading local .mpta file from {model_url}") - # Check if file exists before attempting to load - if not os.path.exists(model_url): - logger.error(f"Local .mpta file not found: {model_url}") - logger.debug(f"Current working directory: {os.getcwd()}") - error_response = { - "type": "error", - "subscriptionIdentifier": subscriptionIdentifier, - "error": f"Model file not found: {model_url}" - } - ws_logger.info(f"TX -> {json.dumps(error_response, separators=(',', ':'))}") - await websocket.send_json(error_response) - continue - model_tree = load_pipeline_from_zip(model_url, extraction_dir) - if model_tree is None: - logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}") - error_response = { - "type": "error", - "subscriptionIdentifier": subscriptionIdentifier, - "error": f"Failed to load model {modelId}" - } - await websocket.send_json(error_response) - continue - if camera_id not in models: - models[camera_id] = {} - models[camera_id][modelId] = model_tree - logger.info(f"Successfully loaded model {modelId} for camera {camera_id}") - logger.debug(f"Model extraction directory: {extraction_dir}") - - # Start LPR integration threads after first model is loaded (only once) - if not lpr_integration_started and hasattr(model_tree, 'get') and model_tree.get('redis_client'): - try: - start_lpr_integration() - lpr_integration_started = True - logger.info("๐Ÿš€ LPR integration started after first model load") - except Exception as e: - logger.error(f"โŒ Failed to start LPR integration: {e}") - if camera_id and (rtsp_url or snapshot_url): - with streams_lock: - # Determine camera URL for shared stream management - camera_url = snapshot_url if snapshot_url else rtsp_url - - if camera_id not in streams and len(streams) < max_streams: - # Check if we already have a stream for this camera URL - shared_stream = camera_streams.get(camera_url) - - if shared_stream: - # Reuse existing stream - logger.info(f"Reusing existing stream for camera URL: {camera_url}") - buffer = shared_stream["buffer"] - stop_event = shared_stream["stop_event"] - thread = shared_stream["thread"] - mode = shared_stream["mode"] - - # Increment reference count - shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1 - else: - # Create new stream - buffer = queue.Queue(maxsize=1) - stop_event = threading.Event() - - if snapshot_url and snapshot_interval: - logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}") - thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event)) - thread.daemon = True - thread.start() - mode = "snapshot" - - # Store shared stream info - shared_stream = { - "buffer": buffer, - "thread": thread, - "stop_event": stop_event, - "mode": mode, - "url": snapshot_url, - "snapshot_interval": snapshot_interval, - "ref_count": 1 - } - camera_streams[camera_url] = shared_stream - - elif rtsp_url: - logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}") - cap = cv2.VideoCapture(rtsp_url) - if not cap.isOpened(): - logger.error(f"Failed to open RTSP stream for camera {camera_id}") - continue - thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event)) - thread.daemon = True - thread.start() - mode = "rtsp" - - # Store shared stream info - shared_stream = { - "buffer": buffer, - "thread": thread, - "stop_event": stop_event, - "mode": mode, - "url": rtsp_url, - "cap": cap, - "ref_count": 1 - } - camera_streams[camera_url] = shared_stream - else: - logger.error(f"No valid URL provided for camera {camera_id}") - continue - - # Create stream info for this subscription - stream_info = { - "buffer": buffer, - "thread": thread, - "stop_event": stop_event, - "modelId": modelId, - "modelName": modelName, - "subscriptionIdentifier": subscriptionIdentifier, - "cropX1": cropX1, - "cropY1": cropY1, - "cropX2": cropX2, - "cropY2": cropY2, - "mode": mode, - "camera_url": camera_url - } - - if mode == "snapshot": - stream_info["snapshot_url"] = snapshot_url - stream_info["snapshot_interval"] = snapshot_interval - elif mode == "rtsp": - stream_info["rtsp_url"] = rtsp_url - stream_info["cap"] = shared_stream["cap"] - - streams[camera_id] = stream_info - subscription_to_camera[camera_id] = camera_url - - elif camera_id and camera_id in streams: - # If already subscribed, unsubscribe first - logger.info(f"Resubscribing to camera {camera_id}") - # Note: Keep models in memory for reuse across subscriptions elif msg_type == "unsubscribe": payload = data.get("payload", {}) subscriptionIdentifier = payload.get("subscriptionIdentifier") @@ -2473,7 +2294,22 @@ async def detect(websocket: WebSocket): camera_streams.clear() subscription_to_camera.clear() with models_lock: + # Clean up all model references before clearing models dict + for camera_id, camera_models in models.items(): + for model_id, model_tree in camera_models.items(): + logger.info(f"๐Ÿงน Shutdown cleanup: Releasing model {model_id} for camera {camera_id}") + # Release model registry references + cleanup_pipeline_node(model_tree) + # Release MPTA manager reference + release_mpta(model_id, camera_id) models.clear() + + # Clean up the entire model registry and MPTA manager + # logger.info("๐Ÿญ Performing final model registry cleanup...") + # cleanup_registry() + # logger.info("๐Ÿญ Performing final MPTA manager cleanup...") + # cleanup_mpta_manager() + latest_frames.clear() cached_detections.clear() frame_skip_flags.clear() diff --git a/siwatsystem/model_registry.py b/siwatsystem/model_registry.py new file mode 100644 index 0000000..95daf3b --- /dev/null +++ b/siwatsystem/model_registry.py @@ -0,0 +1,242 @@ +""" +Shared Model Registry for Memory Optimization + +This module implements a global shared model registry to prevent duplicate model loading +in memory when multiple cameras use the same model. This significantly reduces RAM and +GPU VRAM usage by ensuring only one instance of each unique model is loaded. + +Key Features: +- Thread-safe model loading and access +- Reference counting for proper cleanup +- Automatic model lifecycle management +- Maintains compatibility with existing pipeline system +""" + +import os +import threading +import logging +from typing import Dict, Any, Optional, Set +import torch +from ultralytics import YOLO + +# Create a logger for this module +logger = logging.getLogger("detector_worker.model_registry") + +class ModelRegistry: + """ + Singleton class for managing shared YOLO models across multiple cameras. + + This registry ensures that each unique model is loaded only once in memory, + dramatically reducing RAM and GPU VRAM usage when multiple cameras use the + same model. + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super(ModelRegistry, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + self._initialized = True + + # Thread-safe storage for loaded models + self._models: Dict[str, YOLO] = {} # modelId -> YOLO model instance + self._model_files: Dict[str, str] = {} # modelId -> file path + self._reference_counts: Dict[str, int] = {} # modelId -> reference count + self._model_lock = threading.RLock() # Reentrant lock for nested calls + + logger.info("๐Ÿญ Shared Model Registry initialized - ready for memory-optimized model loading") + + def get_model(self, model_id: str, model_file_path: str) -> YOLO: + """ + Get or load a YOLO model. Returns shared instance if already loaded. + + Args: + model_id: Unique identifier for the model + model_file_path: Path to the model file + + Returns: + YOLO model instance (shared across all callers) + """ + with self._model_lock: + if model_id in self._models: + # Model already loaded - increment reference count and return + self._reference_counts[model_id] += 1 + logger.info(f"๐Ÿ“– Model '{model_id}' reused (ref_count: {self._reference_counts[model_id]}) - SAVED MEMORY!") + return self._models[model_id] + + # Model not loaded yet - load it + logger.info(f"๐Ÿ”„ Loading NEW model '{model_id}' from {model_file_path}") + + if not os.path.exists(model_file_path): + raise FileNotFoundError(f"Model file {model_file_path} not found") + + try: + # Load the YOLO model + model = YOLO(model_file_path) + + # Move to GPU if available + if torch.cuda.is_available(): + logger.info(f"๐Ÿš€ CUDA available. Moving model '{model_id}' to GPU VRAM") + model.to("cuda") + else: + logger.info(f"๐Ÿ’ป CUDA not available. Using CPU for model '{model_id}'") + + # Store in registry + self._models[model_id] = model + self._model_files[model_id] = model_file_path + self._reference_counts[model_id] = 1 + + logger.info(f"โœ… Model '{model_id}' loaded and registered (ref_count: 1)") + self._log_registry_status() + + return model + + except Exception as e: + logger.error(f"โŒ Failed to load model '{model_id}' from {model_file_path}: {e}") + raise + + def release_model(self, model_id: str) -> None: + """ + Release a reference to a model. If reference count reaches zero, + the model may be unloaded to free memory. + + Args: + model_id: Unique identifier for the model to release + """ + with self._model_lock: + if model_id not in self._reference_counts: + logger.warning(f"โš ๏ธ Attempted to release unknown model '{model_id}'") + return + + self._reference_counts[model_id] -= 1 + logger.info(f"๐Ÿ“‰ Model '{model_id}' reference count decreased to {self._reference_counts[model_id]}") + + # For now, keep models in memory even when ref count reaches 0 + # This prevents reload overhead if the same model is needed again soon + # In the future, we could implement LRU eviction policy + # if self._reference_counts[model_id] <= 0: + # logger.info(f"๐Ÿ’ค Model '{model_id}' has 0 references but keeping in memory for reuse") + # Optionally: self._unload_model(model_id) + + def _unload_model(self, model_id: str) -> None: + """ + Internal method to unload a model from memory. + Currently not used to prevent reload overhead. + """ + with self._model_lock: + if model_id in self._models: + logger.info(f"๐Ÿ—‘๏ธ Unloading model '{model_id}' from memory") + + # Clear GPU memory if model was on GPU + model = self._models[model_id] + if hasattr(model, 'model') and hasattr(model.model, 'cuda'): + try: + # Move model to CPU before deletion to free GPU memory + model.to('cpu') + except Exception as e: + logger.warning(f"โš ๏ธ Failed to move model '{model_id}' to CPU: {e}") + + # Remove from registry + del self._models[model_id] + del self._model_files[model_id] + del self._reference_counts[model_id] + + # Force garbage collection + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + logger.info(f"โœ… Model '{model_id}' unloaded and memory freed") + self._log_registry_status() + + def get_registry_status(self) -> Dict[str, Any]: + """ + Get current status of the model registry. + + Returns: + Dictionary with registry statistics + """ + with self._model_lock: + return { + "total_models": len(self._models), + "models": { + model_id: { + "file_path": self._model_files[model_id], + "reference_count": self._reference_counts[model_id] + } + for model_id in self._models + }, + "total_references": sum(self._reference_counts.values()) + } + + def _log_registry_status(self) -> None: + """Log current registry status for debugging.""" + status = self.get_registry_status() + logger.info(f"๐Ÿ“Š Model Registry Status: {status['total_models']} unique models, {status['total_references']} total references") + for model_id, info in status['models'].items(): + logger.debug(f" ๐Ÿ“‹ '{model_id}': refs={info['reference_count']}, file={os.path.basename(info['file_path'])}") + + def cleanup_all(self) -> None: + """ + Clean up all models from the registry. Used during shutdown. + """ + with self._model_lock: + model_ids = list(self._models.keys()) + logger.info(f"๐Ÿงน Cleaning up {len(model_ids)} models from registry") + + for model_id in model_ids: + self._unload_model(model_id) + + logger.info("โœ… Model registry cleanup complete") + + +# Global singleton instance +_registry = ModelRegistry() + +def get_shared_model(model_id: str, model_file_path: str) -> YOLO: + """ + Convenience function to get a shared model instance. + + Args: + model_id: Unique identifier for the model + model_file_path: Path to the model file + + Returns: + YOLO model instance (shared across all callers) + """ + return _registry.get_model(model_id, model_file_path) + +def release_shared_model(model_id: str) -> None: + """ + Convenience function to release a shared model reference. + + Args: + model_id: Unique identifier for the model to release + """ + _registry.release_model(model_id) + +def get_registry_status() -> Dict[str, Any]: + """ + Convenience function to get registry status. + + Returns: + Dictionary with registry statistics + """ + return _registry.get_registry_status() + +def cleanup_registry() -> None: + """ + Convenience function to cleanup the entire registry. + """ + _registry.cleanup_all() \ No newline at end of file diff --git a/siwatsystem/mpta_manager.py b/siwatsystem/mpta_manager.py new file mode 100644 index 0000000..1abda3f --- /dev/null +++ b/siwatsystem/mpta_manager.py @@ -0,0 +1,375 @@ +""" +Shared MPTA Manager for Disk Space Optimization + +This module implements shared MPTA file management to prevent duplicate downloads +and extractions when multiple cameras use the same model. MPTA files are stored +in modelId-based directories and shared across all cameras using that model. + +Key Features: +- Thread-safe MPTA downloading and extraction +- ModelId-based directory structure: models/{modelId}/ +- Reference counting for proper cleanup +- Eliminates duplicate MPTA downloads +- Maintains compatibility with existing pipeline system +""" + +import os +import threading +import logging +import shutil +import requests +from typing import Dict, Set, Optional +from urllib.parse import urlparse +from .pympta import load_pipeline_from_zip + +# Create a logger for this module +logger = logging.getLogger("detector_worker.mpta_manager") + +class MPTAManager: + """ + Singleton class for managing shared MPTA files across multiple cameras. + + This manager ensures that each unique modelId is downloaded and extracted + only once, dramatically reducing disk usage and download time when multiple + cameras use the same model. + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super(MPTAManager, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + self._initialized = True + + # Thread-safe storage for MPTA management + self._model_paths: Dict[int, str] = {} # modelId -> shared_extraction_path + self._mpta_file_paths: Dict[int, str] = {} # modelId -> local_mpta_file_path + self._reference_counts: Dict[int, int] = {} # modelId -> reference count + self._download_locks: Dict[int, threading.Lock] = {} # modelId -> download lock + self._cameras_using_model: Dict[int, Set[str]] = {} # modelId -> set of camera_ids + self._manager_lock = threading.RLock() # Reentrant lock for nested calls + + logger.info("๐Ÿญ Shared MPTA Manager initialized - ready for disk-optimized MPTA management") + + def get_or_download_mpta(self, model_id: int, model_url: str, camera_id: str) -> Optional[tuple[str, str]]: + """ + Get or download an MPTA file. Returns (extraction_path, mpta_file_path) if successful. + + Args: + model_id: Unique identifier for the model + model_url: URL to download the MPTA file from + camera_id: Identifier for the requesting camera + + Returns: + Tuple of (extraction_path, mpta_file_path), or None if failed + """ + with self._manager_lock: + # Track camera usage + if model_id not in self._cameras_using_model: + self._cameras_using_model[model_id] = set() + self._cameras_using_model[model_id].add(camera_id) + + # Check if model directory already exists on disk (from previous sessions) + if model_id not in self._model_paths: + potential_path = f"models/{model_id}" + if os.path.exists(potential_path) and os.path.isdir(potential_path): + # Directory exists from previous session, find the MPTA file + mpta_files = [f for f in os.listdir(potential_path) if f.endswith('.mpta')] + if mpta_files: + # Use the first .mpta file found + mpta_file_path = os.path.join(potential_path, mpta_files[0]) + self._model_paths[model_id] = potential_path + self._mpta_file_paths[model_id] = mpta_file_path + self._reference_counts[model_id] = 0 # Will be incremented below + logger.info(f"๐Ÿ“‚ Found existing MPTA modelId {model_id} from previous session") + + # Check if already available + if model_id in self._model_paths: + shared_path = self._model_paths[model_id] + mpta_file_path = self._mpta_file_paths.get(model_id) + if os.path.exists(shared_path) and mpta_file_path and os.path.exists(mpta_file_path): + self._reference_counts[model_id] += 1 + logger.info(f"๐Ÿ“‚ MPTA modelId {model_id} reused for camera {camera_id} (ref_count: {self._reference_counts[model_id]}) - SAVED DOWNLOAD!") + return (shared_path, mpta_file_path) + else: + # Path was deleted externally, clean up our records + logger.warning(f"โš ๏ธ MPTA path for modelId {model_id} was deleted externally, will re-download") + del self._model_paths[model_id] + self._mpta_file_paths.pop(model_id, None) + self._reference_counts.pop(model_id, 0) + + # Need to download - get or create download lock for this modelId + if model_id not in self._download_locks: + self._download_locks[model_id] = threading.Lock() + + # Download with model-specific lock (released _manager_lock to allow other models) + download_lock = self._download_locks[model_id] + with download_lock: + # Double-check after acquiring download lock + with self._manager_lock: + if model_id in self._model_paths and os.path.exists(self._model_paths[model_id]): + mpta_file_path = self._mpta_file_paths.get(model_id) + if mpta_file_path and os.path.exists(mpta_file_path): + self._reference_counts[model_id] += 1 + logger.info(f"๐Ÿ“‚ MPTA modelId {model_id} became available during wait (ref_count: {self._reference_counts[model_id]})") + return (self._model_paths[model_id], mpta_file_path) + + # Actually download and extract + shared_path = f"models/{model_id}" + logger.info(f"๐Ÿ”„ Downloading NEW MPTA for modelId {model_id} from {model_url}") + + try: + # Ensure directory exists + os.makedirs(shared_path, exist_ok=True) + + # Download MPTA file + mpta_filename = self._extract_filename_from_url(model_url) or f"model_{model_id}.mpta" + local_mpta_path = os.path.join(shared_path, mpta_filename) + + if not self._download_file(model_url, local_mpta_path): + logger.error(f"โŒ Failed to download MPTA for modelId {model_id}") + return None + + # Extract MPTA + pipeline_tree = load_pipeline_from_zip(local_mpta_path, shared_path) + if pipeline_tree is None: + logger.error(f"โŒ Failed to extract MPTA for modelId {model_id}") + return None + + # Success - register in manager + with self._manager_lock: + self._model_paths[model_id] = shared_path + self._mpta_file_paths[model_id] = local_mpta_path + self._reference_counts[model_id] = 1 + + logger.info(f"โœ… MPTA modelId {model_id} downloaded and registered (ref_count: 1)") + self._log_manager_status() + + return (shared_path, local_mpta_path) + + except Exception as e: + logger.error(f"โŒ Error downloading/extracting MPTA for modelId {model_id}: {e}") + # Clean up partial download + if os.path.exists(shared_path): + shutil.rmtree(shared_path, ignore_errors=True) + return None + + def release_mpta(self, model_id: int, camera_id: str) -> None: + """ + Release a reference to an MPTA. If reference count reaches zero, + the MPTA directory may be cleaned up to free disk space. + + Args: + model_id: Unique identifier for the model to release + camera_id: Identifier for the camera releasing the reference + """ + with self._manager_lock: + if model_id not in self._reference_counts: + logger.warning(f"โš ๏ธ Attempted to release unknown MPTA modelId {model_id} for camera {camera_id}") + return + + # Remove camera from usage tracking + if model_id in self._cameras_using_model: + self._cameras_using_model[model_id].discard(camera_id) + + self._reference_counts[model_id] -= 1 + logger.info(f"๐Ÿ“‰ MPTA modelId {model_id} reference count decreased to {self._reference_counts[model_id]} (released by {camera_id})") + + # Clean up if no more references + # if self._reference_counts[model_id] <= 0: + # self._cleanup_mpta(model_id) + + def _cleanup_mpta(self, model_id: int) -> None: + """ + Internal method to clean up an MPTA directory and free disk space. + """ + if model_id in self._model_paths: + shared_path = self._model_paths[model_id] + + try: + if os.path.exists(shared_path): + shutil.rmtree(shared_path) + logger.info(f"๐Ÿ—‘๏ธ Cleaned up MPTA directory: {shared_path}") + + # Remove from tracking + del self._model_paths[model_id] + self._mpta_file_paths.pop(model_id, None) + del self._reference_counts[model_id] + self._cameras_using_model.pop(model_id, None) + + # Clean up download lock (optional, could keep for future use) + self._download_locks.pop(model_id, None) + + logger.info(f"โœ… MPTA modelId {model_id} fully cleaned up and disk space freed") + self._log_manager_status() + + except Exception as e: + logger.error(f"โŒ Error cleaning up MPTA modelId {model_id}: {e}") + + def get_shared_path(self, model_id: int) -> Optional[str]: + """ + Get the shared extraction path for a modelId without downloading. + + Args: + model_id: Model identifier to look up + + Returns: + Shared path if available, None otherwise + """ + with self._manager_lock: + return self._model_paths.get(model_id) + + def get_manager_status(self) -> Dict: + """ + Get current status of the MPTA manager. + + Returns: + Dictionary with manager statistics + """ + with self._manager_lock: + return { + "total_mpta_models": len(self._model_paths), + "models": { + str(model_id): { + "shared_path": path, + "reference_count": self._reference_counts.get(model_id, 0), + "cameras_using": list(self._cameras_using_model.get(model_id, set())) + } + for model_id, path in self._model_paths.items() + }, + "total_references": sum(self._reference_counts.values()), + "active_downloads": len(self._download_locks) + } + + def _log_manager_status(self) -> None: + """Log current manager status for debugging.""" + status = self.get_manager_status() + logger.info(f"๐Ÿ“Š MPTA Manager Status: {status['total_mpta_models']} unique models, {status['total_references']} total references") + for model_id, info in status['models'].items(): + cameras_str = ','.join(info['cameras_using'][:3]) # Show first 3 cameras + if len(info['cameras_using']) > 3: + cameras_str += f"+{len(info['cameras_using'])-3} more" + logger.debug(f" ๐Ÿ“‹ ModelId {model_id}: refs={info['reference_count']}, cameras=[{cameras_str}]") + + def cleanup_all(self) -> None: + """ + Clean up all MPTA directories. Used during shutdown. + """ + with self._manager_lock: + model_ids = list(self._model_paths.keys()) + logger.info(f"๐Ÿงน Cleaning up {len(model_ids)} MPTA directories") + + for model_id in model_ids: + self._cleanup_mpta(model_id) + + # Clear all tracking data + self._download_locks.clear() + logger.info("โœ… MPTA manager cleanup complete") + + def _download_file(self, url: str, local_path: str) -> bool: + """ + Download a file from URL to local path with progress logging. + + Args: + url: URL to download from + local_path: Local path to save to + + Returns: + True if successful, False otherwise + """ + try: + logger.info(f"โฌ‡๏ธ Starting download from {url}") + + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get('content-length', 0)) + if total_size > 0: + logger.info(f"๐Ÿ“ฆ File size: {total_size / 1024 / 1024:.2f} MB") + + downloaded = 0 + last_logged_progress = 0 + with open(local_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + downloaded += len(chunk) + + if total_size > 0: + progress = int((downloaded / total_size) * 100) + # Log at 10% intervals (10%, 20%, 30%, etc.) + if progress >= last_logged_progress + 10 and progress <= 100: + logger.debug(f"Download progress: {progress}%") + last_logged_progress = progress + + logger.info(f"โœ… Successfully downloaded to {local_path}") + return True + + except Exception as e: + logger.error(f"โŒ Download failed: {e}") + # Clean up partial file + if os.path.exists(local_path): + os.remove(local_path) + return False + + def _extract_filename_from_url(self, url: str) -> Optional[str]: + """Extract filename from URL.""" + try: + parsed = urlparse(url) + filename = os.path.basename(parsed.path) + return filename if filename else None + except Exception: + return None + + +# Global singleton instance +_mpta_manager = MPTAManager() + +def get_or_download_mpta(model_id: int, model_url: str, camera_id: str) -> Optional[tuple[str, str]]: + """ + Convenience function to get or download a shared MPTA. + + Args: + model_id: Unique identifier for the model + model_url: URL to download the MPTA file from + camera_id: Identifier for the requesting camera + + Returns: + Tuple of (extraction_path, mpta_file_path), or None if failed + """ + return _mpta_manager.get_or_download_mpta(model_id, model_url, camera_id) + +def release_mpta(model_id: int, camera_id: str) -> None: + """ + Convenience function to release a shared MPTA reference. + + Args: + model_id: Unique identifier for the model to release + camera_id: Identifier for the camera releasing the reference + """ + _mpta_manager.release_mpta(model_id, camera_id) + +def get_mpta_manager_status() -> Dict: + """ + Convenience function to get MPTA manager status. + + Returns: + Dictionary with manager statistics + """ + return _mpta_manager.get_manager_status() + +def cleanup_mpta_manager() -> None: + """ + Convenience function to cleanup the entire MPTA manager. + """ + _mpta_manager.cleanup_all() \ No newline at end of file diff --git a/siwatsystem/pympta.py b/siwatsystem/pympta.py index 975ee36..ac34d88 100644 --- a/siwatsystem/pympta.py +++ b/siwatsystem/pympta.py @@ -13,6 +13,7 @@ import concurrent.futures from ultralytics import YOLO from urllib.parse import urlparse from .database import DatabaseManager +from .model_registry import get_shared_model, release_shared_model from datetime import datetime # Create a logger specifically for this module @@ -98,13 +99,11 @@ def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client, db_manage logger.error(f"Model file {model_path} not found. Current directory: {os.getcwd()}") logger.error(f"Directory content: {os.listdir(os.path.dirname(model_path))}") raise FileNotFoundError(f"Model file {model_path} not found.") - logger.info(f"Loading model for node {node_config['modelId']} from {model_path}") - model = YOLO(model_path) - if torch.cuda.is_available(): - logger.info(f"CUDA available. Moving model {node_config['modelId']} to GPU VRAM") - model.to("cuda") - else: - logger.info(f"CUDA not available. Using CPU for model {node_config['modelId']}") + + # Use shared model registry to prevent duplicate loading + model_id = node_config['modelId'] + logger.info(f"Getting shared model for node {model_id} from {model_path}") + model = get_shared_model(model_id, model_path) # Prepare trigger class indices for optimization trigger_classes = node_config.get("triggerClasses", []) @@ -1108,6 +1107,17 @@ def is_camera_active(camera_id, model_id): return session_state.get("active", True) +def cleanup_pipeline_node(node: dict): + """Clean up a pipeline node and release its model reference.""" + if node and "modelId" in node: + model_id = node["modelId"] + logger.info(f"๐Ÿงน Cleaning up pipeline node: {model_id}") + release_shared_model(model_id) + + # Recursively clean up branches + for branch in node.get("branches", []): + cleanup_pipeline_node(branch) + def cleanup_camera_stability(camera_id): """Clean up stability tracking data when a camera is disconnected.""" global _camera_stability_tracking