nms optimization

This commit is contained in:
Siwat Sirichai 2025-11-09 11:47:18 +07:00
parent 81bbb0074e
commit 8e20496fa7
5 changed files with 907 additions and 26 deletions

View file

@ -100,39 +100,38 @@ class YOLOv8Utils:
output = outputs[output_name] # (1, 84, 8400)
# Transpose to (1, 8400, 84) for easier processing
output = output.transpose(1, 2)
output = output.transpose(1, 2).squeeze(0) # (8400, 84)
# Process first batch (batch size is always 1 for single image inference)
detections = []
for detection in output[0]: # Iterate over 8400 anchor points
# Split bbox coordinates and class scores
bbox = detection[:4] # (cx, cy, w, h)
class_scores = detection[4:] # 80 class scores
# Split bbox coordinates and class scores (vectorized)
bboxes = output[:, :4] # (8400, 4) - (cx, cy, w, h)
class_scores = output[:, 4:] # (8400, 80)
# Get max class score and corresponding class ID
max_score, class_id = torch.max(class_scores, 0)
# Get max class score and corresponding class ID for all anchors (vectorized)
max_scores, class_ids = torch.max(class_scores, dim=1) # (8400,), (8400,)
# Filter by confidence threshold
if max_score > conf_threshold:
# Convert from (cx, cy, w, h) to (x1, y1, x2, y2)
cx, cy, w, h = bbox
x1 = cx - w / 2
y1 = cy - h / 2
x2 = cx + w / 2
y2 = cy + h / 2
# Append detection: [x1, y1, x2, y2, conf, class_id]
detections.append([
x1.item(), y1.item(), x2.item(), y2.item(),
max_score.item(), class_id.item()
])
# Filter by confidence threshold (vectorized)
mask = max_scores > conf_threshold
filtered_bboxes = bboxes[mask] # (N, 4)
filtered_scores = max_scores[mask] # (N,)
filtered_class_ids = class_ids[mask] # (N,)
# Return empty tensor if no detections
if not detections:
if filtered_bboxes.shape[0] == 0:
return torch.zeros((0, 6), device=output.device)
# Convert list to tensor
detections_tensor = torch.tensor(detections, device=output.device)
# Convert from (cx, cy, w, h) to (x1, y1, x2, y2) (vectorized)
cx, cy, w, h = filtered_bboxes[:, 0], filtered_bboxes[:, 1], filtered_bboxes[:, 2], filtered_bboxes[:, 3]
x1 = cx - w / 2
y1 = cy - h / 2
x2 = cx + w / 2
y2 = cy + h / 2
# Stack into detections tensor: [x1, y1, x2, y2, conf, class_id]
detections_tensor = torch.stack([
x1, y1, x2, y2,
filtered_scores,
filtered_class_ids.float()
], dim=1) # (N, 6)
# Apply Non-Maximum Suppression (NMS)
boxes = detections_tensor[:, :4] # (N, 4)