ultralytic export

This commit is contained in:
Siwat Sirichai 2025-11-11 01:28:19 +07:00
parent bf7b68edb1
commit fdaeb9981c
14 changed files with 2241 additions and 507 deletions

View file

@ -21,10 +21,11 @@ Usage:
import argparse
import sys
from pathlib import Path
from typing import Tuple, List, Optional
import torch
import tensorrt as trt
from typing import List, Optional, Tuple
import numpy as np
import tensorrt as trt
import torch
class TensorRTConverter:
@ -39,7 +40,7 @@ class TensorRTConverter:
verbose: Enable verbose logging
"""
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.device = torch.device(f"cuda:{gpu_id}")
# TensorRT logger
log_level = trt.Logger.VERBOSE if verbose else trt.Logger.WARNING
@ -71,13 +72,15 @@ class TensorRTConverter:
raise FileNotFoundError(f"Model file not found: {model_path}")
# Load model (weights_only=False for models with custom classes)
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
checkpoint = torch.load(
model_path, map_location=self.device, weights_only=False
)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model' in checkpoint:
model = checkpoint['model']
elif 'state_dict' in checkpoint:
if "model" in checkpoint:
model = checkpoint["model"]
elif "state_dict" in checkpoint:
# Need model architecture - this is a limitation
raise ValueError(
"Checkpoint contains only state_dict. "
@ -95,9 +98,15 @@ class TensorRTConverter:
print(f"✓ Model loaded successfully")
return model
def export_to_onnx(self, model: torch.nn.Module, input_shape: Tuple[int, ...],
onnx_path: str, dynamic_batch: bool = False,
input_names: List[str] = None, output_names: List[str] = None) -> str:
def export_to_onnx(
self,
model: torch.nn.Module,
input_shape: Tuple[int, ...],
onnx_path: str,
dynamic_batch: bool = False,
input_names: List[str] = None,
output_names: List[str] = None,
) -> str:
"""
Export PyTorch model to ONNX format (intermediate step).
@ -118,9 +127,9 @@ class TensorRTConverter:
# Default names
if input_names is None:
input_names = ['input']
input_names = ["input"]
if output_names is None:
output_names = ['output']
output_names = ["output"]
# Create dummy input
dummy_input = torch.randn(*input_shape, device=self.device)
@ -128,10 +137,7 @@ class TensorRTConverter:
# Dynamic axes configuration
dynamic_axes = None
if dynamic_batch:
dynamic_axes = {
input_names[0]: {0: 'batch'},
output_names[0]: {0: 'batch'}
}
dynamic_axes = {input_names[0]: {0: "batch"}, output_names[0]: {0: "batch"}}
# Export to ONNX
torch.onnx.export(
@ -143,16 +149,22 @@ class TensorRTConverter:
dynamic_axes=dynamic_axes,
opset_version=17, # Use recent ONNX opset
do_constant_folding=True,
verbose=False
verbose=False,
)
print(f"✓ ONNX model exported to {onnx_path}")
return onnx_path
def build_tensorrt_engine_from_onnx(self, onnx_path: str, engine_path: str,
fp16: bool = False, int8: bool = False,
max_workspace_size: int = 4,
min_batch: int = 1, opt_batch: int = 1, max_batch: int = 1) -> str:
def build_tensorrt_engine_from_onnx(
self,
onnx_path: str,
engine_path: str,
fp16: bool = False,
int8: bool = False,
min_batch: int = 1,
opt_batch: int = 1,
max_batch: int = 1,
) -> str:
"""
Build TensorRT engine from ONNX model.
@ -161,7 +173,6 @@ class TensorRTConverter:
engine_path: Output path for TensorRT engine
fp16: Enable FP16 precision
int8: Enable INT8 precision (requires calibration)
max_workspace_size: Maximum workspace size in GB
min_batch: Minimum batch size for optimization
opt_batch: Optimal batch size for optimization
max_batch: Maximum batch size for optimization
@ -171,7 +182,6 @@ class TensorRTConverter:
"""
print(f"\nBuilding TensorRT engine from ONNX...")
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
print(f"Workspace size: {max_workspace_size} GB")
# Create builder and network
builder = trt.Builder(self.logger)
@ -182,7 +192,7 @@ class TensorRTConverter:
# Parse ONNX model
print(f"Loading ONNX file from {onnx_path}...")
with open(onnx_path, 'rb') as f:
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
print("ERROR: Failed to parse the ONNX file:")
for error in range(parser.num_errors):
@ -206,12 +216,6 @@ class TensorRTConverter:
# Create builder config
config = builder.create_builder_config()
# Set workspace size
config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE,
max_workspace_size * (1 << 30) # GB to bytes
)
# Enable precision modes
if fp16:
if not builder.platform_has_fast_fp16:
@ -226,7 +230,9 @@ class TensorRTConverter:
else:
config.set_flag(trt.BuilderFlag.INT8)
print("✓ INT8 mode enabled")
print("Note: INT8 calibration not implemented. Results may be suboptimal.")
print(
"Note: INT8 calibration not implemented. Results may be suboptimal."
)
# Set optimization profile for dynamic shapes
if max_batch > 1 or min_batch != max_batch:
@ -260,7 +266,7 @@ class TensorRTConverter:
# Save engine to file
print(f"Saving engine to {engine_path}...")
with open(engine_path, 'wb') as f:
with open(engine_path, "wb") as f:
f.write(serialized_engine)
# Get file size
@ -270,15 +276,19 @@ class TensorRTConverter:
return engine_path
def convert(self, model_path: str, output_path: str,
input_shape: Tuple[int, ...] = (1, 3, 640, 640),
fp16: bool = False, int8: bool = False,
dynamic_batch: bool = False,
max_batch: int = 16,
workspace_size: int = 4,
input_names: List[str] = None,
output_names: List[str] = None,
keep_onnx: bool = False) -> str:
def convert(
self,
model_path: str,
output_path: str,
input_shape: Tuple[int, ...] = (1, 3, 640, 640),
fp16: bool = False,
int8: bool = False,
dynamic_batch: bool = False,
max_batch: int = 16,
input_names: List[str] = None,
output_names: List[str] = None,
keep_onnx: bool = False,
) -> str:
"""
Convert PyTorch or ONNX model to TensorRT engine.
@ -290,7 +300,6 @@ class TensorRTConverter:
int8: Enable INT8 precision
dynamic_batch: Enable dynamic batch size
max_batch: Maximum batch size (for dynamic batching)
workspace_size: TensorRT workspace size in GB
input_names: Custom input names (for PyTorch export)
output_names: Custom output names (for PyTorch export)
keep_onnx: Keep intermediate ONNX file
@ -304,7 +313,7 @@ class TensorRTConverter:
# Check if input is already ONNX
model_path_obj = Path(model_path)
is_onnx = model_path_obj.suffix.lower() == '.onnx'
is_onnx = model_path_obj.suffix.lower() == ".onnx"
if is_onnx:
# Direct ONNX to TensorRT conversion
@ -319,10 +328,9 @@ class TensorRTConverter:
engine_path=output_path,
fp16=fp16,
int8=int8,
max_workspace_size=workspace_size,
min_batch=min_batch,
opt_batch=opt_batch,
max_batch=max_batch_size
max_batch=max_batch_size,
)
print(f"\n{'=' * 80}")
@ -350,7 +358,7 @@ class TensorRTConverter:
onnx_path=onnx_path,
dynamic_batch=dynamic_batch,
input_names=input_names,
output_names=output_names
output_names=output_names,
)
# Step 3: Build TensorRT engine
@ -363,10 +371,9 @@ class TensorRTConverter:
engine_path=output_path,
fp16=fp16,
int8=int8,
max_workspace_size=workspace_size,
min_batch=min_batch,
opt_batch=opt_batch,
max_batch=max_batch_size
max_batch=max_batch_size,
)
print(f"\n{'=' * 80}")
@ -392,7 +399,7 @@ class TensorRTConverter:
def parse_shape(shape_str: str) -> Tuple[int, ...]:
"""Parse shape string like '1,3,640,640' to tuple"""
try:
return tuple(int(x) for x in shape_str.split(','))
return tuple(int(x) for x in shape_str.split(","))
except ValueError:
raise argparse.ArgumentTypeError(
f"Invalid shape format: {shape_str}. Expected format: 1,3,640,640"
@ -426,97 +433,82 @@ Examples:
# Keep intermediate ONNX file for debugging
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
--keep-onnx
"""
""",
)
# Required arguments
parser.add_argument(
'--model', '-m',
"--model",
"-m",
type=str,
required=True,
help='Path to PyTorch model file (.pt or .pth)'
help="Path to PyTorch model file (.pt or .pth)",
)
parser.add_argument(
'--output', '-o',
"--output",
"-o",
type=str,
required=True,
help='Output path for TensorRT engine (.trt or .engine)'
help="Output path for TensorRT engine (.trt or .engine)",
)
# Optional arguments
parser.add_argument(
'--input-shape', '-s',
"--input-shape",
"-s",
type=parse_shape,
default=(1, 3, 640, 640),
help='Input tensor shape as B,C,H,W (default: 1,3,640,640)'
help="Input tensor shape as B,C,H,W (default: 1,3,640,640)",
)
parser.add_argument(
'--fp16',
action='store_true',
help='Enable FP16 precision (faster inference, slightly lower accuracy)'
"--fp16",
action="store_true",
help="Enable FP16 precision (faster inference, slightly lower accuracy)",
)
parser.add_argument(
'--int8',
action='store_true',
help='Enable INT8 precision (fastest, requires calibration)'
"--int8",
action="store_true",
help="Enable INT8 precision (fastest, requires calibration)",
)
parser.add_argument(
'--dynamic-batch',
action='store_true',
help='Enable dynamic batch size support'
"--dynamic-batch", action="store_true", help="Enable dynamic batch size support"
)
parser.add_argument(
'--max-batch',
"--max-batch",
type=int,
default=16,
help='Maximum batch size for dynamic batching (default: 16)'
help="Maximum batch size for dynamic batching (default: 16)",
)
parser.add_argument(
'--workspace-size',
type=int,
default=4,
help='TensorRT workspace size in GB (default: 4)'
)
parser.add_argument("--gpu", type=int, default=0, help="GPU device ID (default: 0)")
parser.add_argument(
'--gpu',
type=int,
default=0,
help='GPU device ID (default: 0)'
)
parser.add_argument(
'--input-names',
"--input-names",
type=str,
nargs='+',
nargs="+",
default=None,
help='Custom input tensor names (default: ["input"])'
help='Custom input tensor names (default: ["input"])',
)
parser.add_argument(
'--output-names',
"--output-names",
type=str,
nargs='+',
nargs="+",
default=None,
help='Custom output tensor names (default: ["output"])'
help='Custom output tensor names (default: ["output"])',
)
parser.add_argument(
'--keep-onnx',
action='store_true',
help='Keep intermediate ONNX file'
"--keep-onnx", action="store_true", help="Keep intermediate ONNX file"
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Enable verbose logging'
"--verbose", "-v", action="store_true", help="Enable verbose logging"
)
args = parser.parse_args()
@ -542,10 +534,9 @@ Examples:
int8=args.int8,
dynamic_batch=args.dynamic_batch,
max_batch=args.max_batch,
workspace_size=args.workspace_size,
input_names=args.input_names,
output_names=args.output_names,
keep_onnx=args.keep_onnx
keep_onnx=args.keep_onnx,
)
print("\n✓ Conversion successful!")
@ -554,6 +545,7 @@ Examples:
print(f"\n✗ Conversion failed: {e}")
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(1)