new buffer paradigm

This commit is contained in:
Siwat Sirichai 2025-11-11 02:02:12 +07:00
parent fdaeb9981c
commit a519dea130
6 changed files with 341 additions and 327 deletions

View file

@ -1,16 +1,19 @@
"""
Base Model Controller - Abstract base class for batched inference controllers.
Base Model Controller - Simple circular buffer with continuous batch processing.
Provides ping-pong buffer architecture with force-switch timeout mechanism.
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.).
Replaces the complex ping-pong buffer architecture with a simple queue:
- Frames arrive and go into a single deque (circular buffer)
- When batch_size frames are ready, process them
- Continue consuming batches until queue is empty
- Drop oldest frames if queue is full
"""
import logging
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
import torch
@ -28,62 +31,37 @@ class BatchFrame:
metadata: Dict = field(default_factory=dict)
class BufferState(Enum):
"""State of a ping-pong buffer"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"
class BaseModelController(ABC):
"""
Abstract base class for batched inference with ping-pong buffers.
Simple batched inference with circular buffer.
This controller accumulates frames from multiple streams into batches,
processes them through an inference backend, and routes results back to
stream-specific callbacks.
Features:
- Ping-pong circular buffers (BufferA/BufferB)
- Force-switch timeout to prevent batch starvation
- Event-driven processing with callbacks
- Thread-safe frame submission
Subclasses must implement:
- _run_batch_inference(): Backend-specific inference logic
Architecture:
- Single deque (circular buffer) for incoming frames
- Batch processor thread continuously consumes batches
- Frames come in fast, batches go out as fast as inference allows
- Automatic frame dropping when queue is full
"""
def __init__(
self,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
max_queue_size: int = 100,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
self.model_id = model_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.max_queue_size = max_queue_size
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn
# Ping-pong buffers
self.buffer_a: List[BatchFrame] = []
self.buffer_b: List[BatchFrame] = []
# Single circular buffer
self.frame_queue = deque(maxlen=max_queue_size)
self.queue_lock = threading.Lock()
# Buffer states
self.active_buffer = "A"
self.buffer_a_state = BufferState.IDLE
self.buffer_b_state = BufferState.IDLE
# Threading coordination
self.buffer_lock = threading.RLock()
self.last_submit_time = time.time()
# Threads
self.timeout_thread: Optional[threading.Thread] = None
self.processor_threads: Dict[str, threading.Thread] = {}
# Processing thread
self.processor_thread: Optional[threading.Thread] = None
self.running = False
self.stop_event = threading.Event()
@ -93,33 +71,24 @@ class BaseModelController(ABC):
# Statistics
self.total_frames_processed = 0
self.total_batches_processed = 0
self.total_frames_dropped = 0
def start(self):
"""Start the controller background threads"""
"""Start the controller background thread"""
if self.running:
logger.warning("ModelController already running")
logger.warning(f"{self.__class__.__name__} already running")
return
self.running = True
self.stop_event.clear()
# Start timeout monitor thread
self.timeout_thread = threading.Thread(
target=self._timeout_monitor, daemon=True
# Start single processor thread
self.processor_thread = threading.Thread(
target=self._batch_processor, daemon=True
)
self.timeout_thread.start()
self.processor_thread.start()
# Start processor threads for each buffer
self.processor_threads["A"] = threading.Thread(
target=self._batch_processor, args=("A",), daemon=True
)
self.processor_threads["B"] = threading.Thread(
target=self._batch_processor, args=("B",), daemon=True
)
self.processor_threads["A"].start()
self.processor_threads["B"].start()
logger.info(f"{self.__class__.__name__} started")
logger.info(f"{self.__class__.__name__} started (batch_size={self.batch_size})")
def stop(self):
"""Stop the controller and cleanup"""
@ -130,16 +99,12 @@ class BaseModelController(ABC):
self.running = False
self.stop_event.set()
# Wait for threads to finish
if self.timeout_thread and self.timeout_thread.is_alive():
self.timeout_thread.join(timeout=2.0)
# Wait for thread to finish
if self.processor_thread and self.processor_thread.is_alive():
self.processor_thread.join(timeout=2.0)
for thread in self.processor_threads.values():
if thread and thread.is_alive():
thread.join(timeout=2.0)
# Process any remaining frames
self._process_remaining_buffers()
# Process remaining frames
self._process_remaining_frames()
logger.info(f"{self.__class__.__name__} stopped")
def register_callback(self, stream_id: str, callback: Callable):
@ -156,98 +121,48 @@ class BaseModelController(ABC):
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
):
"""Submit a frame for batched inference"""
with self.buffer_lock:
with self.queue_lock:
# If queue is full, oldest frame is automatically dropped (deque with maxlen)
if len(self.frame_queue) >= self.max_queue_size:
self.total_frames_dropped += 1
batch_frame = BatchFrame(
stream_id=stream_id,
frame=frame,
timestamp=time.time(),
metadata=metadata or {},
)
self.frame_queue.append(batch_frame)
# Add to active buffer
if self.active_buffer == "A":
self.buffer_a.append(batch_frame)
self.buffer_a_state = BufferState.FILLING
buffer_size = len(self.buffer_a)
else:
self.buffer_b.append(batch_frame)
self.buffer_b_state = BufferState.FILLING
buffer_size = len(self.buffer_b)
def _batch_processor(self):
"""Background thread that continuously processes batches"""
logger.info(f"{self.__class__.__name__} batch processor started")
self.last_submit_time = time.time()
# Check if we should immediately swap (batch full)
if buffer_size >= self.batch_size:
self._try_swap_buffers()
def _timeout_monitor(self):
"""Monitor force-switch timeout"""
while self.running and not self.stop_event.wait(0.01):
with self.buffer_lock:
time_since_submit = time.time() - self.last_submit_time
if time_since_submit >= self.force_timeout:
active_buffer = (
self.buffer_a if self.active_buffer == "A" else self.buffer_b
)
if len(active_buffer) > 0:
self._try_swap_buffers()
def _try_swap_buffers(self):
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
inactive_state = (
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
)
if inactive_state != BufferState.PROCESSING:
old_active = self.active_buffer
self.active_buffer = "B" if old_active == "A" else "A"
if old_active == "A":
self.buffer_a_state = BufferState.PROCESSING
buffer_size = len(self.buffer_a)
else:
self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b)
logger.debug(
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
)
def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available"""
while self.running and not self.stop_event.is_set():
time.sleep(0.001)
# Check if we have enough frames for a batch
with self.queue_lock:
queue_size = len(self.frame_queue)
with self.buffer_lock:
if buffer_name == "A":
should_process = self.buffer_a_state == BufferState.PROCESSING
else:
should_process = self.buffer_b_state == BufferState.PROCESSING
if queue_size > 0 and queue_size % 10 == 0:
logger.info(f"Queue size: {queue_size}/{self.batch_size}")
if should_process:
self._process_buffer(buffer_name)
if queue_size >= self.batch_size:
# Extract batch
with self.queue_lock:
batch = []
for _ in range(min(self.batch_size, len(self.frame_queue))):
if self.frame_queue:
batch.append(self.frame_queue.popleft())
def _process_buffer(self, buffer_name: str):
"""Process a buffer through inference"""
# Extract buffer to process
with self.buffer_lock:
if buffer_name == "A":
batch = self.buffer_a.copy()
self.buffer_a.clear()
if batch:
logger.info(f"Processing batch of {len(batch)} frames")
self._process_batch(batch)
else:
batch = self.buffer_b.copy()
self.buffer_b.clear()
# Not enough frames, sleep briefly
time.sleep(0.001) # 1ms
if len(batch) == 0:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
return
# Process batch (outside lock to allow concurrent submissions)
def _process_batch(self, batch: List[BatchFrame]):
"""Process a batch through inference"""
try:
start_time = time.time()
results = self._run_batch_inference(batch)
@ -257,8 +172,7 @@ class BaseModelController(ABC):
self.total_batches_processed += 1
logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms"
)
# Emit results to callbacks
@ -276,49 +190,58 @@ class BaseModelController(ABC):
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
finally:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
@abstractmethod
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run inference on a batch of frames (backend-specific).
Args:
batch: List of BatchFrame objects
Returns:
List of detection results (one per frame)
"""
"""Run inference on a batch of frames (backend-specific)"""
pass
def _process_remaining_buffers(self):
"""Process any remaining frames in buffers during shutdown"""
if len(self.buffer_a) > 0:
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
self._process_buffer("A")
if len(self.buffer_b) > 0:
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
self._process_buffer("B")
def _process_remaining_frames(self):
"""Process any remaining frames in queue during shutdown"""
with self.queue_lock:
remaining = len(self.frame_queue)
if remaining > 0:
logger.info(f"Processing remaining {remaining} frames")
while True:
with self.queue_lock:
if not self.frame_queue:
break
batch = []
for _ in range(min(self.batch_size, len(self.frame_queue))):
if self.frame_queue:
batch.append(self.frame_queue.popleft())
if batch:
self._process_batch(batch)
def get_stats(self) -> Dict[str, Any]:
"""Get current buffer statistics"""
"""Get current statistics"""
with self.queue_lock:
queue_size = len(self.frame_queue)
return {
"active_buffer": self.active_buffer,
"buffer_a_size": len(self.buffer_a),
"buffer_b_size": len(self.buffer_b),
"buffer_a_state": self.buffer_a_state.value,
"buffer_b_state": self.buffer_b_state.value,
"queue_size": queue_size,
"max_queue_size": self.max_queue_size,
"batch_size": self.batch_size,
"registered_streams": len(self.result_callbacks),
"total_frames_processed": self.total_frames_processed,
"total_batches_processed": self.total_batches_processed,
"total_frames_dropped": self.total_frames_dropped,
"avg_batch_size": (
self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0
else 0
),
}
# Keep old BufferState enum for backwards compatibility
from enum import Enum
class BufferState(Enum):
"""Deprecated - kept for backwards compatibility"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"