""" Debug script to capture and compare raw PT vs TRT outputs on problematic frames. """ import torch import time from services import StreamDecoderFactory, YOLOv8Utils, TensorRTModelRepository from ultralytics import YOLO import os from dotenv import load_dotenv load_dotenv() GPU_ID = 0 MODEL_PATH = "bangchak/models/frontal_detection_v5.pt" STREAM_URL = os.getenv('CAMERA_URL_1') # Load models print("Loading models...") pt_model = YOLO(MODEL_PATH) pt_model.to(f'cuda:{GPU_ID}') repo = TensorRTModelRepository(gpu_id=GPU_ID) trt_path = "./models/trtptcache/trt/cda5e520441e12fe09a97ac2609da29b4cbac969cc2029ef1735f65697579121.trt" repo.load_model("detector", trt_path, num_contexts=1) # Start decoder print("Starting decoder...") decoder_factory = StreamDecoderFactory(gpu_id=GPU_ID) decoder = decoder_factory.create_decoder(STREAM_URL, buffer_size=30) decoder.start() time.sleep(2) torch.cuda.set_device(GPU_ID) print("\nWaiting for frames with TRT false positives...\n") frame_count = 0 found_issue = False while frame_count < 50 and not found_issue: frame = decoder.get_frame() if frame is None: time.sleep(0.01) continue frame_count += 1 # Preprocess preprocessed = YOLOv8Utils.preprocess(frame, input_size=640) # Run TRT inference trt_outputs = repo.infer("detector", {"images": preprocessed}, synchronize=True) trt_raw = trt_outputs['output0'] # (1, 5, 8400) # Check for the issue - transpose and check channel 4 trt_transposed = trt_raw.transpose(1, 2).squeeze(0) # (8400, 5) conf_channel = trt_transposed[:, 4] # (8400,) num_high_conf = (conf_channel > 0.25).sum().item() if num_high_conf > 100: found_issue = True print(f"🔴 FOUND PROBLEMATIC FRAME {frame_count}!") print(f" TRT detections > 0.25 threshold: {num_high_conf}") # Now run PT model on same frame with torch.no_grad(): pt_raw = pt_model.model(preprocessed)[0] # (1, 5, 8400) print(f"\n=== RAW OUTPUT COMPARISON ===") print(f"PT output shape: {pt_raw.shape}") print(f"TRT output shape: {trt_raw.shape}") # Compare channel 4 (confidence) pt_conf = pt_raw.transpose(1, 2).squeeze(0)[:, 4] trt_conf = trt_transposed[:, 4] print(f"\n--- Confidence Channel (channel 4) ---") print(f"PT confidence stats:") print(f" Min: {pt_conf.min().item():.6e}") print(f" Max: {pt_conf.max().item():.6e}") print(f" Mean: {pt_conf.mean().item():.6e}") print(f" >0.25: {(pt_conf > 0.25).sum().item()}") print(f" >0.5: {(pt_conf > 0.5).sum().item()}") print(f"\nTRT confidence stats:") print(f" Min: {trt_conf.min().item():.6e}") print(f" Max: {trt_conf.max().item():.6e}") print(f" Mean: {trt_conf.mean().item():.6e}") print(f" >0.25: {(trt_conf > 0.25).sum().item()}") print(f" >0.5: {(trt_conf > 0.5).sum().item()}") # Check bbox coordinates too print(f"\n--- BBox Coordinates (channels 0-3) ---") pt_bbox = pt_raw.transpose(1, 2).squeeze(0)[:, :4] trt_bbox = trt_transposed[:, :4] print(f"PT bbox stats:") print(f" Min: {pt_bbox.min().item():.3f}") print(f" Max: {pt_bbox.max().item():.3f}") print(f" Mean: {pt_bbox.mean().item():.3f}") print(f"\nTRT bbox stats:") print(f" Min: {trt_bbox.min().item():.3f}") print(f" Max: {trt_bbox.max().item():.3f}") print(f" Mean: {trt_bbox.mean().item():.3f}") # Sample some values print(f"\n--- Sample Values (first 5 anchors) ---") for i in range(5): print(f"\nAnchor {i}:") print(f" PT [cx={pt_bbox[i,0]:.1f}, cy={pt_bbox[i,1]:.1f}, w={pt_bbox[i,2]:.1f}, h={pt_bbox[i,3]:.1f}, conf={pt_conf[i]:.6f}]") print(f" TRT [cx={trt_bbox[i,0]:.1f}, cy={trt_bbox[i,1]:.1f}, w={trt_bbox[i,2]:.1f}, h={trt_bbox[i,3]:.1f}, conf={trt_conf[i]:.6f}]") # Find indices with high confidence in TRT high_conf_idx = torch.where(trt_conf > 0.25)[0][:5] print(f"\n--- High Confidence Detections in TRT (first 5) ---") for idx in high_conf_idx: i = idx.item() print(f"\nAnchor {i}:") print(f" PT [cx={pt_bbox[i,0]:.1f}, cy={pt_bbox[i,1]:.1f}, w={pt_bbox[i,2]:.1f}, h={pt_bbox[i,3]:.1f}, conf={pt_conf[i]:.6f}]") print(f" TRT [cx={trt_bbox[i,0]:.1f}, cy={trt_bbox[i,1]:.1f}, w={trt_bbox[i,2]:.1f}, h={trt_bbox[i,3]:.1f}, conf={trt_conf[i]:.6f}]") break if frame_count % 10 == 0: print(f"Checked {frame_count} frames, no issues yet...") if not found_issue: print(f"\n⚠️ No problematic frames found in {frame_count} frames") # Cleanup decoder.stop() repo.unload_model("detector") print("\n✓ Done")