464 lines
16 KiB
Python
464 lines
16 KiB
Python
"""
|
|
PTConverter - Convert PyTorch models to TensorRT using torch_tensorrt.
|
|
|
|
This service handles conversion of .pt files to TensorRT format with caching
|
|
to avoid redundant conversions. It maintains a mapping database between
|
|
PT file hashes and their converted TensorRT file hashes.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Tuple
|
|
import torch
|
|
import torch_tensorrt
|
|
|
|
from .modelstorage import IModelStorage, FileModelStorage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PTConverter:
|
|
"""
|
|
PyTorch to TensorRT converter with intelligent caching.
|
|
|
|
Features:
|
|
- Hash-based deduplication: Same PT file only converted once
|
|
- Persistent mapping database: PT hash -> TRT hash mapping
|
|
- Pluggable storage backend: IModelStorage interface
|
|
- Automatic cache management
|
|
|
|
Architecture:
|
|
1. Compute hash of input .pt file
|
|
2. Check mapping database for existing conversion
|
|
3. If found, return cached TRT file hash and path
|
|
4. If not found, perform conversion and store mapping
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
storage: Optional[IModelStorage] = None,
|
|
gpu_id: int = 0,
|
|
default_precision: torch.dtype = torch.float16
|
|
):
|
|
"""
|
|
Initialize PT converter.
|
|
|
|
Args:
|
|
storage: Storage backend (defaults to FileModelStorage)
|
|
gpu_id: GPU device ID for conversion
|
|
default_precision: Default precision for TensorRT conversion (fp16, fp32)
|
|
"""
|
|
self.storage = storage or FileModelStorage()
|
|
self.gpu_id = gpu_id
|
|
self.device = torch.device(f'cuda:{gpu_id}')
|
|
self.default_precision = default_precision
|
|
|
|
# Mapping database: pt_hash -> {"trt_hash": str, "metadata": {...}}
|
|
self.mapping_db_key = "pt_to_trt_mapping.json"
|
|
self.mapping_db: Dict[str, Dict] = self._load_mapping_db()
|
|
|
|
logger.info(f"PTConverter initialized on GPU {gpu_id}")
|
|
logger.info(f"Storage backend: {self.storage.__class__.__name__}")
|
|
logger.info(f"Storage path: {self.storage.get_storage_path()}")
|
|
logger.info(f"Loaded {len(self.mapping_db)} cached conversions")
|
|
|
|
def _load_mapping_db(self) -> Dict[str, Dict]:
|
|
"""
|
|
Load mapping database from storage.
|
|
|
|
Returns:
|
|
Mapping dictionary (pt_hash -> metadata)
|
|
"""
|
|
try:
|
|
data = self.storage.read(self.mapping_db_key)
|
|
if data:
|
|
db = json.loads(data.decode('utf-8'))
|
|
logger.debug(f"Loaded mapping database with {len(db)} entries")
|
|
return db
|
|
else:
|
|
logger.debug("No existing mapping database found, starting fresh")
|
|
return {}
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load mapping database: {e}. Starting fresh.")
|
|
return {}
|
|
|
|
def _save_mapping_db(self):
|
|
"""Save mapping database to storage"""
|
|
try:
|
|
data = json.dumps(self.mapping_db, indent=2).encode('utf-8')
|
|
self.storage.write(self.mapping_db_key, data)
|
|
logger.debug(f"Saved mapping database with {len(self.mapping_db)} entries")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save mapping database: {e}")
|
|
|
|
@staticmethod
|
|
def compute_file_hash(file_path: str) -> str:
|
|
"""
|
|
Compute SHA256 hash of a file.
|
|
|
|
Args:
|
|
file_path: Path to file
|
|
|
|
Returns:
|
|
Hexadecimal hash string
|
|
"""
|
|
sha256_hash = hashlib.sha256()
|
|
with open(file_path, "rb") as f:
|
|
for byte_block in iter(lambda: f.read(65536), b""):
|
|
sha256_hash.update(byte_block)
|
|
return sha256_hash.hexdigest()
|
|
|
|
def get_cached_conversion(self, pt_hash: str) -> Optional[Tuple[str, str]]:
|
|
"""
|
|
Check if PT file has already been converted.
|
|
|
|
Args:
|
|
pt_hash: SHA256 hash of the PT file
|
|
|
|
Returns:
|
|
Tuple of (trt_hash, trt_file_path) if cached, None otherwise
|
|
"""
|
|
if pt_hash not in self.mapping_db:
|
|
return None
|
|
|
|
mapping = self.mapping_db[pt_hash]
|
|
trt_hash = mapping["trt_hash"]
|
|
|
|
# Check both .engine and .trt extensions (Ultralytics uses .engine, generic uses .trt)
|
|
engine_key = f"trt/{trt_hash}.engine"
|
|
trt_key = f"trt/{trt_hash}.trt"
|
|
|
|
# Try .engine first (Ultralytics native format)
|
|
if self.storage.exists(engine_key):
|
|
cached_key = engine_key
|
|
elif self.storage.exists(trt_key):
|
|
cached_key = trt_key
|
|
else:
|
|
logger.warning(
|
|
f"Mapping exists for PT hash {pt_hash[:16]}... but engine file missing. "
|
|
f"Will reconvert."
|
|
)
|
|
# Remove stale mapping
|
|
del self.mapping_db[pt_hash]
|
|
self._save_mapping_db()
|
|
return None
|
|
|
|
# Get local path
|
|
cached_path = self.storage.get_local_path(cached_key)
|
|
if cached_path is None:
|
|
logger.error(f"Could not get local path for engine file {cached_key}")
|
|
return None
|
|
|
|
logger.info(
|
|
f"Found cached conversion for PT hash {pt_hash[:16]}... -> "
|
|
f"Engine hash {trt_hash[:16]}... ({cached_key})"
|
|
)
|
|
return (trt_hash, cached_path)
|
|
|
|
def convert(
|
|
self,
|
|
pt_file_path: str,
|
|
input_shapes: Optional[Dict[str, Tuple]] = None,
|
|
precision: Optional[torch.dtype] = None,
|
|
**conversion_kwargs
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Convert PyTorch model to TensorRT.
|
|
|
|
If this PT file has been converted before (same hash), returns cached result.
|
|
Otherwise, performs conversion and caches the result.
|
|
|
|
Args:
|
|
pt_file_path: Path to .pt file
|
|
input_shapes: Dict of input names to shapes (e.g., {"x": (1, 3, 224, 224)})
|
|
precision: Target precision (fp16, fp32) - defaults to self.default_precision
|
|
**conversion_kwargs: Additional arguments for torch_tensorrt.compile()
|
|
|
|
Returns:
|
|
Tuple of (trt_hash, trt_file_path)
|
|
|
|
Raises:
|
|
FileNotFoundError: If PT file doesn't exist
|
|
RuntimeError: If conversion fails
|
|
"""
|
|
pt_path = Path(pt_file_path).resolve()
|
|
if not pt_path.exists():
|
|
raise FileNotFoundError(f"PT file not found: {pt_file_path}")
|
|
|
|
# Compute PT file hash
|
|
logger.info(f"Computing hash for {pt_path}...")
|
|
pt_hash = self.compute_file_hash(str(pt_path))
|
|
logger.info(f"PT file hash: {pt_hash[:16]}...")
|
|
|
|
# Check cache
|
|
cached = self.get_cached_conversion(pt_hash)
|
|
if cached:
|
|
return cached
|
|
|
|
# Perform conversion
|
|
logger.info(f"Converting {pt_path.name} to TensorRT...")
|
|
trt_hash, trt_path = self._perform_conversion(
|
|
str(pt_path),
|
|
pt_hash,
|
|
input_shapes,
|
|
precision or self.default_precision,
|
|
**conversion_kwargs
|
|
)
|
|
|
|
# Store mapping
|
|
self.mapping_db[pt_hash] = {
|
|
"trt_hash": trt_hash,
|
|
"pt_file": str(pt_path),
|
|
"input_shapes": str(input_shapes),
|
|
"precision": str(precision or self.default_precision),
|
|
}
|
|
self._save_mapping_db()
|
|
|
|
logger.info(f"Conversion complete: PT {pt_hash[:16]}... -> TRT {trt_hash[:16]}...")
|
|
return (trt_hash, trt_path)
|
|
|
|
def _is_ultralytics_model(self, model) -> bool:
|
|
"""
|
|
Check if model is from ultralytics (YOLO).
|
|
|
|
Args:
|
|
model: PyTorch model
|
|
|
|
Returns:
|
|
True if ultralytics model, False otherwise
|
|
"""
|
|
model_class_name = model.__class__.__name__
|
|
model_module = model.__class__.__module__
|
|
|
|
# Check if it's an ultralytics model
|
|
is_ultralytics = (
|
|
'ultralytics' in model_module or
|
|
model_class_name in ['DetectionModel', 'SegmentationModel', 'PoseModel', 'ClassificationModel']
|
|
)
|
|
|
|
return is_ultralytics
|
|
|
|
def _convert_ultralytics_model(
|
|
self,
|
|
pt_path: str,
|
|
pt_hash: str,
|
|
input_shapes: Optional[Dict[str, Tuple]],
|
|
precision: torch.dtype,
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Convert ultralytics YOLO model using native .engine export.
|
|
This produces .engine files with embedded metadata (no manual input_shapes needed).
|
|
|
|
Args:
|
|
pt_path: Path to PT file
|
|
pt_hash: PT file hash
|
|
input_shapes: Input tensor shapes (IGNORED for Ultralytics - auto-detected)
|
|
precision: Target precision
|
|
|
|
Returns:
|
|
Tuple of (engine_hash, engine_file_path)
|
|
"""
|
|
import os
|
|
|
|
logger.info("Detected ultralytics YOLO model, using native .engine export...")
|
|
|
|
# Load ultralytics model
|
|
try:
|
|
from ultralytics import YOLO
|
|
model = YOLO(pt_path)
|
|
except ImportError:
|
|
raise ImportError("ultralytics package not found. Install with: pip install ultralytics")
|
|
|
|
# Export to native .engine format with embedded metadata
|
|
logger.info(f"Exporting to native TensorRT .engine (precision: {'FP16' if precision == torch.float16 else 'FP32'})...")
|
|
|
|
# Ultralytics export creates .engine file in same directory as .pt
|
|
engine_path = model.export(
|
|
format='engine',
|
|
half=(precision == torch.float16),
|
|
device=self.gpu_id,
|
|
batch=1,
|
|
simplify=True
|
|
)
|
|
|
|
# Convert to string (Ultralytics returns Path object)
|
|
engine_path = str(engine_path)
|
|
logger.info(f"Native .engine export complete: {engine_path}")
|
|
logger.info("Metadata embedded in .engine file (stride, imgsz, names, etc.)")
|
|
|
|
# Read the exported .engine file
|
|
with open(engine_path, 'rb') as f:
|
|
engine_data = f.read()
|
|
|
|
# Compute hash of the .engine file
|
|
engine_hash = hashlib.sha256(engine_data).hexdigest()
|
|
|
|
# Store in our cache (as .engine to preserve metadata)
|
|
engine_key = f"trt/{engine_hash}.engine"
|
|
self.storage.write(engine_key, engine_data)
|
|
|
|
cached_path = self.storage.get_local_path(engine_key)
|
|
if cached_path is None:
|
|
raise RuntimeError("Failed to get local path for .engine file")
|
|
|
|
# Clean up the original export (we've cached it)
|
|
# Only delete if it's different from cached path
|
|
if os.path.exists(engine_path) and os.path.abspath(engine_path) != os.path.abspath(cached_path):
|
|
logger.info(f"Removing original export (cached): {engine_path}")
|
|
os.unlink(engine_path)
|
|
else:
|
|
logger.info(f"Keeping original export at: {engine_path}")
|
|
|
|
logger.info(f"Cached .engine file: {cached_path}")
|
|
return (engine_hash, cached_path)
|
|
|
|
def _perform_conversion(
|
|
self,
|
|
pt_path: str,
|
|
pt_hash: str,
|
|
input_shapes: Optional[Dict[str, Tuple]],
|
|
precision: torch.dtype,
|
|
**conversion_kwargs
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Perform the actual PT to TRT conversion.
|
|
|
|
Args:
|
|
pt_path: Path to PT file
|
|
pt_hash: PT file hash
|
|
input_shapes: Input tensor shapes
|
|
precision: Target precision
|
|
**conversion_kwargs: Additional torch_tensorrt arguments
|
|
|
|
Returns:
|
|
Tuple of (trt_hash, trt_file_path)
|
|
"""
|
|
try:
|
|
# Load PyTorch model to check type
|
|
logger.debug(f"Loading PyTorch model from {pt_path}...")
|
|
# Use weights_only=False for models with custom classes (like ultralytics)
|
|
# This is safe for trusted local models
|
|
loaded = torch.load(pt_path, map_location='cpu', weights_only=False)
|
|
|
|
# If model is wrapped in a dict, extract the model
|
|
if isinstance(loaded, dict):
|
|
if 'model' in loaded:
|
|
model = loaded['model']
|
|
elif 'state_dict' in loaded:
|
|
raise ValueError(
|
|
"PT file contains state_dict only. "
|
|
"Please provide a full model or use a different loading method."
|
|
)
|
|
else:
|
|
model = loaded
|
|
|
|
# Check if this is an ultralytics model
|
|
if self._is_ultralytics_model(model):
|
|
logger.info("Detected Ultralytics YOLO model, using native .engine export")
|
|
logger.info("Note: input_shapes parameter is ignored for Ultralytics models (auto-detected)")
|
|
return self._convert_ultralytics_model(pt_path, pt_hash, input_shapes, precision)
|
|
|
|
# For non-ultralytics models, use torch_tensorrt
|
|
logger.info("Using torch_tensorrt for conversion (non-Ultralytics model)")
|
|
|
|
# Non-Ultralytics models REQUIRE input_shapes
|
|
if input_shapes is None:
|
|
raise ValueError(
|
|
"input_shapes required for non-Ultralytics PyTorch models. "
|
|
"For Ultralytics YOLO models, input_shapes is auto-detected. "
|
|
"Example: input_shapes={'images': (1, 3, 640, 640)}"
|
|
)
|
|
|
|
model.eval()
|
|
|
|
# Convert model to target precision to avoid mixed precision issues
|
|
if precision == torch.float16:
|
|
model = model.half()
|
|
elif precision == torch.float32:
|
|
model = model.float()
|
|
|
|
# Move to GPU
|
|
model = model.to(self.device)
|
|
|
|
# Prepare inputs for tracing
|
|
if input_shapes is None:
|
|
raise ValueError(
|
|
"input_shapes must be provided for TensorRT conversion. "
|
|
"Example: {'x': (1, 3, 224, 224)}"
|
|
)
|
|
|
|
# Create sample inputs with matching precision
|
|
inputs = []
|
|
for name, shape in input_shapes.items():
|
|
sample_input = torch.randn(shape, device=self.device, dtype=precision)
|
|
inputs.append(sample_input)
|
|
|
|
# Configure torch_tensorrt
|
|
enabled_precisions = {precision}
|
|
if precision == torch.float16:
|
|
enabled_precisions.add(torch.float32) # Fallback for unsupported ops
|
|
|
|
# Compile to TensorRT
|
|
logger.info(f"Compiling to TensorRT (precision: {precision})...")
|
|
trt_model = torch_tensorrt.compile(
|
|
model,
|
|
inputs=inputs,
|
|
enabled_precisions=enabled_precisions,
|
|
**conversion_kwargs
|
|
)
|
|
|
|
# Save TRT model to temporary location
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.trt') as tmp_file:
|
|
tmp_path = tmp_file.name
|
|
|
|
torch.jit.save(trt_model, tmp_path)
|
|
logger.debug(f"Saved TRT model to temporary file: {tmp_path}")
|
|
|
|
# Compute TRT file hash
|
|
trt_hash = self.compute_file_hash(tmp_path)
|
|
logger.info(f"TRT file hash: {trt_hash[:16]}...")
|
|
|
|
# Store in storage backend
|
|
trt_key = f"trt/{trt_hash}.trt"
|
|
with open(tmp_path, 'rb') as f:
|
|
trt_data = f.read()
|
|
self.storage.write(trt_key, trt_data)
|
|
|
|
# Get local path
|
|
trt_path = self.storage.get_local_path(trt_key)
|
|
if trt_path is None:
|
|
raise RuntimeError("Failed to get local path for converted TRT file")
|
|
|
|
# Cleanup temp file
|
|
Path(tmp_path).unlink()
|
|
|
|
logger.info(f"TRT model stored successfully at {trt_path}")
|
|
return (trt_hash, trt_path)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Conversion failed: {e}", exc_info=True)
|
|
raise RuntimeError(f"Failed to convert PT to TensorRT: {e}")
|
|
|
|
def clear_cache(self):
|
|
"""Clear all cached conversions and mapping database"""
|
|
logger.warning("Clearing all cached conversions...")
|
|
self.mapping_db.clear()
|
|
self._save_mapping_db()
|
|
logger.info("Cache cleared")
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""
|
|
Get conversion statistics.
|
|
|
|
Returns:
|
|
Dictionary with cache stats
|
|
"""
|
|
return {
|
|
"total_cached_conversions": len(self.mapping_db),
|
|
"storage_path": self.storage.get_storage_path(),
|
|
"gpu_id": self.gpu_id,
|
|
"default_precision": str(self.default_precision),
|
|
}
|