python-rtsp-worker/services/tensorrt_model_controller.py

184 lines
6.4 KiB
Python

"""
TensorRT Model Controller - Native TensorRT inference with batched processing.
"""
import logging
from typing import Any, Callable, Dict, List, Optional
import torch
from .base_model_controller import BaseModelController, BatchFrame
logger = logging.getLogger(__name__)
class TensorRTModelController(BaseModelController):
"""
Model controller for native TensorRT inference.
Uses TensorRTModelRepository for GPU-accelerated inference with
context pooling and deduplication.
"""
def __init__(
self,
model_repository,
model_id: str,
batch_size: int = 16,
max_queue_size: int = 100,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
super().__init__(
model_id=model_id,
batch_size=batch_size,
max_queue_size=max_queue_size,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
)
self.model_repository = model_repository
# Detect model's actual batch size from input shape
self.model_batch_size = self._detect_model_batch_size()
if self.model_batch_size == 1:
logger.warning(
f"Model '{model_id}' has fixed batch_size=1. "
f"Will process frames sequentially."
)
else:
logger.info(
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
)
def _detect_model_batch_size(self) -> int:
"""Detect the model's batch size from its input shape"""
try:
metadata = self.model_repository.get_metadata(self.model_id)
first_input_name = metadata.input_names[0]
input_shape = metadata.input_shapes[first_input_name]
batch_dim = input_shape[0]
if batch_dim == -1:
return self.batch_size # Dynamic batch size
else:
return batch_dim # Fixed batch size
except Exception as e:
logger.warning(
f"Could not detect model batch size: {e}. Assuming batch_size=1"
)
return 1
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run TensorRT inference on a batch of frames"""
if self.model_batch_size == 1:
return self._run_sequential_inference(batch)
else:
return self._run_batched_inference(batch)
def _run_sequential_inference(
self, batch: List[BatchFrame]
) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models"""
results = []
for batch_frame in batch:
# Preprocess frame
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
else:
processed = (
batch_frame.frame.unsqueeze(0)
if batch_frame.frame.dim() == 3
else batch_frame.frame
)
# Run inference
outputs = self.model_repository.infer(
self.model_id, {"images": processed}, synchronize=True
)
# Postprocess
if self.postprocess_fn:
try:
detections = self.postprocess_fn(outputs)
except Exception as e:
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else:
detections = outputs
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"frame": batch_frame.frame, # Include original frame tensor
"metadata": batch_frame.metadata,
}
results.append(result)
return results
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run true batched inference for models that support it"""
# Preprocess frames
preprocessed = []
for batch_frame in batch:
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
if processed.dim() == 4 and processed.shape[0] == 1:
processed = processed.squeeze(0)
else:
processed = batch_frame.frame
preprocessed.append(processed)
# Stack into batch tensor
batch_tensor = torch.stack(preprocessed, dim=0)
# Limit to model's max batch size
if batch_tensor.shape[0] > self.model_batch_size:
logger.warning(
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}"
)
batch_tensor = batch_tensor[: self.model_batch_size]
batch = batch[: self.model_batch_size]
# Run inference
outputs = self.model_repository.infer(
self.model_id, {"images": batch_tensor}, synchronize=True
)
# Postprocess results (split batch back to individual results)
results = []
for i, batch_frame in enumerate(batch):
# Extract single frame output and clone for memory safety
frame_output = {}
for k, v in outputs.items():
frame_output[k] = v[i : i + 1].clone()
if self.postprocess_fn:
try:
detections = self.postprocess_fn(frame_output)
except Exception as e:
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else:
detections = frame_output
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"frame": batch_frame.frame, # Include original frame tensor
"metadata": batch_frame.metadata,
}
results.append(result)
return results