447 lines
No EOL
14 KiB
Python
447 lines
No EOL
14 KiB
Python
"""
|
|
YOLO Model Inference Wrapper - Handles model loading and inference optimization
|
|
"""
|
|
|
|
import logging
|
|
import torch
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Any, Tuple, Union
|
|
from threading import Lock
|
|
from dataclasses import dataclass
|
|
import cv2
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Detection:
|
|
"""Represents a single detection result"""
|
|
bbox: List[float] # [x1, y1, x2, y2]
|
|
confidence: float
|
|
class_id: int
|
|
class_name: str
|
|
track_id: Optional[int] = None
|
|
|
|
|
|
@dataclass
|
|
class InferenceResult:
|
|
"""Result from model inference"""
|
|
detections: List[Detection]
|
|
image_shape: Tuple[int, int] # (height, width)
|
|
inference_time: float
|
|
model_id: str
|
|
|
|
|
|
class YOLOWrapper:
|
|
"""Wrapper for YOLO models with caching and optimization"""
|
|
|
|
# Class-level model cache shared across all instances
|
|
_model_cache: Dict[str, Any] = {}
|
|
_cache_lock = Lock()
|
|
|
|
def __init__(self, model_path: Path, model_id: str, device: Optional[str] = None):
|
|
"""
|
|
Initialize YOLO wrapper
|
|
|
|
Args:
|
|
model_path: Path to the .pt model file
|
|
model_id: Unique identifier for the model
|
|
device: Device to run inference on ('cuda', 'cpu', or None for auto)
|
|
"""
|
|
self.model_path = model_path
|
|
self.model_id = model_id
|
|
|
|
# Auto-detect device if not specified
|
|
if device is None:
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
else:
|
|
self.device = device
|
|
|
|
self.model = None
|
|
self._class_names = []
|
|
|
|
|
|
self._load_model()
|
|
|
|
logger.info(f"Initialized YOLO wrapper for {model_id} on {self.device}")
|
|
|
|
def _load_model(self) -> None:
|
|
"""Load the YOLO model with caching"""
|
|
cache_key = str(self.model_path)
|
|
|
|
with self._cache_lock:
|
|
# Check if model is already cached
|
|
if cache_key in self._model_cache:
|
|
logger.info(f"Loading model {self.model_id} from cache")
|
|
self.model = self._model_cache[cache_key]
|
|
self._extract_class_names()
|
|
return
|
|
|
|
# Load model
|
|
try:
|
|
from ultralytics import YOLO
|
|
|
|
logger.info(f"Loading YOLO model from {self.model_path}")
|
|
self.model = YOLO(str(self.model_path))
|
|
|
|
# Move model to device
|
|
if self.device == 'cuda' and torch.cuda.is_available():
|
|
self.model.to('cuda')
|
|
logger.info(f"Model {self.model_id} moved to GPU")
|
|
|
|
# Cache the model
|
|
self._model_cache[cache_key] = self.model
|
|
self._extract_class_names()
|
|
|
|
logger.info(f"Successfully loaded model {self.model_id}")
|
|
|
|
except ImportError:
|
|
logger.error("Ultralytics YOLO not installed. Install with: pip install ultralytics")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Failed to load YOLO model {self.model_id}: {str(e)}", exc_info=True)
|
|
raise
|
|
|
|
def _extract_class_names(self) -> None:
|
|
"""Extract class names from the model"""
|
|
try:
|
|
if hasattr(self.model, 'names'):
|
|
self._class_names = self.model.names
|
|
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'names'):
|
|
self._class_names = self.model.model.names
|
|
else:
|
|
logger.warning(f"Could not extract class names from model {self.model_id}")
|
|
self._class_names = {}
|
|
except Exception as e:
|
|
logger.error(f"Failed to extract class names: {str(e)}")
|
|
self._class_names = {}
|
|
|
|
|
|
def infer(
|
|
self,
|
|
image: np.ndarray,
|
|
confidence_threshold: float = 0.5,
|
|
trigger_classes: Optional[List[str]] = None,
|
|
iou_threshold: float = 0.45
|
|
) -> InferenceResult:
|
|
"""
|
|
Run inference on an image
|
|
|
|
Args:
|
|
image: Input image as numpy array (BGR format)
|
|
confidence_threshold: Minimum confidence for detections
|
|
trigger_classes: List of class names to filter (None = all classes)
|
|
iou_threshold: IoU threshold for NMS
|
|
|
|
Returns:
|
|
InferenceResult containing detections
|
|
"""
|
|
if self.model is None:
|
|
raise RuntimeError(f"Model {self.model_id} not loaded")
|
|
|
|
try:
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Run inference
|
|
results = self.model(
|
|
image,
|
|
conf=confidence_threshold,
|
|
iou=iou_threshold,
|
|
verbose=False
|
|
)
|
|
|
|
inference_time = time.time() - start_time
|
|
|
|
# Parse results
|
|
detections = self._parse_results(results[0], trigger_classes)
|
|
|
|
return InferenceResult(
|
|
detections=detections,
|
|
image_shape=(image.shape[0], image.shape[1]),
|
|
inference_time=inference_time,
|
|
model_id=self.model_id
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Inference failed for model {self.model_id}: {str(e)}", exc_info=True)
|
|
raise
|
|
|
|
def _parse_results(
|
|
self,
|
|
result: Any,
|
|
trigger_classes: Optional[List[str]] = None
|
|
) -> List[Detection]:
|
|
"""
|
|
Parse YOLO results into Detection objects
|
|
|
|
Args:
|
|
result: YOLO result object
|
|
trigger_classes: Optional list of class names to filter
|
|
|
|
Returns:
|
|
List of Detection objects
|
|
"""
|
|
detections = []
|
|
|
|
try:
|
|
if result.boxes is None:
|
|
return detections
|
|
|
|
boxes = result.boxes
|
|
for i in range(len(boxes)):
|
|
# Get box coordinates
|
|
box = boxes.xyxy[i].cpu().numpy()
|
|
x1, y1, x2, y2 = box
|
|
|
|
# Get confidence and class
|
|
conf = float(boxes.conf[i])
|
|
cls_id = int(boxes.cls[i])
|
|
|
|
# Get class name
|
|
class_name = self._class_names.get(cls_id, f"class_{cls_id}")
|
|
|
|
# Filter by trigger classes if specified
|
|
if trigger_classes and class_name not in trigger_classes:
|
|
continue
|
|
|
|
# Get track ID if available
|
|
track_id = None
|
|
if hasattr(boxes, 'id') and boxes.id is not None:
|
|
track_id = int(boxes.id[i])
|
|
|
|
detection = Detection(
|
|
bbox=[float(x1), float(y1), float(x2), float(y2)],
|
|
confidence=conf,
|
|
class_id=cls_id,
|
|
class_name=class_name,
|
|
track_id=track_id
|
|
)
|
|
detections.append(detection)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse results: {str(e)}", exc_info=True)
|
|
|
|
return detections
|
|
|
|
|
|
def track(
|
|
self,
|
|
image: np.ndarray,
|
|
confidence_threshold: float = 0.5,
|
|
trigger_classes: Optional[List[str]] = None,
|
|
persist: bool = True,
|
|
camera_id: Optional[str] = None
|
|
) -> InferenceResult:
|
|
"""
|
|
Run detection (tracking will be handled by external tracker)
|
|
|
|
Args:
|
|
image: Input image as numpy array (BGR format)
|
|
confidence_threshold: Minimum confidence for detections
|
|
trigger_classes: List of class names to filter
|
|
persist: Ignored - tracking handled externally
|
|
camera_id: Ignored - tracking handled externally
|
|
|
|
Returns:
|
|
InferenceResult containing detections (no track IDs from YOLO)
|
|
"""
|
|
# Just do detection - no YOLO tracking
|
|
return self.infer(image, confidence_threshold, trigger_classes)
|
|
|
|
def predict_classification(
|
|
self,
|
|
image: np.ndarray,
|
|
top_k: int = 1
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Run classification on an image
|
|
|
|
Args:
|
|
image: Input image as numpy array (BGR format)
|
|
top_k: Number of top predictions to return
|
|
|
|
Returns:
|
|
Dictionary of class_name -> confidence scores
|
|
"""
|
|
if self.model is None:
|
|
raise RuntimeError(f"Model {self.model_id} not loaded")
|
|
|
|
try:
|
|
# Run inference
|
|
results = self.model(image, verbose=False)
|
|
|
|
# For classification models, extract probabilities
|
|
if hasattr(results[0], 'probs'):
|
|
probs = results[0].probs
|
|
top_indices = probs.top5[:top_k]
|
|
top_conf = probs.top5conf[:top_k].cpu().numpy()
|
|
|
|
predictions = {}
|
|
for idx, conf in zip(top_indices, top_conf):
|
|
class_name = self._class_names.get(int(idx), f"class_{idx}")
|
|
predictions[class_name] = float(conf)
|
|
|
|
return predictions
|
|
else:
|
|
logger.warning(f"Model {self.model_id} does not support classification")
|
|
return {}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Classification failed for model {self.model_id}: {str(e)}", exc_info=True)
|
|
raise
|
|
|
|
def crop_detection(
|
|
self,
|
|
image: np.ndarray,
|
|
detection: Detection,
|
|
padding: int = 0
|
|
) -> np.ndarray:
|
|
"""
|
|
Crop image to detection bounding box
|
|
|
|
Args:
|
|
image: Original image
|
|
detection: Detection to crop
|
|
padding: Additional padding around the box
|
|
|
|
Returns:
|
|
Cropped image region
|
|
"""
|
|
h, w = image.shape[:2]
|
|
x1, y1, x2, y2 = detection.bbox
|
|
|
|
# Add padding and clip to image boundaries
|
|
x1 = max(0, int(x1) - padding)
|
|
y1 = max(0, int(y1) - padding)
|
|
x2 = min(w, int(x2) + padding)
|
|
y2 = min(h, int(y2) + padding)
|
|
|
|
return image[y1:y2, x1:x2]
|
|
|
|
def get_class_names(self) -> Dict[int, str]:
|
|
"""Get the class names dictionary"""
|
|
return self._class_names.copy()
|
|
|
|
def get_num_classes(self) -> int:
|
|
"""Get the number of classes the model can detect"""
|
|
return len(self._class_names)
|
|
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear the model cache"""
|
|
with self._cache_lock:
|
|
cache_key = str(self.model_path)
|
|
if cache_key in self._model_cache:
|
|
del self._model_cache[cache_key]
|
|
logger.info(f"Cleared cache for model {self.model_id}")
|
|
|
|
@classmethod
|
|
def clear_all_cache(cls) -> None:
|
|
"""Clear all cached models"""
|
|
with cls._cache_lock:
|
|
cls._model_cache.clear()
|
|
logger.info("Cleared all model cache")
|
|
|
|
def warmup(self, image_size: Tuple[int, int] = (640, 640)) -> None:
|
|
"""
|
|
Warmup the model with a dummy inference
|
|
|
|
Args:
|
|
image_size: Size of dummy image (height, width)
|
|
"""
|
|
try:
|
|
dummy_image = np.zeros((image_size[0], image_size[1], 3), dtype=np.uint8)
|
|
self.infer(dummy_image, confidence_threshold=0.5)
|
|
logger.info(f"Model {self.model_id} warmed up")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to warmup model {self.model_id}: {str(e)}")
|
|
|
|
|
|
class ModelInferenceManager:
|
|
"""Manages multiple YOLO models for a pipeline"""
|
|
|
|
def __init__(self, model_dir: Path):
|
|
"""
|
|
Initialize the inference manager
|
|
|
|
Args:
|
|
model_dir: Directory containing model files
|
|
"""
|
|
self.model_dir = model_dir
|
|
self.models: Dict[str, YOLOWrapper] = {}
|
|
self._lock = Lock()
|
|
|
|
logger.info(f"Initialized ModelInferenceManager with model directory: {model_dir}")
|
|
|
|
def load_model(
|
|
self,
|
|
model_id: str,
|
|
model_file: str,
|
|
device: Optional[str] = None
|
|
) -> YOLOWrapper:
|
|
"""
|
|
Load a model for inference
|
|
|
|
Args:
|
|
model_id: Unique identifier for the model
|
|
model_file: Filename of the model
|
|
device: Device to run on
|
|
|
|
Returns:
|
|
YOLOWrapper instance
|
|
"""
|
|
with self._lock:
|
|
# Check if already loaded
|
|
if model_id in self.models:
|
|
logger.debug(f"Model {model_id} already loaded")
|
|
return self.models[model_id]
|
|
|
|
# Load the model
|
|
model_path = self.model_dir / model_file
|
|
if not model_path.exists():
|
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
|
|
wrapper = YOLOWrapper(model_path, model_id, device)
|
|
self.models[model_id] = wrapper
|
|
|
|
return wrapper
|
|
|
|
def get_model(self, model_id: str) -> Optional[YOLOWrapper]:
|
|
"""
|
|
Get a loaded model
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
YOLOWrapper instance or None if not loaded
|
|
"""
|
|
return self.models.get(model_id)
|
|
|
|
def unload_model(self, model_id: str) -> bool:
|
|
"""
|
|
Unload a model to free memory
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
True if unloaded, False if not found
|
|
"""
|
|
with self._lock:
|
|
if model_id in self.models:
|
|
self.models[model_id].clear_cache()
|
|
del self.models[model_id]
|
|
logger.info(f"Unloaded model {model_id}")
|
|
return True
|
|
return False
|
|
|
|
def unload_all(self) -> None:
|
|
"""Unload all models"""
|
|
with self._lock:
|
|
for model_id in list(self.models.keys()):
|
|
self.models[model_id].clear_cache()
|
|
self.models.clear()
|
|
logger.info("Unloaded all models") |