""" ModelController - Event-driven batching layer with ping-pong buffers for inference. This module provides batched inference coordination using ping-pong circular buffers with force-switch timeout mechanism using threading and callbacks. """ import threading import torch from typing import Dict, List, Optional, Callable, Any from dataclasses import dataclass, field from enum import Enum import time import logging import queue logger = logging.getLogger(__name__) @dataclass class BatchFrame: """Represents a frame in the batch buffer""" stream_id: str frame: torch.Tensor # GPU tensor (3, H, W) timestamp: float metadata: Dict = field(default_factory=dict) class BufferState(Enum): """State of a ping-pong buffer""" IDLE = "idle" FILLING = "filling" PROCESSING = "processing" class ModelController: """ Manages batched inference with ping-pong buffers and force-switch timeout. This controller accumulates frames from multiple streams into batches, processes them through a model repository, and routes results back to stream-specific callbacks. Features: - Ping-pong circular buffers (BufferA/BufferB) - Force-switch timeout to prevent batch starvation - Event-driven processing with callbacks - Thread-safe frame submission Args: model_repository: TensorRT model repository for inference model_id: Model identifier in the repository batch_size: Maximum frames per batch (default: 16) force_timeout: Max wait time before forcing buffer switch in seconds (default: 0.05) preprocess_fn: Optional preprocessing function for frames postprocess_fn: Optional postprocessing function for model outputs """ def __init__( self, model_repository, model_id: str, batch_size: int = 16, force_timeout: float = 0.05, preprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None, ): self.model_repository = model_repository self.model_id = model_id self.batch_size = batch_size self.force_timeout = force_timeout self.preprocess_fn = preprocess_fn self.postprocess_fn = postprocess_fn # Detect model's actual batch size from input shape self.model_batch_size = self._detect_model_batch_size() if self.model_batch_size == 1: logger.warning( f"Model '{model_id}' has fixed batch_size=1. " f"Will process frames sequentially. Consider rebuilding model with dynamic batching." ) else: logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}") # Ping-pong buffers self.buffer_a: List[BatchFrame] = [] self.buffer_b: List[BatchFrame] = [] # Buffer states self.active_buffer = "A" # Which buffer is currently active (filling) self.buffer_a_state = BufferState.IDLE self.buffer_b_state = BufferState.IDLE # Threading coordination self.buffer_lock = threading.RLock() self.last_submit_time = time.time() # Threads self.timeout_thread: Optional[threading.Thread] = None self.processor_threads: Dict[str, threading.Thread] = {} self.running = False self.stop_event = threading.Event() # Result callbacks (stream_id -> callback) self.result_callbacks: Dict[str, Callable] = {} # Statistics self.total_frames_processed = 0 self.total_batches_processed = 0 def _detect_model_batch_size(self) -> int: """ Detect the model's batch size from its input shape. Returns: Maximum batch size supported by the model (1 for fixed batch size models) """ try: metadata = self.model_repository.get_metadata(self.model_id) # Get first input tensor shape first_input = list(metadata.inputs.values())[0] batch_dim = first_input["shape"][0] # batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size) if batch_dim == -1: # Dynamic batch size - use user-specified batch_size return self.batch_size else: # Fixed batch size return batch_dim except Exception as e: logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1") return 1 def start(self): """Start the controller background threads""" if self.running: logger.warning("ModelController already running") return self.running = True self.stop_event.clear() # Start timeout monitor thread self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True) self.timeout_thread.start() # Start processor threads for each buffer self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True) self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True) self.processor_threads['A'].start() self.processor_threads['B'].start() logger.info("ModelController started") def stop(self): """Stop the controller and cleanup""" if not self.running: return logger.info("Stopping ModelController...") self.running = False self.stop_event.set() # Wait for threads to finish if self.timeout_thread and self.timeout_thread.is_alive(): self.timeout_thread.join(timeout=2.0) for thread in self.processor_threads.values(): if thread and thread.is_alive(): thread.join(timeout=2.0) # Process any remaining frames self._process_remaining_buffers() logger.info("ModelController stopped") def register_callback(self, stream_id: str, callback: Callable): """ Register a callback for inference results from a stream. Args: stream_id: Unique stream identifier callback: Callback function to receive results (can be sync or async) """ self.result_callbacks[stream_id] = callback logger.debug(f"Registered callback for stream: {stream_id}") def unregister_callback(self, stream_id: str): """ Unregister a stream callback. Args: stream_id: Stream identifier to unregister """ self.result_callbacks.pop(stream_id, None) logger.debug(f"Unregistered callback for stream: {stream_id}") def submit_frame( self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None ): """ Submit a frame for batched inference. Args: stream_id: Unique stream identifier frame: GPU tensor (3, H, W) or (C, H, W) metadata: Optional metadata to attach to the frame """ with self.buffer_lock: batch_frame = BatchFrame( stream_id=stream_id, frame=frame, timestamp=time.time(), metadata=metadata or {} ) # Add to active buffer if self.active_buffer == "A": self.buffer_a.append(batch_frame) self.buffer_a_state = BufferState.FILLING buffer_size = len(self.buffer_a) else: self.buffer_b.append(batch_frame) self.buffer_b_state = BufferState.FILLING buffer_size = len(self.buffer_b) self.last_submit_time = time.time() # Check if we should immediately swap (batch full) if buffer_size >= self.batch_size: self._try_swap_buffers() def _timeout_monitor(self): """Monitor force-switch timeout""" while self.running and not self.stop_event.wait(0.01): # Check every 10ms with self.buffer_lock: time_since_submit = time.time() - self.last_submit_time # Check if timeout expired and we have frames waiting if time_since_submit >= self.force_timeout: active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b if len(active_buffer) > 0: self._try_swap_buffers() def _try_swap_buffers(self): """ Attempt to swap ping-pong buffers. Only swaps if the inactive buffer is not currently processing. This method should be called with buffer_lock held. """ # Check if inactive buffer is available inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state if inactive_state != BufferState.PROCESSING: # Swap active buffer old_active = self.active_buffer self.active_buffer = "B" if old_active == "A" else "A" # Mark old active buffer as ready for processing if old_active == "A": self.buffer_a_state = BufferState.PROCESSING buffer_size = len(self.buffer_a) else: self.buffer_b_state = BufferState.PROCESSING buffer_size = len(self.buffer_b) logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})") def _batch_processor(self, buffer_name: str): """Background thread that processes a specific buffer when available""" while self.running and not self.stop_event.is_set(): time.sleep(0.001) # Check every 1ms # Check if this buffer needs processing with self.buffer_lock: if buffer_name == "A": should_process = self.buffer_a_state == BufferState.PROCESSING else: should_process = self.buffer_b_state == BufferState.PROCESSING if should_process: self._process_buffer(buffer_name) def _process_buffer(self, buffer_name: str): """ Process a buffer through inference. Args: buffer_name: "A" or "B" """ # Extract buffer to process with self.buffer_lock: if buffer_name == "A": batch = self.buffer_a.copy() self.buffer_a.clear() else: batch = self.buffer_b.copy() self.buffer_b.clear() if len(batch) == 0: # Mark as idle and return with self.buffer_lock: if buffer_name == "A": self.buffer_a_state = BufferState.IDLE else: self.buffer_b_state = BufferState.IDLE return # Process batch (outside lock to allow concurrent submissions) try: start_time = time.time() results = self._run_batch_inference(batch) inference_time = time.time() - start_time # Update statistics self.total_frames_processed += len(batch) self.total_batches_processed += 1 logger.debug( f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms " f"({inference_time*1000/len(batch):.2f}ms per frame)" ) # Emit results to callbacks for batch_frame, result in zip(batch, results): callback = self.result_callbacks.get(batch_frame.stream_id) if callback: # Call callback directly (synchronous) try: callback(result) except Exception as e: logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True) except Exception as e: logger.error(f"Error processing batch: {e}", exc_info=True) finally: # Mark buffer as idle with self.buffer_lock: if buffer_name == "A": self.buffer_a_state = BufferState.IDLE else: self.buffer_b_state = BufferState.IDLE def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: """ Run inference on a batch of frames. Args: batch: List of BatchFrame objects Returns: List of detection results (one per frame) """ # Check if model supports batching if self.model_batch_size == 1: # Process frames one at a time for batch_size=1 models return self._run_sequential_inference(batch) else: # Use true batching for models that support it return self._run_batched_inference(batch) def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: """Run inference sequentially for batch_size=1 models""" results = [] for batch_frame in batch: # Preprocess frame if self.preprocess_fn: processed = self.preprocess_fn(batch_frame.frame) else: # Ensure we have batch dimension processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame # Run inference for this frame outputs = self.model_repository.infer( self.model_id, {"images": processed}, synchronize=True ) # Postprocess if self.postprocess_fn: try: detections = self.postprocess_fn(outputs) except Exception as e: logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}") # Return empty detections on error detections = torch.zeros((0, 6), device=list(outputs.values())[0].device) else: detections = outputs result = { "stream_id": batch_frame.stream_id, "timestamp": batch_frame.timestamp, "detections": detections, "metadata": batch_frame.metadata, } results.append(result) return results def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: """Run true batched inference for models that support it""" # Preprocess frames (on GPU) preprocessed = [] for batch_frame in batch: if self.preprocess_fn: processed = self.preprocess_fn(batch_frame.frame) # Preprocess may return (1, C, H, W), squeeze to (C, H, W) if processed.dim() == 4 and processed.shape[0] == 1: processed = processed.squeeze(0) else: processed = batch_frame.frame preprocessed.append(processed) # Stack into batch tensor: (N, C, H, W) batch_tensor = torch.stack(preprocessed, dim=0) # Limit batch size to model's max batch size if batch_tensor.shape[0] > self.model_batch_size: logger.warning( f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}, " f"will split into sub-batches" ) # TODO: Handle splitting into sub-batches batch_tensor = batch_tensor[:self.model_batch_size] batch = batch[:self.model_batch_size] # Run inference outputs = self.model_repository.infer( self.model_id, {"images": batch_tensor}, synchronize=True ) # Postprocess results (split batch back to individual results) results = [] for i, batch_frame in enumerate(batch): # Extract single frame output from batch frame_output = {} for k, v in outputs.items(): # v has shape (N, ...), extract index i and keep batch dimension frame_output[k] = v[i:i+1] # Shape: (1, ...) if self.postprocess_fn: try: detections = self.postprocess_fn(frame_output) except Exception as e: logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}") # Return empty detections on error detections = torch.zeros((0, 6), device=list(outputs.values())[0].device) else: detections = frame_output result = { "stream_id": batch_frame.stream_id, "timestamp": batch_frame.timestamp, "detections": detections, "metadata": batch_frame.metadata, } results.append(result) return results def _process_remaining_buffers(self): """Process any remaining frames in buffers during shutdown""" if len(self.buffer_a) > 0: logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A") self._process_buffer("A") if len(self.buffer_b) > 0: logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B") self._process_buffer("B") def get_stats(self) -> Dict[str, Any]: """Get current buffer statistics""" return { "active_buffer": self.active_buffer, "buffer_a_size": len(self.buffer_a), "buffer_b_size": len(self.buffer_b), "buffer_a_state": self.buffer_a_state.value, "buffer_b_state": self.buffer_b_state.value, "registered_streams": len(self.result_callbacks), "total_frames_processed": self.total_frames_processed, "total_batches_processed": self.total_batches_processed, "avg_batch_size": ( self.total_frames_processed / self.total_batches_processed if self.total_batches_processed > 0 else 0 ), }