import threading import hashlib import json from typing import Optional, Dict, Any, List, Tuple from pathlib import Path from queue import Queue import torch import tensorrt as trt from dataclasses import dataclass import logging logger = logging.getLogger(__name__) @dataclass class ModelMetadata: """Metadata for a loaded TensorRT model""" file_path: str file_hash: str input_shapes: Dict[str, Tuple[int, ...]] output_shapes: Dict[str, Tuple[int, ...]] input_names: List[str] output_names: List[str] input_dtypes: Dict[str, torch.dtype] output_dtypes: Dict[str, torch.dtype] class ExecutionContext: """ Wrapper for TensorRT execution context with CUDA stream. Used in context pool for load balancing. """ def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream, context_id: int, device: torch.device): self.context = context self.stream = stream self.context_id = context_id self.device = device self.in_use = False self.lock = threading.Lock() def __repr__(self): return f"ExecutionContext(id={self.context_id}, in_use={self.in_use})" class SharedEngine: """ Shared TensorRT engine with context pool for load balancing. Architecture: - One engine shared across all model_ids with same file hash - Pool of N execution contexts for concurrent inference - Contexts are borrowed/returned using mutex locks - Load balancing: contexts distributed across requests """ def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str, num_contexts: int, device: torch.device, metadata: ModelMetadata): self.engine = engine self.file_hash = file_hash self.file_path = file_path self.metadata = metadata self.device = device self.num_contexts = num_contexts # Create context pool self.context_pool: List[ExecutionContext] = [] self.available_contexts: Queue[ExecutionContext] = Queue() for i in range(num_contexts): ctx = engine.create_execution_context() if ctx is None: raise RuntimeError(f"Failed to create execution context {i}") stream = torch.cuda.Stream(device=device) exec_ctx = ExecutionContext(ctx, stream, i, device) self.context_pool.append(exec_ctx) self.available_contexts.put(exec_ctx) # Model IDs referencing this engine self.model_ids: set = set() self.lock = threading.Lock() print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...") def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]: """ Acquire an available execution context from the pool. Blocks if all contexts are in use. Args: timeout: Max time to wait for context (None = wait forever) Returns: ExecutionContext or None if timeout """ try: exec_ctx = self.available_contexts.get(timeout=timeout) with exec_ctx.lock: exec_ctx.in_use = True return exec_ctx except: return None def release_context(self, exec_ctx: ExecutionContext): """ Return a context to the pool. Args: exec_ctx: Context to release """ with exec_ctx.lock: exec_ctx.in_use = False self.available_contexts.put(exec_ctx) def add_model_id(self, model_id: str): """Add a model_id reference to this engine""" with self.lock: self.model_ids.add(model_id) def remove_model_id(self, model_id: str) -> int: """ Remove a model_id reference from this engine. Returns the number of remaining references. """ with self.lock: self.model_ids.discard(model_id) return len(self.model_ids) def get_reference_count(self) -> int: """Get number of model_ids using this engine""" with self.lock: return len(self.model_ids) def cleanup(self): """Cleanup all contexts""" for exec_ctx in self.context_pool: del exec_ctx.context self.context_pool.clear() del self.engine class TensorRTModelRepository: """ Thread-safe repository for TensorRT models with context pooling and deduplication. Architecture: - Deduplication: Multiple model_ids with same file → share one engine - Context Pool: Each unique engine has N execution contexts (configurable) - Load Balancing: Contexts are borrowed/returned via mutex queue - Scalability: Adding 100 cameras with same model = 1 engine + N contexts (not 100 contexts!) Best Practices: - GPU-to-GPU: All inputs/outputs stay in VRAM (zero CPU transfers) - Thread Safety: Mutex-based context borrowing (TensorRT best practice) - Memory Efficient: Deduplicate by file hash, share engine across model_ids - Concurrent: N contexts allow N parallel inferences per unique model Example: # 100 cameras, same model file for i in range(100): repo.load_model(f"camera_{i}", "yolov8.trt") # Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts! """ def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True, cache_dir: str = ".trt_cache"): """ Initialize the model repository. Args: gpu_id: GPU device ID to use default_num_contexts: Default number of execution contexts per unique engine enable_pt_conversion: Enable automatic PyTorch to TensorRT conversion cache_dir: Directory for caching stripped TensorRT engines and metadata """ self.gpu_id = gpu_id self.device = torch.device(f'cuda:{gpu_id}') self.default_num_contexts = default_num_contexts self.enable_pt_conversion = enable_pt_conversion self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) # Model ID to engine mapping: model_id -> file_hash self._model_to_hash: Dict[str, str] = {} # Shared engines with context pools: file_hash -> SharedEngine self._shared_engines: Dict[str, SharedEngine] = {} # Locks for thread safety self._repo_lock = threading.RLock() # TensorRT logger self.trt_logger = trt.Logger(trt.Logger.WARNING) # PT converter (lazy initialization) self._pt_converter = None print(f"TensorRT Model Repository initialized on GPU {gpu_id}") print(f"Default context pool size: {default_num_contexts} contexts per unique model") print(f"Cache directory: {self.cache_dir}") if enable_pt_conversion: print(f"PyTorch to TensorRT conversion: enabled") @property def pt_converter(self): """Lazy initialization of PT converter""" if self._pt_converter is None and self.enable_pt_conversion: from .pt_converter import PTConverter self._pt_converter = PTConverter(gpu_id=self.gpu_id) logger.info("PT converter initialized") return self._pt_converter @staticmethod def compute_file_hash(file_path: str) -> str: """ Compute SHA256 hash of a file for deduplication. Args: file_path: Path to the file Returns: Hexadecimal hash string """ sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: # Read in chunks to handle large files efficiently for byte_block in iter(lambda: f.read(65536), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() def _load_engine(self, file_path: str) -> trt.ICudaEngine: """ Load TensorRT engine from file. Supports both raw TensorRT engines and Ultralytics .engine files (which have embedded JSON metadata at the beginning). For Ultralytics engines: - Strips metadata and caches pure TensorRT engine in cache_dir - Saves metadata as separate JSON file - Reuses cached stripped engine on subsequent loads Args: file_path: Path to .trt or .engine file Returns: TensorRT engine """ runtime = trt.Runtime(self.trt_logger) # Compute hash of original file for cache lookup file_hash = self.compute_file_hash(file_path) cache_engine_path = self.cache_dir / f"{file_hash}.trt" cache_metadata_path = self.cache_dir / f"{file_hash}_metadata.json" # Check if stripped engine already cached if cache_engine_path.exists(): logger.info(f"Loading cached stripped engine from {cache_engine_path}") with open(cache_engine_path, 'rb') as f: engine_data = f.read() else: # Read and process original file with open(file_path, 'rb') as f: # Try to read Ultralytics metadata header (first 4 bytes = metadata length) try: meta_len_bytes = f.read(4) if len(meta_len_bytes) == 4: meta_len = int.from_bytes(meta_len_bytes, byteorder="little") # Sanity check: metadata length should be reasonable (< 100KB) if 0 < meta_len < 100000: try: metadata_bytes = f.read(meta_len) metadata = json.loads(metadata_bytes.decode("utf-8")) # This is an Ultralytics engine - read remaining pure TRT data engine_data = f.read() # Save stripped engine to cache logger.info(f"Detected Ultralytics engine format") logger.info(f"Ultralytics metadata: {metadata}") logger.info(f"Caching stripped engine to {cache_engine_path}") with open(cache_engine_path, 'wb') as cache_f: cache_f.write(engine_data) # Save metadata separately with open(cache_metadata_path, 'w') as meta_f: json.dump(metadata, meta_f, indent=2) except (UnicodeDecodeError, json.JSONDecodeError): # Not Ultralytics format, rewind and read entire file f.seek(0) engine_data = f.read() else: # Invalid metadata length, rewind and read entire file f.seek(0) engine_data = f.read() else: # File too small, just use what we read engine_data = meta_len_bytes except Exception as e: # Any error, rewind and read entire file logger.warning(f"Error reading engine metadata: {e}, treating as raw TRT engine") f.seek(0) engine_data = f.read() # Cache the engine data (even if it was already raw TRT) if not cache_engine_path.exists(): with open(cache_engine_path, 'wb') as cache_f: cache_f.write(engine_data) engine = runtime.deserialize_cuda_engine(engine_data) if engine is None: raise RuntimeError(f"Failed to load TensorRT engine from {file_path}") return engine def _extract_metadata(self, engine: trt.ICudaEngine, file_path: str, file_hash: str) -> ModelMetadata: """ Extract metadata from TensorRT engine. Args: engine: TensorRT engine file_path: Path to model file file_hash: SHA256 hash of model file Returns: ModelMetadata object """ input_shapes = {} output_shapes = {} input_names = [] output_names = [] input_dtypes = {} output_dtypes = {} # TensorRT dtype to PyTorch dtype mapping trt_to_torch_dtype = { trt.DataType.FLOAT: torch.float32, trt.DataType.HALF: torch.float16, trt.DataType.INT8: torch.int8, trt.DataType.INT32: torch.int32, trt.DataType.BOOL: torch.bool, } # Iterate through all tensors (inputs and outputs) - TensorRT 10.x API for i in range(engine.num_io_tensors): name = engine.get_tensor_name(i) shape = tuple(engine.get_tensor_shape(name)) dtype = trt_to_torch_dtype.get(engine.get_tensor_dtype(name), torch.float32) mode = engine.get_tensor_mode(name) if mode == trt.TensorIOMode.INPUT: input_names.append(name) input_shapes[name] = shape input_dtypes[name] = dtype else: output_names.append(name) output_shapes[name] = shape output_dtypes[name] = dtype return ModelMetadata( file_path=file_path, file_hash=file_hash, input_shapes=input_shapes, output_shapes=output_shapes, input_names=input_names, output_names=output_names, input_dtypes=input_dtypes, output_dtypes=output_dtypes ) def load_model(self, model_id: str, file_path: str, num_contexts: Optional[int] = None, force_reload: bool = False, pt_input_shapes: Optional[Dict[str, Tuple]] = None, pt_precision: Optional[torch.dtype] = None, **pt_conversion_kwargs) -> ModelMetadata: """ Load a TensorRT model with the given ID. Supports both .trt and .pt files. PT files are automatically converted to TensorRT. Deduplication: If a model with the same file hash is already loaded, the model_id is simply mapped to the existing SharedEngine (no new engine or contexts created). Args: model_id: User-defined identifier for this model (e.g., "camera_1") file_path: Path to TensorRT engine file (.trt, .engine) or PyTorch file (.pt, .pth) num_contexts: Number of execution contexts in pool (None = use default) force_reload: If True, reload even if model_id exists pt_input_shapes: Required for .pt files - dict of input shapes (e.g., {"x": (1, 3, 224, 224)}) pt_precision: Precision for PT conversion (torch.float16 or torch.float32) **pt_conversion_kwargs: Additional arguments for torch_tensorrt.compile() Returns: ModelMetadata for the loaded model Raises: FileNotFoundError: If model file doesn't exist RuntimeError: If engine loading fails ValueError: If model_id already exists and force_reload is False, or PT conversion requires input_shapes """ file_path = str(Path(file_path).resolve()) if not Path(file_path).exists(): raise FileNotFoundError(f"Model file not found: {file_path}") # Check if file is PyTorch model file_ext = Path(file_path).suffix.lower() if file_ext in ['.pt', '.pth']: if not self.enable_pt_conversion: raise ValueError( f"PT file provided but PT conversion is disabled. " f"Enable with enable_pt_conversion=True or provide a .trt file." ) logger.info(f"Detected PyTorch model file: {file_path}") logger.info("Converting to TensorRT...") # Convert PT to TRT trt_hash, trt_path = self.pt_converter.convert( file_path, input_shapes=pt_input_shapes, precision=pt_precision, **pt_conversion_kwargs ) # Update file_path to use converted TRT file file_path = trt_path logger.info(f"Will load converted TensorRT model from: {file_path}") if num_contexts is None: num_contexts = self.default_num_contexts with self._repo_lock: # Check if model_id already exists if model_id in self._model_to_hash and not force_reload: raise ValueError( f"Model ID '{model_id}' already exists. " f"Use force_reload=True to reload or choose a different ID." ) # Unload existing model if force_reload if model_id in self._model_to_hash and force_reload: self.unload_model(model_id) # Compute file hash for deduplication print(f"Computing hash for {file_path}...") file_hash = self.compute_file_hash(file_path) print(f"File hash: {file_hash[:16]}...") # Check if this file is already loaded (deduplication) if file_hash in self._shared_engines: shared_engine = self._shared_engines[file_hash] print(f"Engine already loaded (hash match), reusing engine and context pool...") print(f" Existing model_ids using this engine: {shared_engine.model_ids}") else: # Load new engine print(f"Loading TensorRT engine from {file_path}...") engine = self._load_engine(file_path) # Extract metadata metadata = self._extract_metadata(engine, file_path, file_hash) # Create shared engine with context pool shared_engine = SharedEngine( engine=engine, file_hash=file_hash, file_path=file_path, num_contexts=num_contexts, device=self.device, metadata=metadata ) self._shared_engines[file_hash] = shared_engine # Add this model_id to the shared engine shared_engine.add_model_id(model_id) # Map model_id to file_hash self._model_to_hash[model_id] = file_hash print(f"Model '{model_id}' loaded successfully") print(f" Inputs: {shared_engine.metadata.input_names}") for name in shared_engine.metadata.input_names: print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})") print(f" Outputs: {shared_engine.metadata.output_names}") for name in shared_engine.metadata.output_names: print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})") print(f" Context pool size: {num_contexts}") print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}") print(f" Unique engines in VRAM: {len(self._shared_engines)}") return shared_engine.metadata def infer(self, model_id: str, inputs: Dict[str, torch.Tensor], synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]: """ Run GPU-to-GPU inference with the specified model using context pooling. All inputs must be CUDA tensors and outputs will be CUDA tensors (stays in VRAM). Thread-safe: Borrows an execution context from the pool with mutex locking. Args: model_id: Model identifier inputs: Dictionary mapping input names to CUDA tensors synchronize: If True, wait for inference to complete. If False, async execution. timeout: Max time to wait for available context (seconds) Returns: Dictionary mapping output names to CUDA tensors (in VRAM) Raises: KeyError: If model_id not found ValueError: If inputs don't match expected shapes or are not on GPU RuntimeError: If no context available within timeout """ # Get shared engine if model_id not in self._model_to_hash: raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}") file_hash = self._model_to_hash[model_id] shared_engine = self._shared_engines[file_hash] metadata = shared_engine.metadata # Validate inputs for name in metadata.input_names: if name not in inputs: raise ValueError(f"Missing required input: {name}") tensor = inputs[name] if not tensor.is_cuda: raise ValueError(f"Input '{name}' must be a CUDA tensor (on GPU)") # Check device if tensor.device != self.device: print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}") inputs[name] = tensor.to(self.device) # Acquire context from pool (mutex-based) exec_ctx = shared_engine.acquire_context(timeout=timeout) if exec_ctx is None: raise RuntimeError( f"No execution context available for model '{model_id}' within {timeout}s. " f"All {shared_engine.num_contexts} contexts are busy." ) try: # Prepare output tensors outputs = {} # Set input tensors - TensorRT 10.x API for name in metadata.input_names: input_tensor = inputs[name].contiguous() exec_ctx.context.set_tensor_address(name, input_tensor.data_ptr()) # Allocate and set output tensors for name in metadata.output_names: output_shape = metadata.output_shapes[name] output_dtype = metadata.output_dtypes[name] output_tensor = torch.empty( output_shape, dtype=output_dtype, device=self.device ) # NOTE: Don't track these tensors - they're returned to caller and consumed # by postprocessing, then automatically freed by PyTorch's garbage collector. # Tracking them would show false "leaks" since we can't track when the caller # finishes using them and PyTorch deallocates them. outputs[name] = output_tensor exec_ctx.context.set_tensor_address(name, output_tensor.data_ptr()) # Execute inference on context's stream - TensorRT 10.x API with torch.cuda.stream(exec_ctx.stream): success = exec_ctx.context.execute_async_v3( stream_handle=exec_ctx.stream.cuda_stream ) if not success: raise RuntimeError(f"Inference failed for model '{model_id}'") # Synchronize if requested if synchronize: exec_ctx.stream.synchronize() return outputs finally: # Always release context back to pool shared_engine.release_context(exec_ctx) def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]], synchronize: bool = True) -> List[Dict[str, torch.Tensor]]: """ Run inference on multiple inputs. Contexts are borrowed/returned for each input, enabling parallel processing. Args: model_id: Model identifier batch_inputs: List of input dictionaries synchronize: If True, wait for all inferences to complete Returns: List of output dictionaries """ results = [] for inputs in batch_inputs: outputs = self.infer(model_id, inputs, synchronize=synchronize) results.append(outputs) return results def unload_model(self, model_id: str): """ Unload a model from the repository. Removes the model_id reference from the shared engine. If this was the last reference, the engine and all its contexts will be fully unloaded from VRAM. Args: model_id: Model identifier to unload """ with self._repo_lock: if model_id not in self._model_to_hash: print(f"Warning: Model '{model_id}' not found") return file_hash = self._model_to_hash[model_id] # Remove model_id from shared engine if file_hash in self._shared_engines: shared_engine = self._shared_engines[file_hash] remaining_refs = shared_engine.remove_model_id(model_id) # If no more references, cleanup engine and contexts if remaining_refs == 0: shared_engine.cleanup() del self._shared_engines[file_hash] print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)") else: print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)") # Remove from model_id mapping del self._model_to_hash[model_id] def get_metadata(self, model_id: str) -> Optional[ModelMetadata]: """ Get metadata for a loaded model. Args: model_id: Model identifier Returns: ModelMetadata or None if not found """ if model_id not in self._model_to_hash: return None file_hash = self._model_to_hash[model_id] if file_hash not in self._shared_engines: return None return self._shared_engines[file_hash].metadata def list_models(self) -> Dict[str, ModelMetadata]: """ List all loaded models. Returns: Dictionary mapping model_id to ModelMetadata """ with self._repo_lock: result = {} for model_id, file_hash in self._model_to_hash.items(): if file_hash in self._shared_engines: result[model_id] = self._shared_engines[file_hash].metadata return result def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]: """ Get detailed information about a loaded model. Args: model_id: Model identifier Returns: Dictionary with model information or None if not found """ if model_id not in self._model_to_hash: return None file_hash = self._model_to_hash[model_id] if file_hash not in self._shared_engines: return None shared_engine = self._shared_engines[file_hash] metadata = shared_engine.metadata return { 'model_id': model_id, 'file_path': metadata.file_path, 'file_hash': metadata.file_hash[:16] + '...', 'engine_references': shared_engine.get_reference_count(), 'context_pool_size': shared_engine.num_contexts, 'shared_with_model_ids': list(shared_engine.model_ids), 'inputs': { name: { 'shape': metadata.input_shapes[name], 'dtype': str(metadata.input_dtypes[name]) } for name in metadata.input_names }, 'outputs': { name: { 'shape': metadata.output_shapes[name], 'dtype': str(metadata.output_dtypes[name]) } for name in metadata.output_names } } def get_stats(self) -> Dict[str, Any]: """ Get repository statistics. Returns: Dictionary with stats about loaded models and memory usage """ with self._repo_lock: total_contexts = sum( engine.num_contexts for engine in self._shared_engines.values() ) return { 'total_model_ids': len(self._model_to_hash), 'unique_engines': len(self._shared_engines), 'total_contexts': total_contexts, 'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines", 'gpu_id': self.gpu_id, 'models': list(self._model_to_hash.keys()) } def __repr__(self): with self._repo_lock: return (f"TensorRTModelRepository(gpu={self.gpu_id}, " f"model_ids={len(self._model_to_hash)}, " f"unique_engines={len(self._shared_engines)})") def __del__(self): """Cleanup all models on deletion""" with self._repo_lock: model_ids = list(self._model_to_hash.keys()) for model_id in model_ids: self.unload_model(model_id)