ultralytic export
This commit is contained in:
parent
bf7b68edb1
commit
fdaeb9981c
14 changed files with 2241 additions and 507 deletions
|
|
@ -1,13 +1,14 @@
|
|||
import threading
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
import torch
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import tensorrt as trt
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
@dataclass
|
||||
class ModelMetadata:
|
||||
"""Metadata for a loaded TensorRT model"""
|
||||
|
||||
file_path: str
|
||||
file_hash: str
|
||||
input_shapes: Dict[str, Tuple[int, ...]]
|
||||
|
|
@ -30,8 +32,14 @@ class ExecutionContext:
|
|||
Wrapper for TensorRT execution context with CUDA stream.
|
||||
Used in context pool for load balancing.
|
||||
"""
|
||||
def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream,
|
||||
context_id: int, device: torch.device):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: trt.IExecutionContext,
|
||||
stream: torch.cuda.Stream,
|
||||
context_id: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.context = context
|
||||
self.stream = stream
|
||||
self.context_id = context_id
|
||||
|
|
@ -53,8 +61,16 @@ class SharedEngine:
|
|||
- Contexts are borrowed/returned using mutex locks
|
||||
- Load balancing: contexts distributed across requests
|
||||
"""
|
||||
def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str,
|
||||
num_contexts: int, device: torch.device, metadata: ModelMetadata):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: trt.ICudaEngine,
|
||||
file_hash: str,
|
||||
file_path: str,
|
||||
num_contexts: int,
|
||||
device: torch.device,
|
||||
metadata: ModelMetadata,
|
||||
):
|
||||
self.engine = engine
|
||||
self.file_hash = file_hash
|
||||
self.file_path = file_path
|
||||
|
|
@ -80,9 +96,13 @@ class SharedEngine:
|
|||
self.model_ids: set = set()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...")
|
||||
print(
|
||||
f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}..."
|
||||
)
|
||||
|
||||
def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]:
|
||||
def acquire_context(
|
||||
self, timeout: Optional[float] = None
|
||||
) -> Optional[ExecutionContext]:
|
||||
"""
|
||||
Acquire an available execution context from the pool.
|
||||
Blocks if all contexts are in use.
|
||||
|
|
@ -162,7 +182,13 @@ class TensorRTModelRepository:
|
|||
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
|
||||
"""
|
||||
|
||||
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True, cache_dir: str = ".trt_cache"):
|
||||
def __init__(
|
||||
self,
|
||||
gpu_id: int = 0,
|
||||
default_num_contexts: int = 4,
|
||||
enable_pt_conversion: bool = True,
|
||||
cache_dir: str = ".trt_cache",
|
||||
):
|
||||
"""
|
||||
Initialize the model repository.
|
||||
|
||||
|
|
@ -173,7 +199,7 @@ class TensorRTModelRepository:
|
|||
cache_dir: Directory for caching stripped TensorRT engines and metadata
|
||||
"""
|
||||
self.gpu_id = gpu_id
|
||||
self.device = torch.device(f'cuda:{gpu_id}')
|
||||
self.device = torch.device(f"cuda:{gpu_id}")
|
||||
self.default_num_contexts = default_num_contexts
|
||||
self.enable_pt_conversion = enable_pt_conversion
|
||||
self.cache_dir = Path(cache_dir)
|
||||
|
|
@ -195,7 +221,9 @@ class TensorRTModelRepository:
|
|||
self._pt_converter = None
|
||||
|
||||
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
|
||||
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
|
||||
print(
|
||||
f"Default context pool size: {default_num_contexts} contexts per unique model"
|
||||
)
|
||||
print(f"Cache directory: {self.cache_dir}")
|
||||
if enable_pt_conversion:
|
||||
print(f"PyTorch to TensorRT conversion: enabled")
|
||||
|
|
@ -205,6 +233,7 @@ class TensorRTModelRepository:
|
|||
"""Lazy initialization of PT converter"""
|
||||
if self._pt_converter is None and self.enable_pt_conversion:
|
||||
from .pt_converter import PTConverter
|
||||
|
||||
self._pt_converter = PTConverter(gpu_id=self.gpu_id)
|
||||
logger.info("PT converter initialized")
|
||||
return self._pt_converter
|
||||
|
|
@ -255,11 +284,11 @@ class TensorRTModelRepository:
|
|||
# Check if stripped engine already cached
|
||||
if cache_engine_path.exists():
|
||||
logger.info(f"Loading cached stripped engine from {cache_engine_path}")
|
||||
with open(cache_engine_path, 'rb') as f:
|
||||
with open(cache_engine_path, "rb") as f:
|
||||
engine_data = f.read()
|
||||
else:
|
||||
# Read and process original file
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
# Try to read Ultralytics metadata header (first 4 bytes = metadata length)
|
||||
try:
|
||||
meta_len_bytes = f.read(4)
|
||||
|
|
@ -278,13 +307,15 @@ class TensorRTModelRepository:
|
|||
# Save stripped engine to cache
|
||||
logger.info(f"Detected Ultralytics engine format")
|
||||
logger.info(f"Ultralytics metadata: {metadata}")
|
||||
logger.info(f"Caching stripped engine to {cache_engine_path}")
|
||||
logger.info(
|
||||
f"Caching stripped engine to {cache_engine_path}"
|
||||
)
|
||||
|
||||
with open(cache_engine_path, 'wb') as cache_f:
|
||||
with open(cache_engine_path, "wb") as cache_f:
|
||||
cache_f.write(engine_data)
|
||||
|
||||
# Save metadata separately
|
||||
with open(cache_metadata_path, 'w') as meta_f:
|
||||
with open(cache_metadata_path, "w") as meta_f:
|
||||
json.dump(metadata, meta_f, indent=2)
|
||||
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
|
|
@ -301,13 +332,15 @@ class TensorRTModelRepository:
|
|||
|
||||
except Exception as e:
|
||||
# Any error, rewind and read entire file
|
||||
logger.warning(f"Error reading engine metadata: {e}, treating as raw TRT engine")
|
||||
logger.warning(
|
||||
f"Error reading engine metadata: {e}, treating as raw TRT engine"
|
||||
)
|
||||
f.seek(0)
|
||||
engine_data = f.read()
|
||||
|
||||
# Cache the engine data (even if it was already raw TRT)
|
||||
if not cache_engine_path.exists():
|
||||
with open(cache_engine_path, 'wb') as cache_f:
|
||||
with open(cache_engine_path, "wb") as cache_f:
|
||||
cache_f.write(engine_data)
|
||||
|
||||
engine = runtime.deserialize_cuda_engine(engine_data)
|
||||
|
|
@ -316,8 +349,9 @@ class TensorRTModelRepository:
|
|||
|
||||
return engine
|
||||
|
||||
def _extract_metadata(self, engine: trt.ICudaEngine,
|
||||
file_path: str, file_hash: str) -> ModelMetadata:
|
||||
def _extract_metadata(
|
||||
self, engine: trt.ICudaEngine, file_path: str, file_hash: str
|
||||
) -> ModelMetadata:
|
||||
"""
|
||||
Extract metadata from TensorRT engine.
|
||||
|
||||
|
|
@ -369,15 +403,19 @@ class TensorRTModelRepository:
|
|||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
input_dtypes=input_dtypes,
|
||||
output_dtypes=output_dtypes
|
||||
output_dtypes=output_dtypes,
|
||||
)
|
||||
|
||||
def load_model(self, model_id: str, file_path: str,
|
||||
num_contexts: Optional[int] = None,
|
||||
force_reload: bool = False,
|
||||
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
|
||||
pt_precision: Optional[torch.dtype] = None,
|
||||
**pt_conversion_kwargs) -> ModelMetadata:
|
||||
def load_model(
|
||||
self,
|
||||
model_id: str,
|
||||
file_path: str,
|
||||
num_contexts: Optional[int] = None,
|
||||
force_reload: bool = False,
|
||||
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
|
||||
pt_precision: Optional[torch.dtype] = None,
|
||||
**pt_conversion_kwargs,
|
||||
) -> ModelMetadata:
|
||||
"""
|
||||
Load a TensorRT model with the given ID.
|
||||
|
||||
|
|
@ -410,7 +448,7 @@ class TensorRTModelRepository:
|
|||
|
||||
# Check if file is PyTorch model
|
||||
file_ext = Path(file_path).suffix.lower()
|
||||
if file_ext in ['.pt', '.pth']:
|
||||
if file_ext in [".pt", ".pth"]:
|
||||
if not self.enable_pt_conversion:
|
||||
raise ValueError(
|
||||
f"PT file provided but PT conversion is disabled. "
|
||||
|
|
@ -425,7 +463,7 @@ class TensorRTModelRepository:
|
|||
file_path,
|
||||
input_shapes=pt_input_shapes,
|
||||
precision=pt_precision,
|
||||
**pt_conversion_kwargs
|
||||
**pt_conversion_kwargs,
|
||||
)
|
||||
|
||||
# Update file_path to use converted TRT file
|
||||
|
|
@ -455,8 +493,12 @@ class TensorRTModelRepository:
|
|||
# Check if this file is already loaded (deduplication)
|
||||
if file_hash in self._shared_engines:
|
||||
shared_engine = self._shared_engines[file_hash]
|
||||
print(f"Engine already loaded (hash match), reusing engine and context pool...")
|
||||
print(f" Existing model_ids using this engine: {shared_engine.model_ids}")
|
||||
print(
|
||||
f"Engine already loaded (hash match), reusing engine and context pool..."
|
||||
)
|
||||
print(
|
||||
f" Existing model_ids using this engine: {shared_engine.model_ids}"
|
||||
)
|
||||
else:
|
||||
# Load new engine
|
||||
print(f"Loading TensorRT engine from {file_path}...")
|
||||
|
|
@ -472,7 +514,7 @@ class TensorRTModelRepository:
|
|||
file_path=file_path,
|
||||
num_contexts=num_contexts,
|
||||
device=self.device,
|
||||
metadata=metadata
|
||||
metadata=metadata,
|
||||
)
|
||||
self._shared_engines[file_hash] = shared_engine
|
||||
|
||||
|
|
@ -485,18 +527,29 @@ class TensorRTModelRepository:
|
|||
print(f"Model '{model_id}' loaded successfully")
|
||||
print(f" Inputs: {shared_engine.metadata.input_names}")
|
||||
for name in shared_engine.metadata.input_names:
|
||||
print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})")
|
||||
print(
|
||||
f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})"
|
||||
)
|
||||
print(f" Outputs: {shared_engine.metadata.output_names}")
|
||||
for name in shared_engine.metadata.output_names:
|
||||
print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})")
|
||||
print(
|
||||
f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})"
|
||||
)
|
||||
print(f" Context pool size: {num_contexts}")
|
||||
print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}")
|
||||
print(
|
||||
f" Model IDs sharing this engine: {shared_engine.get_reference_count()}"
|
||||
)
|
||||
print(f" Unique engines in VRAM: {len(self._shared_engines)}")
|
||||
|
||||
return shared_engine.metadata
|
||||
|
||||
def infer(self, model_id: str, inputs: Dict[str, torch.Tensor],
|
||||
synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]:
|
||||
def infer(
|
||||
self,
|
||||
model_id: str,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
synchronize: bool = True,
|
||||
timeout: Optional[float] = 5.0,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Run GPU-to-GPU inference with the specified model using context pooling.
|
||||
|
||||
|
|
@ -519,7 +572,9 @@ class TensorRTModelRepository:
|
|||
"""
|
||||
# Get shared engine
|
||||
if model_id not in self._model_to_hash:
|
||||
raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}")
|
||||
raise KeyError(
|
||||
f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}"
|
||||
)
|
||||
|
||||
file_hash = self._model_to_hash[model_id]
|
||||
shared_engine = self._shared_engines[file_hash]
|
||||
|
|
@ -536,7 +591,9 @@ class TensorRTModelRepository:
|
|||
|
||||
# Check device
|
||||
if tensor.device != self.device:
|
||||
print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}")
|
||||
print(
|
||||
f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}"
|
||||
)
|
||||
inputs[name] = tensor.to(self.device)
|
||||
|
||||
# Acquire context from pool (mutex-based)
|
||||
|
|
@ -562,9 +619,7 @@ class TensorRTModelRepository:
|
|||
output_dtype = metadata.output_dtypes[name]
|
||||
|
||||
output_tensor = torch.empty(
|
||||
output_shape,
|
||||
dtype=output_dtype,
|
||||
device=self.device
|
||||
output_shape, dtype=output_dtype, device=self.device
|
||||
)
|
||||
|
||||
# NOTE: Don't track these tensors - they're returned to caller and consumed
|
||||
|
|
@ -584,9 +639,23 @@ class TensorRTModelRepository:
|
|||
if not success:
|
||||
raise RuntimeError(f"Inference failed for model '{model_id}'")
|
||||
|
||||
# Synchronize if requested
|
||||
if synchronize:
|
||||
exec_ctx.stream.synchronize()
|
||||
# CRITICAL: Always synchronize before releasing context
|
||||
# Even if caller requested async execution, we MUST sync before
|
||||
# releasing the context to prevent race conditions where the next
|
||||
# inference using this context overwrites tensor addresses while
|
||||
# the current batch is still being processed.
|
||||
exec_ctx.stream.synchronize()
|
||||
|
||||
# Clone outputs to new tensors to ensure memory safety
|
||||
# This prevents race conditions where the next batch using this context
|
||||
# could overwrite the output tensor addresses before the caller
|
||||
# finishes processing these results.
|
||||
if not synchronize:
|
||||
# For async mode, clone to decouple from context
|
||||
cloned_outputs = {}
|
||||
for name, tensor in outputs.items():
|
||||
cloned_outputs[name] = tensor.clone()
|
||||
outputs = cloned_outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
|
@ -594,8 +663,12 @@ class TensorRTModelRepository:
|
|||
# Always release context back to pool
|
||||
shared_engine.release_context(exec_ctx)
|
||||
|
||||
def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]],
|
||||
synchronize: bool = True) -> List[Dict[str, torch.Tensor]]:
|
||||
def infer_batch(
|
||||
self,
|
||||
model_id: str,
|
||||
batch_inputs: List[Dict[str, torch.Tensor]],
|
||||
synchronize: bool = True,
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Run inference on multiple inputs.
|
||||
Contexts are borrowed/returned for each input, enabling parallel processing.
|
||||
|
|
@ -641,9 +714,13 @@ class TensorRTModelRepository:
|
|||
if remaining_refs == 0:
|
||||
shared_engine.cleanup()
|
||||
del self._shared_engines[file_hash]
|
||||
print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)")
|
||||
print(
|
||||
f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)"
|
||||
)
|
||||
else:
|
||||
print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)")
|
||||
print(
|
||||
f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)"
|
||||
)
|
||||
|
||||
# Remove from model_id mapping
|
||||
del self._model_to_hash[model_id]
|
||||
|
|
@ -702,26 +779,26 @@ class TensorRTModelRepository:
|
|||
metadata = shared_engine.metadata
|
||||
|
||||
return {
|
||||
'model_id': model_id,
|
||||
'file_path': metadata.file_path,
|
||||
'file_hash': metadata.file_hash[:16] + '...',
|
||||
'engine_references': shared_engine.get_reference_count(),
|
||||
'context_pool_size': shared_engine.num_contexts,
|
||||
'shared_with_model_ids': list(shared_engine.model_ids),
|
||||
'inputs': {
|
||||
"model_id": model_id,
|
||||
"file_path": metadata.file_path,
|
||||
"file_hash": metadata.file_hash[:16] + "...",
|
||||
"engine_references": shared_engine.get_reference_count(),
|
||||
"context_pool_size": shared_engine.num_contexts,
|
||||
"shared_with_model_ids": list(shared_engine.model_ids),
|
||||
"inputs": {
|
||||
name: {
|
||||
'shape': metadata.input_shapes[name],
|
||||
'dtype': str(metadata.input_dtypes[name])
|
||||
"shape": metadata.input_shapes[name],
|
||||
"dtype": str(metadata.input_dtypes[name]),
|
||||
}
|
||||
for name in metadata.input_names
|
||||
},
|
||||
'outputs': {
|
||||
"outputs": {
|
||||
name: {
|
||||
'shape': metadata.output_shapes[name],
|
||||
'dtype': str(metadata.output_dtypes[name])
|
||||
"shape": metadata.output_shapes[name],
|
||||
"dtype": str(metadata.output_dtypes[name]),
|
||||
}
|
||||
for name in metadata.output_names
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
|
|
@ -733,24 +810,25 @@ class TensorRTModelRepository:
|
|||
"""
|
||||
with self._repo_lock:
|
||||
total_contexts = sum(
|
||||
engine.num_contexts
|
||||
for engine in self._shared_engines.values()
|
||||
engine.num_contexts for engine in self._shared_engines.values()
|
||||
)
|
||||
|
||||
return {
|
||||
'total_model_ids': len(self._model_to_hash),
|
||||
'unique_engines': len(self._shared_engines),
|
||||
'total_contexts': total_contexts,
|
||||
'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
|
||||
'gpu_id': self.gpu_id,
|
||||
'models': list(self._model_to_hash.keys())
|
||||
"total_model_ids": len(self._model_to_hash),
|
||||
"unique_engines": len(self._shared_engines),
|
||||
"total_contexts": total_contexts,
|
||||
"memory_efficiency": f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
|
||||
"gpu_id": self.gpu_id,
|
||||
"models": list(self._model_to_hash.keys()),
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
with self._repo_lock:
|
||||
return (f"TensorRTModelRepository(gpu={self.gpu_id}, "
|
||||
f"model_ids={len(self._model_to_hash)}, "
|
||||
f"unique_engines={len(self._shared_engines)})")
|
||||
return (
|
||||
f"TensorRTModelRepository(gpu={self.gpu_id}, "
|
||||
f"model_ids={len(self._model_to_hash)}, "
|
||||
f"unique_engines={len(self._shared_engines)})"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup all models on deletion"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue