ultralytic export
This commit is contained in:
parent
bf7b68edb1
commit
fdaeb9981c
14 changed files with 2241 additions and 507 deletions
222
services/ultralytics_exporter.py
Normal file
222
services/ultralytics_exporter.py
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
"""
|
||||
Ultralytics YOLO Model Exporter with Caching
|
||||
|
||||
Exports YOLO .pt models to TensorRT .engine format using Ultralytics library.
|
||||
Provides proper NMS and postprocessing built into the engine.
|
||||
Caches exported engines to avoid redundant exports.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UltralyticsExporter:
|
||||
"""
|
||||
Export YOLO models using Ultralytics with caching.
|
||||
|
||||
Features:
|
||||
- Exports .pt models to TensorRT .engine format
|
||||
- Caches exported engines by source file hash
|
||||
- Saves metadata about exported models
|
||||
- Reuses cached engines when available
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = ".ultralytics_cache"):
|
||||
"""
|
||||
Initialize exporter.
|
||||
|
||||
Args:
|
||||
cache_dir: Directory for caching exported engines
|
||||
"""
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Ultralytics exporter cache directory: {self.cache_dir}")
|
||||
|
||||
@staticmethod
|
||||
def compute_file_hash(file_path: str) -> str:
|
||||
"""
|
||||
Compute SHA256 hash of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to file
|
||||
|
||||
Returns:
|
||||
Hexadecimal hash string
|
||||
"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(65536), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
def export(
|
||||
self,
|
||||
model_path: str,
|
||||
device: int = 0,
|
||||
half: bool = False,
|
||||
imgsz: int = 640,
|
||||
batch: int = 1,
|
||||
**export_kwargs,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Export YOLO model to TensorRT engine with caching.
|
||||
|
||||
Args:
|
||||
model_path: Path to .pt model file
|
||||
device: GPU device ID
|
||||
half: Use FP16 precision
|
||||
imgsz: Input image size (default: 640)
|
||||
batch: Maximum batch size for inference
|
||||
**export_kwargs: Additional arguments for Ultralytics export
|
||||
|
||||
Returns:
|
||||
Tuple of (engine_hash, engine_path)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If model file doesn't exist
|
||||
RuntimeError: If export fails
|
||||
"""
|
||||
model_path = Path(model_path).resolve()
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Compute hash of source model
|
||||
logger.info(f"Computing hash for {model_path}...")
|
||||
model_hash = self.compute_file_hash(str(model_path))
|
||||
logger.info(f"Model hash: {model_hash[:16]}...")
|
||||
|
||||
# Create export config hash (includes export parameters)
|
||||
export_config = {
|
||||
"model_hash": model_hash,
|
||||
"device": device,
|
||||
"half": half,
|
||||
"imgsz": imgsz,
|
||||
"batch": batch,
|
||||
**export_kwargs,
|
||||
}
|
||||
config_str = json.dumps(export_config, sort_keys=True)
|
||||
config_hash = hashlib.sha256(config_str.encode()).hexdigest()
|
||||
|
||||
# Check cache
|
||||
cache_engine_path = self.cache_dir / f"{config_hash}.engine"
|
||||
cache_metadata_path = self.cache_dir / f"{config_hash}_metadata.json"
|
||||
|
||||
if cache_engine_path.exists():
|
||||
logger.info(f"Found cached engine: {cache_engine_path}")
|
||||
logger.info(f"Reusing cached export (config hash: {config_hash[:16]}...)")
|
||||
|
||||
# Load and return metadata
|
||||
if cache_metadata_path.exists():
|
||||
with open(cache_metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
logger.info(f"Cached engine metadata: {metadata}")
|
||||
|
||||
return config_hash, str(cache_engine_path)
|
||||
|
||||
# Export using Ultralytics
|
||||
logger.info(f"Exporting YOLO model to TensorRT engine...")
|
||||
logger.info(f" Source: {model_path}")
|
||||
logger.info(f" Device: GPU {device}")
|
||||
logger.info(f" Precision: {'FP16' if half else 'FP32'}")
|
||||
logger.info(f" Image size: {imgsz}")
|
||||
logger.info(f" Batch size: {batch}")
|
||||
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load model
|
||||
model = YOLO(str(model_path))
|
||||
|
||||
# Export to TensorRT
|
||||
exported_path = model.export(
|
||||
format="engine",
|
||||
device=device,
|
||||
half=half,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
verbose=True,
|
||||
**export_kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"Export complete: {exported_path}")
|
||||
|
||||
# Copy to cache
|
||||
import shutil
|
||||
|
||||
shutil.copy(exported_path, cache_engine_path)
|
||||
logger.info(f"Cached engine: {cache_engine_path}")
|
||||
|
||||
# Save metadata
|
||||
metadata = {
|
||||
"source_model": str(model_path),
|
||||
"model_hash": model_hash,
|
||||
"config_hash": config_hash,
|
||||
"device": device,
|
||||
"half": half,
|
||||
"imgsz": imgsz,
|
||||
"batch": batch,
|
||||
"export_kwargs": export_kwargs,
|
||||
"exported_path": str(exported_path),
|
||||
"cached_path": str(cache_engine_path),
|
||||
}
|
||||
|
||||
with open(cache_metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logger.info(f"Saved metadata: {cache_metadata_path}")
|
||||
|
||||
return config_hash, str(cache_engine_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Export failed: {e}")
|
||||
raise RuntimeError(f"Failed to export YOLO model: {e}")
|
||||
|
||||
def get_cached_engine(self, model_path: str, **export_kwargs) -> Optional[str]:
|
||||
"""
|
||||
Get cached engine path if it exists.
|
||||
|
||||
Args:
|
||||
model_path: Path to .pt model
|
||||
**export_kwargs: Export parameters (must match cached export)
|
||||
|
||||
Returns:
|
||||
Path to cached engine or None if not cached
|
||||
"""
|
||||
try:
|
||||
model_path = Path(model_path).resolve()
|
||||
|
||||
if not model_path.exists():
|
||||
return None
|
||||
|
||||
# Compute hashes
|
||||
model_hash = self.compute_file_hash(str(model_path))
|
||||
|
||||
export_config = {"model_hash": model_hash, **export_kwargs}
|
||||
config_str = json.dumps(export_config, sort_keys=True)
|
||||
config_hash = hashlib.sha256(config_str.encode()).hexdigest()
|
||||
|
||||
cache_engine_path = self.cache_dir / f"{config_hash}.engine"
|
||||
|
||||
if cache_engine_path.exists():
|
||||
return str(cache_engine_path)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check cache: {e}")
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear all cached engines"""
|
||||
import shutil
|
||||
|
||||
if self.cache_dir.exists():
|
||||
shutil.rmtree(self.cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Cache cleared")
|
||||
Loading…
Add table
Add a link
Reference in a new issue