converter system
This commit is contained in:
parent
d3dbf9a580
commit
748fb71980
9 changed files with 1012 additions and 14 deletions
|
|
@ -6,6 +6,9 @@ from queue import Queue
|
|||
import torch
|
||||
import tensorrt as trt
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -158,17 +161,19 @@ class TensorRTModelRepository:
|
|||
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
|
||||
"""
|
||||
|
||||
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4):
|
||||
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True):
|
||||
"""
|
||||
Initialize the model repository.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID to use
|
||||
default_num_contexts: Default number of execution contexts per unique engine
|
||||
enable_pt_conversion: Enable automatic PyTorch to TensorRT conversion
|
||||
"""
|
||||
self.gpu_id = gpu_id
|
||||
self.device = torch.device(f'cuda:{gpu_id}')
|
||||
self.default_num_contexts = default_num_contexts
|
||||
self.enable_pt_conversion = enable_pt_conversion
|
||||
|
||||
# Model ID to engine mapping: model_id -> file_hash
|
||||
self._model_to_hash: Dict[str, str] = {}
|
||||
|
|
@ -182,8 +187,22 @@ class TensorRTModelRepository:
|
|||
# TensorRT logger
|
||||
self.trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
# PT converter (lazy initialization)
|
||||
self._pt_converter = None
|
||||
|
||||
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
|
||||
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
|
||||
if enable_pt_conversion:
|
||||
print(f"PyTorch to TensorRT conversion: enabled")
|
||||
|
||||
@property
|
||||
def pt_converter(self):
|
||||
"""Lazy initialization of PT converter"""
|
||||
if self._pt_converter is None and self.enable_pt_conversion:
|
||||
from .pt_converter import PTConverter
|
||||
self._pt_converter = PTConverter(gpu_id=self.gpu_id)
|
||||
logger.info("PT converter initialized")
|
||||
return self._pt_converter
|
||||
|
||||
@staticmethod
|
||||
def compute_file_hash(file_path: str) -> str:
|
||||
|
|
@ -282,18 +301,26 @@ class TensorRTModelRepository:
|
|||
|
||||
def load_model(self, model_id: str, file_path: str,
|
||||
num_contexts: Optional[int] = None,
|
||||
force_reload: bool = False) -> ModelMetadata:
|
||||
force_reload: bool = False,
|
||||
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
|
||||
pt_precision: Optional[torch.dtype] = None,
|
||||
**pt_conversion_kwargs) -> ModelMetadata:
|
||||
"""
|
||||
Load a TensorRT model with the given ID.
|
||||
|
||||
Supports both .trt and .pt files. PT files are automatically converted to TensorRT.
|
||||
|
||||
Deduplication: If a model with the same file hash is already loaded, the model_id
|
||||
is simply mapped to the existing SharedEngine (no new engine or contexts created).
|
||||
|
||||
Args:
|
||||
model_id: User-defined identifier for this model (e.g., "camera_1")
|
||||
file_path: Path to TensorRT engine file (.trt or .engine)
|
||||
file_path: Path to TensorRT engine file (.trt, .engine) or PyTorch file (.pt, .pth)
|
||||
num_contexts: Number of execution contexts in pool (None = use default)
|
||||
force_reload: If True, reload even if model_id exists
|
||||
pt_input_shapes: Required for .pt files - dict of input shapes (e.g., {"x": (1, 3, 224, 224)})
|
||||
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
|
||||
**pt_conversion_kwargs: Additional arguments for torch_tensorrt.compile()
|
||||
|
||||
Returns:
|
||||
ModelMetadata for the loaded model
|
||||
|
|
@ -301,13 +328,37 @@ class TensorRTModelRepository:
|
|||
Raises:
|
||||
FileNotFoundError: If model file doesn't exist
|
||||
RuntimeError: If engine loading fails
|
||||
ValueError: If model_id already exists and force_reload is False
|
||||
ValueError: If model_id already exists and force_reload is False, or PT conversion requires input_shapes
|
||||
"""
|
||||
file_path = str(Path(file_path).resolve())
|
||||
|
||||
if not Path(file_path).exists():
|
||||
raise FileNotFoundError(f"Model file not found: {file_path}")
|
||||
|
||||
# Check if file is PyTorch model
|
||||
file_ext = Path(file_path).suffix.lower()
|
||||
if file_ext in ['.pt', '.pth']:
|
||||
if not self.enable_pt_conversion:
|
||||
raise ValueError(
|
||||
f"PT file provided but PT conversion is disabled. "
|
||||
f"Enable with enable_pt_conversion=True or provide a .trt file."
|
||||
)
|
||||
|
||||
logger.info(f"Detected PyTorch model file: {file_path}")
|
||||
logger.info("Converting to TensorRT...")
|
||||
|
||||
# Convert PT to TRT
|
||||
trt_hash, trt_path = self.pt_converter.convert(
|
||||
file_path,
|
||||
input_shapes=pt_input_shapes,
|
||||
precision=pt_precision,
|
||||
**pt_conversion_kwargs
|
||||
)
|
||||
|
||||
# Update file_path to use converted TRT file
|
||||
file_path = trt_path
|
||||
logger.info(f"Will load converted TensorRT model from: {file_path}")
|
||||
|
||||
if num_contexts is None:
|
||||
num_contexts = self.default_num_contexts
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue