ultralytic export

This commit is contained in:
Siwat Sirichai 2025-11-11 01:28:19 +07:00
parent bf7b68edb1
commit fdaeb9981c
14 changed files with 2241 additions and 507 deletions

View file

@ -2,38 +2,67 @@
Services package for RTSP stream processing with GPU acceleration.
"""
from .stream_decoder import StreamDecoderFactory, StreamDecoder, ConnectionStatus
from .base_model_controller import BaseModelController, BatchFrame, BufferState
from .inference_engine import (
BackendType,
EngineMetadata,
IInferenceEngine,
NativeTensorRTEngine,
UltralyticsEngine,
create_engine,
)
from .jpeg_encoder import JPEGEncoderFactory, encode_frame_to_jpeg
from .model_repository import TensorRTModelRepository, ModelMetadata, ExecutionContext, SharedEngine
from .tracking_controller import ObjectTracker, TrackedObject, Detection
from .yolo import YOLOv8Utils, COCO_CLASSES
from .model_controller import ModelController, BatchFrame, BufferState
from .stream_connection_manager import StreamConnectionManager, StreamConnection, TrackingResult
from .model_repository import (
ExecutionContext,
ModelMetadata,
SharedEngine,
TensorRTModelRepository,
)
from .modelstorage import FileModelStorage, IModelStorage
from .pt_converter import PTConverter
from .modelstorage import IModelStorage, FileModelStorage
from .stream_connection_manager import (
StreamConnection,
StreamConnectionManager,
TrackingResult,
)
from .stream_decoder import ConnectionStatus, StreamDecoder, StreamDecoderFactory
from .tensorrt_model_controller import TensorRTModelController
from .tracking_controller import Detection, ObjectTracker, TrackedObject
from .ultralytics_exporter import UltralyticsExporter
from .ultralytics_model_controller import UltralyticsModelController
from .yolo import COCO_CLASSES, YOLOv8Utils
__all__ = [
'StreamDecoderFactory',
'StreamDecoder',
'ConnectionStatus',
'JPEGEncoderFactory',
'encode_frame_to_jpeg',
'TensorRTModelRepository',
'ModelMetadata',
'ExecutionContext',
'SharedEngine',
'ObjectTracker',
'TrackedObject',
'Detection',
'YOLOv8Utils',
'COCO_CLASSES',
'ModelController',
'BatchFrame',
'BufferState',
'StreamConnectionManager',
'StreamConnection',
'TrackingResult',
'PTConverter',
'IModelStorage',
'FileModelStorage',
"StreamDecoderFactory",
"StreamDecoder",
"ConnectionStatus",
"JPEGEncoderFactory",
"encode_frame_to_jpeg",
"TensorRTModelRepository",
"ModelMetadata",
"ExecutionContext",
"SharedEngine",
"ObjectTracker",
"TrackedObject",
"Detection",
"YOLOv8Utils",
"COCO_CLASSES",
"BaseModelController",
"TensorRTModelController",
"UltralyticsModelController",
"BatchFrame",
"BufferState",
"StreamConnectionManager",
"StreamConnection",
"TrackingResult",
"PTConverter",
"IModelStorage",
"FileModelStorage",
"IInferenceEngine",
"NativeTensorRTEngine",
"UltralyticsEngine",
"EngineMetadata",
"BackendType",
"create_engine",
"UltralyticsExporter",
]

View file

@ -0,0 +1,324 @@
"""
Base Model Controller - Abstract base class for batched inference controllers.
Provides ping-pong buffer architecture with force-switch timeout mechanism.
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.).
"""
import logging
import threading
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
import torch
logger = logging.getLogger(__name__)
@dataclass
class BatchFrame:
"""Represents a frame in the batch buffer"""
stream_id: str
frame: torch.Tensor # GPU tensor (3, H, W)
timestamp: float
metadata: Dict = field(default_factory=dict)
class BufferState(Enum):
"""State of a ping-pong buffer"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"
class BaseModelController(ABC):
"""
Abstract base class for batched inference with ping-pong buffers.
This controller accumulates frames from multiple streams into batches,
processes them through an inference backend, and routes results back to
stream-specific callbacks.
Features:
- Ping-pong circular buffers (BufferA/BufferB)
- Force-switch timeout to prevent batch starvation
- Event-driven processing with callbacks
- Thread-safe frame submission
Subclasses must implement:
- _run_batch_inference(): Backend-specific inference logic
"""
def __init__(
self,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
self.model_id = model_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn
# Ping-pong buffers
self.buffer_a: List[BatchFrame] = []
self.buffer_b: List[BatchFrame] = []
# Buffer states
self.active_buffer = "A"
self.buffer_a_state = BufferState.IDLE
self.buffer_b_state = BufferState.IDLE
# Threading coordination
self.buffer_lock = threading.RLock()
self.last_submit_time = time.time()
# Threads
self.timeout_thread: Optional[threading.Thread] = None
self.processor_threads: Dict[str, threading.Thread] = {}
self.running = False
self.stop_event = threading.Event()
# Result callbacks (stream_id -> callback)
self.result_callbacks: Dict[str, Callable] = {}
# Statistics
self.total_frames_processed = 0
self.total_batches_processed = 0
def start(self):
"""Start the controller background threads"""
if self.running:
logger.warning("ModelController already running")
return
self.running = True
self.stop_event.clear()
# Start timeout monitor thread
self.timeout_thread = threading.Thread(
target=self._timeout_monitor, daemon=True
)
self.timeout_thread.start()
# Start processor threads for each buffer
self.processor_threads["A"] = threading.Thread(
target=self._batch_processor, args=("A",), daemon=True
)
self.processor_threads["B"] = threading.Thread(
target=self._batch_processor, args=("B",), daemon=True
)
self.processor_threads["A"].start()
self.processor_threads["B"].start()
logger.info(f"{self.__class__.__name__} started")
def stop(self):
"""Stop the controller and cleanup"""
if not self.running:
return
logger.info(f"Stopping {self.__class__.__name__}...")
self.running = False
self.stop_event.set()
# Wait for threads to finish
if self.timeout_thread and self.timeout_thread.is_alive():
self.timeout_thread.join(timeout=2.0)
for thread in self.processor_threads.values():
if thread and thread.is_alive():
thread.join(timeout=2.0)
# Process any remaining frames
self._process_remaining_buffers()
logger.info(f"{self.__class__.__name__} stopped")
def register_callback(self, stream_id: str, callback: Callable):
"""Register a callback for inference results from a stream"""
self.result_callbacks[stream_id] = callback
logger.debug(f"Registered callback for stream: {stream_id}")
def unregister_callback(self, stream_id: str):
"""Unregister a stream callback"""
self.result_callbacks.pop(stream_id, None)
logger.debug(f"Unregistered callback for stream: {stream_id}")
def submit_frame(
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
):
"""Submit a frame for batched inference"""
with self.buffer_lock:
batch_frame = BatchFrame(
stream_id=stream_id,
frame=frame,
timestamp=time.time(),
metadata=metadata or {},
)
# Add to active buffer
if self.active_buffer == "A":
self.buffer_a.append(batch_frame)
self.buffer_a_state = BufferState.FILLING
buffer_size = len(self.buffer_a)
else:
self.buffer_b.append(batch_frame)
self.buffer_b_state = BufferState.FILLING
buffer_size = len(self.buffer_b)
self.last_submit_time = time.time()
# Check if we should immediately swap (batch full)
if buffer_size >= self.batch_size:
self._try_swap_buffers()
def _timeout_monitor(self):
"""Monitor force-switch timeout"""
while self.running and not self.stop_event.wait(0.01):
with self.buffer_lock:
time_since_submit = time.time() - self.last_submit_time
if time_since_submit >= self.force_timeout:
active_buffer = (
self.buffer_a if self.active_buffer == "A" else self.buffer_b
)
if len(active_buffer) > 0:
self._try_swap_buffers()
def _try_swap_buffers(self):
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
inactive_state = (
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
)
if inactive_state != BufferState.PROCESSING:
old_active = self.active_buffer
self.active_buffer = "B" if old_active == "A" else "A"
if old_active == "A":
self.buffer_a_state = BufferState.PROCESSING
buffer_size = len(self.buffer_a)
else:
self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b)
logger.debug(
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
)
def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available"""
while self.running and not self.stop_event.is_set():
time.sleep(0.001)
with self.buffer_lock:
if buffer_name == "A":
should_process = self.buffer_a_state == BufferState.PROCESSING
else:
should_process = self.buffer_b_state == BufferState.PROCESSING
if should_process:
self._process_buffer(buffer_name)
def _process_buffer(self, buffer_name: str):
"""Process a buffer through inference"""
# Extract buffer to process
with self.buffer_lock:
if buffer_name == "A":
batch = self.buffer_a.copy()
self.buffer_a.clear()
else:
batch = self.buffer_b.copy()
self.buffer_b.clear()
if len(batch) == 0:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
return
# Process batch (outside lock to allow concurrent submissions)
try:
start_time = time.time()
results = self._run_batch_inference(batch)
inference_time = time.time() - start_time
self.total_frames_processed += len(batch)
self.total_batches_processed += 1
logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
)
# Emit results to callbacks
for batch_frame, result in zip(batch, results):
callback = self.result_callbacks.get(batch_frame.stream_id)
if callback:
try:
callback(result)
except Exception as e:
logger.error(
f"Error in callback for {batch_frame.stream_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
finally:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
@abstractmethod
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run inference on a batch of frames (backend-specific).
Args:
batch: List of BatchFrame objects
Returns:
List of detection results (one per frame)
"""
pass
def _process_remaining_buffers(self):
"""Process any remaining frames in buffers during shutdown"""
if len(self.buffer_a) > 0:
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
self._process_buffer("A")
if len(self.buffer_b) > 0:
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
self._process_buffer("B")
def get_stats(self) -> Dict[str, Any]:
"""Get current buffer statistics"""
return {
"active_buffer": self.active_buffer,
"buffer_a_size": len(self.buffer_a),
"buffer_b_size": len(self.buffer_b),
"buffer_a_state": self.buffer_a_state.value,
"buffer_b_state": self.buffer_b_state.value,
"registered_streams": len(self.result_callbacks),
"total_frames_processed": self.total_frames_processed,
"total_batches_processed": self.total_batches_processed,
"avg_batch_size": (
self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0
else 0
),
}

