feat: optimize model declaration in ram
	
		
			
	
		
	
	
		
	
		
			All checks were successful
		
		
	
	
		
			
				
	
				Build Worker Base and Application Images / check-base-changes (push) Successful in 8s
				
			
		
			
				
	
				Build Worker Base and Application Images / build-base (push) Has been skipped
				
			
		
			
				
	
				Build Worker Base and Application Images / build-docker (push) Successful in 2m47s
				
			
		
			
				
	
				Build Worker Base and Application Images / deploy-stack (push) Successful in 10s
				
			
		
		
	
	
				
					
				
			
		
			All checks were successful
		
		
	
	Build Worker Base and Application Images / check-base-changes (push) Successful in 8s
				
			Build Worker Base and Application Images / build-base (push) Has been skipped
				
			Build Worker Base and Application Images / build-docker (push) Successful in 2m47s
				
			Build Worker Base and Application Images / deploy-stack (push) Successful in 10s
				
			This commit is contained in:
		
							parent
							
								
									c715b26a2a
								
							
						
					
					
						commit
						ac85caca39
					
				
					 4 changed files with 679 additions and 216 deletions
				
			
		
							
								
								
									
										252
									
								
								app.py
									
										
									
									
									
								
							
							
						
						
									
										252
									
								
								app.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -28,7 +28,9 @@ from websockets.exceptions import ConnectionClosedError
 | 
			
		|||
from ultralytics import YOLO
 | 
			
		||||
 | 
			
		||||
# Import shared pipeline functions
 | 
			
		||||
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline, cleanup_camera_stability
 | 
			
		||||
from siwatsystem.pympta import load_pipeline_from_zip, run_pipeline, cleanup_camera_stability, cleanup_pipeline_node
 | 
			
		||||
from siwatsystem.model_registry import get_registry_status, cleanup_registry
 | 
			
		||||
from siwatsystem.mpta_manager import get_or_download_mpta, release_mpta, get_mpta_manager_status, cleanup_mpta_manager
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -444,30 +446,6 @@ streams_lock = threading.Lock()
 | 
			
		|||
models_lock = threading.Lock()
 | 
			
		||||
logger.debug("Initialized thread locks")
 | 
			
		||||
 | 
			
		||||
# Add helper to download mpta ZIP file from a remote URL
 | 
			
		||||
def download_mpta(url: str, dest_path: str) -> str:
 | 
			
		||||
    try:
 | 
			
		||||
        logger.info(f"Starting download of model from {url} to {dest_path}")
 | 
			
		||||
        os.makedirs(os.path.dirname(dest_path), exist_ok=True)
 | 
			
		||||
        response = requests.get(url, stream=True)
 | 
			
		||||
        if response.status_code == 200:
 | 
			
		||||
            file_size = int(response.headers.get('content-length', 0))
 | 
			
		||||
            logger.info(f"Model file size: {file_size/1024/1024:.2f} MB")
 | 
			
		||||
            downloaded = 0
 | 
			
		||||
            with open(dest_path, "wb") as f:
 | 
			
		||||
                for chunk in response.iter_content(chunk_size=8192):
 | 
			
		||||
                    f.write(chunk)
 | 
			
		||||
                    downloaded += len(chunk)
 | 
			
		||||
                    if file_size > 0 and downloaded % (file_size // 10) < 8192:  # Log approximately every 10%
 | 
			
		||||
                        logger.debug(f"Download progress: {downloaded/file_size*100:.1f}%")
 | 
			
		||||
            logger.info(f"Successfully downloaded mpta file from {url} to {dest_path}")
 | 
			
		||||
            return dest_path
 | 
			
		||||
        else:
 | 
			
		||||
            logger.error(f"Failed to download mpta file (status code {response.status_code}): {response.text}")
 | 
			
		||||
            return None
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.error(f"Exception downloading mpta file from {url}: {str(e)}", exc_info=True)
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
# Add helper to fetch snapshot image from HTTP/HTTPS URL
 | 
			
		||||
def fetch_snapshot(url: str):
 | 
			
		||||
| 
						 | 
				
			
			@ -703,7 +681,9 @@ async def get_lpr_debug_info():
 | 
			
		|||
            },
 | 
			
		||||
            "thread_status": {
 | 
			
		||||
                "lpr_listener_alive": lpr_listener_thread.is_alive() if lpr_listener_thread else False,
 | 
			
		||||
                "cleanup_timer_alive": cleanup_timer_thread.is_alive() if cleanup_timer_thread else False
 | 
			
		||||
                "cleanup_timer_alive": cleanup_timer_thread.is_alive() if cleanup_timer_thread else False,
 | 
			
		||||
                "model_registry": get_registry_status(),
 | 
			
		||||
                "mpta_manager": get_mpta_manager_status()
 | 
			
		||||
            },
 | 
			
		||||
            "cached_detections_by_camera": list(cached_detections.keys())
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			@ -1715,32 +1695,24 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
        display_identifier, camera_identifier = parts
 | 
			
		||||
        camera_id = subscriptionIdentifier
 | 
			
		||||
 | 
			
		||||
        # Load model if needed
 | 
			
		||||
        # Load model if needed using shared MPTA manager
 | 
			
		||||
        if model_url:
 | 
			
		||||
            with models_lock:
 | 
			
		||||
                if (camera_id not in models) or (modelId not in models[camera_id]):
 | 
			
		||||
                    logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
 | 
			
		||||
                    extraction_dir = os.path.join("models", camera_identifier, str(modelId))
 | 
			
		||||
                    os.makedirs(extraction_dir, exist_ok=True)
 | 
			
		||||
                    logger.info(f"Getting shared MPTA for camera {camera_id}, modelId {modelId}")
 | 
			
		||||
                    
 | 
			
		||||
                    # Handle model loading (same as original)
 | 
			
		||||
                    parsed = urlparse(model_url)
 | 
			
		||||
                    if parsed.scheme in ("http", "https"):
 | 
			
		||||
                        filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
 | 
			
		||||
                        local_mpta = os.path.join(extraction_dir, filename)
 | 
			
		||||
                        local_path = download_mpta(model_url, local_mpta)
 | 
			
		||||
                        if not local_path:
 | 
			
		||||
                            logger.error(f"Failed to download model from {model_url}")
 | 
			
		||||
                    # Use shared MPTA manager for optimized downloads
 | 
			
		||||
                    mpta_result = get_or_download_mpta(modelId, model_url, camera_id)
 | 
			
		||||
                    if not mpta_result:
 | 
			
		||||
                        logger.error(f"Failed to get/download MPTA for modelId {modelId}")
 | 
			
		||||
                        return
 | 
			
		||||
                        model_tree = load_pipeline_from_zip(local_path, extraction_dir)
 | 
			
		||||
                    else:
 | 
			
		||||
                        if not os.path.exists(model_url):
 | 
			
		||||
                            logger.error(f"Model file not found: {model_url}")
 | 
			
		||||
                            return
 | 
			
		||||
                        model_tree = load_pipeline_from_zip(model_url, extraction_dir)
 | 
			
		||||
                    
 | 
			
		||||
                    shared_extraction_path, local_mpta_file = mpta_result
 | 
			
		||||
                    
 | 
			
		||||
                    # Load pipeline from local MPTA file
 | 
			
		||||
                    model_tree = load_pipeline_from_zip(local_mpta_file, shared_extraction_path)
 | 
			
		||||
                    if model_tree is None:
 | 
			
		||||
                        logger.error(f"Failed to load model {modelId}")
 | 
			
		||||
                        logger.error(f"Failed to load model {modelId} from shared MPTA")
 | 
			
		||||
                        return
 | 
			
		||||
                        
 | 
			
		||||
                    if camera_id not in models:
 | 
			
		||||
