python-rtsp-worker/services/pt_converter.py
2025-11-09 19:54:35 +07:00

485 lines
17 KiB
Python

"""
PTConverter - Convert PyTorch models to TensorRT using torch_tensorrt.
This service handles conversion of .pt files to TensorRT format with caching
to avoid redundant conversions. It maintains a mapping database between
PT file hashes and their converted TensorRT file hashes.
"""
import hashlib
import json
import logging
from pathlib import Path
from typing import Optional, Dict, Tuple
import torch
import torch_tensorrt
from .modelstorage import IModelStorage, FileModelStorage
logger = logging.getLogger(__name__)
class PTConverter:
"""
PyTorch to TensorRT converter with intelligent caching.
Features:
- Hash-based deduplication: Same PT file only converted once
- Persistent mapping database: PT hash -> TRT hash mapping
- Pluggable storage backend: IModelStorage interface
- Automatic cache management
Architecture:
1. Compute hash of input .pt file
2. Check mapping database for existing conversion
3. If found, return cached TRT file hash and path
4. If not found, perform conversion and store mapping
"""
def __init__(
self,
storage: Optional[IModelStorage] = None,
gpu_id: int = 0,
default_precision: torch.dtype = torch.float16
):
"""
Initialize PT converter.
Args:
storage: Storage backend (defaults to FileModelStorage)
gpu_id: GPU device ID for conversion
default_precision: Default precision for TensorRT conversion (fp16, fp32)
"""
self.storage = storage or FileModelStorage()
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.default_precision = default_precision
# Mapping database: pt_hash -> {"trt_hash": str, "metadata": {...}}
self.mapping_db_key = "pt_to_trt_mapping.json"
self.mapping_db: Dict[str, Dict] = self._load_mapping_db()
logger.info(f"PTConverter initialized on GPU {gpu_id}")
logger.info(f"Storage backend: {self.storage.__class__.__name__}")
logger.info(f"Storage path: {self.storage.get_storage_path()}")
logger.info(f"Loaded {len(self.mapping_db)} cached conversions")
def _load_mapping_db(self) -> Dict[str, Dict]:
"""
Load mapping database from storage.
Returns:
Mapping dictionary (pt_hash -> metadata)
"""
try:
data = self.storage.read(self.mapping_db_key)
if data:
db = json.loads(data.decode('utf-8'))
logger.debug(f"Loaded mapping database with {len(db)} entries")
return db
else:
logger.debug("No existing mapping database found, starting fresh")
return {}
except Exception as e:
logger.warning(f"Failed to load mapping database: {e}. Starting fresh.")
return {}
def _save_mapping_db(self):
"""Save mapping database to storage"""
try:
data = json.dumps(self.mapping_db, indent=2).encode('utf-8')
self.storage.write(self.mapping_db_key, data)
logger.debug(f"Saved mapping database with {len(self.mapping_db)} entries")
except Exception as e:
logger.error(f"Failed to save mapping database: {e}")
@staticmethod
def compute_file_hash(file_path: str) -> str:
"""
Compute SHA256 hash of a file.
Args:
file_path: Path to file
Returns:
Hexadecimal hash string
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(65536), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def get_cached_conversion(self, pt_hash: str) -> Optional[Tuple[str, str]]:
"""
Check if PT file has already been converted.
Args:
pt_hash: SHA256 hash of the PT file
Returns:
Tuple of (trt_hash, trt_file_path) if cached, None otherwise
"""
if pt_hash not in self.mapping_db:
return None
mapping = self.mapping_db[pt_hash]
trt_hash = mapping["trt_hash"]
trt_key = f"trt/{trt_hash}.trt"
# Verify TRT file still exists in storage
if not self.storage.exists(trt_key):
logger.warning(
f"Mapping exists for PT hash {pt_hash[:16]}... but TRT file missing. "
f"Will reconvert."
)
# Remove stale mapping
del self.mapping_db[pt_hash]
self._save_mapping_db()
return None
# Get local path
trt_path = self.storage.get_local_path(trt_key)
if trt_path is None:
logger.error(f"Could not get local path for TRT file {trt_key}")
return None
logger.info(
f"Found cached conversion for PT hash {pt_hash[:16]}... -> "
f"TRT hash {trt_hash[:16]}..."
)
return (trt_hash, trt_path)
def convert(
self,
pt_file_path: str,
input_shapes: Optional[Dict[str, Tuple]] = None,
precision: Optional[torch.dtype] = None,
**conversion_kwargs
) -> Tuple[str, str]:
"""
Convert PyTorch model to TensorRT.
If this PT file has been converted before (same hash), returns cached result.
Otherwise, performs conversion and caches the result.
Args:
pt_file_path: Path to .pt file
input_shapes: Dict of input names to shapes (e.g., {"x": (1, 3, 224, 224)})
precision: Target precision (fp16, fp32) - defaults to self.default_precision
**conversion_kwargs: Additional arguments for torch_tensorrt.compile()
Returns:
Tuple of (trt_hash, trt_file_path)
Raises:
FileNotFoundError: If PT file doesn't exist
RuntimeError: If conversion fails
"""
pt_path = Path(pt_file_path).resolve()
if not pt_path.exists():
raise FileNotFoundError(f"PT file not found: {pt_file_path}")
# Compute PT file hash
logger.info(f"Computing hash for {pt_path}...")
pt_hash = self.compute_file_hash(str(pt_path))
logger.info(f"PT file hash: {pt_hash[:16]}...")
# Check cache
cached = self.get_cached_conversion(pt_hash)
if cached:
return cached
# Perform conversion
logger.info(f"Converting {pt_path.name} to TensorRT...")
trt_hash, trt_path = self._perform_conversion(
str(pt_path),
pt_hash,
input_shapes,
precision or self.default_precision,
**conversion_kwargs
)
# Store mapping
self.mapping_db[pt_hash] = {
"trt_hash": trt_hash,
"pt_file": str(pt_path),
"input_shapes": str(input_shapes),
"precision": str(precision or self.default_precision),
}
self._save_mapping_db()
logger.info(f"Conversion complete: PT {pt_hash[:16]}... -> TRT {trt_hash[:16]}...")
return (trt_hash, trt_path)
def _is_ultralytics_model(self, model) -> bool:
"""
Check if model is from ultralytics (YOLO).
Args:
model: PyTorch model
Returns:
True if ultralytics model, False otherwise
"""
model_class_name = model.__class__.__name__
model_module = model.__class__.__module__
# Check if it's an ultralytics model
is_ultralytics = (
'ultralytics' in model_module or
model_class_name in ['DetectionModel', 'SegmentationModel', 'PoseModel', 'ClassificationModel']
)
return is_ultralytics
def _convert_ultralytics_model(
self,
pt_path: str,
pt_hash: str,
input_shapes: Optional[Dict[str, Tuple]],
precision: torch.dtype,
) -> Tuple[str, str]:
"""
Convert ultralytics YOLO model using ONNX → TensorRT pipeline.
Uses the same approach as scripts/convert_pt_to_tensorrt.py
Args:
pt_path: Path to PT file
pt_hash: PT file hash
input_shapes: Input tensor shapes
precision: Target precision
Returns:
Tuple of (trt_hash, trt_file_path)
"""
import tensorrt as trt
import tempfile
import os
import shutil
logger.info("Detected ultralytics YOLO model, using ONNX → TensorRT pipeline...")
# Load ultralytics model
try:
from ultralytics import YOLO
model = YOLO(pt_path)
except ImportError:
raise ImportError("ultralytics package not found. Install with: pip install ultralytics")
# Determine input shape
if not input_shapes:
raise ValueError("input_shapes required for ultralytics conversion")
input_key = 'images' if 'images' in input_shapes else list(input_shapes.keys())[0]
input_shape = input_shapes[input_key]
# Export to ONNX first
logger.info(f"Exporting to ONNX (input shape: {input_shape})...")
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp_onnx:
onnx_path = tmp_onnx.name
try:
# Use ultralytics export to ONNX
model.export(format='onnx', imgsz=input_shape[2], batch=input_shape[0])
# Ultralytics saves as model_name.onnx in same directory
pt_dir = os.path.dirname(pt_path)
pt_name = os.path.splitext(os.path.basename(pt_path))[0]
onnx_export_path = os.path.join(pt_dir, f"{pt_name}.onnx")
# Move to our temp location (use shutil.move for cross-device support)
if os.path.exists(onnx_export_path):
shutil.move(onnx_export_path, onnx_path)
else:
raise RuntimeError(f"ONNX export failed, file not found: {onnx_export_path}")
logger.info(f"ONNX export complete: {onnx_path}")
# Build TensorRT engine from ONNX
logger.info("Building TensorRT engine from ONNX...")
trt_logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(trt_logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, trt_logger)
# Parse ONNX
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
errors = [parser.get_error(i) for i in range(parser.num_errors)]
raise RuntimeError(f"Failed to parse ONNX: {errors}")
# Configure builder
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB
# Set precision
if precision == torch.float16:
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
logger.info("FP16 mode enabled")
# Build engine
logger.info("Building TensorRT engine (this may take a few minutes)...")
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine is None:
raise RuntimeError("Failed to build TensorRT engine")
# Convert IHostMemory to bytes
engine_bytes = bytes(serialized_engine)
# Save to storage
trt_hash = hashlib.sha256(engine_bytes).hexdigest()
trt_key = f"trt/{trt_hash}.trt"
self.storage.write(trt_key, engine_bytes)
trt_path = self.storage.get_local_path(trt_key)
if trt_path is None:
raise RuntimeError("Failed to get local path for TRT file")
logger.info(f"TensorRT engine built successfully: {trt_path}")
return (trt_hash, trt_path)
finally:
# Cleanup ONNX file
if os.path.exists(onnx_path):
os.unlink(onnx_path)
def _perform_conversion(
self,
pt_path: str,
pt_hash: str,
input_shapes: Optional[Dict[str, Tuple]],
precision: torch.dtype,
**conversion_kwargs
) -> Tuple[str, str]:
"""
Perform the actual PT to TRT conversion.
Args:
pt_path: Path to PT file
pt_hash: PT file hash
input_shapes: Input tensor shapes
precision: Target precision
**conversion_kwargs: Additional torch_tensorrt arguments
Returns:
Tuple of (trt_hash, trt_file_path)
"""
try:
# Load PyTorch model to check type
logger.debug(f"Loading PyTorch model from {pt_path}...")
# Use weights_only=False for models with custom classes (like ultralytics)
# This is safe for trusted local models
loaded = torch.load(pt_path, map_location='cpu', weights_only=False)
# If model is wrapped in a dict, extract the model
if isinstance(loaded, dict):
if 'model' in loaded:
model = loaded['model']
elif 'state_dict' in loaded:
raise ValueError(
"PT file contains state_dict only. "
"Please provide a full model or use a different loading method."
)
else:
model = loaded
# Check if this is an ultralytics model
if self._is_ultralytics_model(model):
logger.info("Detected ultralytics model, using ultralytics export API")
return self._convert_ultralytics_model(pt_path, pt_hash, input_shapes, precision)
# For non-ultralytics models, use torch_tensorrt
logger.info("Using torch_tensorrt for conversion")
model.eval()
# Convert model to target precision to avoid mixed precision issues
if precision == torch.float16:
model = model.half()
elif precision == torch.float32:
model = model.float()
# Move to GPU
model = model.to(self.device)
# Prepare inputs for tracing
if input_shapes is None:
raise ValueError(
"input_shapes must be provided for TensorRT conversion. "
"Example: {'x': (1, 3, 224, 224)}"
)
# Create sample inputs with matching precision
inputs = []
for name, shape in input_shapes.items():
sample_input = torch.randn(shape, device=self.device, dtype=precision)
inputs.append(sample_input)
# Configure torch_tensorrt
enabled_precisions = {precision}
if precision == torch.float16:
enabled_precisions.add(torch.float32) # Fallback for unsupported ops
# Compile to TensorRT
logger.info(f"Compiling to TensorRT (precision: {precision})...")
trt_model = torch_tensorrt.compile(
model,
inputs=inputs,
enabled_precisions=enabled_precisions,
**conversion_kwargs
)
# Save TRT model to temporary location
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.trt') as tmp_file:
tmp_path = tmp_file.name
torch.jit.save(trt_model, tmp_path)
logger.debug(f"Saved TRT model to temporary file: {tmp_path}")
# Compute TRT file hash
trt_hash = self.compute_file_hash(tmp_path)
logger.info(f"TRT file hash: {trt_hash[:16]}...")
# Store in storage backend
trt_key = f"trt/{trt_hash}.trt"
with open(tmp_path, 'rb') as f:
trt_data = f.read()
self.storage.write(trt_key, trt_data)
# Get local path
trt_path = self.storage.get_local_path(trt_key)
if trt_path is None:
raise RuntimeError("Failed to get local path for converted TRT file")
# Cleanup temp file
Path(tmp_path).unlink()
logger.info(f"TRT model stored successfully at {trt_path}")
return (trt_hash, trt_path)
except Exception as e:
logger.error(f"Conversion failed: {e}", exc_info=True)
raise RuntimeError(f"Failed to convert PT to TensorRT: {e}")
def clear_cache(self):
"""Clear all cached conversions and mapping database"""
logger.warning("Clearing all cached conversions...")
self.mapping_db.clear()
self._save_mapping_db()
logger.info("Cache cleared")
def get_stats(self) -> Dict:
"""
Get conversion statistics.
Returns:
Dictionary with cache stats
"""
return {
"total_cached_conversions": len(self.mapping_db),
"storage_path": self.storage.get_storage_path(),
"gpu_id": self.gpu_id,
"default_precision": str(self.default_precision),
}