python-rtsp-worker/services/pt_converter.py

464 lines
16 KiB
Python

"""
PTConverter - Convert PyTorch models to TensorRT using torch_tensorrt.
This service handles conversion of .pt files to TensorRT format with caching
to avoid redundant conversions. It maintains a mapping database between
PT file hashes and their converted TensorRT file hashes.
"""
import hashlib
import json
import logging
from pathlib import Path
from typing import Optional, Dict, Tuple
import torch
import torch_tensorrt
from .modelstorage import IModelStorage, FileModelStorage
logger = logging.getLogger(__name__)
class PTConverter:
"""
PyTorch to TensorRT converter with intelligent caching.
Features:
- Hash-based deduplication: Same PT file only converted once
- Persistent mapping database: PT hash -> TRT hash mapping
- Pluggable storage backend: IModelStorage interface
- Automatic cache management
Architecture:
1. Compute hash of input .pt file
2. Check mapping database for existing conversion
3. If found, return cached TRT file hash and path
4. If not found, perform conversion and store mapping
"""
def __init__(
self,
storage: Optional[IModelStorage] = None,
gpu_id: int = 0,
default_precision: torch.dtype = torch.float16
):
"""
Initialize PT converter.
Args:
storage: Storage backend (defaults to FileModelStorage)
gpu_id: GPU device ID for conversion
default_precision: Default precision for TensorRT conversion (fp16, fp32)
"""
self.storage = storage or FileModelStorage()
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.default_precision = default_precision
# Mapping database: pt_hash -> {"trt_hash": str, "metadata": {...}}
self.mapping_db_key = "pt_to_trt_mapping.json"
self.mapping_db: Dict[str, Dict] = self._load_mapping_db()
logger.info(f"PTConverter initialized on GPU {gpu_id}")
logger.info(f"Storage backend: {self.storage.__class__.__name__}")
logger.info(f"Storage path: {self.storage.get_storage_path()}")
logger.info(f"Loaded {len(self.mapping_db)} cached conversions")
def _load_mapping_db(self) -> Dict[str, Dict]:
"""
Load mapping database from storage.
Returns:
Mapping dictionary (pt_hash -> metadata)
"""
try:
data = self.storage.read(self.mapping_db_key)
if data:
db = json.loads(data.decode('utf-8'))
logger.debug(f"Loaded mapping database with {len(db)} entries")
return db
else:
logger.debug("No existing mapping database found, starting fresh")
return {}
except Exception as e:
logger.warning(f"Failed to load mapping database: {e}. Starting fresh.")
return {}
def _save_mapping_db(self):
"""Save mapping database to storage"""
try:
data = json.dumps(self.mapping_db, indent=2).encode('utf-8')
self.storage.write(self.mapping_db_key, data)
logger.debug(f"Saved mapping database with {len(self.mapping_db)} entries")
except Exception as e:
logger.error(f"Failed to save mapping database: {e}")
@staticmethod
def compute_file_hash(file_path: str) -> str:
"""
Compute SHA256 hash of a file.
Args:
file_path: Path to file
Returns:
Hexadecimal hash string
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(65536), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def get_cached_conversion(self, pt_hash: str) -> Optional[Tuple[str, str]]:
"""
Check if PT file has already been converted.
Args:
pt_hash: SHA256 hash of the PT file
Returns:
Tuple of (trt_hash, trt_file_path) if cached, None otherwise
"""
if pt_hash not in self.mapping_db:
return None
mapping = self.mapping_db[pt_hash]
trt_hash = mapping["trt_hash"]
# Check both .engine and .trt extensions (Ultralytics uses .engine, generic uses .trt)
engine_key = f"trt/{trt_hash}.engine"
trt_key = f"trt/{trt_hash}.trt"
# Try .engine first (Ultralytics native format)
if self.storage.exists(engine_key):
cached_key = engine_key
elif self.storage.exists(trt_key):
cached_key = trt_key
else:
logger.warning(
f"Mapping exists for PT hash {pt_hash[:16]}... but engine file missing. "
f"Will reconvert."
)
# Remove stale mapping
del self.mapping_db[pt_hash]
self._save_mapping_db()
return None
# Get local path
cached_path = self.storage.get_local_path(cached_key)
if cached_path is None:
logger.error(f"Could not get local path for engine file {cached_key}")
return None
logger.info(
f"Found cached conversion for PT hash {pt_hash[:16]}... -> "
f"Engine hash {trt_hash[:16]}... ({cached_key})"
)
return (trt_hash, cached_path)
def convert(
self,
pt_file_path: str,
input_shapes: Optional[Dict[str, Tuple]] = None,
precision: Optional[torch.dtype] = None,
**conversion_kwargs
) -> Tuple[str, str]:
"""
Convert PyTorch model to TensorRT.
If this PT file has been converted before (same hash), returns cached result.
Otherwise, performs conversion and caches the result.
Args:
pt_file_path: Path to .pt file
input_shapes: Dict of input names to shapes (e.g., {"x": (1, 3, 224, 224)})
precision: Target precision (fp16, fp32) - defaults to self.default_precision
**conversion_kwargs: Additional arguments for torch_tensorrt.compile()
Returns:
Tuple of (trt_hash, trt_file_path)
Raises:
FileNotFoundError: If PT file doesn't exist
RuntimeError: If conversion fails
"""
pt_path = Path(pt_file_path).resolve()
if not pt_path.exists():
raise FileNotFoundError(f"PT file not found: {pt_file_path}")
# Compute PT file hash
logger.info(f"Computing hash for {pt_path}...")
pt_hash = self.compute_file_hash(str(pt_path))
logger.info(f"PT file hash: {pt_hash[:16]}...")
# Check cache
cached = self.get_cached_conversion(pt_hash)
if cached:
return cached
# Perform conversion
logger.info(f"Converting {pt_path.name} to TensorRT...")
trt_hash, trt_path = self._perform_conversion(
str(pt_path),
pt_hash,
input_shapes,
precision or self.default_precision,
**conversion_kwargs
)
# Store mapping
self.mapping_db[pt_hash] = {
"trt_hash": trt_hash,
"pt_file": str(pt_path),
"input_shapes": str(input_shapes),
"precision": str(precision or self.default_precision),
}
self._save_mapping_db()
logger.info(f"Conversion complete: PT {pt_hash[:16]}... -> TRT {trt_hash[:16]}...")
return (trt_hash, trt_path)
def _is_ultralytics_model(self, model) -> bool:
"""
Check if model is from ultralytics (YOLO).
Args:
model: PyTorch model
Returns:
True if ultralytics model, False otherwise
"""
model_class_name = model.__class__.__name__
model_module = model.__class__.__module__
# Check if it's an ultralytics model
is_ultralytics = (
'ultralytics' in model_module or
model_class_name in ['DetectionModel', 'SegmentationModel', 'PoseModel', 'ClassificationModel']
)
return is_ultralytics
def _convert_ultralytics_model(
self,
pt_path: str,
pt_hash: str,
input_shapes: Optional[Dict[str, Tuple]],
precision: torch.dtype,
) -> Tuple[str, str]:
"""
Convert ultralytics YOLO model using native .engine export.
This produces .engine files with embedded metadata (no manual input_shapes needed).
Args:
pt_path: Path to PT file
pt_hash: PT file hash
input_shapes: Input tensor shapes (IGNORED for Ultralytics - auto-detected)
precision: Target precision
Returns:
Tuple of (engine_hash, engine_file_path)
"""
import os
logger.info("Detected ultralytics YOLO model, using native .engine export...")
# Load ultralytics model
try:
from ultralytics import YOLO
model = YOLO(pt_path)
except ImportError:
raise ImportError("ultralytics package not found. Install with: pip install ultralytics")
# Export to native .engine format with embedded metadata
logger.info(f"Exporting to native TensorRT .engine (precision: {'FP16' if precision == torch.float16 else 'FP32'})...")
# Ultralytics export creates .engine file in same directory as .pt
engine_path = model.export(
format='engine',
half=(precision == torch.float16),
device=self.gpu_id,
batch=1,
simplify=True
)
# Convert to string (Ultralytics returns Path object)
engine_path = str(engine_path)
logger.info(f"Native .engine export complete: {engine_path}")
logger.info("Metadata embedded in .engine file (stride, imgsz, names, etc.)")
# Read the exported .engine file
with open(engine_path, 'rb') as f:
engine_data = f.read()
# Compute hash of the .engine file
engine_hash = hashlib.sha256(engine_data).hexdigest()
# Store in our cache (as .engine to preserve metadata)
engine_key = f"trt/{engine_hash}.engine"
self.storage.write(engine_key, engine_data)
cached_path = self.storage.get_local_path(engine_key)
if cached_path is None:
raise RuntimeError("Failed to get local path for .engine file")
# Clean up the original export (we've cached it)
# Only delete if it's different from cached path
if os.path.exists(engine_path) and os.path.abspath(engine_path) != os.path.abspath(cached_path):
logger.info(f"Removing original export (cached): {engine_path}")
os.unlink(engine_path)
else:
logger.info(f"Keeping original export at: {engine_path}")
logger.info(f"Cached .engine file: {cached_path}")
return (engine_hash, cached_path)
def _perform_conversion(
self,
pt_path: str,
pt_hash: str,
input_shapes: Optional[Dict[str, Tuple]],
precision: torch.dtype,
**conversion_kwargs
) -> Tuple[str, str]:
"""
Perform the actual PT to TRT conversion.
Args:
pt_path: Path to PT file
pt_hash: PT file hash
input_shapes: Input tensor shapes
precision: Target precision
**conversion_kwargs: Additional torch_tensorrt arguments
Returns:
Tuple of (trt_hash, trt_file_path)
"""
try:
# Load PyTorch model to check type
logger.debug(f"Loading PyTorch model from {pt_path}...")
# Use weights_only=False for models with custom classes (like ultralytics)
# This is safe for trusted local models
loaded = torch.load(pt_path, map_location='cpu', weights_only=False)
# If model is wrapped in a dict, extract the model
if isinstance(loaded, dict):
if 'model' in loaded:
model = loaded['model']
elif 'state_dict' in loaded:
raise ValueError(
"PT file contains state_dict only. "
"Please provide a full model or use a different loading method."
)
else:
model = loaded
# Check if this is an ultralytics model
if self._is_ultralytics_model(model):
logger.info("Detected Ultralytics YOLO model, using native .engine export")
logger.info("Note: input_shapes parameter is ignored for Ultralytics models (auto-detected)")
return self._convert_ultralytics_model(pt_path, pt_hash, input_shapes, precision)
# For non-ultralytics models, use torch_tensorrt
logger.info("Using torch_tensorrt for conversion (non-Ultralytics model)")
# Non-Ultralytics models REQUIRE input_shapes
if input_shapes is None:
raise ValueError(
"input_shapes required for non-Ultralytics PyTorch models. "
"For Ultralytics YOLO models, input_shapes is auto-detected. "
"Example: input_shapes={'images': (1, 3, 640, 640)}"
)
model.eval()
# Convert model to target precision to avoid mixed precision issues
if precision == torch.float16:
model = model.half()
elif precision == torch.float32:
model = model.float()
# Move to GPU
model = model.to(self.device)
# Prepare inputs for tracing
if input_shapes is None:
raise ValueError(
"input_shapes must be provided for TensorRT conversion. "
"Example: {'x': (1, 3, 224, 224)}"
)
# Create sample inputs with matching precision
inputs = []
for name, shape in input_shapes.items():
sample_input = torch.randn(shape, device=self.device, dtype=precision)
inputs.append(sample_input)
# Configure torch_tensorrt
enabled_precisions = {precision}
if precision == torch.float16:
enabled_precisions.add(torch.float32) # Fallback for unsupported ops
# Compile to TensorRT
logger.info(f"Compiling to TensorRT (precision: {precision})...")
trt_model = torch_tensorrt.compile(
model,
inputs=inputs,
enabled_precisions=enabled_precisions,
**conversion_kwargs
)
# Save TRT model to temporary location
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.trt') as tmp_file:
tmp_path = tmp_file.name
torch.jit.save(trt_model, tmp_path)
logger.debug(f"Saved TRT model to temporary file: {tmp_path}")
# Compute TRT file hash
trt_hash = self.compute_file_hash(tmp_path)
logger.info(f"TRT file hash: {trt_hash[:16]}...")
# Store in storage backend
trt_key = f"trt/{trt_hash}.trt"
with open(tmp_path, 'rb') as f:
trt_data = f.read()
self.storage.write(trt_key, trt_data)
# Get local path
trt_path = self.storage.get_local_path(trt_key)
if trt_path is None:
raise RuntimeError("Failed to get local path for converted TRT file")
# Cleanup temp file
Path(tmp_path).unlink()
logger.info(f"TRT model stored successfully at {trt_path}")
return (trt_hash, trt_path)
except Exception as e:
logger.error(f"Conversion failed: {e}", exc_info=True)
raise RuntimeError(f"Failed to convert PT to TensorRT: {e}")
def clear_cache(self):
"""Clear all cached conversions and mapping database"""
logger.warning("Clearing all cached conversions...")
self.mapping_db.clear()
self._save_mapping_db()
logger.info("Cache cleared")
def get_stats(self) -> Dict:
"""
Get conversion statistics.
Returns:
Dictionary with cache stats
"""
return {
"total_cached_conversions": len(self.mapping_db),
"storage_path": self.storage.get_storage_path(),
"gpu_id": self.gpu_id,
"default_precision": str(self.default_precision),
}