| 
						 | 
				
			
			@ -1857,6 +1829,18 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
            stream = streams.pop(subscription_id)
 | 
			
		||||
            camera_url = subscription_to_camera.pop(subscription_id, None)
 | 
			
		||||
            
 | 
			
		||||
            # Clean up model references for this camera
 | 
			
		||||
            with models_lock:
 | 
			
		||||
                if subscription_id in models:
 | 
			
		||||
                    camera_models = models[subscription_id]
 | 
			
		||||
                    for model_id, model_tree in camera_models.items():
 | 
			
		||||
                        logger.info(f"🧹 Cleaning up model references for camera {subscription_id}, modelId {model_id}")
 | 
			
		||||
                        # Release model registry references
 | 
			
		||||
                        cleanup_pipeline_node(model_tree)
 | 
			
		||||
                        # Release MPTA manager reference
 | 
			
		||||
                        release_mpta(model_id, subscription_id)
 | 
			
		||||
                    del models[subscription_id]
 | 
			
		||||
            
 | 
			
		||||
            if camera_url and camera_url in camera_streams:
 | 
			
		||||
                shared_stream = camera_streams[camera_url]
 | 
			
		||||
                shared_stream["ref_count"] -= 1
 | 
			
		||||
| 
						 | 
				
			
			@ -2015,169 +1999,6 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
                                })
 | 
			
		||||
                    await reconcile_subscriptions(current_subs, websocket)
 | 
			
		||||
                    
 | 
			
		||||
                elif msg_type == "old_subscribe_logic_removed":
 | 
			
		||||
                    if model_url:
 | 
			
		||||
                        with models_lock:
 | 
			
		||||
                            if (camera_id not in models) or (modelId not in models[camera_id]):
 | 
			
		||||
                                logger.info(f"Loading model from {model_url} for camera {camera_id}, modelId {modelId}")
 | 
			
		||||
                                extraction_dir = os.path.join("models", camera_identifier, str(modelId))
 | 
			
		||||
                                os.makedirs(extraction_dir, exist_ok=True)
 | 
			
		||||
                                # If model_url is remote, download it first.
 | 
			
		||||
                                parsed = urlparse(model_url)
 | 
			
		||||
                                if parsed.scheme in ("http", "https"):
 | 
			
		||||
                                    logger.info(f"Downloading remote .mpta file from {model_url}")
 | 
			
		||||
                                    filename = os.path.basename(parsed.path) or f"model_{modelId}.mpta"
 | 
			
		||||
                                    local_mpta = os.path.join(extraction_dir, filename)
 | 
			
		||||
                                    logger.debug(f"Download destination: {local_mpta}")
 | 
			
		||||
                                    local_path = download_mpta(model_url, local_mpta)
 | 
			
		||||
                                    if not local_path:
 | 
			
		||||
                                        logger.error(f"Failed to download the remote .mpta file from {model_url}")
 | 
			
		||||
                                        error_response = {
 | 
			
		||||
                                            "type": "error",
 | 
			
		||||
                                            "subscriptionIdentifier": subscriptionIdentifier,
 | 
			
		||||
                                            "error": f"Failed to download model from {model_url}"
 | 
			
		||||
                                        }
 | 
			
		||||
                                        ws_logger.info(f"TX -> {json.dumps(error_response, separators=(',', ':'))}")
 | 
			
		||||
                                        await websocket.send_json(error_response)
 | 
			
		||||
                                        continue
 | 
			
		||||
                                    model_tree = load_pipeline_from_zip(local_path, extraction_dir)
 | 
			
		||||
                                else:
 | 
			
		||||
                                    logger.info(f"Loading local .mpta file from {model_url}")
 | 
			
		||||
                                    # Check if file exists before attempting to load
 | 
			
		||||
                                    if not os.path.exists(model_url):
 | 
			
		||||
                                        logger.error(f"Local .mpta file not found: {model_url}")
 | 
			
		||||
                                        logger.debug(f"Current working directory: {os.getcwd()}")
 | 
			
		||||
                                        error_response = {
 | 
			
		||||
                                            "type": "error",
 | 
			
		||||
                                            "subscriptionIdentifier": subscriptionIdentifier,
 | 
			
		||||
                                            "error": f"Model file not found: {model_url}"
 | 
			
		||||
                                        }
 | 
			
		||||
                                        ws_logger.info(f"TX -> {json.dumps(error_response, separators=(',', ':'))}")
 | 
			
		||||
                                        await websocket.send_json(error_response)
 | 
			
		||||
                                        continue
 | 
			
		||||
                                    model_tree = load_pipeline_from_zip(model_url, extraction_dir)
 | 
			
		||||
                                if model_tree is None:
 | 
			
		||||
                                    logger.error(f"Failed to load model {modelId} from .mpta file for camera {camera_id}")
 | 
			
		||||
                                    error_response = {
 | 
			
		||||
                                        "type": "error",
 | 
			
		||||
                                        "subscriptionIdentifier": subscriptionIdentifier,
 | 
			
		||||
                                        "error": f"Failed to load model {modelId}"
 | 
			
		||||
                                    }
 | 
			
		||||
                                    await websocket.send_json(error_response)
 | 
			
		||||
                                    continue
 | 
			
		||||
                                if camera_id not in models:
 | 
			
		||||
                                    models[camera_id] = {}
 | 
			
		||||
                                models[camera_id][modelId] = model_tree
 | 
			
		||||
                                logger.info(f"Successfully loaded model {modelId} for camera {camera_id}")
 | 
			
		||||
                                logger.debug(f"Model extraction directory: {extraction_dir}")
 | 
			
		||||
                                
 | 
			
		||||
                                # Start LPR integration threads after first model is loaded (only once)
 | 
			
		||||
                                if not lpr_integration_started and hasattr(model_tree, 'get') and model_tree.get('redis_client'):
 | 
			
		||||
                                    try:
 | 
			
		||||
                                        start_lpr_integration()
 | 
			
		||||
                                        lpr_integration_started = True
 | 
			
		||||
                                        logger.info("🚀 LPR integration started after first model load")
 | 
			
		||||
                                    except Exception as e:
 | 
			
		||||
                                        logger.error(f"❌ Failed to start LPR integration: {e}")
 | 
			
		||||
                    if camera_id and (rtsp_url or snapshot_url):
 | 
			
		||||
                        with streams_lock:
 | 
			
		||||
                            # Determine camera URL for shared stream management
 | 
			
		||||
                            camera_url = snapshot_url if snapshot_url else rtsp_url
 | 
			
		||||
                            
 | 
			
		||||
                            if camera_id not in streams and len(streams) < max_streams:
 | 
			
		||||
                                # Check if we already have a stream for this camera URL
 | 
			
		||||
                                shared_stream = camera_streams.get(camera_url)
 | 
			
		||||
                                
 | 
			
		||||
                                if shared_stream:
 | 
			
		||||
                                    # Reuse existing stream
 | 
			
		||||
                                    logger.info(f"Reusing existing stream for camera URL: {camera_url}")
 | 
			
		||||
                                    buffer = shared_stream["buffer"]
 | 
			
		||||
                                    stop_event = shared_stream["stop_event"]
 | 
			
		||||
                                    thread = shared_stream["thread"]
 | 
			
		||||
                                    mode = shared_stream["mode"]
 | 
			
		||||
                                    
 | 
			
		||||
                                    # Increment reference count
 | 
			
		||||
                                    shared_stream["ref_count"] = shared_stream.get("ref_count", 0) + 1
 | 
			
		||||
                                else:
 | 
			
		||||
                                    # Create new stream
 | 
			
		||||
                                    buffer = queue.Queue(maxsize=1)
 | 
			
		||||
                                    stop_event = threading.Event()
 | 
			
		||||
                                    
 | 
			
		||||
                                    if snapshot_url and snapshot_interval:
 | 
			
		||||
                                        logger.info(f"Creating new snapshot stream for camera {camera_id}: {snapshot_url}")
 | 
			
		||||
                                        thread = threading.Thread(target=snapshot_reader, args=(camera_id, snapshot_url, snapshot_interval, buffer, stop_event))
 | 
			
		||||
                                        thread.daemon = True
 | 
			
		||||
                                        thread.start()
 | 
			
		||||
                                        mode = "snapshot"
 | 
			
		||||
                                        
 | 
			
		||||
                                        # Store shared stream info
 | 
			
		||||
                                        shared_stream = {
 | 
			
		||||
                                            "buffer": buffer,
 | 
			
		||||
                                            "thread": thread,
 | 
			
		||||
                                            "stop_event": stop_event,
 | 
			
		||||
                                            "mode": mode,
 | 
			
		||||
                                            "url": snapshot_url,
 | 
			
		||||
                                            "snapshot_interval": snapshot_interval,
 | 
			
		||||
                                            "ref_count": 1
 | 
			
		||||
                                        }
 | 
			
		||||
                                        camera_streams[camera_url] = shared_stream
 | 
			
		||||
                                        
 | 
			
		||||
                                    elif rtsp_url:
 | 
			
		||||
                                        logger.info(f"Creating new RTSP stream for camera {camera_id}: {rtsp_url}")
 | 
			
		||||
                                        cap = cv2.VideoCapture(rtsp_url)
 | 
			
		||||
                                        if not cap.isOpened():
 | 
			
		||||
                                            logger.error(f"Failed to open RTSP stream for camera {camera_id}")
 | 
			
		||||
                                            continue
 | 
			
		||||
                                        thread = threading.Thread(target=frame_reader, args=(camera_id, cap, buffer, stop_event))
 | 
			
		||||
                                        thread.daemon = True
 | 
			
		||||
                                        thread.start()
 | 
			
		||||
                                        mode = "rtsp"
 | 
			
		||||
                                        
 | 
			
		||||
                                        # Store shared stream info
 | 
			
		||||
                                        shared_stream = {
 | 
			
		||||
                                            "buffer": buffer,
 | 
			
		||||
                                            "thread": thread,
 | 
			
		||||
                                            "stop_event": stop_event,
 | 
			
		||||
                                            "mode": mode,
 | 
			
		||||
                                            "url": rtsp_url,
 | 
			
		||||
                                            "cap": cap,
 | 
			
		||||
                                            "ref_count": 1
 | 
			
		||||
                                        }
 | 
			
		||||
                                        camera_streams[camera_url] = shared_stream
 | 
			
		||||
                                    else:
 | 
			
		||||
                                        logger.error(f"No valid URL provided for camera {camera_id}")
 | 
			
		||||
                                        continue
 | 
			
		||||
                                
 | 
			
		||||
                                # Create stream info for this subscription
 | 
			
		||||
                                stream_info = {
 | 
			
		||||
                                    "buffer": buffer,
 | 
			
		||||
                                    "thread": thread,
 | 
			
		||||
                                    "stop_event": stop_event,
 | 
			
		||||
                                    "modelId": modelId,
 | 
			
		||||
                                    "modelName": modelName,
 | 
			
		||||
                                    "subscriptionIdentifier": subscriptionIdentifier,
 | 
			
		||||
                                    "cropX1": cropX1,
 | 
			
		||||
                                    "cropY1": cropY1,
 | 
			
		||||
                                    "cropX2": cropX2,
 | 
			
		||||
                                    "cropY2": cropY2,
 | 
			
		||||
                                    "mode": mode,
 | 
			
		||||
                                    "camera_url": camera_url
 | 
			
		||||
                                }
 | 
			
		||||
                                
 | 
			
		||||
                                if mode == "snapshot":
 | 
			
		||||
                                    stream_info["snapshot_url"] = snapshot_url
 | 
			
		||||
                                    stream_info["snapshot_interval"] = snapshot_interval
 | 
			
		||||
                                elif mode == "rtsp":
 | 
			
		||||
                                    stream_info["rtsp_url"] = rtsp_url
 | 
			
		||||
                                    stream_info["cap"] = shared_stream["cap"]
 | 
			
		||||
                                
 | 
			
		||||
                                streams[camera_id] = stream_info
 | 
			
		||||
                                subscription_to_camera[camera_id] = camera_url
 | 
			
		||||
                                
 | 
			
		||||
                            elif camera_id and camera_id in streams:
 | 
			
		||||
                                # If already subscribed, unsubscribe first
 | 
			
		||||
                                logger.info(f"Resubscribing to camera {camera_id}")
 | 
			
		||||
                                # Note: Keep models in memory for reuse across subscriptions
 | 
			
		||||
                elif msg_type == "unsubscribe":
 | 
			
		||||
                    payload = data.get("payload", {})
 | 
			
		||||
                    subscriptionIdentifier = payload.get("subscriptionIdentifier")
 | 
			
		||||
