python-rtsp-worker/services/model_repository.py

631 lines
22 KiB
Python

import threading
import hashlib
from typing import Optional, Dict, Any, List, Tuple
from pathlib import Path
from queue import Queue
import torch
import tensorrt as trt
from dataclasses import dataclass
@dataclass
class ModelMetadata:
"""Metadata for a loaded TensorRT model"""
file_path: str
file_hash: str
input_shapes: Dict[str, Tuple[int, ...]]
output_shapes: Dict[str, Tuple[int, ...]]
input_names: List[str]
output_names: List[str]
input_dtypes: Dict[str, torch.dtype]
output_dtypes: Dict[str, torch.dtype]
class ExecutionContext:
"""
Wrapper for TensorRT execution context with CUDA stream.
Used in context pool for load balancing.
"""
def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream,
context_id: int, device: torch.device):
self.context = context
self.stream = stream
self.context_id = context_id
self.device = device
self.in_use = False
self.lock = threading.Lock()
def __repr__(self):
return f"ExecutionContext(id={self.context_id}, in_use={self.in_use})"
class SharedEngine:
"""
Shared TensorRT engine with context pool for load balancing.
Architecture:
- One engine shared across all model_ids with same file hash
- Pool of N execution contexts for concurrent inference
- Contexts are borrowed/returned using mutex locks
- Load balancing: contexts distributed across requests
"""
def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str,
num_contexts: int, device: torch.device, metadata: ModelMetadata):
self.engine = engine
self.file_hash = file_hash
self.file_path = file_path
self.metadata = metadata
self.device = device
self.num_contexts = num_contexts
# Create context pool
self.context_pool: List[ExecutionContext] = []
self.available_contexts: Queue[ExecutionContext] = Queue()
for i in range(num_contexts):
ctx = engine.create_execution_context()
if ctx is None:
raise RuntimeError(f"Failed to create execution context {i}")
stream = torch.cuda.Stream(device=device)
exec_ctx = ExecutionContext(ctx, stream, i, device)
self.context_pool.append(exec_ctx)
self.available_contexts.put(exec_ctx)
# Model IDs referencing this engine
self.model_ids: set = set()
self.lock = threading.Lock()
print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...")
def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]:
"""
Acquire an available execution context from the pool.
Blocks if all contexts are in use.
Args:
timeout: Max time to wait for context (None = wait forever)
Returns:
ExecutionContext or None if timeout
"""
try:
exec_ctx = self.available_contexts.get(timeout=timeout)
with exec_ctx.lock:
exec_ctx.in_use = True
return exec_ctx
except:
return None
def release_context(self, exec_ctx: ExecutionContext):
"""
Return a context to the pool.
Args:
exec_ctx: Context to release
"""
with exec_ctx.lock:
exec_ctx.in_use = False
self.available_contexts.put(exec_ctx)
def add_model_id(self, model_id: str):
"""Add a model_id reference to this engine"""
with self.lock:
self.model_ids.add(model_id)
def remove_model_id(self, model_id: str) -> int:
"""
Remove a model_id reference from this engine.
Returns the number of remaining references.
"""
with self.lock:
self.model_ids.discard(model_id)
return len(self.model_ids)
def get_reference_count(self) -> int:
"""Get number of model_ids using this engine"""
with self.lock:
return len(self.model_ids)
def cleanup(self):
"""Cleanup all contexts"""
for exec_ctx in self.context_pool:
del exec_ctx.context
self.context_pool.clear()
del self.engine
class TensorRTModelRepository:
"""
Thread-safe repository for TensorRT models with context pooling and deduplication.
Architecture:
- Deduplication: Multiple model_ids with same file → share one engine
- Context Pool: Each unique engine has N execution contexts (configurable)
- Load Balancing: Contexts are borrowed/returned via mutex queue
- Scalability: Adding 100 cameras with same model = 1 engine + N contexts (not 100 contexts!)
Best Practices:
- GPU-to-GPU: All inputs/outputs stay in VRAM (zero CPU transfers)
- Thread Safety: Mutex-based context borrowing (TensorRT best practice)
- Memory Efficient: Deduplicate by file hash, share engine across model_ids
- Concurrent: N contexts allow N parallel inferences per unique model
Example:
# 100 cameras, same model file
for i in range(100):
repo.load_model(f"camera_{i}", "yolov8.trt")
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
"""
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4):
"""
Initialize the model repository.
Args:
gpu_id: GPU device ID to use
default_num_contexts: Default number of execution contexts per unique engine
"""
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.default_num_contexts = default_num_contexts
# Model ID to engine mapping: model_id -> file_hash
self._model_to_hash: Dict[str, str] = {}
# Shared engines with context pools: file_hash -> SharedEngine
self._shared_engines: Dict[str, SharedEngine] = {}
# Locks for thread safety
self._repo_lock = threading.RLock()
# TensorRT logger
self.trt_logger = trt.Logger(trt.Logger.WARNING)
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
@staticmethod
def compute_file_hash(file_path: str) -> str:
"""
Compute SHA256 hash of a file for deduplication.
Args:
file_path: Path to the file
Returns:
Hexadecimal hash string
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
# Read in chunks to handle large files efficiently
for byte_block in iter(lambda: f.read(65536), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def _load_engine(self, file_path: str) -> trt.ICudaEngine:
"""
Load TensorRT engine from file.
Args:
file_path: Path to .trt or .engine file
Returns:
TensorRT engine
"""
runtime = trt.Runtime(self.trt_logger)
with open(file_path, 'rb') as f:
engine_data = f.read()
engine = runtime.deserialize_cuda_engine(engine_data)
if engine is None:
raise RuntimeError(f"Failed to load TensorRT engine from {file_path}")
return engine
def _extract_metadata(self, engine: trt.ICudaEngine,
file_path: str, file_hash: str) -> ModelMetadata:
"""
Extract metadata from TensorRT engine.
Args:
engine: TensorRT engine
file_path: Path to model file
file_hash: SHA256 hash of model file
Returns:
ModelMetadata object
"""
input_shapes = {}
output_shapes = {}
input_names = []
output_names = []
input_dtypes = {}
output_dtypes = {}
# TensorRT dtype to PyTorch dtype mapping
trt_to_torch_dtype = {
trt.DataType.FLOAT: torch.float32,
trt.DataType.HALF: torch.float16,
trt.DataType.INT8: torch.int8,
trt.DataType.INT32: torch.int32,
trt.DataType.BOOL: torch.bool,
}
# Iterate through all tensors (inputs and outputs) - TensorRT 10.x API
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
shape = tuple(engine.get_tensor_shape(name))
dtype = trt_to_torch_dtype.get(engine.get_tensor_dtype(name), torch.float32)
mode = engine.get_tensor_mode(name)
if mode == trt.TensorIOMode.INPUT:
input_names.append(name)
input_shapes[name] = shape
input_dtypes[name] = dtype
else:
output_names.append(name)
output_shapes[name] = shape
output_dtypes[name] = dtype
return ModelMetadata(
file_path=file_path,
file_hash=file_hash,
input_shapes=input_shapes,
output_shapes=output_shapes,
input_names=input_names,
output_names=output_names,
input_dtypes=input_dtypes,
output_dtypes=output_dtypes
)
def load_model(self, model_id: str, file_path: str,
num_contexts: Optional[int] = None,
force_reload: bool = False) -> ModelMetadata:
"""
Load a TensorRT model with the given ID.
Deduplication: If a model with the same file hash is already loaded, the model_id
is simply mapped to the existing SharedEngine (no new engine or contexts created).
Args:
model_id: User-defined identifier for this model (e.g., "camera_1")
file_path: Path to TensorRT engine file (.trt or .engine)
num_contexts: Number of execution contexts in pool (None = use default)
force_reload: If True, reload even if model_id exists
Returns:
ModelMetadata for the loaded model
Raises:
FileNotFoundError: If model file doesn't exist
RuntimeError: If engine loading fails
ValueError: If model_id already exists and force_reload is False
"""
file_path = str(Path(file_path).resolve())
if not Path(file_path).exists():
raise FileNotFoundError(f"Model file not found: {file_path}")
if num_contexts is None:
num_contexts = self.default_num_contexts
with self._repo_lock:
# Check if model_id already exists
if model_id in self._model_to_hash and not force_reload:
raise ValueError(
f"Model ID '{model_id}' already exists. "
f"Use force_reload=True to reload or choose a different ID."
)
# Unload existing model if force_reload
if model_id in self._model_to_hash and force_reload:
self.unload_model(model_id)
# Compute file hash for deduplication
print(f"Computing hash for {file_path}...")
file_hash = self.compute_file_hash(file_path)
print(f"File hash: {file_hash[:16]}...")
# Check if this file is already loaded (deduplication)
if file_hash in self._shared_engines:
shared_engine = self._shared_engines[file_hash]
print(f"Engine already loaded (hash match), reusing engine and context pool...")
print(f" Existing model_ids using this engine: {shared_engine.model_ids}")
else:
# Load new engine
print(f"Loading TensorRT engine from {file_path}...")
engine = self._load_engine(file_path)
# Extract metadata
metadata = self._extract_metadata(engine, file_path, file_hash)
# Create shared engine with context pool
shared_engine = SharedEngine(
engine=engine,
file_hash=file_hash,
file_path=file_path,
num_contexts=num_contexts,
device=self.device,
metadata=metadata
)
self._shared_engines[file_hash] = shared_engine
# Add this model_id to the shared engine
shared_engine.add_model_id(model_id)
# Map model_id to file_hash
self._model_to_hash[model_id] = file_hash
print(f"Model '{model_id}' loaded successfully")
print(f" Inputs: {shared_engine.metadata.input_names}")
for name in shared_engine.metadata.input_names:
print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})")
print(f" Outputs: {shared_engine.metadata.output_names}")
for name in shared_engine.metadata.output_names:
print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})")
print(f" Context pool size: {num_contexts}")
print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}")
print(f" Unique engines in VRAM: {len(self._shared_engines)}")
return shared_engine.metadata
def infer(self, model_id: str, inputs: Dict[str, torch.Tensor],
synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]:
"""
Run GPU-to-GPU inference with the specified model using context pooling.
All inputs must be CUDA tensors and outputs will be CUDA tensors (stays in VRAM).
Thread-safe: Borrows an execution context from the pool with mutex locking.
Args:
model_id: Model identifier
inputs: Dictionary mapping input names to CUDA tensors
synchronize: If True, wait for inference to complete. If False, async execution.
timeout: Max time to wait for available context (seconds)
Returns:
Dictionary mapping output names to CUDA tensors (in VRAM)
Raises:
KeyError: If model_id not found
ValueError: If inputs don't match expected shapes or are not on GPU
RuntimeError: If no context available within timeout
"""
# Get shared engine
if model_id not in self._model_to_hash:
raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}")
file_hash = self._model_to_hash[model_id]
shared_engine = self._shared_engines[file_hash]
metadata = shared_engine.metadata
# Validate inputs
for name in metadata.input_names:
if name not in inputs:
raise ValueError(f"Missing required input: {name}")
tensor = inputs[name]
if not tensor.is_cuda:
raise ValueError(f"Input '{name}' must be a CUDA tensor (on GPU)")
# Check device
if tensor.device != self.device:
print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}")
inputs[name] = tensor.to(self.device)
# Acquire context from pool (mutex-based)
exec_ctx = shared_engine.acquire_context(timeout=timeout)
if exec_ctx is None:
raise RuntimeError(
f"No execution context available for model '{model_id}' within {timeout}s. "
f"All {shared_engine.num_contexts} contexts are busy."
)
try:
# Prepare output tensors
outputs = {}
# Set input tensors - TensorRT 10.x API
for name in metadata.input_names:
input_tensor = inputs[name].contiguous()
exec_ctx.context.set_tensor_address(name, input_tensor.data_ptr())
# Allocate and set output tensors
for name in metadata.output_names:
output_shape = metadata.output_shapes[name]
output_dtype = metadata.output_dtypes[name]
output_tensor = torch.empty(
output_shape,
dtype=output_dtype,
device=self.device
)
outputs[name] = output_tensor
exec_ctx.context.set_tensor_address(name, output_tensor.data_ptr())
# Execute inference on context's stream - TensorRT 10.x API
with torch.cuda.stream(exec_ctx.stream):
success = exec_ctx.context.execute_async_v3(
stream_handle=exec_ctx.stream.cuda_stream
)
if not success:
raise RuntimeError(f"Inference failed for model '{model_id}'")
# Synchronize if requested
if synchronize:
exec_ctx.stream.synchronize()
return outputs
finally:
# Always release context back to pool
shared_engine.release_context(exec_ctx)
def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]],
synchronize: bool = True) -> List[Dict[str, torch.Tensor]]:
"""
Run inference on multiple inputs.
Contexts are borrowed/returned for each input, enabling parallel processing.
Args:
model_id: Model identifier
batch_inputs: List of input dictionaries
synchronize: If True, wait for all inferences to complete
Returns:
List of output dictionaries
"""
results = []
for inputs in batch_inputs:
outputs = self.infer(model_id, inputs, synchronize=synchronize)
results.append(outputs)
return results
def unload_model(self, model_id: str):
"""
Unload a model from the repository.
Removes the model_id reference from the shared engine. If this was the last
reference, the engine and all its contexts will be fully unloaded from VRAM.
Args:
model_id: Model identifier to unload
"""
with self._repo_lock:
if model_id not in self._model_to_hash:
print(f"Warning: Model '{model_id}' not found")
return
file_hash = self._model_to_hash[model_id]
# Remove model_id from shared engine
if file_hash in self._shared_engines:
shared_engine = self._shared_engines[file_hash]
remaining_refs = shared_engine.remove_model_id(model_id)
# If no more references, cleanup engine and contexts
if remaining_refs == 0:
shared_engine.cleanup()
del self._shared_engines[file_hash]
print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)")
else:
print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)")
# Remove from model_id mapping
del self._model_to_hash[model_id]
def get_metadata(self, model_id: str) -> Optional[ModelMetadata]:
"""
Get metadata for a loaded model.
Args:
model_id: Model identifier
Returns:
ModelMetadata or None if not found
"""
if model_id not in self._model_to_hash:
return None
file_hash = self._model_to_hash[model_id]
if file_hash not in self._shared_engines:
return None
return self._shared_engines[file_hash].metadata
def list_models(self) -> Dict[str, ModelMetadata]:
"""
List all loaded models.
Returns:
Dictionary mapping model_id to ModelMetadata
"""
with self._repo_lock:
result = {}
for model_id, file_hash in self._model_to_hash.items():
if file_hash in self._shared_engines:
result[model_id] = self._shared_engines[file_hash].metadata
return result
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get detailed information about a loaded model.
Args:
model_id: Model identifier
Returns:
Dictionary with model information or None if not found
"""
if model_id not in self._model_to_hash:
return None
file_hash = self._model_to_hash[model_id]
if file_hash not in self._shared_engines:
return None
shared_engine = self._shared_engines[file_hash]
metadata = shared_engine.metadata
return {
'model_id': model_id,
'file_path': metadata.file_path,
'file_hash': metadata.file_hash[:16] + '...',
'engine_references': shared_engine.get_reference_count(),
'context_pool_size': shared_engine.num_contexts,
'shared_with_model_ids': list(shared_engine.model_ids),
'inputs': {
name: {
'shape': metadata.input_shapes[name],
'dtype': str(metadata.input_dtypes[name])
}
for name in metadata.input_names
},
'outputs': {
name: {
'shape': metadata.output_shapes[name],
'dtype': str(metadata.output_dtypes[name])
}
for name in metadata.output_names
}
}
def get_stats(self) -> Dict[str, Any]:
"""
Get repository statistics.
Returns:
Dictionary with stats about loaded models and memory usage
"""
with self._repo_lock:
total_contexts = sum(
engine.num_contexts
for engine in self._shared_engines.values()
)
return {
'total_model_ids': len(self._model_to_hash),
'unique_engines': len(self._shared_engines),
'total_contexts': total_contexts,
'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
'gpu_id': self.gpu_id,
'models': list(self._model_to_hash.keys())
}
def __repr__(self):
with self._repo_lock:
return (f"TensorRTModelRepository(gpu={self.gpu_id}, "
f"model_ids={len(self._model_to_hash)}, "
f"unique_engines={len(self._shared_engines)})")
def __del__(self):
"""Cleanup all models on deletion"""
with self._repo_lock:
model_ids = list(self._model_to_hash.keys())
for model_id in model_ids:
self.unload_model(model_id)