85 lines
2.4 KiB
Python
85 lines
2.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Quick verification script for TensorRT model
|
|
"""
|
|
|
|
import torch
|
|
from services.model_repository import TensorRTModelRepository
|
|
|
|
def verify_model():
|
|
print("=" * 80)
|
|
print("TensorRT Model Verification")
|
|
print("=" * 80)
|
|
|
|
# Initialize repository
|
|
repo = TensorRTModelRepository(gpu_id=0, default_num_contexts=2)
|
|
|
|
# Load the model
|
|
print("\nLoading YOLOv8n TensorRT engine...")
|
|
try:
|
|
metadata = repo.load_model(
|
|
model_id="yolov8n_test",
|
|
file_path="models/yolov8n.trt",
|
|
num_contexts=2
|
|
)
|
|
print("✓ Model loaded successfully!")
|
|
except Exception as e:
|
|
print(f"✗ Failed to load model: {e}")
|
|
return
|
|
|
|
# Get model info
|
|
print("\n" + "=" * 80)
|
|
print("Model Information")
|
|
print("=" * 80)
|
|
info = repo.get_model_info("yolov8n_test")
|
|
if info:
|
|
print(f"Model ID: {info['model_id']}")
|
|
print(f"File: {info['file_path']}")
|
|
print(f"File hash: {info['file_hash']}")
|
|
print(f"\nInputs:")
|
|
for name, spec in info['inputs'].items():
|
|
print(f" {name}: {spec['shape']} ({spec['dtype']})")
|
|
print(f"\nOutputs:")
|
|
for name, spec in info['outputs'].items():
|
|
print(f" {name}: {spec['shape']} ({spec['dtype']})")
|
|
|
|
# Run test inference
|
|
print("\n" + "=" * 80)
|
|
print("Running Test Inference")
|
|
print("=" * 80)
|
|
|
|
try:
|
|
# Create dummy input (simulating a 640x640 image)
|
|
input_tensor = torch.rand(1, 3, 640, 640, dtype=torch.float32, device='cuda:0')
|
|
print(f"Input tensor: {input_tensor.shape} on {input_tensor.device}")
|
|
|
|
# Run inference
|
|
outputs = repo.infer(
|
|
model_id="yolov8n_test",
|
|
inputs={"images": input_tensor},
|
|
synchronize=True
|
|
)
|
|
|
|
print("\n✓ Inference successful!")
|
|
print("\nOutputs:")
|
|
for name, tensor in outputs.items():
|
|
print(f" {name}: {tensor.shape} on {tensor.device} ({tensor.dtype})")
|
|
|
|
except Exception as e:
|
|
print(f"\n✗ Inference failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
# Cleanup
|
|
print("\n" + "=" * 80)
|
|
print("Cleanup")
|
|
print("=" * 80)
|
|
repo.unload_model("yolov8n_test")
|
|
print("✓ Model unloaded")
|
|
|
|
print("\n" + "=" * 80)
|
|
print("Verification Complete!")
|
|
print("=" * 80)
|
|
|
|
if __name__ == "__main__":
|
|
verify_model()
|