View file

@ -0,0 +1,635 @@
"""
Inference Engine Abstraction Layer
Provides a unified interface for different inference backends:
- Native TensorRT: Direct TensorRT API with zero-copy GPU tensors
- Ultralytics: YOLO models with built-in pre/postprocessing
- Future: ONNX Runtime, OpenVINO, etc.
All engines support zero-copy GPU tensor inference where possible.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
class BackendType(Enum):
"""Supported inference backend types"""
TENSORRT = "tensorrt"
ULTRALYTICS = "ultralytics"
@classmethod
def from_string(cls, backend: str) -> "BackendType":
"""Convert string to BackendType"""
backend = backend.lower()
for member in cls:
if member.value == backend:
return member
raise ValueError(
f"Unknown backend: {backend}. Available: {[m.value for m in cls]}"
)
@dataclass
class EngineMetadata:
"""Metadata for an inference engine"""
engine_type: str # "tensorrt", "ultralytics", etc.
model_path: str
input_shapes: Dict[str, Tuple[int, ...]]
output_shapes: Dict[str, Tuple[int, ...]]
input_names: List[str]
output_names: List[str]
input_dtypes: Dict[str, torch.dtype]
output_dtypes: Dict[str, torch.dtype]
supports_batching: bool = True
supports_dynamic_shapes: bool = False
extra_info: Dict[str, Any] = None # Backend-specific info
class IInferenceEngine(ABC):
"""
Abstract interface for inference engines.
All implementations must support zero-copy GPU tensor inference:
- Inputs: CUDA tensors on GPU
- Outputs: CUDA tensors on GPU
- No CPU transfers during inference
"""
@abstractmethod
def initialize(
self, model_path: str, device: torch.device, **kwargs
) -> EngineMetadata:
"""
Initialize the inference engine.
Automatically detects model type and handles conversion if needed.
Args:
model_path: Path to model file (.pt, .engine, .trt)
device: GPU device to use
**kwargs: Optional parameters (batch_size, half, workspace, etc.)
Returns:
EngineMetadata with model information
"""
pass
@abstractmethod
def infer(
self, inputs: Dict[str, torch.Tensor], **kwargs
) -> Dict[str, torch.Tensor]:
"""
Run inference on GPU tensors (zero-copy).
Args:
inputs: Dict of input_name -> CUDA tensor
**kwargs: Backend-specific inference parameters
Returns:
Dict of output_name -> CUDA tensor
Raises:
ValueError: If inputs are not CUDA tensors or wrong shape
"""
pass
@abstractmethod
def get_metadata(self) -> EngineMetadata:
"""Get engine metadata"""
pass
@abstractmethod
def cleanup(self):
"""Cleanup resources"""
pass
@property
@abstractmethod
def is_initialized(self) -> bool:
"""Check if engine is initialized"""
pass
@property
@abstractmethod
def device(self) -> torch.device:
"""Get device the engine is running on"""
pass
class NativeTensorRTEngine(IInferenceEngine):
"""
Native TensorRT inference engine with direct API access.
Features:
- Zero-copy GPU tensor inference
- Execution context pooling for concurrent inference
- Support for .trt, .engine files
- Automatic Ultralytics .engine metadata stripping
"""
def __init__(self):
self._engine = None
self._contexts = []
self._metadata = None
self._device = None
self._trt_logger = None
def initialize(
self, model_path: str, device: torch.device, num_contexts: int = 1, **kwargs
) -> EngineMetadata:
"""
Initialize TensorRT engine.
Args:
model_path: Path to .trt or .engine file
device: GPU device
num_contexts: Number of execution contexts for pooling
Returns:
EngineMetadata
"""
import tensorrt as trt
self._device = device
self._trt_logger = trt.Logger(trt.Logger.WARNING)
# Load engine
runtime = trt.Runtime(self._trt_logger)
# Read engine file (handle Ultralytics format)
engine_data = self._load_engine_data(model_path)
self._engine = runtime.deserialize_cuda_engine(engine_data)
if self._engine is None:
raise RuntimeError(f"Failed to load TensorRT engine from {model_path}")
# Create execution contexts
for i in range(num_contexts):
ctx = self._engine.create_execution_context()
if ctx is None:
raise RuntimeError(f"Failed to create execution context {i}")
self._contexts.append(ctx)
# Extract metadata
self._metadata = self._extract_metadata(model_path)
return self._metadata
def _load_engine_data(self, file_path: str) -> bytes:
"""Load engine data, stripping Ultralytics metadata if present"""
import json
with open(file_path, "rb") as f:
# Try to read Ultralytics metadata header
meta_len_bytes = f.read(4)
if len(meta_len_bytes) == 4:
meta_len = int.from_bytes(meta_len_bytes, byteorder="little")
# Sanity check
if 0 < meta_len < 100000:
try:
metadata_bytes = f.read(meta_len)
json.loads(metadata_bytes.decode("utf-8"))
# Valid Ultralytics metadata, rest is engine
return f.read()
except (UnicodeDecodeError, json.JSONDecodeError):
pass
# Not Ultralytics format, read entire file
f.seek(0)
return f.read()
def _extract_metadata(self, model_path: str) -> EngineMetadata:
"""Extract metadata from TensorRT engine"""
import tensorrt as trt
input_shapes = {}
output_shapes = {}
input_names = []
output_names = []
input_dtypes = {}
output_dtypes = {}
trt_to_torch_dtype = {
trt.DataType.FLOAT: torch.float32,
trt.DataType.HALF: torch.float16,
trt.DataType.INT8: torch.int8,
trt.DataType.INT32: torch.int32,
trt.DataType.BOOL: torch.bool,
}
for i in range(self._engine.num_io_tensors):
name = self._engine.get_tensor_name(i)
shape = tuple(self._engine.get_tensor_shape(name))
dtype = trt_to_torch_dtype.get(
self._engine.get_tensor_dtype(name), torch.float32
)
mode = self._engine.get_tensor_mode(name)
if mode == trt.TensorIOMode.INPUT:
input_names.append(name)
input_shapes[name] = shape
input_dtypes[name] = dtype
else:
output_names.append(name)
output_shapes[name] = shape
output_dtypes[name] = dtype
return EngineMetadata(
engine_type="tensorrt",
model_path=model_path,
input_shapes=input_shapes,
output_shapes=output_shapes,
input_names=input_names,
output_names=output_names,
input_dtypes=input_dtypes,
output_dtypes=output_dtypes,
supports_batching=True,
supports_dynamic_shapes=False,
)
def infer(
self,
inputs: Dict[str, torch.Tensor],
context_id: int = 0,
stream: Optional[torch.cuda.Stream] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Run TensorRT inference with zero-copy GPU tensors.
Args:
inputs: Dict of input_name -> CUDA tensor
context_id: Which execution context to use
stream: CUDA stream for async execution
Returns:
Dict of output_name -> CUDA tensor
"""
if not self.is_initialized:
raise RuntimeError("Engine not initialized")
# Validate inputs
for name in self._metadata.input_names:
if name not in inputs:
raise ValueError(f"Missing required input: {name}")
if not inputs[name].is_cuda:
raise ValueError(f"Input '{name}' must be a CUDA tensor")
# Get execution context
if context_id >= len(self._contexts):
raise ValueError(
f"Invalid context_id {context_id}, only {len(self._contexts)} contexts available"
)
context = self._contexts[context_id]
# Prepare outputs
outputs = {}
# Set input tensor addresses
for name in self._metadata.input_names:
input_tensor = inputs[name].contiguous()
context.set_tensor_address(name, input_tensor.data_ptr())
# Allocate and set output tensors
for name in self._metadata.output_names:
output_tensor = torch.empty(
self._metadata.output_shapes[name],
dtype=self._metadata.output_dtypes[name],
device=self._device,
)
outputs[name] = output_tensor
context.set_tensor_address(name, output_tensor.data_ptr())
# Execute
if stream is None:
stream = torch.cuda.Stream(device=self._device)
with torch.cuda.stream(stream):
success = context.execute_async_v3(stream_handle=stream.cuda_stream)
if not success:
raise RuntimeError("TensorRT inference failed")
stream.synchronize()
return outputs
def get_metadata(self) -> EngineMetadata:
"""Get engine metadata"""
if self._metadata is None:
raise RuntimeError("Engine not initialized")
return self._metadata
def cleanup(self):
"""Cleanup TensorRT resources"""
for ctx in self._contexts:
del ctx
self._contexts.clear()
if self._engine is not None:
del self._engine
self._engine = None
self._metadata = None
@property
def is_initialized(self) -> bool:
return self._engine is not None
@property
def device(self) -> torch.device:
return self._device
class UltralyticsEngine(IInferenceEngine):
"""
Ultralytics YOLO inference engine.
Features:
- Zero-copy GPU tensor inference
- Built-in preprocessing/postprocessing for YOLO models
- Supports .pt, .engine formats
- Automatic model export to TensorRT with caching
"""
def __init__(self):
self._model = None
self._metadata = None
self._device = None
self._model_path = None
self._exporter = None
def initialize(
self,
model_path: str,
device: torch.device,
batch: int = 1,
half: bool = False,
imgsz: int = 640,
cache_dir: str = ".ultralytics_cache",
**kwargs,
) -> EngineMetadata:
"""
Initialize Ultralytics YOLO model.
Automatically exports .pt models to .engine format with caching.
Args:
model_path: Path to .pt or .engine file
device: GPU device
batch: Maximum batch size for inference
half: Use FP16 precision
imgsz: Input image size
cache_dir: Directory for caching exported engines
**kwargs: Additional export parameters
Returns:
EngineMetadata
"""
from ultralytics import YOLO
from .ultralytics_exporter import UltralyticsExporter
self._device = device
self._model_path = model_path
# Check if we need to export
model_file = Path(model_path)
final_model_path = model_path
if model_file.suffix == ".pt":
# Use exporter with caching
print(f"Checking for cached TensorRT engine...")
self._exporter = UltralyticsExporter(cache_dir=cache_dir)
_, engine_path = self._exporter.export(
model_path=str(model_path),
device=device.index if device.type == "cuda" else 0,
half=half,
imgsz=imgsz,
batch=batch,
**kwargs,
)
final_model_path = engine_path
print(f"Using TensorRT engine: {engine_path}")
# Load model (Ultralytics handles .engine files natively)
self._model = YOLO(final_model_path)
# Move to device if needed (only for .pt models, .engine already on specific device)
if hasattr(self._model, "model") and self._model.model is not None:
# Check if it's actually a torch model (not a string path for .engine files)
if hasattr(self._model.model, "to"):
self._model.model = self._model.model.to(device)
# Extract metadata
self._metadata = self._extract_metadata()
return self._metadata
def _extract_metadata(self) -> EngineMetadata:
"""Extract metadata from Ultralytics model"""
# Ultralytics models typically expect (B, 3, H, W) input
# and return Results objects, not raw tensors
# Default values
batch_size = -1 # Dynamic batching by default
imgsz = 640
input_shape = (batch_size, 3, imgsz, imgsz)
if hasattr(self._model, "model") and self._model.model is not None:
# Try to get actual input shape from model
try:
# For .engine files, check predictor model
if (
hasattr(self._model, "predictor")
and self._model.predictor is not None
):
predictor = self._model.predictor
# Get image size
if hasattr(predictor, "args") and hasattr(predictor.args, "imgsz"):
imgsz_val = predictor.args.imgsz
if isinstance(imgsz_val, (list, tuple)):
h, w = (
imgsz_val[0],
imgsz_val[1] if len(imgsz_val) > 1 else imgsz_val[0],
)
else:
h = w = imgsz_val
imgsz = h # Use height as reference
# Get batch size from model
if hasattr(predictor, "model"):
pred_model = predictor.model
# For TensorRT engines, check input bindings
if hasattr(pred_model, "bindings"):
# This is a TensorRT AutoBackend
try:
# Get first input binding shape
if hasattr(pred_model, "input_shape"):
shape = pred_model.input_shape
if shape and len(shape) >= 4:
batch_size = shape[0] if shape[0] > 0 else -1
except:
pass
# Try batch attribute
if batch_size == -1 and hasattr(pred_model, "batch"):
batch_size = (
pred_model.batch if pred_model.batch > 0 else -1
)
# Fallback: check model args
if hasattr(self._model.model, "args"):
imgsz_val = getattr(self._model.model.args, "imgsz", 640)
if isinstance(imgsz_val, (list, tuple)):
h, w = (
imgsz_val[0],
imgsz_val[1] if len(imgsz_val) > 1 else imgsz_val[0],
)
else:
h = w = imgsz_val
imgsz = h
input_shape = (batch_size, 3, imgsz, imgsz)
except Exception as e:
logger.warning(f"Could not extract full metadata: {e}")
pass
return EngineMetadata(
engine_type="ultralytics",
model_path=self._model_path,
input_shapes={"images": input_shape},
output_shapes={"results": (-1,)}, # Dynamic, depends on detections
input_names=["images"],
output_names=["results"],
input_dtypes={"images": torch.float32},
output_dtypes={"results": torch.float32},
supports_batching=True,
supports_dynamic_shapes=(batch_size == -1),
extra_info={
"is_yolo": True,
"has_builtin_postprocess": True,
"batch_size": batch_size,
"imgsz": imgsz,
},
)
def infer(
self,
inputs: Dict[str, torch.Tensor],
return_raw: bool = False,
conf: float = 0.25,
iou: float = 0.45,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Run Ultralytics inference with zero-copy GPU tensors.
Args:
inputs: Dict with "images" key -> CUDA tensor (B, 3, H, W), normalized [0, 1]
return_raw: If True, return raw tensor output. If False, return Results objects
conf: Confidence threshold
iou: IoU threshold for NMS
Returns:
Dict with inference results
Note:
Input tensor should be normalized to [0, 1] range.
Format: (B, 3, H, W) in RGB color space.
"""
if not self.is_initialized:
raise RuntimeError("Engine not initialized")
# Get input tensor
if "images" not in inputs:
raise ValueError("Input must contain 'images' key")
images = inputs["images"]
if not images.is_cuda:
raise ValueError("Input must be a CUDA tensor")
# Ensure tensor is on correct device
if images.device != self._device:
images = images.to(self._device)
# Run inference
results = self._model(images, conf=conf, iou=iou, verbose=False, **kwargs)
# Return results
# Note: Ultralytics returns Results objects, not raw tensors
# For compatibility, we wrap them in a dict
return {
"results": results,
"raw_predictions": results[0].boxes.data
if len(results) > 0 and hasattr(results[0], "boxes")
else None,
}
def get_metadata(self) -> EngineMetadata:
"""Get engine metadata"""
if self._metadata is None:
raise RuntimeError("Engine not initialized")
return self._metadata
def cleanup(self):
"""Cleanup Ultralytics model"""
if self._model is not None:
del self._model
self._model = None
self._metadata = None
@property
def is_initialized(self) -> bool:
return self._model is not None
@property
def device(self) -> torch.device:
return self._device
def create_engine(backend: str | BackendType, **kwargs) -> IInferenceEngine:
"""
Factory function to create inference engine.
Args:
backend: Backend type (BackendType enum or string: "tensorrt", "ultralytics")
**kwargs: Engine-specific arguments
Returns:
IInferenceEngine instance
Example:
>>> from services import create_engine, BackendType
>>> engine = create_engine(BackendType.TENSORRT)
>>> engine = create_engine("ultralytics")
"""
# Convert string to BackendType if needed
if isinstance(backend, str):
backend = BackendType.from_string(backend)
engines = {
BackendType.TENSORRT: NativeTensorRTEngine,
BackendType.ULTRALYTICS: UltralyticsEngine,
}
if backend not in engines:
raise ValueError(
f"Unknown backend: {backend}. Available: {[b.value for b in BackendType]}"
)
return engines[backend]()

View file

@ -5,21 +5,22 @@ This module provides batched inference coordination using ping-pong circular buf
with force-switch timeout mechanism using threading and callbacks.
"""
import threading
import torch
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import time
import logging
import queue
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
import torch
logger = logging.getLogger(__name__)
@dataclass
class BatchFrame:
"""Represents a frame in the batch buffer"""
stream_id: str
frame: torch.Tensor # GPU tensor (3, H, W)
timestamp: float
@ -28,6 +29,7 @@ class BatchFrame:
class BufferState(Enum):
"""State of a ping-pong buffer"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"
@ -80,7 +82,9 @@ class ModelController:
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
)
else:
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}")
logger.info(
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
)
# Ping-pong buffers
self.buffer_a: List[BatchFrame] = []
@ -130,7 +134,9 @@ class ModelController:
# Fixed batch size
return batch_dim
except Exception as e:
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
logger.warning(
f"Could not detect model batch size: {e}. Assuming batch_size=1"
)
return 1
def start(self):
@ -143,14 +149,20 @@ class ModelController:
self.stop_event.clear()
# Start timeout monitor thread
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True)
self.timeout_thread = threading.Thread(
target=self._timeout_monitor, daemon=True
)
self.timeout_thread.start()
# Start processor threads for each buffer
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True)
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True)
self.processor_threads['A'].start()
self.processor_threads['B'].start()
self.processor_threads["A"] = threading.Thread(
target=self._batch_processor, args=("A",), daemon=True
)
self.processor_threads["B"] = threading.Thread(
target=self._batch_processor, args=("B",), daemon=True
)
self.processor_threads["A"].start()
self.processor_threads["B"].start()
logger.info("ModelController started")
@ -197,10 +209,7 @@ class ModelController:
logger.debug(f"Unregistered callback for stream: {stream_id}")
def submit_frame(
self,
stream_id: str,
frame: torch.Tensor,
metadata: Optional[Dict] = None
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
):
"""
Submit a frame for batched inference.
@ -215,7 +224,7 @@ class ModelController:
stream_id=stream_id,
frame=frame,
timestamp=time.time(),
metadata=metadata or {}
metadata=metadata or {},
)
# Add to active buffer
@ -242,7 +251,9 @@ class ModelController:
# Check if timeout expired and we have frames waiting
if time_since_submit >= self.force_timeout:
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
active_buffer = (
self.buffer_a if self.active_buffer == "A" else self.buffer_b
)
if len(active_buffer) > 0:
self._try_swap_buffers()
@ -254,7 +265,9 @@ class ModelController:
This method should be called with buffer_lock held.
"""
# Check if inactive buffer is available
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
inactive_state = (
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
)
if inactive_state != BufferState.PROCESSING:
# Swap active buffer
@ -269,7 +282,9 @@ class ModelController:
self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b)
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
logger.debug(
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
)
def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available"""
@ -322,8 +337,8 @@ class ModelController:
self.total_batches_processed += 1
logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms "
f"({inference_time*1000/len(batch):.2f}ms per frame)"
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
)
# Emit results to callbacks
@ -334,7 +349,10 @@ class ModelController:
try:
callback(result)
except Exception as e:
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True)
logger.error(
f"Error in callback for {batch_frame.stream_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
@ -365,7 +383,9 @@ class ModelController:
# Use true batching for models that support it
return self._run_batched_inference(batch)
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
def _run_sequential_inference(
self, batch: List[BatchFrame]
) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models"""
results = []
@ -375,13 +395,15 @@ class ModelController:
processed = self.preprocess_fn(batch_frame.frame)
else:
# Ensure we have batch dimension
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
processed = (
batch_frame.frame.unsqueeze(0)
if batch_frame.frame.dim() == 3
else batch_frame.frame
)
# Run inference for this frame
outputs = self.model_repository.infer(
self.model_id,
{"images": processed},
synchronize=True
self.model_id, {"images": processed}, synchronize=True
)
# Postprocess
@ -389,9 +411,13 @@ class ModelController:
try:
detections = self.postprocess_fn(outputs)
except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
# Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else:
detections = outputs
@ -429,32 +455,37 @@ class ModelController:
f"will split into sub-batches"
)
# TODO: Handle splitting into sub-batches
batch_tensor = batch_tensor[:self.model_batch_size]
batch = batch[:self.model_batch_size]
batch_tensor = batch_tensor[: self.model_batch_size]
batch = batch[: self.model_batch_size]
# Run inference
outputs = self.model_repository.infer(
self.model_id,
{"images": batch_tensor},
synchronize=True
self.model_id, {"images": batch_tensor}, synchronize=True
)
# Postprocess results (split batch back to individual results)
results = []
for i, batch_frame in enumerate(batch):
# Extract single frame output from batch
# Extract single frame output from batch and clone to ensure memory safety
# This prevents potential race conditions if the output tensors are still
# in use when the next inference batch is processed
frame_output = {}
for k, v in outputs.items():
# v has shape (N, ...), extract index i and keep batch dimension
frame_output[k] = v[i:i+1] # Shape: (1, ...)
# Clone to decouple from shared batch output tensor
frame_output[k] = v[i : i + 1].clone() # Shape: (1, ...)
if self.postprocess_fn:
try:
detections = self.postprocess_fn(frame_output)
except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
# Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else:
detections = frame_output
@ -490,6 +521,8 @@ class ModelController:
"total_batches_processed": self.total_batches_processed,
"avg_batch_size": (
self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0
else 0
if self.total_batches_processed > 0 else 0
),
}

View file

@ -1,13 +1,14 @@
import threading
import hashlib
import json
from typing import Optional, Dict, Any, List, Tuple
import logging
import threading
from dataclasses import dataclass
from pathlib import Path
from queue import Queue
import torch
from typing import Any, Dict, List, Optional, Tuple
import tensorrt as trt
from dataclasses import dataclass
import logging
import torch
logger = logging.getLogger(__name__)
@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
@dataclass
class ModelMetadata:
"""Metadata for a loaded TensorRT model"""
file_path: str
file_hash: str
input_shapes: Dict[str, Tuple[int, ...]]
@ -30,8 +32,14 @@ class ExecutionContext:
Wrapper for TensorRT execution context with CUDA stream.
Used in context pool for load balancing.
"""
def __init__(self, context: trt.IExecutionContext, stream: torch.cuda.Stream,
context_id: int, device: torch.device):
def __init__(
self,
context: trt.IExecutionContext,
stream: torch.cuda.Stream,
context_id: int,
device: torch.device,
):
self.context = context
self.stream = stream
self.context_id = context_id
@ -53,8 +61,16 @@ class SharedEngine:
- Contexts are borrowed/returned using mutex locks
- Load balancing: contexts distributed across requests
"""
def __init__(self, engine: trt.ICudaEngine, file_hash: str, file_path: str,
num_contexts: int, device: torch.device, metadata: ModelMetadata):
def __init__(
self,
engine: trt.ICudaEngine,
file_hash: str,
file_path: str,
num_contexts: int,
device: torch.device,
metadata: ModelMetadata,
):
self.engine = engine
self.file_hash = file_hash
self.file_path = file_path
@ -80,9 +96,13 @@ class SharedEngine:
self.model_ids: set = set()
self.lock = threading.Lock()
print(f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}...")
print(
f"Created context pool with {num_contexts} contexts for engine {file_hash[:8]}..."
)
def acquire_context(self, timeout: Optional[float] = None) -> Optional[ExecutionContext]:
def acquire_context(
self, timeout: Optional[float] = None
) -> Optional[ExecutionContext]:
"""
Acquire an available execution context from the pool.
Blocks if all contexts are in use.
@ -162,7 +182,13 @@ class TensorRTModelRepository:
# Result: 1 engine in VRAM, N contexts (e.g., 4), not 100 contexts!
"""
def __init__(self, gpu_id: int = 0, default_num_contexts: int = 4, enable_pt_conversion: bool = True, cache_dir: str = ".trt_cache"):
def __init__(
self,
gpu_id: int = 0,
default_num_contexts: int = 4,
enable_pt_conversion: bool = True,
cache_dir: str = ".trt_cache",
):
"""
Initialize the model repository.
@ -173,7 +199,7 @@ class TensorRTModelRepository:
cache_dir: Directory for caching stripped TensorRT engines and metadata
"""
self.gpu_id = gpu_id
self.device = torch.device(f'cuda:{gpu_id}')
self.device = torch.device(f"cuda:{gpu_id}")
self.default_num_contexts = default_num_contexts
self.enable_pt_conversion = enable_pt_conversion
self.cache_dir = Path(cache_dir)
@ -195,7 +221,9 @@ class TensorRTModelRepository:
self._pt_converter = None
print(f"TensorRT Model Repository initialized on GPU {gpu_id}")
print(f"Default context pool size: {default_num_contexts} contexts per unique model")
print(
f"Default context pool size: {default_num_contexts} contexts per unique model"
)
print(f"Cache directory: {self.cache_dir}")
if enable_pt_conversion:
print(f"PyTorch to TensorRT conversion: enabled")
@ -205,6 +233,7 @@ class TensorRTModelRepository:
"""Lazy initialization of PT converter"""
if self._pt_converter is None and self.enable_pt_conversion:
from .pt_converter import PTConverter
self._pt_converter = PTConverter(gpu_id=self.gpu_id)
logger.info("PT converter initialized")
return self._pt_converter
@ -255,11 +284,11 @@ class TensorRTModelRepository:
# Check if stripped engine already cached
if cache_engine_path.exists():
logger.info(f"Loading cached stripped engine from {cache_engine_path}")
with open(cache_engine_path, 'rb') as f:
with open(cache_engine_path, "rb") as f:
engine_data = f.read()
else:
# Read and process original file
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
# Try to read Ultralytics metadata header (first 4 bytes = metadata length)
try:
meta_len_bytes = f.read(4)
@ -278,13 +307,15 @@ class TensorRTModelRepository:
# Save stripped engine to cache
logger.info(f"Detected Ultralytics engine format")
logger.info(f"Ultralytics metadata: {metadata}")
logger.info(f"Caching stripped engine to {cache_engine_path}")
logger.info(
f"Caching stripped engine to {cache_engine_path}"
)
with open(cache_engine_path, 'wb') as cache_f:
with open(cache_engine_path, "wb") as cache_f:
cache_f.write(engine_data)
# Save metadata separately
with open(cache_metadata_path, 'w') as meta_f:
with open(cache_metadata_path, "w") as meta_f:
json.dump(metadata, meta_f, indent=2)
except (UnicodeDecodeError, json.JSONDecodeError):
@ -301,13 +332,15 @@ class TensorRTModelRepository:
except Exception as e:
# Any error, rewind and read entire file
logger.warning(f"Error reading engine metadata: {e}, treating as raw TRT engine")
logger.warning(
f"Error reading engine metadata: {e}, treating as raw TRT engine"
)
f.seek(0)
engine_data = f.read()
# Cache the engine data (even if it was already raw TRT)
if not cache_engine_path.exists():
with open(cache_engine_path, 'wb') as cache_f:
with open(cache_engine_path, "wb") as cache_f:
cache_f.write(engine_data)
engine = runtime.deserialize_cuda_engine(engine_data)
@ -316,8 +349,9 @@ class TensorRTModelRepository:
return engine
def _extract_metadata(self, engine: trt.ICudaEngine,
file_path: str, file_hash: str) -> ModelMetadata:
def _extract_metadata(
self, engine: trt.ICudaEngine, file_path: str, file_hash: str
) -> ModelMetadata:
"""
Extract metadata from TensorRT engine.
@ -369,15 +403,19 @@ class TensorRTModelRepository:
input_names=input_names,
output_names=output_names,
input_dtypes=input_dtypes,
output_dtypes=output_dtypes
output_dtypes=output_dtypes,
)
def load_model(self, model_id: str, file_path: str,
num_contexts: Optional[int] = None,
force_reload: bool = False,
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
pt_precision: Optional[torch.dtype] = None,
**pt_conversion_kwargs) -> ModelMetadata:
def load_model(
self,
model_id: str,
file_path: str,
num_contexts: Optional[int] = None,
force_reload: bool = False,
pt_input_shapes: Optional[Dict[str, Tuple]] = None,
pt_precision: Optional[torch.dtype] = None,
**pt_conversion_kwargs,
) -> ModelMetadata:
"""
Load a TensorRT model with the given ID.
@ -410,7 +448,7 @@ class TensorRTModelRepository:
# Check if file is PyTorch model
file_ext = Path(file_path).suffix.lower()
if file_ext in ['.pt', '.pth']:
if file_ext in [".pt", ".pth"]:
if not self.enable_pt_conversion:
raise ValueError(
f"PT file provided but PT conversion is disabled. "
@ -425,7 +463,7 @@ class TensorRTModelRepository:
file_path,
input_shapes=pt_input_shapes,
precision=pt_precision,
**pt_conversion_kwargs
**pt_conversion_kwargs,
)
# Update file_path to use converted TRT file
@ -455,8 +493,12 @@ class TensorRTModelRepository:
# Check if this file is already loaded (deduplication)
if file_hash in self._shared_engines:
shared_engine = self._shared_engines[file_hash]
print(f"Engine already loaded (hash match), reusing engine and context pool...")
print(f" Existing model_ids using this engine: {shared_engine.model_ids}")
print(
f"Engine already loaded (hash match), reusing engine and context pool..."
)
print(
f" Existing model_ids using this engine: {shared_engine.model_ids}"
)
else:
# Load new engine
print(f"Loading TensorRT engine from {file_path}...")
@ -472,7 +514,7 @@ class TensorRTModelRepository:
file_path=file_path,
num_contexts=num_contexts,
device=self.device,
metadata=metadata
metadata=metadata,
)
self._shared_engines[file_hash] = shared_engine
@ -485,18 +527,29 @@ class TensorRTModelRepository:
print(f"Model '{model_id}' loaded successfully")
print(f" Inputs: {shared_engine.metadata.input_names}")
for name in shared_engine.metadata.input_names:
print(f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})")
print(
f" {name}: {shared_engine.metadata.input_shapes[name]} ({shared_engine.metadata.input_dtypes[name]})"
)
print(f" Outputs: {shared_engine.metadata.output_names}")
for name in shared_engine.metadata.output_names:
print(f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})")
print(
f" {name}: {shared_engine.metadata.output_shapes[name]} ({shared_engine.metadata.output_dtypes[name]})"
)
print(f" Context pool size: {num_contexts}")
print(f" Model IDs sharing this engine: {shared_engine.get_reference_count()}")
print(
f" Model IDs sharing this engine: {shared_engine.get_reference_count()}"
)
print(f" Unique engines in VRAM: {len(self._shared_engines)}")
return shared_engine.metadata
def infer(self, model_id: str, inputs: Dict[str, torch.Tensor],
synchronize: bool = True, timeout: Optional[float] = 5.0) -> Dict[str, torch.Tensor]:
def infer(
self,
model_id: str,
inputs: Dict[str, torch.Tensor],
synchronize: bool = True,
timeout: Optional[float] = 5.0,
) -> Dict[str, torch.Tensor]:
"""
Run GPU-to-GPU inference with the specified model using context pooling.
@ -519,7 +572,9 @@ class TensorRTModelRepository:
"""
# Get shared engine
if model_id not in self._model_to_hash:
raise KeyError(f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}")
raise KeyError(
f"Model '{model_id}' not found. Available: {list(self._model_to_hash.keys())}"
)
file_hash = self._model_to_hash[model_id]
shared_engine = self._shared_engines[file_hash]
@ -536,7 +591,9 @@ class TensorRTModelRepository:
# Check device
if tensor.device != self.device:
print(f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}")
print(
f"Warning: Input '{name}' on {tensor.device}, moving to {self.device}"
)
inputs[name] = tensor.to(self.device)
# Acquire context from pool (mutex-based)
@ -562,9 +619,7 @@ class TensorRTModelRepository:
output_dtype = metadata.output_dtypes[name]
output_tensor = torch.empty(
output_shape,
dtype=output_dtype,
device=self.device
output_shape, dtype=output_dtype, device=self.device
)
# NOTE: Don't track these tensors - they're returned to caller and consumed
@ -584,9 +639,23 @@ class TensorRTModelRepository:
if not success:
raise RuntimeError(f"Inference failed for model '{model_id}'")
# Synchronize if requested
if synchronize:
exec_ctx.stream.synchronize()
# CRITICAL: Always synchronize before releasing context
# Even if caller requested async execution, we MUST sync before
# releasing the context to prevent race conditions where the next
# inference using this context overwrites tensor addresses while
# the current batch is still being processed.
exec_ctx.stream.synchronize()
# Clone outputs to new tensors to ensure memory safety
# This prevents race conditions where the next batch using this context
# could overwrite the output tensor addresses before the caller
# finishes processing these results.
if not synchronize:
# For async mode, clone to decouple from context
cloned_outputs = {}
for name, tensor in outputs.items():
cloned_outputs[name] = tensor.clone()
outputs = cloned_outputs
return outputs
@ -594,8 +663,12 @@ class TensorRTModelRepository:
# Always release context back to pool
shared_engine.release_context(exec_ctx)
def infer_batch(self, model_id: str, batch_inputs: List[Dict[str, torch.Tensor]],
synchronize: bool = True) -> List[Dict[str, torch.Tensor]]:
def infer_batch(
self,
model_id: str,
batch_inputs: List[Dict[str, torch.Tensor]],
synchronize: bool = True,
) -> List[Dict[str, torch.Tensor]]:
"""
Run inference on multiple inputs.
Contexts are borrowed/returned for each input, enabling parallel processing.
@ -641,9 +714,13 @@ class TensorRTModelRepository:
if remaining_refs == 0:
shared_engine.cleanup()
del self._shared_engines[file_hash]
print(f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)")
print(
f"Model '{model_id}' unloaded, engine removed from VRAM (0 references)"
)
else:
print(f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)")
print(
f"Model '{model_id}' unloaded, engine kept in VRAM ({remaining_refs} references)"
)
# Remove from model_id mapping
del self._model_to_hash[model_id]
@ -702,26 +779,26 @@ class TensorRTModelRepository:
metadata = shared_engine.metadata
return {
'model_id': model_id,
'file_path': metadata.file_path,
'file_hash': metadata.file_hash[:16] + '...',
'engine_references': shared_engine.get_reference_count(),
'context_pool_size': shared_engine.num_contexts,
'shared_with_model_ids': list(shared_engine.model_ids),
'inputs': {
"model_id": model_id,
"file_path": metadata.file_path,
"file_hash": metadata.file_hash[:16] + "...",
"engine_references": shared_engine.get_reference_count(),
"context_pool_size": shared_engine.num_contexts,
"shared_with_model_ids": list(shared_engine.model_ids),
"inputs": {
name: {
'shape': metadata.input_shapes[name],
'dtype': str(metadata.input_dtypes[name])
"shape": metadata.input_shapes[name],
"dtype": str(metadata.input_dtypes[name]),
}
for name in metadata.input_names
},
'outputs': {
"outputs": {
name: {
'shape': metadata.output_shapes[name],
'dtype': str(metadata.output_dtypes[name])
"shape": metadata.output_shapes[name],
"dtype": str(metadata.output_dtypes[name]),
}
for name in metadata.output_names
}
},
}
def get_stats(self) -> Dict[str, Any]:
@ -733,24 +810,25 @@ class TensorRTModelRepository:
"""
with self._repo_lock:
total_contexts = sum(
engine.num_contexts
for engine in self._shared_engines.values()
engine.num_contexts for engine in self._shared_engines.values()
)
return {
'total_model_ids': len(self._model_to_hash),
'unique_engines': len(self._shared_engines),
'total_contexts': total_contexts,
'memory_efficiency': f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
'gpu_id': self.gpu_id,
'models': list(self._model_to_hash.keys())
"total_model_ids": len(self._model_to_hash),
"unique_engines": len(self._shared_engines),
"total_contexts": total_contexts,
"memory_efficiency": f"{len(self._model_to_hash)} model IDs using only {len(self._shared_engines)} engines",
"gpu_id": self.gpu_id,
"models": list(self._model_to_hash.keys()),
}
def __repr__(self):
with self._repo_lock:
return (f"TensorRTModelRepository(gpu={self.gpu_id}, "
f"model_ids={len(self._model_to_hash)}, "
f"unique_engines={len(self._shared_engines)})")
return (
f"TensorRTModelRepository(gpu={self.gpu_id}, "
f"model_ids={len(self._model_to_hash)}, "
f"unique_engines={len(self._shared_engines)})"
)
def __del__(self):
"""Cleanup all models on deletion"""

View file

@ -5,25 +5,28 @@ This module provides high-level connection management for multiple RTSP streams,
coordinating decoders, batched inference, and tracking with callbacks and threading.
"""
import threading
import time
from typing import Dict, Optional, Callable, Tuple, Any, List
from dataclasses import dataclass
from enum import Enum
import logging
import queue
import threading
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from .model_controller import ModelController
from .stream_decoder import StreamDecoderFactory
from .base_model_controller import BaseModelController
from .model_repository import TensorRTModelRepository
from .stream_decoder import StreamDecoderFactory
from .tensorrt_model_controller import TensorRTModelController
from .ultralytics_model_controller import UltralyticsModelController
logger = logging.getLogger(__name__)
class ConnectionStatus(Enum):
"""Status of a stream connection"""
CONNECTING = "connecting"
CONNECTED = "connected"
DISCONNECTED = "disconnected"
@ -33,6 +36,7 @@ class ConnectionStatus(Enum):
@dataclass
class TrackingResult:
"""Result emitted to user callbacks"""
stream_id: str
timestamp: float
tracked_objects: List # List of TrackedObject from TrackingController
@ -61,7 +65,7 @@ class StreamConnection:
self,
stream_id: str,
decoder,
model_controller: ModelController,
model_controller: BaseModelController,
tracking_controller,
poll_interval: float = 0.01,
):
@ -107,7 +111,9 @@ class StreamConnection:
break
else:
# Timeout - but don't fail hard, let it try to connect in background
logger.warning(f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying...")
logger.warning(
f"Stream {self.stream_id} not connected after {max_wait}s, will continue trying..."
)
self.status = ConnectionStatus.CONNECTING
def stop(self):
@ -144,28 +150,42 @@ class StreamConnection:
self.last_frame_time = time.time()
self.frame_count += 1
# CRITICAL: Clone the GPU tensor to decouple from decoder's frame buffer
# The decoder reuses frame buffer memory, so we must copy the tensor
# before submitting to async batched inference to prevent race conditions
# where the decoder overwrites memory while inference is still reading it.
cloned_tensor = frame_ref.rgb_tensor.clone()
# Submit to model controller for batched inference
# Pass the FrameReference in metadata so we can free it later
self.model_controller.submit_frame(
stream_id=self.stream_id,
frame=frame_ref.rgb_tensor,
frame=cloned_tensor, # Use cloned tensor, not original
metadata={
"frame_number": self.frame_count,
"shape": tuple(frame_ref.rgb_tensor.shape),
"shape": tuple(cloned_tensor.shape),
"frame_ref": frame_ref, # Store reference for later cleanup
}
},
)
# Update connection status based on decoder status
if self.decoder.is_connected() and self.status != ConnectionStatus.CONNECTED:
if (
self.decoder.is_connected()
and self.status != ConnectionStatus.CONNECTED
):
logger.info(f"Stream {self.stream_id} reconnected")
self.status = ConnectionStatus.CONNECTED
elif not self.decoder.is_connected() and self.status == ConnectionStatus.CONNECTED:
elif (
not self.decoder.is_connected()
and self.status == ConnectionStatus.CONNECTED
):
logger.warning(f"Stream {self.stream_id} disconnected")
self.status = ConnectionStatus.DISCONNECTED
except Exception as e:
logger.error(f"Error processing frame for {self.stream_id}: {e}", exc_info=True)
logger.error(
f"Error processing frame for {self.stream_id}: {e}", exc_info=True
)
self.error_queue.put(e)
self.status = ConnectionStatus.ERROR
# Free the frame on error
@ -205,7 +225,10 @@ class StreamConnection:
self.result_queue.put(tracking_result)
except Exception as e:
logger.error(f"Error handling inference result for {self.stream_id}: {e}", exc_info=True)
logger.error(
f"Error handling inference result for {self.stream_id}: {e}",
exc_info=True,
)
self.error_queue.put(e)
finally:
# Free the frame reference - this is the last point in the pipeline
@ -235,12 +258,16 @@ class StreamConnection:
if confidence < min_confidence:
continue
detection_list.append(Detection(
bbox=det[:4].cpu().tolist(),
confidence=confidence,
class_id=int(det[5]) if det.shape[0] > 5 else 0,
class_name=f"class_{int(det[5])}" if det.shape[0] > 5 else "unknown"
))
detection_list.append(
Detection(
bbox=det[:4].cpu().tolist(),
confidence=confidence,
class_id=int(det[5]) if det.shape[0] > 5 else 0,
class_name=f"class_{int(det[5])}"
if det.shape[0] > 5
else "unknown",
)
)
# Update tracker with detections (will scale bboxes to frame space)
return self.tracking_controller.update(detection_list, frame_shape=frame_shape)
@ -319,21 +346,38 @@ class StreamConnectionManager:
force_timeout: float = 0.05,
poll_interval: float = 0.01,
enable_pt_conversion: bool = True,
backend: str = "tensorrt", # "tensorrt" or "ultralytics"
):
self.gpu_id = gpu_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.poll_interval = poll_interval
self.backend = backend.lower()
# Factories
self.decoder_factory = StreamDecoderFactory(gpu_id=gpu_id)
self.model_repository = TensorRTModelRepository(
gpu_id=gpu_id,
enable_pt_conversion=enable_pt_conversion
)
# Initialize inference engine based on backend
self.inference_engine = None
self.model_repository = None # Legacy - will be removed
if self.backend == "ultralytics":
# Use Ultralytics native YOLO inference
from .inference_engine import UltralyticsEngine
self.inference_engine = UltralyticsEngine()
logger.info("Using Ultralytics inference engine")
else:
# Use native TensorRT inference
self.model_repository = TensorRTModelRepository(
gpu_id=gpu_id, enable_pt_conversion=enable_pt_conversion
)
logger.info("Using native TensorRT inference engine")
# Controllers
self.model_controller: Optional[ModelController] = None
self.model_controller = (
None # Will be TensorRTModelController or UltralyticsModelController
)
# Connections
self.connections: Dict[str, StreamConnection] = {}
@ -350,7 +394,7 @@ class StreamConnectionManager:
num_contexts: int = 4,
pt_input_shapes: Optional[Dict] = None,
pt_precision: Optional[Any] = None,
**pt_conversion_kwargs
**pt_conversion_kwargs,
):
"""
Initialize the manager with a model.
@ -382,28 +426,58 @@ class StreamConnectionManager:
)
"""
logger.info(f"Initializing StreamConnectionManager on GPU {self.gpu_id}")
logger.info(f"Backend: {self.backend}")
# Load model (synchronous)
self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs
)
logger.info(f"Loaded model {model_id} from {model_path}")
# Initialize engine based on backend
if self.backend == "ultralytics":
# Use Ultralytics native inference
logger.info("Initializing Ultralytics YOLO engine...")
device = torch.device(f"cuda:{self.gpu_id}")
# Create model controller
self.model_controller = ModelController(
model_repository=self.model_repository,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_controller.start()
metadata = self.inference_engine.initialize(
model_path=model_path,
device=device,
batch=self.batch_size,
half=False, # Use FP32 for now
imgsz=640,
**pt_conversion_kwargs,
)
logger.info(f"Ultralytics engine initialized: {metadata}")
# Create Ultralytics model controller
self.model_controller = UltralyticsModelController(
inference_engine=self.inference_engine,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_controller.start()
else:
# Use native TensorRT with model repository
logger.info("Initializing TensorRT engine...")
self.model_repository.load_model(
model_id,
model_path,
num_contexts=num_contexts,
pt_input_shapes=pt_input_shapes,
pt_precision=pt_precision,
**pt_conversion_kwargs,
)
logger.info(f"Loaded model {model_id} from {model_path}")
# Create TensorRT model controller
self.model_controller = TensorRTModelController(
model_repository=self.model_repository,
model_id=model_id,
batch_size=self.batch_size,
force_timeout=self.force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_controller.start()
# Don't create a shared tracking controller here
# Each stream will get its own tracking controller to avoid track accumulation
@ -452,12 +526,13 @@ class StreamConnectionManager:
# Create lightweight tracker (NO model_repository dependency!)
from .tracking_controller import ObjectTracker
tracking_controller = ObjectTracker(
gpu_id=self.gpu_id,
tracker_type="iou",
max_age=30,
iou_threshold=0.3,
class_names=None # TODO: pass class names if available
class_names=None, # TODO: pass class names if available
)
logger.info(f"Created lightweight ObjectTracker for stream {stream_id}")
@ -472,8 +547,7 @@ class StreamConnectionManager:
# Register callback with model controller
self.model_controller.register_callback(
stream_id,
connection._handle_inference_result
stream_id, connection._handle_inference_result
)
# Start connection
@ -487,14 +561,12 @@ class StreamConnectionManager:
threading.Thread(
target=self._forward_results,
args=(connection, on_tracking_result),
daemon=True
daemon=True,
).start()
if on_error:
threading.Thread(
target=self._forward_errors,
args=(connection, on_error),
daemon=True
target=self._forward_errors, args=(connection, on_error), daemon=True
).start()
logger.info(f"Stream {stream_id} connected successfully")
@ -549,7 +621,10 @@ class StreamConnectionManager:
for result in connection.tracking_results():
callback(result)
except Exception as e:
logger.error(f"Error in result forwarding for {connection.stream_id}: {e}", exc_info=True)
logger.error(
f"Error in result forwarding for {connection.stream_id}: {e}",
exc_info=True,
)
def _forward_errors(self, connection: StreamConnection, callback: Callable):
"""
@ -563,7 +638,10 @@ class StreamConnectionManager:
for error in connection.errors():
callback(error)
except Exception as e:
logger.error(f"Error in error forwarding for {connection.stream_id}: {e}", exc_info=True)
logger.error(
f"Error in error forwarding for {connection.stream_id}: {e}",
exc_info=True,
)
def get_stats(self) -> Dict[str, Any]:
"""
@ -581,7 +659,9 @@ class StreamConnectionManager:
"force_timeout": self.force_timeout,
"poll_interval": self.poll_interval,
},
"model_controller": self.model_controller.get_stats() if self.model_controller else {},
"model_controller": self.model_controller.get_stats()
if self.model_controller
else {},
"connections": {
stream_id: conn.get_stats()
for stream_id, conn in self.connections.items()

View file

@ -0,0 +1,182 @@
"""
TensorRT Model Controller - Native TensorRT inference with batched processing.
"""
import logging
from typing import Any, Callable, Dict, List, Optional
import torch
from .base_model_controller import BaseModelController, BatchFrame
logger = logging.getLogger(__name__)
class TensorRTModelController(BaseModelController):
"""
Model controller for native TensorRT inference.
Uses TensorRTModelRepository for GPU-accelerated inference with
context pooling and deduplication.
"""
def __init__(
self,
model_repository,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
super().__init__(
model_id=model_id,
batch_size=batch_size,
force_timeout=force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_repository = model_repository
# Detect model's actual batch size from input shape
self.model_batch_size = self._detect_model_batch_size()
if self.model_batch_size == 1:
logger.warning(
f"Model '{model_id}' has fixed batch_size=1. "
f"Will process frames sequentially."
)
else:
logger.info(
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
)
def _detect_model_batch_size(self) -> int:
"""Detect the model's batch size from its input shape"""
try:
metadata = self.model_repository.get_metadata(self.model_id)
first_input_name = metadata.input_names[0]
input_shape = metadata.input_shapes[first_input_name]
batch_dim = input_shape[0]
if batch_dim == -1:
return self.batch_size # Dynamic batch size
else:
return batch_dim # Fixed batch size
except Exception as e:
logger.warning(
f"Could not detect model batch size: {e}. Assuming batch_size=1"
)
return 1
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run TensorRT inference on a batch of frames"""
if self.model_batch_size == 1:
return self._run_sequential_inference(batch)
else:
return self._run_batched_inference(batch)
def _run_sequential_inference(
self, batch: List[BatchFrame]
) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models"""
results = []
for batch_frame in batch:
# Preprocess frame
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
else:
processed = (
batch_frame.frame.unsqueeze(0)
if batch_frame.frame.dim() == 3
else batch_frame.frame
)
# Run inference
outputs = self.model_repository.infer(
self.model_id, {"images": processed}, synchronize=True
)
# Postprocess
if self.postprocess_fn:
try:
detections = self.postprocess_fn(outputs)
except Exception as e:
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else:
detections = outputs
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
}
results.append(result)
return results
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run true batched inference for models that support it"""
# Preprocess frames
preprocessed = []
for batch_frame in batch:
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
if processed.dim() == 4 and processed.shape[0] == 1:
processed = processed.squeeze(0)
else:
processed = batch_frame.frame
preprocessed.append(processed)
# Stack into batch tensor
batch_tensor = torch.stack(preprocessed, dim=0)
# Limit to model's max batch size
if batch_tensor.shape[0] > self.model_batch_size:
logger.warning(
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}"
)
batch_tensor = batch_tensor[: self.model_batch_size]
batch = batch[: self.model_batch_size]
# Run inference
outputs = self.model_repository.infer(
self.model_id, {"images": batch_tensor}, synchronize=True
)
# Postprocess results (split batch back to individual results)
results = []
for i, batch_frame in enumerate(batch):
# Extract single frame output and clone for memory safety
frame_output = {}
for k, v in outputs.items():
frame_output[k] = v[i : i + 1].clone()
if self.postprocess_fn:
try:
detections = self.postprocess_fn(frame_output)
except Exception as e:
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else:
detections = frame_output
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
}
results.append(result)
return results

View file

@ -0,0 +1,222 @@
"""
Ultralytics YOLO Model Exporter with Caching
Exports YOLO .pt models to TensorRT .engine format using Ultralytics library.
Provides proper NMS and postprocessing built into the engine.
Caches exported engines to avoid redundant exports.
"""
import hashlib
import json
import logging
from pathlib import Path
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
class UltralyticsExporter:
"""
Export YOLO models using Ultralytics with caching.
Features:
- Exports .pt models to TensorRT .engine format
- Caches exported engines by source file hash
- Saves metadata about exported models
- Reuses cached engines when available
"""
def __init__(self, cache_dir: str = ".ultralytics_cache"):
"""
Initialize exporter.
Args:
cache_dir: Directory for caching exported engines
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Ultralytics exporter cache directory: {self.cache_dir}")
@staticmethod
def compute_file_hash(file_path: str) -> str:
"""
Compute SHA256 hash of a file.
Args:
file_path: Path to file
Returns:
Hexadecimal hash string
"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(65536), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def export(
self,
model_path: str,
device: int = 0,
half: bool = False,
imgsz: int = 640,
batch: int = 1,
**export_kwargs,
) -> Tuple[str, str]:
"""
Export YOLO model to TensorRT engine with caching.
Args:
model_path: Path to .pt model file
device: GPU device ID
half: Use FP16 precision
imgsz: Input image size (default: 640)
batch: Maximum batch size for inference
**export_kwargs: Additional arguments for Ultralytics export
Returns:
Tuple of (engine_hash, engine_path)
Raises:
FileNotFoundError: If model file doesn't exist
RuntimeError: If export fails
"""
model_path = Path(model_path).resolve()
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
# Compute hash of source model
logger.info(f"Computing hash for {model_path}...")
model_hash = self.compute_file_hash(str(model_path))
logger.info(f"Model hash: {model_hash[:16]}...")
# Create export config hash (includes export parameters)
export_config = {
"model_hash": model_hash,
"device": device,
"half": half,
"imgsz": imgsz,
"batch": batch,
**export_kwargs,
}
config_str = json.dumps(export_config, sort_keys=True)
config_hash = hashlib.sha256(config_str.encode()).hexdigest()
# Check cache
cache_engine_path = self.cache_dir / f"{config_hash}.engine"
cache_metadata_path = self.cache_dir / f"{config_hash}_metadata.json"
if cache_engine_path.exists():
logger.info(f"Found cached engine: {cache_engine_path}")
logger.info(f"Reusing cached export (config hash: {config_hash[:16]}...)")
# Load and return metadata
if cache_metadata_path.exists():
with open(cache_metadata_path, "r") as f:
metadata = json.load(f)
logger.info(f"Cached engine metadata: {metadata}")
return config_hash, str(cache_engine_path)
# Export using Ultralytics
logger.info(f"Exporting YOLO model to TensorRT engine...")
logger.info(f" Source: {model_path}")
logger.info(f" Device: GPU {device}")
logger.info(f" Precision: {'FP16' if half else 'FP32'}")
logger.info(f" Image size: {imgsz}")
logger.info(f" Batch size: {batch}")
try:
from ultralytics import YOLO
# Load model
model = YOLO(str(model_path))
# Export to TensorRT
exported_path = model.export(
format="engine",
device=device,
half=half,
imgsz=imgsz,
batch=batch,
verbose=True,
**export_kwargs,
)
logger.info(f"Export complete: {exported_path}")
# Copy to cache
import shutil
shutil.copy(exported_path, cache_engine_path)
logger.info(f"Cached engine: {cache_engine_path}")
# Save metadata
metadata = {
"source_model": str(model_path),
"model_hash": model_hash,
"config_hash": config_hash,
"device": device,
"half": half,
"imgsz": imgsz,
"batch": batch,
"export_kwargs": export_kwargs,
"exported_path": str(exported_path),
"cached_path": str(cache_engine_path),
}
with open(cache_metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
logger.info(f"Saved metadata: {cache_metadata_path}")
return config_hash, str(cache_engine_path)
except Exception as e:
logger.error(f"Export failed: {e}")
raise RuntimeError(f"Failed to export YOLO model: {e}")
def get_cached_engine(self, model_path: str, **export_kwargs) -> Optional[str]:
"""
Get cached engine path if it exists.
Args:
model_path: Path to .pt model
**export_kwargs: Export parameters (must match cached export)
Returns:
Path to cached engine or None if not cached
"""
try:
model_path = Path(model_path).resolve()
if not model_path.exists():
return None
# Compute hashes
model_hash = self.compute_file_hash(str(model_path))
export_config = {"model_hash": model_hash, **export_kwargs}
config_str = json.dumps(export_config, sort_keys=True)
config_hash = hashlib.sha256(config_str.encode()).hexdigest()
cache_engine_path = self.cache_dir / f"{config_hash}.engine"
if cache_engine_path.exists():
return str(cache_engine_path)
return None
except Exception as e:
logger.warning(f"Failed to check cache: {e}")
return None
def clear_cache(self):
"""Clear all cached engines"""
import shutil
if self.cache_dir.exists():
shutil.rmtree(self.cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info("Cache cleared")

View file

@ -0,0 +1,217 @@
"""
Ultralytics Model Controller - YOLO inference with batched processing.
"""
import logging
from typing import Any, Callable, Dict, List, Optional
import torch
from .base_model_controller import BaseModelController, BatchFrame
logger = logging.getLogger(__name__)
class UltralyticsModelController(BaseModelController):
"""
Model controller for Ultralytics YOLO inference.
Uses UltralyticsEngine which wraps the Ultralytics YOLO model with
native TensorRT backend for GPU-accelerated inference.
"""
def __init__(
self,
inference_engine,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
# Auto-detect actual batch size from the YOLO engine
engine_batch_size = self._detect_engine_batch_size(inference_engine)
# If engine has fixed batch size, use it. Otherwise use user's batch_size
actual_batch_size = engine_batch_size if engine_batch_size > 0 else batch_size
super().__init__(
model_id=model_id,
batch_size=actual_batch_size,
force_timeout=force_timeout,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.inference_engine = inference_engine
self.engine_batch_size = engine_batch_size # Store for padding logic
if engine_batch_size > 0:
logger.info(
f"Ultralytics engine has fixed batch_size={engine_batch_size}, "
f"will pad batches to match"
)
else:
logger.info(
f"Ultralytics engine supports dynamic batching, "
f"using max batch_size={actual_batch_size}"
)
def _detect_engine_batch_size(self, inference_engine) -> int:
"""
Detect the batch size from Ultralytics engine.
Returns:
Fixed batch size (e.g., 2, 4, 8) or -1 for dynamic batching
"""
try:
# Get engine metadata
metadata = inference_engine.get_metadata()
# Check input shape for batch dimension
if "images" in metadata.input_shapes:
input_shape = metadata.input_shapes["images"]
batch_dim = input_shape[0]
if batch_dim > 0:
# Fixed batch size
return batch_dim
else:
# Dynamic batch size (-1)
return -1
# Fallback: try to get from model directly
if (
hasattr(inference_engine, "_model")
and inference_engine._model is not None
):
model = inference_engine._model
# Try to get batch info from Ultralytics model
if hasattr(model, "predictor") and model.predictor is not None:
predictor = model.predictor
if hasattr(predictor, "model") and hasattr(
predictor.model, "batch"
):
return predictor.model.batch
# Try to get from model.model (for .engine files)
if hasattr(model, "model"):
# For TensorRT engines, check input shape
if hasattr(model.model, "get_input_details"):
details = model.model.get_input_details()
if details and len(details) > 0:
shape = details[0].get("shape")
if shape and len(shape) > 0:
return shape[0] if shape[0] > 0 else -1
except Exception as e:
logger.warning(f"Could not detect engine batch size: {e}")
# Default: assume dynamic batching
return -1
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run Ultralytics YOLO inference on a batch of frames.
Ultralytics handles batching natively and returns Results objects.
"""
# Preprocess frames
preprocessed = []
for batch_frame in batch:
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
# Ensure shape is (C, H, W) not (1, C, H, W)
if processed.dim() == 4 and processed.shape[0] == 1:
processed = processed.squeeze(0)
else:
processed = batch_frame.frame
preprocessed.append(processed)
# Stack into batch tensor: (B, C, H, W)
batch_tensor = torch.stack(preprocessed, dim=0)
actual_batch_size = len(batch)
# Handle fixed batch size engines (pad if needed)
if self.engine_batch_size > 0:
# Engine has fixed batch size
if batch_tensor.shape[0] > self.engine_batch_size:
# Truncate to engine's max batch size
logger.warning(
f"Batch size {batch_tensor.shape[0]} exceeds engine max {self.engine_batch_size}, truncating"
)
batch_tensor = batch_tensor[: self.engine_batch_size]
batch = batch[: self.engine_batch_size]
actual_batch_size = self.engine_batch_size
elif batch_tensor.shape[0] < self.engine_batch_size:
# Pad to match engine's fixed batch size
padding_size = self.engine_batch_size - batch_tensor.shape[0]
# Replicate last frame to pad (cheaper than zeros)
padding = batch_tensor[-1:].repeat(padding_size, 1, 1, 1)
batch_tensor = torch.cat([batch_tensor, padding], dim=0)
logger.debug(
f"Padded batch from {actual_batch_size} to {self.engine_batch_size} frames"
)
else:
# Dynamic batching - just limit to max
if batch_tensor.shape[0] > self.batch_size:
logger.warning(
f"Batch size {batch_tensor.shape[0]} exceeds configured max {self.batch_size}"
)
batch_tensor = batch_tensor[: self.batch_size]
batch = batch[: self.batch_size]
actual_batch_size = self.batch_size
# Run Ultralytics inference
# Input should be (B, 3, H, W) in range [0, 1], RGB format
outputs = self.inference_engine.infer(
inputs={"images": batch_tensor},
conf=0.25, # Confidence threshold
iou=0.45, # NMS IoU threshold
)
# Ultralytics returns Results objects in outputs["results"]
yolo_results = outputs["results"]
# Convert Results objects to our standard format
# Only process actual batch size (ignore padded results if any)
results = []
for i in range(actual_batch_size):
batch_frame = batch[i]
yolo_result = yolo_results[i]
# Extract detections from YOLO Results object
# yolo_result.boxes.data has format: [x1, y1, x2, y2, conf, cls]
if hasattr(yolo_result, "boxes") and yolo_result.boxes is not None:
detections = yolo_result.boxes.data # Already a tensor on GPU
else:
# No detections
detections = torch.zeros((0, 6), device=batch_tensor.device)
# Apply custom postprocessing if provided
if self.postprocess_fn:
try:
# For Ultralytics, postprocess_fn might do additional filtering
# Pass the raw boxes tensor in the same format as TensorRT output
detections = self.postprocess_fn(
{
"output0": detections.unsqueeze(
0
) # Add batch dim for compatibility
}
)
except Exception as e:
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
detections = torch.zeros((0, 6), device=batch_tensor.device)
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
"yolo_result": yolo_result, # Keep original Results object for debugging
}
results.append(result)
return results

View file

@ -101,14 +101,15 @@ class YOLOv8Utils:
# Get output tensor (first and only output)
output_name = list(outputs.keys())[0]
output = outputs[output_name] # (1, 84, 8400)
output = outputs[output_name] # (1, 4+num_classes, 8400)
# Transpose to (1, 8400, 84) for easier processing
output = output.transpose(1, 2).squeeze(0) # (8400, 84)
# Transpose to (1, 8400, 4+num_classes) for easier processing
output = output.transpose(1, 2).squeeze(0) # (8400, 4+num_classes)
# Split bbox coordinates and class scores (vectorized)
# Format: [cx, cy, w, h, class_scores...]
bboxes = output[:, :4] # (8400, 4) - (cx, cy, w, h)
class_scores = output[:, 4:] # (8400, 80)
class_scores = output[:, 4:] # (8400, num_classes) - dynamically sized
# Get max class score and corresponding class ID for all anchors (vectorized)
max_scores, class_ids = torch.max(class_scores, dim=1) # (8400,), (8400,)