converter system

This commit is contained in:
Siwat Sirichai 2025-11-09 19:54:35 +07:00
parent d3dbf9a580
commit 748fb71980
9 changed files with 1012 additions and 14 deletions

View file

@ -6,6 +6,9 @@ from queue import Queue
import torch
import tensorrt as trt
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
@ -158,17 +161,19 @@ class TensorRTModelRepository:
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
"""
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4):
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True):
"""
Initialize the model repository.
Args:
gpu_id: GPU device ID to use
default_num_contexts: Default number of execution contexts per unique engine
enable_pt_conversion: Enable automatic PyTorch to TensorRT conversion
"""
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.default_num_contexts = default_num_contexts
self.enable_pt_conversion = enable_pt_conversion
# Model ID to engine mapping: model_id -> file_hash
self._model_to_hash: Dict[str, str] = {}
@ -182,8 +187,22 @@ class TensorRTModelRepository:
# TensorRT logger
self.trt_logger = trt.Logger(trt.Logger.WARNING)
# PT converter (lazy initialization)
self._pt_converter = None
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
if enable_pt_conversion:
print(f"PyTorch to TensorRT conversion: enabled")
@property
def pt_converter(self):
"""Lazy initialization of PT converter"""
if self._pt_converter is None and self.enable_pt_conversion:
from .pt_converter import PTConverter
self._pt_converter = PTConverter(gpu_id=self.gpu_id)
logger.info("PT converter initialized")
return self._pt_converter
@staticmethod
def compute_file_hash(file_path: str) -> str:
@ -282,18 +301,26 @@ class TensorRTModelRepository:
def load_model(self, model_id: str, file_path: str,
num_contexts: Optional[int] = None,
force_reload: bool = False) -> ModelMetadata:
force_reload: bool = False,
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
pt_precision: Optional[torch.dtype] = None,
**pt_conversion_kwargs) -> ModelMetadata:
"""
Load a TensorRT model with the given ID.
Supports both .trt and .pt files. PT files are automatically converted to TensorRT.
Deduplication: If a model with the same file hash is already loaded, the model_id
is simply mapped to the existing SharedEngine (no new engine or contexts created).
Args:
model_id: User-defined identifier for this model (e.g., "camera_1")
file_path: Path to TensorRT engine file (.trt or .engine)
file_path: Path to TensorRT engine file (.trt, .engine) or PyTorch file (.pt, .pth)
num_contexts: Number of execution contexts in pool (None = use default)
force_reload: If True, reload even if model_id exists
pt_input_shapes: Required for .pt files - dict of input shapes (e.g., {"x": (1, 3, 224, 224)})
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
**pt_conversion_kwargs: Additional arguments for torch_tensorrt.compile()
Returns:
ModelMetadata for the loaded model
@ -301,13 +328,37 @@ class TensorRTModelRepository:
Raises:
FileNotFoundError: If model file doesn't exist
RuntimeError: If engine loading fails
ValueError: If model_id already exists and force_reload is False
ValueError: If model_id already exists and force_reload is False, or PT conversion requires input_shapes
"""
file_path = str(Path(file_path).resolve())
if not Path(file_path).exists():
raise FileNotFoundError(f"Model file not found: {file_path}")
# Check if file is PyTorch model
file_ext = Path(file_path).suffix.lower()
if file_ext in ['.pt', '.pth']:
if not self.enable_pt_conversion:
raise ValueError(
f"PT file provided but PT conversion is disabled. "
f"Enable with enable_pt_conversion=True or provide a .trt file."
)
logger.info(f"Detected PyTorch model file: {file_path}")
logger.info("Converting to TensorRT...")
# Convert PT to TRT
trt_hash, trt_path = self.pt_converter.convert(
file_path,
input_shapes=pt_input_shapes,
precision=pt_precision,
**pt_conversion_kwargs
)
# Update file_path to use converted TRT file
file_path = trt_path
logger.info(f"Will load converted TensorRT model from: {file_path}")
if num_contexts is None:
num_contexts = self.default_num_contexts