184 lines
6.4 KiB
Python
184 lines
6.4 KiB
Python
"""
|
|
TensorRT Model Controller - Native TensorRT inference with batched processing.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
import torch
|
|
|
|
from .base_model_controller import BaseModelController, BatchFrame
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TensorRTModelController(BaseModelController):
|
|
"""
|
|
Model controller for native TensorRT inference.
|
|
|
|
Uses TensorRTModelRepository for GPU-accelerated inference with
|
|
context pooling and deduplication.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_repository,
|
|
model_id: str,
|
|
batch_size: int = 16,
|
|
max_queue_size: int = 100,
|
|
preprocess_fn: Optional[Callable] = None,
|
|
postprocess_fn: Optional[Callable] = None,
|
|
):
|
|
super().__init__(
|
|
model_id=model_id,
|
|
batch_size=batch_size,
|
|
max_queue_size=max_queue_size,
|
|
preprocess_fn=preprocess_fn,
|
|
postprocess_fn=postprocess_fn,
|
|
)
|
|
self.model_repository = model_repository
|
|
|
|
# Detect model's actual batch size from input shape
|
|
self.model_batch_size = self._detect_model_batch_size()
|
|
if self.model_batch_size == 1:
|
|
logger.warning(
|
|
f"Model '{model_id}' has fixed batch_size=1. "
|
|
f"Will process frames sequentially."
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
|
|
)
|
|
|
|
def _detect_model_batch_size(self) -> int:
|
|
"""Detect the model's batch size from its input shape"""
|
|
try:
|
|
metadata = self.model_repository.get_metadata(self.model_id)
|
|
first_input_name = metadata.input_names[0]
|
|
input_shape = metadata.input_shapes[first_input_name]
|
|
batch_dim = input_shape[0]
|
|
|
|
if batch_dim == -1:
|
|
return self.batch_size # Dynamic batch size
|
|
else:
|
|
return batch_dim # Fixed batch size
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Could not detect model batch size: {e}. Assuming batch_size=1"
|
|
)
|
|
return 1
|
|
|
|
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
|
"""Run TensorRT inference on a batch of frames"""
|
|
if self.model_batch_size == 1:
|
|
return self._run_sequential_inference(batch)
|
|
else:
|
|
return self._run_batched_inference(batch)
|
|
|
|
def _run_sequential_inference(
|
|
self, batch: List[BatchFrame]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Run inference sequentially for batch_size=1 models"""
|
|
results = []
|
|
|
|
for batch_frame in batch:
|
|
# Preprocess frame
|
|
if self.preprocess_fn:
|
|
processed = self.preprocess_fn(batch_frame.frame)
|
|
else:
|
|
processed = (
|
|
batch_frame.frame.unsqueeze(0)
|
|
if batch_frame.frame.dim() == 3
|
|
else batch_frame.frame
|
|
)
|
|
|
|
# Run inference
|
|
outputs = self.model_repository.infer(
|
|
self.model_id, {"images": processed}, synchronize=True
|
|
)
|
|
|
|
# Postprocess
|
|
if self.postprocess_fn:
|
|
try:
|
|
detections = self.postprocess_fn(outputs)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
|
)
|
|
detections = torch.zeros(
|
|
(0, 6), device=list(outputs.values())[0].device
|
|
)
|
|
else:
|
|
detections = outputs
|
|
|
|
result = {
|
|
"stream_id": batch_frame.stream_id,
|
|
"timestamp": batch_frame.timestamp,
|
|
"detections": detections,
|
|
"frame": batch_frame.frame, # Include original frame tensor
|
|
"metadata": batch_frame.metadata,
|
|
}
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
|
"""Run true batched inference for models that support it"""
|
|
# Preprocess frames
|
|
preprocessed = []
|
|
for batch_frame in batch:
|
|
if self.preprocess_fn:
|
|
processed = self.preprocess_fn(batch_frame.frame)
|
|
if processed.dim() == 4 and processed.shape[0] == 1:
|
|
processed = processed.squeeze(0)
|
|
else:
|
|
processed = batch_frame.frame
|
|
preprocessed.append(processed)
|
|
|
|
# Stack into batch tensor
|
|
batch_tensor = torch.stack(preprocessed, dim=0)
|
|
|
|
# Limit to model's max batch size
|
|
if batch_tensor.shape[0] > self.model_batch_size:
|
|
logger.warning(
|
|
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}"
|
|
)
|
|
batch_tensor = batch_tensor[: self.model_batch_size]
|
|
batch = batch[: self.model_batch_size]
|
|
|
|
# Run inference
|
|
outputs = self.model_repository.infer(
|
|
self.model_id, {"images": batch_tensor}, synchronize=True
|
|
)
|
|
|
|
# Postprocess results (split batch back to individual results)
|
|
results = []
|
|
for i, batch_frame in enumerate(batch):
|
|
# Extract single frame output and clone for memory safety
|
|
frame_output = {}
|
|
for k, v in outputs.items():
|
|
frame_output[k] = v[i : i + 1].clone()
|
|
|
|
if self.postprocess_fn:
|
|
try:
|
|
detections = self.postprocess_fn(frame_output)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
|
)
|
|
detections = torch.zeros(
|
|
(0, 6), device=list(outputs.values())[0].device
|
|
)
|
|
else:
|
|
detections = frame_output
|
|
|
|
result = {
|
|
"stream_id": batch_frame.stream_id,
|
|
"timestamp": batch_frame.timestamp,
|
|
"detections": detections,
|
|
"frame": batch_frame.frame, # Include original frame tensor
|
|
"metadata": batch_frame.metadata,
|
|
}
|
|
results.append(result)
|
|
|
|
return results
|