yolo util class
This commit is contained in:
parent
bea895d3d8
commit
81bbb0074e
5 changed files with 1058 additions and 1 deletions
|
|
@ -7,6 +7,7 @@ from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg
|
|||
from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionContext, SharedEngine
|
||||
from .tracking_controller import TrackingController, TrackedObject
|
||||
from .tracking_factory import TrackingFactory
|
||||
from .yolo import YOLOv8Utils, COCO_CLASSES
|
||||
|
||||
__all__ = [
|
||||
'StreamDecoderFactory',
|
||||
|
|
@ -21,4 +22,6 @@ __all__ = [
|
|||
'TrackingController',
|
||||
'TrackedObject',
|
||||
'TrackingFactory',
|
||||
'YOLOv8Utils',
|
||||
'COCO_CLASSES',
|
||||
]
|
||||
|
|
|
|||
198
services/yolo.py
Normal file
198
services/yolo.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
"""
|
||||
YOLOv8 Model Utilities
|
||||
|
||||
This module provides static utility functions for YOLOv8 model preprocessing
|
||||
and postprocessing, compatible with TensorRT inference.
|
||||
|
||||
Features:
|
||||
- Preprocessing: Resize and normalize frames for YOLOv8 inference
|
||||
- Postprocessing: Parse YOLOv8 output format to detection boxes
|
||||
- Format conversion: (cx, cy, w, h) to (x1, y1, x2, y2)
|
||||
- Confidence filtering and NMS handled separately
|
||||
|
||||
Usage:
|
||||
from services.yolo import YOLOv8Utils
|
||||
|
||||
# Preprocess frame
|
||||
model_input = YOLOv8Utils.preprocess(frame_gpu, input_size=640)
|
||||
|
||||
# Run inference
|
||||
outputs = model_repo.infer(model_id="yolov8", inputs={"images": model_input})
|
||||
|
||||
# Postprocess detections
|
||||
detections = YOLOv8Utils.postprocess(outputs, conf_threshold=0.25, nms_threshold=0.45)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
class YOLOv8Utils:
|
||||
"""Static utility class for YOLOv8 model operations."""
|
||||
|
||||
@staticmethod
|
||||
def preprocess(
|
||||
frame: torch.Tensor,
|
||||
input_size: int = 640
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess frame for YOLOv8 inference.
|
||||
|
||||
Args:
|
||||
frame: RGB frame as GPU tensor, shape (3, H, W) uint8
|
||||
input_size: Model input size (default: 640 for YOLOv8)
|
||||
|
||||
Returns:
|
||||
Preprocessed frame ready for model, shape (1, 3, input_size, input_size) float32
|
||||
|
||||
Example:
|
||||
>>> frame_gpu = decoder.get_latest_frame(rgb=True) # (3, 720, 1280)
|
||||
>>> model_input = YOLOv8Utils.preprocess(frame_gpu) # (1, 3, 640, 640)
|
||||
"""
|
||||
# Add batch dimension and convert to float
|
||||
frame_batch = frame.unsqueeze(0).float() # (1, 3, H, W)
|
||||
|
||||
# Resize to model input size
|
||||
frame_resized = torch.nn.functional.interpolate(
|
||||
frame_batch,
|
||||
size=(input_size, input_size),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Normalize to [0, 1] (YOLOv8 expects normalized input)
|
||||
frame_normalized = frame_resized / 255.0
|
||||
|
||||
return frame_normalized
|
||||
|
||||
@staticmethod
|
||||
def postprocess(
|
||||
outputs: dict,
|
||||
conf_threshold: float = 0.25,
|
||||
nms_threshold: float = 0.45
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Postprocess YOLOv8 TensorRT output to detection format.
|
||||
|
||||
YOLOv8 output format: (1, 84, 8400)
|
||||
- 84 channels = 4 bbox coords (cx, cy, w, h) + 80 class scores
|
||||
- 8400 anchor points
|
||||
|
||||
Args:
|
||||
outputs: Dictionary of model outputs from TensorRT inference
|
||||
conf_threshold: Confidence threshold for filtering detections (default: 0.25)
|
||||
nms_threshold: IoU threshold for Non-Maximum Suppression (default: 0.45)
|
||||
|
||||
Returns:
|
||||
Tensor of shape (N, 6): [x1, y1, x2, y2, conf, class_id]
|
||||
- Coordinates are in model input space (0-640 for default YOLOv8)
|
||||
- N is the number of detections after NMS
|
||||
|
||||
Example:
|
||||
>>> outputs = model_repo.infer(model_id="yolov8", inputs={"images": frame})
|
||||
>>> detections = YOLOv8Utils.postprocess(outputs, conf_threshold=0.5)
|
||||
>>> # detections: [[x1, y1, x2, y2, conf, class_id], ...]
|
||||
"""
|
||||
from torchvision.ops import nms
|
||||
|
||||
# Get output tensor (first and only output)
|
||||
output_name = list(outputs.keys())[0]
|
||||
output = outputs[output_name] # (1, 84, 8400)
|
||||
|
||||
# Transpose to (1, 8400, 84) for easier processing
|
||||
output = output.transpose(1, 2)
|
||||
|
||||
# Process first batch (batch size is always 1 for single image inference)
|
||||
detections = []
|
||||
for detection in output[0]: # Iterate over 8400 anchor points
|
||||
# Split bbox coordinates and class scores
|
||||
bbox = detection[:4] # (cx, cy, w, h)
|
||||
class_scores = detection[4:] # 80 class scores
|
||||
|
||||
# Get max class score and corresponding class ID
|
||||
max_score, class_id = torch.max(class_scores, 0)
|
||||
|
||||
# Filter by confidence threshold
|
||||
if max_score > conf_threshold:
|
||||
# Convert from (cx, cy, w, h) to (x1, y1, x2, y2)
|
||||
cx, cy, w, h = bbox
|
||||
x1 = cx - w / 2
|
||||
y1 = cy - h / 2
|
||||
x2 = cx + w / 2
|
||||
y2 = cy + h / 2
|
||||
|
||||
# Append detection: [x1, y1, x2, y2, conf, class_id]
|
||||
detections.append([
|
||||
x1.item(), y1.item(), x2.item(), y2.item(),
|
||||
max_score.item(), class_id.item()
|
||||
])
|
||||
|
||||
# Return empty tensor if no detections
|
||||
if not detections:
|
||||
return torch.zeros((0, 6), device=output.device)
|
||||
|
||||
# Convert list to tensor
|
||||
detections_tensor = torch.tensor(detections, device=output.device)
|
||||
|
||||
# Apply Non-Maximum Suppression (NMS)
|
||||
boxes = detections_tensor[:, :4] # (N, 4)
|
||||
scores = detections_tensor[:, 4] # (N,)
|
||||
|
||||
# NMS returns indices of boxes to keep
|
||||
keep_indices = nms(boxes, scores, iou_threshold=nms_threshold)
|
||||
|
||||
# Return filtered detections
|
||||
return detections_tensor[keep_indices]
|
||||
|
||||
@staticmethod
|
||||
def scale_boxes(
|
||||
boxes: torch.Tensor,
|
||||
from_size: Tuple[int, int],
|
||||
to_size: Tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Scale bounding boxes from one coordinate space to another.
|
||||
|
||||
Args:
|
||||
boxes: Tensor of boxes, shape (N, 4) in format [x1, y1, x2, y2]
|
||||
from_size: Source size (width, height) - e.g., (640, 640) for model output
|
||||
to_size: Target size (width, height) - e.g., (1280, 720) for display
|
||||
|
||||
Returns:
|
||||
Scaled boxes tensor, same shape as input
|
||||
|
||||
Example:
|
||||
>>> detections = YOLOv8Utils.postprocess(outputs) # boxes in 640x640 space
|
||||
>>> boxes = detections[:, :4] # Extract boxes
|
||||
>>> scaled_boxes = YOLOv8Utils.scale_boxes(boxes, (640, 640), (1280, 720))
|
||||
"""
|
||||
scale_x = to_size[0] / from_size[0]
|
||||
scale_y = to_size[1] / from_size[1]
|
||||
|
||||
# Clone to avoid modifying original
|
||||
scaled = boxes.clone()
|
||||
scaled[:, [0, 2]] *= scale_x # Scale x coordinates
|
||||
scaled[:, [1, 3]] *= scale_y # Scale y coordinates
|
||||
|
||||
return scaled
|
||||
|
||||
|
||||
# COCO class names for YOLOv8 (80 classes)
|
||||
COCO_CLASSES = {
|
||||
0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane',
|
||||
5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
|
||||
10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
|
||||
14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow',
|
||||
20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack',
|
||||
25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee',
|
||||
30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat',
|
||||
35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket',
|
||||
39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon',
|
||||
45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange',
|
||||
50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut',
|
||||
55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed',
|
||||
60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse',
|
||||
65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven',
|
||||
70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock',
|
||||
75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue