fix: gpu memory leaks

This commit is contained in:
Siwat Sirichai 2025-11-10 22:10:46 +07:00
parent 3a47920186
commit 593611cdb7
13 changed files with 420 additions and 166 deletions

136
debug_trt_output.py Normal file
View file

@ -0,0 +1,136 @@
"""
Debug script to capture and compare raw PT vs TRT outputs on problematic frames.
"""
import torch
import time
from services import StreamDecoderFactory, YOLOv8Utils, TensorRTModelRepository
from ultralytics import YOLO
import os
from dotenv import load_dotenv
load_dotenv()
GPU_ID = 0
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
STREAM_URL = os.getenv('CAMERA_URL_1')
# Load models
print("Loading models...")
pt_model = YOLO(MODEL_PATH)
pt_model.to(f'cuda:{GPU_ID}')
repo = TensorRTModelRepository(gpu_id=GPU_ID)
trt_path = "./models/trtptcache/trt/cda5e520441e12fe09a97ac2609da29b4cbac969cc2029ef1735f65697579121.trt"
repo.load_model("detector", trt_path, num_contexts=1)
# Start decoder
print("Starting decoder...")
decoder_factory = StreamDecoderFactory(gpu_id=GPU_ID)
decoder = decoder_factory.create_decoder(STREAM_URL, buffer_size=30)
decoder.start()
time.sleep(2)
torch.cuda.set_device(GPU_ID)
print("\nWaiting for frames with TRT false positives...\n")
frame_count = 0
found_issue = False
while frame_count < 50 and not found_issue:
frame = decoder.get_frame()
if frame is None:
time.sleep(0.01)
continue
frame_count += 1
# Preprocess
preprocessed = YOLOv8Utils.preprocess(frame, input_size=640)
# Run TRT inference
trt_outputs = repo.infer("detector", {"images": preprocessed}, synchronize=True)
trt_raw = trt_outputs['output0'] # (1, 5, 8400)
# Check for the issue - transpose and check channel 4
trt_transposed = trt_raw.transpose(1, 2).squeeze(0) # (8400, 5)
conf_channel = trt_transposed[:, 4] # (8400,)
num_high_conf = (conf_channel > 0.25).sum().item()
if num_high_conf > 100:
found_issue = True
print(f"🔴 FOUND PROBLEMATIC FRAME {frame_count}!")
print(f" TRT detections > 0.25 threshold: {num_high_conf}")
# Now run PT model on same frame
with torch.no_grad():
pt_raw = pt_model.model(preprocessed)[0] # (1, 5, 8400)
print(f"\n=== RAW OUTPUT COMPARISON ===")
print(f"PT output shape: {pt_raw.shape}")
print(f"TRT output shape: {trt_raw.shape}")
# Compare channel 4 (confidence)
pt_conf = pt_raw.transpose(1, 2).squeeze(0)[:, 4]
trt_conf = trt_transposed[:, 4]
print(f"\n--- Confidence Channel (channel 4) ---")
print(f"PT confidence stats:")
print(f" Min: {pt_conf.min().item():.6e}")
print(f" Max: {pt_conf.max().item():.6e}")
print(f" Mean: {pt_conf.mean().item():.6e}")
print(f" >0.25: {(pt_conf > 0.25).sum().item()}")
print(f" >0.5: {(pt_conf > 0.5).sum().item()}")
print(f"\nTRT confidence stats:")
print(f" Min: {trt_conf.min().item():.6e}")
print(f" Max: {trt_conf.max().item():.6e}")
print(f" Mean: {trt_conf.mean().item():.6e}")
print(f" >0.25: {(trt_conf > 0.25).sum().item()}")
print(f" >0.5: {(trt_conf > 0.5).sum().item()}")
# Check bbox coordinates too
print(f"\n--- BBox Coordinates (channels 0-3) ---")
pt_bbox = pt_raw.transpose(1, 2).squeeze(0)[:, :4]
trt_bbox = trt_transposed[:, :4]
print(f"PT bbox stats:")
print(f" Min: {pt_bbox.min().item():.3f}")
print(f" Max: {pt_bbox.max().item():.3f}")
print(f" Mean: {pt_bbox.mean().item():.3f}")
print(f"\nTRT bbox stats:")
print(f" Min: {trt_bbox.min().item():.3f}")
print(f" Max: {trt_bbox.max().item():.3f}")
print(f" Mean: {trt_bbox.mean().item():.3f}")
# Sample some values
print(f"\n--- Sample Values (first 5 anchors) ---")
for i in range(5):
print(f"\nAnchor {i}:")
print(f" PT [cx={pt_bbox[i,0]:.1f}, cy={pt_bbox[i,1]:.1f}, w={pt_bbox[i,2]:.1f}, h={pt_bbox[i,3]:.1f}, conf={pt_conf[i]:.6f}]")
print(f" TRT [cx={trt_bbox[i,0]:.1f}, cy={trt_bbox[i,1]:.1f}, w={trt_bbox[i,2]:.1f}, h={trt_bbox[i,3]:.1f}, conf={trt_conf[i]:.6f}]")
# Find indices with high confidence in TRT
high_conf_idx = torch.where(trt_conf > 0.25)[0][:5]
print(f"\n--- High Confidence Detections in TRT (first 5) ---")
for idx in high_conf_idx:
i = idx.item()
print(f"\nAnchor {i}:")
print(f" PT [cx={pt_bbox[i,0]:.1f}, cy={pt_bbox[i,1]:.1f}, w={pt_bbox[i,2]:.1f}, h={pt_bbox[i,3]:.1f}, conf={pt_conf[i]:.6f}]")
print(f" TRT [cx={trt_bbox[i,0]:.1f}, cy={trt_bbox[i,1]:.1f}, w={trt_bbox[i,2]:.1f}, h={trt_bbox[i,3]:.1f}, conf={trt_conf[i]:.6f}]")
break
if frame_count % 10 == 0:
print(f"Checked {frame_count} frames, no issues yet...")
if not found_issue:
print(f"\n⚠️ No problematic frames found in {frame_count} frames")
# Cleanup
decoder.stop()
repo.unload_model("detector")
print("\n✓ Done")