222 lines
6.9 KiB
Python
222 lines
6.9 KiB
Python
"""
|
|
Ultralytics YOLO Model Exporter with Caching
|
|
|
|
Exports YOLO .pt models to TensorRT .engine format using Ultralytics library.
|
|
Provides proper NMS and postprocessing built into the engine.
|
|
Caches exported engines to avoid redundant exports.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional, Tuple
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class UltralyticsExporter:
|
|
"""
|
|
Export YOLO models using Ultralytics with caching.
|
|
|
|
Features:
|
|
- Exports .pt models to TensorRT .engine format
|
|
- Caches exported engines by source file hash
|
|
- Saves metadata about exported models
|
|
- Reuses cached engines when available
|
|
"""
|
|
|
|
def __init__(self, cache_dir: str = ".ultralytics_cache"):
|
|
"""
|
|
Initialize exporter.
|
|
|
|
Args:
|
|
cache_dir: Directory for caching exported engines
|
|
"""
|
|
self.cache_dir = Path(cache_dir)
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info(f"Ultralytics exporter cache directory: {self.cache_dir}")
|
|
|
|
@staticmethod
|
|
def compute_file_hash(file_path: str) -> str:
|
|
"""
|
|
Compute SHA256 hash of a file.
|
|
|
|
Args:
|
|
file_path: Path to file
|
|
|
|
Returns:
|
|
Hexadecimal hash string
|
|
"""
|
|
sha256_hash = hashlib.sha256()
|
|
with open(file_path, "rb") as f:
|
|
for byte_block in iter(lambda: f.read(65536), b""):
|
|
sha256_hash.update(byte_block)
|
|
return sha256_hash.hexdigest()
|
|
|
|
def export(
|
|
self,
|
|
model_path: str,
|
|
device: int = 0,
|
|
half: bool = False,
|
|
imgsz: int = 640,
|
|
batch: int = 1,
|
|
**export_kwargs,
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Export YOLO model to TensorRT engine with caching.
|
|
|
|
Args:
|
|
model_path: Path to .pt model file
|
|
device: GPU device ID
|
|
half: Use FP16 precision
|
|
imgsz: Input image size (default: 640)
|
|
batch: Maximum batch size for inference
|
|
**export_kwargs: Additional arguments for Ultralytics export
|
|
|
|
Returns:
|
|
Tuple of (engine_hash, engine_path)
|
|
|
|
Raises:
|
|
FileNotFoundError: If model file doesn't exist
|
|
RuntimeError: If export fails
|
|
"""
|
|
model_path = Path(model_path).resolve()
|
|
|
|
if not model_path.exists():
|
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
|
|
# Compute hash of source model
|
|
logger.info(f"Computing hash for {model_path}...")
|
|
model_hash = self.compute_file_hash(str(model_path))
|
|
logger.info(f"Model hash: {model_hash[:16]}...")
|
|
|
|
# Create export config hash (includes export parameters)
|
|
export_config = {
|
|
"model_hash": model_hash,
|
|
"device": device,
|
|
"half": half,
|
|
"imgsz": imgsz,
|
|
"batch": batch,
|
|
**export_kwargs,
|
|
}
|
|
config_str = json.dumps(export_config, sort_keys=True)
|
|
config_hash = hashlib.sha256(config_str.encode()).hexdigest()
|
|
|
|
# Check cache
|
|
cache_engine_path = self.cache_dir / f"{config_hash}.engine"
|
|
cache_metadata_path = self.cache_dir / f"{config_hash}_metadata.json"
|
|
|
|
if cache_engine_path.exists():
|
|
logger.info(f"Found cached engine: {cache_engine_path}")
|
|
logger.info(f"Reusing cached export (config hash: {config_hash[:16]}...)")
|
|
|
|
# Load and return metadata
|
|
if cache_metadata_path.exists():
|
|
with open(cache_metadata_path, "r") as f:
|
|
metadata = json.load(f)
|
|
logger.info(f"Cached engine metadata: {metadata}")
|
|
|
|
return config_hash, str(cache_engine_path)
|
|
|
|
# Export using Ultralytics
|
|
logger.info(f"Exporting YOLO model to TensorRT engine...")
|
|
logger.info(f" Source: {model_path}")
|
|
logger.info(f" Device: GPU {device}")
|
|
logger.info(f" Precision: {'FP16' if half else 'FP32'}")
|
|
logger.info(f" Image size: {imgsz}")
|
|
logger.info(f" Batch size: {batch}")
|
|
|
|
try:
|
|
from ultralytics import YOLO
|
|
|
|
# Load model
|
|
model = YOLO(str(model_path))
|
|
|
|
# Export to TensorRT
|
|
exported_path = model.export(
|
|
format="engine",
|
|
device=device,
|
|
half=half,
|
|
imgsz=imgsz,
|
|
batch=batch,
|
|
verbose=True,
|
|
**export_kwargs,
|
|
)
|
|
|
|
logger.info(f"Export complete: {exported_path}")
|
|
|
|
# Copy to cache
|
|
import shutil
|
|
|
|
shutil.copy(exported_path, cache_engine_path)
|
|
logger.info(f"Cached engine: {cache_engine_path}")
|
|
|
|
# Save metadata
|
|
metadata = {
|
|
"source_model": str(model_path),
|
|
"model_hash": model_hash,
|
|
"config_hash": config_hash,
|
|
"device": device,
|
|
"half": half,
|
|
"imgsz": imgsz,
|
|
"batch": batch,
|
|
"export_kwargs": export_kwargs,
|
|
"exported_path": str(exported_path),
|
|
"cached_path": str(cache_engine_path),
|
|
}
|
|
|
|
with open(cache_metadata_path, "w") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
logger.info(f"Saved metadata: {cache_metadata_path}")
|
|
|
|
return config_hash, str(cache_engine_path)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Export failed: {e}")
|
|
raise RuntimeError(f"Failed to export YOLO model: {e}")
|
|
|
|
def get_cached_engine(self, model_path: str, **export_kwargs) -> Optional[str]:
|
|
"""
|
|
Get cached engine path if it exists.
|
|
|
|
Args:
|
|
model_path: Path to .pt model
|
|
**export_kwargs: Export parameters (must match cached export)
|
|
|
|
Returns:
|
|
Path to cached engine or None if not cached
|
|
"""
|
|
try:
|
|
model_path = Path(model_path).resolve()
|
|
|
|
if not model_path.exists():
|
|
return None
|
|
|
|
# Compute hashes
|
|
model_hash = self.compute_file_hash(str(model_path))
|
|
|
|
export_config = {"model_hash": model_hash, **export_kwargs}
|
|
config_str = json.dumps(export_config, sort_keys=True)
|
|
config_hash = hashlib.sha256(config_str.encode()).hexdigest()
|
|
|
|
cache_engine_path = self.cache_dir / f"{config_hash}.engine"
|
|
|
|
if cache_engine_path.exists():
|
|
return str(cache_engine_path)
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to check cache: {e}")
|
|
return None
|
|
|
|
def clear_cache(self):
|
|
"""Clear all cached engines"""
|
|
import shutil
|
|
|
|
if self.cache_dir.exists():
|
|
shutil.rmtree(self.cache_dir)
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info("Cache cleared")
|