ultralytic export
This commit is contained in:
parent
bf7b68edb1
commit
fdaeb9981c
14 changed files with 2241 additions and 507 deletions
635
services/inference_engine.py
Normal file
635
services/inference_engine.py
Normal file
|
|
@ -0,0 +1,635 @@
|
|||
"""
|
||||
Inference Engine Abstraction Layer
|
||||
|
||||
Provides a unified interface for different inference backends:
|
||||
- Native TensorRT: Direct TensorRT API with zero-copy GPU tensors
|
||||
- Ultralytics: YOLO models with built-in pre/postprocessing
|
||||
- Future: ONNX Runtime, OpenVINO, etc.
|
||||
|
||||
All engines support zero-copy GPU tensor inference where possible.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BackendType(Enum):
|
||||
"""Supported inference backend types"""
|
||||
|
||||
TENSORRT = "tensorrt"
|
||||
ULTRALYTICS = "ultralytics"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, backend: str) -> "BackendType":
|
||||
"""Convert string to BackendType"""
|
||||
backend = backend.lower()
|
||||
for member in cls:
|
||||
if member.value == backend:
|
||||
return member
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. Available: {[m.value for m in cls]}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineMetadata:
|
||||
"""Metadata for an inference engine"""
|
||||
|
||||
engine_type: str # "tensorrt", "ultralytics", etc.
|
||||
model_path: str
|
||||
input_shapes: Dict[str, Tuple[int, ...]]
|
||||
output_shapes: Dict[str, Tuple[int, ...]]
|
||||
input_names: List[str]
|
||||
output_names: List[str]
|
||||
input_dtypes: Dict[str, torch.dtype]
|
||||
output_dtypes: Dict[str, torch.dtype]
|
||||
supports_batching: bool = True
|
||||
supports_dynamic_shapes: bool = False
|
||||
extra_info: Dict[str, Any] = None # Backend-specific info
|
||||
|
||||
|
||||
class IInferenceEngine(ABC):
|
||||
"""
|
||||
Abstract interface for inference engines.
|
||||
|
||||
All implementations must support zero-copy GPU tensor inference:
|
||||
- Inputs: CUDA tensors on GPU
|
||||
- Outputs: CUDA tensors on GPU
|
||||
- No CPU transfers during inference
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def initialize(
|
||||
self, model_path: str, device: torch.device, **kwargs
|
||||
) -> EngineMetadata:
|
||||
"""
|
||||
Initialize the inference engine.
|
||||
|
||||
Automatically detects model type and handles conversion if needed.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file (.pt, .engine, .trt)
|
||||
device: GPU device to use
|
||||
**kwargs: Optional parameters (batch_size, half, workspace, etc.)
|
||||
|
||||
Returns:
|
||||
EngineMetadata with model information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def infer(
|
||||
self, inputs: Dict[str, torch.Tensor], **kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Run inference on GPU tensors (zero-copy).
|
||||
|
||||
Args:
|
||||
inputs: Dict of input_name -> CUDA tensor
|
||||
**kwargs: Backend-specific inference parameters
|
||||
|
||||
Returns:
|
||||
Dict of output_name -> CUDA tensor
|
||||
|
||||
Raises:
|
||||
ValueError: If inputs are not CUDA tensors or wrong shape
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self) -> EngineMetadata:
|
||||
"""Get engine metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if engine is initialized"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def device(self) -> torch.device:
|
||||
"""Get device the engine is running on"""
|
||||
pass
|
||||
|
||||
|
||||
class NativeTensorRTEngine(IInferenceEngine):
|
||||
"""
|
||||
Native TensorRT inference engine with direct API access.
|
||||
|
||||
Features:
|
||||
- Zero-copy GPU tensor inference
|
||||
- Execution context pooling for concurrent inference
|
||||
- Support for .trt, .engine files
|
||||
- Automatic Ultralytics .engine metadata stripping
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._engine = None
|
||||
self._contexts = []
|
||||
self._metadata = None
|
||||
self._device = None
|
||||
self._trt_logger = None
|
||||
|
||||
def initialize(
|
||||
self, model_path: str, device: torch.device, num_contexts: int = 1, **kwargs
|
||||
) -> EngineMetadata:
|
||||
"""
|
||||
Initialize TensorRT engine.
|
||||
|
||||
Args:
|
||||
model_path: Path to .trt or .engine file
|
||||
device: GPU device
|
||||
num_contexts: Number of execution contexts for pooling
|
||||
|
||||
Returns:
|
||||
EngineMetadata
|
||||
"""
|
||||
import tensorrt as trt
|
||||
|
||||
self._device = device
|
||||
self._trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
# Load engine
|
||||
runtime = trt.Runtime(self._trt_logger)
|
||||
|
||||
# Read engine file (handle Ultralytics format)
|
||||
engine_data = self._load_engine_data(model_path)
|
||||
|
||||
self._engine = runtime.deserialize_cuda_engine(engine_data)
|
||||
if self._engine is None:
|
||||
raise RuntimeError(f"Failed to load TensorRT engine from {model_path}")
|
||||
|
||||
# Create execution contexts
|
||||
for i in range(num_contexts):
|
||||
ctx = self._engine.create_execution_context()
|
||||
if ctx is None:
|
||||
raise RuntimeError(f"Failed to create execution context {i}")
|
||||
self._contexts.append(ctx)
|
||||
|
||||
# Extract metadata
|
||||
self._metadata = self._extract_metadata(model_path)
|
||||
|
||||
return self._metadata
|
||||
|
||||
def _load_engine_data(self, file_path: str) -> bytes:
|
||||
"""Load engine data, stripping Ultralytics metadata if present"""
|
||||
import json
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
# Try to read Ultralytics metadata header
|
||||
meta_len_bytes = f.read(4)
|
||||
if len(meta_len_bytes) == 4:
|
||||
meta_len = int.from_bytes(meta_len_bytes, byteorder="little")
|
||||
|
||||
# Sanity check
|
||||
if 0 < meta_len < 100000:
|
||||
try:
|
||||
metadata_bytes = f.read(meta_len)
|
||||
json.loads(metadata_bytes.decode("utf-8"))
|
||||
# Valid Ultralytics metadata, rest is engine
|
||||
return f.read()
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
pass
|
||||
|
||||
# Not Ultralytics format, read entire file
|
||||
f.seek(0)
|
||||
return f.read()
|
||||
|
||||
def _extract_metadata(self, model_path: str) -> EngineMetadata:
|
||||
"""Extract metadata from TensorRT engine"""
|
||||
import tensorrt as trt
|
||||
|
||||
input_shapes = {}
|
||||
output_shapes = {}
|
||||
input_names = []
|
||||
output_names = []
|
||||
input_dtypes = {}
|
||||
output_dtypes = {}
|
||||
|
||||
trt_to_torch_dtype = {
|
||||
trt.DataType.FLOAT: torch.float32,
|
||||
trt.DataType.HALF: torch.float16,
|
||||
trt.DataType.INT8: torch.int8,
|
||||
trt.DataType.INT32: torch.int32,
|
||||
trt.DataType.BOOL: torch.bool,
|
||||
}
|
||||
|
||||
for i in range(self._engine.num_io_tensors):
|
||||
name = self._engine.get_tensor_name(i)
|
||||
shape = tuple(self._engine.get_tensor_shape(name))
|
||||
dtype = trt_to_torch_dtype.get(
|
||||
self._engine.get_tensor_dtype(name), torch.float32
|
||||
)
|
||||
mode = self._engine.get_tensor_mode(name)
|
||||
|
||||
if mode == trt.TensorIOMode.INPUT:
|
||||
input_names.append(name)
|
||||
input_shapes[name] = shape
|
||||
input_dtypes[name] = dtype
|
||||
else:
|
||||
output_names.append(name)
|
||||
output_shapes[name] = shape
|
||||
output_dtypes[name] = dtype
|
||||
|
||||
return EngineMetadata(
|
||||
engine_type="tensorrt",
|
||||
model_path=model_path,
|
||||
input_shapes=input_shapes,
|
||||
output_shapes=output_shapes,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
input_dtypes=input_dtypes,
|
||||
output_dtypes=output_dtypes,
|
||||
supports_batching=True,
|
||||
supports_dynamic_shapes=False,
|
||||
)
|
||||
|
||||
def infer(
|
||||
self,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
context_id: int = 0,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Run TensorRT inference with zero-copy GPU tensors.
|
||||
|
||||
Args:
|
||||
inputs: Dict of input_name -> CUDA tensor
|
||||
context_id: Which execution context to use
|
||||
stream: CUDA stream for async execution
|
||||
|
||||
Returns:
|
||||
Dict of output_name -> CUDA tensor
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Engine not initialized")
|
||||
|
||||
# Validate inputs
|
||||
for name in self._metadata.input_names:
|
||||
if name not in inputs:
|
||||
raise ValueError(f"Missing required input: {name}")
|
||||
if not inputs[name].is_cuda:
|
||||
raise ValueError(f"Input '{name}' must be a CUDA tensor")
|
||||
|
||||
# Get execution context
|
||||
if context_id >= len(self._contexts):
|
||||
raise ValueError(
|
||||
f"Invalid context_id {context_id}, only {len(self._contexts)} contexts available"
|
||||
)
|
||||
|
||||
context = self._contexts[context_id]
|
||||
|
||||
# Prepare outputs
|
||||
outputs = {}
|
||||
|
||||
# Set input tensor addresses
|
||||
for name in self._metadata.input_names:
|
||||
input_tensor = inputs[name].contiguous()
|
||||
context.set_tensor_address(name, input_tensor.data_ptr())
|
||||
|
||||
# Allocate and set output tensors
|
||||
for name in self._metadata.output_names:
|
||||
output_tensor = torch.empty(
|
||||
self._metadata.output_shapes[name],
|
||||
dtype=self._metadata.output_dtypes[name],
|
||||
device=self._device,
|
||||
)
|
||||
outputs[name] = output_tensor
|
||||
context.set_tensor_address(name, output_tensor.data_ptr())
|
||||
|
||||
# Execute
|
||||
if stream is None:
|
||||
stream = torch.cuda.Stream(device=self._device)
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
success = context.execute_async_v3(stream_handle=stream.cuda_stream)
|
||||
if not success:
|
||||
raise RuntimeError("TensorRT inference failed")
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
return outputs
|
||||
|
||||
def get_metadata(self) -> EngineMetadata:
|
||||
"""Get engine metadata"""
|
||||
if self._metadata is None:
|
||||
raise RuntimeError("Engine not initialized")
|
||||
return self._metadata
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup TensorRT resources"""
|
||||
for ctx in self._contexts:
|
||||
del ctx
|
||||
self._contexts.clear()
|
||||
|
||||
if self._engine is not None:
|
||||
del self._engine
|
||||
self._engine = None
|
||||
|
||||
self._metadata = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._engine is not None
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._device
|
||||
|
||||
|
||||
class UltralyticsEngine(IInferenceEngine):
|
||||
"""
|
||||
Ultralytics YOLO inference engine.
|
||||
|
||||
Features:
|
||||
- Zero-copy GPU tensor inference
|
||||
- Built-in preprocessing/postprocessing for YOLO models
|
||||
- Supports .pt, .engine formats
|
||||
- Automatic model export to TensorRT with caching
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._model = None
|
||||
self._metadata = None
|
||||
self._device = None
|
||||
self._model_path = None
|
||||
self._exporter = None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
model_path: str,
|
||||
device: torch.device,
|
||||
batch: int = 1,
|
||||
half: bool = False,
|
||||
imgsz: int = 640,
|
||||
cache_dir: str = ".ultralytics_cache",
|
||||
**kwargs,
|
||||
) -> EngineMetadata:
|
||||
"""
|
||||
Initialize Ultralytics YOLO model.
|
||||
|
||||
Automatically exports .pt models to .engine format with caching.
|
||||
|
||||
Args:
|
||||
model_path: Path to .pt or .engine file
|
||||
device: GPU device
|
||||
batch: Maximum batch size for inference
|
||||
half: Use FP16 precision
|
||||
imgsz: Input image size
|
||||
cache_dir: Directory for caching exported engines
|
||||
**kwargs: Additional export parameters
|
||||
|
||||
Returns:
|
||||
EngineMetadata
|
||||
"""
|
||||
from ultralytics import YOLO
|
||||
|
||||
from .ultralytics_exporter import UltralyticsExporter
|
||||
|
||||
self._device = device
|
||||
self._model_path = model_path
|
||||
|
||||
# Check if we need to export
|
||||
model_file = Path(model_path)
|
||||
final_model_path = model_path
|
||||
|
||||
if model_file.suffix == ".pt":
|
||||
# Use exporter with caching
|
||||
print(f"Checking for cached TensorRT engine...")
|
||||
self._exporter = UltralyticsExporter(cache_dir=cache_dir)
|
||||
|
||||
_, engine_path = self._exporter.export(
|
||||
model_path=str(model_path),
|
||||
device=device.index if device.type == "cuda" else 0,
|
||||
half=half,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
final_model_path = engine_path
|
||||
print(f"Using TensorRT engine: {engine_path}")
|
||||
|
||||
# Load model (Ultralytics handles .engine files natively)
|
||||
self._model = YOLO(final_model_path)
|
||||
|
||||
# Move to device if needed (only for .pt models, .engine already on specific device)
|
||||
if hasattr(self._model, "model") and self._model.model is not None:
|
||||
# Check if it's actually a torch model (not a string path for .engine files)
|
||||
if hasattr(self._model.model, "to"):
|
||||
self._model.model = self._model.model.to(device)
|
||||
|
||||
# Extract metadata
|
||||
self._metadata = self._extract_metadata()
|
||||
|
||||
return self._metadata
|
||||
|
||||
def _extract_metadata(self) -> EngineMetadata:
|
||||
"""Extract metadata from Ultralytics model"""
|
||||
# Ultralytics models typically expect (B, 3, H, W) input
|
||||
# and return Results objects, not raw tensors
|
||||
|
||||
# Default values
|
||||
batch_size = -1 # Dynamic batching by default
|
||||
imgsz = 640
|
||||
input_shape = (batch_size, 3, imgsz, imgsz)
|
||||
|
||||
if hasattr(self._model, "model") and self._model.model is not None:
|
||||
# Try to get actual input shape from model
|
||||
try:
|
||||
# For .engine files, check predictor model
|
||||
if (
|
||||
hasattr(self._model, "predictor")
|
||||
and self._model.predictor is not None
|
||||
):
|
||||
predictor = self._model.predictor
|
||||
|
||||
# Get image size
|
||||
if hasattr(predictor, "args") and hasattr(predictor.args, "imgsz"):
|
||||
imgsz_val = predictor.args.imgsz
|
||||
if isinstance(imgsz_val, (list, tuple)):
|
||||
h, w = (
|
||||
imgsz_val[0],
|
||||
imgsz_val[1] if len(imgsz_val) > 1 else imgsz_val[0],
|
||||
)
|
||||
else:
|
||||
h = w = imgsz_val
|
||||
imgsz = h # Use height as reference
|
||||
|
||||
# Get batch size from model
|
||||
if hasattr(predictor, "model"):
|
||||
pred_model = predictor.model
|
||||
|
||||
# For TensorRT engines, check input bindings
|
||||
if hasattr(pred_model, "bindings"):
|
||||
# This is a TensorRT AutoBackend
|
||||
try:
|
||||
# Get first input binding shape
|
||||
if hasattr(pred_model, "input_shape"):
|
||||
shape = pred_model.input_shape
|
||||
if shape and len(shape) >= 4:
|
||||
batch_size = shape[0] if shape[0] > 0 else -1
|
||||
except:
|
||||
pass
|
||||
|
||||
# Try batch attribute
|
||||
if batch_size == -1 and hasattr(pred_model, "batch"):
|
||||
batch_size = (
|
||||
pred_model.batch if pred_model.batch > 0 else -1
|
||||
)
|
||||
|
||||
# Fallback: check model args
|
||||
if hasattr(self._model.model, "args"):
|
||||
imgsz_val = getattr(self._model.model.args, "imgsz", 640)
|
||||
if isinstance(imgsz_val, (list, tuple)):
|
||||
h, w = (
|
||||
imgsz_val[0],
|
||||
imgsz_val[1] if len(imgsz_val) > 1 else imgsz_val[0],
|
||||
)
|
||||
else:
|
||||
h = w = imgsz_val
|
||||
imgsz = h
|
||||
|
||||
input_shape = (batch_size, 3, imgsz, imgsz)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not extract full metadata: {e}")
|
||||
pass
|
||||
|
||||
return EngineMetadata(
|
||||
engine_type="ultralytics",
|
||||
model_path=self._model_path,
|
||||
input_shapes={"images": input_shape},
|
||||
output_shapes={"results": (-1,)}, # Dynamic, depends on detections
|
||||
input_names=["images"],
|
||||
output_names=["results"],
|
||||
input_dtypes={"images": torch.float32},
|
||||
output_dtypes={"results": torch.float32},
|
||||
supports_batching=True,
|
||||
supports_dynamic_shapes=(batch_size == -1),
|
||||
extra_info={
|
||||
"is_yolo": True,
|
||||
"has_builtin_postprocess": True,
|
||||
"batch_size": batch_size,
|
||||
"imgsz": imgsz,
|
||||
},
|
||||
)
|
||||
|
||||
def infer(
|
||||
self,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
return_raw: bool = False,
|
||||
conf: float = 0.25,
|
||||
iou: float = 0.45,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Run Ultralytics inference with zero-copy GPU tensors.
|
||||
|
||||
Args:
|
||||
inputs: Dict with "images" key -> CUDA tensor (B, 3, H, W), normalized [0, 1]
|
||||
return_raw: If True, return raw tensor output. If False, return Results objects
|
||||
conf: Confidence threshold
|
||||
iou: IoU threshold for NMS
|
||||
|
||||
Returns:
|
||||
Dict with inference results
|
||||
|
||||
Note:
|
||||
Input tensor should be normalized to [0, 1] range.
|
||||
Format: (B, 3, H, W) in RGB color space.
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Engine not initialized")
|
||||
|
||||
# Get input tensor
|
||||
if "images" not in inputs:
|
||||
raise ValueError("Input must contain 'images' key")
|
||||
|
||||
images = inputs["images"]
|
||||
|
||||
if not images.is_cuda:
|
||||
raise ValueError("Input must be a CUDA tensor")
|
||||
|
||||
# Ensure tensor is on correct device
|
||||
if images.device != self._device:
|
||||
images = images.to(self._device)
|
||||
|
||||
# Run inference
|
||||
results = self._model(images, conf=conf, iou=iou, verbose=False, **kwargs)
|
||||
|
||||
# Return results
|
||||
# Note: Ultralytics returns Results objects, not raw tensors
|
||||
# For compatibility, we wrap them in a dict
|
||||
return {
|
||||
"results": results,
|
||||
"raw_predictions": results[0].boxes.data
|
||||
if len(results) > 0 and hasattr(results[0], "boxes")
|
||||
else None,
|
||||
}
|
||||
|
||||
def get_metadata(self) -> EngineMetadata:
|
||||
"""Get engine metadata"""
|
||||
if self._metadata is None:
|
||||
raise RuntimeError("Engine not initialized")
|
||||
return self._metadata
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup Ultralytics model"""
|
||||
if self._model is not None:
|
||||
del self._model
|
||||
self._model = None
|
||||
self._metadata = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._device
|
||||
|
||||
|
||||
def create_engine(backend: str | BackendType, **kwargs) -> IInferenceEngine:
|
||||
"""
|
||||
Factory function to create inference engine.
|
||||
|
||||
Args:
|
||||
backend: Backend type (BackendType enum or string: "tensorrt", "ultralytics")
|
||||
**kwargs: Engine-specific arguments
|
||||
|
||||
Returns:
|
||||
IInferenceEngine instance
|
||||
|
||||
Example:
|
||||
>>> from services import create_engine, BackendType
|
||||
>>> engine = create_engine(BackendType.TENSORRT)
|
||||
>>> engine = create_engine("ultralytics")
|
||||
"""
|
||||
# Convert string to BackendType if needed
|
||||
if isinstance(backend, str):
|
||||
backend = BackendType.from_string(backend)
|
||||
|
||||
engines = {
|
||||
BackendType.TENSORRT: NativeTensorRTEngine,
|
||||
BackendType.ULTRALYTICS: UltralyticsEngine,
|
||||
}
|
||||
|
||||
if backend not in engines:
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. Available: {[b.value for b in BackendType]}"
|
||||
)
|
||||
|
||||
return engines[backend]()
|
||||
Loading…
Add table
Add a link
Reference in a new issue