new buffer paradigm
This commit is contained in:
parent
fdaeb9981c
commit
a519dea130
6 changed files with 341 additions and 327 deletions
|
|
@ -9,6 +9,7 @@ Provides a unified interface for different inference backends:
|
|||
All engines support zero-copy GPU tensor inference where possible.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
|
@ -17,6 +18,8 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BackendType(Enum):
|
||||
"""Supported inference backend types"""
|
||||
|
|
@ -423,9 +426,18 @@ class UltralyticsEngine(IInferenceEngine):
|
|||
final_model_path = engine_path
|
||||
print(f"Using TensorRT engine: {engine_path}")
|
||||
|
||||
# CRITICAL: Update _model_path to point to the .engine file for metadata extraction
|
||||
self._model_path = engine_path
|
||||
|
||||
# Load model (Ultralytics handles .engine files natively)
|
||||
self._model = YOLO(final_model_path)
|
||||
|
||||
logger.info(f"Loaded Ultralytics model: {type(self._model)}")
|
||||
if hasattr(self._model, "predictor"):
|
||||
logger.info(
|
||||
f"Model has predictor: {type(self._model.predictor) if self._model.predictor else None}"
|
||||
)
|
||||
|
||||
# Move to device if needed (only for .pt models, .engine already on specific device)
|
||||
if hasattr(self._model, "model") and self._model.model is not None:
|
||||
# Check if it's actually a torch model (not a string path for .engine files)
|
||||
|
|
@ -437,6 +449,39 @@ class UltralyticsEngine(IInferenceEngine):
|
|||
|
||||
return self._metadata
|
||||
|
||||
def _read_batch_size_from_engine_file(self, engine_path: str) -> int:
|
||||
"""
|
||||
Read batch size from the metadata JSON file saved next to the engine.
|
||||
|
||||
Much simpler than parsing TensorRT engine!
|
||||
"""
|
||||
try:
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# The metadata file is named: <engine_path_without_extension>_metadata.json
|
||||
engine_file = Path(engine_path)
|
||||
metadata_file = engine_file.with_name(f"{engine_file.stem}_metadata.json")
|
||||
|
||||
print(f"[UltralyticsEngine] Looking for metadata file: {metadata_file}")
|
||||
|
||||
if metadata_file.exists():
|
||||
with open(metadata_file, "r") as f:
|
||||
metadata = json.load(f)
|
||||
batch_size = metadata.get("batch", -1)
|
||||
print(
|
||||
f"[UltralyticsEngine] Found metadata: batch={batch_size}, imgsz={metadata.get('imgsz')}"
|
||||
)
|
||||
return batch_size
|
||||
else:
|
||||
print(f"[UltralyticsEngine] Metadata file not found: {metadata_file}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[UltralyticsEngine] Could not read batch size from metadata file: {e}"
|
||||
)
|
||||
|
||||
return -1 # Default to dynamic
|
||||
|
||||
def _extract_metadata(self) -> EngineMetadata:
|
||||
"""Extract metadata from Ultralytics model"""
|
||||
# Ultralytics models typically expect (B, 3, H, W) input
|
||||
|
|
@ -447,6 +492,17 @@ class UltralyticsEngine(IInferenceEngine):
|
|||
imgsz = 640
|
||||
input_shape = (batch_size, 3, imgsz, imgsz)
|
||||
|
||||
# CRITICAL: For .engine files, read batch size directly from the TensorRT engine file
|
||||
print(f"[UltralyticsEngine] _model_path={self._model_path}")
|
||||
if self._model_path.endswith(".engine"):
|
||||
print(f"[UltralyticsEngine] Reading batch size from engine file...")
|
||||
batch_size = self._read_batch_size_from_engine_file(self._model_path)
|
||||
print(f"[UltralyticsEngine] Read batch_size={batch_size} from .engine file")
|
||||
if batch_size > 0:
|
||||
input_shape = (batch_size, 3, imgsz, imgsz)
|
||||
else:
|
||||
print(f"[UltralyticsEngine] Not an .engine file, skipping direct read")
|
||||
|
||||
if hasattr(self._model, "model") and self._model.model is not None:
|
||||
# Try to get actual input shape from model
|
||||
try:
|
||||
|
|
@ -508,6 +564,10 @@ class UltralyticsEngine(IInferenceEngine):
|
|||
logger.warning(f"Could not extract full metadata: {e}")
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"Extracted Ultralytics metadata: batch_size={batch_size}, imgsz={imgsz}, input_shape={input_shape}"
|
||||
)
|
||||
|
||||
return EngineMetadata(
|
||||
engine_type="ultralytics",
|
||||
model_path=self._model_path,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue