converter system

This commit is contained in:
Siwat Sirichai 2025-11-09 19:54:35 +07:00
parent d3dbf9a580
commit 748fb71980
9 changed files with 1012 additions and 14 deletions

View file

@ -10,6 +10,8 @@ from .tracking_factory import TrackingFactory
from .yolo import YOLOv8Utils, COCO_CLASSES
from .model_controller import ModelController, BatchFrame, BufferState
from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult
from .pt_converter import PTConverter
from .modelstorage import IModelStorage, FileModelStorage
__all__ = [
'StreamDecoderFactory',
@ -32,4 +34,7 @@ __all__ = [
'StreamConnectionManager',
'StreamConnection',
'TrackingResult',
'PTConverter',
'IModelStorage',
'FileModelStorage',
]

View file

@ -6,6 +6,9 @@ from queue import Queue
import torch
import tensorrt as trt
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
@ -158,17 +161,19 @@ class TensorRTModelRepository:
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
"""
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4):
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True):
"""
Initialize the model repository.
Args:
gpu_id: GPU device ID to use
default_num_contexts: Default number of execution contexts per unique engine
enable_pt_conversion: Enable automatic PyTorch to TensorRT conversion
"""
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.default_num_contexts = default_num_contexts
self.enable_pt_conversion = enable_pt_conversion
# Model ID to engine mapping: model_id -> file_hash
self._model_to_hash: Dict[str, str] = {}
@ -182,8 +187,22 @@ class TensorRTModelRepository:
# TensorRT logger
self.trt_logger = trt.Logger(trt.Logger.WARNING)
# PT converter (lazy initialization)
self._pt_converter = None
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
if enable_pt_conversion:
print(f"PyTorch to TensorRT conversion: enabled")
@property
def pt_converter(self):
"""Lazy initialization of PT converter"""
if self._pt_converter is None and self.enable_pt_conversion:
from .pt_converter import PTConverter
self._pt_converter = PTConverter(gpu_id=self.gpu_id)
logger.info("PT converter initialized")
return self._pt_converter
@staticmethod
def compute_file_hash(file_path: str) -> str:
@ -282,18 +301,26 @@ class TensorRTModelRepository:
def load_model(self, model_id: str, file_path: str,
num_contexts: Optional[int] = None,
force_reload: bool = False) -> ModelMetadata:
force_reload: bool = False,
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
pt_precision: Optional[torch.dtype] = None,
**pt_conversion_kwargs) -> ModelMetadata:
"""
Load a TensorRT model with the given ID.
Supports both .trt and .pt files. PT files are automatically converted to TensorRT.
Deduplication: If a model with the same file hash is already loaded, the model_id
is simply mapped to the existing SharedEngine (no new engine or contexts created).
Args:
model_id: User-defined identifier for this model (e.g., "camera_1")
file_path: Path to TensorRT engine file (.trt or .engine)
file_path: Path to TensorRT engine file (.trt, .engine) or PyTorch file (.pt, .pth)
num_contexts: Number of execution contexts in pool (None = use default)
force_reload: If True, reload even if model_id exists
pt_input_shapes: Required for .pt files - dict of input shapes (e.g., {"x": (1, 3, 224, 224)})
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
**pt_conversion_kwargs: Additional arguments for torch_tensorrt.compile()
Returns:
ModelMetadata for the loaded model
@ -301,13 +328,37 @@ class TensorRTModelRepository:
Raises:
FileNotFoundError: If model file doesn't exist
RuntimeError: If engine loading fails
ValueError: If model_id already exists and force_reload is False
ValueError: If model_id already exists and force_reload is False, or PT conversion requires input_shapes
"""
file_path = str(Path(file_path).resolve())
if not Path(file_path).exists():
raise FileNotFoundError(f"Model file not found: {file_path}")
# Check if file is PyTorch model
file_ext = Path(file_path).suffix.lower()
if file_ext in ['.pt', '.pth']:
if not self.enable_pt_conversion:
raise ValueError(
f"PT file provided but PT conversion is disabled. "
f"Enable with enable_pt_conversion=True or provide a .trt file."
)
logger.info(f"Detected PyTorch model file: {file_path}")
logger.info("Converting to TensorRT...")
# Convert PT to TRT
trt_hash, trt_path = self.pt_converter.convert(
file_path,
input_shapes=pt_input_shapes,
precision=pt_precision,
**pt_conversion_kwargs
)
# Update file_path to use converted TRT file
file_path = trt_path
logger.info(f"Will load converted TensorRT model from: {file_path}")
if num_contexts is None:
num_contexts = self.default_num_contexts

View file

@ -0,0 +1,8 @@
"""
Model storage module for managing TensorRT and PyTorch model files.
"""
from .interface import IModelStorage
from .file_storage import FileModelStorage
__all__ = ['IModelStorage', 'FileModelStorage']

View file

@ -0,0 +1,161 @@
"""
FileModelStorage - Local filesystem implementation of IModelStorage.
Stores model files in a local directory structure.
"""
import os
from pathlib import Path
from typing import Optional
import logging
from .interface import IModelStorage
logger = logging.getLogger(__name__)
class FileModelStorage(IModelStorage):
"""
Local filesystem storage for model files.
Stores files in a directory structure:
./models/trtptcache/
trt/
<hash1>.trt
<hash2>.trt
pt/
<hash3>.pt
"""
def __init__(self, base_path: str = "./models/trtptcache"):
"""
Initialize file storage.
Args:
base_path: Base directory for storing files (default: ./models/trtptcache)
"""
self.base_path = Path(base_path).resolve()
self._ensure_directories()
def _ensure_directories(self):
"""Create base directory structure if it doesn't exist"""
self.base_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Model storage initialized at: {self.base_path}")
def _get_full_path(self, key: str) -> Path:
"""
Get full filesystem path for a key.
Args:
key: Storage key (e.g., "trt/hash123.trt")
Returns:
Full filesystem path
"""
return self.base_path / key
def write(self, key: str, data: bytes) -> None:
"""
Write data to filesystem.
Args:
key: Storage key (e.g., "trt/hash123.trt")
data: Binary data to write
Raises:
IOError: If write operation fails
"""
file_path = self._get_full_path(key)
# Ensure parent directory exists
file_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(file_path, 'wb') as f:
f.write(data)
logger.debug(f"Wrote {len(data)} bytes to {file_path}")
except Exception as e:
raise IOError(f"Failed to write to {file_path}: {e}")
def read(self, key: str) -> Optional[bytes]:
"""
Read data from filesystem.
Args:
key: Storage key
Returns:
Binary data if found, None otherwise
Raises:
IOError: If read operation fails
"""
file_path = self._get_full_path(key)
if not file_path.exists():
return None
try:
with open(file_path, 'rb') as f:
data = f.read()
logger.debug(f"Read {len(data)} bytes from {file_path}")
return data
except Exception as e:
raise IOError(f"Failed to read from {file_path}: {e}")
def exists(self, key: str) -> bool:
"""
Check if file exists.
Args:
key: Storage key
Returns:
True if file exists, False otherwise
"""
return self._get_full_path(key).exists()
def delete(self, key: str) -> bool:
"""
Delete file from filesystem.
Args:
key: Storage key
Returns:
True if deleted successfully, False if file didn't exist
"""
file_path = self._get_full_path(key)
if not file_path.exists():
return False
try:
file_path.unlink()
logger.debug(f"Deleted {file_path}")
return True
except Exception as e:
logger.error(f"Failed to delete {file_path}: {e}")
return False
def get_local_path(self, key: str) -> Optional[str]:
"""
Get local filesystem path for a key.
Args:
key: Storage key
Returns:
Local path if file exists, None otherwise
"""
file_path = self._get_full_path(key)
return str(file_path) if file_path.exists() else None
def get_storage_path(self) -> str:
"""
Get the base storage path.
Returns:
Base path where files are stored
"""
return str(self.base_path)

View file

@ -0,0 +1,91 @@
"""
IModelStorage - Interface for model file storage.
Defines the contract for storing and retrieving model files (TensorRT, PyTorch, etc.)
"""
from abc import ABC, abstractmethod
from typing import Optional, BinaryIO
from pathlib import Path
import io
class IModelStorage(ABC):
"""
Interface for model file storage.
This abstraction allows swapping storage backends (local filesystem, S3, etc.)
without changing the model conversion and loading logic.
"""
@abstractmethod
def write(self, key: str, data: bytes) -> None:
"""
Write data to storage with the given key.
Args:
key: Storage key (e.g., "trt/hash123.trt" or "pt/hash456.pt")
data: Binary data to write
Raises:
IOError: If write operation fails
"""
pass
@abstractmethod
def read(self, key: str) -> Optional[bytes]:
"""
Read data from storage by key.
Args:
key: Storage key
Returns:
Binary data if found, None otherwise
Raises:
IOError: If read operation fails
"""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""
Check if a key exists in storage.
Args:
key: Storage key
Returns:
True if key exists, False otherwise
"""
pass
@abstractmethod
def delete(self, key: str) -> bool:
"""
Delete data from storage.
Args:
key: Storage key
Returns:
True if deleted successfully, False if key didn't exist
"""
pass
@abstractmethod
def get_local_path(self, key: str) -> Optional[str]:
"""
Get a local filesystem path for the key.
For local storage, this returns the direct path.
For remote storage (S3), this may download to a temp location or return None.
Args:
key: Storage key
Returns:
Local path if available/downloaded, None if not supported
"""
pass

485
services/pt_converter.py Normal file
View file

@ -0,0 +1,485 @@
"""
PTConverter - Convert PyTorch models to TensorRT using torch_tensorrt.
This service handles conversion of .pt files to TensorRT format with caching
to avoid redundant conversions. It maintains a mapping database between
PT file hashes and their converted TensorRT file hashes.
"""
import hashlib
import json
import logging
from pathlib import Path
from typing import Optional, Dict, Tuple
import torch
import torch_tensorrt
from .modelstorage import IModelStorage, FileModelStorage
logger = logging.getLogger(__name__)
class PTConverter:
"""
PyTorch to TensorRT converter with intelligent caching.
Features:
- Hash-based deduplication: Same PT file only converted once
- Persistent mapping database: PT hash -> TRT hash mapping
- Pluggable storage backend: IModelStorage interface
- Automatic cache management
Architecture:
1. Compute hash of input .pt file
2. Check mapping database for existing conversion
3. If found, return cached TRT file hash and path
4. If not found, perform conversion and store mapping
"""
def __init__(
self,
storage: Optional[IModelStorage] = None,
gpu_id: int = 0,
default_precision: torch.dtype = torch.float16
):
"""
Initialize PT converter.
Args:
storage: Storage backend (defaults to FileModelStorage)
gpu_id: GPU device ID for conversion
default_precision: Default precision for TensorRT conversion (fp16, fp32)
"""
self.storage = storage or FileModelStorage()
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.default_precision = default_precision
# Mapping database: pt_hash -> {"trt_hash": str, "metadata": {...}}
self.mapping_db_key = "pt_to_trt_mapping.json"
self.mapping_db: Dict[str, Dict] = self._load_mapping_db()
logger.info(f"PTConverter initialized on GPU {gpu_id}")
logger.info(f"Storage backend: {self.storage.__class__.__name__}")
logger.info(f"Storage path: {self.storage.get_storage_path()}")
logger.info(f"Loaded {len(self.mapping_db)} cached conversions")
def _load_mapping_db(self) -> Dict[str, Dict]:
"""
Load mapping database from storage.
Returns:
Mapping dictionary (pt_hash -> metadata)
"""
try:
data = self.storage.read(self.mapping_db_key)
if data:
db = json.loads(data.decode('utf-8'))
logger.debug(f"Loaded mapping database with {len(db)} entries")
return db
else:
logger.debug("No existing mapping database found, starting fresh")
return {}
except Exception as e:
logger.warning(f"Failed to load mapping database: {e}. Starting fresh.")
return {}
def _save_mapping_db(self):
"""Save mapping database to storage"""
try:
data = json.dumps(self.mapping_db, indent=2).encode('utf-8')
self.storage.write(self.mapping_db_key, data)
logger.debug(f"Saved mapping database with {len(self.mapping_db)} entries")
except Exception as e:
logger.error(f"Failed to save mapping database: {e}")
@staticmethod
def compute_file_hash(file_path: str) -> str:
"""
Compute SHA256 hash of a file.
Args:
file_path: Path to file
Returns:
Hexadecimal hash string
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(65536), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def get_cached_conversion(self, pt_hash: str) -> Optional[Tuple[str, str]]:
"""
Check if PT file has already been converted.
Args:
pt_hash: SHA256 hash of the PT file
Returns:
Tuple of (trt_hash, trt_file_path) if cached, None otherwise
"""
if pt_hash not in self.mapping_db:
return None
mapping = self.mapping_db[pt_hash]
trt_hash = mapping["trt_hash"]
trt_key = f"trt/{trt_hash}.trt"
# Verify TRT file still exists in storage
if not self.storage.exists(trt_key):
logger.warning(
f"Mapping exists for PT hash {pt_hash[:16]}... but TRT file missing. "
f"Will reconvert."
)
# Remove stale mapping
del self.mapping_db[pt_hash]
self._save_mapping_db()
return None
# Get local path
trt_path = self.storage.get_local_path(trt_key)
if trt_path is None:
logger.error(f"Could not get local path for TRT file {trt_key}")
return None
logger.info(
f"Found cached conversion for PT hash {pt_hash[:16]}... -> "
f"TRT hash {trt_hash[:16]}..."
)
return (trt_hash, trt_path)
def convert(
self,
pt_file_path: str,
input_shapes: Optional[Dict[str, Tuple]] = None,
precision: Optional[torch.dtype] = None,
**conversion_kwargs
) -> Tuple[str, str]:
"""
Convert PyTorch model to TensorRT.
If this PT file has been converted before (same hash), returns cached result.
Otherwise, performs conversion and caches the result.
Args:
pt_file_path: Path to .pt file
input_shapes: Dict of input names to shapes (e.g., {"x": (1, 3, 224, 224)})
precision: Target precision (fp16, fp32) - defaults to self.default_precision
**conversion_kwargs: Additional arguments for torch_tensorrt.compile()
Returns:
Tuple of (trt_hash, trt_file_path)
Raises:
FileNotFoundError: If PT file doesn't exist
RuntimeError: If conversion fails
"""
pt_path = Path(pt_file_path).resolve()
if not pt_path.exists():
raise FileNotFoundError(f"PT file not found: {pt_file_path}")
# Compute PT file hash
logger.info(f"Computing hash for {pt_path}...")
pt_hash = self.compute_file_hash(str(pt_path))
logger.info(f"PT file hash: {pt_hash[:16]}...")
# Check cache
cached = self.get_cached_conversion(pt_hash)
if cached:
return cached
# Perform conversion
logger.info(f"Converting {pt_path.name} to TensorRT...")
trt_hash, trt_path = self._perform_conversion(
str(pt_path),
pt_hash,
input_shapes,
precision or self.default_precision,
**conversion_kwargs
)
# Store mapping
self.mapping_db[pt_hash] = {
"trt_hash": trt_hash,
"pt_file": str(pt_path),
"input_shapes": str(input_shapes),
"precision": str(precision or self.default_precision),
}
self._save_mapping_db()
logger.info(f"Conversion complete: PT {pt_hash[:16]}... -> TRT {trt_hash[:16]}...")
return (trt_hash, trt_path)
def _is_ultralytics_model(self, model) -> bool:
"""
Check if model is from ultralytics (YOLO).
Args:
model: PyTorch model
Returns:
True if ultralytics model, False otherwise
"""
model_class_name = model.__class__.__name__
model_module = model.__class__.__module__
# Check if it's an ultralytics model
is_ultralytics = (
'ultralytics' in model_module or
model_class_name in ['DetectionModel', 'SegmentationModel', 'PoseModel', 'ClassificationModel']
)
return is_ultralytics
def _convert_ultralytics_model(
self,
pt_path: str,
pt_hash: str,
input_shapes: Optional[Dict[str, Tuple]],
precision: torch.dtype,
) -> Tuple[str, str]:
"""
Convert ultralytics YOLO model using ONNX TensorRT pipeline.
Uses the same approach as scripts/convert_pt_to_tensorrt.py
Args:
pt_path: Path to PT file
pt_hash: PT file hash
input_shapes: Input tensor shapes
precision: Target precision
Returns:
Tuple of (trt_hash, trt_file_path)
"""
import tensorrt as trt
import tempfile
import os
import shutil
logger.info("Detected ultralytics YOLO model, using ONNX → TensorRT pipeline...")
# Load ultralytics model
try:
from ultralytics import YOLO
model = YOLO(pt_path)
except ImportError:
raise ImportError("ultralytics package not found. Install with: pip install ultralytics")
# Determine input shape
if not input_shapes:
raise ValueError("input_shapes required for ultralytics conversion")
input_key = 'images' if 'images' in input_shapes else list(input_shapes.keys())[0]
input_shape = input_shapes[input_key]
# Export to ONNX first
logger.info(f"Exporting to ONNX (input shape: {input_shape})...")
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp_onnx:
onnx_path = tmp_onnx.name
try:
# Use ultralytics export to ONNX
model.export(format='onnx', imgsz=input_shape[2], batch=input_shape[0])
# Ultralytics saves as model_name.onnx in same directory
pt_dir = os.path.dirname(pt_path)
pt_name = os.path.splitext(os.path.basename(pt_path))[0]
onnx_export_path = os.path.join(pt_dir, f"{pt_name}.onnx")
# Move to our temp location (use shutil.move for cross-device support)
if os.path.exists(onnx_export_path):
shutil.move(onnx_export_path, onnx_path)
else:
raise RuntimeError(f"ONNX export failed, file not found: {onnx_export_path}")
logger.info(f"ONNX export complete: {onnx_path}")
# Build TensorRT engine from ONNX
logger.info("Building TensorRT engine from ONNX...")
trt_logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(trt_logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, trt_logger)
# Parse ONNX
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
errors = [parser.get_error(i) for i in range(parser.num_errors)]
raise RuntimeError(f"Failed to parse ONNX: {errors}")
# Configure builder
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB
# Set precision
if precision == torch.float16:
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
logger.info("FP16 mode enabled")
# Build engine
logger.info("Building TensorRT engine (this may take a few minutes)...")
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine is None:
raise RuntimeError("Failed to build TensorRT engine")
# Convert IHostMemory to bytes
engine_bytes = bytes(serialized_engine)
# Save to storage
trt_hash = hashlib.sha256(engine_bytes).hexdigest()
trt_key = f"trt/{trt_hash}.trt"
self.storage.write(trt_key, engine_bytes)
trt_path = self.storage.get_local_path(trt_key)
if trt_path is None:
raise RuntimeError("Failed to get local path for TRT file")
logger.info(f"TensorRT engine built successfully: {trt_path}")
return (trt_hash, trt_path)
finally:
# Cleanup ONNX file
if os.path.exists(onnx_path):
os.unlink(onnx_path)
def _perform_conversion(
self,
pt_path: str,
pt_hash: str,
input_shapes: Optional[Dict[str, Tuple]],
precision: torch.dtype,
**conversion_kwargs
) -> Tuple[str, str]:
"""
Perform the actual PT to TRT conversion.
Args:
pt_path: Path to PT file
pt_hash: PT file hash
input_shapes: Input tensor shapes
precision: Target precision
**conversion_kwargs: Additional torch_tensorrt arguments
Returns:
Tuple of (trt_hash, trt_file_path)
"""
try:
# Load PyTorch model to check type
logger.debug(f"Loading PyTorch model from {pt_path}...")
# Use weights_only=False for models with custom classes (like ultralytics)
# This is safe for trusted local models
loaded = torch.load(pt_path, map_location='cpu', weights_only=False)
# If model is wrapped in a dict, extract the model
if isinstance(loaded, dict):
if 'model' in loaded:
model = loaded['model']
elif 'state_dict' in loaded:
raise ValueError(
"PT file contains state_dict only. "
"Please provide a full model or use a different loading method."
)
else:
model = loaded
# Check if this is an ultralytics model
if self._is_ultralytics_model(model):
logger.info("Detected ultralytics model, using ultralytics export API")
return self._convert_ultralytics_model(pt_path, pt_hash, input_shapes, precision)
# For non-ultralytics models, use torch_tensorrt
logger.info("Using torch_tensorrt for conversion")
model.eval()
# Convert model to target precision to avoid mixed precision issues
if precision == torch.float16:
model = model.half()
elif precision == torch.float32:
model = model.float()
# Move to GPU
model = model.to(self.device)
# Prepare inputs for tracing
if input_shapes is None:
raise ValueError(
"input_shapes must be provided for TensorRT conversion. "
"Example: {'x': (1, 3, 224, 224)}"
)
# Create sample inputs with matching precision
inputs = []
for name, shape in input_shapes.items():
sample_input = torch.randn(shape, device=self.device, dtype=precision)
inputs.append(sample_input)
# Configure torch_tensorrt
enabled_precisions = {precision}
if precision == torch.float16:
enabled_precisions.add(torch.float32) # Fallback for unsupported ops
# Compile to TensorRT
logger.info(f"Compiling to TensorRT (precision: {precision})...")
trt_model = torch_tensorrt.compile(
model,
inputs=inputs,
enabled_precisions=enabled_precisions,
**conversion_kwargs
)
# Save TRT model to temporary location
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix='.trt') as tmp_file:
tmp_path = tmp_file.name
torch.jit.save(trt_model, tmp_path)
logger.debug(f"Saved TRT model to temporary file: {tmp_path}")
# Compute TRT file hash
trt_hash = self.compute_file_hash(tmp_path)
logger.info(f"TRT file hash: {trt_hash[:16]}...")
# Store in storage backend
trt_key = f"trt/{trt_hash}.trt"
with open(tmp_path, 'rb') as f:
trt_data = f.read()
self.storage.write(trt_key, trt_data)
# Get local path
trt_path = self.storage.get_local_path(trt_key)
if trt_path is None:
raise RuntimeError("Failed to get local path for converted TRT file")
# Cleanup temp file
Path(tmp_path).unlink()
logger.info(f"TRT model stored successfully at {trt_path}")
return (trt_hash, trt_path)
except Exception as e:
logger.error(f"Conversion failed: {e}", exc_info=True)
raise RuntimeError(f"Failed to convert PT to TensorRT: {e}")
def clear_cache(self):
"""Clear all cached conversions and mapping database"""
logger.warning("Clearing all cached conversions...")
self.mapping_db.clear()
self._save_mapping_db()
logger.info("Cache cleared")
def get_stats(self) -> Dict:
"""
Get conversion statistics.
Returns:
Dictionary with cache stats
"""
return {
"total_cached_conversions": len(self.mapping_db),
"storage_path": self.storage.get_storage_path(),
"gpu_id": self.gpu_id,
"default_precision": str(self.default_precision),
}