ultralytic export

This commit is contained in:
Siwat Sirichai 2025-11-11 01:28:19 +07:00
parent bf7b68edb1
commit fdaeb9981c
14 changed files with 2241 additions and 507 deletions

View file

@ -1,13 +1,14 @@
import threading
import hashlib
import json
from typing import Optional, Dict, Any, List, Tuple
import logging
import threading
from dataclasses import dataclass
from pathlib import Path
from queue import Queue
import torch
from typing import Any, Dict, List, Optional, Tuple
import tensorrt as trt
from dataclasses import dataclass
import logging
import torch
logger = logging.getLogger(__name__)
@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
@dataclass
class ModelMetadata:
"""Metadata for a loaded TensorRT model"""
file_path: str
file_hash: str
input_shapes: Dict[str, Tuple[int, ...]]
@ -30,8 +32,14 @@ class ExecutionContext:
Wrapper for TensorRT execution context with CUDA stream.
Used in context pool for load balancing.
"""
def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream,
context_id: int, device: torch.device):
def __init__(
self,
context: trt.IExecutionContext,
stream: torch.cuda.Stream,
context_id: int,
device: torch.device,
):
self.context = context
self.stream = stream
self.context_id = context_id
@ -53,8 +61,16 @@ class SharedEngine:
- Contexts are borrowed/returned using mutex locks
- Load balancing: contexts distributed across requests
"""
def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str,
num_contexts: int, device: torch.device, metadata: ModelMetadata):
def __init__(
self,
engine: trt.ICudaEngine,
file_hash: str,
file_path: str,
num_contexts: int,
device: torch.device,
metadata: ModelMetadata,
):
self.engine = engine
self.file_hash = file_hash
self.file_path = file_path
@ -80,9 +96,13 @@ class SharedEngine:
self.model_ids: set = set()
self.lock = threading.Lock()
print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...")
print(
f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}..."
)
def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]:
def acquire_context(
self, timeout: Optional[float] = None
) -> Optional[ExecutionContext]:
"""
Acquire an available execution context from the pool.
Blocks if all contexts are in use.
@ -162,7 +182,13 @@ class TensorRTModelRepository:
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
"""
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True, cache_dir: str = ".trt_cache"):
def __init__(
self,
gpu_id: int = 0,
default_num_contexts: int = 4,
enable_pt_conversion: bool = True,
cache_dir: str = ".trt_cache",
):
"""
Initialize the model repository.
@ -173,7 +199,7 @@ class TensorRTModelRepository:
cache_dir: Directory for caching stripped TensorRT engines and metadata
"""
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.device = torch.device(f"cuda:{gpu_id}")
self.default_num_contexts = default_num_contexts
self.enable_pt_conversion = enable_pt_conversion
self.cache_dir = Path(cache_dir)
@ -195,7 +221,9 @@ class TensorRTModelRepository:
self._pt_converter = None
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
print(
f"Default context pool size: {default_num_contexts} contexts per unique model"
)
print(f"Cache directory: {self.cache_dir}")
if enable_pt_conversion:
print(f"PyTorch to TensorRT conversion: enabled")
@ -205,6 +233,7 @@ class TensorRTModelRepository:
"""Lazy initialization of PT converter"""
if self._pt_converter is None and self.enable_pt_conversion:
from .pt_converter import PTConverter
self._pt_converter = PTConverter(gpu_id=self.gpu_id)
logger.info("PT converter initialized")
return self._pt_converter
@ -255,11 +284,11 @@ class TensorRTModelRepository:
# Check if stripped engine already cached
if cache_engine_path.exists():
logger.info(f"Loading cached stripped engine from {cache_engine_path}")
with open(cache_engine_path, 'rb') as f:
with open(cache_engine_path, "rb") as f:
engine_data = f.read()
else:
# Read and process original file
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
# Try to read Ultralytics metadata header (first 4 bytes = metadata length)
try:
meta_len_bytes = f.read(4)
@ -278,13 +307,15 @@ class TensorRTModelRepository:
# Save stripped engine to cache
logger.info(f"Detected Ultralytics engine format")
logger.info(f"Ultralytics metadata: {metadata}")
logger.info(f"Caching stripped engine to {cache_engine_path}")
logger.info(
f"Caching stripped engine to {cache_engine_path}"
)
with open(cache_engine_path, 'wb') as cache_f:
with open(cache_engine_path, "wb") as cache_f:
cache_f.write(engine_data)
# Save metadata separately
with open(cache_metadata_path, 'w') as meta_f:
with open(cache_metadata_path, "w") as meta_f:
json.dump(metadata, meta_f, indent=2)
except (UnicodeDecodeError, json.JSONDecodeError):
@ -301,13 +332,15 @@ class TensorRTModelRepository:
except Exception as e:
# Any error, rewind and read entire file
logger.warning(f"Error reading engine metadata: {e}, treating as raw TRT engine")
logger.warning(
f"Error reading engine metadata: {e}, treating as raw TRT engine"
)
f.seek(0)
engine_data = f.read()
# Cache the engine data (even if it was already raw TRT)
if not cache_engine_path.exists():
with open(cache_engine_path, 'wb') as cache_f:
with open(cache_engine_path, "wb") as cache_f:
cache_f.write(engine_data)
engine = runtime.deserialize_cuda_engine(engine_data)
@ -316,8 +349,9 @@ class TensorRTModelRepository:
return engine
def _extract_metadata(self, engine: trt.ICudaEngine,
file_path: str, file_hash: str) -> ModelMetadata:
def _extract_metadata(
self, engine: trt.ICudaEngine, file_path: str, file_hash: str
) -> ModelMetadata:
"""
Extract metadata from TensorRT engine.
@ -369,15 +403,19 @@ class TensorRTModelRepository:
input_names=input_names,
output_names=output_names,
input_dtypes=input_dtypes,
output_dtypes=output_dtypes
output_dtypes=output_dtypes,
)
def load_model(self, model_id: str, file_path: str,
num_contexts: Optional[int] = None,
force_reload: bool = False,
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
pt_precision: Optional[torch.dtype] = None,
**pt_conversion_kwargs) -> ModelMetadata:
def load_model(
self,
model_id: str,
file_path: str,
num_contexts: Optional[int] = None,
force_reload: bool = False,
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
pt_precision: Optional[torch.dtype] = None,
**pt_conversion_kwargs,
) -> ModelMetadata:
"""
Load a TensorRT model with the given ID.
@ -410,7 +448,7 @@ class TensorRTModelRepository:
# Check if file is PyTorch model
file_ext = Path(file_path).suffix.lower()
if file_ext in ['.pt', '.pth']:
if file_ext in [".pt", ".pth"]:
if not self.enable_pt_conversion:
raise ValueError(
f"PT file provided but PT conversion is disabled. "
@ -425,7 +463,7 @@ class TensorRTModelRepository:
file_path,
input_shapes=pt_input_shapes,
precision=pt_precision,
**pt_conversion_kwargs
**pt_conversion_kwargs,
)
# Update file_path to use converted TRT file
@ -455,8 +493,12 @@ class TensorRTModelRepository:
# Check if this file is already loaded (deduplication)
if file_hash in self._shared_engines:
shared_engine = self._shared_engines[file_hash]
print(f"Engine already loaded (hash match), reusing engine and context pool...")
print(f" Existing model_ids using this engine: {shared_engine.model_ids}")
print(
f"Engine already loaded (hash match), reusing engine and context pool..."
)
print(
f" Existing model_ids using this engine: {shared_engine.model_ids}"
)
else:
# Load new engine
print(f"Loading TensorRT engine from {file_path}...")
@ -472,7 +514,7 @@ class TensorRTModelRepository:
file_path=file_path,
num_contexts=num_contexts,
device=self.device,
metadata=metadata
metadata=metadata,
)
self._shared_engines[file_hash] = shared_engine
@ -485,18 +527,29 @@ class TensorRTModelRepository:
print(f"Model '{model_id}' loaded successfully")
print(f" Inputs: {shared_engine.metadata.input_names}")
for name in shared_engine.metadata.input_names:
print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})")
print(
f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})"
)
print(f" Outputs: {shared_engine.metadata.output_names}")
for name in shared_engine.metadata.output_names:
print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})")
print(
f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})"
)
print(f" Context pool size: {num_contexts}")
print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}")
print(
f" Model IDs sharing this engine: {shared_engine.get_reference_count()}"
)
print(f" Unique engines in VRAM: {len(self._shared_engines)}")
return shared_engine.metadata
def infer(self, model_id: str, inputs: Dict[str, torch.Tensor],
synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]:
def infer(
self,
model_id: str,
inputs: Dict[str, torch.Tensor],
synchronize: bool = True,
timeout: Optional[float] = 5.0,
) -> Dict[str, torch.Tensor]:
"""
Run GPU-to-GPU inference with the specified model using context pooling.
@ -519,7 +572,9 @@ class TensorRTModelRepository:
"""
# Get shared engine
if model_id not in self._model_to_hash:
raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}")
raise KeyError(
f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}"
)
file_hash = self._model_to_hash[model_id]
shared_engine = self._shared_engines[file_hash]
@ -536,7 +591,9 @@ class TensorRTModelRepository:
# Check device
if tensor.device != self.device:
print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}")
print(
f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}"
)
inputs[name] = tensor.to(self.device)
# Acquire context from pool (mutex-based)
@ -562,9 +619,7 @@ class TensorRTModelRepository:
output_dtype = metadata.output_dtypes[name]
output_tensor = torch.empty(
output_shape,
dtype=output_dtype,
device=self.device
output_shape, dtype=output_dtype, device=self.device
)
# NOTE: Don't track these tensors - they're returned to caller and consumed
@ -584,9 +639,23 @@ class TensorRTModelRepository:
if not success:
raise RuntimeError(f"Inference failed for model '{model_id}'")
# Synchronize if requested
if synchronize:
exec_ctx.stream.synchronize()
# CRITICAL: Always synchronize before releasing context
# Even if caller requested async execution, we MUST sync before
# releasing the context to prevent race conditions where the next
# inference using this context overwrites tensor addresses while
# the current batch is still being processed.
exec_ctx.stream.synchronize()
# Clone outputs to new tensors to ensure memory safety
# This prevents race conditions where the next batch using this context
# could overwrite the output tensor addresses before the caller
# finishes processing these results.
if not synchronize:
# For async mode, clone to decouple from context
cloned_outputs = {}
for name, tensor in outputs.items():
cloned_outputs[name] = tensor.clone()
outputs = cloned_outputs
return outputs
@ -594,8 +663,12 @@ class TensorRTModelRepository:
# Always release context back to pool
shared_engine.release_context(exec_ctx)
def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]],
synchronize: bool = True) -> List[Dict[str, torch.Tensor]]:
def infer_batch(
self,
model_id: str,
batch_inputs: List[Dict[str, torch.Tensor]],
synchronize: bool = True,
) -> List[Dict[str, torch.Tensor]]:
"""
Run inference on multiple inputs.
Contexts are borrowed/returned for each input, enabling parallel processing.
@ -641,9 +714,13 @@ class TensorRTModelRepository:
if remaining_refs == 0:
shared_engine.cleanup()
del self._shared_engines[file_hash]
print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)")
print(
f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)"
)
else:
print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)")
print(
f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)"
)
# Remove from model_id mapping
del self._model_to_hash[model_id]
@ -702,26 +779,26 @@ class TensorRTModelRepository:
metadata = shared_engine.metadata
return {
'model_id': model_id,
'file_path': metadata.file_path,
'file_hash': metadata.file_hash[:16] + '...',
'engine_references': shared_engine.get_reference_count(),
'context_pool_size': shared_engine.num_contexts,
'shared_with_model_ids': list(shared_engine.model_ids),
'inputs': {
"model_id": model_id,
"file_path": metadata.file_path,
"file_hash": metadata.file_hash[:16] + "...",
"engine_references": shared_engine.get_reference_count(),
"context_pool_size": shared_engine.num_contexts,
"shared_with_model_ids": list(shared_engine.model_ids),
"inputs": {
name: {
'shape': metadata.input_shapes[name],
'dtype': str(metadata.input_dtypes[name])
"shape": metadata.input_shapes[name],
"dtype": str(metadata.input_dtypes[name]),
}
for name in metadata.input_names
},
'outputs': {
"outputs": {
name: {
'shape': metadata.output_shapes[name],
'dtype': str(metadata.output_dtypes[name])
"shape": metadata.output_shapes[name],
"dtype": str(metadata.output_dtypes[name]),
}
for name in metadata.output_names
}
},
}
def get_stats(self) -> Dict[str, Any]:
@ -733,24 +810,25 @@ class TensorRTModelRepository:
"""
with self._repo_lock:
total_contexts = sum(
engine.num_contexts
for engine in self._shared_engines.values()
engine.num_contexts for engine in self._shared_engines.values()
)
return {
'total_model_ids': len(self._model_to_hash),
'unique_engines': len(self._shared_engines),
'total_contexts': total_contexts,
'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
'gpu_id': self.gpu_id,
'models': list(self._model_to_hash.keys())
"total_model_ids": len(self._model_to_hash),
"unique_engines": len(self._shared_engines),
"total_contexts": total_contexts,
"memory_efficiency": f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
"gpu_id": self.gpu_id,
"models": list(self._model_to_hash.keys()),
}
def __repr__(self):
with self._repo_lock:
return (f"TensorRTModelRepository(gpu={self.gpu_id}, "
f"model_ids={len(self._model_to_hash)}, "
f"unique_engines={len(self._shared_engines)})")
return (
f"TensorRTModelRepository(gpu={self.gpu_id}, "
f"model_ids={len(self._model_to_hash)}, "
f"unique_engines={len(self._shared_engines)})"
)
def __del__(self):
"""Cleanup all models on deletion"""