nms optimization
This commit is contained in:
parent
81bbb0074e
commit
8e20496fa7
5 changed files with 907 additions and 26 deletions
|
|
@ -100,39 +100,38 @@ class YOLOv8Utils:
|
|||
output = outputs[output_name] # (1, 84, 8400)
|
||||
|
||||
# Transpose to (1, 8400, 84) for easier processing
|
||||
output = output.transpose(1, 2)
|
||||
output = output.transpose(1, 2).squeeze(0) # (8400, 84)
|
||||
|
||||
# Process first batch (batch size is always 1 for single image inference)
|
||||
detections = []
|
||||
for detection in output[0]: # Iterate over 8400 anchor points
|
||||
# Split bbox coordinates and class scores
|
||||
bbox = detection[:4] # (cx, cy, w, h)
|
||||
class_scores = detection[4:] # 80 class scores
|
||||
# Split bbox coordinates and class scores (vectorized)
|
||||
bboxes = output[:, :4] # (8400, 4) - (cx, cy, w, h)
|
||||
class_scores = output[:, 4:] # (8400, 80)
|
||||
|
||||
# Get max class score and corresponding class ID
|
||||
max_score, class_id = torch.max(class_scores, 0)
|
||||
# Get max class score and corresponding class ID for all anchors (vectorized)
|
||||
max_scores, class_ids = torch.max(class_scores, dim=1) # (8400,), (8400,)
|
||||
|
||||
# Filter by confidence threshold
|
||||
if max_score > conf_threshold:
|
||||
# Convert from (cx, cy, w, h) to (x1, y1, x2, y2)
|
||||
cx, cy, w, h = bbox
|
||||
x1 = cx - w / 2
|
||||
y1 = cy - h / 2
|
||||
x2 = cx + w / 2
|
||||
y2 = cy + h / 2
|
||||
|
||||
# Append detection: [x1, y1, x2, y2, conf, class_id]
|
||||
detections.append([
|
||||
x1.item(), y1.item(), x2.item(), y2.item(),
|
||||
max_score.item(), class_id.item()
|
||||
])
|
||||
# Filter by confidence threshold (vectorized)
|
||||
mask = max_scores > conf_threshold
|
||||
filtered_bboxes = bboxes[mask] # (N, 4)
|
||||
filtered_scores = max_scores[mask] # (N,)
|
||||
filtered_class_ids = class_ids[mask] # (N,)
|
||||
|
||||
# Return empty tensor if no detections
|
||||
if not detections:
|
||||
if filtered_bboxes.shape[0] == 0:
|
||||
return torch.zeros((0, 6), device=output.device)
|
||||
|
||||
# Convert list to tensor
|
||||
detections_tensor = torch.tensor(detections, device=output.device)
|
||||
# Convert from (cx, cy, w, h) to (x1, y1, x2, y2) (vectorized)
|
||||
cx, cy, w, h = filtered_bboxes[:, 0], filtered_bboxes[:, 1], filtered_bboxes[:, 2], filtered_bboxes[:, 3]
|
||||
x1 = cx - w / 2
|
||||
y1 = cy - h / 2
|
||||
x2 = cx + w / 2
|
||||
y2 = cy + h / 2
|
||||
|
||||
# Stack into detections tensor: [x1, y1, x2, y2, conf, class_id]
|
||||
detections_tensor = torch.stack([
|
||||
x1, y1, x2, y2,
|
||||
filtered_scores,
|
||||
filtered_class_ids.float()
|
||||
], dim=1) # (N, 6)
|
||||
|
||||
# Apply Non-Maximum Suppression (NMS)
|
||||
boxes = detections_tensor[:, :4] # (N, 4)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue