#!/usr/bin/env python3 """ PyTorch to TensorRT Model Conversion Script This script converts PyTorch models (.pt, .pth) to TensorRT engines (.trt) for optimized inference. Features: - Automatic FP32/FP16/INT8 precision modes - Dynamic batch size support - Input shape validation - Optimization profiles for dynamic shapes - ONNX intermediate format - GPU-accelerated conversion Usage: python convert_pt_to_tensorrt.py --model path/to/model.pt --output models/model.trt python convert_pt_to_tensorrt.py --model yolov8n.pt --input-shape 1 3 640 640 --fp16 python convert_pt_to_tensorrt.py --model model.pt --dynamic-batch --max-batch 16 """ import argparse import sys from pathlib import Path from typing import List, Optional, Tuple import numpy as np import tensorrt as trt import torch class TensorRTConverter: """Converts PyTorch models to TensorRT engines""" def __init__(self, gpu_id: int = 0, verbose: bool = True): """ Initialize the converter. Args: gpu_id: GPU device ID to use for conversion verbose: Enable verbose logging """ self.gpu_id = gpu_id self.device = torch.device(f"cuda:{gpu_id}") # TensorRT logger log_level = trt.Logger.VERBOSE if verbose else trt.Logger.WARNING self.logger = trt.Logger(log_level) # Set CUDA device torch.cuda.set_device(gpu_id) print(f"Initialized TensorRT Converter on GPU {gpu_id}") print(f"PyTorch version: {torch.__version__}") print(f"TensorRT version: {trt.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"CUDA device: {torch.cuda.get_device_name(gpu_id)}") def load_pytorch_model(self, model_path: str) -> torch.nn.Module: """ Load PyTorch model from file. Args: model_path: Path to .pt or .pth file Returns: Loaded PyTorch model in eval mode """ print(f"\nLoading PyTorch model from {model_path}...") if not Path(model_path).exists(): raise FileNotFoundError(f"Model file not found: {model_path}") # Load model (weights_only=False for models with custom classes) checkpoint = torch.load( model_path, map_location=self.device, weights_only=False ) # Handle different checkpoint formats if isinstance(checkpoint, dict): if "model" in checkpoint: model = checkpoint["model"] elif "state_dict" in checkpoint: # Need model architecture - this is a limitation raise ValueError( "Checkpoint contains only state_dict. " "Please provide the complete model or modify this script to load your architecture." ) else: raise ValueError("Unknown checkpoint format") else: model = checkpoint # Set to eval mode model.eval() model.to(self.device) print(f"✓ Model loaded successfully") return model def export_to_onnx( self, model: torch.nn.Module, input_shape: Tuple[int, ...], onnx_path: str, dynamic_batch: bool = False, input_names: List[str] = None, output_names: List[str] = None, ) -> str: """ Export PyTorch model to ONNX format (intermediate step). Args: model: PyTorch model input_shape: Input tensor shape (B, C, H, W) onnx_path: Output path for ONNX file dynamic_batch: Enable dynamic batch dimension input_names: List of input tensor names output_names: List of output tensor names Returns: Path to exported ONNX file """ print(f"\nExporting to ONNX format...") print(f"Input shape: {input_shape}") print(f"Dynamic batch: {dynamic_batch}") # Default names if input_names is None: input_names = ["input"] if output_names is None: output_names = ["output"] # Create dummy input dummy_input = torch.randn(*input_shape, device=self.device) # Dynamic axes configuration dynamic_axes = None if dynamic_batch: dynamic_axes = {input_names[0]: {0: "batch"}, output_names[0]: {0: "batch"}} # Export to ONNX torch.onnx.export( model, dummy_input, onnx_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=17, # Use recent ONNX opset do_constant_folding=True, verbose=False, ) print(f"✓ ONNX model exported to {onnx_path}") return onnx_path def build_tensorrt_engine_from_onnx( self, onnx_path: str, engine_path: str, fp16: bool = False, int8: bool = False, min_batch: int = 1, opt_batch: int = 1, max_batch: int = 1, ) -> str: """ Build TensorRT engine from ONNX model. Args: onnx_path: Path to ONNX model engine_path: Output path for TensorRT engine fp16: Enable FP16 precision int8: Enable INT8 precision (requires calibration) min_batch: Minimum batch size for optimization opt_batch: Optimal batch size for optimization max_batch: Maximum batch size for optimization Returns: Path to built TensorRT engine """ print(f"\nBuilding TensorRT engine from ONNX...") print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}") # Create builder and network builder = trt.Builder(self.logger) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) parser = trt.OnnxParser(network, self.logger) # Parse ONNX model print(f"Loading ONNX file from {onnx_path}...") with open(onnx_path, "rb") as f: if not parser.parse(f.read()): print("ERROR: Failed to parse the ONNX file:") for error in range(parser.num_errors): print(f" {parser.get_error(error)}") raise RuntimeError("Failed to parse ONNX model") print(f"✓ ONNX model parsed successfully") # Print network info print(f"\nNetwork Information:") print(f" Inputs: {network.num_inputs}") for i in range(network.num_inputs): inp = network.get_input(i) print(f" [{i}] {inp.name}: {inp.shape} ({inp.dtype})") print(f" Outputs: {network.num_outputs}") for i in range(network.num_outputs): out = network.get_output(i) print(f" [{i}] {out.name}: {out.shape} ({out.dtype})") # Create builder config config = builder.create_builder_config() # Enable precision modes if fp16: if not builder.platform_has_fast_fp16: print("Warning: FP16 not supported on this platform, using FP32") else: config.set_flag(trt.BuilderFlag.FP16) print("✓ FP16 mode enabled") if int8: if not builder.platform_has_fast_int8: print("Warning: INT8 not supported on this platform, using FP32/FP16") else: config.set_flag(trt.BuilderFlag.INT8) print("✓ INT8 mode enabled") print( "Note: INT8 calibration not implemented. Results may be suboptimal." ) # Set optimization profile for dynamic shapes if max_batch > 1 or min_batch != max_batch: profile = builder.create_optimization_profile() for i in range(network.num_inputs): inp = network.get_input(i) shape = list(inp.shape) # Handle dynamic batch dimension if shape[0] == -1: # Min, opt, max shapes min_shape = [min_batch] + shape[1:] opt_shape = [opt_batch] + shape[1:] max_shape = [max_batch] + shape[1:] profile.set_shape(inp.name, min_shape, opt_shape, max_shape) print(f" Dynamic shape for {inp.name}:") print(f" Min: {min_shape}") print(f" Opt: {opt_shape}") print(f" Max: {max_shape}") config.add_optimization_profile(profile) # Build engine print(f"\nBuilding TensorRT engine (this may take a few minutes)...") serialized_engine = builder.build_serialized_network(network, config) if serialized_engine is None: raise RuntimeError("Failed to build TensorRT engine") # Save engine to file print(f"Saving engine to {engine_path}...") with open(engine_path, "wb") as f: f.write(serialized_engine) # Get file size file_size_mb = Path(engine_path).stat().st_size / (1024 * 1024) print(f"✓ TensorRT engine built successfully") print(f" Engine size: {file_size_mb:.2f} MB") return engine_path def convert( self, model_path: str, output_path: str, input_shape: Tuple[int, ...] = (1, 3, 640, 640), fp16: bool = False, int8: bool = False, dynamic_batch: bool = False, max_batch: int = 16, input_names: List[str] = None, output_names: List[str] = None, keep_onnx: bool = False, ) -> str: """ Convert PyTorch or ONNX model to TensorRT engine. Args: model_path: Path to PyTorch model (.pt, .pth) or ONNX model (.onnx) output_path: Path for output TensorRT engine (.trt) input_shape: Input tensor shape (B, C, H, W) - required for PyTorch models fp16: Enable FP16 precision int8: Enable INT8 precision dynamic_batch: Enable dynamic batch size max_batch: Maximum batch size (for dynamic batching) input_names: Custom input names (for PyTorch export) output_names: Custom output names (for PyTorch export) keep_onnx: Keep intermediate ONNX file Returns: Path to created TensorRT engine """ # Create output directory output_dir = Path(output_path).parent output_dir.mkdir(parents=True, exist_ok=True) # Check if input is already ONNX model_path_obj = Path(model_path) is_onnx = model_path_obj.suffix.lower() == ".onnx" if is_onnx: # Direct ONNX to TensorRT conversion print(f"Input is ONNX model, converting directly to TensorRT...") min_batch = 1 opt_batch = input_shape[0] if not dynamic_batch else max(1, max_batch // 2) max_batch_size = max_batch if dynamic_batch else input_shape[0] engine_path = self.build_tensorrt_engine_from_onnx( onnx_path=model_path, engine_path=output_path, fp16=fp16, int8=int8, min_batch=min_batch, opt_batch=opt_batch, max_batch=max_batch_size, ) print(f"\n{'=' * 80}") print(f"CONVERSION COMPLETED SUCCESSFULLY") print(f"{'=' * 80}") print(f"Input: {model_path}") print(f"Output: {engine_path}") print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}") print(f"{'=' * 80}") return engine_path # PyTorch to TensorRT conversion (via ONNX) # Temporary ONNX path onnx_path = str(output_dir / "temp_model.onnx") try: # Step 1: Load PyTorch model model = self.load_pytorch_model(model_path) # Step 2: Export to ONNX self.export_to_onnx( model=model, input_shape=input_shape, onnx_path=onnx_path, dynamic_batch=dynamic_batch, input_names=input_names, output_names=output_names, ) # Step 3: Build TensorRT engine min_batch = 1 opt_batch = input_shape[0] if not dynamic_batch else max(1, max_batch // 2) max_batch_size = max_batch if dynamic_batch else input_shape[0] engine_path = self.build_tensorrt_engine_from_onnx( onnx_path=onnx_path, engine_path=output_path, fp16=fp16, int8=int8, min_batch=min_batch, opt_batch=opt_batch, max_batch=max_batch_size, ) print(f"\n{'=' * 80}") print(f"CONVERSION COMPLETED SUCCESSFULLY") print(f"{'=' * 80}") print(f"Input: {model_path}") print(f"Output: {engine_path}") print(f"Precision: FP{'16' if fp16 else '32'}{' + INT8' if int8 else ''}") print(f"Dynamic batch: {dynamic_batch}") if dynamic_batch: print(f"Batch range: [1, {max_batch}]") print(f"{'=' * 80}") return engine_path finally: # Cleanup temporary ONNX file if not keep_onnx and Path(onnx_path).exists(): Path(onnx_path).unlink() print(f"Cleaned up temporary ONNX file") def parse_shape(shape_str: str) -> Tuple[int, ...]: """Parse shape string like '1,3,640,640' to tuple""" try: return tuple(int(x) for x in shape_str.split(",")) except ValueError: raise argparse.ArgumentTypeError( f"Invalid shape format: {shape_str}. Expected format: 1,3,640,640" ) def main(): parser = argparse.ArgumentParser( description="Convert PyTorch models to TensorRT engines", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Basic conversion (FP32) python convert_pt_to_tensorrt.py --model yolov8n.pt --output models/yolov8n.trt # FP16 precision for faster inference python convert_pt_to_tensorrt.py --model model.pt --output model.trt --fp16 # Custom input shape python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\ --input-shape 1,3,416,416 # Dynamic batch size (1 to 16) python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\ --dynamic-batch --max-batch 16 # INT8 quantization for maximum speed (requires calibration) python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\ --fp16 --int8 # Keep intermediate ONNX file for debugging python convert_pt_to_tensorrt.py --model model.pt --output model.trt \\ --keep-onnx """, ) # Required arguments parser.add_argument( "--model", "-m", type=str, required=True, help="Path to PyTorch model file (.pt or .pth)", ) parser.add_argument( "--output", "-o", type=str, required=True, help="Output path for TensorRT engine (.trt or .engine)", ) # Optional arguments parser.add_argument( "--input-shape", "-s", type=parse_shape, default=(1, 3, 640, 640), help="Input tensor shape as B,C,H,W (default: 1,3,640,640)", ) parser.add_argument( "--fp16", action="store_true", help="Enable FP16 precision (faster inference, slightly lower accuracy)", ) parser.add_argument( "--int8", action="store_true", help="Enable INT8 precision (fastest, requires calibration)", ) parser.add_argument( "--dynamic-batch", action="store_true", help="Enable dynamic batch size support" ) parser.add_argument( "--max-batch", type=int, default=16, help="Maximum batch size for dynamic batching (default: 16)", ) parser.add_argument("--gpu", type=int, default=0, help="GPU device ID (default: 0)") parser.add_argument( "--input-names", type=str, nargs="+", default=None, help='Custom input tensor names (default: ["input"])', ) parser.add_argument( "--output-names", type=str, nargs="+", default=None, help='Custom output tensor names (default: ["output"])', ) parser.add_argument( "--keep-onnx", action="store_true", help="Keep intermediate ONNX file" ) parser.add_argument( "--verbose", "-v", action="store_true", help="Enable verbose logging" ) args = parser.parse_args() # Validate arguments if not Path(args.model).exists(): print(f"Error: Model file not found: {args.model}") sys.exit(1) if args.int8 and not args.fp16: print("Warning: INT8 mode works best with FP16 enabled. Adding --fp16 flag.") args.fp16 = True # Run conversion try: converter = TensorRTConverter(gpu_id=args.gpu, verbose=args.verbose) converter.convert( model_path=args.model, output_path=args.output, input_shape=args.input_shape, fp16=args.fp16, int8=args.int8, dynamic_batch=args.dynamic_batch, max_batch=args.max_batch, input_names=args.input_names, output_names=args.output_names, keep_onnx=args.keep_onnx, ) print("\n✓ Conversion successful!") except Exception as e: print(f"\n✗ Conversion failed: {e}") if args.verbose: import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()