new buffer paradigm

This commit is contained in:
Siwat Sirichai 2025-11-11 02:02:12 +07:00
parent fdaeb9981c
commit a519dea130
6 changed files with 341 additions and 327 deletions

View file

@ -9,6 +9,7 @@ Provides a unified interface for different inference backends:
All engines support zero-copy GPU tensor inference where possible.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
@ -17,6 +18,8 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
logger = logging.getLogger(__name__)
class BackendType(Enum):
"""Supported inference backend types"""
@ -423,9 +426,18 @@ class UltralyticsEngine(IInferenceEngine):
final_model_path = engine_path
print(f"Using TensorRT engine: {engine_path}")
# CRITICAL: Update _model_path to point to the .engine file for metadata extraction
self._model_path = engine_path
# Load model (Ultralytics handles .engine files natively)
self._model = YOLO(final_model_path)
logger.info(f"Loaded Ultralytics model: {type(self._model)}")
if hasattr(self._model, "predictor"):
logger.info(
f"Model has predictor: {type(self._model.predictor) if self._model.predictor else None}"
)
# Move to device if needed (only for .pt models, .engine already on specific device)
if hasattr(self._model, "model") and self._model.model is not None:
# Check if it's actually a torch model (not a string path for .engine files)
@ -437,6 +449,39 @@ class UltralyticsEngine(IInferenceEngine):
return self._metadata
def _read_batch_size_from_engine_file(self, engine_path: str) -> int:
"""
Read batch size from the metadata JSON file saved next to the engine.
Much simpler than parsing TensorRT engine!
"""
try:
import json
from pathlib import Path
# The metadata file is named: <engine_path_without_extension>_metadata.json
engine_file = Path(engine_path)
metadata_file = engine_file.with_name(f"{engine_file.stem}_metadata.json")
print(f"[UltralyticsEngine] Looking for metadata file: {metadata_file}")
if metadata_file.exists():
with open(metadata_file, "r") as f:
metadata = json.load(f)
batch_size = metadata.get("batch", -1)
print(
f"[UltralyticsEngine] Found metadata: batch={batch_size}, imgsz={metadata.get('imgsz')}"
)
return batch_size
else:
print(f"[UltralyticsEngine] Metadata file not found: {metadata_file}")
except Exception as e:
print(
f"[UltralyticsEngine] Could not read batch size from metadata file: {e}"
)
return -1 # Default to dynamic
def _extract_metadata(self) -> EngineMetadata:
"""Extract metadata from Ultralytics model"""
# Ultralytics models typically expect (B, 3, H, W) input
@ -447,6 +492,17 @@ class UltralyticsEngine(IInferenceEngine):
imgsz = 640
input_shape = (batch_size, 3, imgsz, imgsz)
# CRITICAL: For .engine files, read batch size directly from the TensorRT engine file
print(f"[UltralyticsEngine] _model_path={self._model_path}")
if self._model_path.endswith(".engine"):
print(f"[UltralyticsEngine] Reading batch size from engine file...")
batch_size = self._read_batch_size_from_engine_file(self._model_path)
print(f"[UltralyticsEngine] Read batch_size={batch_size} from .engine file")
if batch_size > 0:
input_shape = (batch_size, 3, imgsz, imgsz)
else:
print(f"[UltralyticsEngine] Not an .engine file, skipping direct read")
if hasattr(self._model, "model") and self._model.model is not None:
# Try to get actual input shape from model
try:
@ -508,6 +564,10 @@ class UltralyticsEngine(IInferenceEngine):
logger.warning(f"Could not extract full metadata: {e}")
pass
logger.info(
f"Extracted Ultralytics metadata: batch_size={batch_size}, imgsz={imgsz}, input_shape={input_shape}"
)
return EngineMetadata(
engine_type="ultralytics",
model_path=self._model_path,