| 
						 | 
				
			
			@ -2473,7 +2294,22 @@ async def detect(websocket: WebSocket):
 | 
			
		|||
            camera_streams.clear()
 | 
			
		||||
            subscription_to_camera.clear()
 | 
			
		||||
        with models_lock:
 | 
			
		||||
            # Clean up all model references before clearing models dict
 | 
			
		||||
            for camera_id, camera_models in models.items():
 | 
			
		||||
                for model_id, model_tree in camera_models.items():
 | 
			
		||||
                    logger.info(f"🧹 Shutdown cleanup: Releasing model {model_id} for camera {camera_id}")
 | 
			
		||||
                    # Release model registry references
 | 
			
		||||
                    cleanup_pipeline_node(model_tree)
 | 
			
		||||
                    # Release MPTA manager reference
 | 
			
		||||
                    release_mpta(model_id, camera_id)
 | 
			
		||||
            models.clear()
 | 
			
		||||
            
 | 
			
		||||
        # Clean up the entire model registry and MPTA manager
 | 
			
		||||
        # logger.info("🏭 Performing final model registry cleanup...")
 | 
			
		||||
        # cleanup_registry()
 | 
			
		||||
        # logger.info("🏭 Performing final MPTA manager cleanup...")
 | 
			
		||||
        # cleanup_mpta_manager()
 | 
			
		||||
        
 | 
			
		||||
        latest_frames.clear()
 | 
			
		||||
        cached_detections.clear()
 | 
			
		||||
        frame_skip_flags.clear()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										242
									
								
								siwatsystem/model_registry.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										242
									
								
								siwatsystem/model_registry.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,242 @@
 | 
			
		|||
