136 lines
4.8 KiB
Python
136 lines
4.8 KiB
Python
"""
|
|
Debug script to capture and compare raw PT vs TRT outputs on problematic frames.
|
|
"""
|
|
|
|
import torch
|
|
import time
|
|
from services import StreamDecoderFactory, YOLOv8Utils, TensorRTModelRepository
|
|
from ultralytics import YOLO
|
|
import os
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
GPU_ID = 0
|
|
MODEL_PATH = "bangchak/models/frontal_detection_v5.pt"
|
|
STREAM_URL = os.getenv('CAMERA_URL_1')
|
|
|
|
# Load models
|
|
print("Loading models...")
|
|
pt_model = YOLO(MODEL_PATH)
|
|
pt_model.to(f'cuda:{GPU_ID}')
|
|
|
|
repo = TensorRTModelRepository(gpu_id=GPU_ID)
|
|
trt_path = "./models/trtptcache/trt/cda5e520441e12fe09a97ac2609da29b4cbac969cc2029ef1735f65697579121.trt"
|
|
repo.load_model("detector", trt_path, num_contexts=1)
|
|
|
|
# Start decoder
|
|
print("Starting decoder...")
|
|
decoder_factory = StreamDecoderFactory(gpu_id=GPU_ID)
|
|
decoder = decoder_factory.create_decoder(STREAM_URL, buffer_size=30)
|
|
decoder.start()
|
|
time.sleep(2)
|
|
|
|
torch.cuda.set_device(GPU_ID)
|
|
|
|
print("\nWaiting for frames with TRT false positives...\n")
|
|
|
|
frame_count = 0
|
|
found_issue = False
|
|
|
|
while frame_count < 50 and not found_issue:
|
|
frame = decoder.get_frame()
|
|
if frame is None:
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
frame_count += 1
|
|
|
|
# Preprocess
|
|
preprocessed = YOLOv8Utils.preprocess(frame, input_size=640)
|
|
|
|
# Run TRT inference
|
|
trt_outputs = repo.infer("detector", {"images": preprocessed}, synchronize=True)
|
|
trt_raw = trt_outputs['output0'] # (1, 5, 8400)
|
|
|
|
# Check for the issue - transpose and check channel 4
|
|
trt_transposed = trt_raw.transpose(1, 2).squeeze(0) # (8400, 5)
|
|
conf_channel = trt_transposed[:, 4] # (8400,)
|
|
|
|
num_high_conf = (conf_channel > 0.25).sum().item()
|
|
|
|
if num_high_conf > 100:
|
|
found_issue = True
|
|
print(f"🔴 FOUND PROBLEMATIC FRAME {frame_count}!")
|
|
print(f" TRT detections > 0.25 threshold: {num_high_conf}")
|
|
|
|
# Now run PT model on same frame
|
|
with torch.no_grad():
|
|
pt_raw = pt_model.model(preprocessed)[0] # (1, 5, 8400)
|
|
|
|
print(f"\n=== RAW OUTPUT COMPARISON ===")
|
|
print(f"PT output shape: {pt_raw.shape}")
|
|
print(f"TRT output shape: {trt_raw.shape}")
|
|
|
|
# Compare channel 4 (confidence)
|
|
pt_conf = pt_raw.transpose(1, 2).squeeze(0)[:, 4]
|
|
trt_conf = trt_transposed[:, 4]
|
|
|
|
print(f"\n--- Confidence Channel (channel 4) ---")
|
|
print(f"PT confidence stats:")
|
|
print(f" Min: {pt_conf.min().item():.6e}")
|
|
print(f" Max: {pt_conf.max().item():.6e}")
|
|
print(f" Mean: {pt_conf.mean().item():.6e}")
|
|
print(f" >0.25: {(pt_conf > 0.25).sum().item()}")
|
|
print(f" >0.5: {(pt_conf > 0.5).sum().item()}")
|
|
|
|
print(f"\nTRT confidence stats:")
|
|
print(f" Min: {trt_conf.min().item():.6e}")
|
|
print(f" Max: {trt_conf.max().item():.6e}")
|
|
print(f" Mean: {trt_conf.mean().item():.6e}")
|
|
print(f" >0.25: {(trt_conf > 0.25).sum().item()}")
|
|
print(f" >0.5: {(trt_conf > 0.5).sum().item()}")
|
|
|
|
# Check bbox coordinates too
|
|
print(f"\n--- BBox Coordinates (channels 0-3) ---")
|
|
pt_bbox = pt_raw.transpose(1, 2).squeeze(0)[:, :4]
|
|
trt_bbox = trt_transposed[:, :4]
|
|
|
|
print(f"PT bbox stats:")
|
|
print(f" Min: {pt_bbox.min().item():.3f}")
|
|
print(f" Max: {pt_bbox.max().item():.3f}")
|
|
print(f" Mean: {pt_bbox.mean().item():.3f}")
|
|
|
|
print(f"\nTRT bbox stats:")
|
|
print(f" Min: {trt_bbox.min().item():.3f}")
|
|
print(f" Max: {trt_bbox.max().item():.3f}")
|
|
print(f" Mean: {trt_bbox.mean().item():.3f}")
|
|
|
|
# Sample some values
|
|
print(f"\n--- Sample Values (first 5 anchors) ---")
|
|
for i in range(5):
|
|
print(f"\nAnchor {i}:")
|
|
print(f" PT [cx={pt_bbox[i,0]:.1f}, cy={pt_bbox[i,1]:.1f}, w={pt_bbox[i,2]:.1f}, h={pt_bbox[i,3]:.1f}, conf={pt_conf[i]:.6f}]")
|
|
print(f" TRT [cx={trt_bbox[i,0]:.1f}, cy={trt_bbox[i,1]:.1f}, w={trt_bbox[i,2]:.1f}, h={trt_bbox[i,3]:.1f}, conf={trt_conf[i]:.6f}]")
|
|
|
|
# Find indices with high confidence in TRT
|
|
high_conf_idx = torch.where(trt_conf > 0.25)[0][:5]
|
|
print(f"\n--- High Confidence Detections in TRT (first 5) ---")
|
|
for idx in high_conf_idx:
|
|
i = idx.item()
|
|
print(f"\nAnchor {i}:")
|
|
print(f" PT [cx={pt_bbox[i,0]:.1f}, cy={pt_bbox[i,1]:.1f}, w={pt_bbox[i,2]:.1f}, h={pt_bbox[i,3]:.1f}, conf={pt_conf[i]:.6f}]")
|
|
print(f" TRT [cx={trt_bbox[i,0]:.1f}, cy={trt_bbox[i,1]:.1f}, w={trt_bbox[i,2]:.1f}, h={trt_bbox[i,3]:.1f}, conf={trt_conf[i]:.6f}]")
|
|
|
|
break
|
|
|
|
if frame_count % 10 == 0:
|
|
print(f"Checked {frame_count} frames, no issues yet...")
|
|
|
|
if not found_issue:
|
|
print(f"\n⚠️ No problematic frames found in {frame_count} frames")
|
|
|
|
# Cleanup
|
|
decoder.stop()
|
|
repo.unload_model("detector")
|
|
print("\n✓ Done")
|