485 lines
17 KiB
Python
485 lines
17 KiB
Python
"""
|
|
PTConverter - Convert PyTorch models to TensorRT using torch_tensorrt.
|
|
|
|
This service handles conversion of .pt files to TensorRT format with caching
|
|
to avoid redundant conversions. It maintains a mapping database between
|
|
PT file hashes and their converted TensorRT file hashes.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Tuple
|
|
import torch
|
|
import torch_tensorrt
|
|
|
|
from .modelstorage import IModelStorage, FileModelStorage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PTConverter:
|
|
"""
|
|
PyTorch to TensorRT converter with intelligent caching.
|
|
|
|
Features:
|
|
- Hash-based deduplication: Same PT file only converted once
|
|
- Persistent mapping database: PT hash -> TRT hash mapping
|
|
- Pluggable storage backend: IModelStorage interface
|
|
- Automatic cache management
|
|
|
|
Architecture:
|
|
1. Compute hash of input .pt file
|
|
2. Check mapping database for existing conversion
|
|
3. If found, return cached TRT file hash and path
|
|
4. If not found, perform conversion and store mapping
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
storage: Optional[IModelStorage] = None,
|
|
gpu_id: int = 0,
|
|
default_precision: torch.dtype = torch.float16
|
|
):
|
|
"""
|
|
Initialize PT converter.
|
|
|
|
Args:
|
|
storage: Storage backend (defaults to FileModelStorage)
|
|
gpu_id: GPU device ID for conversion
|
|
default_precision: Default precision for TensorRT conversion (fp16, fp32)
|
|
"""
|
|
self.storage = storage or FileModelStorage()
|
|
self.gpu_id = gpu_id
|
|
self.device = torch.device(f'cuda:{gpu_id}')
|
|
self.default_precision = default_precision
|
|
|
|
# Mapping database: pt_hash -> {"trt_hash": str, "metadata": {...}}
|
|
self.mapping_db_key = "pt_to_trt_mapping.json"
|
|
self.mapping_db: Dict[str, Dict] = self._load_mapping_db()
|
|
|
|
logger.info(f"PTConverter initialized on GPU {gpu_id}")
|
|
logger.info(f"Storage backend: {self.storage.__class__.__name__}")
|
|
logger.info(f"Storage path: {self.storage.get_storage_path()}")
|
|
logger.info(f"Loaded {len(self.mapping_db)} cached conversions")
|
|
|
|
def _load_mapping_db(self) -> Dict[str, Dict]:
|
|
"""
|
|
Load mapping database from storage.
|
|
|
|
Returns:
|
|
Mapping dictionary (pt_hash -> metadata)
|
|
"""
|
|
try:
|
|
data = self.storage.read(self.mapping_db_key)
|
|
if data:
|
|
db = json.loads(data.decode('utf-8'))
|
|
logger.debug(f"Loaded mapping database with {len(db)} entries")
|
|
return db
|
|
else:
|
|
logger.debug("No existing mapping database found, starting fresh")
|
|
return {}
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load mapping database: {e}. Starting fresh.")
|
|
return {}
|
|
|
|
def _save_mapping_db(self):
|
|
"""Save mapping database to storage"""
|
|
try:
|
|
data = json.dumps(self.mapping_db, indent=2).encode('utf-8')
|
|
self.storage.write(self.mapping_db_key, data)
|
|
logger.debug(f"Saved mapping database with {len(self.mapping_db)} entries")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save mapping database: {e}")
|
|
|
|
@staticmethod
|
|
def compute_file_hash(file_path: str) -> str:
|
|
"""
|
|
Compute SHA256 hash of a file.
|
|
|
|
Args:
|
|
file_path: Path to file
|
|
|
|
Returns:
|
|
Hexadecimal hash string
|
|
"""
|
|
sha256_hash = hashlib.sha256()
|
|
with open(file_path, "rb") as f:
|
|
for byte_block in iter(lambda: f.read(65536), b""):
|
|
sha256_hash.update(byte_block)
|
|
return sha256_hash.hexdigest()
|
|
|
|
def get_cached_conversion(self, pt_hash: str) -> Optional[Tuple[str, str]]:
|
|
"""
|
|
Check if PT file has already been converted.
|
|
|
|
Args:
|
|
pt_hash: SHA256 hash of the PT file
|
|
|
|
Returns:
|
|
Tuple of (trt_hash, trt_file_path) if cached, None otherwise
|
|
"""
|
|
if pt_hash not in self.mapping_db:
|
|
return None
|
|
|
|
mapping = self.mapping_db[pt_hash]
|
|
trt_hash = mapping["trt_hash"]
|
|
trt_key = f"trt/{trt_hash}.trt"
|
|
|
|
# Verify TRT file still exists in storage
|
|
if not self.storage.exists(trt_key):
|
|
logger.warning(
|
|
f"Mapping exists for PT hash {pt_hash[:16]}... but TRT file missing. "
|
|
f"Will reconvert."
|
|
)
|
|
# Remove stale mapping
|
|
del self.mapping_db[pt_hash]
|
|
self._save_mapping_db()
|
|
return None
|
|
|
|
# Get local path
|
|
trt_path = self.storage.get_local_path(trt_key)
|
|
if trt_path is None:
|
|
logger.error(f"Could not get local path for TRT file {trt_key}")
|
|
return None
|
|
|
|
logger.info(
|
|
f"Found cached conversion for PT hash {pt_hash[:16]}... -> "
|
|
f"TRT hash {trt_hash[:16]}..."
|
|
)
|
|
return (trt_hash, trt_path)
|
|
|
|
def convert(
|
|
self,
|
|
pt_file_path: str,
|
|
input_shapes: Optional[Dict[str, Tuple]] = None,
|
|
precision: Optional[torch.dtype] = None,
|
|
**conversion_kwargs
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Convert PyTorch model to TensorRT.
|
|
|
|
If this PT file has been converted before (same hash), returns cached result.
|
|
Otherwise, performs conversion and caches the result.
|
|
|
|
Args:
|
|
pt_file_path: Path to .pt file
|
|
input_shapes: Dict of input names to shapes (e.g., {"x": (1, 3, 224, 224)})
|
|
precision: Target precision (fp16, fp32) - defaults to self.default_precision
|
|
**conversion_kwargs: Additional arguments for torch_tensorrt.compile()
|
|
|
|
Returns:
|
|
Tuple of (trt_hash, trt_file_path)
|
|
|
|
Raises:
|
|
FileNotFoundError: If PT file doesn't exist
|
|
RuntimeError: If conversion fails
|
|
"""
|
|
pt_path = Path(pt_file_path).resolve()
|
|
if not pt_path.exists():
|
|
raise FileNotFoundError(f"PT file not found: {pt_file_path}")
|
|
|
|
# Compute PT file hash
|
|
logger.info(f"Computing hash for {pt_path}...")
|
|
pt_hash = self.compute_file_hash(str(pt_path))
|
|
logger.info(f"PT file hash: {pt_hash[:16]}...")
|
|
|
|
# Check cache
|
|
cached = self.get_cached_conversion(pt_hash)
|
|
if cached:
|
|
return cached
|
|
|
|
# Perform conversion
|
|
logger.info(f"Converting {pt_path.name} to TensorRT...")
|
|
trt_hash, trt_path = self._perform_conversion(
|
|
str(pt_path),
|
|
pt_hash,
|
|
input_shapes,
|
|
precision or self.default_precision,
|
|
**conversion_kwargs
|
|
)
|
|
|
|
# Store mapping
|
|
self.mapping_db[pt_hash] = {
|
|
"trt_hash": trt_hash,
|
|
"pt_file": str(pt_path),
|
|
"input_shapes": str(input_shapes),
|
|
"precision": str(precision or self.default_precision),
|
|
}
|
|
self._save_mapping_db()
|
|
|
|
logger.info(f"Conversion complete: PT {pt_hash[:16]}... -> TRT {trt_hash[:16]}...")
|
|
return (trt_hash, trt_path)
|
|
|
|
def _is_ultralytics_model(self, model) -> bool:
|
|
"""
|
|
Check if model is from ultralytics (YOLO).
|
|
|
|
Args:
|
|
model: PyTorch model
|
|
|
|
Returns:
|
|
True if ultralytics model, False otherwise
|
|
"""
|
|
model_class_name = model.__class__.__name__
|
|
model_module = model.__class__.__module__
|
|
|
|
# Check if it's an ultralytics model
|
|
is_ultralytics = (
|
|
'ultralytics' in model_module or
|
|
model_class_name in ['DetectionModel', 'SegmentationModel', 'PoseModel', 'ClassificationModel']
|
|
)
|
|
|
|
return is_ultralytics
|
|
|
|
def _convert_ultralytics_model(
|
|
self,
|
|
pt_path: str,
|
|
pt_hash: str,
|
|
input_shapes: Optional[Dict[str, Tuple]],
|
|
precision: torch.dtype,
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Convert ultralytics YOLO model using ONNX → TensorRT pipeline.
|
|
Uses the same approach as scripts/convert_pt_to_tensorrt.py
|
|
|
|
Args:
|
|
pt_path: Path to PT file
|
|
pt_hash: PT file hash
|
|
input_shapes: Input tensor shapes
|
|
precision: Target precision
|
|
|
|
Returns:
|
|
Tuple of (trt_hash, trt_file_path)
|
|
"""
|
|
import tensorrt as trt
|
|
import tempfile
|
|
import os
|
|
import shutil
|
|
|
|
logger.info("Detected ultralytics YOLO model, using ONNX → TensorRT pipeline...")
|
|
|
|
# Load ultralytics model
|
|
try:
|
|
from ultralytics import YOLO
|
|
model = YOLO(pt_path)
|
|
except ImportError:
|
|
raise ImportError("ultralytics package not found. Install with: pip install ultralytics")
|
|
|
|
# Determine input shape
|
|
if not input_shapes:
|
|
raise ValueError("input_shapes required for ultralytics conversion")
|
|
|
|
input_key = 'images' if 'images' in input_shapes else list(input_shapes.keys())[0]
|
|
input_shape = input_shapes[input_key]
|
|
|
|
# Export to ONNX first
|
|
logger.info(f"Exporting to ONNX (input shape: {input_shape})...")
|
|
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp_onnx:
|
|
onnx_path = tmp_onnx.name
|
|
|
|
try:
|
|
# Use ultralytics export to ONNX
|
|
model.export(format='onnx', imgsz=input_shape[2], batch=input_shape[0])
|
|
# Ultralytics saves as model_name.onnx in same directory
|
|
pt_dir = os.path.dirname(pt_path)
|
|
pt_name = os.path.splitext(os.path.basename(pt_path))[0]
|
|
onnx_export_path = os.path.join(pt_dir, f"{pt_name}.onnx")
|
|
|
|
# Move to our temp location (use shutil.move for cross-device support)
|
|
if os.path.exists(onnx_export_path):
|
|
shutil.move(onnx_export_path, onnx_path)
|
|
else:
|
|
raise RuntimeError(f"ONNX export failed, file not found: {onnx_export_path}")
|
|
|
|
logger.info(f"ONNX export complete: {onnx_path}")
|
|
|
|
# Build TensorRT engine from ONNX
|
|
logger.info("Building TensorRT engine from ONNX...")
|
|
trt_logger = trt.Logger(trt.Logger.WARNING)
|
|
builder = trt.Builder(trt_logger)
|
|
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
|
parser = trt.OnnxParser(network, trt_logger)
|
|
|
|
# Parse ONNX
|
|
with open(onnx_path, 'rb') as f:
|
|
if not parser.parse(f.read()):
|
|
errors = [parser.get_error(i) for i in range(parser.num_errors)]
|
|
raise RuntimeError(f"Failed to parse ONNX: {errors}")
|
|
|
|
# Configure builder
|
|
config = builder.create_builder_config()
|
|
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB
|
|
|
|
# Set precision
|
|
if precision == torch.float16:
|
|
if builder.platform_has_fast_fp16:
|
|
config.set_flag(trt.BuilderFlag.FP16)
|
|
logger.info("FP16 mode enabled")
|
|
|
|
# Build engine
|
|
logger.info("Building TensorRT engine (this may take a few minutes)...")
|
|
serialized_engine = builder.build_serialized_network(network, config)
|
|
|
|
if serialized_engine is None:
|
|
raise RuntimeError("Failed to build TensorRT engine")
|
|
|
|
# Convert IHostMemory to bytes
|
|
engine_bytes = bytes(serialized_engine)
|
|
|
|
# Save to storage
|
|
trt_hash = hashlib.sha256(engine_bytes).hexdigest()
|
|
trt_key = f"trt/{trt_hash}.trt"
|
|
self.storage.write(trt_key, engine_bytes)
|
|
|
|
trt_path = self.storage.get_local_path(trt_key)
|
|
if trt_path is None:
|
|
raise RuntimeError("Failed to get local path for TRT file")
|
|
|
|
logger.info(f"TensorRT engine built successfully: {trt_path}")
|
|
return (trt_hash, trt_path)
|
|
|
|
finally:
|
|
# Cleanup ONNX file
|
|
if os.path.exists(onnx_path):
|
|
os.unlink(onnx_path)
|
|
|
|
def _perform_conversion(
|
|
self,
|
|
pt_path: str,
|
|
pt_hash: str,
|
|
input_shapes: Optional[Dict[str, Tuple]],
|
|
precision: torch.dtype,
|
|
**conversion_kwargs
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Perform the actual PT to TRT conversion.
|
|
|
|
Args:
|
|
pt_path: Path to PT file
|
|
pt_hash: PT file hash
|
|
input_shapes: Input tensor shapes
|
|
precision: Target precision
|
|
**conversion_kwargs: Additional torch_tensorrt arguments
|
|
|
|
Returns:
|
|
Tuple of (trt_hash, trt_file_path)
|
|
"""
|
|
try:
|
|
# Load PyTorch model to check type
|
|
logger.debug(f"Loading PyTorch model from {pt_path}...")
|
|
# Use weights_only=False for models with custom classes (like ultralytics)
|
|
# This is safe for trusted local models
|
|
loaded = torch.load(pt_path, map_location='cpu', weights_only=False)
|
|
|
|
# If model is wrapped in a dict, extract the model
|
|
if isinstance(loaded, dict):
|
|
if 'model' in loaded:
|
|
model = loaded['model']
|
|
elif 'state_dict' in loaded:
|
|
raise ValueError(
|
|
"PT file contains state_dict only. "
|
|
"Please provide a full model or use a different loading method."
|
|
)
|
|
else:
|
|
model = loaded
|
|
|
|
# Check if this is an ultralytics model
|
|
if self._is_ultralytics_model(model):
|
|
logger.info("Detected ultralytics model, using ultralytics export API")
|
|
return self._convert_ultralytics_model(pt_path, pt_hash, input_shapes, precision)
|
|
|
|
# For non-ultralytics models, use torch_tensorrt
|
|
logger.info("Using torch_tensorrt for conversion")
|
|
model.eval()
|
|
|
|
# Convert model to target precision to avoid mixed precision issues
|
|
if precision == torch.float16:
|
|
model = model.half()
|
|
elif precision == torch.float32:
|
|
model = model.float()
|
|
|
|
# Move to GPU
|
|
model = model.to(self.device)
|
|
|
|
# Prepare inputs for tracing
|
|
if input_shapes is None:
|
|
raise ValueError(
|
|
"input_shapes must be provided for TensorRT conversion. "
|
|
"Example: {'x': (1, 3, 224, 224)}"
|
|
)
|
|
|
|
# Create sample inputs with matching precision
|
|
inputs = []
|
|
for name, shape in input_shapes.items():
|
|
sample_input = torch.randn(shape, device=self.device, dtype=precision)
|
|
inputs.append(sample_input)
|
|
|
|
# Configure torch_tensorrt
|
|
enabled_precisions = {precision}
|
|
if precision == torch.float16:
|
|
enabled_precisions.add(torch.float32) # Fallback for unsupported ops
|
|
|
|
# Compile to TensorRT
|
|
logger.info(f"Compiling to TensorRT (precision: {precision})...")
|
|
trt_model = torch_tensorrt.compile(
|
|
model,
|
|
inputs=inputs,
|
|
enabled_precisions=enabled_precisions,
|
|
**conversion_kwargs
|
|
)
|
|
|
|
# Save TRT model to temporary location
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.trt') as tmp_file:
|
|
tmp_path = tmp_file.name
|
|
|
|
torch.jit.save(trt_model, tmp_path)
|
|
logger.debug(f"Saved TRT model to temporary file: {tmp_path}")
|
|
|
|
# Compute TRT file hash
|
|
trt_hash = self.compute_file_hash(tmp_path)
|
|
logger.info(f"TRT file hash: {trt_hash[:16]}...")
|
|
|
|
# Store in storage backend
|
|
trt_key = f"trt/{trt_hash}.trt"
|
|
with open(tmp_path, 'rb') as f:
|
|
trt_data = f.read()
|
|
self.storage.write(trt_key, trt_data)
|
|
|
|
# Get local path
|
|
trt_path = self.storage.get_local_path(trt_key)
|
|
if trt_path is None:
|
|
raise RuntimeError("Failed to get local path for converted TRT file")
|
|
|
|
# Cleanup temp file
|
|
Path(tmp_path).unlink()
|
|
|
|
logger.info(f"TRT model stored successfully at {trt_path}")
|
|
return (trt_hash, trt_path)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Conversion failed: {e}", exc_info=True)
|
|
raise RuntimeError(f"Failed to convert PT to TensorRT: {e}")
|
|
|
|
def clear_cache(self):
|
|
"""Clear all cached conversions and mapping database"""
|
|
logger.warning("Clearing all cached conversions...")
|
|
self.mapping_db.clear()
|
|
self._save_mapping_db()
|
|
logger.info("Cache cleared")
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""
|
|
Get conversion statistics.
|
|
|
|
Returns:
|
|
Dictionary with cache stats
|
|
"""
|
|
return {
|
|
"total_cached_conversions": len(self.mapping_db),
|
|
"storage_path": self.storage.get_storage_path(),
|
|
"gpu_id": self.gpu_id,
|
|
"default_precision": str(self.default_precision),
|
|
}
|