""" PTConverter - Convert PyTorch models to TensorRT using torch_tensorrt. This service handles conversion of .pt files to TensorRT format with caching to avoid redundant conversions. It maintains a mapping database between PT file hashes and their converted TensorRT file hashes. """ import hashlib import json import logging from pathlib import Path from typing import Optional, Dict, Tuple import torch import torch_tensorrt from .modelstorage import IModelStorage, FileModelStorage logger = logging.getLogger(__name__) class PTConverter: """ PyTorch to TensorRT converter with intelligent caching. Features: - Hash-based deduplication: Same PT file only converted once - Persistent mapping database: PT hash -> TRT hash mapping - Pluggable storage backend: IModelStorage interface - Automatic cache management Architecture: 1. Compute hash of input .pt file 2. Check mapping database for existing conversion 3. If found, return cached TRT file hash and path 4. If not found, perform conversion and store mapping """ def __init__( self, storage: Optional[IModelStorage] = None, gpu_id: int = 0, default_precision: torch.dtype = torch.float16 ): """ Initialize PT converter. Args: storage: Storage backend (defaults to FileModelStorage) gpu_id: GPU device ID for conversion default_precision: Default precision for TensorRT conversion (fp16, fp32) """ self.storage = storage or FileModelStorage() self.gpu_id = gpu_id self.device = torch.device(f'cuda:{gpu_id}') self.default_precision = default_precision # Mapping database: pt_hash -> {"trt_hash": str, "metadata": {...}} self.mapping_db_key = "pt_to_trt_mapping.json" self.mapping_db: Dict[str, Dict] = self._load_mapping_db() logger.info(f"PTConverter initialized on GPU {gpu_id}") logger.info(f"Storage backend: {self.storage.__class__.__name__}") logger.info(f"Storage path: {self.storage.get_storage_path()}") logger.info(f"Loaded {len(self.mapping_db)} cached conversions") def _load_mapping_db(self) -> Dict[str, Dict]: """ Load mapping database from storage. Returns: Mapping dictionary (pt_hash -> metadata) """ try: data = self.storage.read(self.mapping_db_key) if data: db = json.loads(data.decode('utf-8')) logger.debug(f"Loaded mapping database with {len(db)} entries") return db else: logger.debug("No existing mapping database found, starting fresh") return {} except Exception as e: logger.warning(f"Failed to load mapping database: {e}. Starting fresh.") return {} def _save_mapping_db(self): """Save mapping database to storage""" try: data = json.dumps(self.mapping_db, indent=2).encode('utf-8') self.storage.write(self.mapping_db_key, data) logger.debug(f"Saved mapping database with {len(self.mapping_db)} entries") except Exception as e: logger.error(f"Failed to save mapping database: {e}") @staticmethod def compute_file_hash(file_path: str) -> str: """ Compute SHA256 hash of a file. Args: file_path: Path to file Returns: Hexadecimal hash string """ sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: for byte_block in iter(lambda: f.read(65536), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() def get_cached_conversion(self, pt_hash: str) -> Optional[Tuple[str, str]]: """ Check if PT file has already been converted. Args: pt_hash: SHA256 hash of the PT file Returns: Tuple of (trt_hash, trt_file_path) if cached, None otherwise """ if pt_hash not in self.mapping_db: return None mapping = self.mapping_db[pt_hash] trt_hash = mapping["trt_hash"] # Check both .engine and .trt extensions (Ultralytics uses .engine, generic uses .trt) engine_key = f"trt/{trt_hash}.engine" trt_key = f"trt/{trt_hash}.trt" # Try .engine first (Ultralytics native format) if self.storage.exists(engine_key): cached_key = engine_key elif self.storage.exists(trt_key): cached_key = trt_key else: logger.warning( f"Mapping exists for PT hash {pt_hash[:16]}... but engine file missing. " f"Will reconvert." ) # Remove stale mapping del self.mapping_db[pt_hash] self._save_mapping_db() return None # Get local path cached_path = self.storage.get_local_path(cached_key) if cached_path is None: logger.error(f"Could not get local path for engine file {cached_key}") return None logger.info( f"Found cached conversion for PT hash {pt_hash[:16]}... -> " f"Engine hash {trt_hash[:16]}... ({cached_key})" ) return (trt_hash, cached_path) def convert( self, pt_file_path: str, input_shapes: Optional[Dict[str, Tuple]] = None, precision: Optional[torch.dtype] = None, **conversion_kwargs ) -> Tuple[str, str]: """ Convert PyTorch model to TensorRT. If this PT file has been converted before (same hash), returns cached result. Otherwise, performs conversion and caches the result. Args: pt_file_path: Path to .pt file input_shapes: Dict of input names to shapes (e.g., {"x": (1, 3, 224, 224)}) precision: Target precision (fp16, fp32) - defaults to self.default_precision **conversion_kwargs: Additional arguments for torch_tensorrt.compile() Returns: Tuple of (trt_hash, trt_file_path) Raises: FileNotFoundError: If PT file doesn't exist RuntimeError: If conversion fails """ pt_path = Path(pt_file_path).resolve() if not pt_path.exists(): raise FileNotFoundError(f"PT file not found: {pt_file_path}") # Compute PT file hash logger.info(f"Computing hash for {pt_path}...") pt_hash = self.compute_file_hash(str(pt_path)) logger.info(f"PT file hash: {pt_hash[:16]}...") # Check cache cached = self.get_cached_conversion(pt_hash) if cached: return cached # Perform conversion logger.info(f"Converting {pt_path.name} to TensorRT...") trt_hash, trt_path = self._perform_conversion( str(pt_path), pt_hash, input_shapes, precision or self.default_precision, **conversion_kwargs ) # Store mapping self.mapping_db[pt_hash] = { "trt_hash": trt_hash, "pt_file": str(pt_path), "input_shapes": str(input_shapes), "precision": str(precision or self.default_precision), } self._save_mapping_db() logger.info(f"Conversion complete: PT {pt_hash[:16]}... -> TRT {trt_hash[:16]}...") return (trt_hash, trt_path) def _is_ultralytics_model(self, model) -> bool: """ Check if model is from ultralytics (YOLO). Args: model: PyTorch model Returns: True if ultralytics model, False otherwise """ model_class_name = model.__class__.__name__ model_module = model.__class__.__module__ # Check if it's an ultralytics model is_ultralytics = ( 'ultralytics' in model_module or model_class_name in ['DetectionModel', 'SegmentationModel', 'PoseModel', 'ClassificationModel'] ) return is_ultralytics def _convert_ultralytics_model( self, pt_path: str, pt_hash: str, input_shapes: Optional[Dict[str, Tuple]], precision: torch.dtype, ) -> Tuple[str, str]: """ Convert ultralytics YOLO model using native .engine export. This produces .engine files with embedded metadata (no manual input_shapes needed). Args: pt_path: Path to PT file pt_hash: PT file hash input_shapes: Input tensor shapes (IGNORED for Ultralytics - auto-detected) precision: Target precision Returns: Tuple of (engine_hash, engine_file_path) """ import os logger.info("Detected ultralytics YOLO model, using native .engine export...") # Load ultralytics model try: from ultralytics import YOLO model = YOLO(pt_path) except ImportError: raise ImportError("ultralytics package not found. Install with: pip install ultralytics") # Export to native .engine format with embedded metadata logger.info(f"Exporting to native TensorRT .engine (precision: {'FP16' if precision == torch.float16 else 'FP32'})...") # Ultralytics export creates .engine file in same directory as .pt engine_path = model.export( format='engine', half=(precision == torch.float16), device=self.gpu_id, batch=1, simplify=True ) # Convert to string (Ultralytics returns Path object) engine_path = str(engine_path) logger.info(f"Native .engine export complete: {engine_path}") logger.info("Metadata embedded in .engine file (stride, imgsz, names, etc.)") # Read the exported .engine file with open(engine_path, 'rb') as f: engine_data = f.read() # Compute hash of the .engine file engine_hash = hashlib.sha256(engine_data).hexdigest() # Store in our cache (as .engine to preserve metadata) engine_key = f"trt/{engine_hash}.engine" self.storage.write(engine_key, engine_data) cached_path = self.storage.get_local_path(engine_key) if cached_path is None: raise RuntimeError("Failed to get local path for .engine file") # Clean up the original export (we've cached it) # Only delete if it's different from cached path if os.path.exists(engine_path) and os.path.abspath(engine_path) != os.path.abspath(cached_path): logger.info(f"Removing original export (cached): {engine_path}") os.unlink(engine_path) else: logger.info(f"Keeping original export at: {engine_path}") logger.info(f"Cached .engine file: {cached_path}") return (engine_hash, cached_path) def _perform_conversion( self, pt_path: str, pt_hash: str, input_shapes: Optional[Dict[str, Tuple]], precision: torch.dtype, **conversion_kwargs ) -> Tuple[str, str]: """ Perform the actual PT to TRT conversion. Args: pt_path: Path to PT file pt_hash: PT file hash input_shapes: Input tensor shapes precision: Target precision **conversion_kwargs: Additional torch_tensorrt arguments Returns: Tuple of (trt_hash, trt_file_path) """ try: # Load PyTorch model to check type logger.debug(f"Loading PyTorch model from {pt_path}...") # Use weights_only=False for models with custom classes (like ultralytics) # This is safe for trusted local models loaded = torch.load(pt_path, map_location='cpu', weights_only=False) # If model is wrapped in a dict, extract the model if isinstance(loaded, dict): if 'model' in loaded: model = loaded['model'] elif 'state_dict' in loaded: raise ValueError( "PT file contains state_dict only. " "Please provide a full model or use a different loading method." ) else: model = loaded # Check if this is an ultralytics model if self._is_ultralytics_model(model): logger.info("Detected Ultralytics YOLO model, using native .engine export") logger.info("Note: input_shapes parameter is ignored for Ultralytics models (auto-detected)") return self._convert_ultralytics_model(pt_path, pt_hash, input_shapes, precision) # For non-ultralytics models, use torch_tensorrt logger.info("Using torch_tensorrt for conversion (non-Ultralytics model)") # Non-Ultralytics models REQUIRE input_shapes if input_shapes is None: raise ValueError( "input_shapes required for non-Ultralytics PyTorch models. " "For Ultralytics YOLO models, input_shapes is auto-detected. " "Example: input_shapes={'images': (1, 3, 640, 640)}" ) model.eval() # Convert model to target precision to avoid mixed precision issues if precision == torch.float16: model = model.half() elif precision == torch.float32: model = model.float() # Move to GPU model = model.to(self.device) # Prepare inputs for tracing if input_shapes is None: raise ValueError( "input_shapes must be provided for TensorRT conversion. " "Example: {'x': (1, 3, 224, 224)}" ) # Create sample inputs with matching precision inputs = [] for name, shape in input_shapes.items(): sample_input = torch.randn(shape, device=self.device, dtype=precision) inputs.append(sample_input) # Configure torch_tensorrt enabled_precisions = {precision} if precision == torch.float16: enabled_precisions.add(torch.float32) # Fallback for unsupported ops # Compile to TensorRT logger.info(f"Compiling to TensorRT (precision: {precision})...") trt_model = torch_tensorrt.compile( model, inputs=inputs, enabled_precisions=enabled_precisions, **conversion_kwargs ) # Save TRT model to temporary location import tempfile with tempfile.NamedTemporaryFile(delete=False, suffix='.trt') as tmp_file: tmp_path = tmp_file.name torch.jit.save(trt_model, tmp_path) logger.debug(f"Saved TRT model to temporary file: {tmp_path}") # Compute TRT file hash trt_hash = self.compute_file_hash(tmp_path) logger.info(f"TRT file hash: {trt_hash[:16]}...") # Store in storage backend trt_key = f"trt/{trt_hash}.trt" with open(tmp_path, 'rb') as f: trt_data = f.read() self.storage.write(trt_key, trt_data) # Get local path trt_path = self.storage.get_local_path(trt_key) if trt_path is None: raise RuntimeError("Failed to get local path for converted TRT file") # Cleanup temp file Path(tmp_path).unlink() logger.info(f"TRT model stored successfully at {trt_path}") return (trt_hash, trt_path) except Exception as e: logger.error(f"Conversion failed: {e}", exc_info=True) raise RuntimeError(f"Failed to convert PT to TensorRT: {e}") def clear_cache(self): """Clear all cached conversions and mapping database""" logger.warning("Clearing all cached conversions...") self.mapping_db.clear() self._save_mapping_db() logger.info("Cache cleared") def get_stats(self) -> Dict: """ Get conversion statistics. Returns: Dictionary with cache stats """ return { "total_cached_conversions": len(self.mapping_db), "storage_path": self.storage.get_storage_path(), "gpu_id": self.gpu_id, "default_precision": str(self.default_precision), }