fix: gpu memory leaks

This commit is contained in:
Siwat Sirichai 2025-11-10 22:10:46 +07:00
parent 3a47920186
commit 593611cdb7
13 changed files with 420 additions and 166 deletions

View file

@ -1,5 +1,6 @@
import threading
import hashlib
import json
from typing import Optional, Dict, Any, List, Tuple
from pathlib import Path
from queue import Queue
@ -161,7 +162,7 @@ class TensorRTModelRepository:
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
"""
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True):
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True, cache_dir: str = ".trt_cache"):
"""
Initialize the model repository.
@ -169,11 +170,14 @@ class TensorRTModelRepository:
gpu_id: GPU device ID to use
default_num_contexts: Default number of execution contexts per unique engine
enable_pt_conversion: Enable automatic PyTorch to TensorRT conversion
cache_dir: Directory for caching stripped TensorRT engines and metadata
"""
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.default_num_contexts = default_num_contexts
self.enable_pt_conversion = enable_pt_conversion
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Model ID to engine mapping: model_id -> file_hash
self._model_to_hash: Dict[str, str] = {}
@ -192,6 +196,7 @@ class TensorRTModelRepository:
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
print(f"Cache directory: {self.cache_dir}")
if enable_pt_conversion:
print(f"PyTorch to TensorRT conversion: enabled")
@ -226,6 +231,14 @@ class TensorRTModelRepository:
"""
Load TensorRT engine from file.
Supports both raw TensorRT engines and Ultralytics .engine files
(which have embedded JSON metadata at the beginning).
For Ultralytics engines:
- Strips metadata and caches pure TensorRT engine in cache_dir
- Saves metadata as separate JSON file
- Reuses cached stripped engine on subsequent loads
Args:
file_path: Path to .trt or .engine file
@ -234,8 +247,68 @@ class TensorRTModelRepository:
"""
runtime = trt.Runtime(self.trt_logger)
with open(file_path, 'rb') as f:
engine_data = f.read()
# Compute hash of original file for cache lookup
file_hash = self.compute_file_hash(file_path)
cache_engine_path = self.cache_dir / f"{file_hash}.trt"
cache_metadata_path = self.cache_dir / f"{file_hash}_metadata.json"
# Check if stripped engine already cached
if cache_engine_path.exists():
logger.info(f"Loading cached stripped engine from {cache_engine_path}")
with open(cache_engine_path, 'rb') as f:
engine_data = f.read()
else:
# Read and process original file
with open(file_path, 'rb') as f:
# Try to read Ultralytics metadata header (first 4 bytes = metadata length)
try:
meta_len_bytes = f.read(4)
if len(meta_len_bytes) == 4:
meta_len = int.from_bytes(meta_len_bytes, byteorder="little")
# Sanity check: metadata length should be reasonable (< 100KB)
if 0 < meta_len < 100000:
try:
metadata_bytes = f.read(meta_len)
metadata = json.loads(metadata_bytes.decode("utf-8"))
# This is an Ultralytics engine - read remaining pure TRT data
engine_data = f.read()
# Save stripped engine to cache
logger.info(f"Detected Ultralytics engine format")
logger.info(f"Ultralytics metadata: {metadata}")
logger.info(f"Caching stripped engine to {cache_engine_path}")
with open(cache_engine_path, 'wb') as cache_f:
cache_f.write(engine_data)
# Save metadata separately
with open(cache_metadata_path, 'w') as meta_f:
json.dump(metadata, meta_f, indent=2)
except (UnicodeDecodeError, json.JSONDecodeError):
# Not Ultralytics format, rewind and read entire file
f.seek(0)
engine_data = f.read()
else:
# Invalid metadata length, rewind and read entire file
f.seek(0)
engine_data = f.read()
else:
# File too small, just use what we read
engine_data = meta_len_bytes
except Exception as e:
# Any error, rewind and read entire file
logger.warning(f"Error reading engine metadata: {e}, treating as raw TRT engine")
f.seek(0)
engine_data = f.read()
# Cache the engine data (even if it was already raw TRT)
if not cache_engine_path.exists():
with open(cache_engine_path, 'wb') as cache_f:
cache_f.write(engine_data)
engine = runtime.deserialize_cuda_engine(engine_data)
if engine is None:
@ -494,6 +567,11 @@ class TensorRTModelRepository:
device=self.device
)
# NOTE: Don't track these tensors - they're returned to caller and consumed
# by postprocessing, then automatically freed by PyTorch's garbage collector.
# Tracking them would show false "leaks" since we can't track when the caller
# finishes using them and PyTorch deallocates them.
outputs[name] = output_tensor
exec_ctx.context.set_tensor_address(name, output_tensor.data_ptr())