fix: gpu memory leaks
This commit is contained in:
parent
3a47920186
commit
593611cdb7
13 changed files with 420 additions and 166 deletions
136
debug_trt_output.py
Normal file
136
debug_trt_output.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
Debug script to capture and compare raw PT vs TRT outputs on problematic frames.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import time
|
||||
from services import StreamDecoderFactory, YOLOv8Utils, TensorRTModelRepository
|
||||
from ultralytics import YOLO
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
GPU_ID = 0
|
||||
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
|
||||
STREAM_URL = os.getenv('CAMERA_URL_1')
|
||||
|
||||
# Load models
|
||||
print("Loading models...")
|
||||
pt_model = YOLO(MODEL_PATH)
|
||||
pt_model.to(f'cuda:{GPU_ID}')
|
||||
|
||||
repo = TensorRTModelRepository(gpu_id=GPU_ID)
|
||||
trt_path = "./models/trtptcache/trt/cda5e520441e12fe09a97ac2609da29b4cbac969cc2029ef1735f65697579121.trt"
|
||||
repo.load_model("detector", trt_path, num_contexts=1)
|
||||
|
||||
# Start decoder
|
||||
print("Starting decoder...")
|
||||
decoder_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
||||
decoder = decoder_factory.create_decoder(STREAM_URL, buffer_size=30)
|
||||
decoder.start()
|
||||
time.sleep(2)
|
||||
|
||||
torch.cuda.set_device(GPU_ID)
|
||||
|
||||
print("\nWaiting for frames with TRT false positives...\n")
|
||||
|
||||
frame_count = 0
|
||||
found_issue = False
|
||||
|
||||
while frame_count < 50 and not found_issue:
|
||||
frame = decoder.get_frame()
|
||||
if frame is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# Preprocess
|
||||
preprocessed = YOLOv8Utils.preprocess(frame, input_size=640)
|
||||
|
||||
# Run TRT inference
|
||||
trt_outputs = repo.infer("detector", {"images": preprocessed}, synchronize=True)
|
||||
trt_raw = trt_outputs['output0'] # (1, 5, 8400)
|
||||
|
||||
# Check for the issue - transpose and check channel 4
|
||||
trt_transposed = trt_raw.transpose(1, 2).squeeze(0) # (8400, 5)
|
||||
conf_channel = trt_transposed[:, 4] # (8400,)
|
||||
|
||||
num_high_conf = (conf_channel > 0.25).sum().item()
|
||||
|
||||
if num_high_conf > 100:
|
||||
found_issue = True
|
||||
print(f"🔴 FOUND PROBLEMATIC FRAME {frame_count}!")
|
||||
print(f" TRT detections > 0.25 threshold: {num_high_conf}")
|
||||
|
||||
# Now run PT model on same frame
|
||||
with torch.no_grad():
|
||||
pt_raw = pt_model.model(preprocessed)[0] # (1, 5, 8400)
|
||||
|
||||
print(f"\n=== RAW OUTPUT COMPARISON ===")
|
||||
print(f"PT output shape: {pt_raw.shape}")
|
||||
print(f"TRT output shape: {trt_raw.shape}")
|
||||
|
||||
# Compare channel 4 (confidence)
|
||||
pt_conf = pt_raw.transpose(1, 2).squeeze(0)[:, 4]
|
||||
trt_conf = trt_transposed[:, 4]
|
||||
|
||||
print(f"\n--- Confidence Channel (channel 4) ---")
|
||||
print(f"PT confidence stats:")
|
||||
print(f" Min: {pt_conf.min().item():.6e}")
|
||||
print(f" Max: {pt_conf.max().item():.6e}")
|
||||
print(f" Mean: {pt_conf.mean().item():.6e}")
|
||||
print(f" >0.25: {(pt_conf > 0.25).sum().item()}")
|
||||
print(f" >0.5: {(pt_conf > 0.5).sum().item()}")
|
||||
|
||||
print(f"\nTRT confidence stats:")
|
||||
print(f" Min: {trt_conf.min().item():.6e}")
|
||||
print(f" Max: {trt_conf.max().item():.6e}")
|
||||
print(f" Mean: {trt_conf.mean().item():.6e}")
|
||||
print(f" >0.25: {(trt_conf > 0.25).sum().item()}")
|
||||
print(f" >0.5: {(trt_conf > 0.5).sum().item()}")
|
||||
|
||||
# Check bbox coordinates too
|
||||
print(f"\n--- BBox Coordinates (channels 0-3) ---")
|
||||
pt_bbox = pt_raw.transpose(1, 2).squeeze(0)[:, :4]
|
||||
trt_bbox = trt_transposed[:, :4]
|
||||
|
||||
print(f"PT bbox stats:")
|
||||
print(f" Min: {pt_bbox.min().item():.3f}")
|
||||
print(f" Max: {pt_bbox.max().item():.3f}")
|
||||
print(f" Mean: {pt_bbox.mean().item():.3f}")
|
||||
|
||||
print(f"\nTRT bbox stats:")
|
||||
print(f" Min: {trt_bbox.min().item():.3f}")
|
||||
print(f" Max: {trt_bbox.max().item():.3f}")
|
||||
print(f" Mean: {trt_bbox.mean().item():.3f}")
|
||||
|
||||
# Sample some values
|
||||
print(f"\n--- Sample Values (first 5 anchors) ---")
|
||||
for i in range(5):
|
||||
print(f"\nAnchor {i}:")
|
||||
print(f" PT [cx={pt_bbox[i,0]:.1f}, cy={pt_bbox[i,1]:.1f}, w={pt_bbox[i,2]:.1f}, h={pt_bbox[i,3]:.1f}, conf={pt_conf[i]:.6f}]")
|
||||
print(f" TRT [cx={trt_bbox[i,0]:.1f}, cy={trt_bbox[i,1]:.1f}, w={trt_bbox[i,2]:.1f}, h={trt_bbox[i,3]:.1f}, conf={trt_conf[i]:.6f}]")
|
||||
|
||||
# Find indices with high confidence in TRT
|
||||
high_conf_idx = torch.where(trt_conf > 0.25)[0][:5]
|
||||
print(f"\n--- High Confidence Detections in TRT (first 5) ---")
|
||||
for idx in high_conf_idx:
|
||||
i = idx.item()
|
||||
print(f"\nAnchor {i}:")
|
||||
print(f" PT [cx={pt_bbox[i,0]:.1f}, cy={pt_bbox[i,1]:.1f}, w={pt_bbox[i,2]:.1f}, h={pt_bbox[i,3]:.1f}, conf={pt_conf[i]:.6f}]")
|
||||
print(f" TRT [cx={trt_bbox[i,0]:.1f}, cy={trt_bbox[i,1]:.1f}, w={trt_bbox[i,2]:.1f}, h={trt_bbox[i,3]:.1f}, conf={trt_conf[i]:.6f}]")
|
||||
|
||||
break
|
||||
|
||||
if frame_count % 10 == 0:
|
||||
print(f"Checked {frame_count} frames, no issues yet...")
|
||||
|
||||
if not found_issue:
|
||||
print(f"\n⚠️ No problematic frames found in {frame_count} frames")
|
||||
|
||||
# Cleanup
|
||||
decoder.stop()
|
||||
repo.unload_model("detector")
|
||||
print("\n✓ Done")
|
||||
Loading…
Add table
Add a link
Reference in a new issue