"""
 | 
			
		||||
Shared Model Registry for Memory Optimization
 | 
			
		||||
 | 
			
		||||
This module implements a global shared model registry to prevent duplicate model loading
 | 
			
		||||
in memory when multiple cameras use the same model. This significantly reduces RAM and
 | 
			
		||||
GPU VRAM usage by ensuring only one instance of each unique model is loaded.
 | 
			
		||||
 | 
			
		||||
Key Features:
 | 
			
		||||
- Thread-safe model loading and access
 | 
			
		||||
- Reference counting for proper cleanup
 | 
			
		||||
- Automatic model lifecycle management
 | 
			
		||||
- Maintains compatibility with existing pipeline system
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import threading
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Dict, Any, Optional, Set
 | 
			
		||||
import torch
 | 
			
		||||
from ultralytics import YOLO
 | 
			
		||||
 | 
			
		||||
# Create a logger for this module
 | 
			
		||||
logger = logging.getLogger("detector_worker.model_registry")
 | 
			
		||||
 | 
			
		||||
class ModelRegistry:
 | 
			
		||||
    """
 | 
			
		||||
    Singleton class for managing shared YOLO models across multiple cameras.
 | 
			
		||||
    
 | 
			
		||||
    This registry ensures that each unique model is loaded only once in memory,
 | 
			
		||||
    dramatically reducing RAM and GPU VRAM usage when multiple cameras use the
 | 
			
		||||
    same model.
 | 
			
		||||
    """
 | 
			
		||||
    
 | 
			
		||||
    _instance = None
 | 
			
		||||
    _lock = threading.Lock()
 | 
			
		||||
    
 | 
			
		||||
    def __new__(cls):
 | 
			
		||||
        if cls._instance is None:
 | 
			
		||||
            with cls._lock:
 | 
			
		||||
                if cls._instance is None:
 | 
			
		||||
                    cls._instance = super(ModelRegistry, cls).__new__(cls)
 | 
			
		||||
                    cls._instance._initialized = False
 | 
			
		||||
        return cls._instance
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        if self._initialized:
 | 
			
		||||
            return
 | 
			
		||||
            
 | 
			
		||||
        self._initialized = True
 | 
			
		||||
        
 | 
			
		||||
        # Thread-safe storage for loaded models
 | 
			
		||||
        self._models: Dict[str, YOLO] = {}  # modelId -> YOLO model instance
 | 
			
		||||
        self._model_files: Dict[str, str] = {}  # modelId -> file path
 | 
			
		||||
        self._reference_counts: Dict[str, int] = {}  # modelId -> reference count
 | 
			
		||||
        self._model_lock = threading.RLock()  # Reentrant lock for nested calls
 | 
			
		||||
        
 | 
			
		||||
        logger.info("🏭 Shared Model Registry initialized - ready for memory-optimized model loading")
 | 
			
		||||
    
 | 
			
		||||
    def get_model(self, model_id: str, model_file_path: str) -> YOLO:
 | 
			
		||||
        """
 | 
			
		||||
        Get or load a YOLO model. Returns shared instance if already loaded.
 | 
			
		||||
        
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: Unique identifier for the model
 | 
			
		||||
            model_file_path: Path to the model file
 | 
			
		||||
            
 | 
			
		||||
        Returns:
 | 
			
		||||
            YOLO model instance (shared across all callers)
 | 
			
		||||
        """
 | 
			
		||||
        with self._model_lock:
 | 
			
		||||
            if model_id in self._models:
 | 
			
		||||
                # Model already loaded - increment reference count and return
 | 
			
		||||
                self._reference_counts[model_id] += 1
 | 
			
		||||
                logger.info(f"📖 Model '{model_id}' reused (ref_count: {self._reference_counts[model_id]}) - SAVED MEMORY!")
 | 
			
		||||
                return self._models[model_id]
 | 
			
		||||
            
 | 
			
		||||
            # Model not loaded yet - load it
 | 
			
		||||
            logger.info(f"🔄 Loading NEW model '{model_id}' from {model_file_path}")
 | 
			
		||||
            
 | 
			
		||||
            if not os.path.exists(model_file_path):
 | 
			
		||||
                raise FileNotFoundError(f"Model file {model_file_path} not found")
 | 
			
		||||
            
 | 
			
		||||
            try:
 | 
			
		||||
                # Load the YOLO model
 | 
			
		||||
                model = YOLO(model_file_path)
 | 
			
		||||
                
 | 
			
		||||
                # Move to GPU if available
 | 
			
		||||
                if torch.cuda.is_available():
 | 
			
		||||
                    logger.info(f"🚀 CUDA available. Moving model '{model_id}' to GPU VRAM")
 | 
			
		||||
                    model.to("cuda")
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.info(f"💻 CUDA not available. Using CPU for model '{model_id}'")
 | 
			
		||||
                
 | 
			
		||||
                # Store in registry
 | 
			
		||||
                self._models[model_id] = model
 | 
			
		||||
                self._model_files[model_id] = model_file_path
 | 
			
		||||
                self._reference_counts[model_id] = 1
 | 
			
		||||
                
 | 
			
		||||
                logger.info(f"✅ Model '{model_id}' loaded and registered (ref_count: 1)")
 | 
			
		||||
                self._log_registry_status()
 | 
			
		||||
                
 | 
			
		||||
                return model
 | 
			
		||||
                
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error(f"❌ Failed to load model '{model_id}' from {model_file_path}: {e}")
 | 
			
		||||
                raise
 | 
			
		||||
    
 | 
			
		||||
    def release_model(self, model_id: str) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Release a reference to a model. If reference count reaches zero,
 | 
			
		||||
        the model may be unloaded to free memory.
 | 
			
		||||
        
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: Unique identifier for the model to release
 | 
			
		||||
        """
 | 
			
		||||
        with self._model_lock:
 | 
			
		||||
            if model_id not in self._reference_counts:
 | 
			
		||||
                logger.warning(f"⚠️ Attempted to release unknown model '{model_id}'")
 | 
			
		||||
                return
 | 
			
		||||
            
 | 
			
		||||
            self._reference_counts[model_id] -= 1
 | 
			
		||||
            logger.info(f"📉 Model '{model_id}' reference count decreased to {self._reference_counts[model_id]}")
 | 
			
		||||
            
 | 
			
		||||
            # For now, keep models in memory even when ref count reaches 0
 | 
			
		||||
            # This prevents reload overhead if the same model is needed again soon
 | 
			
		||||
            # In the future, we could implement LRU eviction policy
 | 
			
		||||
            # if self._reference_counts[model_id] <= 0:
 | 
			
		||||
            #     logger.info(f"💤 Model '{model_id}' has 0 references but keeping in memory for reuse")
 | 
			
		||||
                # Optionally: self._unload_model(model_id)
 | 
			
		||||
    
 | 
			
		||||
    def _unload_model(self, model_id: str) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Internal method to unload a model from memory.
 | 
			
		||||
        Currently not used to prevent reload overhead.
 | 
			
		||||
        """
 | 
			
		||||
        with self._model_lock:
 | 
			
		||||
            if model_id in self._models:
 | 
			
		||||
                logger.info(f"🗑️ Unloading model '{model_id}' from memory")
 | 
			
		||||
                
 | 
			
		||||
                # Clear GPU memory if model was on GPU
 | 
			
		||||
                model = self._models[model_id]
 | 
			
		||||
                if hasattr(model, 'model') and hasattr(model.model, 'cuda'):
 | 
			
		||||
                    try:
 | 
			
		||||
                        # Move model to CPU before deletion to free GPU memory
 | 
			
		||||
                        model.to('cpu')
 | 
			
		||||
                    except Exception as e:
 | 
			
		||||
                        logger.warning(f"⚠️ Failed to move model '{model_id}' to CPU: {e}")
 | 
			
		||||
                
 | 
			
		||||
                # Remove from registry
 | 
			
		||||
                del self._models[model_id]
 | 
			
		||||
                del self._model_files[model_id]
 | 
			
		||||
                del self._reference_counts[model_id]
 | 
			
		||||
                
 | 
			
		||||
                # Force garbage collection
 | 
			
		||||
                import gc
 | 
			
		||||
                gc.collect()
 | 
			
		||||
                if torch.cuda.is_available():
 | 
			
		||||
                    torch.cuda.empty_cache()
 | 
			
		||||
                
 | 
			
		||||
                logger.info(f"✅ Model '{model_id}' unloaded and memory freed")
 | 
			
		||||
                self._log_registry_status()
 | 
			
		||||
    
 | 
			
		||||
    def get_registry_status(self) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Get current status of the model registry.
 | 
			
		||||
        
 | 
			
		||||
        Returns:
 | 
			
		||||
            Dictionary with registry statistics
 | 
			
		||||
        """
 | 
			
		||||
        with self._model_lock:
 | 
			
		||||
            return {
 | 
			
		||||
                "total_models": len(self._models),
 | 
			
		||||
                "models": {
 | 
			
		||||
                    model_id: {
 | 
			
		||||
                        "file_path": self._model_files[model_id],
 | 
			
		||||
                        "reference_count": self._reference_counts[model_id]
 | 
			
		||||
                    }
 | 
			
		||||
                    for model_id in self._models
 | 
			
		||||
                },
 | 
			
		||||
                "total_references": sum(self._reference_counts.values())
 | 
			
		||||
            }
 | 
			
		||||
    
 | 
			
		||||
    def _log_registry_status(self) -> None:
 | 
			
		||||
        """Log current registry status for debugging."""
 | 
			
		||||
        status = self.get_registry_status()
 | 
			
		||||
        logger.info(f"📊 Model Registry Status: {status['total_models']} unique models, {status['total_references']} total references")
 | 
			
		||||
        for model_id, info in status['models'].items():
 | 
			
		||||
            logger.debug(f"   📋 '{model_id}': refs={info['reference_count']}, file={os.path.basename(info['file_path'])}")
 | 
			
		||||
    
 | 
			
		||||
    def cleanup_all(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Clean up all models from the registry. Used during shutdown.
 | 
			
		||||
        """
 | 
			
		||||
        with self._model_lock:
 | 
			
		||||
            model_ids = list(self._models.keys())
 | 
			
		||||
            logger.info(f"🧹 Cleaning up {len(model_ids)} models from registry")
 | 
			
		||||
            
 | 
			
		||||
            for model_id in model_ids:
 | 
			
		||||
                self._unload_model(model_id)
 | 
			
		||||
            
 | 
			
		||||
            logger.info("✅ Model registry cleanup complete")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Global singleton instance
 | 
			
		||||
_registry = ModelRegistry()
 | 
			
		||||
 | 
			
		||||
def get_shared_model(model_id: str, model_file_path: str) -> YOLO:
 | 
			
		||||
    """
 | 
			
		||||
    Convenience function to get a shared model instance.
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
        model_id: Unique identifier for the model
 | 
			
		||||
        model_file_path: Path to the model file
 | 
			
		||||
        
 | 
			
		||||
    Returns:
 | 
			
		||||
        YOLO model instance (shared across all callers)
 | 
			
		||||
    """
 | 
			
		||||
    return _registry.get_model(model_id, model_file_path)
 | 
			
		||||
 | 
			
		||||
def release_shared_model(model_id: str) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Convenience function to release a shared model reference.
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
        model_id: Unique identifier for the model to release
 | 
			
		||||
    """
 | 
			
		||||
    _registry.release_model(model_id)
 | 
			
		||||
 | 
			
		||||
def get_registry_status() -> Dict[str, Any]:
 | 
			
		||||
    """
 | 
			
		||||
    Convenience function to get registry status.
 | 
			
		||||
    
 | 
			
		||||
    Returns:
 | 
			
		||||
        Dictionary with registry statistics
 | 
			
		||||
    """
 | 
			
		||||
    return _registry.get_registry_status()
 | 
			
		||||
 | 
			
		||||
def cleanup_registry() -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Convenience function to cleanup the entire registry.
 | 
			
		||||
    """
 | 
			
		||||
    _registry.cleanup_all()
 | 
			
		||||
							
								
								
									
										375
									
								
								siwatsystem/mpta_manager.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										375
									
								
								siwatsystem/mpta_manager.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,375 @@
 | 
			
		|||
"""
 | 
			
		||||
Shared MPTA Manager for Disk Space Optimization
 | 
			
		||||
 | 
			
		||||
This module implements shared MPTA file management to prevent duplicate downloads
 | 
			
		||||
and extractions when multiple cameras use the same model. MPTA files are stored
 | 
			
		||||
in modelId-based directories and shared across all cameras using that model.
 | 
			
		||||
 | 
			
		||||
Key Features:
 | 
			
		||||
- Thread-safe MPTA downloading and extraction
 | 
			
		||||
- ModelId-based directory structure: models/{modelId}/
 | 
			
		||||
- Reference counting for proper cleanup
 | 
			
		||||
- Eliminates duplicate MPTA downloads
 | 
			
		||||
- Maintains compatibility with existing pipeline system
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import threading
 | 
			
		||||
import logging
 | 
			
		||||
import shutil
 | 
			
		||||
import requests
 | 
			
		||||
from typing import Dict, Set, Optional
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
from .pympta import load_pipeline_from_zip
 | 
			
		||||
 | 
			
		||||
# Create a logger for this module
 | 
			
		||||
logger = logging.getLogger("detector_worker.mpta_manager")
 | 
			
		||||
 | 
			
		||||
class MPTAManager:
 | 
			
		||||
    """
 | 
			
		||||
    Singleton class for managing shared MPTA files across multiple cameras.
 | 
			
		||||
    
 | 
			
		||||
    This manager ensures that each unique modelId is downloaded and extracted
 | 
			
		||||
    only once, dramatically reducing disk usage and download time when multiple
 | 
			
		||||
    cameras use the same model.
 | 
			
		||||
    """
 | 
			
		||||
    
 | 
			
		||||
    _instance = None
 | 
			
		||||
    _lock = threading.Lock()
 | 
			
		||||
    
 | 
			
		||||
    def __new__(cls):
 | 
			
		||||
        if cls._instance is None:
 | 
			
		||||
            with cls._lock:
 | 
			
		||||
                if cls._instance is None:
 | 
			
		||||
                    cls._instance = super(MPTAManager, cls).__new__(cls)
 | 
			
		||||
                    cls._instance._initialized = False
 | 
			
		||||
        return cls._instance
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        if self._initialized:
 | 
			
		||||
            return
 | 
			
		||||
            
 | 
			
		||||
        self._initialized = True
 | 
			
		||||
        
 | 
			
		||||
        # Thread-safe storage for MPTA management
 | 
			
		||||
        self._model_paths: Dict[int, str] = {}  # modelId -> shared_extraction_path
 | 
			
		||||
        self._mpta_file_paths: Dict[int, str] = {}  # modelId -> local_mpta_file_path
 | 
			
		||||
        self._reference_counts: Dict[int, int] = {}  # modelId -> reference count  
 | 
			
		||||
        self._download_locks: Dict[int, threading.Lock] = {}  # modelId -> download lock
 | 
			
		||||
        self._cameras_using_model: Dict[int, Set[str]] = {}  # modelId -> set of camera_ids
 | 
			
		||||
        self._manager_lock = threading.RLock()  # Reentrant lock for nested calls
 | 
			
		||||
        
 | 
			
		||||
        logger.info("🏭 Shared MPTA Manager initialized - ready for disk-optimized MPTA management")
 | 
			
		||||
    
 | 
			
		||||
    def get_or_download_mpta(self, model_id: int, model_url: str, camera_id: str) -> Optional[tuple[str, str]]:
 | 
			
		||||
        """
 | 
			
		||||
        Get or download an MPTA file. Returns (extraction_path, mpta_file_path) if successful.
 | 
			
		||||
        
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: Unique identifier for the model
 | 
			
		||||
            model_url: URL to download the MPTA file from
 | 
			
		||||
            camera_id: Identifier for the requesting camera
 | 
			
		||||
            
 | 
			
		||||
        Returns:
 | 
			
		||||
            Tuple of (extraction_path, mpta_file_path), or None if failed
 | 
			
		||||
        """
 | 
			
		||||
        with self._manager_lock:
 | 
			
		||||
            # Track camera usage
 | 
			
		||||
            if model_id not in self._cameras_using_model:
 | 
			
		||||
                self._cameras_using_model[model_id] = set()
 | 
			
		||||
            self._cameras_using_model[model_id].add(camera_id)
 | 
			
		||||
            
 | 
			
		||||
            # Check if model directory already exists on disk (from previous sessions)
 | 
			
		||||
            if model_id not in self._model_paths:
 | 
			
		||||
                potential_path = f"models/{model_id}"
 | 
			
		||||
                if os.path.exists(potential_path) and os.path.isdir(potential_path):
 | 
			
		||||
                    # Directory exists from previous session, find the MPTA file
 | 
			
		||||
                    mpta_files = [f for f in os.listdir(potential_path) if f.endswith('.mpta')]
 | 
			
		||||
                    if mpta_files:
 | 
			
		||||
                        # Use the first .mpta file found
 | 
			
		||||
                        mpta_file_path = os.path.join(potential_path, mpta_files[0])
 | 
			
		||||
                        self._model_paths[model_id] = potential_path
 | 
			
		||||
                        self._mpta_file_paths[model_id] = mpta_file_path
 | 
			
		||||
                        self._reference_counts[model_id] = 0  # Will be incremented below
 | 
			
		||||
                        logger.info(f"📂 Found existing MPTA modelId {model_id} from previous session")
 | 
			
		||||
            
 | 
			
		||||
            # Check if already available
 | 
			
		||||
            if model_id in self._model_paths:
 | 
			
		||||
                shared_path = self._model_paths[model_id]
 | 
			
		||||
                mpta_file_path = self._mpta_file_paths.get(model_id)
 | 
			
		||||
                if os.path.exists(shared_path) and mpta_file_path and os.path.exists(mpta_file_path):
 | 
			
		||||
                    self._reference_counts[model_id] += 1
 | 
			
		||||
                    logger.info(f"📂 MPTA modelId {model_id} reused for camera {camera_id} (ref_count: {self._reference_counts[model_id]}) - SAVED DOWNLOAD!")
 | 
			
		||||
                    return (shared_path, mpta_file_path)
 | 
			
		||||
                else:
 | 
			
		||||
                    # Path was deleted externally, clean up our records
 | 
			
		||||
                    logger.warning(f"⚠️ MPTA path for modelId {model_id} was deleted externally, will re-download")
 | 
			
		||||
                    del self._model_paths[model_id]
 | 
			
		||||
                    self._mpta_file_paths.pop(model_id, None)
 | 
			
		||||
                    self._reference_counts.pop(model_id, 0)
 | 
			
		||||
            
 | 
			
		||||
            # Need to download - get or create download lock for this modelId
 | 
			
		||||
            if model_id not in self._download_locks:
 | 
			
		||||
                self._download_locks[model_id] = threading.Lock()
 | 
			
		||||
            
 | 
			
		||||
        # Download with model-specific lock (released _manager_lock to allow other models)
 | 
			
		||||
        download_lock = self._download_locks[model_id]
 | 
			
		||||
        with download_lock:
 | 
			
		||||
            # Double-check after acquiring download lock
 | 
			
		||||
            with self._manager_lock:
 | 
			
		||||
                if model_id in self._model_paths and os.path.exists(self._model_paths[model_id]):
 | 
			
		||||
                    mpta_file_path = self._mpta_file_paths.get(model_id)
 | 
			
		||||
                    if mpta_file_path and os.path.exists(mpta_file_path):
 | 
			
		||||
                        self._reference_counts[model_id] += 1
 | 
			
		||||
                        logger.info(f"📂 MPTA modelId {model_id} became available during wait (ref_count: {self._reference_counts[model_id]})")
 | 
			
		||||
                        return (self._model_paths[model_id], mpta_file_path)
 | 
			
		||||
            
 | 
			
		||||
            # Actually download and extract
 | 
			
		||||
            shared_path = f"models/{model_id}"
 | 
			
		||||
            logger.info(f"🔄 Downloading NEW MPTA for modelId {model_id} from {model_url}")
 | 
			
		||||
            
 | 
			
		||||
            try:
 | 
			
		||||
                # Ensure directory exists
 | 
			
		||||
                os.makedirs(shared_path, exist_ok=True)
 | 
			
		||||
                
 | 
			
		||||
                # Download MPTA file
 | 
			
		||||
                mpta_filename = self._extract_filename_from_url(model_url) or f"model_{model_id}.mpta"
 | 
			
		||||
                local_mpta_path = os.path.join(shared_path, mpta_filename)
 | 
			
		||||
                
 | 
			
		||||
                if not self._download_file(model_url, local_mpta_path):
 | 
			
		||||
                    logger.error(f"❌ Failed to download MPTA for modelId {model_id}")
 | 
			
		||||
                    return None
 | 
			
		||||
                
 | 
			
		||||
                # Extract MPTA
 | 
			
		||||
                pipeline_tree = load_pipeline_from_zip(local_mpta_path, shared_path)
 | 
			
		||||
                if pipeline_tree is None:
 | 
			
		||||
                    logger.error(f"❌ Failed to extract MPTA for modelId {model_id}")
 | 
			
		||||
                    return None
 | 
			
		||||
                
 | 
			
		||||
                # Success - register in manager
 | 
			
		||||
                with self._manager_lock:
 | 
			
		||||
                    self._model_paths[model_id] = shared_path
 | 
			
		||||
                    self._mpta_file_paths[model_id] = local_mpta_path
 | 
			
		||||
                    self._reference_counts[model_id] = 1
 | 
			
		||||
                    
 | 
			
		||||
                    logger.info(f"✅ MPTA modelId {model_id} downloaded and registered (ref_count: 1)")
 | 
			
		||||
                    self._log_manager_status()
 | 
			
		||||
                
 | 
			
		||||
                return (shared_path, local_mpta_path)
 | 
			
		||||
                
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error(f"❌ Error downloading/extracting MPTA for modelId {model_id}: {e}")
 | 
			
		||||
                # Clean up partial download
 | 
			
		||||
                if os.path.exists(shared_path):
 | 
			
		||||
                    shutil.rmtree(shared_path, ignore_errors=True)
 | 
			
		||||
                return None
 | 
			
		||||
    
 | 
			
		||||
    def release_mpta(self, model_id: int, camera_id: str) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Release a reference to an MPTA. If reference count reaches zero,
 | 
			
		||||
        the MPTA directory may be cleaned up to free disk space.
 | 
			
		||||
        
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: Unique identifier for the model to release
 | 
			
		||||
            camera_id: Identifier for the camera releasing the reference
 | 
			
		||||
        """
 | 
			
		||||
        with self._manager_lock:
 | 
			
		||||
            if model_id not in self._reference_counts:
 | 
			
		||||
                logger.warning(f"⚠️ Attempted to release unknown MPTA modelId {model_id} for camera {camera_id}")
 | 
			
		||||
                return
 | 
			
		||||
            
 | 
			
		||||
            # Remove camera from usage tracking
 | 
			
		||||
            if model_id in self._cameras_using_model:
 | 
			
		||||
                self._cameras_using_model[model_id].discard(camera_id)
 | 
			
		||||
            
 | 
			
		||||
            self._reference_counts[model_id] -= 1
 | 
			
		||||
            logger.info(f"📉 MPTA modelId {model_id} reference count decreased to {self._reference_counts[model_id]} (released by {camera_id})")
 | 
			
		||||
            
 | 
			
		||||
            # Clean up if no more references
 | 
			
		||||
            # if self._reference_counts[model_id] <= 0:
 | 
			
		||||
            #     self._cleanup_mpta(model_id)
 | 
			
		||||
    
 | 
			
		||||
    def _cleanup_mpta(self, model_id: int) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Internal method to clean up an MPTA directory and free disk space.
 | 
			
		||||
        """
 | 
			
		||||
        if model_id in self._model_paths:
 | 
			
		||||
            shared_path = self._model_paths[model_id]
 | 
			
		||||
            
 | 
			
		||||
            try:
 | 
			
		||||
                if os.path.exists(shared_path):
 | 
			
		||||
                    shutil.rmtree(shared_path)
 | 
			
		||||
                    logger.info(f"🗑️ Cleaned up MPTA directory: {shared_path}")
 | 
			
		||||
                
 | 
			
		||||
                # Remove from tracking
 | 
			
		||||
                del self._model_paths[model_id]
 | 
			
		||||
                self._mpta_file_paths.pop(model_id, None)
 | 
			
		||||
                del self._reference_counts[model_id]
 | 
			
		||||
                self._cameras_using_model.pop(model_id, None)
 | 
			
		||||
                
 | 
			
		||||
                # Clean up download lock (optional, could keep for future use)
 | 
			
		||||
                self._download_locks.pop(model_id, None)
 | 
			
		||||
                
 | 
			
		||||
                logger.info(f"✅ MPTA modelId {model_id} fully cleaned up and disk space freed")
 | 
			
		||||
                self._log_manager_status()
 | 
			
		||||
                
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error(f"❌ Error cleaning up MPTA modelId {model_id}: {e}")
 | 
			
		||||
    
 | 
			
		||||
    def get_shared_path(self, model_id: int) -> Optional[str]:
 | 
			
		||||
        """
 | 
			
		||||
        Get the shared extraction path for a modelId without downloading.
 | 
			
		||||
        
 | 
			
		||||
        Args:
 | 
			
		||||
            model_id: Model identifier to look up
 | 
			
		||||
            
 | 
			
		||||
        Returns:
 | 
			
		||||
            Shared path if available, None otherwise
 | 
			
		||||
        """
 | 
			
		||||
        with self._manager_lock:
 | 
			
		||||
            return self._model_paths.get(model_id)
 | 
			
		||||
    
 | 
			
		||||
    def get_manager_status(self) -> Dict:
 | 
			
		||||
        """
 | 
			
		||||
        Get current status of the MPTA manager.
 | 
			
		||||
        
 | 
			
		||||
        Returns:
 | 
			
		||||
            Dictionary with manager statistics
 | 
			
		||||
        """
 | 
			
		||||
        with self._manager_lock:
 | 
			
		||||
            return {
 | 
			
		||||
                "total_mpta_models": len(self._model_paths),
 | 
			
		||||
                "models": {
 | 
			
		||||
                    str(model_id): {
 | 
			
		||||
                        "shared_path": path,
 | 
			
		||||
                        "reference_count": self._reference_counts.get(model_id, 0),
 | 
			
		||||
                        "cameras_using": list(self._cameras_using_model.get(model_id, set()))
 | 
			
		||||
                    }
 | 
			
		||||
                    for model_id, path in self._model_paths.items()
 | 
			
		||||
                },
 | 
			
		||||
                "total_references": sum(self._reference_counts.values()),
 | 
			
		||||
                "active_downloads": len(self._download_locks)
 | 
			
		||||
            }
 | 
			
		||||
    
 | 
			
		||||
    def _log_manager_status(self) -> None:
 | 
			
		||||
        """Log current manager status for debugging."""
 | 
			
		||||
        status = self.get_manager_status()
 | 
			
		||||
        logger.info(f"📊 MPTA Manager Status: {status['total_mpta_models']} unique models, {status['total_references']} total references")
 | 
			
		||||
        for model_id, info in status['models'].items():
 | 
			
		||||
            cameras_str = ','.join(info['cameras_using'][:3])  # Show first 3 cameras
 | 
			
		||||
            if len(info['cameras_using']) > 3:
 | 
			
		||||
                cameras_str += f"+{len(info['cameras_using'])-3} more"
 | 
			
		||||
            logger.debug(f"   📋 ModelId {model_id}: refs={info['reference_count']}, cameras=[{cameras_str}]")
 | 
			
		||||
    
 | 
			
		||||
    def cleanup_all(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Clean up all MPTA directories. Used during shutdown.
 | 
			
		||||
        """
 | 
			
		||||
        with self._manager_lock:
 | 
			
		||||
            model_ids = list(self._model_paths.keys())
 | 
			
		||||
            logger.info(f"🧹 Cleaning up {len(model_ids)} MPTA directories")
 | 
			
		||||
            
 | 
			
		||||
            for model_id in model_ids:
 | 
			
		||||
                self._cleanup_mpta(model_id)
 | 
			
		||||
            
 | 
			
		||||
            # Clear all tracking data
 | 
			
		||||
            self._download_locks.clear()
 | 
			
		||||
            logger.info("✅ MPTA manager cleanup complete")
 | 
			
		||||
    
 | 
			
		||||
    def _download_file(self, url: str, local_path: str) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Download a file from URL to local path with progress logging.
 | 
			
		||||
        
 | 
			
		||||
        Args:
 | 
			
		||||
            url: URL to download from
 | 
			
		||||
            local_path: Local path to save to
 | 
			
		||||
            
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if successful, False otherwise
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            logger.info(f"⬇️ Starting download from {url}")
 | 
			
		||||
            
 | 
			
		||||
            response = requests.get(url, stream=True)
 | 
			
		||||
            response.raise_for_status()
 | 
			
		||||
            
 | 
			
		||||
            total_size = int(response.headers.get('content-length', 0))
 | 
			
		||||
            if total_size > 0:
 | 
			
		||||
                logger.info(f"📦 File size: {total_size / 1024 / 1024:.2f} MB")
 | 
			
		||||
            
 | 
			
		||||
            downloaded = 0
 | 
			
		||||
            last_logged_progress = 0
 | 
			
		||||
            with open(local_path, 'wb') as f:
 | 
			
		||||
                for chunk in response.iter_content(chunk_size=8192):
 | 
			
		||||
                    if chunk:
 | 
			
		||||
                        f.write(chunk)
 | 
			
		||||
                        downloaded += len(chunk)
 | 
			
		||||
                        
 | 
			
		||||
                        if total_size > 0:
 | 
			
		||||
                            progress = int((downloaded / total_size) * 100)
 | 
			
		||||
                            # Log at 10% intervals (10%, 20%, 30%, etc.)
 | 
			
		||||
                            if progress >= last_logged_progress + 10 and progress <= 100:
 | 
			
		||||
                                logger.debug(f"Download progress: {progress}%")
 | 
			
		||||
                                last_logged_progress = progress
 | 
			
		||||
            
 | 
			
		||||
            logger.info(f"✅ Successfully downloaded to {local_path}")
 | 
			
		||||
            return True
 | 
			
		||||
            
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error(f"❌ Download failed: {e}")
 | 
			
		||||
            # Clean up partial file
 | 
			
		||||
            if os.path.exists(local_path):
 | 
			
		||||
                os.remove(local_path)
 | 
			
		||||
            return False
 | 
			
		||||
    
 | 
			
		||||
    def _extract_filename_from_url(self, url: str) -> Optional[str]:
 | 
			
		||||
        """Extract filename from URL."""
 | 
			
		||||
        try:
 | 
			
		||||
            parsed = urlparse(url)
 | 
			
		||||
            filename = os.path.basename(parsed.path)
 | 
			
		||||
            return filename if filename else None
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Global singleton instance
 | 
			
		||||
_mpta_manager = MPTAManager()
 | 
			
		||||
 | 
			
		||||
def get_or_download_mpta(model_id: int, model_url: str, camera_id: str) -> Optional[tuple[str, str]]:
 | 
			
		||||
    """
 | 
			
		||||
    Convenience function to get or download a shared MPTA.
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
        model_id: Unique identifier for the model
 | 
			
		||||
        model_url: URL to download the MPTA file from
 | 
			
		||||
        camera_id: Identifier for the requesting camera
 | 
			
		||||
        
 | 
			
		||||
    Returns:
 | 
			
		||||
        Tuple of (extraction_path, mpta_file_path), or None if failed
 | 
			
		||||
    """
 | 
			
		||||
    return _mpta_manager.get_or_download_mpta(model_id, model_url, camera_id)
 | 
			
		||||
 | 
			
		||||
def release_mpta(model_id: int, camera_id: str) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Convenience function to release a shared MPTA reference.
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
        model_id: Unique identifier for the model to release
 | 
			
		||||
        camera_id: Identifier for the camera releasing the reference
 | 
			
		||||
    """
 | 
			
		||||
    _mpta_manager.release_mpta(model_id, camera_id)
 | 
			
		||||
 | 
			
		||||
def get_mpta_manager_status() -> Dict:
 | 
			
		||||
    """
 | 
			
		||||
    Convenience function to get MPTA manager status.
 | 
			
		||||
    
 | 
			
		||||
    Returns:
 | 
			
		||||
        Dictionary with manager statistics
 | 
			
		||||
    """
 | 
			
		||||
    return _mpta_manager.get_manager_status()
 | 
			
		||||
 | 
			
		||||
def cleanup_mpta_manager() -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Convenience function to cleanup the entire MPTA manager.
 | 
			
		||||
    """
 | 
			
		||||
    _mpta_manager.cleanup_all()
 | 
			
		||||
| 
						 | 
				
			
			@ -13,6 +13,7 @@ import concurrent.futures
 | 
			
		|||
from ultralytics import YOLO
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
from .database import DatabaseManager
 | 
			
		||||
from .model_registry import get_shared_model, release_shared_model
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
 | 
			
		||||
# Create a logger specifically for this module
 | 
			
		||||
| 
						 | 
				
			
			@ -98,13 +99,11 @@ def load_pipeline_node(node_config: dict, mpta_dir: str, redis_client, db_manage
 | 
			
		|||
        logger.error(f"Model file {model_path} not found. Current directory: {os.getcwd()}")
 | 
			
		||||
        logger.error(f"Directory content: {os.listdir(os.path.dirname(model_path))}")
 | 
			
		||||
        raise FileNotFoundError(f"Model file {model_path} not found.")
 | 
			
		||||
    logger.info(f"Loading model for node {node_config['modelId']} from {model_path}")
 | 
			
		||||
    model = YOLO(model_path)
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        logger.info(f"CUDA available. Moving model {node_config['modelId']} to GPU VRAM")
 | 
			
		||||
        model.to("cuda")
 | 
			
		||||
    else:
 | 
			
		||||
        logger.info(f"CUDA not available. Using CPU for model {node_config['modelId']}")
 | 
			
		||||
    
 | 
			
		||||
    # Use shared model registry to prevent duplicate loading
 | 
			
		||||
    model_id = node_config['modelId']
 | 
			
		||||
    logger.info(f"Getting shared model for node {model_id} from {model_path}")
 | 
			
		||||
    model = get_shared_model(model_id, model_path)
 | 
			
		||||
 | 
			
		||||
    # Prepare trigger class indices for optimization
 | 
			
		||||
    trigger_classes = node_config.get("triggerClasses", [])
 | 
			
		||||
| 
						 | 
				
			
			@ -1108,6 +1107,17 @@ def is_camera_active(camera_id, model_id):
 | 
			
		|||
    
 | 
			
		||||
    return session_state.get("active", True)
 | 
			
		||||
 | 
			
		||||
def cleanup_pipeline_node(node: dict):
 | 
			
		||||
    """Clean up a pipeline node and release its model reference."""
 | 
			
		||||
    if node and "modelId" in node:
 | 
			
		||||
        model_id = node["modelId"]
 | 
			
		||||
        logger.info(f"🧹 Cleaning up pipeline node: {model_id}")
 | 
			
		||||
        release_shared_model(model_id)
 | 
			
		||||
        
 | 
			
		||||
        # Recursively clean up branches
 | 
			
		||||
        for branch in node.get("branches", []):
 | 
			
		||||
            cleanup_pipeline_node(branch)
 | 
			
		||||
 | 
			
		||||
def cleanup_camera_stability(camera_id):
 | 
			
		||||
    """Clean up stability tracking data when a camera is disconnected."""
 | 
			
		||||
    global _camera_stability_tracking
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue