python-rtsp-worker/services/yolo.py
2025-11-09 11:47:18 +07:00

197 lines
7.4 KiB
Python

"""
YOLOv8 Model Utilities
This module provides static utility functions for YOLOv8 model preprocessing
and postprocessing, compatible with TensorRT inference.
Features:
- Preprocessing: Resize and normalize frames for YOLOv8 inference
- Postprocessing: Parse YOLOv8 output format to detection boxes
- Format conversion: (cx, cy, w, h) to (x1, y1, x2, y2)
- Confidence filtering and NMS handled separately
Usage:
from services.yolo import YOLOv8Utils
# Preprocess frame
model_input = YOLOv8Utils.preprocess(frame_gpu, input_size=640)
# Run inference
outputs = model_repo.infer(model_id="yolov8", inputs={"images": model_input})
# Postprocess detections
detections = YOLOv8Utils.postprocess(outputs, conf_threshold=0.25, nms_threshold=0.45)
"""
import torch
from typing import Tuple, Optional
class YOLOv8Utils:
"""Static utility class for YOLOv8 model operations."""
@staticmethod
def preprocess(
frame: torch.Tensor,
input_size: int = 640
) -> torch.Tensor:
"""
Preprocess frame for YOLOv8 inference.
Args:
frame: RGB frame as GPU tensor, shape (3, H, W) uint8
input_size: Model input size (default: 640 for YOLOv8)
Returns:
Preprocessed frame ready for model, shape (1, 3, input_size, input_size) float32
Example:
>>> frame_gpu = decoder.get_latest_frame(rgb=True) # (3, 720, 1280)
>>> model_input = YOLOv8Utils.preprocess(frame_gpu) # (1, 3, 640, 640)
"""
# Add batch dimension and convert to float
frame_batch = frame.unsqueeze(0).float() # (1, 3, H, W)
# Resize to model input size
frame_resized = torch.nn.functional.interpolate(
frame_batch,
size=(input_size, input_size),
mode='bilinear',
align_corners=False
)
# Normalize to [0, 1] (YOLOv8 expects normalized input)
frame_normalized = frame_resized / 255.0
return frame_normalized
@staticmethod
def postprocess(
outputs: dict,
conf_threshold: float = 0.25,
nms_threshold: float = 0.45
) -> torch.Tensor:
"""
Postprocess YOLOv8 TensorRT output to detection format.
YOLOv8 output format: (1, 84, 8400)
- 84 channels = 4 bbox coords (cx, cy, w, h) + 80 class scores
- 8400 anchor points
Args:
outputs: Dictionary of model outputs from TensorRT inference
conf_threshold: Confidence threshold for filtering detections (default: 0.25)
nms_threshold: IoU threshold for Non-Maximum Suppression (default: 0.45)
Returns:
Tensor of shape (N, 6): [x1, y1, x2, y2, conf, class_id]
- Coordinates are in model input space (0-640 for default YOLOv8)
- N is the number of detections after NMS
Example:
>>> outputs = model_repo.infer(model_id="yolov8", inputs={"images": frame})
>>> detections = YOLOv8Utils.postprocess(outputs, conf_threshold=0.5)
>>> # detections: [[x1, y1, x2, y2, conf, class_id], ...]
"""
from torchvision.ops import nms
# Get output tensor (first and only output)
output_name = list(outputs.keys())[0]
output = outputs[output_name] # (1, 84, 8400)
# Transpose to (1, 8400, 84) for easier processing
output = output.transpose(1, 2).squeeze(0) # (8400, 84)
# Split bbox coordinates and class scores (vectorized)
bboxes = output[:, :4] # (8400, 4) - (cx, cy, w, h)
class_scores = output[:, 4:] # (8400, 80)
# Get max class score and corresponding class ID for all anchors (vectorized)
max_scores, class_ids = torch.max(class_scores, dim=1) # (8400,), (8400,)
# Filter by confidence threshold (vectorized)
mask = max_scores > conf_threshold
filtered_bboxes = bboxes[mask] # (N, 4)
filtered_scores = max_scores[mask] # (N,)
filtered_class_ids = class_ids[mask] # (N,)
# Return empty tensor if no detections
if filtered_bboxes.shape[0] == 0:
return torch.zeros((0, 6), device=output.device)
# Convert from (cx, cy, w, h) to (x1, y1, x2, y2) (vectorized)
cx, cy, w, h = filtered_bboxes[:, 0], filtered_bboxes[:, 1], filtered_bboxes[:, 2], filtered_bboxes[:, 3]
x1 = cx - w / 2
y1 = cy - h / 2
x2 = cx + w / 2
y2 = cy + h / 2
# Stack into detections tensor: [x1, y1, x2, y2, conf, class_id]
detections_tensor = torch.stack([
x1, y1, x2, y2,
filtered_scores,
filtered_class_ids.float()
], dim=1) # (N, 6)
# Apply Non-Maximum Suppression (NMS)
boxes = detections_tensor[:, :4] # (N, 4)
scores = detections_tensor[:, 4] # (N,)
# NMS returns indices of boxes to keep
keep_indices = nms(boxes, scores, iou_threshold=nms_threshold)
# Return filtered detections
return detections_tensor[keep_indices]
@staticmethod
def scale_boxes(
boxes: torch.Tensor,
from_size: Tuple[int, int],
to_size: Tuple[int, int]
) -> torch.Tensor:
"""
Scale bounding boxes from one coordinate space to another.
Args:
boxes: Tensor of boxes, shape (N, 4) in format [x1, y1, x2, y2]
from_size: Source size (width, height) - e.g., (640, 640) for model output
to_size: Target size (width, height) - e.g., (1280, 720) for display
Returns:
Scaled boxes tensor, same shape as input
Example:
>>> detections = YOLOv8Utils.postprocess(outputs) # boxes in 640x640 space
>>> boxes = detections[:, :4] # Extract boxes
>>> scaled_boxes = YOLOv8Utils.scale_boxes(boxes, (640, 640), (1280, 720))
"""
scale_x = to_size[0] / from_size[0]
scale_y = to_size[1] / from_size[1]
# Clone to avoid modifying original
scaled = boxes.clone()
scaled[:, [0, 2]] *= scale_x # Scale x coordinates
scaled[:, [1, 3]] *= scale_y # Scale y coordinates
return scaled
# COCO class names for YOLOv8 (80 classes)
COCO_CLASSES = {
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow',
20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack',
25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee',
30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket',
39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon',
45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange',
50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut',
55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed',
60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse',
65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven',
70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock',
75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
}