""" Base Model Controller - Abstract base class for batched inference controllers. Provides ping-pong buffer architecture with force-switch timeout mechanism. Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.). """ import logging import threading import time from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Dict, List, Optional import torch logger = logging.getLogger(__name__) @dataclass class BatchFrame: """Represents a frame in the batch buffer""" stream_id: str frame: torch.Tensor # GPU tensor (3, H, W) timestamp: float metadata: Dict = field(default_factory=dict) class BufferState(Enum): """State of a ping-pong buffer""" IDLE = "idle" FILLING = "filling" PROCESSING = "processing" class BaseModelController(ABC): """ Abstract base class for batched inference with ping-pong buffers. This controller accumulates frames from multiple streams into batches, processes them through an inference backend, and routes results back to stream-specific callbacks. Features: - Ping-pong circular buffers (BufferA/BufferB) - Force-switch timeout to prevent batch starvation - Event-driven processing with callbacks - Thread-safe frame submission Subclasses must implement: - _run_batch_inference(): Backend-specific inference logic """ def __init__( self, model_id: str, batch_size: int = 16, force_timeout: float = 0.05, preprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None, ): self.model_id = model_id self.batch_size = batch_size self.force_timeout = force_timeout self.preprocess_fn = preprocess_fn self.postprocess_fn = postprocess_fn # Ping-pong buffers self.buffer_a: List[BatchFrame] = [] self.buffer_b: List[BatchFrame] = [] # Buffer states self.active_buffer = "A" self.buffer_a_state = BufferState.IDLE self.buffer_b_state = BufferState.IDLE # Threading coordination self.buffer_lock = threading.RLock() self.last_submit_time = time.time() # Threads self.timeout_thread: Optional[threading.Thread] = None self.processor_threads: Dict[str, threading.Thread] = {} self.running = False self.stop_event = threading.Event() # Result callbacks (stream_id -> callback) self.result_callbacks: Dict[str, Callable] = {} # Statistics self.total_frames_processed = 0 self.total_batches_processed = 0 def start(self): """Start the controller background threads""" if self.running: logger.warning("ModelController already running") return self.running = True self.stop_event.clear() # Start timeout monitor thread self.timeout_thread = threading.Thread( target=self._timeout_monitor, daemon=True ) self.timeout_thread.start() # Start processor threads for each buffer self.processor_threads["A"] = threading.Thread( target=self._batch_processor, args=("A",), daemon=True ) self.processor_threads["B"] = threading.Thread( target=self._batch_processor, args=("B",), daemon=True ) self.processor_threads["A"].start() self.processor_threads["B"].start() logger.info(f"{self.__class__.__name__} started") def stop(self): """Stop the controller and cleanup""" if not self.running: return logger.info(f"Stopping {self.__class__.__name__}...") self.running = False self.stop_event.set() # Wait for threads to finish if self.timeout_thread and self.timeout_thread.is_alive(): self.timeout_thread.join(timeout=2.0) for thread in self.processor_threads.values(): if thread and thread.is_alive(): thread.join(timeout=2.0) # Process any remaining frames self._process_remaining_buffers() logger.info(f"{self.__class__.__name__} stopped") def register_callback(self, stream_id: str, callback: Callable): """Register a callback for inference results from a stream""" self.result_callbacks[stream_id] = callback logger.debug(f"Registered callback for stream: {stream_id}") def unregister_callback(self, stream_id: str): """Unregister a stream callback""" self.result_callbacks.pop(stream_id, None) logger.debug(f"Unregistered callback for stream: {stream_id}") def submit_frame( self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None ): """Submit a frame for batched inference""" with self.buffer_lock: batch_frame = BatchFrame( stream_id=stream_id, frame=frame, timestamp=time.time(), metadata=metadata or {}, ) # Add to active buffer if self.active_buffer == "A": self.buffer_a.append(batch_frame) self.buffer_a_state = BufferState.FILLING buffer_size = len(self.buffer_a) else: self.buffer_b.append(batch_frame) self.buffer_b_state = BufferState.FILLING buffer_size = len(self.buffer_b) self.last_submit_time = time.time() # Check if we should immediately swap (batch full) if buffer_size >= self.batch_size: self._try_swap_buffers() def _timeout_monitor(self): """Monitor force-switch timeout""" while self.running and not self.stop_event.wait(0.01): with self.buffer_lock: time_since_submit = time.time() - self.last_submit_time if time_since_submit >= self.force_timeout: active_buffer = ( self.buffer_a if self.active_buffer == "A" else self.buffer_b ) if len(active_buffer) > 0: self._try_swap_buffers() def _try_swap_buffers(self): """Attempt to swap ping-pong buffers (called with buffer_lock held)""" inactive_state = ( self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state ) if inactive_state != BufferState.PROCESSING: old_active = self.active_buffer self.active_buffer = "B" if old_active == "A" else "A" if old_active == "A": self.buffer_a_state = BufferState.PROCESSING buffer_size = len(self.buffer_a) else: self.buffer_b_state = BufferState.PROCESSING buffer_size = len(self.buffer_b) logger.debug( f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})" ) def _batch_processor(self, buffer_name: str): """Background thread that processes a specific buffer when available""" while self.running and not self.stop_event.is_set(): time.sleep(0.001) with self.buffer_lock: if buffer_name == "A": should_process = self.buffer_a_state == BufferState.PROCESSING else: should_process = self.buffer_b_state == BufferState.PROCESSING if should_process: self._process_buffer(buffer_name) def _process_buffer(self, buffer_name: str): """Process a buffer through inference""" # Extract buffer to process with self.buffer_lock: if buffer_name == "A": batch = self.buffer_a.copy() self.buffer_a.clear() else: batch = self.buffer_b.copy() self.buffer_b.clear() if len(batch) == 0: with self.buffer_lock: if buffer_name == "A": self.buffer_a_state = BufferState.IDLE else: self.buffer_b_state = BufferState.IDLE return # Process batch (outside lock to allow concurrent submissions) try: start_time = time.time() results = self._run_batch_inference(batch) inference_time = time.time() - start_time self.total_frames_processed += len(batch) self.total_batches_processed += 1 logger.debug( f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms " f"({inference_time * 1000 / len(batch):.2f}ms per frame)" ) # Emit results to callbacks for batch_frame, result in zip(batch, results): callback = self.result_callbacks.get(batch_frame.stream_id) if callback: try: callback(result) except Exception as e: logger.error( f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True, ) except Exception as e: logger.error(f"Error processing batch: {e}", exc_info=True) finally: with self.buffer_lock: if buffer_name == "A": self.buffer_a_state = BufferState.IDLE else: self.buffer_b_state = BufferState.IDLE @abstractmethod def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: """ Run inference on a batch of frames (backend-specific). Args: batch: List of BatchFrame objects Returns: List of detection results (one per frame) """ pass def _process_remaining_buffers(self): """Process any remaining frames in buffers during shutdown""" if len(self.buffer_a) > 0: logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A") self._process_buffer("A") if len(self.buffer_b) > 0: logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B") self._process_buffer("B") def get_stats(self) -> Dict[str, Any]: """Get current buffer statistics""" return { "active_buffer": self.active_buffer, "buffer_a_size": len(self.buffer_a), "buffer_b_size": len(self.buffer_b), "buffer_a_state": self.buffer_a_state.value, "buffer_b_state": self.buffer_b_state.value, "registered_streams": len(self.result_callbacks), "total_frames_processed": self.total_frames_processed, "total_batches_processed": self.total_batches_processed, "avg_batch_size": ( self.total_frames_processed / self.total_batches_processed if self.total_batches_processed > 0 else 0 ), }