ultralytic export
This commit is contained in:
parent
bf7b68edb1
commit
fdaeb9981c
14 changed files with 2241 additions and 507 deletions
182
services/tensorrt_model_controller.py
Normal file
182
services/tensorrt_model_controller.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
"""
|
||||
TensorRT Model Controller - Native TensorRT inference with batched processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .base_model_controller import BaseModelController, BatchFrame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TensorRTModelController(BaseModelController):
|
||||
"""
|
||||
Model controller for native TensorRT inference.
|
||||
|
||||
Uses TensorRTModelRepository for GPU-accelerated inference with
|
||||
context pooling and deduplication.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_repository,
|
||||
model_id: str,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
batch_size=batch_size,
|
||||
force_timeout=force_timeout,
|
||||
preprocess_fn=preprocess_fn,
|
||||
postprocess_fn=postprocess_fn,
|
||||
)
|
||||
self.model_repository = model_repository
|
||||
|
||||
# Detect model's actual batch size from input shape
|
||||
self.model_batch_size = self._detect_model_batch_size()
|
||||
if self.model_batch_size == 1:
|
||||
logger.warning(
|
||||
f"Model '{model_id}' has fixed batch_size=1. "
|
||||
f"Will process frames sequentially."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
|
||||
)
|
||||
|
||||
def _detect_model_batch_size(self) -> int:
|
||||
"""Detect the model's batch size from its input shape"""
|
||||
try:
|
||||
metadata = self.model_repository.get_metadata(self.model_id)
|
||||
first_input_name = metadata.input_names[0]
|
||||
input_shape = metadata.input_shapes[first_input_name]
|
||||
batch_dim = input_shape[0]
|
||||
|
||||
if batch_dim == -1:
|
||||
return self.batch_size # Dynamic batch size
|
||||
else:
|
||||
return batch_dim # Fixed batch size
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not detect model batch size: {e}. Assuming batch_size=1"
|
||||
)
|
||||
return 1
|
||||
|
||||
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""Run TensorRT inference on a batch of frames"""
|
||||
if self.model_batch_size == 1:
|
||||
return self._run_sequential_inference(batch)
|
||||
else:
|
||||
return self._run_batched_inference(batch)
|
||||
|
||||
def _run_sequential_inference(
|
||||
self, batch: List[BatchFrame]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run inference sequentially for batch_size=1 models"""
|
||||
results = []
|
||||
|
||||
for batch_frame in batch:
|
||||
# Preprocess frame
|
||||
if self.preprocess_fn:
|
||||
processed = self.preprocess_fn(batch_frame.frame)
|
||||
else:
|
||||
processed = (
|
||||
batch_frame.frame.unsqueeze(0)
|
||||
if batch_frame.frame.dim() == 3
|
||||
else batch_frame.frame
|
||||
)
|
||||
|
||||
# Run inference
|
||||
outputs = self.model_repository.infer(
|
||||
self.model_id, {"images": processed}, synchronize=True
|
||||
)
|
||||
|
||||
# Postprocess
|
||||
if self.postprocess_fn:
|
||||
try:
|
||||
detections = self.postprocess_fn(outputs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
||||
)
|
||||
detections = torch.zeros(
|
||||
(0, 6), device=list(outputs.values())[0].device
|
||||
)
|
||||
else:
|
||||
detections = outputs
|
||||
|
||||
result = {
|
||||
"stream_id": batch_frame.stream_id,
|
||||
"timestamp": batch_frame.timestamp,
|
||||
"detections": detections,
|
||||
"metadata": batch_frame.metadata,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""Run true batched inference for models that support it"""
|
||||
# Preprocess frames
|
||||
preprocessed = []
|
||||
for batch_frame in batch:
|
||||
if self.preprocess_fn:
|
||||
processed = self.preprocess_fn(batch_frame.frame)
|
||||
if processed.dim() == 4 and processed.shape[0] == 1:
|
||||
processed = processed.squeeze(0)
|
||||
else:
|
||||
processed = batch_frame.frame
|
||||
preprocessed.append(processed)
|
||||
|
||||
# Stack into batch tensor
|
||||
batch_tensor = torch.stack(preprocessed, dim=0)
|
||||
|
||||
# Limit to model's max batch size
|
||||
if batch_tensor.shape[0] > self.model_batch_size:
|
||||
logger.warning(
|
||||
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}"
|
||||
)
|
||||
batch_tensor = batch_tensor[: self.model_batch_size]
|
||||
batch = batch[: self.model_batch_size]
|
||||
|
||||
# Run inference
|
||||
outputs = self.model_repository.infer(
|
||||
self.model_id, {"images": batch_tensor}, synchronize=True
|
||||
)
|
||||
|
||||
# Postprocess results (split batch back to individual results)
|
||||
results = []
|
||||
for i, batch_frame in enumerate(batch):
|
||||
# Extract single frame output and clone for memory safety
|
||||
frame_output = {}
|
||||
for k, v in outputs.items():
|
||||
frame_output[k] = v[i : i + 1].clone()
|
||||
|
||||
if self.postprocess_fn:
|
||||
try:
|
||||
detections = self.postprocess_fn(frame_output)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
||||
)
|
||||
detections = torch.zeros(
|
||||
(0, 6), device=list(outputs.values())[0].device
|
||||
)
|
||||
else:
|
||||
detections = frame_output
|
||||
|
||||
result = {
|
||||
"stream_id": batch_frame.stream_id,
|
||||
"timestamp": batch_frame.timestamp,
|
||||
"detections": detections,
|
||||
"metadata": batch_frame.metadata,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
Loading…
Add table
Add a link
Reference in a new issue