""" TensorRT Model Controller - Native TensorRT inference with batched processing. """ import logging from typing import Any, Callable, Dict, List, Optional import torch from .base_model_controller import BaseModelController, BatchFrame logger = logging.getLogger(__name__) class TensorRTModelController(BaseModelController): """ Model controller for native TensorRT inference. Uses TensorRTModelRepository for GPU-accelerated inference with context pooling and deduplication. """ def __init__( self, model_repository, model_id: str, batch_size: int = 16, max_queue_size: int = 100, preprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None, ): super().__init__( model_id=model_id, batch_size=batch_size, max_queue_size=max_queue_size, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, ) self.model_repository = model_repository # Detect model's actual batch size from input shape self.model_batch_size = self._detect_model_batch_size() if self.model_batch_size == 1: logger.warning( f"Model '{model_id}' has fixed batch_size=1. " f"Will process frames sequentially." ) else: logger.info( f"Model '{model_id}' supports batch_size={self.model_batch_size}" ) def _detect_model_batch_size(self) -> int: """Detect the model's batch size from its input shape""" try: metadata = self.model_repository.get_metadata(self.model_id) first_input_name = metadata.input_names[0] input_shape = metadata.input_shapes[first_input_name] batch_dim = input_shape[0] if batch_dim == -1: return self.batch_size # Dynamic batch size else: return batch_dim # Fixed batch size except Exception as e: logger.warning( f"Could not detect model batch size: {e}. Assuming batch_size=1" ) return 1 def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: """Run TensorRT inference on a batch of frames""" if self.model_batch_size == 1: return self._run_sequential_inference(batch) else: return self._run_batched_inference(batch) def _run_sequential_inference( self, batch: List[BatchFrame] ) -> List[Dict[str, Any]]: """Run inference sequentially for batch_size=1 models""" results = [] for batch_frame in batch: # Preprocess frame if self.preprocess_fn: processed = self.preprocess_fn(batch_frame.frame) else: processed = ( batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame ) # Run inference outputs = self.model_repository.infer( self.model_id, {"images": processed}, synchronize=True ) # Postprocess if self.postprocess_fn: try: detections = self.postprocess_fn(outputs) except Exception as e: logger.error( f"Error in postprocess for stream {batch_frame.stream_id}: {e}" ) detections = torch.zeros( (0, 6), device=list(outputs.values())[0].device ) else: detections = outputs result = { "stream_id": batch_frame.stream_id, "timestamp": batch_frame.timestamp, "detections": detections, "frame": batch_frame.frame, # Include original frame tensor "metadata": batch_frame.metadata, } results.append(result) return results def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: """Run true batched inference for models that support it""" # Preprocess frames preprocessed = [] for batch_frame in batch: if self.preprocess_fn: processed = self.preprocess_fn(batch_frame.frame) if processed.dim() == 4 and processed.shape[0] == 1: processed = processed.squeeze(0) else: processed = batch_frame.frame preprocessed.append(processed) # Stack into batch tensor batch_tensor = torch.stack(preprocessed, dim=0) # Limit to model's max batch size if batch_tensor.shape[0] > self.model_batch_size: logger.warning( f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}" ) batch_tensor = batch_tensor[: self.model_batch_size] batch = batch[: self.model_batch_size] # Run inference outputs = self.model_repository.infer( self.model_id, {"images": batch_tensor}, synchronize=True ) # Postprocess results (split batch back to individual results) results = [] for i, batch_frame in enumerate(batch): # Extract single frame output and clone for memory safety frame_output = {} for k, v in outputs.items(): frame_output[k] = v[i : i + 1].clone() if self.postprocess_fn: try: detections = self.postprocess_fn(frame_output) except Exception as e: logger.error( f"Error in postprocess for stream {batch_frame.stream_id}: {e}" ) detections = torch.zeros( (0, 6), device=list(outputs.values())[0].device ) else: detections = frame_output result = { "stream_id": batch_frame.stream_id, "timestamp": batch_frame.timestamp, "detections": detections, "frame": batch_frame.frame, # Include original frame tensor "metadata": batch_frame.metadata, } results.append(result) return results