""" Base Model Controller - Simple circular buffer with continuous batch processing. Replaces the complex ping-pong buffer architecture with a simple queue: - Frames arrive and go into a single deque (circular buffer) - When batch_size frames are ready, process them - Continue consuming batches until queue is empty - Drop oldest frames if queue is full """ import logging import threading import time from abc import ABC, abstractmethod from collections import deque from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional import torch logger = logging.getLogger(__name__) @dataclass class BatchFrame: """Represents a frame in the batch buffer""" stream_id: str frame: torch.Tensor # GPU tensor (3, H, W) timestamp: float metadata: Dict = field(default_factory=dict) class BaseModelController(ABC): """ Simple batched inference with circular buffer. Architecture: - Single deque (circular buffer) for incoming frames - Batch processor thread continuously consumes batches - Frames come in fast, batches go out as fast as inference allows - Automatic frame dropping when queue is full """ def __init__( self, model_id: str, batch_size: int = 16, max_queue_size: int = 100, preprocess_fn: Optional[Callable] = None, postprocess_fn: Optional[Callable] = None, ): self.model_id = model_id self.batch_size = batch_size self.max_queue_size = max_queue_size self.preprocess_fn = preprocess_fn self.postprocess_fn = postprocess_fn # Single circular buffer self.frame_queue = deque(maxlen=max_queue_size) self.queue_lock = threading.Lock() # Processing thread self.processor_thread: Optional[threading.Thread] = None self.running = False self.stop_event = threading.Event() # Result callbacks (stream_id -> callback) self.result_callbacks: Dict[str, Callable] = {} # Statistics self.total_frames_processed = 0 self.total_batches_processed = 0 self.total_frames_dropped = 0 def start(self): """Start the controller background thread""" if self.running: logger.warning(f"{self.__class__.__name__} already running") return self.running = True self.stop_event.clear() # Start single processor thread self.processor_thread = threading.Thread( target=self._batch_processor, daemon=True ) self.processor_thread.start() logger.info(f"{self.__class__.__name__} started (batch_size={self.batch_size})") def stop(self): """Stop the controller and cleanup""" if not self.running: return logger.info(f"Stopping {self.__class__.__name__}...") self.running = False self.stop_event.set() # Wait for thread to finish if self.processor_thread and self.processor_thread.is_alive(): self.processor_thread.join(timeout=2.0) # Process remaining frames self._process_remaining_frames() logger.info(f"{self.__class__.__name__} stopped") def register_callback(self, stream_id: str, callback: Callable): """Register a callback for inference results from a stream""" self.result_callbacks[stream_id] = callback logger.debug(f"Registered callback for stream: {stream_id}") def unregister_callback(self, stream_id: str): """Unregister a stream callback""" self.result_callbacks.pop(stream_id, None) logger.debug(f"Unregistered callback for stream: {stream_id}") def submit_frame( self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None ): """Submit a frame for batched inference""" with self.queue_lock: # If queue is full, oldest frame is automatically dropped (deque with maxlen) if len(self.frame_queue) >= self.max_queue_size: self.total_frames_dropped += 1 batch_frame = BatchFrame( stream_id=stream_id, frame=frame, timestamp=time.time(), metadata=metadata or {}, ) self.frame_queue.append(batch_frame) def _batch_processor(self): """Background thread that continuously processes batches""" logger.info(f"{self.__class__.__name__} batch processor started") while self.running and not self.stop_event.is_set(): # Check if we have enough frames for a batch with self.queue_lock: queue_size = len(self.frame_queue) if queue_size > 0 and queue_size % 10 == 0: logger.info(f"Queue size: {queue_size}/{self.batch_size}") if queue_size >= self.batch_size: # Extract batch with self.queue_lock: batch = [] for _ in range(min(self.batch_size, len(self.frame_queue))): if self.frame_queue: batch.append(self.frame_queue.popleft()) if batch: logger.info(f"Processing batch of {len(batch)} frames") self._process_batch(batch) else: # Not enough frames, sleep briefly time.sleep(0.001) # 1ms def _process_batch(self, batch: List[BatchFrame]): """Process a batch through inference""" try: start_time = time.time() results = self._run_batch_inference(batch) inference_time = time.time() - start_time self.total_frames_processed += len(batch) self.total_batches_processed += 1 logger.debug( f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms" ) # Emit results to callbacks for batch_frame, result in zip(batch, results): callback = self.result_callbacks.get(batch_frame.stream_id) if callback: try: callback(result) except Exception as e: logger.error( f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True, ) except Exception as e: logger.error(f"Error processing batch: {e}", exc_info=True) @abstractmethod def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]: """Run inference on a batch of frames (backend-specific)""" pass def _process_remaining_frames(self): """Process any remaining frames in queue during shutdown""" with self.queue_lock: remaining = len(self.frame_queue) if remaining > 0: logger.info(f"Processing remaining {remaining} frames") while True: with self.queue_lock: if not self.frame_queue: break batch = [] for _ in range(min(self.batch_size, len(self.frame_queue))): if self.frame_queue: batch.append(self.frame_queue.popleft()) if batch: self._process_batch(batch) def get_stats(self) -> Dict[str, Any]: """Get current statistics""" with self.queue_lock: queue_size = len(self.frame_queue) return { "queue_size": queue_size, "max_queue_size": self.max_queue_size, "batch_size": self.batch_size, "registered_streams": len(self.result_callbacks), "total_frames_processed": self.total_frames_processed, "total_batches_processed": self.total_batches_processed, "total_frames_dropped": self.total_frames_dropped, "avg_batch_size": ( self.total_frames_processed / self.total_batches_processed if self.total_batches_processed > 0 else 0 ), } # Keep old BufferState enum for backwards compatibility from enum import Enum class BufferState(Enum): """Deprecated - kept for backwards compatibility""" IDLE = "idle" FILLING = "filling" PROCESSING = "processing"