164 lines
No EOL
6.7 KiB
Python
164 lines
No EOL
6.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for car_frontal_detection_v1.pt model
|
|
Usage: python test.py --image <image_path> [--confidence <threshold>] [--save-output]
|
|
"""
|
|
# python test.py --image sample.jpg --confidence 0.6 --save-output
|
|
|
|
import argparse
|
|
import cv2
|
|
import torch
|
|
import numpy as np
|
|
from pathlib import Path
|
|
import sys
|
|
|
|
def load_model_direct(model_path):
|
|
"""Load model directly with torch.load to handle version compatibility"""
|
|
try:
|
|
# Try to load with weights_only=False for compatibility
|
|
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
|
print(f"Model checkpoint keys: {list(checkpoint.keys())}")
|
|
|
|
# Try to get model info
|
|
if 'model' in checkpoint:
|
|
model_info = checkpoint.get('model', {})
|
|
print(f"Model info available: {hasattr(model_info, 'names') if hasattr(model_info, 'names') else 'No names found'}")
|
|
|
|
return checkpoint
|
|
except Exception as e:
|
|
print(f"Direct torch.load failed: {e}")
|
|
return None
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Test car frontal detection model')
|
|
parser.add_argument('--image', required=True, help='Path to input image')
|
|
parser.add_argument('--model', default='car_frontal_detection_v1.pt', help='Path to model file')
|
|
parser.add_argument('--confidence', type=float, default=0.5, help='Confidence threshold (default: 0.5)')
|
|
parser.add_argument('--save-output', action='store_true', help='Save output image with detections')
|
|
parser.add_argument('--output-dir', default='output', help='Output directory for results')
|
|
parser.add_argument('--use-yolo', action='store_true', default=True, help='Use YOLO loading (default: True)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Check if model file exists
|
|
if not Path(args.model).exists():
|
|
print(f"Error: Model file '{args.model}' not found")
|
|
sys.exit(1)
|
|
|
|
# Check if image file exists
|
|
if not Path(args.image).exists():
|
|
print(f"Error: Image file '{args.image}' not found")
|
|
sys.exit(1)
|
|
|
|
print(f"Loading model: {args.model}")
|
|
|
|
model = None
|
|
if args.use_yolo:
|
|
try:
|
|
from ultralytics import YOLO
|
|
model = YOLO(args.model)
|
|
print(f"Model loaded successfully with YOLO")
|
|
print(f"Model classes: {model.names}")
|
|
except Exception as e:
|
|
print(f"Error loading model with YOLO: {e}")
|
|
print("Falling back to direct loading...")
|
|
|
|
if model is None:
|
|
# Try direct loading for inspection
|
|
checkpoint = load_model_direct(args.model)
|
|
if checkpoint is None:
|
|
print("Failed to load model with any method")
|
|
sys.exit(1)
|
|
|
|
print("Model loaded directly - this is for inspection only")
|
|
print("Available keys in checkpoint:", list(checkpoint.keys()))
|
|
|
|
# Try to get model information
|
|
if 'model' in checkpoint:
|
|
model_obj = checkpoint['model']
|
|
print(f"Model object type: {type(model_obj)}")
|
|
if hasattr(model_obj, 'names'):
|
|
print(f"Model classes: {model_obj.names}")
|
|
if hasattr(model_obj, 'yaml'):
|
|
print(f"Model YAML config available: {bool(model_obj.yaml)}")
|
|
|
|
print("\nTo run inference, you need a compatible Ultralytics version.")
|
|
print("Consider upgrading ultralytics: pip install ultralytics --upgrade")
|
|
return
|
|
|
|
print(f"Loading image: {args.image}")
|
|
try:
|
|
image = cv2.imread(args.image)
|
|
if image is None:
|
|
raise ValueError("Could not load image")
|
|
print(f"Image shape: {image.shape}")
|
|
except Exception as e:
|
|
print(f"Error loading image: {e}")
|
|
sys.exit(1)
|
|
|
|
print(f"Running inference with confidence threshold: {args.confidence}")
|
|
try:
|
|
results = model(image, conf=args.confidence)
|
|
|
|
if len(results) > 0 and len(results[0].boxes) > 0:
|
|
print(f"Detections found: {len(results[0].boxes)}")
|
|
|
|
# Print detection details
|
|
for i, box in enumerate(results[0].boxes):
|
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
|
conf = box.conf[0].cpu().numpy()
|
|
cls = int(box.cls[0].cpu().numpy())
|
|
class_name = model.names[cls] if cls in model.names else f"Class_{cls}"
|
|
|
|
print(f"Detection {i+1}:")
|
|
print(f" Class: {class_name}")
|
|
print(f" Confidence: {conf:.3f}")
|
|
print(f" Bounding box: ({x1:.1f}, {y1:.1f}, {x2:.1f}, {y2:.1f})")
|
|
print(f" Size: {x2-x1:.1f}x{y2-y1:.1f}")
|
|
else:
|
|
print("No detections found")
|
|
|
|
if args.save_output:
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
# Draw detections on image
|
|
annotated_image = results[0].plot()
|
|
|
|
# Save annotated image
|
|
input_path = Path(args.image)
|
|
output_path = output_dir / f"{input_path.stem}_detected{input_path.suffix}"
|
|
cv2.imwrite(str(output_path), annotated_image)
|
|
print(f"Output saved to: {output_path}")
|
|
|
|
# Also save results as text
|
|
results_path = output_dir / f"{input_path.stem}_results.txt"
|
|
with open(results_path, 'w') as f:
|
|
f.write(f"Model: {args.model}\n")
|
|
f.write(f"Image: {args.image}\n")
|
|
f.write(f"Confidence threshold: {args.confidence}\n")
|
|
f.write(f"Detections: {len(results[0].boxes) if len(results) > 0 else 0}\n\n")
|
|
|
|
if len(results) > 0 and len(results[0].boxes) > 0:
|
|
for i, box in enumerate(results[0].boxes):
|
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
|
conf = box.conf[0].cpu().numpy()
|
|
cls = int(box.cls[0].cpu().numpy())
|
|
class_name = model.names[cls] if cls in model.names else f"Class_{cls}"
|
|
|
|
f.write(f"Detection {i+1}:\n")
|
|
f.write(f" Class: {class_name}\n")
|
|
f.write(f" Confidence: {conf:.3f}\n")
|
|
f.write(f" Bounding box: ({x1:.1f}, {y1:.1f}, {x2:.1f}, {y2:.1f})\n")
|
|
f.write(f" Size: {x2-x1:.1f}x{y2-y1:.1f}\n\n")
|
|
|
|
print(f"Results saved to: {results_path}")
|
|
|
|
except Exception as e:
|
|
print(f"Error during inference: {e}")
|
|
sys.exit(1)
|
|
|
|
print("Test completed successfully!")
|
|
|
|
if __name__ == "__main__":
|
|
main() |