feat: inference subsystem and optimization to decoder
This commit is contained in:
commit
3c83a57e44
19 changed files with 3897 additions and 0 deletions
562
scripts/convert_pt_to_tensorrt.py
Executable file
562
scripts/convert_pt_to_tensorrt.py
Executable file
|
|
@ -0,0 +1,562 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
PyTorch to TensorRT Model Conversion Script
|
||||
|
||||
This script converts PyTorch models (.pt, .pth) to TensorRT engines (.trt) for optimized inference.
|
||||
|
||||
Features:
|
||||
- Automatic FP32/FP16/INT8 precision modes
|
||||
- Dynamic batch size support
|
||||
- Input shape validation
|
||||
- Optimization profiles for dynamic shapes
|
||||
- ONNX intermediate format
|
||||
- GPU-accelerated conversion
|
||||
|
||||
Usage:
|
||||
python convert_pt_to_tensorrt.py --model path/to/model.pt --output models/model.trt
|
||||
python convert_pt_to_tensorrt.py --model yolov8n.pt --input-shape 1 3 640 640 --fp16
|
||||
python convert_pt_to_tensorrt.py --model model.pt --dynamic-batch --max-batch 16
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Tuple, List, Optional
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TensorRTConverter:
|
||||
"""Converts PyTorch models to TensorRT engines"""
|
||||
|
||||
def __init__(self, gpu_id: int = 0, verbose: bool = True):
|
||||
"""
|
||||
Initialize the converter.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID to use for conversion
|
||||
verbose: Enable verbose logging
|
||||
"""
|
||||
self.gpu_id = gpu_id
|
||||
self.device = torch.device(f'cuda:{gpu_id}')
|
||||
|
||||
# TensorRT logger
|
||||
log_level = trt.Logger.VERBOSE if verbose else trt.Logger.WARNING
|
||||
self.logger = trt.Logger(log_level)
|
||||
|
||||
# Set CUDA device
|
||||
torch.cuda.set_device(gpu_id)
|
||||
|
||||
print(f"Initialized TensorRT Converter on GPU {gpu_id}")
|
||||
print(f"PyTorch version: {torch.__version__}")
|
||||
print(f"TensorRT version: {trt.__version__}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(gpu_id)}")
|
||||
|
||||
def load_pytorch_model(self, model_path: str) -> torch.nn.Module:
|
||||
"""
|
||||
Load PyTorch model from file.
|
||||
|
||||
Args:
|
||||
model_path: Path to .pt or .pth file
|
||||
|
||||
Returns:
|
||||
Loaded PyTorch model in eval mode
|
||||
"""
|
||||
print(f"\nLoading PyTorch model from {model_path}...")
|
||||
|
||||
if not Path(model_path).exists():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Load model (weights_only=False for models with custom classes)
|
||||
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if isinstance(checkpoint, dict):
|
||||
if 'model' in checkpoint:
|
||||
model = checkpoint['model']
|
||||
elif 'state_dict' in checkpoint:
|
||||
# Need model architecture - this is a limitation
|
||||
raise ValueError(
|
||||
"Checkpoint contains only state_dict. "
|
||||
"Please provide the complete model or modify this script to load your architecture."
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown checkpoint format")
|
||||
else:
|
||||
model = checkpoint
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
model.to(self.device)
|
||||
|
||||
print(f"✓ Model loaded successfully")
|
||||
return model
|
||||
|
||||
def export_to_onnx(self, model: torch.nn.Module, input_shape: Tuple[int, ...],
|
||||
onnx_path: str, dynamic_batch: bool = False,
|
||||
input_names: List[str] = None, output_names: List[str] = None) -> str:
|
||||
"""
|
||||
Export PyTorch model to ONNX format (intermediate step).
|
||||
|
||||
Args:
|
||||
model: PyTorch model
|
||||
input_shape: Input tensor shape (B, C, H, W)
|
||||
onnx_path: Output path for ONNX file
|
||||
dynamic_batch: Enable dynamic batch dimension
|
||||
input_names: List of input tensor names
|
||||
output_names: List of output tensor names
|
||||
|
||||
Returns:
|
||||
Path to exported ONNX file
|
||||
"""
|
||||
print(f"\nExporting to ONNX format...")
|
||||
print(f"Input shape: {input_shape}")
|
||||
print(f"Dynamic batch: {dynamic_batch}")
|
||||
|
||||
# Default names
|
||||
if input_names is None:
|
||||
input_names = ['input']
|
||||
if output_names is None:
|
||||
output_names = ['output']
|
||||
|
||||
# Create dummy input
|
||||
dummy_input = torch.randn(*input_shape, device=self.device)
|
||||
|
||||
# Dynamic axes configuration
|
||||
dynamic_axes = None
|
||||
if dynamic_batch:
|
||||
dynamic_axes = {
|
||||
input_names[0]: {0: 'batch'},
|
||||
output_names[0]: {0: 'batch'}
|
||||
}
|
||||
|
||||
# Export to ONNX
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
onnx_path,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=17, # Use recent ONNX opset
|
||||
do_constant_folding=True,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
print(f"✓ ONNX model exported to {onnx_path}")
|
||||
return onnx_path
|
||||
|
||||
def build_tensorrt_engine_from_onnx(self, onnx_path: str, engine_path: str,
|
||||
fp16: bool = False, int8: bool = False,
|
||||
max_workspace_size: int = 4,
|
||||
min_batch: int = 1, opt_batch: int = 1, max_batch: int = 1) -> str:
|
||||
"""
|
||||
Build TensorRT engine from ONNX model.
|
||||
|
||||
Args:
|
||||
onnx_path: Path to ONNX model
|
||||
engine_path: Output path for TensorRT engine
|
||||
fp16: Enable FP16 precision
|
||||
int8: Enable INT8 precision (requires calibration)
|
||||
max_workspace_size: Maximum workspace size in GB
|
||||
min_batch: Minimum batch size for optimization
|
||||
opt_batch: Optimal batch size for optimization
|
||||
max_batch: Maximum batch size for optimization
|
||||
|
||||
Returns:
|
||||
Path to built TensorRT engine
|
||||
"""
|
||||
print(f"\nBuilding TensorRT engine from ONNX...")
|
||||
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
|
||||
print(f"Workspace size: {max_workspace_size} GB")
|
||||
|
||||
# Create builder and network
|
||||
builder = trt.Builder(self.logger)
|
||||
network = builder.create_network(
|
||||
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
)
|
||||
parser = trt.OnnxParser(network, self.logger)
|
||||
|
||||
# Parse ONNX model
|
||||
print(f"Loading ONNX file from {onnx_path}...")
|
||||
with open(onnx_path, 'rb') as f:
|
||||
if not parser.parse(f.read()):
|
||||
print("ERROR: Failed to parse the ONNX file:")
|
||||
for error in range(parser.num_errors):
|
||||
print(f" {parser.get_error(error)}")
|
||||
raise RuntimeError("Failed to parse ONNX model")
|
||||
|
||||
print(f"✓ ONNX model parsed successfully")
|
||||
|
||||
# Print network info
|
||||
print(f"\nNetwork Information:")
|
||||
print(f" Inputs: {network.num_inputs}")
|
||||
for i in range(network.num_inputs):
|
||||
inp = network.get_input(i)
|
||||
print(f" [{i}] {inp.name}: {inp.shape} ({inp.dtype})")
|
||||
|
||||
print(f" Outputs: {network.num_outputs}")
|
||||
for i in range(network.num_outputs):
|
||||
out = network.get_output(i)
|
||||
print(f" [{i}] {out.name}: {out.shape} ({out.dtype})")
|
||||
|
||||
# Create builder config
|
||||
config = builder.create_builder_config()
|
||||
|
||||
# Set workspace size
|
||||
config.set_memory_pool_limit(
|
||||
trt.MemoryPoolType.WORKSPACE,
|
||||
max_workspace_size * (1 << 30) # GB to bytes
|
||||
)
|
||||
|
||||
# Enable precision modes
|
||||
if fp16:
|
||||
if not builder.platform_has_fast_fp16:
|
||||
print("Warning: FP16 not supported on this platform, using FP32")
|
||||
else:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
print("✓ FP16 mode enabled")
|
||||
|
||||
if int8:
|
||||
if not builder.platform_has_fast_int8:
|
||||
print("Warning: INT8 not supported on this platform, using FP32/FP16")
|
||||
else:
|
||||
config.set_flag(trt.BuilderFlag.INT8)
|
||||
print("✓ INT8 mode enabled")
|
||||
print("Note: INT8 calibration not implemented. Results may be suboptimal.")
|
||||
|
||||
# Set optimization profile for dynamic shapes
|
||||
if max_batch > 1 or min_batch != max_batch:
|
||||
profile = builder.create_optimization_profile()
|
||||
|
||||
for i in range(network.num_inputs):
|
||||
inp = network.get_input(i)
|
||||
shape = list(inp.shape)
|
||||
|
||||
# Handle dynamic batch dimension
|
||||
if shape[0] == -1:
|
||||
# Min, opt, max shapes
|
||||
min_shape = [min_batch] + shape[1:]
|
||||
opt_shape = [opt_batch] + shape[1:]
|
||||
max_shape = [max_batch] + shape[1:]
|
||||
|
||||
profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
|
||||
print(f" Dynamic shape for {inp.name}:")
|
||||
print(f" Min: {min_shape}")
|
||||
print(f" Opt: {opt_shape}")
|
||||
print(f" Max: {max_shape}")
|
||||
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
# Build engine
|
||||
print(f"\nBuilding TensorRT engine (this may take a few minutes)...")
|
||||
serialized_engine = builder.build_serialized_network(network, config)
|
||||
|
||||
if serialized_engine is None:
|
||||
raise RuntimeError("Failed to build TensorRT engine")
|
||||
|
||||
# Save engine to file
|
||||
print(f"Saving engine to {engine_path}...")
|
||||
with open(engine_path, 'wb') as f:
|
||||
f.write(serialized_engine)
|
||||
|
||||
# Get file size
|
||||
file_size_mb = Path(engine_path).stat().st_size / (1024 * 1024)
|
||||
print(f"✓ TensorRT engine built successfully")
|
||||
print(f" Engine size: {file_size_mb:.2f} MB")
|
||||
|
||||
return engine_path
|
||||
|
||||
def convert(self, model_path: str, output_path: str,
|
||||
input_shape: Tuple[int, ...] = (1, 3, 640, 640),
|
||||
fp16: bool = False, int8: bool = False,
|
||||
dynamic_batch: bool = False,
|
||||
max_batch: int = 16,
|
||||
workspace_size: int = 4,
|
||||
input_names: List[str] = None,
|
||||
output_names: List[str] = None,
|
||||
keep_onnx: bool = False) -> str:
|
||||
"""
|
||||
Convert PyTorch or ONNX model to TensorRT engine.
|
||||
|
||||
Args:
|
||||
model_path: Path to PyTorch model (.pt, .pth) or ONNX model (.onnx)
|
||||
output_path: Path for output TensorRT engine (.trt)
|
||||
input_shape: Input tensor shape (B, C, H, W) - required for PyTorch models
|
||||
fp16: Enable FP16 precision
|
||||
int8: Enable INT8 precision
|
||||
dynamic_batch: Enable dynamic batch size
|
||||
max_batch: Maximum batch size (for dynamic batching)
|
||||
workspace_size: TensorRT workspace size in GB
|
||||
input_names: Custom input names (for PyTorch export)
|
||||
output_names: Custom output names (for PyTorch export)
|
||||
keep_onnx: Keep intermediate ONNX file
|
||||
|
||||
Returns:
|
||||
Path to created TensorRT engine
|
||||
"""
|
||||
# Create output directory
|
||||
output_dir = Path(output_path).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if input is already ONNX
|
||||
model_path_obj = Path(model_path)
|
||||
is_onnx = model_path_obj.suffix.lower() == '.onnx'
|
||||
|
||||
if is_onnx:
|
||||
# Direct ONNX to TensorRT conversion
|
||||
print(f"Input is ONNX model, converting directly to TensorRT...")
|
||||
|
||||
min_batch = 1
|
||||
opt_batch = input_shape[0] if not dynamic_batch else max(1, max_batch // 2)
|
||||
max_batch_size = max_batch if dynamic_batch else input_shape[0]
|
||||
|
||||
engine_path = self.build_tensorrt_engine_from_onnx(
|
||||
onnx_path=model_path,
|
||||
engine_path=output_path,
|
||||
fp16=fp16,
|
||||
int8=int8,
|
||||
max_workspace_size=workspace_size,
|
||||
min_batch=min_batch,
|
||||
opt_batch=opt_batch,
|
||||
max_batch=max_batch_size
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"CONVERSION COMPLETED SUCCESSFULLY")
|
||||
print(f"{'=' * 80}")
|
||||
print(f"Input: {model_path}")
|
||||
print(f"Output: {engine_path}")
|
||||
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
return engine_path
|
||||
|
||||
# PyTorch to TensorRT conversion (via ONNX)
|
||||
# Temporary ONNX path
|
||||
onnx_path = str(output_dir / "temp_model.onnx")
|
||||
|
||||
try:
|
||||
# Step 1: Load PyTorch model
|
||||
model = self.load_pytorch_model(model_path)
|
||||
|
||||
# Step 2: Export to ONNX
|
||||
self.export_to_onnx(
|
||||
model=model,
|
||||
input_shape=input_shape,
|
||||
onnx_path=onnx_path,
|
||||
dynamic_batch=dynamic_batch,
|
||||
input_names=input_names,
|
||||
output_names=output_names
|
||||
)
|
||||
|
||||
# Step 3: Build TensorRT engine
|
||||
min_batch = 1
|
||||
opt_batch = input_shape[0] if not dynamic_batch else max(1, max_batch // 2)
|
||||
max_batch_size = max_batch if dynamic_batch else input_shape[0]
|
||||
|
||||
engine_path = self.build_tensorrt_engine_from_onnx(
|
||||
onnx_path=onnx_path,
|
||||
engine_path=output_path,
|
||||
fp16=fp16,
|
||||
int8=int8,
|
||||
max_workspace_size=workspace_size,
|
||||
min_batch=min_batch,
|
||||
opt_batch=opt_batch,
|
||||
max_batch=max_batch_size
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"CONVERSION COMPLETED SUCCESSFULLY")
|
||||
print(f"{'=' * 80}")
|
||||
print(f"Input: {model_path}")
|
||||
print(f"Output: {engine_path}")
|
||||
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
|
||||
print(f"Dynamic batch: {dynamic_batch}")
|
||||
if dynamic_batch:
|
||||
print(f"Batch range: [1, {max_batch}]")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
return engine_path
|
||||
|
||||
finally:
|
||||
# Cleanup temporary ONNX file
|
||||
if not keep_onnx and Path(onnx_path).exists():
|
||||
Path(onnx_path).unlink()
|
||||
print(f"Cleaned up temporary ONNX file")
|
||||
|
||||
|
||||
def parse_shape(shape_str: str) -> Tuple[int, ...]:
|
||||
"""Parse shape string like '1,3,640,640' to tuple"""
|
||||
try:
|
||||
return tuple(int(x) for x in shape_str.split(','))
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid shape format: {shape_str}. Expected format: 1,3,640,640"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert PyTorch models to TensorRT engines",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Basic conversion (FP32)
|
||||
python convert_pt_to_tensorrt.py --model yolov8n.pt --output models/yolov8n.trt
|
||||
|
||||
# FP16 precision for faster inference
|
||||
python convert_pt_to_tensorrt.py --model model.pt --output model.trt --fp16
|
||||
|
||||
# Custom input shape
|
||||
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
||||
--input-shape 1,3,416,416
|
||||
|
||||
# Dynamic batch size (1 to 16)
|
||||
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
||||
--dynamic-batch --max-batch 16
|
||||
|
||||
# INT8 quantization for maximum speed (requires calibration)
|
||||
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
||||
--fp16 --int8
|
||||
|
||||
# Keep intermediate ONNX file for debugging
|
||||
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
||||
--keep-onnx
|
||||
"""
|
||||
)
|
||||
|
||||
# Required arguments
|
||||
parser.add_argument(
|
||||
'--model', '-m',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to PyTorch model file (.pt or .pth)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Output path for TensorRT engine (.trt or .engine)'
|
||||
)
|
||||
|
||||
# Optional arguments
|
||||
parser.add_argument(
|
||||
'--input-shape', '-s',
|
||||
type=parse_shape,
|
||||
default=(1, 3, 640, 640),
|
||||
help='Input tensor shape as B,C,H,W (default: 1,3,640,640)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--fp16',
|
||||
action='store_true',
|
||||
help='Enable FP16 precision (faster inference, slightly lower accuracy)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--int8',
|
||||
action='store_true',
|
||||
help='Enable INT8 precision (fastest, requires calibration)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--dynamic-batch',
|
||||
action='store_true',
|
||||
help='Enable dynamic batch size support'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-batch',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Maximum batch size for dynamic batching (default: 16)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--workspace-size',
|
||||
type=int,
|
||||
default=4,
|
||||
help='TensorRT workspace size in GB (default: 4)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--gpu',
|
||||
type=int,
|
||||
default=0,
|
||||
help='GPU device ID (default: 0)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--input-names',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='Custom input tensor names (default: ["input"])'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output-names',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='Custom output tensor names (default: ["output"])'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--keep-onnx',
|
||||
action='store_true',
|
||||
help='Keep intermediate ONNX file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Enable verbose logging'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if not Path(args.model).exists():
|
||||
print(f"Error: Model file not found: {args.model}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.int8 and not args.fp16:
|
||||
print("Warning: INT8 mode works best with FP16 enabled. Adding --fp16 flag.")
|
||||
args.fp16 = True
|
||||
|
||||
# Run conversion
|
||||
try:
|
||||
converter = TensorRTConverter(gpu_id=args.gpu, verbose=args.verbose)
|
||||
|
||||
converter.convert(
|
||||
model_path=args.model,
|
||||
output_path=args.output,
|
||||
input_shape=args.input_shape,
|
||||
fp16=args.fp16,
|
||||
int8=args.int8,
|
||||
dynamic_batch=args.dynamic_batch,
|
||||
max_batch=args.max_batch,
|
||||
workspace_size=args.workspace_size,
|
||||
input_names=args.input_names,
|
||||
output_names=args.output_names,
|
||||
keep_onnx=args.keep_onnx
|
||||
)
|
||||
|
||||
print("\n✓ Conversion successful!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Conversion failed: {e}")
|
||||
if args.verbose:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue