feat: inference subsystem and optimization to decoder

This commit is contained in:
Siwat Sirichai 2025-11-09 00:57:08 +07:00
commit 3c83a57e44
19 changed files with 3897 additions and 0 deletions

View file

@ -0,0 +1,631 @@
import threading
import hashlib
from typing import Optional, Dict, Any, List, Tuple
from pathlib import Path
from queue import Queue
import torch
import tensorrt as trt
from dataclasses import dataclass
@dataclass
class ModelMetadata:
"""Metadata for a loaded TensorRT model"""
file_path: str
file_hash: str
input_shapes: Dict[str, Tuple[int, ...]]
output_shapes: Dict[str, Tuple[int, ...]]
input_names: List[str]
output_names: List[str]
input_dtypes: Dict[str, torch.dtype]
output_dtypes: Dict[str, torch.dtype]
class ExecutionContext:
"""
Wrapper for TensorRT execution context with CUDA stream.
Used in context pool for load balancing.
"""
def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream,
context_id: int, device: torch.device):
self.context = context
self.stream = stream
self.context_id = context_id
self.device = device
self.in_use = False
self.lock = threading.Lock()
def __repr__(self):
return f"ExecutionContext(id={self.context_id}, in_use={self.in_use})"
class SharedEngine:
"""
Shared TensorRT engine with context pool for load balancing.
Architecture:
- One engine shared across all model_ids with same file hash
- Pool of N execution contexts for concurrent inference
- Contexts are borrowed/returned using mutex locks
- Load balancing: contexts distributed across requests
"""
def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str,
num_contexts: int, device: torch.device, metadata: ModelMetadata):
self.engine = engine
self.file_hash = file_hash
self.file_path = file_path
self.metadata = metadata
self.device = device
self.num_contexts = num_contexts
# Create context pool
self.context_pool: List[ExecutionContext] = []
self.available_contexts: Queue[ExecutionContext] = Queue()
for i in range(num_contexts):
ctx = engine.create_execution_context()
if ctx is None:
raise RuntimeError(f"Failed to create execution context {i}")
stream = torch.cuda.Stream(device=device)
exec_ctx = ExecutionContext(ctx, stream, i, device)
self.context_pool.append(exec_ctx)
self.available_contexts.put(exec_ctx)
# Model IDs referencing this engine
self.model_ids: set = set()
self.lock = threading.Lock()
print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...")
def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]:
"""
Acquire an available execution context from the pool.
Blocks if all contexts are in use.
Args:
timeout: Max time to wait for context (None = wait forever)
Returns:
ExecutionContext or None if timeout
"""
try:
exec_ctx = self.available_contexts.get(timeout=timeout)
with exec_ctx.lock:
exec_ctx.in_use = True
return exec_ctx
except:
return None
def release_context(self, exec_ctx: ExecutionContext):
"""
Return a context to the pool.
Args:
exec_ctx: Context to release
"""
with exec_ctx.lock:
exec_ctx.in_use = False
self.available_contexts.put(exec_ctx)
def add_model_id(self, model_id: str):
"""Add a model_id reference to this engine"""
with self.lock:
self.model_ids.add(model_id)
def remove_model_id(self, model_id: str) -> int:
"""
Remove a model_id reference from this engine.
Returns the number of remaining references.
"""
with self.lock:
self.model_ids.discard(model_id)
return len(self.model_ids)
def get_reference_count(self) -> int:
"""Get number of model_ids using this engine"""
with self.lock:
return len(self.model_ids)
def cleanup(self):
"""Cleanup all contexts"""
for exec_ctx in self.context_pool:
del exec_ctx.context
self.context_pool.clear()
del self.engine
class TensorRTModelRepository:
"""
Thread-safe repository for TensorRT models with context pooling and deduplication.
Architecture:
- Deduplication: Multiple model_ids with same file share one engine
- Context Pool: Each unique engine has N execution contexts (configurable)
- Load Balancing: Contexts are borrowed/returned via mutex queue
- Scalability: Adding 100 cameras with same model = 1 engine + N contexts (not 100 contexts!)
Best Practices:
- GPU-to-GPU: All inputs/outputs stay in VRAM (zero CPU transfers)
- Thread Safety: Mutex-based context borrowing (TensorRT best practice)
- Memory Efficient: Deduplicate by file hash, share engine across model_ids
- Concurrent: N contexts allow N parallel inferences per unique model
Example:
# 100 cameras, same model file
for i in range(100):
repo.load_model(f"camera_{i}", "yolov8.trt")
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
"""
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4):
"""
Initialize the model repository.
Args:
gpu_id: GPU device ID to use
default_num_contexts: Default number of execution contexts per unique engine
"""
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.default_num_contexts = default_num_contexts
# Model ID to engine mapping: model_id -> file_hash
self._model_to_hash: Dict[str, str] = {}
# Shared engines with context pools: file_hash -> SharedEngine
self._shared_engines: Dict[str, SharedEngine] = {}
# Locks for thread safety
self._repo_lock = threading.RLock()
# TensorRT logger
self.trt_logger = trt.Logger(trt.Logger.WARNING)
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
@staticmethod
def compute_file_hash(file_path: str) -> str:
"""
Compute SHA256 hash of a file for deduplication.
Args:
file_path: Path to the file
Returns:
Hexadecimal hash string
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
# Read in chunks to handle large files efficiently
for byte_block in iter(lambda: f.read(65536), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def _load_engine(self, file_path: str) -> trt.ICudaEngine:
"""
Load TensorRT engine from file.
Args:
file_path: Path to .trt or .engine file
Returns:
TensorRT engine
"""
runtime = trt.Runtime(self.trt_logger)
with open(file_path, 'rb') as f:
engine_data = f.read()
engine = runtime.deserialize_cuda_engine(engine_data)
if engine is None:
raise RuntimeError(f"Failed to load TensorRT engine from {file_path}")
return engine
def _extract_metadata(self, engine: trt.ICudaEngine,
file_path: str, file_hash: str) -> ModelMetadata:
"""
Extract metadata from TensorRT engine.
Args:
engine: TensorRT engine
file_path: Path to model file
file_hash: SHA256 hash of model file
Returns:
ModelMetadata object
"""
input_shapes = {}
output_shapes = {}
input_names = []
output_names = []
input_dtypes = {}
output_dtypes = {}
# TensorRT dtype to PyTorch dtype mapping
trt_to_torch_dtype = {
trt.DataType.FLOAT: torch.float32,
trt.DataType.HALF: torch.float16,
trt.DataType.INT8: torch.int8,
trt.DataType.INT32: torch.int32,
trt.DataType.BOOL: torch.bool,
}
# Iterate through all tensors (inputs and outputs) - TensorRT 10.x API
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
shape = tuple(engine.get_tensor_shape(name))
dtype = trt_to_torch_dtype.get(engine.get_tensor_dtype(name), torch.float32)
mode = engine.get_tensor_mode(name)
if mode == trt.TensorIOMode.INPUT:
input_names.append(name)
input_shapes[name] = shape
input_dtypes[name] = dtype
else:
output_names.append(name)
output_shapes[name] = shape
output_dtypes[name] = dtype
return ModelMetadata(
file_path=file_path,
file_hash=file_hash,
input_shapes=input_shapes,
output_shapes=output_shapes,
input_names=input_names,
output_names=output_names,
input_dtypes=input_dtypes,
output_dtypes=output_dtypes
)
def load_model(self, model_id: str, file_path: str,
num_contexts: Optional[int] = None,
force_reload: bool = False) -> ModelMetadata:
"""
Load a TensorRT model with the given ID.
Deduplication: If a model with the same file hash is already loaded, the model_id
is simply mapped to the existing SharedEngine (no new engine or contexts created).
Args:
model_id: User-defined identifier for this model (e.g., "camera_1")
file_path: Path to TensorRT engine file (.trt or .engine)
num_contexts: Number of execution contexts in pool (None = use default)
force_reload: If True, reload even if model_id exists
Returns:
ModelMetadata for the loaded model
Raises:
FileNotFoundError: If model file doesn't exist
RuntimeError: If engine loading fails
ValueError: If model_id already exists and force_reload is False
"""
file_path = str(Path(file_path).resolve())
if not Path(file_path).exists():
raise FileNotFoundError(f"Model file not found: {file_path}")
if num_contexts is None:
num_contexts = self.default_num_contexts
with self._repo_lock:
# Check if model_id already exists
if model_id in self._model_to_hash and not force_reload:
raise ValueError(
f"Model ID '{model_id}' already exists. "
f"Use force_reload=True to reload or choose a different ID."
)
# Unload existing model if force_reload
if model_id in self._model_to_hash and force_reload:
self.unload_model(model_id)
# Compute file hash for deduplication
print(f"Computing hash for {file_path}...")
file_hash = self.compute_file_hash(file_path)
print(f"File hash: {file_hash[:16]}...")
# Check if this file is already loaded (deduplication)
if file_hash in self._shared_engines:
shared_engine = self._shared_engines[file_hash]
print(f"Engine already loaded (hash match), reusing engine and context pool...")
print(f" Existing model_ids using this engine: {shared_engine.model_ids}")
else:
# Load new engine
print(f"Loading TensorRT engine from {file_path}...")
engine = self._load_engine(file_path)
# Extract metadata
metadata = self._extract_metadata(engine, file_path, file_hash)
# Create shared engine with context pool
shared_engine = SharedEngine(
engine=engine,
file_hash=file_hash,
file_path=file_path,
num_contexts=num_contexts,
device=self.device,
metadata=metadata
)
self._shared_engines[file_hash] = shared_engine
# Add this model_id to the shared engine
shared_engine.add_model_id(model_id)
# Map model_id to file_hash
self._model_to_hash[model_id] = file_hash
print(f"Model '{model_id}' loaded successfully")
print(f" Inputs: {shared_engine.metadata.input_names}")
for name in shared_engine.metadata.input_names:
print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})")
print(f" Outputs: {shared_engine.metadata.output_names}")
for name in shared_engine.metadata.output_names:
print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})")
print(f" Context pool size: {num_contexts}")
print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}")
print(f" Unique engines in VRAM: {len(self._shared_engines)}")
return shared_engine.metadata
def infer(self, model_id: str, inputs: Dict[str, torch.Tensor],
synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]:
"""
Run GPU-to-GPU inference with the specified model using context pooling.
All inputs must be CUDA tensors and outputs will be CUDA tensors (stays in VRAM).
Thread-safe: Borrows an execution context from the pool with mutex locking.
Args:
model_id: Model identifier
inputs: Dictionary mapping input names to CUDA tensors
synchronize: If True, wait for inference to complete. If False, async execution.
timeout: Max time to wait for available context (seconds)
Returns:
Dictionary mapping output names to CUDA tensors (in VRAM)
Raises:
KeyError: If model_id not found
ValueError: If inputs don't match expected shapes or are not on GPU
RuntimeError: If no context available within timeout
"""
# Get shared engine
if model_id not in self._model_to_hash:
raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}")
file_hash = self._model_to_hash[model_id]
shared_engine = self._shared_engines[file_hash]
metadata = shared_engine.metadata
# Validate inputs
for name in metadata.input_names:
if name not in inputs:
raise ValueError(f"Missing required input: {name}")
tensor = inputs[name]
if not tensor.is_cuda:
raise ValueError(f"Input '{name}' must be a CUDA tensor (on GPU)")
# Check device
if tensor.device != self.device:
print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}")
inputs[name] = tensor.to(self.device)
# Acquire context from pool (mutex-based)
exec_ctx = shared_engine.acquire_context(timeout=timeout)
if exec_ctx is None:
raise RuntimeError(
f"No execution context available for model '{model_id}' within {timeout}s. "
f"All {shared_engine.num_contexts} contexts are busy."
)
try:
# Prepare output tensors
outputs = {}
# Set input tensors - TensorRT 10.x API
for name in metadata.input_names:
input_tensor = inputs[name].contiguous()
exec_ctx.context.set_tensor_address(name, input_tensor.data_ptr())
# Allocate and set output tensors
for name in metadata.output_names:
output_shape = metadata.output_shapes[name]
output_dtype = metadata.output_dtypes[name]
output_tensor = torch.empty(
output_shape,
dtype=output_dtype,
device=self.device
)
outputs[name] = output_tensor
exec_ctx.context.set_tensor_address(name, output_tensor.data_ptr())
# Execute inference on context's stream - TensorRT 10.x API
with torch.cuda.stream(exec_ctx.stream):
success = exec_ctx.context.execute_async_v3(
stream_handle=exec_ctx.stream.cuda_stream
)
if not success:
raise RuntimeError(f"Inference failed for model '{model_id}'")
# Synchronize if requested
if synchronize:
exec_ctx.stream.synchronize()
return outputs
finally:
# Always release context back to pool
shared_engine.release_context(exec_ctx)
def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]],
synchronize: bool = True) -> List[Dict[str, torch.Tensor]]:
"""
Run inference on multiple inputs.
Contexts are borrowed/returned for each input, enabling parallel processing.
Args:
model_id: Model identifier
batch_inputs: List of input dictionaries
synchronize: If True, wait for all inferences to complete
Returns:
List of output dictionaries
"""
results = []
for inputs in batch_inputs:
outputs = self.infer(model_id, inputs, synchronize=synchronize)
results.append(outputs)
return results
def unload_model(self, model_id: str):
"""
Unload a model from the repository.
Removes the model_id reference from the shared engine. If this was the last
reference, the engine and all its contexts will be fully unloaded from VRAM.
Args:
model_id: Model identifier to unload
"""
with self._repo_lock:
if model_id not in self._model_to_hash:
print(f"Warning: Model '{model_id}' not found")
return
file_hash = self._model_to_hash[model_id]
# Remove model_id from shared engine
if file_hash in self._shared_engines:
shared_engine = self._shared_engines[file_hash]
remaining_refs = shared_engine.remove_model_id(model_id)
# If no more references, cleanup engine and contexts
if remaining_refs == 0:
shared_engine.cleanup()
del self._shared_engines[file_hash]
print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)")
else:
print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)")
# Remove from model_id mapping
del self._model_to_hash[model_id]
def get_metadata(self, model_id: str) -> Optional[ModelMetadata]:
"""
Get metadata for a loaded model.
Args:
model_id: Model identifier
Returns:
ModelMetadata or None if not found
"""
if model_id not in self._model_to_hash:
return None
file_hash = self._model_to_hash[model_id]
if file_hash not in self._shared_engines:
return None
return self._shared_engines[file_hash].metadata
def list_models(self) -> Dict[str, ModelMetadata]:
"""
List all loaded models.
Returns:
Dictionary mapping model_id to ModelMetadata
"""
with self._repo_lock:
result = {}
for model_id, file_hash in self._model_to_hash.items():
if file_hash in self._shared_engines:
result[model_id] = self._shared_engines[file_hash].metadata
return result
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get detailed information about a loaded model.
Args:
model_id: Model identifier
Returns:
Dictionary with model information or None if not found
"""
if model_id not in self._model_to_hash:
return None
file_hash = self._model_to_hash[model_id]
if file_hash not in self._shared_engines:
return None
shared_engine = self._shared_engines[file_hash]
metadata = shared_engine.metadata
return {
'model_id': model_id,
'file_path': metadata.file_path,
'file_hash': metadata.file_hash[:16] + '...',
'engine_references': shared_engine.get_reference_count(),
'context_pool_size': shared_engine.num_contexts,
'shared_with_model_ids': list(shared_engine.model_ids),
'inputs': {
name: {
'shape': metadata.input_shapes[name],
'dtype': str(metadata.input_dtypes[name])
}
for name in metadata.input_names
},
'outputs': {
name: {
'shape': metadata.output_shapes[name],
'dtype': str(metadata.output_dtypes[name])
}
for name in metadata.output_names
}
}
def get_stats(self) -> Dict[str, Any]:
"""
Get repository statistics.
Returns:
Dictionary with stats about loaded models and memory usage
"""
with self._repo_lock:
total_contexts = sum(
engine.num_contexts
for engine in self._shared_engines.values()
)
return {
'total_model_ids': len(self._model_to_hash),
'unique_engines': len(self._shared_engines),
'total_contexts': total_contexts,
'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
'gpu_id': self.gpu_id,
'models': list(self._model_to_hash.keys())
}
def __repr__(self):
with self._repo_lock:
return (f"TensorRTModelRepository(gpu={self.gpu_id}, "
f"model_ids={len(self._model_to_hash)}, "
f"unique_engines={len(self._shared_engines)})")
def __del__(self):
"""Cleanup all models on deletion"""
with self._repo_lock:
model_ids = list(self._model_to_hash.keys())
for model_id in model_ids:
self.unload_model(model_id)