ultralytic export
This commit is contained in:
parent
bf7b68edb1
commit
fdaeb9981c
14 changed files with 2241 additions and 507 deletions
|
|
@ -21,10 +21,11 @@ Usage:
|
|||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Tuple, List, Optional
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
|
||||
|
||||
class TensorRTConverter:
|
||||
|
|
@ -39,7 +40,7 @@ class TensorRTConverter:
|
|||
verbose: Enable verbose logging
|
||||
"""
|
||||
self.gpu_id = gpu_id
|
||||
self.device = torch.device(f'cuda:{gpu_id}')
|
||||
self.device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
# TensorRT logger
|
||||
log_level = trt.Logger.VERBOSE if verbose else trt.Logger.WARNING
|
||||
|
|
@ -71,13 +72,15 @@ class TensorRTConverter:
|
|||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Load model (weights_only=False for models with custom classes)
|
||||
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
|
||||
checkpoint = torch.load(
|
||||
model_path, map_location=self.device, weights_only=False
|
||||
)
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if isinstance(checkpoint, dict):
|
||||
if 'model' in checkpoint:
|
||||
model = checkpoint['model']
|
||||
elif 'state_dict' in checkpoint:
|
||||
if "model" in checkpoint:
|
||||
model = checkpoint["model"]
|
||||
elif "state_dict" in checkpoint:
|
||||
# Need model architecture - this is a limitation
|
||||
raise ValueError(
|
||||
"Checkpoint contains only state_dict. "
|
||||
|
|
@ -95,9 +98,15 @@ class TensorRTConverter:
|
|||
print(f"✓ Model loaded successfully")
|
||||
return model
|
||||
|
||||
def export_to_onnx(self, model: torch.nn.Module, input_shape: Tuple[int, ...],
|
||||
onnx_path: str, dynamic_batch: bool = False,
|
||||
input_names: List[str] = None, output_names: List[str] = None) -> str:
|
||||
def export_to_onnx(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
input_shape: Tuple[int, ...],
|
||||
onnx_path: str,
|
||||
dynamic_batch: bool = False,
|
||||
input_names: List[str] = None,
|
||||
output_names: List[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Export PyTorch model to ONNX format (intermediate step).
|
||||
|
||||
|
|
@ -118,9 +127,9 @@ class TensorRTConverter:
|
|||
|
||||
# Default names
|
||||
if input_names is None:
|
||||
input_names = ['input']
|
||||
input_names = ["input"]
|
||||
if output_names is None:
|
||||
output_names = ['output']
|
||||
output_names = ["output"]
|
||||
|
||||
# Create dummy input
|
||||
dummy_input = torch.randn(*input_shape, device=self.device)
|
||||
|
|
@ -128,10 +137,7 @@ class TensorRTConverter:
|
|||
# Dynamic axes configuration
|
||||
dynamic_axes = None
|
||||
if dynamic_batch:
|
||||
dynamic_axes = {
|
||||
input_names[0]: {0: 'batch'},
|
||||
output_names[0]: {0: 'batch'}
|
||||
}
|
||||
dynamic_axes = {input_names[0]: {0: "batch"}, output_names[0]: {0: "batch"}}
|
||||
|
||||
# Export to ONNX
|
||||
torch.onnx.export(
|
||||
|
|
@ -143,16 +149,22 @@ class TensorRTConverter:
|
|||
dynamic_axes=dynamic_axes,
|
||||
opset_version=17, # Use recent ONNX opset
|
||||
do_constant_folding=True,
|
||||
verbose=False
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
print(f"✓ ONNX model exported to {onnx_path}")
|
||||
return onnx_path
|
||||
|
||||
def build_tensorrt_engine_from_onnx(self, onnx_path: str, engine_path: str,
|
||||
fp16: bool = False, int8: bool = False,
|
||||
max_workspace_size: int = 4,
|
||||
min_batch: int = 1, opt_batch: int = 1, max_batch: int = 1) -> str:
|
||||
def build_tensorrt_engine_from_onnx(
|
||||
self,
|
||||
onnx_path: str,
|
||||
engine_path: str,
|
||||
fp16: bool = False,
|
||||
int8: bool = False,
|
||||
min_batch: int = 1,
|
||||
opt_batch: int = 1,
|
||||
max_batch: int = 1,
|
||||
) -> str:
|
||||
"""
|
||||
Build TensorRT engine from ONNX model.
|
||||
|
||||
|
|
@ -161,7 +173,6 @@ class TensorRTConverter:
|
|||
engine_path: Output path for TensorRT engine
|
||||
fp16: Enable FP16 precision
|
||||
int8: Enable INT8 precision (requires calibration)
|
||||
max_workspace_size: Maximum workspace size in GB
|
||||
min_batch: Minimum batch size for optimization
|
||||
opt_batch: Optimal batch size for optimization
|
||||
max_batch: Maximum batch size for optimization
|
||||
|
|
@ -171,7 +182,6 @@ class TensorRTConverter:
|
|||
"""
|
||||
print(f"\nBuilding TensorRT engine from ONNX...")
|
||||
print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}")
|
||||
print(f"Workspace size: {max_workspace_size} GB")
|
||||
|
||||
# Create builder and network
|
||||
builder = trt.Builder(self.logger)
|
||||
|
|
@ -182,7 +192,7 @@ class TensorRTConverter:
|
|||
|
||||
# Parse ONNX model
|
||||
print(f"Loading ONNX file from {onnx_path}...")
|
||||
with open(onnx_path, 'rb') as f:
|
||||
with open(onnx_path, "rb") as f:
|
||||
if not parser.parse(f.read()):
|
||||
print("ERROR: Failed to parse the ONNX file:")
|
||||
for error in range(parser.num_errors):
|
||||
|
|
@ -206,12 +216,6 @@ class TensorRTConverter:
|
|||
# Create builder config
|
||||
config = builder.create_builder_config()
|
||||
|
||||
# Set workspace size
|
||||
config.set_memory_pool_limit(
|
||||
trt.MemoryPoolType.WORKSPACE,
|
||||
max_workspace_size * (1 << 30) # GB to bytes
|
||||
)
|
||||
|
||||
# Enable precision modes
|
||||
if fp16:
|
||||
if not builder.platform_has_fast_fp16:
|
||||
|
|
@ -226,7 +230,9 @@ class TensorRTConverter:
|
|||
else:
|
||||
config.set_flag(trt.BuilderFlag.INT8)
|
||||
print("✓ INT8 mode enabled")
|
||||
print("Note: INT8 calibration not implemented. Results may be suboptimal.")
|
||||
print(
|
||||
"Note: INT8 calibration not implemented. Results may be suboptimal."
|
||||
)
|
||||
|
||||
# Set optimization profile for dynamic shapes
|
||||
if max_batch > 1 or min_batch != max_batch:
|
||||
|
|
@ -260,7 +266,7 @@ class TensorRTConverter:
|
|||
|
||||
# Save engine to file
|
||||
print(f"Saving engine to {engine_path}...")
|
||||
with open(engine_path, 'wb') as f:
|
||||
with open(engine_path, "wb") as f:
|
||||
f.write(serialized_engine)
|
||||
|
||||
# Get file size
|
||||
|
|
@ -270,15 +276,19 @@ class TensorRTConverter:
|
|||
|
||||
return engine_path
|
||||
|
||||
def convert(self, model_path: str, output_path: str,
|
||||
input_shape: Tuple[int, ...] = (1, 3, 640, 640),
|
||||
fp16: bool = False, int8: bool = False,
|
||||
dynamic_batch: bool = False,
|
||||
max_batch: int = 16,
|
||||
workspace_size: int = 4,
|
||||
input_names: List[str] = None,
|
||||
output_names: List[str] = None,
|
||||
keep_onnx: bool = False) -> str:
|
||||
def convert(
|
||||
self,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
input_shape: Tuple[int, ...] = (1, 3, 640, 640),
|
||||
fp16: bool = False,
|
||||
int8: bool = False,
|
||||
dynamic_batch: bool = False,
|
||||
max_batch: int = 16,
|
||||
input_names: List[str] = None,
|
||||
output_names: List[str] = None,
|
||||
keep_onnx: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Convert PyTorch or ONNX model to TensorRT engine.
|
||||
|
||||
|
|
@ -290,7 +300,6 @@ class TensorRTConverter:
|
|||
int8: Enable INT8 precision
|
||||
dynamic_batch: Enable dynamic batch size
|
||||
max_batch: Maximum batch size (for dynamic batching)
|
||||
workspace_size: TensorRT workspace size in GB
|
||||
input_names: Custom input names (for PyTorch export)
|
||||
output_names: Custom output names (for PyTorch export)
|
||||
keep_onnx: Keep intermediate ONNX file
|
||||
|
|
@ -304,7 +313,7 @@ class TensorRTConverter:
|
|||
|
||||
# Check if input is already ONNX
|
||||
model_path_obj = Path(model_path)
|
||||
is_onnx = model_path_obj.suffix.lower() == '.onnx'
|
||||
is_onnx = model_path_obj.suffix.lower() == ".onnx"
|
||||
|
||||
if is_onnx:
|
||||
# Direct ONNX to TensorRT conversion
|
||||
|
|
@ -319,10 +328,9 @@ class TensorRTConverter:
|
|||
engine_path=output_path,
|
||||
fp16=fp16,
|
||||
int8=int8,
|
||||
max_workspace_size=workspace_size,
|
||||
min_batch=min_batch,
|
||||
opt_batch=opt_batch,
|
||||
max_batch=max_batch_size
|
||||
max_batch=max_batch_size,
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
|
|
@ -350,7 +358,7 @@ class TensorRTConverter:
|
|||
onnx_path=onnx_path,
|
||||
dynamic_batch=dynamic_batch,
|
||||
input_names=input_names,
|
||||
output_names=output_names
|
||||
output_names=output_names,
|
||||
)
|
||||
|
||||
# Step 3: Build TensorRT engine
|
||||
|
|
@ -363,10 +371,9 @@ class TensorRTConverter:
|
|||
engine_path=output_path,
|
||||
fp16=fp16,
|
||||
int8=int8,
|
||||
max_workspace_size=workspace_size,
|
||||
min_batch=min_batch,
|
||||
opt_batch=opt_batch,
|
||||
max_batch=max_batch_size
|
||||
max_batch=max_batch_size,
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
|
|
@ -392,7 +399,7 @@ class TensorRTConverter:
|
|||
def parse_shape(shape_str: str) -> Tuple[int, ...]:
|
||||
"""Parse shape string like '1,3,640,640' to tuple"""
|
||||
try:
|
||||
return tuple(int(x) for x in shape_str.split(','))
|
||||
return tuple(int(x) for x in shape_str.split(","))
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid shape format: {shape_str}. Expected format: 1,3,640,640"
|
||||
|
|
@ -426,97 +433,82 @@ Examples:
|
|||
# Keep intermediate ONNX file for debugging
|
||||
python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\
|
||||
--keep-onnx
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
# Required arguments
|
||||
parser.add_argument(
|
||||
'--model', '-m',
|
||||
"--model",
|
||||
"-m",
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to PyTorch model file (.pt or .pth)'
|
||||
help="Path to PyTorch model file (.pt or .pth)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
"--output",
|
||||
"-o",
|
||||
type=str,
|
||||
required=True,
|
||||
help='Output path for TensorRT engine (.trt or .engine)'
|
||||
help="Output path for TensorRT engine (.trt or .engine)",
|
||||
)
|
||||
|
||||
# Optional arguments
|
||||
parser.add_argument(
|
||||
'--input-shape', '-s',
|
||||
"--input-shape",
|
||||
"-s",
|
||||
type=parse_shape,
|
||||
default=(1, 3, 640, 640),
|
||||
help='Input tensor shape as B,C,H,W (default: 1,3,640,640)'
|
||||
help="Input tensor shape as B,C,H,W (default: 1,3,640,640)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--fp16',
|
||||
action='store_true',
|
||||
help='Enable FP16 precision (faster inference, slightly lower accuracy)'
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Enable FP16 precision (faster inference, slightly lower accuracy)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--int8',
|
||||
action='store_true',
|
||||
help='Enable INT8 precision (fastest, requires calibration)'
|
||||
"--int8",
|
||||
action="store_true",
|
||||
help="Enable INT8 precision (fastest, requires calibration)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--dynamic-batch',
|
||||
action='store_true',
|
||||
help='Enable dynamic batch size support'
|
||||
"--dynamic-batch", action="store_true", help="Enable dynamic batch size support"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-batch',
|
||||
"--max-batch",
|
||||
type=int,
|
||||
default=16,
|
||||
help='Maximum batch size for dynamic batching (default: 16)'
|
||||
help="Maximum batch size for dynamic batching (default: 16)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--workspace-size',
|
||||
type=int,
|
||||
default=4,
|
||||
help='TensorRT workspace size in GB (default: 4)'
|
||||
)
|
||||
parser.add_argument("--gpu", type=int, default=0, help="GPU device ID (default: 0)")
|
||||
|
||||
parser.add_argument(
|
||||
'--gpu',
|
||||
type=int,
|
||||
default=0,
|
||||
help='GPU device ID (default: 0)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--input-names',
|
||||
"--input-names",
|
||||
type=str,
|
||||
nargs='+',
|
||||
nargs="+",
|
||||
default=None,
|
||||
help='Custom input tensor names (default: ["input"])'
|
||||
help='Custom input tensor names (default: ["input"])',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output-names',
|
||||
"--output-names",
|
||||
type=str,
|
||||
nargs='+',
|
||||
nargs="+",
|
||||
default=None,
|
||||
help='Custom output tensor names (default: ["output"])'
|
||||
help='Custom output tensor names (default: ["output"])',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--keep-onnx',
|
||||
action='store_true',
|
||||
help='Keep intermediate ONNX file'
|
||||
"--keep-onnx", action="store_true", help="Keep intermediate ONNX file"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Enable verbose logging'
|
||||
"--verbose", "-v", action="store_true", help="Enable verbose logging"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
@ -542,10 +534,9 @@ Examples:
|
|||
int8=args.int8,
|
||||
dynamic_batch=args.dynamic_batch,
|
||||
max_batch=args.max_batch,
|
||||
workspace_size=args.workspace_size,
|
||||
input_names=args.input_names,
|
||||
output_names=args.output_names,
|
||||
keep_onnx=args.keep_onnx
|
||||
keep_onnx=args.keep_onnx,
|
||||
)
|
||||
|
||||
print("\n✓ Conversion successful!")
|
||||
|
|
@ -554,6 +545,7 @@ Examples:
|
|||
print(f"\n✗ Conversion failed: {e}")
|
||||
if args.verbose:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue