447 lines
		
	
	
		
			No EOL
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			447 lines
		
	
	
		
			No EOL
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
YOLO Model Inference Wrapper - Handles model loading and inference optimization
 | 
						|
"""
 | 
						|
 | 
						|
import logging
 | 
						|
import torch
 | 
						|
import numpy as np
 | 
						|
from pathlib import Path
 | 
						|
from typing import Dict, List, Optional, Any, Tuple, Union
 | 
						|
from threading import Lock
 | 
						|
from dataclasses import dataclass
 | 
						|
import cv2
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class Detection:
 | 
						|
    """Represents a single detection result"""
 | 
						|
    bbox: List[float]  # [x1, y1, x2, y2]
 | 
						|
    confidence: float
 | 
						|
    class_id: int
 | 
						|
    class_name: str
 | 
						|
    track_id: Optional[int] = None
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class InferenceResult:
 | 
						|
    """Result from model inference"""
 | 
						|
    detections: List[Detection]
 | 
						|
    image_shape: Tuple[int, int]  # (height, width)
 | 
						|
    inference_time: float
 | 
						|
    model_id: str
 | 
						|
 | 
						|
 | 
						|
class YOLOWrapper:
 | 
						|
    """Wrapper for YOLO models with caching and optimization"""
 | 
						|
 | 
						|
    # Class-level model cache shared across all instances
 | 
						|
    _model_cache: Dict[str, Any] = {}
 | 
						|
    _cache_lock = Lock()
 | 
						|
 | 
						|
    def __init__(self, model_path: Path, model_id: str, device: Optional[str] = None):
 | 
						|
        """
 | 
						|
        Initialize YOLO wrapper
 | 
						|
 | 
						|
        Args:
 | 
						|
            model_path: Path to the .pt model file
 | 
						|
            model_id: Unique identifier for the model
 | 
						|
            device: Device to run inference on ('cuda', 'cpu', or None for auto)
 | 
						|
        """
 | 
						|
        self.model_path = model_path
 | 
						|
        self.model_id = model_id
 | 
						|
 | 
						|
        # Auto-detect device if not specified
 | 
						|
        if device is None:
 | 
						|
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
 | 
						|
        else:
 | 
						|
            self.device = device
 | 
						|
 | 
						|
        self.model = None
 | 
						|
        self._class_names = []
 | 
						|
 | 
						|
 | 
						|
        self._load_model()
 | 
						|
 | 
						|
        logger.info(f"Initialized YOLO wrapper for {model_id} on {self.device}")
 | 
						|
 | 
						|
    def _load_model(self) -> None:
 | 
						|
        """Load the YOLO model with caching"""
 | 
						|
        cache_key = str(self.model_path)
 | 
						|
 | 
						|
        with self._cache_lock:
 | 
						|
            # Check if model is already cached
 | 
						|
            if cache_key in self._model_cache:
 | 
						|
                logger.info(f"Loading model {self.model_id} from cache")
 | 
						|
                self.model = self._model_cache[cache_key]
 | 
						|
                self._extract_class_names()
 | 
						|
                return
 | 
						|
 | 
						|
            # Load model
 | 
						|
            try:
 | 
						|
                from ultralytics import YOLO
 | 
						|
 | 
						|
                logger.info(f"Loading YOLO model from {self.model_path}")
 | 
						|
                self.model = YOLO(str(self.model_path))
 | 
						|
 | 
						|
                # Move model to device
 | 
						|
                if self.device == 'cuda' and torch.cuda.is_available():
 | 
						|
                    self.model.to('cuda')
 | 
						|
                    logger.info(f"Model {self.model_id} moved to GPU")
 | 
						|
 | 
						|
                # Cache the model
 | 
						|
                self._model_cache[cache_key] = self.model
 | 
						|
                self._extract_class_names()
 | 
						|
 | 
						|
                logger.info(f"Successfully loaded model {self.model_id}")
 | 
						|
 | 
						|
            except ImportError:
 | 
						|
                logger.error("Ultralytics YOLO not installed. Install with: pip install ultralytics")
 | 
						|
                raise
 | 
						|
            except Exception as e:
 | 
						|
                logger.error(f"Failed to load YOLO model {self.model_id}: {str(e)}", exc_info=True)
 | 
						|
                raise
 | 
						|
 | 
						|
    def _extract_class_names(self) -> None:
 | 
						|
        """Extract class names from the model"""
 | 
						|
        try:
 | 
						|
            if hasattr(self.model, 'names'):
 | 
						|
                self._class_names = self.model.names
 | 
						|
            elif hasattr(self.model, 'model') and hasattr(self.model.model, 'names'):
 | 
						|
                self._class_names = self.model.model.names
 | 
						|
            else:
 | 
						|
                logger.warning(f"Could not extract class names from model {self.model_id}")
 | 
						|
                self._class_names = {}
 | 
						|
        except Exception as e:
 | 
						|
            logger.error(f"Failed to extract class names: {str(e)}")
 | 
						|
            self._class_names = {}
 | 
						|
 | 
						|
 | 
						|
    def infer(
 | 
						|
        self,
 | 
						|
        image: np.ndarray,
 | 
						|
        confidence_threshold: float = 0.5,
 | 
						|
        trigger_classes: Optional[List[str]] = None,
 | 
						|
        iou_threshold: float = 0.45
 | 
						|
    ) -> InferenceResult:
 | 
						|
        """
 | 
						|
        Run inference on an image
 | 
						|
 | 
						|
        Args:
 | 
						|
            image: Input image as numpy array (BGR format)
 | 
						|
            confidence_threshold: Minimum confidence for detections
 | 
						|
            trigger_classes: List of class names to filter (None = all classes)
 | 
						|
            iou_threshold: IoU threshold for NMS
 | 
						|
 | 
						|
        Returns:
 | 
						|
            InferenceResult containing detections
 | 
						|
        """
 | 
						|
        if self.model is None:
 | 
						|
            raise RuntimeError(f"Model {self.model_id} not loaded")
 | 
						|
 | 
						|
        try:
 | 
						|
            import time
 | 
						|
            start_time = time.time()
 | 
						|
 | 
						|
            # Run inference
 | 
						|
            results = self.model(
 | 
						|
                image,
 | 
						|
                conf=confidence_threshold,
 | 
						|
                iou=iou_threshold,
 | 
						|
                verbose=False
 | 
						|
            )
 | 
						|
 | 
						|
            inference_time = time.time() - start_time
 | 
						|
 | 
						|
            # Parse results
 | 
						|
            detections = self._parse_results(results[0], trigger_classes)
 | 
						|
 | 
						|
            return InferenceResult(
 | 
						|
                detections=detections,
 | 
						|
                image_shape=(image.shape[0], image.shape[1]),
 | 
						|
                inference_time=inference_time,
 | 
						|
                model_id=self.model_id
 | 
						|
            )
 | 
						|
 | 
						|
        except Exception as e:
 | 
						|
            logger.error(f"Inference failed for model {self.model_id}: {str(e)}", exc_info=True)
 | 
						|
            raise
 | 
						|
 | 
						|
    def _parse_results(
 | 
						|
        self,
 | 
						|
        result: Any,
 | 
						|
        trigger_classes: Optional[List[str]] = None
 | 
						|
    ) -> List[Detection]:
 | 
						|
        """
 | 
						|
        Parse YOLO results into Detection objects
 | 
						|
 | 
						|
        Args:
 | 
						|
            result: YOLO result object
 | 
						|
            trigger_classes: Optional list of class names to filter
 | 
						|
 | 
						|
        Returns:
 | 
						|
            List of Detection objects
 | 
						|
        """
 | 
						|
        detections = []
 | 
						|
 | 
						|
        try:
 | 
						|
            if result.boxes is None:
 | 
						|
                return detections
 | 
						|
 | 
						|
            boxes = result.boxes
 | 
						|
            for i in range(len(boxes)):
 | 
						|
                # Get box coordinates
 | 
						|
                box = boxes.xyxy[i].cpu().numpy()
 | 
						|
                x1, y1, x2, y2 = box
 | 
						|
 | 
						|
                # Get confidence and class
 | 
						|
                conf = float(boxes.conf[i])
 | 
						|
                cls_id = int(boxes.cls[i])
 | 
						|
 | 
						|
                # Get class name
 | 
						|
                class_name = self._class_names.get(cls_id, f"class_{cls_id}")
 | 
						|
 | 
						|
                # Filter by trigger classes if specified
 | 
						|
                if trigger_classes and class_name not in trigger_classes:
 | 
						|
                    continue
 | 
						|
 | 
						|
                # Get track ID if available
 | 
						|
                track_id = None
 | 
						|
                if hasattr(boxes, 'id') and boxes.id is not None:
 | 
						|
                    track_id = int(boxes.id[i])
 | 
						|
 | 
						|
                detection = Detection(
 | 
						|
                    bbox=[float(x1), float(y1), float(x2), float(y2)],
 | 
						|
                    confidence=conf,
 | 
						|
                    class_id=cls_id,
 | 
						|
                    class_name=class_name,
 | 
						|
                    track_id=track_id
 | 
						|
                )
 | 
						|
                detections.append(detection)
 | 
						|
 | 
						|
        except Exception as e:
 | 
						|
            logger.error(f"Failed to parse results: {str(e)}", exc_info=True)
 | 
						|
 | 
						|
        return detections
 | 
						|
 | 
						|
 | 
						|
    def track(
 | 
						|
        self,
 | 
						|
        image: np.ndarray,
 | 
						|
        confidence_threshold: float = 0.5,
 | 
						|
        trigger_classes: Optional[List[str]] = None,
 | 
						|
        persist: bool = True,
 | 
						|
        camera_id: Optional[str] = None
 | 
						|
    ) -> InferenceResult:
 | 
						|
        """
 | 
						|
        Run detection (tracking will be handled by external tracker)
 | 
						|
 | 
						|
        Args:
 | 
						|
            image: Input image as numpy array (BGR format)
 | 
						|
            confidence_threshold: Minimum confidence for detections
 | 
						|
            trigger_classes: List of class names to filter
 | 
						|
            persist: Ignored - tracking handled externally
 | 
						|
            camera_id: Ignored - tracking handled externally
 | 
						|
 | 
						|
        Returns:
 | 
						|
            InferenceResult containing detections (no track IDs from YOLO)
 | 
						|
        """
 | 
						|
        # Just do detection - no YOLO tracking
 | 
						|
        return self.infer(image, confidence_threshold, trigger_classes)
 | 
						|
 | 
						|
    def predict_classification(
 | 
						|
        self,
 | 
						|
        image: np.ndarray,
 | 
						|
        top_k: int = 1
 | 
						|
    ) -> Dict[str, float]:
 | 
						|
        """
 | 
						|
        Run classification on an image
 | 
						|
 | 
						|
        Args:
 | 
						|
            image: Input image as numpy array (BGR format)
 | 
						|
            top_k: Number of top predictions to return
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Dictionary of class_name -> confidence scores
 | 
						|
        """
 | 
						|
        if self.model is None:
 | 
						|
            raise RuntimeError(f"Model {self.model_id} not loaded")
 | 
						|
 | 
						|
        try:
 | 
						|
            # Run inference
 | 
						|
            results = self.model(image, verbose=False)
 | 
						|
 | 
						|
            # For classification models, extract probabilities
 | 
						|
            if hasattr(results[0], 'probs'):
 | 
						|
                probs = results[0].probs
 | 
						|
                top_indices = probs.top5[:top_k]
 | 
						|
                top_conf = probs.top5conf[:top_k].cpu().numpy()
 | 
						|
 | 
						|
                predictions = {}
 | 
						|
                for idx, conf in zip(top_indices, top_conf):
 | 
						|
                    class_name = self._class_names.get(int(idx), f"class_{idx}")
 | 
						|
                    predictions[class_name] = float(conf)
 | 
						|
 | 
						|
                return predictions
 | 
						|
            else:
 | 
						|
                logger.warning(f"Model {self.model_id} does not support classification")
 | 
						|
                return {}
 | 
						|
 | 
						|
        except Exception as e:
 | 
						|
            logger.error(f"Classification failed for model {self.model_id}: {str(e)}", exc_info=True)
 | 
						|
            raise
 | 
						|
 | 
						|
    def crop_detection(
 | 
						|
        self,
 | 
						|
        image: np.ndarray,
 | 
						|
        detection: Detection,
 | 
						|
        padding: int = 0
 | 
						|
    ) -> np.ndarray:
 | 
						|
        """
 | 
						|
        Crop image to detection bounding box
 | 
						|
 | 
						|
        Args:
 | 
						|
            image: Original image
 | 
						|
            detection: Detection to crop
 | 
						|
            padding: Additional padding around the box
 | 
						|
 | 
						|
        Returns:
 | 
						|
            Cropped image region
 | 
						|
        """
 | 
						|
        h, w = image.shape[:2]
 | 
						|
        x1, y1, x2, y2 = detection.bbox
 | 
						|
 | 
						|
        # Add padding and clip to image boundaries
 | 
						|
        x1 = max(0, int(x1) - padding)
 | 
						|
        y1 = max(0, int(y1) - padding)
 | 
						|
        x2 = min(w, int(x2) + padding)
 | 
						|
        y2 = min(h, int(y2) + padding)
 | 
						|
 | 
						|
        return image[y1:y2, x1:x2]
 | 
						|
 | 
						|
    def get_class_names(self) -> Dict[int, str]:
 | 
						|
        """Get the class names dictionary"""
 | 
						|
        return self._class_names.copy()
 | 
						|
 | 
						|
    def get_num_classes(self) -> int:
 | 
						|
        """Get the number of classes the model can detect"""
 | 
						|
        return len(self._class_names)
 | 
						|
 | 
						|
 | 
						|
    def clear_cache(self) -> None:
 | 
						|
        """Clear the model cache"""
 | 
						|
        with self._cache_lock:
 | 
						|
            cache_key = str(self.model_path)
 | 
						|
            if cache_key in self._model_cache:
 | 
						|
                del self._model_cache[cache_key]
 | 
						|
                logger.info(f"Cleared cache for model {self.model_id}")
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def clear_all_cache(cls) -> None:
 | 
						|
        """Clear all cached models"""
 | 
						|
        with cls._cache_lock:
 | 
						|
            cls._model_cache.clear()
 | 
						|
            logger.info("Cleared all model cache")
 | 
						|
 | 
						|
    def warmup(self, image_size: Tuple[int, int] = (640, 640)) -> None:
 | 
						|
        """
 | 
						|
        Warmup the model with a dummy inference
 | 
						|
 | 
						|
        Args:
 | 
						|
            image_size: Size of dummy image (height, width)
 | 
						|
        """
 | 
						|
        try:
 | 
						|
            dummy_image = np.zeros((image_size[0], image_size[1], 3), dtype=np.uint8)
 | 
						|
            self.infer(dummy_image, confidence_threshold=0.5)
 | 
						|
            logger.info(f"Model {self.model_id} warmed up")
 | 
						|
        except Exception as e:
 | 
						|
            logger.warning(f"Failed to warmup model {self.model_id}: {str(e)}")
 | 
						|
 | 
						|
 | 
						|
class ModelInferenceManager:
 | 
						|
    """Manages multiple YOLO models for a pipeline"""
 | 
						|
 | 
						|
    def __init__(self, model_dir: Path):
 | 
						|
        """
 | 
						|
        Initialize the inference manager
 | 
						|
 | 
						|
        Args:
 | 
						|
            model_dir: Directory containing model files
 | 
						|
        """
 | 
						|
        self.model_dir = model_dir
 | 
						|
        self.models: Dict[str, YOLOWrapper] = {}
 | 
						|
        self._lock = Lock()
 | 
						|
 | 
						|
        logger.info(f"Initialized ModelInferenceManager with model directory: {model_dir}")
 | 
						|
 | 
						|
    def load_model(
 | 
						|
        self,
 | 
						|
        model_id: str,
 | 
						|
        model_file: str,
 | 
						|
        device: Optional[str] = None
 | 
						|
    ) -> YOLOWrapper:
 | 
						|
        """
 | 
						|
        Load a model for inference
 | 
						|
 | 
						|
        Args:
 | 
						|
            model_id: Unique identifier for the model
 | 
						|
            model_file: Filename of the model
 | 
						|
            device: Device to run on
 | 
						|
 | 
						|
        Returns:
 | 
						|
            YOLOWrapper instance
 | 
						|
        """
 | 
						|
        with self._lock:
 | 
						|
            # Check if already loaded
 | 
						|
            if model_id in self.models:
 | 
						|
                logger.debug(f"Model {model_id} already loaded")
 | 
						|
                return self.models[model_id]
 | 
						|
 | 
						|
            # Load the model
 | 
						|
            model_path = self.model_dir / model_file
 | 
						|
            if not model_path.exists():
 | 
						|
                raise FileNotFoundError(f"Model file not found: {model_path}")
 | 
						|
 | 
						|
            wrapper = YOLOWrapper(model_path, model_id, device)
 | 
						|
            self.models[model_id] = wrapper
 | 
						|
 | 
						|
            return wrapper
 | 
						|
 | 
						|
    def get_model(self, model_id: str) -> Optional[YOLOWrapper]:
 | 
						|
        """
 | 
						|
        Get a loaded model
 | 
						|
 | 
						|
        Args:
 | 
						|
            model_id: Model identifier
 | 
						|
 | 
						|
        Returns:
 | 
						|
            YOLOWrapper instance or None if not loaded
 | 
						|
        """
 | 
						|
        return self.models.get(model_id)
 | 
						|
 | 
						|
    def unload_model(self, model_id: str) -> bool:
 | 
						|
        """
 | 
						|
        Unload a model to free memory
 | 
						|
 | 
						|
        Args:
 | 
						|
            model_id: Model identifier
 | 
						|
 | 
						|
        Returns:
 | 
						|
            True if unloaded, False if not found
 | 
						|
        """
 | 
						|
        with self._lock:
 | 
						|
            if model_id in self.models:
 | 
						|
                self.models[model_id].clear_cache()
 | 
						|
                del self.models[model_id]
 | 
						|
                logger.info(f"Unloaded model {model_id}")
 | 
						|
                return True
 | 
						|
            return False
 | 
						|
 | 
						|
    def unload_all(self) -> None:
 | 
						|
        """Unload all models"""
 | 
						|
        with self._lock:
 | 
						|
            for model_id in list(self.models.keys()):
 | 
						|
                self.models[model_id].clear_cache()
 | 
						|
            self.models.clear()
 | 
						|
            logger.info("Unloaded all models") |