fix: gpu memory leaks
This commit is contained in:
parent
3a47920186
commit
593611cdb7
13 changed files with 420 additions and 166 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -5,3 +5,4 @@ __pycache__/
|
|||
.claude
|
||||
/models/
|
||||
/tracked_objects.json
|
||||
.trt_cache
|
||||
2
bangchak/models/.gitignore
vendored
Normal file
2
bangchak/models/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
*.onnx
|
||||
*.engine
|
||||
136
debug_trt_output.py
Normal file
136
debug_trt_output.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
Debug script to capture and compare raw PT vs TRT outputs on problematic frames.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import time
|
||||
from services import StreamDecoderFactory, YOLOv8Utils, TensorRTModelRepository
|
||||
from ultralytics import YOLO
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
|
||||
STREAM_URL = os.getenv('CAMERA_URL_1')
|
||||
|
||||
# Load models
|
||||
print("Loading models...")
|
||||
pt_model = YOLO(MODEL_PATH)
|
||||
pt_model.to(f'cuda:{GPU_ID}')
|
||||
|
||||
repo = TensorRTModelRepository(gpu_id=GPU_ID)
|
||||
trt_path = "./models/trtptcache/trt/cda5e520441e12fe09a97ac2609da29b4cbac969cc2029ef1735f65697579121.trt"
|
||||
repo.load_model("detector", trt_path, num_contexts=1)
|
||||
|
||||
# Start decoder
|
||||
print("Starting decoder...")
|
||||
decoder_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoder = decoder_factory.create_decoder(STREAM_URL, buffer_size=30)
|
||||
decoder.start()
|
||||
time.sleep(2)
|
||||
|
||||
torch.cuda.set_device(GPU_ID)
|
||||
|
||||
print("\nWaiting for frames with TRT false positives...\n")
|
||||
|
||||
frame_count = 0
|
||||
found_issue = False
|
||||
|
||||
while frame_count < 50 and not found_issue:
|
||||
frame = decoder.get_frame()
|
||||
if frame is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# Preprocess
|
||||
preprocessed = YOLOv8Utils.preprocess(frame, input_size=640)
|
||||
|
||||
# Run TRT inference
|
||||
trt_outputs = repo.infer("detector", {"images": preprocessed}, synchronize=True)
|
||||
trt_raw = trt_outputs['output0'] # (1, 5, 8400)
|
||||
|
||||
# Check for the issue - transpose and check channel 4
|
||||
trt_transposed = trt_raw.transpose(1, 2).squeeze(0) # (8400, 5)
|
||||
conf_channel = trt_transposed[:, 4] # (8400,)
|
||||
|
||||
num_high_conf = (conf_channel > 0.25).sum().item()
|
||||
|
||||
if num_high_conf > 100:
|
||||
found_issue = True
|
||||
print(f"🔴 FOUND PROBLEMATIC FRAME {frame_count}!")
|
||||
print(f" TRT detections > 0.25 threshold: {num_high_conf}")
|
||||
|
||||
# Now run PT model on same frame
|
||||
with torch.no_grad():
|
||||
pt_raw = pt_model.model(preprocessed)[0] # (1, 5, 8400)
|
||||
|
||||
print(f"\n=== RAW OUTPUT COMPARISON ===")
|
||||
print(f"PT output shape: {pt_raw.shape}")
|
||||
print(f"TRT output shape: {trt_raw.shape}")
|
||||
|
||||
# Compare channel 4 (confidence)
|
||||
pt_conf = pt_raw.transpose(1, 2).squeeze(0)[:, 4]
|
||||
trt_conf = trt_transposed[:, 4]
|
||||
|
||||
print(f"\n--- Confidence Channel (channel 4) ---")
|
||||
print(f"PT confidence stats:")
|
||||
print(f" Min: {pt_conf.min().item():.6e}")
|
||||
print(f" Max: {pt_conf.max().item():.6e}")
|
||||
print(f" Mean: {pt_conf.mean().item():.6e}")
|
||||
print(f" >0.25: {(pt_conf > 0.25).sum().item()}")
|
||||
print(f" >0.5: {(pt_conf > 0.5).sum().item()}")
|
||||
|
||||
print(f"\nTRT confidence stats:")
|
||||
print(f" Min: {trt_conf.min().item():.6e}")
|
||||
print(f" Max: {trt_conf.max().item():.6e}")
|
||||
print(f" Mean: {trt_conf.mean().item():.6e}")
|
||||
print(f" >0.25: {(trt_conf > 0.25).sum().item()}")
|
||||
print(f" >0.5: {(trt_conf > 0.5).sum().item()}")
|
||||
|
||||
# Check bbox coordinates too
|
||||
print(f"\n--- BBox Coordinates (channels 0-3) ---")
|
||||
pt_bbox = pt_raw.transpose(1, 2).squeeze(0)[:, :4]
|
||||
trt_bbox = trt_transposed[:, :4]
|
||||
|
||||
print(f"PT bbox stats:")
|
||||
print(f" Min: {pt_bbox.min().item():.3f}")
|
||||
print(f" Max: {pt_bbox.max().item():.3f}")
|
||||
print(f" Mean: {pt_bbox.mean().item():.3f}")
|
||||
|
||||
print(f"\nTRT bbox stats:")
|
||||
print(f" Min: {trt_bbox.min().item():.3f}")
|
||||
print(f" Max: {trt_bbox.max().item():.3f}")
|
||||
print(f" Mean: {trt_bbox.mean().item():.3f}")
|
||||
|
||||
# Sample some values
|
||||
print(f"\n--- Sample Values (first 5 anchors) ---")
|
||||
for i in range(5):
|
||||
print(f"\nAnchor {i}:")
|
||||
print(f" PT [cx={pt_bbox[i,0]:.1f}, cy={pt_bbox[i,1]:.1f}, w={pt_bbox[i,2]:.1f}, h={pt_bbox[i,3]:.1f}, conf={pt_conf[i]:.6f}]")
|
||||
print(f" TRT [cx={trt_bbox[i,0]:.1f}, cy={trt_bbox[i,1]:.1f}, w={trt_bbox[i,2]:.1f}, h={trt_bbox[i,3]:.1f}, conf={trt_conf[i]:.6f}]")
|
||||
|
||||
# Find indices with high confidence in TRT
|
||||
high_conf_idx = torch.where(trt_conf > 0.25)[0][:5]
|
||||
print(f"\n--- High Confidence Detections in TRT (first 5) ---")
|
||||
for idx in high_conf_idx:
|
||||
i = idx.item()
|
||||
print(f"\nAnchor {i}:")
|
||||
print(f" PT [cx={pt_bbox[i,0]:.1f}, cy={pt_bbox[i,1]:.1f}, w={pt_bbox[i,2]:.1f}, h={pt_bbox[i,3]:.1f}, conf={pt_conf[i]:.6f}]")
|
||||
print(f" TRT [cx={trt_bbox[i,0]:.1f}, cy={trt_bbox[i,1]:.1f}, w={trt_bbox[i,2]:.1f}, h={trt_bbox[i,3]:.1f}, conf={trt_conf[i]:.6f}]")
|
||||
|
||||
break
|
||||
|
||||
if frame_count % 10 == 0:
|
||||
print(f"Checked {frame_count} frames, no issues yet...")
|
||||
|
||||
if not found_issue:
|
||||
print(f"\n⚠️ No problematic frames found in {frame_count} frames")
|
||||
|
||||
# Cleanup
|
||||
decoder.stop()
|
||||
repo.unload_model("detector")
|
||||
print("\n✓ Done")
|
||||
9
new_buffer_design.txt
Normal file
9
new_buffer_design.txt
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
The Post-Decoded Buffer should just be the ping pong ring buffer
|
||||
let's get some relationship in order
|
||||
- ping pong ring is per model
|
||||
- many camera may use the same model
|
||||
- this buffer is filled when we memcpy it from decode buffer
|
||||
but I need some more ground rules
|
||||
- in the model buffer, one frame per camera may be in the buffer, if older one of the same camera exist, evict it. This is a real time system so buffer should be as fresh as possible.
|
||||
- The goal of batching is not to pool up processing for the same camera but to pool up multiple camera.
|
||||
- if all camera in the pool already post its frame, flush the buffer too
|
||||
|
|
@ -117,9 +117,10 @@ class ModelController:
|
|||
"""
|
||||
try:
|
||||
metadata = self.model_repository.get_metadata(self.model_id)
|
||||
# Get first input tensor shape
|
||||
first_input = list(metadata.inputs.values())[0]
|
||||
batch_dim = first_input["shape"][0]
|
||||
# Get first input tensor shape (ModelMetadata has input_shapes, not inputs)
|
||||
first_input_name = metadata.input_names[0]
|
||||
input_shape = metadata.input_shapes[first_input_name]
|
||||
batch_dim = input_shape[0]
|
||||
|
||||
# batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size)
|
||||
if batch_dim == -1:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import threading
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
|
|
@ -161,7 +162,7 @@ class TensorRTModelRepository:
|
|||
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
|
||||
"""
|
||||
|
||||
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True):
|
||||
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True, cache_dir: str = ".trt_cache"):
|
||||
"""
|
||||
Initialize the model repository.
|
||||
|
||||
|
|
@ -169,11 +170,14 @@ class TensorRTModelRepository:
|
|||
gpu_id: GPU device ID to use
|
||||
default_num_contexts: Default number of execution contexts per unique engine
|
||||
enable_pt_conversion: Enable automatic PyTorch to TensorRT conversion
|
||||
cache_dir: Directory for caching stripped TensorRT engines and metadata
|
||||
"""
|
||||
self.gpu_id = gpu_id
|
||||
self.device = torch.device(f'cuda:{gpu_id}')
|
||||
self.default_num_contexts = default_num_contexts
|
||||
self.enable_pt_conversion = enable_pt_conversion
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Model ID to engine mapping: model_id -> file_hash
|
||||
self._model_to_hash: Dict[str, str] = {}
|
||||
|
|
@ -192,6 +196,7 @@ class TensorRTModelRepository:
|
|||
|
||||
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
|
||||
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
|
||||
print(f"Cache directory: {self.cache_dir}")
|
||||
if enable_pt_conversion:
|
||||
print(f"PyTorch to TensorRT conversion: enabled")
|
||||
|
||||
|
|
@ -226,6 +231,14 @@ class TensorRTModelRepository:
|
|||
"""
|
||||
Load TensorRT engine from file.
|
||||
|
||||
Supports both raw TensorRT engines and Ultralytics .engine files
|
||||
(which have embedded JSON metadata at the beginning).
|
||||
|
||||
For Ultralytics engines:
|
||||
- Strips metadata and caches pure TensorRT engine in cache_dir
|
||||
- Saves metadata as separate JSON file
|
||||
- Reuses cached stripped engine on subsequent loads
|
||||
|
||||
Args:
|
||||
file_path: Path to .trt or .engine file
|
||||
|
||||
|
|
@ -234,8 +247,68 @@ class TensorRTModelRepository:
|
|||
"""
|
||||
runtime = trt.Runtime(self.trt_logger)
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
engine_data = f.read()
|
||||
# Compute hash of original file for cache lookup
|
||||
file_hash = self.compute_file_hash(file_path)
|
||||
cache_engine_path = self.cache_dir / f"{file_hash}.trt"
|
||||
cache_metadata_path = self.cache_dir / f"{file_hash}_metadata.json"
|
||||
|
||||
# Check if stripped engine already cached
|
||||
if cache_engine_path.exists():
|
||||
logger.info(f"Loading cached stripped engine from {cache_engine_path}")
|
||||
with open(cache_engine_path, 'rb') as f:
|
||||
engine_data = f.read()
|
||||
else:
|
||||
# Read and process original file
|
||||
with open(file_path, 'rb') as f:
|
||||
# Try to read Ultralytics metadata header (first 4 bytes = metadata length)
|
||||
try:
|
||||
meta_len_bytes = f.read(4)
|
||||
if len(meta_len_bytes) == 4:
|
||||
meta_len = int.from_bytes(meta_len_bytes, byteorder="little")
|
||||
|
||||
# Sanity check: metadata length should be reasonable (< 100KB)
|
||||
if 0 < meta_len < 100000:
|
||||
try:
|
||||
metadata_bytes = f.read(meta_len)
|
||||
metadata = json.loads(metadata_bytes.decode("utf-8"))
|
||||
|
||||
# This is an Ultralytics engine - read remaining pure TRT data
|
||||
engine_data = f.read()
|
||||
|
||||
# Save stripped engine to cache
|
||||
logger.info(f"Detected Ultralytics engine format")
|
||||
logger.info(f"Ultralytics metadata: {metadata}")
|
||||
logger.info(f"Caching stripped engine to {cache_engine_path}")
|
||||
|
||||
with open(cache_engine_path, 'wb') as cache_f:
|
||||
cache_f.write(engine_data)
|
||||
|
||||
# Save metadata separately
|
||||
with open(cache_metadata_path, 'w') as meta_f:
|
||||
json.dump(metadata, meta_f, indent=2)
|
||||
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
# Not Ultralytics format, rewind and read entire file
|
||||
f.seek(0)
|
||||
engine_data = f.read()
|
||||
else:
|
||||
# Invalid metadata length, rewind and read entire file
|
||||
f.seek(0)
|
||||
engine_data = f.read()
|
||||
else:
|
||||
# File too small, just use what we read
|
||||
engine_data = meta_len_bytes
|
||||
|
||||
except Exception as e:
|
||||
# Any error, rewind and read entire file
|
||||
logger.warning(f"Error reading engine metadata: {e}, treating as raw TRT engine")
|
||||
f.seek(0)
|
||||
engine_data = f.read()
|
||||
|
||||
# Cache the engine data (even if it was already raw TRT)
|
||||
if not cache_engine_path.exists():
|
||||
with open(cache_engine_path, 'wb') as cache_f:
|
||||
cache_f.write(engine_data)
|
||||
|
||||
engine = runtime.deserialize_cuda_engine(engine_data)
|
||||
if engine is None:
|
||||
|
|
@ -494,6 +567,11 @@ class TensorRTModelRepository:
|
|||
device=self.device
|
||||
)
|
||||
|
||||
# NOTE: Don't track these tensors - they're returned to caller and consumed
|
||||
# by postprocessing, then automatically freed by PyTorch's garbage collector.
|
||||
# Tracking them would show false "leaks" since we can't track when the caller
|
||||
# finishes using them and PyTorch deallocates them.
|
||||
|
||||
outputs[name] = output_tensor
|
||||
exec_ctx.context.set_tensor_address(name, output_tensor.data_ptr())
|
||||
|
||||
|
|
|
|||
|
|
@ -125,12 +125,19 @@ class PTConverter:
|
|||
|
||||
mapping = self.mapping_db[pt_hash]
|
||||
trt_hash = mapping["trt_hash"]
|
||||
|
||||
# Check both .engine and .trt extensions (Ultralytics uses .engine, generic uses .trt)
|
||||
engine_key = f"trt/{trt_hash}.engine"
|
||||
trt_key = f"trt/{trt_hash}.trt"
|
||||
|
||||
# Verify TRT file still exists in storage
|
||||
if not self.storage.exists(trt_key):
|
||||
# Try .engine first (Ultralytics native format)
|
||||
if self.storage.exists(engine_key):
|
||||
cached_key = engine_key
|
||||
elif self.storage.exists(trt_key):
|
||||
cached_key = trt_key
|
||||
else:
|
||||
logger.warning(
|
||||
f"Mapping exists for PT hash {pt_hash[:16]}... but TRT file missing. "
|
||||
f"Mapping exists for PT hash {pt_hash[:16]}... but engine file missing. "
|
||||
f"Will reconvert."
|
||||
)
|
||||
# Remove stale mapping
|
||||
|
|
@ -139,16 +146,16 @@ class PTConverter:
|
|||
return None
|
||||
|
||||
# Get local path
|
||||
trt_path = self.storage.get_local_path(trt_key)
|
||||
if trt_path is None:
|
||||
logger.error(f"Could not get local path for TRT file {trt_key}")
|
||||
cached_path = self.storage.get_local_path(cached_key)
|
||||
if cached_path is None:
|
||||
logger.error(f"Could not get local path for engine file {cached_key}")
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Found cached conversion for PT hash {pt_hash[:16]}... -> "
|
||||
f"TRT hash {trt_hash[:16]}..."
|
||||
f"Engine hash {trt_hash[:16]}... ({cached_key})"
|
||||
)
|
||||
return (trt_hash, trt_path)
|
||||
return (trt_hash, cached_path)
|
||||
|
||||
def convert(
|
||||
self,
|
||||
|
|
@ -241,24 +248,21 @@ class PTConverter:
|
|||
precision: torch.dtype,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Convert ultralytics YOLO model using ONNX → TensorRT pipeline.
|
||||
Uses the same approach as scripts/convert_pt_to_tensorrt.py
|
||||
Convert ultralytics YOLO model using native .engine export.
|
||||
This produces .engine files with embedded metadata (no manual input_shapes needed).
|
||||
|
||||
Args:
|
||||
pt_path: Path to PT file
|
||||
pt_hash: PT file hash
|
||||
input_shapes: Input tensor shapes
|
||||
input_shapes: Input tensor shapes (IGNORED for Ultralytics - auto-detected)
|
||||
precision: Target precision
|
||||
|
||||
Returns:
|
||||
Tuple of (trt_hash, trt_file_path)
|
||||
Tuple of (engine_hash, engine_file_path)
|
||||
"""
|
||||
import tensorrt as trt
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
|
||||
logger.info("Detected ultralytics YOLO model, using ONNX → TensorRT pipeline...")
|
||||
logger.info("Detected ultralytics YOLO model, using native .engine export...")
|
||||
|
||||
# Load ultralytics model
|
||||
try:
|
||||
|
|
@ -267,83 +271,48 @@ class PTConverter:
|
|||
except ImportError:
|
||||
raise ImportError("ultralytics package not found. Install with: pip install ultralytics")
|
||||
|
||||
# Determine input shape
|
||||
if not input_shapes:
|
||||
raise ValueError("input_shapes required for ultralytics conversion")
|
||||
# Export to native .engine format with embedded metadata
|
||||
logger.info(f"Exporting to native TensorRT .engine (precision: {'FP16' if precision == torch.float16 else 'FP32'})...")
|
||||
|
||||
input_key = 'images' if 'images' in input_shapes else list(input_shapes.keys())[0]
|
||||
input_shape = input_shapes[input_key]
|
||||
# Ultralytics export creates .engine file in same directory as .pt
|
||||
engine_path = model.export(
|
||||
format='engine',
|
||||
half=(precision == torch.float16),
|
||||
device=self.gpu_id,
|
||||
batch=1,
|
||||
simplify=True
|
||||
)
|
||||
|
||||
# Export to ONNX first
|
||||
logger.info(f"Exporting to ONNX (input shape: {input_shape})...")
|
||||
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp_onnx:
|
||||
onnx_path = tmp_onnx.name
|
||||
# Convert to string (Ultralytics returns Path object)
|
||||
engine_path = str(engine_path)
|
||||
logger.info(f"Native .engine export complete: {engine_path}")
|
||||
logger.info("Metadata embedded in .engine file (stride, imgsz, names, etc.)")
|
||||
|
||||
try:
|
||||
# Use ultralytics export to ONNX
|
||||
model.export(format='onnx', imgsz=input_shape[2], batch=input_shape[0])
|
||||
# Ultralytics saves as model_name.onnx in same directory
|
||||
pt_dir = os.path.dirname(pt_path)
|
||||
pt_name = os.path.splitext(os.path.basename(pt_path))[0]
|
||||
onnx_export_path = os.path.join(pt_dir, f"{pt_name}.onnx")
|
||||
# Read the exported .engine file
|
||||
with open(engine_path, 'rb') as f:
|
||||
engine_data = f.read()
|
||||
|
||||
# Move to our temp location (use shutil.move for cross-device support)
|
||||
if os.path.exists(onnx_export_path):
|
||||
shutil.move(onnx_export_path, onnx_path)
|
||||
else:
|
||||
raise RuntimeError(f"ONNX export failed, file not found: {onnx_export_path}")
|
||||
# Compute hash of the .engine file
|
||||
engine_hash = hashlib.sha256(engine_data).hexdigest()
|
||||
|
||||
logger.info(f"ONNX export complete: {onnx_path}")
|
||||
# Store in our cache (as .engine to preserve metadata)
|
||||
engine_key = f"trt/{engine_hash}.engine"
|
||||
self.storage.write(engine_key, engine_data)
|
||||
|
||||
# Build TensorRT engine from ONNX
|
||||
logger.info("Building TensorRT engine from ONNX...")
|
||||
trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
builder = trt.Builder(trt_logger)
|
||||
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
parser = trt.OnnxParser(network, trt_logger)
|
||||
cached_path = self.storage.get_local_path(engine_key)
|
||||
if cached_path is None:
|
||||
raise RuntimeError("Failed to get local path for .engine file")
|
||||
|
||||
# Parse ONNX
|
||||
with open(onnx_path, 'rb') as f:
|
||||
if not parser.parse(f.read()):
|
||||
errors = [parser.get_error(i) for i in range(parser.num_errors)]
|
||||
raise RuntimeError(f"Failed to parse ONNX: {errors}")
|
||||
# Clean up the original export (we've cached it)
|
||||
# Only delete if it's different from cached path
|
||||
if os.path.exists(engine_path) and os.path.abspath(engine_path) != os.path.abspath(cached_path):
|
||||
logger.info(f"Removing original export (cached): {engine_path}")
|
||||
os.unlink(engine_path)
|
||||
else:
|
||||
logger.info(f"Keeping original export at: {engine_path}")
|
||||
|
||||
# Configure builder
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB
|
||||
|
||||
# Set precision
|
||||
if precision == torch.float16:
|
||||
if builder.platform_has_fast_fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
logger.info("FP16 mode enabled")
|
||||
|
||||
# Build engine
|
||||
logger.info("Building TensorRT engine (this may take a few minutes)...")
|
||||
serialized_engine = builder.build_serialized_network(network, config)
|
||||
|
||||
if serialized_engine is None:
|
||||
raise RuntimeError("Failed to build TensorRT engine")
|
||||
|
||||
# Convert IHostMemory to bytes
|
||||
engine_bytes = bytes(serialized_engine)
|
||||
|
||||
# Save to storage
|
||||
trt_hash = hashlib.sha256(engine_bytes).hexdigest()
|
||||
trt_key = f"trt/{trt_hash}.trt"
|
||||
self.storage.write(trt_key, engine_bytes)
|
||||
|
||||
trt_path = self.storage.get_local_path(trt_key)
|
||||
if trt_path is None:
|
||||
raise RuntimeError("Failed to get local path for TRT file")
|
||||
|
||||
logger.info(f"TensorRT engine built successfully: {trt_path}")
|
||||
return (trt_hash, trt_path)
|
||||
|
||||
finally:
|
||||
# Cleanup ONNX file
|
||||
if os.path.exists(onnx_path):
|
||||
os.unlink(onnx_path)
|
||||
logger.info(f"Cached .engine file: {cached_path}")
|
||||
return (engine_hash, cached_path)
|
||||
|
||||
def _perform_conversion(
|
||||
self,
|
||||
|
|
@ -387,11 +356,21 @@ class PTConverter:
|
|||
|
||||
# Check if this is an ultralytics model
|
||||
if self._is_ultralytics_model(model):
|
||||
logger.info("Detected ultralytics model, using ultralytics export API")
|
||||
logger.info("Detected Ultralytics YOLO model, using native .engine export")
|
||||
logger.info("Note: input_shapes parameter is ignored for Ultralytics models (auto-detected)")
|
||||
return self._convert_ultralytics_model(pt_path, pt_hash, input_shapes, precision)
|
||||
|
||||
# For non-ultralytics models, use torch_tensorrt
|
||||
logger.info("Using torch_tensorrt for conversion")
|
||||
logger.info("Using torch_tensorrt for conversion (non-Ultralytics model)")
|
||||
|
||||
# Non-Ultralytics models REQUIRE input_shapes
|
||||
if input_shapes is None:
|
||||
raise ValueError(
|
||||
"input_shapes required for non-Ultralytics PyTorch models. "
|
||||
"For Ultralytics YOLO models, input_shapes is auto-detected. "
|
||||
"Example: input_shapes={'images': (1, 3, 640, 640)}"
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Convert model to target precision to avoid mixed precision issues
|
||||
|
|
|
|||
|
|
@ -127,15 +127,17 @@ class StreamConnection:
|
|||
self.status = ConnectionStatus.DISCONNECTED
|
||||
logger.info(f"Stream {self.stream_id} stopped")
|
||||
|
||||
def _on_frame_decoded(self, frame: torch.Tensor):
|
||||
def _on_frame_decoded(self, frame_ref):
|
||||
"""
|
||||
Event handler called by decoder when a new frame is decoded.
|
||||
This is the event-driven replacement for polling.
|
||||
|
||||
Args:
|
||||
frame: RGB frame tensor on GPU (3, H, W)
|
||||
frame_ref: FrameReference object containing the RGB frame tensor
|
||||
"""
|
||||
if not self.running:
|
||||
# If not running, free the frame immediately
|
||||
frame_ref.free()
|
||||
return
|
||||
|
||||
try:
|
||||
|
|
@ -143,12 +145,14 @@ class StreamConnection:
|
|||
self.frame_count += 1
|
||||
|
||||
# Submit to model controller for batched inference
|
||||
# Pass the FrameReference in metadata so we can free it later
|
||||
self.model_controller.submit_frame(
|
||||
stream_id=self.stream_id,
|
||||
frame=frame,
|
||||
frame=frame_ref.rgb_tensor,
|
||||
metadata={
|
||||
"frame_number": self.frame_count,
|
||||
"shape": tuple(frame.shape),
|
||||
"shape": tuple(frame_ref.rgb_tensor.shape),
|
||||
"frame_ref": frame_ref, # Store reference for later cleanup
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -164,6 +168,8 @@ class StreamConnection:
|
|||
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
|
||||
self.error_queue.put(e)
|
||||
self.status = ConnectionStatus.ERROR
|
||||
# Free the frame on error
|
||||
frame_ref.free()
|
||||
|
||||
def _handle_inference_result(self, result: Dict[str, Any]):
|
||||
"""
|
||||
|
|
@ -173,12 +179,17 @@ class StreamConnection:
|
|||
Args:
|
||||
result: Inference result dictionary
|
||||
"""
|
||||
frame_ref = None
|
||||
try:
|
||||
# Extract detections
|
||||
detections = result["detections"]
|
||||
|
||||
# Run tracking (synchronous)
|
||||
tracked_objects = self._run_tracking_sync(detections)
|
||||
# Get FrameReference from metadata (if present)
|
||||
frame_ref = result["metadata"].get("frame_ref")
|
||||
|
||||
# Run tracking (synchronous) with frame shape for bbox scaling
|
||||
frame_shape = result["metadata"].get("shape")
|
||||
tracked_objects = self._run_tracking_sync(detections, frame_shape)
|
||||
|
||||
# Create tracking result
|
||||
tracking_result = TrackingResult(
|
||||
|
|
@ -196,13 +207,18 @@ class StreamConnection:
|
|||
except Exception as e:
|
||||
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
|
||||
self.error_queue.put(e)
|
||||
finally:
|
||||
# Free the frame reference - this is the last point in the pipeline
|
||||
if frame_ref is not None:
|
||||
frame_ref.free()
|
||||
|
||||
def _run_tracking_sync(self, detections, min_confidence=0.7):
|
||||
def _run_tracking_sync(self, detections, frame_shape=None, min_confidence=0.7):
|
||||
"""
|
||||
Run tracking synchronously (called from executor).
|
||||
|
||||
Args:
|
||||
detections: Detection tensor (N, 6) [x1, y1, x2, y2, conf, class_id]
|
||||
frame_shape: Original frame shape (C, H, W) for scaling bboxes
|
||||
min_confidence: Minimum confidence threshold for detections
|
||||
|
||||
Returns:
|
||||
|
|
@ -226,8 +242,8 @@ class StreamConnection:
|
|||
class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown"
|
||||
))
|
||||
|
||||
# Update tracker with detections (lightweight, no model dependency!)
|
||||
return self.tracking_controller.update(detection_list)
|
||||
# Update tracker with detections (will scale bboxes to frame space)
|
||||
return self.tracking_controller.update(detection_list, frame_shape=frame_shape)
|
||||
|
||||
def tracking_results(self):
|
||||
"""
|
||||
|
|
@ -339,15 +355,31 @@ class StreamConnectionManager:
|
|||
"""
|
||||
Initialize the manager with a model.
|
||||
|
||||
Supports transparent loading of .pt (YOLO), .engine, and .trt files.
|
||||
For Ultralytics YOLO models (.pt), metadata is auto-detected - no manual
|
||||
input_shapes or precision needed! Non-YOLO models still require input_shapes.
|
||||
|
||||
Args:
|
||||
model_path: Path to TensorRT or PyTorch model file (.trt, .pt, .pth)
|
||||
model_path: Path to model file (.trt, .engine, .pt, .pth)
|
||||
- .engine: Ultralytics native format (recommended)
|
||||
- .pt: Auto-converts to .engine (YOLO models only)
|
||||
- .trt: Raw TensorRT engine
|
||||
model_id: Model identifier (default: "detector")
|
||||
preprocess_fn: Preprocessing function (e.g., YOLOv8Utils.preprocess)
|
||||
postprocess_fn: Postprocessing function (e.g., YOLOv8Utils.postprocess)
|
||||
num_contexts: Number of TensorRT execution contexts (default: 4)
|
||||
pt_input_shapes: Required for PT files - dict of input shapes
|
||||
pt_precision: Precision for PT conversion (torch.float16 or torch.float32)
|
||||
pt_input_shapes: [Optional] Only required for non-YOLO PyTorch models
|
||||
YOLO models auto-detect from embedded metadata
|
||||
pt_precision: [Optional] Precision for PT conversion (auto-detected for YOLO)
|
||||
**pt_conversion_kwargs: Additional PT conversion arguments
|
||||
|
||||
Example:
|
||||
# YOLO model - no manual parameters needed:
|
||||
manager.initialize(
|
||||
model_path="model.pt", # or .engine
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess
|
||||
)
|
||||
"""
|
||||
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -12,26 +12,31 @@ from .jpeg_encoder import encode_frame_to_jpeg
|
|||
|
||||
class FrameReference:
|
||||
"""
|
||||
CPU-side reference object for a GPU frame.
|
||||
Reference-counted frame wrapper for zero-copy memory management.
|
||||
|
||||
This object holds a cloned RGB tensor that is independent of PyNvVideoCodec's
|
||||
DecodedFrame lifecycle. We don't keep the DecodedFrame to avoid conflicts
|
||||
with PyNvVideoCodec's internal frame pool management.
|
||||
This allows multiple parts of the pipeline to hold references to the same
|
||||
cloned frame, and tracks when all references are released so the decoder
|
||||
knows when buffer slots can be reused.
|
||||
"""
|
||||
def __init__(self, rgb_tensor: torch.Tensor, buffer_index: int, decoder):
|
||||
self.rgb_tensor = rgb_tensor # Cloned RGB tensor (independent copy)
|
||||
self.rgb_tensor = rgb_tensor # Cloned RGB tensor (one clone per frame)
|
||||
self.buffer_index = buffer_index
|
||||
self.decoder = decoder # Reference to decoder for marking as free
|
||||
self.decoder = decoder
|
||||
self._freed = False
|
||||
|
||||
def free(self):
|
||||
"""Mark this frame as no longer in use"""
|
||||
"""Mark this reference as freed - called by the last user of the frame"""
|
||||
if not self._freed:
|
||||
self._freed = True
|
||||
|
||||
# Release GPU memory immediately
|
||||
if self.rgb_tensor is not None:
|
||||
del self.rgb_tensor
|
||||
self.rgb_tensor = None
|
||||
self.decoder._mark_frame_free(self.buffer_index)
|
||||
|
||||
def is_freed(self) -> bool:
|
||||
"""Check if this frame has been freed"""
|
||||
"""Check if this reference has been freed"""
|
||||
return self._freed
|
||||
|
||||
def __del__(self):
|
||||
|
|
@ -212,13 +217,10 @@ class StreamDecoder:
|
|||
self.status = ConnectionStatus.DISCONNECTED
|
||||
self._status_lock = threading.Lock()
|
||||
|
||||
# Frame buffer (ring buffer) - stores FrameReference objects
|
||||
# Frame buffer (ring buffer) - stores cloned RGB tensors
|
||||
self.frame_buffer = deque(maxlen=buffer_size)
|
||||
self._buffer_lock = threading.RLock()
|
||||
|
||||
# Track which buffer slots are in use (list of FrameReference objects)
|
||||
self._in_use_frames = [] # List of FrameReference objects currently held by callbacks
|
||||
|
||||
# Decoder and container instances
|
||||
self.decoder = None
|
||||
self.container = None
|
||||
|
|
@ -236,6 +238,10 @@ class StreamDecoder:
|
|||
self._frame_callbacks = []
|
||||
self._callback_lock = threading.Lock()
|
||||
|
||||
# Track frames currently in use (referenced by callbacks/pipeline)
|
||||
self._in_use_frames = [] # List of FrameReference objects
|
||||
self._frame_index_counter = 0 # Monotonically increasing frame index
|
||||
|
||||
def register_frame_callback(self, callback: Callable):
|
||||
"""
|
||||
Register a callback to be called when a new frame is decoded.
|
||||
|
|
@ -396,19 +402,7 @@ class StreamDecoder:
|
|||
# Add frames to ring buffer and fire callbacks
|
||||
with self._buffer_lock:
|
||||
for frame in decoded_frames:
|
||||
# Check for buffer overflow - discard oldest if needed
|
||||
if len(self.frame_buffer) >= self.buffer_size:
|
||||
# Check if oldest frame is still in use
|
||||
if len(self._in_use_frames) > 0:
|
||||
oldest_ref = self.frame_buffer[0] if len(self.frame_buffer) > 0 else None
|
||||
if oldest_ref and not oldest_ref.is_freed():
|
||||
# Force free the oldest frame to prevent overflow
|
||||
print(f"[WARNING] Buffer overflow, force-freeing oldest frame (buffer_index={oldest_ref.buffer_index})")
|
||||
oldest_ref.free()
|
||||
|
||||
# Deque will automatically remove oldest when at maxlen
|
||||
|
||||
# Convert to tensor
|
||||
# Convert to tensor immediately after NVDEC
|
||||
try:
|
||||
# Convert DecodedFrame to PyTorch tensor using DLPack (zero-copy)
|
||||
nv12_tensor = torch.from_dlpack(frame)
|
||||
|
|
@ -417,32 +411,32 @@ class StreamDecoder:
|
|||
if self.frame_height is not None and self.frame_width is not None:
|
||||
rgb_tensor = nv12_to_rgb_gpu(nv12_tensor, self.frame_height, self.frame_width)
|
||||
|
||||
# CRITICAL: Clone the RGB tensor to break CUDA memory dependency
|
||||
# The nv12_to_rgb_gpu creates a new tensor, but it still references
|
||||
# the same CUDA context/stream. We need an independent copy.
|
||||
rgb_tensor_cloned = rgb_tensor.clone()
|
||||
# CLONE ONCE into our post-decode buffer
|
||||
# This breaks the dependency on PyNvVideoCodec's DecodedFrame
|
||||
# After this, the tensor is fully ours and can be used throughout the pipeline
|
||||
rgb_cloned = rgb_tensor.clone()
|
||||
|
||||
# Create FrameReference object for C++-style memory management
|
||||
# We don't keep the DecodedFrame to avoid conflicts with PyNvVideoCodec's
|
||||
# internal frame pool - the clone is fully independent
|
||||
buffer_index = self.frame_count
|
||||
# Create FrameReference for reference counting
|
||||
frame_ref = FrameReference(
|
||||
rgb_tensor=rgb_tensor_cloned, # Independent cloned tensor
|
||||
buffer_index=buffer_index,
|
||||
rgb_tensor=rgb_cloned,
|
||||
buffer_index=self._frame_index_counter,
|
||||
decoder=self
|
||||
)
|
||||
self._frame_index_counter += 1
|
||||
|
||||
# Add to buffer and in-use tracking
|
||||
# Add FrameReference to ring buffer (deque automatically removes oldest when full)
|
||||
self.frame_buffer.append(frame_ref)
|
||||
self._in_use_frames.append(frame_ref)
|
||||
self.frame_count += 1
|
||||
|
||||
# Fire callbacks with the cloned RGB tensor from FrameReference
|
||||
# The tensor is now independent of the DecodedFrame lifecycle
|
||||
# Track this frame as in-use
|
||||
self._in_use_frames.append(frame_ref)
|
||||
|
||||
# Fire callbacks with the FrameReference
|
||||
# The callback receivers should call .free() when done
|
||||
with self._callback_lock:
|
||||
for callback in self._frame_callbacks:
|
||||
try:
|
||||
callback(frame_ref.rgb_tensor)
|
||||
callback(frame_ref)
|
||||
except Exception as e:
|
||||
print(f"Error in frame callback: {e}")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -272,12 +272,14 @@ class ObjectTracker:
|
|||
for tid in stale_track_ids:
|
||||
del self._tracks[tid]
|
||||
|
||||
def update(self, detections: List[Detection]) -> List[TrackedObject]:
|
||||
def update(self, detections: List[Detection], frame_shape: tuple = None, model_input_size: int = 640) -> List[TrackedObject]:
|
||||
"""
|
||||
Update tracker with new detections (decoupled from inference).
|
||||
|
||||
Args:
|
||||
detections: List of Detection objects from model inference
|
||||
frame_shape: Original frame shape (C, H, W) for scaling bboxes back from model space
|
||||
model_input_size: Model input size (default: 640 for YOLOv8)
|
||||
|
||||
Returns:
|
||||
List of currently tracked objects
|
||||
|
|
@ -291,6 +293,22 @@ class ObjectTracker:
|
|||
self._cleanup_stale_tracks()
|
||||
return list(self._tracks.values())
|
||||
|
||||
# Scale detections from model space (640x640) to frame space (H x W)
|
||||
if frame_shape is not None:
|
||||
_, frame_h, frame_w = frame_shape
|
||||
scale_x = frame_w / model_input_size
|
||||
scale_y = frame_h / model_input_size
|
||||
|
||||
# Scale all detection bboxes
|
||||
for det in detections:
|
||||
x1, y1, x2, y2 = det.bbox
|
||||
det.bbox = [
|
||||
x1 * scale_x,
|
||||
y1 * scale_y,
|
||||
x2 * scale_x,
|
||||
y2 * scale_y
|
||||
]
|
||||
|
||||
# Convert detections to tensor for GPU processing
|
||||
det_tensor = torch.tensor(
|
||||
[[*det.bbox, det.confidence, det.class_id] for det in detections],
|
||||
|
|
|
|||
|
|
@ -63,6 +63,10 @@ class YOLOv8Utils:
|
|||
# Normalize to [0, 1] (YOLOv8 expects normalized input)
|
||||
frame_normalized = frame_resized / 255.0
|
||||
|
||||
# NOTE: Don't track these tensors - they're short-lived inputs to TensorRT
|
||||
# that get automatically freed by PyTorch after inference completes.
|
||||
# Tracking them would show false "leaks" since we can't track when TensorRT consumes them.
|
||||
|
||||
return frame_normalized
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -43,8 +43,8 @@ async def example_callback_pattern():
|
|||
poll_interval=0.01, # 100 FPS
|
||||
)
|
||||
|
||||
# Initialize with YOLOv8 model
|
||||
model_path = "models/yolov8n.trt" # Adjust path as needed
|
||||
# Initialize with YOLOv8 model (transparent loading: .pt, .engine, or .trt)
|
||||
model_path = "models/yolov8n.trt" # Can also use .pt or .engine
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
|
@ -53,7 +53,8 @@ async def example_callback_pattern():
|
|||
model_path=model_path,
|
||||
model_id="yolo",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess
|
||||
# Note: No manual parameters needed for YOLO models
|
||||
)
|
||||
|
||||
# Define callback for tracking results
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ from services import (
|
|||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def main_single_stream():
|
||||
"""Single stream example with event-driven architecture."""
|
||||
print("=" * 80)
|
||||
|
|
@ -33,7 +32,7 @@ def main_single_stream():
|
|||
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # PT file will be auto-converted
|
||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
|
||||
STREAM_URL = os.getenv('CAMERA_URL_1', 'rtsp://localhost:8554/test')
|
||||
BATCH_SIZE = 4
|
||||
FORCE_TIMEOUT = 0.05
|
||||
|
|
@ -59,10 +58,10 @@ def main_single_stream():
|
|||
)
|
||||
print("✓ Manager created")
|
||||
|
||||
# Initialize with PT model (auto-conversion)
|
||||
print("\n[2/3] Initializing with PT model...")
|
||||
print("Note: First load will convert PT to TensorRT (3-5 minutes)")
|
||||
print("Subsequent loads will use cached TensorRT engine\n")
|
||||
# Initialize with model (transparent loading - no manual parameters needed)
|
||||
print("\n[2/3] Initializing model...")
|
||||
print("Note: YOLO models auto-convert to native TensorRT .engine (first time only)")
|
||||
print("Metadata is auto-detected from model - no manual input_shapes needed!\n")
|
||||
|
||||
try:
|
||||
manager.initialize(
|
||||
|
|
@ -70,11 +69,10 @@ def main_single_stream():
|
|||
model_id="detector",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
num_contexts=4,
|
||||
pt_input_shapes={"images": (1, 3, 640, 640)},
|
||||
pt_precision=torch.float16
|
||||
num_contexts=4
|
||||
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
||||
)
|
||||
print("✓ Manager initialized (PT converted to TensorRT)")
|
||||
print("✓ Manager initialized")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to initialize: {e}")
|
||||
import traceback
|
||||
|
|
@ -176,6 +174,7 @@ def main_single_stream():
|
|||
class_counts[obj.class_name] = class_counts.get(obj.class_name, 0) + 1
|
||||
print(f" Classes: {class_counts}")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n✓ Interrupted by user")
|
||||
|
||||
|
|
@ -206,7 +205,7 @@ def main_multi_stream():
|
|||
|
||||
# Configuration
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "models/yolov8n.pt" # PT file will be auto-converted
|
||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" # Transparent loading: .pt, .engine, or .trt
|
||||
BATCH_SIZE = 16
|
||||
FORCE_TIMEOUT = 0.05
|
||||
|
||||
|
|
@ -241,17 +240,16 @@ def main_multi_stream():
|
|||
)
|
||||
print("✓ Manager created")
|
||||
|
||||
# Initialize with PT model
|
||||
print("\n[2/3] Initializing with PT model...")
|
||||
# Initialize model (transparent loading)
|
||||
print("\n[2/3] Initializing model...")
|
||||
try:
|
||||
manager.initialize(
|
||||
model_path=MODEL_PATH,
|
||||
model_id="detector",
|
||||
preprocess_fn=YOLOv8Utils.preprocess,
|
||||
postprocess_fn=YOLOv8Utils.postprocess,
|
||||
num_contexts=8,
|
||||
pt_input_shapes={"images": (1, 3, 640, 640)},
|
||||
pt_precision=torch.float16
|
||||
num_contexts=8
|
||||
# Note: No pt_input_shapes or pt_precision needed for YOLO models!
|
||||
)
|
||||
print("✓ Manager initialized")
|
||||
except Exception as e:
|
||||
|
|
@ -312,6 +310,7 @@ def main_multi_stream():
|
|||
s_fps = stats['count'] / s_elapsed if s_elapsed > 0 else 0
|
||||
print(f" {sid}: {stats['count']} ({s_fps:.1f} FPS)")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n✓ Interrupted")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue