""" PTConverter - Convert PyTorch models to TensorRT using torch_tensorrt. This service handles conversion of .pt files to TensorRT format with caching to avoid redundant conversions. It maintains a mapping database between PT file hashes and their converted TensorRT file hashes. """ import hashlib import json import logging from pathlib import Path from typing import Optional, Dict, Tuple import torch import torch_tensorrt from .modelstorage import IModelStorage, FileModelStorage logger = logging.getLogger(__name__) class PTConverter: """ PyTorch to TensorRT converter with intelligent caching. Features: - Hash-based deduplication: Same PT file only converted once - Persistent mapping database: PT hash -> TRT hash mapping - Pluggable storage backend: IModelStorage interface - Automatic cache management Architecture: 1. Compute hash of input .pt file 2. Check mapping database for existing conversion 3. If found, return cached TRT file hash and path 4. If not found, perform conversion and store mapping """ def __init__( self, storage: Optional[IModelStorage] = None, gpu_id: int = 0, default_precision: torch.dtype = torch.float16 ): """ Initialize PT converter. Args: storage: Storage backend (defaults to FileModelStorage) gpu_id: GPU device ID for conversion default_precision: Default precision for TensorRT conversion (fp16, fp32) """ self.storage = storage or FileModelStorage() self.gpu_id = gpu_id self.device = torch.device(f'cuda:{gpu_id}') self.default_precision = default_precision # Mapping database: pt_hash -> {"trt_hash": str, "metadata": {...}} self.mapping_db_key = "pt_to_trt_mapping.json" self.mapping_db: Dict[str, Dict] = self._load_mapping_db() logger.info(f"PTConverter initialized on GPU {gpu_id}") logger.info(f"Storage backend: {self.storage.__class__.__name__}") logger.info(f"Storage path: {self.storage.get_storage_path()}") logger.info(f"Loaded {len(self.mapping_db)} cached conversions") def _load_mapping_db(self) -> Dict[str, Dict]: """ Load mapping database from storage. Returns: Mapping dictionary (pt_hash -> metadata) """ try: data = self.storage.read(self.mapping_db_key) if data: db = json.loads(data.decode('utf-8')) logger.debug(f"Loaded mapping database with {len(db)} entries") return db else: logger.debug("No existing mapping database found, starting fresh") return {} except Exception as e: logger.warning(f"Failed to load mapping database: {e}. Starting fresh.") return {} def _save_mapping_db(self): """Save mapping database to storage""" try: data = json.dumps(self.mapping_db, indent=2).encode('utf-8') self.storage.write(self.mapping_db_key, data) logger.debug(f"Saved mapping database with {len(self.mapping_db)} entries") except Exception as e: logger.error(f"Failed to save mapping database: {e}") @staticmethod def compute_file_hash(file_path: str) -> str: """ Compute SHA256 hash of a file. Args: file_path: Path to file Returns: Hexadecimal hash string """ sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: for byte_block in iter(lambda: f.read(65536), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() def get_cached_conversion(self, pt_hash: str) -> Optional[Tuple[str, str]]: """ Check if PT file has already been converted. Args: pt_hash: SHA256 hash of the PT file Returns: Tuple of (trt_hash, trt_file_path) if cached, None otherwise """ if pt_hash not in self.mapping_db: return None mapping = self.mapping_db[pt_hash] trt_hash = mapping["trt_hash"] trt_key = f"trt/{trt_hash}.trt" # Verify TRT file still exists in storage if not self.storage.exists(trt_key): logger.warning( f"Mapping exists for PT hash {pt_hash[:16]}... but TRT file missing. " f"Will reconvert." ) # Remove stale mapping del self.mapping_db[pt_hash] self._save_mapping_db() return None # Get local path trt_path = self.storage.get_local_path(trt_key) if trt_path is None: logger.error(f"Could not get local path for TRT file {trt_key}") return None logger.info( f"Found cached conversion for PT hash {pt_hash[:16]}... -> " f"TRT hash {trt_hash[:16]}..." ) return (trt_hash, trt_path) def convert( self, pt_file_path: str, input_shapes: Optional[Dict[str, Tuple]] = None, precision: Optional[torch.dtype] = None, **conversion_kwargs ) -> Tuple[str, str]: """ Convert PyTorch model to TensorRT. If this PT file has been converted before (same hash), returns cached result. Otherwise, performs conversion and caches the result. Args: pt_file_path: Path to .pt file input_shapes: Dict of input names to shapes (e.g., {"x": (1, 3, 224, 224)}) precision: Target precision (fp16, fp32) - defaults to self.default_precision **conversion_kwargs: Additional arguments for torch_tensorrt.compile() Returns: Tuple of (trt_hash, trt_file_path) Raises: FileNotFoundError: If PT file doesn't exist RuntimeError: If conversion fails """ pt_path = Path(pt_file_path).resolve() if not pt_path.exists(): raise FileNotFoundError(f"PT file not found: {pt_file_path}") # Compute PT file hash logger.info(f"Computing hash for {pt_path}...") pt_hash = self.compute_file_hash(str(pt_path)) logger.info(f"PT file hash: {pt_hash[:16]}...") # Check cache cached = self.get_cached_conversion(pt_hash) if cached: return cached # Perform conversion logger.info(f"Converting {pt_path.name} to TensorRT...") trt_hash, trt_path = self._perform_conversion( str(pt_path), pt_hash, input_shapes, precision or self.default_precision, **conversion_kwargs ) # Store mapping self.mapping_db[pt_hash] = { "trt_hash": trt_hash, "pt_file": str(pt_path), "input_shapes": str(input_shapes), "precision": str(precision or self.default_precision), } self._save_mapping_db() logger.info(f"Conversion complete: PT {pt_hash[:16]}... -> TRT {trt_hash[:16]}...") return (trt_hash, trt_path) def _is_ultralytics_model(self, model) -> bool: """ Check if model is from ultralytics (YOLO). Args: model: PyTorch model Returns: True if ultralytics model, False otherwise """ model_class_name = model.__class__.__name__ model_module = model.__class__.__module__ # Check if it's an ultralytics model is_ultralytics = ( 'ultralytics' in model_module or model_class_name in ['DetectionModel', 'SegmentationModel', 'PoseModel', 'ClassificationModel'] ) return is_ultralytics def _convert_ultralytics_model( self, pt_path: str, pt_hash: str, input_shapes: Optional[Dict[str, Tuple]], precision: torch.dtype, ) -> Tuple[str, str]: """ Convert ultralytics YOLO model using ONNX → TensorRT pipeline. Uses the same approach as scripts/convert_pt_to_tensorrt.py Args: pt_path: Path to PT file pt_hash: PT file hash input_shapes: Input tensor shapes precision: Target precision Returns: Tuple of (trt_hash, trt_file_path) """ import tensorrt as trt import tempfile import os import shutil logger.info("Detected ultralytics YOLO model, using ONNX → TensorRT pipeline...") # Load ultralytics model try: from ultralytics import YOLO model = YOLO(pt_path) except ImportError: raise ImportError("ultralytics package not found. Install with: pip install ultralytics") # Determine input shape if not input_shapes: raise ValueError("input_shapes required for ultralytics conversion") input_key = 'images' if 'images' in input_shapes else list(input_shapes.keys())[0] input_shape = input_shapes[input_key] # Export to ONNX first logger.info(f"Exporting to ONNX (input shape: {input_shape})...") with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp_onnx: onnx_path = tmp_onnx.name try: # Use ultralytics export to ONNX model.export(format='onnx', imgsz=input_shape[2], batch=input_shape[0]) # Ultralytics saves as model_name.onnx in same directory pt_dir = os.path.dirname(pt_path) pt_name = os.path.splitext(os.path.basename(pt_path))[0] onnx_export_path = os.path.join(pt_dir, f"{pt_name}.onnx") # Move to our temp location (use shutil.move for cross-device support) if os.path.exists(onnx_export_path): shutil.move(onnx_export_path, onnx_path) else: raise RuntimeError(f"ONNX export failed, file not found: {onnx_export_path}") logger.info(f"ONNX export complete: {onnx_path}") # Build TensorRT engine from ONNX logger.info("Building TensorRT engine from ONNX...") trt_logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(trt_logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, trt_logger) # Parse ONNX with open(onnx_path, 'rb') as f: if not parser.parse(f.read()): errors = [parser.get_error(i) for i in range(parser.num_errors)] raise RuntimeError(f"Failed to parse ONNX: {errors}") # Configure builder config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB # Set precision if precision == torch.float16: if builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) logger.info("FP16 mode enabled") # Build engine logger.info("Building TensorRT engine (this may take a few minutes)...") serialized_engine = builder.build_serialized_network(network, config) if serialized_engine is None: raise RuntimeError("Failed to build TensorRT engine") # Convert IHostMemory to bytes engine_bytes = bytes(serialized_engine) # Save to storage trt_hash = hashlib.sha256(engine_bytes).hexdigest() trt_key = f"trt/{trt_hash}.trt" self.storage.write(trt_key, engine_bytes) trt_path = self.storage.get_local_path(trt_key) if trt_path is None: raise RuntimeError("Failed to get local path for TRT file") logger.info(f"TensorRT engine built successfully: {trt_path}") return (trt_hash, trt_path) finally: # Cleanup ONNX file if os.path.exists(onnx_path): os.unlink(onnx_path) def _perform_conversion( self, pt_path: str, pt_hash: str, input_shapes: Optional[Dict[str, Tuple]], precision: torch.dtype, **conversion_kwargs ) -> Tuple[str, str]: """ Perform the actual PT to TRT conversion. Args: pt_path: Path to PT file pt_hash: PT file hash input_shapes: Input tensor shapes precision: Target precision **conversion_kwargs: Additional torch_tensorrt arguments Returns: Tuple of (trt_hash, trt_file_path) """ try: # Load PyTorch model to check type logger.debug(f"Loading PyTorch model from {pt_path}...") # Use weights_only=False for models with custom classes (like ultralytics) # This is safe for trusted local models loaded = torch.load(pt_path, map_location='cpu', weights_only=False) # If model is wrapped in a dict, extract the model if isinstance(loaded, dict): if 'model' in loaded: model = loaded['model'] elif 'state_dict' in loaded: raise ValueError( "PT file contains state_dict only. " "Please provide a full model or use a different loading method." ) else: model = loaded # Check if this is an ultralytics model if self._is_ultralytics_model(model): logger.info("Detected ultralytics model, using ultralytics export API") return self._convert_ultralytics_model(pt_path, pt_hash, input_shapes, precision) # For non-ultralytics models, use torch_tensorrt logger.info("Using torch_tensorrt for conversion") model.eval() # Convert model to target precision to avoid mixed precision issues if precision == torch.float16: model = model.half() elif precision == torch.float32: model = model.float() # Move to GPU model = model.to(self.device) # Prepare inputs for tracing if input_shapes is None: raise ValueError( "input_shapes must be provided for TensorRT conversion. " "Example: {'x': (1, 3, 224, 224)}" ) # Create sample inputs with matching precision inputs = [] for name, shape in input_shapes.items(): sample_input = torch.randn(shape, device=self.device, dtype=precision) inputs.append(sample_input) # Configure torch_tensorrt enabled_precisions = {precision} if precision == torch.float16: enabled_precisions.add(torch.float32) # Fallback for unsupported ops # Compile to TensorRT logger.info(f"Compiling to TensorRT (precision: {precision})...") trt_model = torch_tensorrt.compile( model, inputs=inputs, enabled_precisions=enabled_precisions, **conversion_kwargs ) # Save TRT model to temporary location import tempfile with tempfile.NamedTemporaryFile(delete=False, suffix='.trt') as tmp_file: tmp_path = tmp_file.name torch.jit.save(trt_model, tmp_path) logger.debug(f"Saved TRT model to temporary file: {tmp_path}") # Compute TRT file hash trt_hash = self.compute_file_hash(tmp_path) logger.info(f"TRT file hash: {trt_hash[:16]}...") # Store in storage backend trt_key = f"trt/{trt_hash}.trt" with open(tmp_path, 'rb') as f: trt_data = f.read() self.storage.write(trt_key, trt_data) # Get local path trt_path = self.storage.get_local_path(trt_key) if trt_path is None: raise RuntimeError("Failed to get local path for converted TRT file") # Cleanup temp file Path(tmp_path).unlink() logger.info(f"TRT model stored successfully at {trt_path}") return (trt_hash, trt_path) except Exception as e: logger.error(f"Conversion failed: {e}", exc_info=True) raise RuntimeError(f"Failed to convert PT to TensorRT: {e}") def clear_cache(self): """Clear all cached conversions and mapping database""" logger.warning("Clearing all cached conversions...") self.mapping_db.clear() self._save_mapping_db() logger.info("Cache cleared") def get_stats(self) -> Dict: """ Get conversion statistics. Returns: Dictionary with cache stats """ return { "total_cached_conversions": len(self.mapping_db), "storage_path": self.storage.get_storage_path(), "gpu_id": self.gpu_id, "default_precision": str(self.default_precision), }