""" ModelController - Async batching layer with ping-pong buffers for inference. This module provides batched inference coordination using ping-pong circular buffers with force-switch timeout mechanism. """ import asyncio import torch from typing import Dict, List, Optional, Callable, Any from dataclasses import dataclass, field from enum import Enum import time import logging logger = logging.getLogger(__name__) @dataclass class BatchFrame: """Represents a frame in the batch buffer""" stream_id: str frame: torch.Tensor # GPU tensor (3, H, W) timestamp: float metadata: Dict = field(default_factory=dict) class BufferState(Enum): """State of a ping-pong buffer""" IDLE = "idle" FILLING = "filling" PROCESSING = "processing" class ModelController: """ Manages batched inference with ping-pong buffers and force-switch timeout. This controller accumulates frames from multiple streams into batches, processes them through a model repository, and routes results back to stream-specific callbacks. Features: - Ping-pong circular buffers (BufferA/BufferB) - Force-switch timeout to prevent batch starvation - Async event-driven processing - Thread-safe frame submission Args: model_repository: TensorRT model repository for inference model_id: Model identifier in the repository batch_size: Maximum frames per batch (default: 16) force_timeout: Max wait time before forcing buffer switch in seconds (default: 0.05) preprocess_fn: Optional preprocessing function for frames postprocess_fn: Optional postprocessing function for model outputs """ def __init__( self, model_repository, model_id: str, batch_size: int = 16, force_timeout: float = 0.05, preprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None, ): self.model_repository = model_repository self.model_id = model_id self.batch_size = batch_size self.force_timeout = force_timeout self.preprocess_fn = preprocess_fn self.postprocess_fn = postprocess_fn # Detect model's actual batch size from input shape self.model_batch_size = self._detect_model_batch_size() if self.model_batch_size == 1: logger.warning( f"Model '{model_id}' has fixed batch_size=1. " f"Will process frames sequentially. Consider rebuilding model with dynamic batching." ) else: logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}") # Ping-pong buffers self.buffer_a: List[BatchFrame] = [] self.buffer_b: List[BatchFrame] = [] # Buffer states self.active_buffer = "A" # Which buffer is currently active (filling) self.buffer_a_state = BufferState.IDLE self.buffer_b_state = BufferState.IDLE # Async coordination self.buffer_lock = asyncio.Lock() self.last_submit_time = time.time() # Tasks self.timeout_task: Optional[asyncio.Task] = None self.processor_task: Optional[asyncio.Task] = None self.running = False # Result callbacks (stream_id -> callback) self.result_callbacks: Dict[str, Callable] = {} # Statistics self.total_frames_processed = 0 self.total_batches_processed = 0 def _detect_model_batch_size(self) -> int: """ Detect the model's batch size from its input shape. Returns: Maximum batch size supported by the model (1 for fixed batch size models) """ try: metadata = self.model_repository.get_metadata(self.model_id) # Get first input tensor shape first_input = list(metadata.inputs.values())[0] batch_dim = first_input["shape"][0] # batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size) if batch_dim == -1: # Dynamic batch size - use user-specified batch_size return self.batch_size else: # Fixed batch size return batch_dim except Exception as e: logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1") return 1 async def start(self): """Start the controller background tasks""" if self.running: logger.warning("ModelController already running") return self.running = True self.timeout_task = asyncio.create_task(self._timeout_monitor()) self.processor_task = asyncio.create_task(self._batch_processor()) logger.info("ModelController started") async def stop(self): """Stop the controller and cleanup""" if not self.running: return logger.info("Stopping ModelController...") self.running = False # Cancel tasks if self.timeout_task: self.timeout_task.cancel() try: await self.timeout_task except asyncio.CancelledError: pass if self.processor_task: self.processor_task.cancel() try: await self.processor_task except asyncio.CancelledError: pass # Process any remaining frames await self._process_remaining_buffers() logger.info("ModelController stopped") def register_callback(self, stream_id: str, callback: Callable): """ Register a callback for inference results from a stream. Args: stream_id: Unique stream identifier callback: Callback function to receive results (can be sync or async) """ self.result_callbacks[stream_id] = callback logger.debug(f"Registered callback for stream: {stream_id}") def unregister_callback(self, stream_id: str): """ Unregister a stream callback. Args: stream_id: Stream identifier to unregister """ self.result_callbacks.pop(stream_id, None) logger.debug(f"Unregistered callback for stream: {stream_id}") async def submit_frame( self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None ): """ Submit a frame for batched inference. Args: stream_id: Unique stream identifier frame: GPU tensor (3, H, W) or (C, H, W) metadata: Optional metadata to attach to the frame """ async with self.buffer_lock: batch_frame = BatchFrame( stream_id=stream_id, frame=frame, timestamp=time.time(), metadata=metadata or {} ) # Add to active buffer if self.active_buffer == "A": self.buffer_a.append(batch_frame) self.buffer_a_state = BufferState.FILLING buffer_size = len(self.buffer_a) else: self.buffer_b.append(batch_frame) self.buffer_b_state = BufferState.FILLING buffer_size = len(self.buffer_b) self.last_submit_time = time.time() # Check if we should immediately swap (batch full) if buffer_size >= self.batch_size: await self._try_swap_buffers() async def _timeout_monitor(self): """Monitor force-switch timeout""" while self.running: await asyncio.sleep(0.01) # Check every 10ms async with self.buffer_lock: time_since_submit = time.time() - self.last_submit_time # Check if timeout expired and we have frames waiting if time_since_submit >= self.force_timeout: active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b if len(active_buffer) > 0: await self._try_swap_buffers() async def _try_swap_buffers(self): """ Attempt to swap ping-pong buffers. Only swaps if the inactive buffer is not currently processing. This method should be called with buffer_lock held. """ # Check if inactive buffer is available inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state if inactive_state != BufferState.PROCESSING: # Swap active buffer old_active = self.active_buffer self.active_buffer = "B" if old_active == "A" else "A" # Mark old active buffer as ready for processing if old_active == "A": self.buffer_a_state = BufferState.PROCESSING buffer_size = len(self.buffer_a) else: self.buffer_b_state = BufferState.PROCESSING buffer_size = len(self.buffer_b) logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})") async def _batch_processor(self): """Background task that processes batches when available""" while self.running: await asyncio.sleep(0.001) # Check every 1ms # Check if buffer A needs processing if self.buffer_a_state == BufferState.PROCESSING: await self._process_buffer("A") # Check if buffer B needs processing if self.buffer_b_state == BufferState.PROCESSING: await self._process_buffer("B") async def _process_buffer(self, buffer_name: str): """ Process a buffer through inference. Args: buffer_name: "A" or "B" """ # Extract buffer to process async with self.buffer_lock: if buffer_name == "A": batch = self.buffer_a.copy() self.buffer_a.clear() else: batch = self.buffer_b.copy() self.buffer_b.clear() if len(batch) == 0: # Mark as idle and return async with self.buffer_lock: if buffer_name == "A": self.buffer_a_state = BufferState.IDLE else: self.buffer_b_state = BufferState.IDLE return # Process batch (outside lock to allow concurrent submissions) try: start_time = time.time() results = await self._run_batch_inference(batch) inference_time = time.time() - start_time # Update statistics self.total_frames_processed += len(batch) self.total_batches_processed += 1 logger.debug( f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms " f"({inference_time*1000/len(batch):.2f}ms per frame)" ) # Emit results to callbacks for batch_frame, result in zip(batch, results): callback = self.result_callbacks.get(batch_frame.stream_id) if callback: # Schedule callback asynchronously if asyncio.iscoroutinefunction(callback): asyncio.create_task(callback(result)) else: # Run sync callback in executor to avoid blocking loop = asyncio.get_event_loop() loop.call_soon(lambda cb=callback, r=result: cb(r)) except Exception as e: logger.error(f"Error processing batch: {e}", exc_info=True) # TODO: Emit error events to streams finally: # Mark buffer as idle async with self.buffer_lock: if buffer_name == "A": self.buffer_a_state = BufferState.IDLE else: self.buffer_b_state = BufferState.IDLE async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: """ Run inference on a batch of frames. Args: batch: List of BatchFrame objects Returns: List of detection results (one per frame) """ loop = asyncio.get_event_loop() # Check if model supports batching if self.model_batch_size == 1: # Process frames one at a time for batch_size=1 models return await self._run_sequential_inference(batch, loop) else: # Use true batching for models that support it return await self._run_batched_inference(batch, loop) async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]: """Run inference sequentially for batch_size=1 models""" results = [] for batch_frame in batch: # Preprocess frame if self.preprocess_fn: processed = self.preprocess_fn(batch_frame.frame) else: # Ensure we have batch dimension processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame # Run inference for this frame outputs = await loop.run_in_executor( None, lambda p=processed: self.model_repository.infer( self.model_id, {"images": p}, synchronize=True ) ) # Postprocess if self.postprocess_fn: try: detections = self.postprocess_fn(outputs) except Exception as e: logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}") # Return empty detections on error detections = torch.zeros((0, 6), device=list(outputs.values())[0].device) else: detections = outputs result = { "stream_id": batch_frame.stream_id, "timestamp": batch_frame.timestamp, "detections": detections, "metadata": batch_frame.metadata, } results.append(result) return results async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]: """Run true batched inference for models that support it""" # Preprocess frames (on GPU) preprocessed = [] for batch_frame in batch: if self.preprocess_fn: processed = self.preprocess_fn(batch_frame.frame) # Preprocess may return (1, C, H, W), squeeze to (C, H, W) if processed.dim() == 4 and processed.shape[0] == 1: processed = processed.squeeze(0) else: processed = batch_frame.frame preprocessed.append(processed) # Stack into batch tensor: (N, C, H, W) batch_tensor = torch.stack(preprocessed, dim=0) # Limit batch size to model's max batch size if batch_tensor.shape[0] > self.model_batch_size: logger.warning( f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}, " f"will split into sub-batches" ) # TODO: Handle splitting into sub-batches batch_tensor = batch_tensor[:self.model_batch_size] batch = batch[:self.model_batch_size] # Run inference outputs = await loop.run_in_executor( None, lambda: self.model_repository.infer( self.model_id, {"images": batch_tensor}, synchronize=True ) ) # Postprocess results (split batch back to individual results) results = [] for i, batch_frame in enumerate(batch): # Extract single frame output from batch frame_output = {} for k, v in outputs.items(): # v has shape (N, ...), extract index i and keep batch dimension frame_output[k] = v[i:i+1] # Shape: (1, ...) if self.postprocess_fn: try: detections = self.postprocess_fn(frame_output) except Exception as e: logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}") # Return empty detections on error detections = torch.zeros((0, 6), device=list(outputs.values())[0].device) else: detections = frame_output result = { "stream_id": batch_frame.stream_id, "timestamp": batch_frame.timestamp, "detections": detections, "metadata": batch_frame.metadata, } results.append(result) return results async def _process_remaining_buffers(self): """Process any remaining frames in buffers during shutdown""" if len(self.buffer_a) > 0: logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A") await self._process_buffer("A") if len(self.buffer_b) > 0: logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B") await self._process_buffer("B") def get_stats(self) -> Dict[str, Any]: """Get current buffer statistics""" return { "active_buffer": self.active_buffer, "buffer_a_size": len(self.buffer_a), "buffer_b_size": len(self.buffer_b), "buffer_a_state": self.buffer_a_state.value, "buffer_b_state": self.buffer_b_state.value, "registered_streams": len(self.result_callbacks), "total_frames_processed": self.total_frames_processed, "total_batches_processed": self.total_batches_processed, "avg_batch_size": ( self.total_frames_processed / self.total_batches_processed if self.total_batches_processed > 0 else 0 ), }