nms optimization
This commit is contained in:
parent
81bbb0074e
commit
8e20496fa7
5 changed files with 907 additions and 26 deletions
310
test_batch_inference.py
Normal file
310
test_batch_inference.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
"""
|
||||
Batch Inference Test - Process Multiple Cameras in Single Batch
|
||||
|
||||
This script demonstrates batch inference to eliminate sequential processing bottleneck.
|
||||
Instead of processing 4 cameras one-by-one, we process all 4 in a single batched inference.
|
||||
|
||||
Requirements:
|
||||
- TensorRT model with dynamic batching support
|
||||
- Rebuild model: python scripts/convert_pt_to_tensorrt.py --model yolov8n.pt
|
||||
--output models/yolov8n_batch4.trt --dynamic-batch --max-batch 4 --fp16
|
||||
|
||||
Performance Comparison:
|
||||
- Sequential: Process each camera separately (current bottleneck)
|
||||
- Batched: Stack all frames → single inference → split results
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
import torch
|
||||
from dotenv import load_dotenv
|
||||
from services import (
|
||||
StreamDecoderFactory,
|
||||
TensorRTModelRepository,
|
||||
YOLOv8Utils,
|
||||
COCO_CLASSES,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def preprocess_batch(frames: list[torch.Tensor], input_size: int = 640) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess multiple frames for batched inference.
|
||||
|
||||
Args:
|
||||
frames: List of GPU tensors, each (3, H, W) uint8
|
||||
input_size: Model input size (default: 640)
|
||||
|
||||
Returns:
|
||||
Batched tensor (B, 3, 640, 640) float32
|
||||
"""
|
||||
# Preprocess each frame individually
|
||||
preprocessed = [YOLOv8Utils.preprocess(frame, input_size) for frame in frames]
|
||||
|
||||
# Stack into batch: (B, 3, 640, 640)
|
||||
return torch.cat(preprocessed, dim=0)
|
||||
|
||||
|
||||
def postprocess_batch(outputs: dict, conf_threshold: float = 0.25,
|
||||
nms_threshold: float = 0.45) -> list[torch.Tensor]:
|
||||
"""
|
||||
Postprocess batched YOLOv8 output to per-image detections.
|
||||
|
||||
YOLOv8 batched output: (B, 84, 8400)
|
||||
|
||||
Args:
|
||||
outputs: Dictionary of model outputs from TensorRT inference
|
||||
conf_threshold: Confidence threshold
|
||||
nms_threshold: IoU threshold for NMS
|
||||
|
||||
Returns:
|
||||
List of detection tensors, each (N, 6): [x1, y1, x2, y2, conf, class_id]
|
||||
"""
|
||||
from torchvision.ops import nms
|
||||
|
||||
# Get output tensor
|
||||
output_name = list(outputs.keys())[0]
|
||||
output = outputs[output_name] # (B, 84, 8400)
|
||||
|
||||
batch_size = output.shape[0]
|
||||
results = []
|
||||
|
||||
for b in range(batch_size):
|
||||
# Extract single image from batch
|
||||
single_output = output[b:b+1] # (1, 84, 8400)
|
||||
|
||||
# Reuse existing postprocessing logic
|
||||
detections = YOLOv8Utils.postprocess(
|
||||
{output_name: single_output},
|
||||
conf_threshold=conf_threshold,
|
||||
nms_threshold=nms_threshold
|
||||
)
|
||||
|
||||
results.append(detections)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def benchmark_sequential_vs_batch(duration: int = 30):
|
||||
"""
|
||||
Benchmark sequential vs batched inference.
|
||||
|
||||
Args:
|
||||
duration: Test duration in seconds
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("BATCH INFERENCE BENCHMARK")
|
||||
print("=" * 80)
|
||||
|
||||
GPU_ID = 0
|
||||
MODEL_PATH_BATCH = "models/yolov8n_batch4.trt" # Dynamic batch model
|
||||
MODEL_PATH_SINGLE = "models/yolov8n.trt" # Original single-batch model
|
||||
|
||||
# Check if batch model exists
|
||||
if not os.path.exists(MODEL_PATH_BATCH):
|
||||
print(f"\n⚠ Batch model not found: {MODEL_PATH_BATCH}")
|
||||
print("\nTo create it, run:")
|
||||
print(" python scripts/convert_pt_to_tensorrt.py \\")
|
||||
print(" --model yolov8n.pt \\")
|
||||
print(" --output models/yolov8n_batch4.trt \\")
|
||||
print(" --dynamic-batch --max-batch 4 --fp16")
|
||||
print("\nFalling back to simulated batch processing...")
|
||||
use_true_batching = False
|
||||
MODEL_PATH = MODEL_PATH_SINGLE
|
||||
else:
|
||||
use_true_batching = True
|
||||
MODEL_PATH = MODEL_PATH_BATCH
|
||||
print(f"\n✓ Using batch model: {MODEL_PATH_BATCH}")
|
||||
|
||||
# Load camera URLs
|
||||
camera_urls = []
|
||||
for i in range(1, 5):
|
||||
url = os.getenv(f'CAMERA_URL_{i}')
|
||||
if url:
|
||||
camera_urls.append(url)
|
||||
|
||||
if len(camera_urls) < 2:
|
||||
print(f"⚠ Need at least 2 cameras, found {len(camera_urls)}")
|
||||
return
|
||||
|
||||
print(f"\nTesting with {len(camera_urls)} cameras")
|
||||
|
||||
# Initialize components
|
||||
print("\nInitializing...")
|
||||
model_repo = TensorRTModelRepository(gpu_id=GPU_ID, default_num_contexts=4)
|
||||
model_repo.load_model("detector", MODEL_PATH, num_contexts=4)
|
||||
|
||||
stream_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoders = []
|
||||
|
||||
for i, url in enumerate(camera_urls):
|
||||
decoder = stream_factory.create_decoder(url, buffer_size=30)
|
||||
decoder.start()
|
||||
decoders.append(decoder)
|
||||
print(f" Camera {i+1}: {url}")
|
||||
|
||||
print("\nWaiting for streams to connect...")
|
||||
time.sleep(10)
|
||||
|
||||
# ==================== SEQUENTIAL BENCHMARK ====================
|
||||
print("\n" + "=" * 80)
|
||||
print("1. SEQUENTIAL INFERENCE (Current Method)")
|
||||
print("=" * 80)
|
||||
|
||||
frame_count_seq = 0
|
||||
start_time = time.time()
|
||||
|
||||
print(f"\nRunning for {duration} seconds...")
|
||||
|
||||
try:
|
||||
while time.time() - start_time < duration:
|
||||
for decoder in decoders:
|
||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||
if frame_gpu is None:
|
||||
continue
|
||||
|
||||
# Preprocess
|
||||
preprocessed = YOLOv8Utils.preprocess(frame_gpu)
|
||||
|
||||
# Inference (single frame)
|
||||
outputs = model_repo.infer(
|
||||
model_id="detector",
|
||||
inputs={"images": preprocessed},
|
||||
synchronize=True
|
||||
)
|
||||
|
||||
# Postprocess
|
||||
detections = YOLOv8Utils.postprocess(outputs)
|
||||
|
||||
frame_count_seq += 1
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
seq_time = time.time() - start_time
|
||||
seq_fps = frame_count_seq / seq_time
|
||||
|
||||
print(f"\nSequential Results:")
|
||||
print(f" Total frames: {frame_count_seq}")
|
||||
print(f" Total time: {seq_time:.2f}s")
|
||||
print(f" Combined FPS: {seq_fps:.2f}")
|
||||
print(f" Per-camera FPS: {seq_fps / len(camera_urls):.2f}")
|
||||
|
||||
# ==================== BATCHED BENCHMARK ====================
|
||||
print("\n" + "=" * 80)
|
||||
print("2. BATCHED INFERENCE (Optimized Method)")
|
||||
print("=" * 80)
|
||||
|
||||
if not use_true_batching:
|
||||
print("\n⚠ Skipping true batch inference (model not available)")
|
||||
print(" Results would be identical without dynamic batch model")
|
||||
else:
|
||||
frame_count_batch = 0
|
||||
start_time = time.time()
|
||||
|
||||
print(f"\nRunning for {duration} seconds...")
|
||||
|
||||
try:
|
||||
while time.time() - start_time < duration:
|
||||
# Collect frames from all cameras
|
||||
frames = []
|
||||
for decoder in decoders:
|
||||
frame_gpu = decoder.get_latest_frame(rgb=True)
|
||||
if frame_gpu is not None:
|
||||
frames.append(frame_gpu)
|
||||
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
|
||||
# Batch preprocess
|
||||
batch_input = preprocess_batch(frames)
|
||||
|
||||
# Single batched inference
|
||||
outputs = model_repo.infer(
|
||||
model_id="detector",
|
||||
inputs={"images": batch_input},
|
||||
synchronize=True
|
||||
)
|
||||
|
||||
# Batch postprocess
|
||||
batch_detections = postprocess_batch(outputs)
|
||||
|
||||
frame_count_batch += len(frames)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
batch_time = time.time() - start_time
|
||||
batch_fps = frame_count_batch / batch_time
|
||||
|
||||
print(f"\nBatched Results:")
|
||||
print(f" Total frames: {frame_count_batch}")
|
||||
print(f" Total time: {batch_time:.2f}s")
|
||||
print(f" Combined FPS: {batch_fps:.2f}")
|
||||
print(f" Per-camera FPS: {batch_fps / len(camera_urls):.2f}")
|
||||
|
||||
# ==================== COMPARISON ====================
|
||||
print("\n" + "=" * 80)
|
||||
print("COMPARISON")
|
||||
print("=" * 80)
|
||||
|
||||
improvement = ((batch_fps - seq_fps) / seq_fps) * 100
|
||||
|
||||
print(f"\nSequential: {seq_fps:.2f} FPS combined ({seq_fps / len(camera_urls):.2f} per camera)")
|
||||
print(f"Batched: {batch_fps:.2f} FPS combined ({batch_fps / len(camera_urls):.2f} per camera)")
|
||||
print(f"\nImprovement: {improvement:+.1f}%")
|
||||
|
||||
if improvement > 10:
|
||||
print("✓ Significant improvement with batch inference!")
|
||||
elif improvement > 0:
|
||||
print("✓ Moderate improvement with batch inference")
|
||||
else:
|
||||
print("⚠ No improvement - check batch model configuration")
|
||||
|
||||
# Cleanup
|
||||
print("\n" + "=" * 80)
|
||||
print("Cleanup")
|
||||
print("=" * 80)
|
||||
|
||||
for i, decoder in enumerate(decoders):
|
||||
decoder.stop()
|
||||
print(f" Stopped camera {i+1}")
|
||||
|
||||
print("\n✓ Benchmark complete!")
|
||||
|
||||
|
||||
def test_batch_preprocessing():
|
||||
"""Test that batch preprocessing works correctly"""
|
||||
print("\n" + "=" * 80)
|
||||
print("BATCH PREPROCESSING TEST")
|
||||
print("=" * 80)
|
||||
|
||||
# Create dummy frames
|
||||
device = torch.device('cuda:0')
|
||||
frames = [
|
||||
torch.randint(0, 256, (3, 720, 1280), dtype=torch.uint8, device=device)
|
||||
for _ in range(4)
|
||||
]
|
||||
|
||||
print(f"\nInput: {len(frames)} frames, each {frames[0].shape}")
|
||||
|
||||
# Test batch preprocessing
|
||||
batch = preprocess_batch(frames)
|
||||
print(f"Output: {batch.shape} (expected: [4, 3, 640, 640])")
|
||||
print(f"dtype: {batch.dtype} (expected: torch.float32)")
|
||||
print(f"range: [{batch.min():.3f}, {batch.max():.3f}] (expected: [0.0, 1.0])")
|
||||
|
||||
assert batch.shape == (4, 3, 640, 640), "Batch shape mismatch"
|
||||
assert batch.dtype == torch.float32, "Dtype mismatch"
|
||||
assert 0.0 <= batch.min() and batch.max() <= 1.0, "Value range incorrect"
|
||||
|
||||
print("\n✓ Batch preprocessing test passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test batch preprocessing
|
||||
test_batch_preprocessing()
|
||||
|
||||
# Run benchmark
|
||||
benchmark_sequential_vs_batch(duration=30)
|
||||
Loading…
Add table
Add a link
Reference in a new issue