python-rtsp-worker/verify_tensorrt_model.py

85 lines
2.4 KiB
Python

#!/usr/bin/env python3
"""
Quick verification script for TensorRT model
"""
import torch
from services.model_repository import TensorRTModelRepository
def verify_model():
print("=" * 80)
print("TensorRT Model Verification")
print("=" * 80)
# Initialize repository
repo = TensorRTModelRepository(gpu_id=0, default_num_contexts=2)
# Load the model
print("\nLoading YOLOv8n TensorRT engine...")
try:
metadata = repo.load_model(
model_id="yolov8n_test",
file_path="models/yolov8n.trt",
num_contexts=2
)
print("✓ Model loaded successfully!")
except Exception as e:
print(f"✗ Failed to load model: {e}")
return
# Get model info
print("\n" + "=" * 80)
print("Model Information")
print("=" * 80)
info = repo.get_model_info("yolov8n_test")
if info:
print(f"Model ID: {info['model_id']}")
print(f"File: {info['file_path']}")
print(f"File hash: {info['file_hash']}")
print(f"\nInputs:")
for name, spec in info['inputs'].items():
print(f" {name}: {spec['shape']} ({spec['dtype']})")
print(f"\nOutputs:")
for name, spec in info['outputs'].items():
print(f" {name}: {spec['shape']} ({spec['dtype']})")
# Run test inference
print("\n" + "=" * 80)
print("Running Test Inference")
print("=" * 80)
try:
# Create dummy input (simulating a 640x640 image)
input_tensor = torch.rand(1, 3, 640, 640, dtype=torch.float32, device='cuda:0')
print(f"Input tensor: {input_tensor.shape} on {input_tensor.device}")
# Run inference
outputs = repo.infer(
model_id="yolov8n_test",
inputs={"images": input_tensor},
synchronize=True
)
print("\n✓ Inference successful!")
print("\nOutputs:")
for name, tensor in outputs.items():
print(f" {name}: {tensor.shape} on {tensor.device} ({tensor.dtype})")
except Exception as e:
print(f"\n✗ Inference failed: {e}")
import traceback
traceback.print_exc()
# Cleanup
print("\n" + "=" * 80)
print("Cleanup")
print("=" * 80)
repo.unload_model("yolov8n_test")
print("✓ Model unloaded")
print("\n" + "=" * 80)
print("Verification Complete!")
print("=" * 80)
if __name__ == "__main__":
verify_model()