197 lines
7.4 KiB
Python
197 lines
7.4 KiB
Python
"""
|
|
YOLOv8 Model Utilities
|
|
|
|
This module provides static utility functions for YOLOv8 model preprocessing
|
|
and postprocessing, compatible with TensorRT inference.
|
|
|
|
Features:
|
|
- Preprocessing: Resize and normalize frames for YOLOv8 inference
|
|
- Postprocessing: Parse YOLOv8 output format to detection boxes
|
|
- Format conversion: (cx, cy, w, h) to (x1, y1, x2, y2)
|
|
- Confidence filtering and NMS handled separately
|
|
|
|
Usage:
|
|
from services.yolo import YOLOv8Utils
|
|
|
|
# Preprocess frame
|
|
model_input = YOLOv8Utils.preprocess(frame_gpu, input_size=640)
|
|
|
|
# Run inference
|
|
outputs = model_repo.infer(model_id="yolov8", inputs={"images": model_input})
|
|
|
|
# Postprocess detections
|
|
detections = YOLOv8Utils.postprocess(outputs, conf_threshold=0.25, nms_threshold=0.45)
|
|
"""
|
|
|
|
import torch
|
|
from typing import Tuple, Optional
|
|
|
|
|
|
class YOLOv8Utils:
|
|
"""Static utility class for YOLOv8 model operations."""
|
|
|
|
@staticmethod
|
|
def preprocess(
|
|
frame: torch.Tensor,
|
|
input_size: int = 640
|
|
) -> torch.Tensor:
|
|
"""
|
|
Preprocess frame for YOLOv8 inference.
|
|
|
|
Args:
|
|
frame: RGB frame as GPU tensor, shape (3, H, W) uint8
|
|
input_size: Model input size (default: 640 for YOLOv8)
|
|
|
|
Returns:
|
|
Preprocessed frame ready for model, shape (1, 3, input_size, input_size) float32
|
|
|
|
Example:
|
|
>>> frame_gpu = decoder.get_latest_frame(rgb=True) # (3, 720, 1280)
|
|
>>> model_input = YOLOv8Utils.preprocess(frame_gpu) # (1, 3, 640, 640)
|
|
"""
|
|
# Add batch dimension and convert to float
|
|
frame_batch = frame.unsqueeze(0).float() # (1, 3, H, W)
|
|
|
|
# Resize to model input size
|
|
frame_resized = torch.nn.functional.interpolate(
|
|
frame_batch,
|
|
size=(input_size, input_size),
|
|
mode='bilinear',
|
|
align_corners=False
|
|
)
|
|
|
|
# Normalize to [0, 1] (YOLOv8 expects normalized input)
|
|
frame_normalized = frame_resized / 255.0
|
|
|
|
return frame_normalized
|
|
|
|
@staticmethod
|
|
def postprocess(
|
|
outputs: dict,
|
|
conf_threshold: float = 0.25,
|
|
nms_threshold: float = 0.45
|
|
) -> torch.Tensor:
|
|
"""
|
|
Postprocess YOLOv8 TensorRT output to detection format.
|
|
|
|
YOLOv8 output format: (1, 84, 8400)
|
|
- 84 channels = 4 bbox coords (cx, cy, w, h) + 80 class scores
|
|
- 8400 anchor points
|
|
|
|
Args:
|
|
outputs: Dictionary of model outputs from TensorRT inference
|
|
conf_threshold: Confidence threshold for filtering detections (default: 0.25)
|
|
nms_threshold: IoU threshold for Non-Maximum Suppression (default: 0.45)
|
|
|
|
Returns:
|
|
Tensor of shape (N, 6): [x1, y1, x2, y2, conf, class_id]
|
|
- Coordinates are in model input space (0-640 for default YOLOv8)
|
|
- N is the number of detections after NMS
|
|
|
|
Example:
|
|
>>> outputs = model_repo.infer(model_id="yolov8", inputs={"images": frame})
|
|
>>> detections = YOLOv8Utils.postprocess(outputs, conf_threshold=0.5)
|
|
>>> # detections: [[x1, y1, x2, y2, conf, class_id], ...]
|
|
"""
|
|
from torchvision.ops import nms
|
|
|
|
# Get output tensor (first and only output)
|
|
output_name = list(outputs.keys())[0]
|
|
output = outputs[output_name] # (1, 84, 8400)
|
|
|
|
# Transpose to (1, 8400, 84) for easier processing
|
|
output = output.transpose(1, 2).squeeze(0) # (8400, 84)
|
|
|
|
# Split bbox coordinates and class scores (vectorized)
|
|
bboxes = output[:, :4] # (8400, 4) - (cx, cy, w, h)
|
|
class_scores = output[:, 4:] # (8400, 80)
|
|
|
|
# Get max class score and corresponding class ID for all anchors (vectorized)
|
|
max_scores, class_ids = torch.max(class_scores, dim=1) # (8400,), (8400,)
|
|
|
|
# Filter by confidence threshold (vectorized)
|
|
mask = max_scores > conf_threshold
|
|
filtered_bboxes = bboxes[mask] # (N, 4)
|
|
filtered_scores = max_scores[mask] # (N,)
|
|
filtered_class_ids = class_ids[mask] # (N,)
|
|
|
|
# Return empty tensor if no detections
|
|
if filtered_bboxes.shape[0] == 0:
|
|
return torch.zeros((0, 6), device=output.device)
|
|
|
|
# Convert from (cx, cy, w, h) to (x1, y1, x2, y2) (vectorized)
|
|
cx, cy, w, h = filtered_bboxes[:, 0], filtered_bboxes[:, 1], filtered_bboxes[:, 2], filtered_bboxes[:, 3]
|
|
x1 = cx - w / 2
|
|
y1 = cy - h / 2
|
|
x2 = cx + w / 2
|
|
y2 = cy + h / 2
|
|
|
|
# Stack into detections tensor: [x1, y1, x2, y2, conf, class_id]
|
|
detections_tensor = torch.stack([
|
|
x1, y1, x2, y2,
|
|
filtered_scores,
|
|
filtered_class_ids.float()
|
|
], dim=1) # (N, 6)
|
|
|
|
# Apply Non-Maximum Suppression (NMS)
|
|
boxes = detections_tensor[:, :4] # (N, 4)
|
|
scores = detections_tensor[:, 4] # (N,)
|
|
|
|
# NMS returns indices of boxes to keep
|
|
keep_indices = nms(boxes, scores, iou_threshold=nms_threshold)
|
|
|
|
# Return filtered detections
|
|
return detections_tensor[keep_indices]
|
|
|
|
@staticmethod
|
|
def scale_boxes(
|
|
boxes: torch.Tensor,
|
|
from_size: Tuple[int, int],
|
|
to_size: Tuple[int, int]
|
|
) -> torch.Tensor:
|
|
"""
|
|
Scale bounding boxes from one coordinate space to another.
|
|
|
|
Args:
|
|
boxes: Tensor of boxes, shape (N, 4) in format [x1, y1, x2, y2]
|
|
from_size: Source size (width, height) - e.g., (640, 640) for model output
|
|
to_size: Target size (width, height) - e.g., (1280, 720) for display
|
|
|
|
Returns:
|
|
Scaled boxes tensor, same shape as input
|
|
|
|
Example:
|
|
>>> detections = YOLOv8Utils.postprocess(outputs) # boxes in 640x640 space
|
|
>>> boxes = detections[:, :4] # Extract boxes
|
|
>>> scaled_boxes = YOLOv8Utils.scale_boxes(boxes, (640, 640), (1280, 720))
|
|
"""
|
|
scale_x = to_size[0] / from_size[0]
|
|
scale_y = to_size[1] / from_size[1]
|
|
|
|
# Clone to avoid modifying original
|
|
scaled = boxes.clone()
|
|
scaled[:, [0, 2]] *= scale_x # Scale x coordinates
|
|
scaled[:, [1, 3]] *= scale_y # Scale y coordinates
|
|
|
|
return scaled
|
|
|
|
|
|
# COCO class names for YOLOv8 (80 classes)
|
|
COCO_CLASSES = {
|
|
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
|
|
5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
|
|
10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
|
|
14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow',
|
|
20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack',
|
|
25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee',
|
|
30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
|
|
35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket',
|
|
39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon',
|
|
45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange',
|
|
50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut',
|
|
55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed',
|
|
60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse',
|
|
65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven',
|
|
70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock',
|
|
75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
|
|
}
|