631 lines
22 KiB
Python
631 lines
22 KiB
Python
import threading
|
|
import hashlib
|
|
from typing import Optional, Dict, Any, List, Tuple
|
|
from pathlib import Path
|
|
from queue import Queue
|
|
import torch
|
|
import tensorrt as trt
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class ModelMetadata:
|
|
"""Metadata for a loaded TensorRT model"""
|
|
file_path: str
|
|
file_hash: str
|
|
input_shapes: Dict[str, Tuple[int, ...]]
|
|
output_shapes: Dict[str, Tuple[int, ...]]
|
|
input_names: List[str]
|
|
output_names: List[str]
|
|
input_dtypes: Dict[str, torch.dtype]
|
|
output_dtypes: Dict[str, torch.dtype]
|
|
|
|
|
|
class ExecutionContext:
|
|
"""
|
|
Wrapper for TensorRT execution context with CUDA stream.
|
|
Used in context pool for load balancing.
|
|
"""
|
|
def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream,
|
|
context_id: int, device: torch.device):
|
|
self.context = context
|
|
self.stream = stream
|
|
self.context_id = context_id
|
|
self.device = device
|
|
self.in_use = False
|
|
self.lock = threading.Lock()
|
|
|
|
def __repr__(self):
|
|
return f"ExecutionContext(id={self.context_id}, in_use={self.in_use})"
|
|
|
|
|
|
class SharedEngine:
|
|
"""
|
|
Shared TensorRT engine with context pool for load balancing.
|
|
|
|
Architecture:
|
|
- One engine shared across all model_ids with same file hash
|
|
- Pool of N execution contexts for concurrent inference
|
|
- Contexts are borrowed/returned using mutex locks
|
|
- Load balancing: contexts distributed across requests
|
|
"""
|
|
def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str,
|
|
num_contexts: int, device: torch.device, metadata: ModelMetadata):
|
|
self.engine = engine
|
|
self.file_hash = file_hash
|
|
self.file_path = file_path
|
|
self.metadata = metadata
|
|
self.device = device
|
|
self.num_contexts = num_contexts
|
|
|
|
# Create context pool
|
|
self.context_pool: List[ExecutionContext] = []
|
|
self.available_contexts: Queue[ExecutionContext] = Queue()
|
|
|
|
for i in range(num_contexts):
|
|
ctx = engine.create_execution_context()
|
|
if ctx is None:
|
|
raise RuntimeError(f"Failed to create execution context {i}")
|
|
|
|
stream = torch.cuda.Stream(device=device)
|
|
exec_ctx = ExecutionContext(ctx, stream, i, device)
|
|
self.context_pool.append(exec_ctx)
|
|
self.available_contexts.put(exec_ctx)
|
|
|
|
# Model IDs referencing this engine
|
|
self.model_ids: set = set()
|
|
self.lock = threading.Lock()
|
|
|
|
print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...")
|
|
|
|
def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]:
|
|
"""
|
|
Acquire an available execution context from the pool.
|
|
Blocks if all contexts are in use.
|
|
|
|
Args:
|
|
timeout: Max time to wait for context (None = wait forever)
|
|
|
|
Returns:
|
|
ExecutionContext or None if timeout
|
|
"""
|
|
try:
|
|
exec_ctx = self.available_contexts.get(timeout=timeout)
|
|
with exec_ctx.lock:
|
|
exec_ctx.in_use = True
|
|
return exec_ctx
|
|
except:
|
|
return None
|
|
|
|
def release_context(self, exec_ctx: ExecutionContext):
|
|
"""
|
|
Return a context to the pool.
|
|
|
|
Args:
|
|
exec_ctx: Context to release
|
|
"""
|
|
with exec_ctx.lock:
|
|
exec_ctx.in_use = False
|
|
self.available_contexts.put(exec_ctx)
|
|
|
|
def add_model_id(self, model_id: str):
|
|
"""Add a model_id reference to this engine"""
|
|
with self.lock:
|
|
self.model_ids.add(model_id)
|
|
|
|
def remove_model_id(self, model_id: str) -> int:
|
|
"""
|
|
Remove a model_id reference from this engine.
|
|
Returns the number of remaining references.
|
|
"""
|
|
with self.lock:
|
|
self.model_ids.discard(model_id)
|
|
return len(self.model_ids)
|
|
|
|
def get_reference_count(self) -> int:
|
|
"""Get number of model_ids using this engine"""
|
|
with self.lock:
|
|
return len(self.model_ids)
|
|
|
|
def cleanup(self):
|
|
"""Cleanup all contexts"""
|
|
for exec_ctx in self.context_pool:
|
|
del exec_ctx.context
|
|
self.context_pool.clear()
|
|
del self.engine
|
|
|
|
|
|
class TensorRTModelRepository:
|
|
"""
|
|
Thread-safe repository for TensorRT models with context pooling and deduplication.
|
|
|
|
Architecture:
|
|
- Deduplication: Multiple model_ids with same file → share one engine
|
|
- Context Pool: Each unique engine has N execution contexts (configurable)
|
|
- Load Balancing: Contexts are borrowed/returned via mutex queue
|
|
- Scalability: Adding 100 cameras with same model = 1 engine + N contexts (not 100 contexts!)
|
|
|
|
Best Practices:
|
|
- GPU-to-GPU: All inputs/outputs stay in VRAM (zero CPU transfers)
|
|
- Thread Safety: Mutex-based context borrowing (TensorRT best practice)
|
|
- Memory Efficient: Deduplicate by file hash, share engine across model_ids
|
|
- Concurrent: N contexts allow N parallel inferences per unique model
|
|
|
|
Example:
|
|
# 100 cameras, same model file
|
|
for i in range(100):
|
|
repo.load_model(f"camera_{i}", "yolov8.trt")
|
|
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
|
|
"""
|
|
|
|
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4):
|
|
"""
|
|
Initialize the model repository.
|
|
|
|
Args:
|
|
gpu_id: GPU device ID to use
|
|
default_num_contexts: Default number of execution contexts per unique engine
|
|
"""
|
|
self.gpu_id = gpu_id
|
|
self.device = torch.device(f'cuda:{gpu_id}')
|
|
self.default_num_contexts = default_num_contexts
|
|
|
|
# Model ID to engine mapping: model_id -> file_hash
|
|
self._model_to_hash: Dict[str, str] = {}
|
|
|
|
# Shared engines with context pools: file_hash -> SharedEngine
|
|
self._shared_engines: Dict[str, SharedEngine] = {}
|
|
|
|
# Locks for thread safety
|
|
self._repo_lock = threading.RLock()
|
|
|
|
# TensorRT logger
|
|
self.trt_logger = trt.Logger(trt.Logger.WARNING)
|
|
|
|
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
|
|
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
|
|
|
|
@staticmethod
|
|
def compute_file_hash(file_path: str) -> str:
|
|
"""
|
|
Compute SHA256 hash of a file for deduplication.
|
|
|
|
Args:
|
|
file_path: Path to the file
|
|
|
|
Returns:
|
|
Hexadecimal hash string
|
|
"""
|
|
sha256_hash = hashlib.sha256()
|
|
with open(file_path, "rb") as f:
|
|
# Read in chunks to handle large files efficiently
|
|
for byte_block in iter(lambda: f.read(65536), b""):
|
|
sha256_hash.update(byte_block)
|
|
return sha256_hash.hexdigest()
|
|
|
|
def _load_engine(self, file_path: str) -> trt.ICudaEngine:
|
|
"""
|
|
Load TensorRT engine from file.
|
|
|
|
Args:
|
|
file_path: Path to .trt or .engine file
|
|
|
|
Returns:
|
|
TensorRT engine
|
|
"""
|
|
runtime = trt.Runtime(self.trt_logger)
|
|
|
|
with open(file_path, 'rb') as f:
|
|
engine_data = f.read()
|
|
|
|
engine = runtime.deserialize_cuda_engine(engine_data)
|
|
if engine is None:
|
|
raise RuntimeError(f"Failed to load TensorRT engine from {file_path}")
|
|
|
|
return engine
|
|
|
|
def _extract_metadata(self, engine: trt.ICudaEngine,
|
|
file_path: str, file_hash: str) -> ModelMetadata:
|
|
"""
|
|
Extract metadata from TensorRT engine.
|
|
|
|
Args:
|
|
engine: TensorRT engine
|
|
file_path: Path to model file
|
|
file_hash: SHA256 hash of model file
|
|
|
|
Returns:
|
|
ModelMetadata object
|
|
"""
|
|
input_shapes = {}
|
|
output_shapes = {}
|
|
input_names = []
|
|
output_names = []
|
|
input_dtypes = {}
|
|
output_dtypes = {}
|
|
|
|
# TensorRT dtype to PyTorch dtype mapping
|
|
trt_to_torch_dtype = {
|
|
trt.DataType.FLOAT: torch.float32,
|
|
trt.DataType.HALF: torch.float16,
|
|
trt.DataType.INT8: torch.int8,
|
|
trt.DataType.INT32: torch.int32,
|
|
trt.DataType.BOOL: torch.bool,
|
|
}
|
|
|
|
# Iterate through all tensors (inputs and outputs) - TensorRT 10.x API
|
|
for i in range(engine.num_io_tensors):
|
|
name = engine.get_tensor_name(i)
|
|
shape = tuple(engine.get_tensor_shape(name))
|
|
dtype = trt_to_torch_dtype.get(engine.get_tensor_dtype(name), torch.float32)
|
|
mode = engine.get_tensor_mode(name)
|
|
|
|
if mode == trt.TensorIOMode.INPUT:
|
|
input_names.append(name)
|
|
input_shapes[name] = shape
|
|
input_dtypes[name] = dtype
|
|
else:
|
|
output_names.append(name)
|
|
output_shapes[name] = shape
|
|
output_dtypes[name] = dtype
|
|
|
|
return ModelMetadata(
|
|
file_path=file_path,
|
|
file_hash=file_hash,
|
|
input_shapes=input_shapes,
|
|
output_shapes=output_shapes,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
input_dtypes=input_dtypes,
|
|
output_dtypes=output_dtypes
|
|
)
|
|
|
|
def load_model(self, model_id: str, file_path: str,
|
|
num_contexts: Optional[int] = None,
|
|
force_reload: bool = False) -> ModelMetadata:
|
|
"""
|
|
Load a TensorRT model with the given ID.
|
|
|
|
Deduplication: If a model with the same file hash is already loaded, the model_id
|
|
is simply mapped to the existing SharedEngine (no new engine or contexts created).
|
|
|
|
Args:
|
|
model_id: User-defined identifier for this model (e.g., "camera_1")
|
|
file_path: Path to TensorRT engine file (.trt or .engine)
|
|
num_contexts: Number of execution contexts in pool (None = use default)
|
|
force_reload: If True, reload even if model_id exists
|
|
|
|
Returns:
|
|
ModelMetadata for the loaded model
|
|
|
|
Raises:
|
|
FileNotFoundError: If model file doesn't exist
|
|
RuntimeError: If engine loading fails
|
|
ValueError: If model_id already exists and force_reload is False
|
|
"""
|
|
file_path = str(Path(file_path).resolve())
|
|
|
|
if not Path(file_path).exists():
|
|
raise FileNotFoundError(f"Model file not found: {file_path}")
|
|
|
|
if num_contexts is None:
|
|
num_contexts = self.default_num_contexts
|
|
|
|
with self._repo_lock:
|
|
# Check if model_id already exists
|
|
if model_id in self._model_to_hash and not force_reload:
|
|
raise ValueError(
|
|
f"Model ID '{model_id}' already exists. "
|
|
f"Use force_reload=True to reload or choose a different ID."
|
|
)
|
|
|
|
# Unload existing model if force_reload
|
|
if model_id in self._model_to_hash and force_reload:
|
|
self.unload_model(model_id)
|
|
|
|
# Compute file hash for deduplication
|
|
print(f"Computing hash for {file_path}...")
|
|
file_hash = self.compute_file_hash(file_path)
|
|
print(f"File hash: {file_hash[:16]}...")
|
|
|
|
# Check if this file is already loaded (deduplication)
|
|
if file_hash in self._shared_engines:
|
|
shared_engine = self._shared_engines[file_hash]
|
|
print(f"Engine already loaded (hash match), reusing engine and context pool...")
|
|
print(f" Existing model_ids using this engine: {shared_engine.model_ids}")
|
|
else:
|
|
# Load new engine
|
|
print(f"Loading TensorRT engine from {file_path}...")
|
|
engine = self._load_engine(file_path)
|
|
|
|
# Extract metadata
|
|
metadata = self._extract_metadata(engine, file_path, file_hash)
|
|
|
|
# Create shared engine with context pool
|
|
shared_engine = SharedEngine(
|
|
engine=engine,
|
|
file_hash=file_hash,
|
|
file_path=file_path,
|
|
num_contexts=num_contexts,
|
|
device=self.device,
|
|
metadata=metadata
|
|
)
|
|
self._shared_engines[file_hash] = shared_engine
|
|
|
|
# Add this model_id to the shared engine
|
|
shared_engine.add_model_id(model_id)
|
|
|
|
# Map model_id to file_hash
|
|
self._model_to_hash[model_id] = file_hash
|
|
|
|
print(f"Model '{model_id}' loaded successfully")
|
|
print(f" Inputs: {shared_engine.metadata.input_names}")
|
|
for name in shared_engine.metadata.input_names:
|
|
print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})")
|
|
print(f" Outputs: {shared_engine.metadata.output_names}")
|
|
for name in shared_engine.metadata.output_names:
|
|
print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})")
|
|
print(f" Context pool size: {num_contexts}")
|
|
print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}")
|
|
print(f" Unique engines in VRAM: {len(self._shared_engines)}")
|
|
|
|
return shared_engine.metadata
|
|
|
|
def infer(self, model_id: str, inputs: Dict[str, torch.Tensor],
|
|
synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Run GPU-to-GPU inference with the specified model using context pooling.
|
|
|
|
All inputs must be CUDA tensors and outputs will be CUDA tensors (stays in VRAM).
|
|
Thread-safe: Borrows an execution context from the pool with mutex locking.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
inputs: Dictionary mapping input names to CUDA tensors
|
|
synchronize: If True, wait for inference to complete. If False, async execution.
|
|
timeout: Max time to wait for available context (seconds)
|
|
|
|
Returns:
|
|
Dictionary mapping output names to CUDA tensors (in VRAM)
|
|
|
|
Raises:
|
|
KeyError: If model_id not found
|
|
ValueError: If inputs don't match expected shapes or are not on GPU
|
|
RuntimeError: If no context available within timeout
|
|
"""
|
|
# Get shared engine
|
|
if model_id not in self._model_to_hash:
|
|
raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}")
|
|
|
|
file_hash = self._model_to_hash[model_id]
|
|
shared_engine = self._shared_engines[file_hash]
|
|
metadata = shared_engine.metadata
|
|
|
|
# Validate inputs
|
|
for name in metadata.input_names:
|
|
if name not in inputs:
|
|
raise ValueError(f"Missing required input: {name}")
|
|
|
|
tensor = inputs[name]
|
|
if not tensor.is_cuda:
|
|
raise ValueError(f"Input '{name}' must be a CUDA tensor (on GPU)")
|
|
|
|
# Check device
|
|
if tensor.device != self.device:
|
|
print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}")
|
|
inputs[name] = tensor.to(self.device)
|
|
|
|
# Acquire context from pool (mutex-based)
|
|
exec_ctx = shared_engine.acquire_context(timeout=timeout)
|
|
if exec_ctx is None:
|
|
raise RuntimeError(
|
|
f"No execution context available for model '{model_id}' within {timeout}s. "
|
|
f"All {shared_engine.num_contexts} contexts are busy."
|
|
)
|
|
|
|
try:
|
|
# Prepare output tensors
|
|
outputs = {}
|
|
|
|
# Set input tensors - TensorRT 10.x API
|
|
for name in metadata.input_names:
|
|
input_tensor = inputs[name].contiguous()
|
|
exec_ctx.context.set_tensor_address(name, input_tensor.data_ptr())
|
|
|
|
# Allocate and set output tensors
|
|
for name in metadata.output_names:
|
|
output_shape = metadata.output_shapes[name]
|
|
output_dtype = metadata.output_dtypes[name]
|
|
|
|
output_tensor = torch.empty(
|
|
output_shape,
|
|
dtype=output_dtype,
|
|
device=self.device
|
|
)
|
|
|
|
outputs[name] = output_tensor
|
|
exec_ctx.context.set_tensor_address(name, output_tensor.data_ptr())
|
|
|
|
# Execute inference on context's stream - TensorRT 10.x API
|
|
with torch.cuda.stream(exec_ctx.stream):
|
|
success = exec_ctx.context.execute_async_v3(
|
|
stream_handle=exec_ctx.stream.cuda_stream
|
|
)
|
|
|
|
if not success:
|
|
raise RuntimeError(f"Inference failed for model '{model_id}'")
|
|
|
|
# Synchronize if requested
|
|
if synchronize:
|
|
exec_ctx.stream.synchronize()
|
|
|
|
return outputs
|
|
|
|
finally:
|
|
# Always release context back to pool
|
|
shared_engine.release_context(exec_ctx)
|
|
|
|
def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]],
|
|
synchronize: bool = True) -> List[Dict[str, torch.Tensor]]:
|
|
"""
|
|
Run inference on multiple inputs.
|
|
Contexts are borrowed/returned for each input, enabling parallel processing.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
batch_inputs: List of input dictionaries
|
|
synchronize: If True, wait for all inferences to complete
|
|
|
|
Returns:
|
|
List of output dictionaries
|
|
"""
|
|
results = []
|
|
for inputs in batch_inputs:
|
|
outputs = self.infer(model_id, inputs, synchronize=synchronize)
|
|
results.append(outputs)
|
|
|
|
return results
|
|
|
|
def unload_model(self, model_id: str):
|
|
"""
|
|
Unload a model from the repository.
|
|
|
|
Removes the model_id reference from the shared engine. If this was the last
|
|
reference, the engine and all its contexts will be fully unloaded from VRAM.
|
|
|
|
Args:
|
|
model_id: Model identifier to unload
|
|
"""
|
|
with self._repo_lock:
|
|
if model_id not in self._model_to_hash:
|
|
print(f"Warning: Model '{model_id}' not found")
|
|
return
|
|
|
|
file_hash = self._model_to_hash[model_id]
|
|
|
|
# Remove model_id from shared engine
|
|
if file_hash in self._shared_engines:
|
|
shared_engine = self._shared_engines[file_hash]
|
|
remaining_refs = shared_engine.remove_model_id(model_id)
|
|
|
|
# If no more references, cleanup engine and contexts
|
|
if remaining_refs == 0:
|
|
shared_engine.cleanup()
|
|
del self._shared_engines[file_hash]
|
|
print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)")
|
|
else:
|
|
print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)")
|
|
|
|
# Remove from model_id mapping
|
|
del self._model_to_hash[model_id]
|
|
|
|
def get_metadata(self, model_id: str) -> Optional[ModelMetadata]:
|
|
"""
|
|
Get metadata for a loaded model.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
ModelMetadata or None if not found
|
|
"""
|
|
if model_id not in self._model_to_hash:
|
|
return None
|
|
|
|
file_hash = self._model_to_hash[model_id]
|
|
if file_hash not in self._shared_engines:
|
|
return None
|
|
|
|
return self._shared_engines[file_hash].metadata
|
|
|
|
def list_models(self) -> Dict[str, ModelMetadata]:
|
|
"""
|
|
List all loaded models.
|
|
|
|
Returns:
|
|
Dictionary mapping model_id to ModelMetadata
|
|
"""
|
|
with self._repo_lock:
|
|
result = {}
|
|
for model_id, file_hash in self._model_to_hash.items():
|
|
if file_hash in self._shared_engines:
|
|
result[model_id] = self._shared_engines[file_hash].metadata
|
|
return result
|
|
|
|
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get detailed information about a loaded model.
|
|
|
|
Args:
|
|
model_id: Model identifier
|
|
|
|
Returns:
|
|
Dictionary with model information or None if not found
|
|
"""
|
|
if model_id not in self._model_to_hash:
|
|
return None
|
|
|
|
file_hash = self._model_to_hash[model_id]
|
|
if file_hash not in self._shared_engines:
|
|
return None
|
|
|
|
shared_engine = self._shared_engines[file_hash]
|
|
metadata = shared_engine.metadata
|
|
|
|
return {
|
|
'model_id': model_id,
|
|
'file_path': metadata.file_path,
|
|
'file_hash': metadata.file_hash[:16] + '...',
|
|
'engine_references': shared_engine.get_reference_count(),
|
|
'context_pool_size': shared_engine.num_contexts,
|
|
'shared_with_model_ids': list(shared_engine.model_ids),
|
|
'inputs': {
|
|
name: {
|
|
'shape': metadata.input_shapes[name],
|
|
'dtype': str(metadata.input_dtypes[name])
|
|
}
|
|
for name in metadata.input_names
|
|
},
|
|
'outputs': {
|
|
name: {
|
|
'shape': metadata.output_shapes[name],
|
|
'dtype': str(metadata.output_dtypes[name])
|
|
}
|
|
for name in metadata.output_names
|
|
}
|
|
}
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Get repository statistics.
|
|
|
|
Returns:
|
|
Dictionary with stats about loaded models and memory usage
|
|
"""
|
|
with self._repo_lock:
|
|
total_contexts = sum(
|
|
engine.num_contexts
|
|
for engine in self._shared_engines.values()
|
|
)
|
|
|
|
return {
|
|
'total_model_ids': len(self._model_to_hash),
|
|
'unique_engines': len(self._shared_engines),
|
|
'total_contexts': total_contexts,
|
|
'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
|
|
'gpu_id': self.gpu_id,
|
|
'models': list(self._model_to_hash.keys())
|
|
}
|
|
|
|
def __repr__(self):
|
|
with self._repo_lock:
|
|
return (f"TensorRTModelRepository(gpu={self.gpu_id}, "
|
|
f"model_ids={len(self._model_to_hash)}, "
|
|
f"unique_engines={len(self._shared_engines)})")
|
|
|
|
def __del__(self):
|
|
"""Cleanup all models on deletion"""
|
|
with self._repo_lock:
|
|
model_ids = list(self._model_to_hash.keys())
|
|
for model_id in model_ids:
|
|
self.unload_model(model_id)
|