import threading import hashlib from typing import Optional, Dict, Any, List, Tuple from pathlib import Path from queue import Queue import torch import tensorrt as trt from dataclasses import dataclass @dataclass class ModelMetadata: """Metadata for a loaded TensorRT model""" file_path: str file_hash: str input_shapes: Dict[str, Tuple[int, ...]] output_shapes: Dict[str, Tuple[int, ...]] input_names: List[str] output_names: List[str] input_dtypes: Dict[str, torch.dtype] output_dtypes: Dict[str, torch.dtype] class ExecutionContext: """ Wrapper for TensorRT execution context with CUDA stream. Used in context pool for load balancing. """ def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream, context_id: int, device: torch.device): self.context = context self.stream = stream self.context_id = context_id self.device = device self.in_use = False self.lock = threading.Lock() def __repr__(self): return f"ExecutionContext(id={self.context_id}, in_use={self.in_use})" class SharedEngine: """ Shared TensorRT engine with context pool for load balancing. Architecture: - One engine shared across all model_ids with same file hash - Pool of N execution contexts for concurrent inference - Contexts are borrowed/returned using mutex locks - Load balancing: contexts distributed across requests """ def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str, num_contexts: int, device: torch.device, metadata: ModelMetadata): self.engine = engine self.file_hash = file_hash self.file_path = file_path self.metadata = metadata self.device = device self.num_contexts = num_contexts # Create context pool self.context_pool: List[ExecutionContext] = [] self.available_contexts: Queue[ExecutionContext] = Queue() for i in range(num_contexts): ctx = engine.create_execution_context() if ctx is None: raise RuntimeError(f"Failed to create execution context {i}") stream = torch.cuda.Stream(device=device) exec_ctx = ExecutionContext(ctx, stream, i, device) self.context_pool.append(exec_ctx) self.available_contexts.put(exec_ctx) # Model IDs referencing this engine self.model_ids: set = set() self.lock = threading.Lock() print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...") def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]: """ Acquire an available execution context from the pool. Blocks if all contexts are in use. Args: timeout: Max time to wait for context (None = wait forever) Returns: ExecutionContext or None if timeout """ try: exec_ctx = self.available_contexts.get(timeout=timeout) with exec_ctx.lock: exec_ctx.in_use = True return exec_ctx except: return None def release_context(self, exec_ctx: ExecutionContext): """ Return a context to the pool. Args: exec_ctx: Context to release """ with exec_ctx.lock: exec_ctx.in_use = False self.available_contexts.put(exec_ctx) def add_model_id(self, model_id: str): """Add a model_id reference to this engine""" with self.lock: self.model_ids.add(model_id) def remove_model_id(self, model_id: str) -> int: """ Remove a model_id reference from this engine. Returns the number of remaining references. """ with self.lock: self.model_ids.discard(model_id) return len(self.model_ids) def get_reference_count(self) -> int: """Get number of model_ids using this engine""" with self.lock: return len(self.model_ids) def cleanup(self): """Cleanup all contexts""" for exec_ctx in self.context_pool: del exec_ctx.context self.context_pool.clear() del self.engine class TensorRTModelRepository: """ Thread-safe repository for TensorRT models with context pooling and deduplication. Architecture: - Deduplication: Multiple model_ids with same file → share one engine - Context Pool: Each unique engine has N execution contexts (configurable) - Load Balancing: Contexts are borrowed/returned via mutex queue - Scalability: Adding 100 cameras with same model = 1 engine + N contexts (not 100 contexts!) Best Practices: - GPU-to-GPU: All inputs/outputs stay in VRAM (zero CPU transfers) - Thread Safety: Mutex-based context borrowing (TensorRT best practice) - Memory Efficient: Deduplicate by file hash, share engine across model_ids - Concurrent: N contexts allow N parallel inferences per unique model Example: # 100 cameras, same model file for i in range(100): repo.load_model(f"camera_{i}", "yolov8.trt") # Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts! """ def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4): """ Initialize the model repository. Args: gpu_id: GPU device ID to use default_num_contexts: Default number of execution contexts per unique engine """ self.gpu_id = gpu_id self.device = torch.device(f'cuda:{gpu_id}') self.default_num_contexts = default_num_contexts # Model ID to engine mapping: model_id -> file_hash self._model_to_hash: Dict[str, str] = {} # Shared engines with context pools: file_hash -> SharedEngine self._shared_engines: Dict[str, SharedEngine] = {} # Locks for thread safety self._repo_lock = threading.RLock() # TensorRT logger self.trt_logger = trt.Logger(trt.Logger.WARNING) print(f"TensorRT Model Repository initialized on GPU {gpu_id}") print(f"Default context pool size: {default_num_contexts} contexts per unique model") @staticmethod def compute_file_hash(file_path: str) -> str: """ Compute SHA256 hash of a file for deduplication. Args: file_path: Path to the file Returns: Hexadecimal hash string """ sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: # Read in chunks to handle large files efficiently for byte_block in iter(lambda: f.read(65536), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() def _load_engine(self, file_path: str) -> trt.ICudaEngine: """ Load TensorRT engine from file. Args: file_path: Path to .trt or .engine file Returns: TensorRT engine """ runtime = trt.Runtime(self.trt_logger) with open(file_path, 'rb') as f: engine_data = f.read() engine = runtime.deserialize_cuda_engine(engine_data) if engine is None: raise RuntimeError(f"Failed to load TensorRT engine from {file_path}") return engine def _extract_metadata(self, engine: trt.ICudaEngine, file_path: str, file_hash: str) -> ModelMetadata: """ Extract metadata from TensorRT engine. Args: engine: TensorRT engine file_path: Path to model file file_hash: SHA256 hash of model file Returns: ModelMetadata object """ input_shapes = {} output_shapes = {} input_names = [] output_names = [] input_dtypes = {} output_dtypes = {} # TensorRT dtype to PyTorch dtype mapping trt_to_torch_dtype = { trt.DataType.FLOAT: torch.float32, trt.DataType.HALF: torch.float16, trt.DataType.INT8: torch.int8, trt.DataType.INT32: torch.int32, trt.DataType.BOOL: torch.bool, } # Iterate through all tensors (inputs and outputs) - TensorRT 10.x API for i in range(engine.num_io_tensors): name = engine.get_tensor_name(i) shape = tuple(engine.get_tensor_shape(name)) dtype = trt_to_torch_dtype.get(engine.get_tensor_dtype(name), torch.float32) mode = engine.get_tensor_mode(name) if mode == trt.TensorIOMode.INPUT: input_names.append(name) input_shapes[name] = shape input_dtypes[name] = dtype else: output_names.append(name) output_shapes[name] = shape output_dtypes[name] = dtype return ModelMetadata( file_path=file_path, file_hash=file_hash, input_shapes=input_shapes, output_shapes=output_shapes, input_names=input_names, output_names=output_names, input_dtypes=input_dtypes, output_dtypes=output_dtypes ) def load_model(self, model_id: str, file_path: str, num_contexts: Optional[int] = None, force_reload: bool = False) -> ModelMetadata: """ Load a TensorRT model with the given ID. Deduplication: If a model with the same file hash is already loaded, the model_id is simply mapped to the existing SharedEngine (no new engine or contexts created). Args: model_id: User-defined identifier for this model (e.g., "camera_1") file_path: Path to TensorRT engine file (.trt or .engine) num_contexts: Number of execution contexts in pool (None = use default) force_reload: If True, reload even if model_id exists Returns: ModelMetadata for the loaded model Raises: FileNotFoundError: If model file doesn't exist RuntimeError: If engine loading fails ValueError: If model_id already exists and force_reload is False """ file_path = str(Path(file_path).resolve()) if not Path(file_path).exists(): raise FileNotFoundError(f"Model file not found: {file_path}") if num_contexts is None: num_contexts = self.default_num_contexts with self._repo_lock: # Check if model_id already exists if model_id in self._model_to_hash and not force_reload: raise ValueError( f"Model ID '{model_id}' already exists. " f"Use force_reload=True to reload or choose a different ID." ) # Unload existing model if force_reload if model_id in self._model_to_hash and force_reload: self.unload_model(model_id) # Compute file hash for deduplication print(f"Computing hash for {file_path}...") file_hash = self.compute_file_hash(file_path) print(f"File hash: {file_hash[:16]}...") # Check if this file is already loaded (deduplication) if file_hash in self._shared_engines: shared_engine = self._shared_engines[file_hash] print(f"Engine already loaded (hash match), reusing engine and context pool...") print(f" Existing model_ids using this engine: {shared_engine.model_ids}") else: # Load new engine print(f"Loading TensorRT engine from {file_path}...") engine = self._load_engine(file_path) # Extract metadata metadata = self._extract_metadata(engine, file_path, file_hash) # Create shared engine with context pool shared_engine = SharedEngine( engine=engine, file_hash=file_hash, file_path=file_path, num_contexts=num_contexts, device=self.device, metadata=metadata ) self._shared_engines[file_hash] = shared_engine # Add this model_id to the shared engine shared_engine.add_model_id(model_id) # Map model_id to file_hash self._model_to_hash[model_id] = file_hash print(f"Model '{model_id}' loaded successfully") print(f" Inputs: {shared_engine.metadata.input_names}") for name in shared_engine.metadata.input_names: print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})") print(f" Outputs: {shared_engine.metadata.output_names}") for name in shared_engine.metadata.output_names: print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})") print(f" Context pool size: {num_contexts}") print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}") print(f" Unique engines in VRAM: {len(self._shared_engines)}") return shared_engine.metadata def infer(self, model_id: str, inputs: Dict[str, torch.Tensor], synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]: """ Run GPU-to-GPU inference with the specified model using context pooling. All inputs must be CUDA tensors and outputs will be CUDA tensors (stays in VRAM). Thread-safe: Borrows an execution context from the pool with mutex locking. Args: model_id: Model identifier inputs: Dictionary mapping input names to CUDA tensors synchronize: If True, wait for inference to complete. If False, async execution. timeout: Max time to wait for available context (seconds) Returns: Dictionary mapping output names to CUDA tensors (in VRAM) Raises: KeyError: If model_id not found ValueError: If inputs don't match expected shapes or are not on GPU RuntimeError: If no context available within timeout """ # Get shared engine if model_id not in self._model_to_hash: raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}") file_hash = self._model_to_hash[model_id] shared_engine = self._shared_engines[file_hash] metadata = shared_engine.metadata # Validate inputs for name in metadata.input_names: if name not in inputs: raise ValueError(f"Missing required input: {name}") tensor = inputs[name] if not tensor.is_cuda: raise ValueError(f"Input '{name}' must be a CUDA tensor (on GPU)") # Check device if tensor.device != self.device: print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}") inputs[name] = tensor.to(self.device) # Acquire context from pool (mutex-based) exec_ctx = shared_engine.acquire_context(timeout=timeout) if exec_ctx is None: raise RuntimeError( f"No execution context available for model '{model_id}' within {timeout}s. " f"All {shared_engine.num_contexts} contexts are busy." ) try: # Prepare output tensors outputs = {} # Set input tensors - TensorRT 10.x API for name in metadata.input_names: input_tensor = inputs[name].contiguous() exec_ctx.context.set_tensor_address(name, input_tensor.data_ptr()) # Allocate and set output tensors for name in metadata.output_names: output_shape = metadata.output_shapes[name] output_dtype = metadata.output_dtypes[name] output_tensor = torch.empty( output_shape, dtype=output_dtype, device=self.device ) outputs[name] = output_tensor exec_ctx.context.set_tensor_address(name, output_tensor.data_ptr()) # Execute inference on context's stream - TensorRT 10.x API with torch.cuda.stream(exec_ctx.stream): success = exec_ctx.context.execute_async_v3( stream_handle=exec_ctx.stream.cuda_stream ) if not success: raise RuntimeError(f"Inference failed for model '{model_id}'") # Synchronize if requested if synchronize: exec_ctx.stream.synchronize() return outputs finally: # Always release context back to pool shared_engine.release_context(exec_ctx) def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]], synchronize: bool = True) -> List[Dict[str, torch.Tensor]]: """ Run inference on multiple inputs. Contexts are borrowed/returned for each input, enabling parallel processing. Args: model_id: Model identifier batch_inputs: List of input dictionaries synchronize: If True, wait for all inferences to complete Returns: List of output dictionaries """ results = [] for inputs in batch_inputs: outputs = self.infer(model_id, inputs, synchronize=synchronize) results.append(outputs) return results def unload_model(self, model_id: str): """ Unload a model from the repository. Removes the model_id reference from the shared engine. If this was the last reference, the engine and all its contexts will be fully unloaded from VRAM. Args: model_id: Model identifier to unload """ with self._repo_lock: if model_id not in self._model_to_hash: print(f"Warning: Model '{model_id}' not found") return file_hash = self._model_to_hash[model_id] # Remove model_id from shared engine if file_hash in self._shared_engines: shared_engine = self._shared_engines[file_hash] remaining_refs = shared_engine.remove_model_id(model_id) # If no more references, cleanup engine and contexts if remaining_refs == 0: shared_engine.cleanup() del self._shared_engines[file_hash] print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)") else: print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)") # Remove from model_id mapping del self._model_to_hash[model_id] def get_metadata(self, model_id: str) -> Optional[ModelMetadata]: """ Get metadata for a loaded model. Args: model_id: Model identifier Returns: ModelMetadata or None if not found """ if model_id not in self._model_to_hash: return None file_hash = self._model_to_hash[model_id] if file_hash not in self._shared_engines: return None return self._shared_engines[file_hash].metadata def list_models(self) -> Dict[str, ModelMetadata]: """ List all loaded models. Returns: Dictionary mapping model_id to ModelMetadata """ with self._repo_lock: result = {} for model_id, file_hash in self._model_to_hash.items(): if file_hash in self._shared_engines: result[model_id] = self._shared_engines[file_hash].metadata return result def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]: """ Get detailed information about a loaded model. Args: model_id: Model identifier Returns: Dictionary with model information or None if not found """ if model_id not in self._model_to_hash: return None file_hash = self._model_to_hash[model_id] if file_hash not in self._shared_engines: return None shared_engine = self._shared_engines[file_hash] metadata = shared_engine.metadata return { 'model_id': model_id, 'file_path': metadata.file_path, 'file_hash': metadata.file_hash[:16] + '...', 'engine_references': shared_engine.get_reference_count(), 'context_pool_size': shared_engine.num_contexts, 'shared_with_model_ids': list(shared_engine.model_ids), 'inputs': { name: { 'shape': metadata.input_shapes[name], 'dtype': str(metadata.input_dtypes[name]) } for name in metadata.input_names }, 'outputs': { name: { 'shape': metadata.output_shapes[name], 'dtype': str(metadata.output_dtypes[name]) } for name in metadata.output_names } } def get_stats(self) -> Dict[str, Any]: """ Get repository statistics. Returns: Dictionary with stats about loaded models and memory usage """ with self._repo_lock: total_contexts = sum( engine.num_contexts for engine in self._shared_engines.values() ) return { 'total_model_ids': len(self._model_to_hash), 'unique_engines': len(self._shared_engines), 'total_contexts': total_contexts, 'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines", 'gpu_id': self.gpu_id, 'models': list(self._model_to_hash.keys()) } def __repr__(self): with self._repo_lock: return (f"TensorRTModelRepository(gpu={self.gpu_id}, " f"model_ids={len(self._model_to_hash)}, " f"unique_engines={len(self._shared_engines)})") def __del__(self): """Cleanup all models on deletion""" with self._repo_lock: model_ids = list(self._model_to_hash.keys()) for model_id in model_ids: self.unload_model(model_id)