python-detector-worker/core/models/inference.py
2025-09-26 14:22:38 +07:00

447 lines
No EOL
14 KiB
Python

"""
YOLO Model Inference Wrapper - Handles model loading and inference optimization
"""
import logging
import torch
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple, Union
from threading import Lock
from dataclasses import dataclass
import cv2
logger = logging.getLogger(__name__)
@dataclass
class Detection:
"""Represents a single detection result"""
bbox: List[float] # [x1, y1, x2, y2]
confidence: float
class_id: int
class_name: str
track_id: Optional[int] = None
@dataclass
class InferenceResult:
"""Result from model inference"""
detections: List[Detection]
image_shape: Tuple[int, int] # (height, width)
inference_time: float
model_id: str
class YOLOWrapper:
"""Wrapper for YOLO models with caching and optimization"""
# Class-level model cache shared across all instances
_model_cache: Dict[str, Any] = {}
_cache_lock = Lock()
def __init__(self, model_path: Path, model_id: str, device: Optional[str] = None):
"""
Initialize YOLO wrapper
Args:
model_path: Path to the .pt model file
model_id: Unique identifier for the model
device: Device to run inference on ('cuda', 'cpu', or None for auto)
"""
self.model_path = model_path
self.model_id = model_id
# Auto-detect device if not specified
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
self.model = None
self._class_names = []
self._load_model()
logger.info(f"Initialized YOLO wrapper for {model_id} on {self.device}")
def _load_model(self) -> None:
"""Load the YOLO model with caching"""
cache_key = str(self.model_path)
with self._cache_lock:
# Check if model is already cached
if cache_key in self._model_cache:
logger.info(f"Loading model {self.model_id} from cache")
self.model = self._model_cache[cache_key]
self._extract_class_names()
return
# Load model
try:
from ultralytics import YOLO
logger.info(f"Loading YOLO model from {self.model_path}")
self.model = YOLO(str(self.model_path))
# Move model to device
if self.device == 'cuda' and torch.cuda.is_available():
self.model.to('cuda')
logger.info(f"Model {self.model_id} moved to GPU")
# Cache the model
self._model_cache[cache_key] = self.model
self._extract_class_names()
logger.info(f"Successfully loaded model {self.model_id}")
except ImportError:
logger.error("Ultralytics YOLO not installed. Install with: pip install ultralytics")
raise
except Exception as e:
logger.error(f"Failed to load YOLO model {self.model_id}: {str(e)}", exc_info=True)
raise
def _extract_class_names(self) -> None:
"""Extract class names from the model"""
try:
if hasattr(self.model, 'names'):
self._class_names = self.model.names
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'names'):
self._class_names = self.model.model.names
else:
logger.warning(f"Could not extract class names from model {self.model_id}")
self._class_names = {}
except Exception as e:
logger.error(f"Failed to extract class names: {str(e)}")
self._class_names = {}
def infer(
self,
image: np.ndarray,
confidence_threshold: float = 0.5,
trigger_classes: Optional[List[str]] = None,
iou_threshold: float = 0.45
) -> InferenceResult:
"""
Run inference on an image
Args:
image: Input image as numpy array (BGR format)
confidence_threshold: Minimum confidence for detections
trigger_classes: List of class names to filter (None = all classes)
iou_threshold: IoU threshold for NMS
Returns:
InferenceResult containing detections
"""
if self.model is None:
raise RuntimeError(f"Model {self.model_id} not loaded")
try:
import time
start_time = time.time()
# Run inference
results = self.model(
image,
conf=confidence_threshold,
iou=iou_threshold,
verbose=False
)
inference_time = time.time() - start_time
# Parse results
detections = self._parse_results(results[0], trigger_classes)
return InferenceResult(
detections=detections,
image_shape=(image.shape[0], image.shape[1]),
inference_time=inference_time,
model_id=self.model_id
)
except Exception as e:
logger.error(f"Inference failed for model {self.model_id}: {str(e)}", exc_info=True)
raise
def _parse_results(
self,
result: Any,
trigger_classes: Optional[List[str]] = None
) -> List[Detection]:
"""
Parse YOLO results into Detection objects
Args:
result: YOLO result object
trigger_classes: Optional list of class names to filter
Returns:
List of Detection objects
"""
detections = []
try:
if result.boxes is None:
return detections
boxes = result.boxes
for i in range(len(boxes)):
# Get box coordinates
box = boxes.xyxy[i].cpu().numpy()
x1, y1, x2, y2 = box
# Get confidence and class
conf = float(boxes.conf[i])
cls_id = int(boxes.cls[i])
# Get class name
class_name = self._class_names.get(cls_id, f"class_{cls_id}")
# Filter by trigger classes if specified
if trigger_classes and class_name not in trigger_classes:
continue
# Get track ID if available
track_id = None
if hasattr(boxes, 'id') and boxes.id is not None:
track_id = int(boxes.id[i])
detection = Detection(
bbox=[float(x1), float(y1), float(x2), float(y2)],
confidence=conf,
class_id=cls_id,
class_name=class_name,
track_id=track_id
)
detections.append(detection)
except Exception as e:
logger.error(f"Failed to parse results: {str(e)}", exc_info=True)
return detections
def track(
self,
image: np.ndarray,
confidence_threshold: float = 0.5,
trigger_classes: Optional[List[str]] = None,
persist: bool = True,
camera_id: Optional[str] = None
) -> InferenceResult:
"""
Run detection (tracking will be handled by external tracker)
Args:
image: Input image as numpy array (BGR format)
confidence_threshold: Minimum confidence for detections
trigger_classes: List of class names to filter
persist: Ignored - tracking handled externally
camera_id: Ignored - tracking handled externally
Returns:
InferenceResult containing detections (no track IDs from YOLO)
"""
# Just do detection - no YOLO tracking
return self.infer(image, confidence_threshold, trigger_classes)
def predict_classification(
self,
image: np.ndarray,
top_k: int = 1
) -> Dict[str, float]:
"""
Run classification on an image
Args:
image: Input image as numpy array (BGR format)
top_k: Number of top predictions to return
Returns:
Dictionary of class_name -> confidence scores
"""
if self.model is None:
raise RuntimeError(f"Model {self.model_id} not loaded")
try:
# Run inference
results = self.model(image, verbose=False)
# For classification models, extract probabilities
if hasattr(results[0], 'probs'):
probs = results[0].probs
top_indices = probs.top5[:top_k]
top_conf = probs.top5conf[:top_k].cpu().numpy()
predictions = {}
for idx, conf in zip(top_indices, top_conf):
class_name = self._class_names.get(int(idx), f"class_{idx}")
predictions[class_name] = float(conf)
return predictions
else:
logger.warning(f"Model {self.model_id} does not support classification")
return {}
except Exception as e:
logger.error(f"Classification failed for model {self.model_id}: {str(e)}", exc_info=True)
raise
def crop_detection(
self,
image: np.ndarray,
detection: Detection,
padding: int = 0
) -> np.ndarray:
"""
Crop image to detection bounding box
Args:
image: Original image
detection: Detection to crop
padding: Additional padding around the box
Returns:
Cropped image region
"""
h, w = image.shape[:2]
x1, y1, x2, y2 = detection.bbox
# Add padding and clip to image boundaries
x1 = max(0, int(x1) - padding)
y1 = max(0, int(y1) - padding)
x2 = min(w, int(x2) + padding)
y2 = min(h, int(y2) + padding)
return image[y1:y2, x1:x2]
def get_class_names(self) -> Dict[int, str]:
"""Get the class names dictionary"""
return self._class_names.copy()
def get_num_classes(self) -> int:
"""Get the number of classes the model can detect"""
return len(self._class_names)
def clear_cache(self) -> None:
"""Clear the model cache"""
with self._cache_lock:
cache_key = str(self.model_path)
if cache_key in self._model_cache:
del self._model_cache[cache_key]
logger.info(f"Cleared cache for model {self.model_id}")
@classmethod
def clear_all_cache(cls) -> None:
"""Clear all cached models"""
with cls._cache_lock:
cls._model_cache.clear()
logger.info("Cleared all model cache")
def warmup(self, image_size: Tuple[int, int] = (640, 640)) -> None:
"""
Warmup the model with a dummy inference
Args:
image_size: Size of dummy image (height, width)
"""
try:
dummy_image = np.zeros((image_size[0], image_size[1], 3), dtype=np.uint8)
self.infer(dummy_image, confidence_threshold=0.5)
logger.info(f"Model {self.model_id} warmed up")
except Exception as e:
logger.warning(f"Failed to warmup model {self.model_id}: {str(e)}")
class ModelInferenceManager:
"""Manages multiple YOLO models for a pipeline"""
def __init__(self, model_dir: Path):
"""
Initialize the inference manager
Args:
model_dir: Directory containing model files
"""
self.model_dir = model_dir
self.models: Dict[str, YOLOWrapper] = {}
self._lock = Lock()
logger.info(f"Initialized ModelInferenceManager with model directory: {model_dir}")
def load_model(
self,
model_id: str,
model_file: str,
device: Optional[str] = None
) -> YOLOWrapper:
"""
Load a model for inference
Args:
model_id: Unique identifier for the model
model_file: Filename of the model
device: Device to run on
Returns:
YOLOWrapper instance
"""
with self._lock:
# Check if already loaded
if model_id in self.models:
logger.debug(f"Model {model_id} already loaded")
return self.models[model_id]
# Load the model
model_path = self.model_dir / model_file
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
wrapper = YOLOWrapper(model_path, model_id, device)
self.models[model_id] = wrapper
return wrapper
def get_model(self, model_id: str) -> Optional[YOLOWrapper]:
"""
Get a loaded model
Args:
model_id: Model identifier
Returns:
YOLOWrapper instance or None if not loaded
"""
return self.models.get(model_id)
def unload_model(self, model_id: str) -> bool:
"""
Unload a model to free memory
Args:
model_id: Model identifier
Returns:
True if unloaded, False if not found
"""
with self._lock:
if model_id in self.models:
self.models[model_id].clear_cache()
del self.models[model_id]
logger.info(f"Unloaded model {model_id}")
return True
return False
def unload_all(self) -> None:
"""Unload all models"""
with self._lock:
for model_id in list(self.models.keys()):
self.models[model_id].clear_cache()
self.models.clear()
logger.info("Unloaded all models")