554 lines
18 KiB
Python
Executable file
554 lines
18 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
PyTorch to TensorRT Model Conversion Script
|
|
|
|
This script converts PyTorch models (.pt, .pth) to TensorRT engines (.trt) for optimized inference.
|
|
|
|
Features:
|
|
- Automatic FP32/FP16/INT8 precision modes
|
|
- Dynamic batch size support
|
|
- Input shape validation
|
|
- Optimization profiles for dynamic shapes
|
|
- ONNX intermediate format
|
|
- GPU-accelerated conversion
|
|
|
|
Usage:
|
|
python convert_pt_to_tensorrt.py --model path/to/model.pt --output models/model.trt
|
|
python convert_pt_to_tensorrt.py --model yolov8n.pt --input-shape 1 3 640 640 --fp16
|
|
python convert_pt_to_tensorrt.py --model model.pt --dynamic-batch --max-batch 16
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import tensorrt as trt
|
|
import torch
|
|
|
|
|
|
class TensorRTConverter:
|
|
"""Converts PyTorch models to TensorRT engines"""
|
|
|
|
def __init__(self, gpu_id: int = 0, verbose: bool = True):
|
|
"""
|
|
Initialize the converter.
|
|
|
|
Args:
|
|
gpu_id: GPU device ID to use for conversion
|
|
verbose: Enable verbose logging
|
|
"""
|
|
self.gpu_id = gpu_id
|
|
self.device = torch.device(f"cuda:{gpu_id}")
|
|
|
|
# TensorRT logger
|
|
log_level = trt.Logger.VERBOSE if verbose else trt.Logger.WARNING
|
|
self.logger = trt.Logger(log_level)
|
|
|
|
# Set CUDA device
|
|
torch.cuda.set_device(gpu_id)
|
|
|
|
print(f"Initialized TensorRT Converter on GPU {gpu_id}")
|
|
print(f"PyTorch version: {torch.__version__}")
|
|
print(f"TensorRT version: {trt.__version__}")
|
|
print(f"CUDA available: {torch.cuda.is_available()}")
|
|
if torch.cuda.is_available():
|
|
print(f"CUDA device: {torch.cuda.get_device_name(gpu_id)}")
|
|
|
|
def load_pytorch_model(self, model_path: str) -> torch.nn.Module:
|
|
"""
|
|
Load PyTorch model from file.
|
|
|
|
Args:
|
|
model_path: Path to .pt or .pth file
|
|
|
|
Returns:
|
|
Loaded PyTorch model in eval mode
|
|
"""
|
|
print(f"\nLoading PyTorch model from {model_path}...")
|
|
|
|
if not Path(model_path).exists():
|
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
|
|
# Load model (weights_only=False for models with custom classes)
|
|
checkpoint = torch.load(
|
|
model_path, map_location=self.device, weights_only=False
|
|
)
|
|
|
|
# Handle different checkpoint formats
|
|
if isinstance(checkpoint, dict):
|
|
if "model" in checkpoint:
|
|
model = checkpoint["model"]
|
|
elif "state_dict" in checkpoint:
|
|
# Need model architecture - this is a limitation
|
|
raise ValueError(
|
|
"Checkpoint contains only state_dict. "
|
|
"Please provide the complete model or modify this script to load your architecture."
|
|
)
|
|
else:
|
|
raise ValueError("Unknown checkpoint format")
|
|
else:
|
|
model = checkpoint
|
|
|
|
# Set to eval mode
|
|
model.eval()
|
|
model.to(self.device)
|
|
|
|
print(f"✓ Model loaded successfully")
|
|
return model
|
|
|
|
def export_to_onnx(
|
|
self,
|
|
model: torch.nn.Module,
|
|
input_shape: Tuple[int, ...],
|
|
onnx_path: str,
|
|
dynamic_batch: bool = False,
|
|
input_names: List[str] = None,
|
|
output_names: List[str] = None,
|
|
) -> str:
|
|
"""
|
|
Export PyTorch model to ONNX format (intermediate step).
|
|
|
|
Args:
|
|
model: PyTorch model
|
|
input_shape: Input tensor shape (B, C, H, W)
|
|
onnx_path: Output path for ONNX file
|
|
dynamic_batch: Enable dynamic batch dimension
|
|
input_names: List of input tensor names
|
|
output_names: List of output tensor names
|
|
|
|
Returns:
|
|
Path to exported ONNX file
|
|
"""
|
|
print(f"\nExporting to ONNX format...")
|
|
print(f"Input shape: {input_shape}")
|
|
print(f"Dynamic batch: {dynamic_batch}")
|
|
|
|
# Default names
|
|
if input_names is None:
|
|
input_names = ["input"]
|
|
if output_names is None:
|
|
output_names = ["output"]
|
|
|
|
# Create dummy input
|
|
dummy_input = torch.randn(*input_shape, device=self.device)
|
|
|
|
# Dynamic axes configuration
|
|
dynamic_axes = None
|
|
if dynamic_batch:
|
|
dynamic_axes = {input_names[0]: {0: "batch"}, output_names[0]: {0: "batch"}}
|
|
|
|
# Export to ONNX
|
|
torch.onnx.export(
|
|
model,
|
|
dummy_input,
|
|
onnx_path,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
dynamic_axes=dynamic_axes,
|
|
opset_version=17, # Use recent ONNX opset
|
|
do_constant_folding=True,
|
|
verbose=False,
|
|
)
|
|
|
|
print(f"✓ ONNX model exported to {onnx_path}")
|
|
return onnx_path
|
|
|
|
def build_tensorrt_engine_from_onnx(
|
|
self,
|
|
onnx_path: str,
|
|
engine_path: str,
|
|
fp16: bool = False,
|
|
int8: bool = False,
|
|
min_batch: int = 1,
|
|
opt_batch: int = 1,
|
|
max_batch: int = 1,
|
|
) -> str:
|
|
"""
|
|
Build TensorRT engine from ONNX model.
|
|
|
|
Args:
|
|
onnx_path: Path to ONNX model
|
|
engine_path: Output path for TensorRT engine
|
|
fp16: Enable FP16 precision
|
|
int8: Enable INT8 precision (requires calibration)
|
|
min_batch: Minimum batch size for optimization
|
|
opt_batch: Optimal batch size for optimization
|
|
max_batch: Maximum batch size for optimization
|
|
|
|
Returns:
|
|
Path to built TensorRT engine
|
|
"""
|
|
print(f"\nBuilding TensorRT engine from ONNX...")
|
|
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
|
|
|
|
# Create builder and network
|
|
builder = trt.Builder(self.logger)
|
|
network = builder.create_network(
|
|
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
)
|
|
parser = trt.OnnxParser(network, self.logger)
|
|
|
|
# Parse ONNX model
|
|
print(f"Loading ONNX file from {onnx_path}...")
|
|
with open(onnx_path, "rb") as f:
|
|
if not parser.parse(f.read()):
|
|
print("ERROR: Failed to parse the ONNX file:")
|
|
for error in range(parser.num_errors):
|
|
print(f" {parser.get_error(error)}")
|
|
raise RuntimeError("Failed to parse ONNX model")
|
|
|
|
print(f"✓ ONNX model parsed successfully")
|
|
|
|
# Print network info
|
|
print(f"\nNetwork Information:")
|
|
print(f" Inputs: {network.num_inputs}")
|
|
for i in range(network.num_inputs):
|
|
inp = network.get_input(i)
|
|
print(f" [{i}] {inp.name}: {inp.shape} ({inp.dtype})")
|
|
|
|
print(f" Outputs: {network.num_outputs}")
|
|
for i in range(network.num_outputs):
|
|
out = network.get_output(i)
|
|
print(f" [{i}] {out.name}: {out.shape} ({out.dtype})")
|
|
|
|
# Create builder config
|
|
config = builder.create_builder_config()
|
|
|
|
# Enable precision modes
|
|
if fp16:
|
|
if not builder.platform_has_fast_fp16:
|
|
print("Warning: FP16 not supported on this platform, using FP32")
|
|
else:
|
|
config.set_flag(trt.BuilderFlag.FP16)
|
|
print("✓ FP16 mode enabled")
|
|
|
|
if int8:
|
|
if not builder.platform_has_fast_int8:
|
|
print("Warning: INT8 not supported on this platform, using FP32/FP16")
|
|
else:
|
|
config.set_flag(trt.BuilderFlag.INT8)
|
|
print("✓ INT8 mode enabled")
|
|
print(
|
|
"Note: INT8 calibration not implemented. Results may be suboptimal."
|
|
)
|
|
|
|
# Set optimization profile for dynamic shapes
|
|
if max_batch > 1 or min_batch != max_batch:
|
|
profile = builder.create_optimization_profile()
|
|
|
|
for i in range(network.num_inputs):
|
|
inp = network.get_input(i)
|
|
shape = list(inp.shape)
|
|
|
|
# Handle dynamic batch dimension
|
|
if shape[0] == -1:
|
|
# Min, opt, max shapes
|
|
min_shape = [min_batch] + shape[1:]
|
|
opt_shape = [opt_batch] + shape[1:]
|
|
max_shape = [max_batch] + shape[1:]
|
|
|
|
profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
|
|
print(f" Dynamic shape for {inp.name}:")
|
|
print(f" Min: {min_shape}")
|
|
print(f" Opt: {opt_shape}")
|
|
print(f" Max: {max_shape}")
|
|
|
|
config.add_optimization_profile(profile)
|
|
|
|
# Build engine
|
|
print(f"\nBuilding TensorRT engine (this may take a few minutes)...")
|
|
serialized_engine = builder.build_serialized_network(network, config)
|
|
|
|
if serialized_engine is None:
|
|
raise RuntimeError("Failed to build TensorRT engine")
|
|
|
|
# Save engine to file
|
|
print(f"Saving engine to {engine_path}...")
|
|
with open(engine_path, "wb") as f:
|
|
f.write(serialized_engine)
|
|
|
|
# Get file size
|
|
file_size_mb = Path(engine_path).stat().st_size / (1024 * 1024)
|
|
print(f"✓ TensorRT engine built successfully")
|
|
print(f" Engine size: {file_size_mb:.2f} MB")
|
|
|
|
return engine_path
|
|
|
|
def convert(
|
|
self,
|
|
model_path: str,
|
|
output_path: str,
|
|
input_shape: Tuple[int, ...] = (1, 3, 640, 640),
|
|
fp16: bool = False,
|
|
int8: bool = False,
|
|
dynamic_batch: bool = False,
|
|
max_batch: int = 16,
|
|
input_names: List[str] = None,
|
|
output_names: List[str] = None,
|
|
keep_onnx: bool = False,
|
|
) -> str:
|
|
"""
|
|
Convert PyTorch or ONNX model to TensorRT engine.
|
|
|
|
Args:
|
|
model_path: Path to PyTorch model (.pt, .pth) or ONNX model (.onnx)
|
|
output_path: Path for output TensorRT engine (.trt)
|
|
input_shape: Input tensor shape (B, C, H, W) - required for PyTorch models
|
|
fp16: Enable FP16 precision
|
|
int8: Enable INT8 precision
|
|
dynamic_batch: Enable dynamic batch size
|
|
max_batch: Maximum batch size (for dynamic batching)
|
|
input_names: Custom input names (for PyTorch export)
|
|
output_names: Custom output names (for PyTorch export)
|
|
keep_onnx: Keep intermediate ONNX file
|
|
|
|
Returns:
|
|
Path to created TensorRT engine
|
|
"""
|
|
# Create output directory
|
|
output_dir = Path(output_path).parent
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Check if input is already ONNX
|
|
model_path_obj = Path(model_path)
|
|
is_onnx = model_path_obj.suffix.lower() == ".onnx"
|
|
|
|
if is_onnx:
|
|
# Direct ONNX to TensorRT conversion
|
|
print(f"Input is ONNX model, converting directly to TensorRT...")
|
|
|
|
min_batch = 1
|
|
opt_batch = input_shape[0] if not dynamic_batch else max(1, max_batch // 2)
|
|
max_batch_size = max_batch if dynamic_batch else input_shape[0]
|
|
|
|
engine_path = self.build_tensorrt_engine_from_onnx(
|
|
onnx_path=model_path,
|
|
engine_path=output_path,
|
|
fp16=fp16,
|
|
int8=int8,
|
|
min_batch=min_batch,
|
|
opt_batch=opt_batch,
|
|
max_batch=max_batch_size,
|
|
)
|
|
|
|
print(f"\n{'=' * 80}")
|
|
print(f"CONVERSION COMPLETED SUCCESSFULLY")
|
|
print(f"{'=' * 80}")
|
|
print(f"Input: {model_path}")
|
|
print(f"Output: {engine_path}")
|
|
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
|
|
print(f"{'=' * 80}")
|
|
|
|
return engine_path
|
|
|
|
# PyTorch to TensorRT conversion (via ONNX)
|
|
# Temporary ONNX path
|
|
onnx_path = str(output_dir / "temp_model.onnx")
|
|
|
|
try:
|
|
# Step 1: Load PyTorch model
|
|
model = self.load_pytorch_model(model_path)
|
|
|
|
# Step 2: Export to ONNX
|
|
self.export_to_onnx(
|
|
model=model,
|
|
input_shape=input_shape,
|
|
onnx_path=onnx_path,
|
|
dynamic_batch=dynamic_batch,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
)
|
|
|
|
# Step 3: Build TensorRT engine
|
|
min_batch = 1
|
|
opt_batch = input_shape[0] if not dynamic_batch else max(1, max_batch // 2)
|
|
max_batch_size = max_batch if dynamic_batch else input_shape[0]
|
|
|
|
engine_path = self.build_tensorrt_engine_from_onnx(
|
|
onnx_path=onnx_path,
|
|
engine_path=output_path,
|
|
fp16=fp16,
|
|
int8=int8,
|
|
min_batch=min_batch,
|
|
opt_batch=opt_batch,
|
|
max_batch=max_batch_size,
|
|
)
|
|
|
|
print(f"\n{'=' * 80}")
|
|
print(f"CONVERSION COMPLETED SUCCESSFULLY")
|
|
print(f"{'=' * 80}")
|
|
print(f"Input: {model_path}")
|
|
print(f"Output: {engine_path}")
|
|
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
|
|
print(f"Dynamic batch: {dynamic_batch}")
|
|
if dynamic_batch:
|
|
print(f"Batch range: [1, {max_batch}]")
|
|
print(f"{'=' * 80}")
|
|
|
|
return engine_path
|
|
|
|
finally:
|
|
# Cleanup temporary ONNX file
|
|
if not keep_onnx and Path(onnx_path).exists():
|
|
Path(onnx_path).unlink()
|
|
print(f"Cleaned up temporary ONNX file")
|
|
|
|
|
|
def parse_shape(shape_str: str) -> Tuple[int, ...]:
|
|
"""Parse shape string like '1,3,640,640' to tuple"""
|
|
try:
|
|
return tuple(int(x) for x in shape_str.split(","))
|
|
except ValueError:
|
|
raise argparse.ArgumentTypeError(
|
|
f"Invalid shape format: {shape_str}. Expected format: 1,3,640,640"
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert PyTorch models to TensorRT engines",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Basic conversion (FP32)
|
|
python convert_pt_to_tensorrt.py --model yolov8n.pt --output models/yolov8n.trt
|
|
|
|
# FP16 precision for faster inference
|
|
python convert_pt_to_tensorrt.py --model model.pt --output model.trt --fp16
|
|
|
|
# Custom input shape
|
|
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
|
--input-shape 1,3,416,416
|
|
|
|
# Dynamic batch size (1 to 16)
|
|
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
|
--dynamic-batch --max-batch 16
|
|
|
|
# INT8 quantization for maximum speed (requires calibration)
|
|
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
|
--fp16 --int8
|
|
|
|
# Keep intermediate ONNX file for debugging
|
|
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
|
--keep-onnx
|
|
""",
|
|
)
|
|
|
|
# Required arguments
|
|
parser.add_argument(
|
|
"--model",
|
|
"-m",
|
|
type=str,
|
|
required=True,
|
|
help="Path to PyTorch model file (.pt or .pth)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output",
|
|
"-o",
|
|
type=str,
|
|
required=True,
|
|
help="Output path for TensorRT engine (.trt or .engine)",
|
|
)
|
|
|
|
# Optional arguments
|
|
parser.add_argument(
|
|
"--input-shape",
|
|
"-s",
|
|
type=parse_shape,
|
|
default=(1, 3, 640, 640),
|
|
help="Input tensor shape as B,C,H,W (default: 1,3,640,640)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--fp16",
|
|
action="store_true",
|
|
help="Enable FP16 precision (faster inference, slightly lower accuracy)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--int8",
|
|
action="store_true",
|
|
help="Enable INT8 precision (fastest, requires calibration)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dynamic-batch", action="store_true", help="Enable dynamic batch size support"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max-batch",
|
|
type=int,
|
|
default=16,
|
|
help="Maximum batch size for dynamic batching (default: 16)",
|
|
)
|
|
|
|
parser.add_argument("--gpu", type=int, default=0, help="GPU device ID (default: 0)")
|
|
|
|
parser.add_argument(
|
|
"--input-names",
|
|
type=str,
|
|
nargs="+",
|
|
default=None,
|
|
help='Custom input tensor names (default: ["input"])',
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output-names",
|
|
type=str,
|
|
nargs="+",
|
|
default=None,
|
|
help='Custom output tensor names (default: ["output"])',
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--keep-onnx", action="store_true", help="Keep intermediate ONNX file"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--verbose", "-v", action="store_true", help="Enable verbose logging"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate arguments
|
|
if not Path(args.model).exists():
|
|
print(f"Error: Model file not found: {args.model}")
|
|
sys.exit(1)
|
|
|
|
if args.int8 and not args.fp16:
|
|
print("Warning: INT8 mode works best with FP16 enabled. Adding --fp16 flag.")
|
|
args.fp16 = True
|
|
|
|
# Run conversion
|
|
try:
|
|
converter = TensorRTConverter(gpu_id=args.gpu, verbose=args.verbose)
|
|
|
|
converter.convert(
|
|
model_path=args.model,
|
|
output_path=args.output,
|
|
input_shape=args.input_shape,
|
|
fp16=args.fp16,
|
|
int8=args.int8,
|
|
dynamic_batch=args.dynamic_batch,
|
|
max_batch=args.max_batch,
|
|
input_names=args.input_names,
|
|
output_names=args.output_names,
|
|
keep_onnx=args.keep_onnx,
|
|
)
|
|
|
|
print("\n✓ Conversion successful!")
|
|
|
|
except Exception as e:
|
|
print(f"\n✗ Conversion failed: {e}")
|
|
if args.verbose:
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|