new buffer paradigm
This commit is contained in:
parent
fdaeb9981c
commit
a519dea130
6 changed files with 341 additions and 327 deletions
|
|
@ -1,16 +1,19 @@
|
|||
"""
|
||||
Base Model Controller - Abstract base class for batched inference controllers.
|
||||
Base Model Controller - Simple circular buffer with continuous batch processing.
|
||||
|
||||
Provides ping-pong buffer architecture with force-switch timeout mechanism.
|
||||
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.).
|
||||
Replaces the complex ping-pong buffer architecture with a simple queue:
|
||||
- Frames arrive and go into a single deque (circular buffer)
|
||||
- When batch_size frames are ready, process them
|
||||
- Continue consuming batches until queue is empty
|
||||
- Drop oldest frames if queue is full
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
|
@ -28,62 +31,37 @@ class BatchFrame:
|
|||
metadata: Dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class BufferState(Enum):
|
||||
"""State of a ping-pong buffer"""
|
||||
|
||||
IDLE = "idle"
|
||||
FILLING = "filling"
|
||||
PROCESSING = "processing"
|
||||
|
||||
|
||||
class BaseModelController(ABC):
|
||||
"""
|
||||
Abstract base class for batched inference with ping-pong buffers.
|
||||
Simple batched inference with circular buffer.
|
||||
|
||||
This controller accumulates frames from multiple streams into batches,
|
||||
processes them through an inference backend, and routes results back to
|
||||
stream-specific callbacks.
|
||||
|
||||
Features:
|
||||
- Ping-pong circular buffers (BufferA/BufferB)
|
||||
- Force-switch timeout to prevent batch starvation
|
||||
- Event-driven processing with callbacks
|
||||
- Thread-safe frame submission
|
||||
|
||||
Subclasses must implement:
|
||||
- _run_batch_inference(): Backend-specific inference logic
|
||||
Architecture:
|
||||
- Single deque (circular buffer) for incoming frames
|
||||
- Batch processor thread continuously consumes batches
|
||||
- Frames come in fast, batches go out as fast as inference allows
|
||||
- Automatic frame dropping when queue is full
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
max_queue_size: int = 100,
|
||||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.batch_size = batch_size
|
||||
self.force_timeout = force_timeout
|
||||
self.max_queue_size = max_queue_size
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self.postprocess_fn = postprocess_fn
|
||||
|
||||
# Ping-pong buffers
|
||||
self.buffer_a: List[BatchFrame] = []
|
||||
self.buffer_b: List[BatchFrame] = []
|
||||
# Single circular buffer
|
||||
self.frame_queue = deque(maxlen=max_queue_size)
|
||||
self.queue_lock = threading.Lock()
|
||||
|
||||
# Buffer states
|
||||
self.active_buffer = "A"
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
# Threading coordination
|
||||
self.buffer_lock = threading.RLock()
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Threads
|
||||
self.timeout_thread: Optional[threading.Thread] = None
|
||||
self.processor_threads: Dict[str, threading.Thread] = {}
|
||||
# Processing thread
|
||||
self.processor_thread: Optional[threading.Thread] = None
|
||||
self.running = False
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
|
|
@ -93,33 +71,24 @@ class BaseModelController(ABC):
|
|||
# Statistics
|
||||
self.total_frames_processed = 0
|
||||
self.total_batches_processed = 0
|
||||
self.total_frames_dropped = 0
|
||||
|
||||
def start(self):
|
||||
"""Start the controller background threads"""
|
||||
"""Start the controller background thread"""
|
||||
if self.running:
|
||||
logger.warning("ModelController already running")
|
||||
logger.warning(f"{self.__class__.__name__} already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.stop_event.clear()
|
||||
|
||||
# Start timeout monitor thread
|
||||
self.timeout_thread = threading.Thread(
|
||||
target=self._timeout_monitor, daemon=True
|
||||
# Start single processor thread
|
||||
self.processor_thread = threading.Thread(
|
||||
target=self._batch_processor, daemon=True
|
||||
)
|
||||
self.timeout_thread.start()
|
||||
self.processor_thread.start()
|
||||
|
||||
# Start processor threads for each buffer
|
||||
self.processor_threads["A"] = threading.Thread(
|
||||
target=self._batch_processor, args=("A",), daemon=True
|
||||
)
|
||||
self.processor_threads["B"] = threading.Thread(
|
||||
target=self._batch_processor, args=("B",), daemon=True
|
||||
)
|
||||
self.processor_threads["A"].start()
|
||||
self.processor_threads["B"].start()
|
||||
|
||||
logger.info(f"{self.__class__.__name__} started")
|
||||
logger.info(f"{self.__class__.__name__} started (batch_size={self.batch_size})")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and cleanup"""
|
||||
|
|
@ -130,16 +99,12 @@ class BaseModelController(ABC):
|
|||
self.running = False
|
||||
self.stop_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if self.timeout_thread and self.timeout_thread.is_alive():
|
||||
self.timeout_thread.join(timeout=2.0)
|
||||
# Wait for thread to finish
|
||||
if self.processor_thread and self.processor_thread.is_alive():
|
||||
self.processor_thread.join(timeout=2.0)
|
||||
|
||||
for thread in self.processor_threads.values():
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
# Process any remaining frames
|
||||
self._process_remaining_buffers()
|
||||
# Process remaining frames
|
||||
self._process_remaining_frames()
|
||||
logger.info(f"{self.__class__.__name__} stopped")
|
||||
|
||||
def register_callback(self, stream_id: str, callback: Callable):
|
||||
|
|
@ -156,98 +121,48 @@ class BaseModelController(ABC):
|
|||
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
||||
):
|
||||
"""Submit a frame for batched inference"""
|
||||
with self.buffer_lock:
|
||||
with self.queue_lock:
|
||||
# If queue is full, oldest frame is automatically dropped (deque with maxlen)
|
||||
if len(self.frame_queue) >= self.max_queue_size:
|
||||
self.total_frames_dropped += 1
|
||||
|
||||
batch_frame = BatchFrame(
|
||||
stream_id=stream_id,
|
||||
frame=frame,
|
||||
timestamp=time.time(),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.frame_queue.append(batch_frame)
|
||||
|
||||
# Add to active buffer
|
||||
if self.active_buffer == "A":
|
||||
self.buffer_a.append(batch_frame)
|
||||
self.buffer_a_state = BufferState.FILLING
|
||||
buffer_size = len(self.buffer_a)
|
||||
else:
|
||||
self.buffer_b.append(batch_frame)
|
||||
self.buffer_b_state = BufferState.FILLING
|
||||
buffer_size = len(self.buffer_b)
|
||||
def _batch_processor(self):
|
||||
"""Background thread that continuously processes batches"""
|
||||
logger.info(f"{self.__class__.__name__} batch processor started")
|
||||
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Check if we should immediately swap (batch full)
|
||||
if buffer_size >= self.batch_size:
|
||||
self._try_swap_buffers()
|
||||
|
||||
def _timeout_monitor(self):
|
||||
"""Monitor force-switch timeout"""
|
||||
while self.running and not self.stop_event.wait(0.01):
|
||||
with self.buffer_lock:
|
||||
time_since_submit = time.time() - self.last_submit_time
|
||||
|
||||
if time_since_submit >= self.force_timeout:
|
||||
active_buffer = (
|
||||
self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||
)
|
||||
if len(active_buffer) > 0:
|
||||
self._try_swap_buffers()
|
||||
|
||||
def _try_swap_buffers(self):
|
||||
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
|
||||
inactive_state = (
|
||||
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
||||
)
|
||||
|
||||
if inactive_state != BufferState.PROCESSING:
|
||||
old_active = self.active_buffer
|
||||
self.active_buffer = "B" if old_active == "A" else "A"
|
||||
|
||||
if old_active == "A":
|
||||
self.buffer_a_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_a)
|
||||
else:
|
||||
self.buffer_b_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_b)
|
||||
|
||||
logger.debug(
|
||||
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
|
||||
)
|
||||
|
||||
def _batch_processor(self, buffer_name: str):
|
||||
"""Background thread that processes a specific buffer when available"""
|
||||
while self.running and not self.stop_event.is_set():
|
||||
time.sleep(0.001)
|
||||
# Check if we have enough frames for a batch
|
||||
with self.queue_lock:
|
||||
queue_size = len(self.frame_queue)
|
||||
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
should_process = self.buffer_a_state == BufferState.PROCESSING
|
||||
else:
|
||||
should_process = self.buffer_b_state == BufferState.PROCESSING
|
||||
if queue_size > 0 and queue_size % 10 == 0:
|
||||
logger.info(f"Queue size: {queue_size}/{self.batch_size}")
|
||||
|
||||
if should_process:
|
||||
self._process_buffer(buffer_name)
|
||||
if queue_size >= self.batch_size:
|
||||
# Extract batch
|
||||
with self.queue_lock:
|
||||
batch = []
|
||||
for _ in range(min(self.batch_size, len(self.frame_queue))):
|
||||
if self.frame_queue:
|
||||
batch.append(self.frame_queue.popleft())
|
||||
|
||||
def _process_buffer(self, buffer_name: str):
|
||||
"""Process a buffer through inference"""
|
||||
# Extract buffer to process
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
batch = self.buffer_a.copy()
|
||||
self.buffer_a.clear()
|
||||
if batch:
|
||||
logger.info(f"Processing batch of {len(batch)} frames")
|
||||
self._process_batch(batch)
|
||||
else:
|
||||
batch = self.buffer_b.copy()
|
||||
self.buffer_b.clear()
|
||||
# Not enough frames, sleep briefly
|
||||
time.sleep(0.001) # 1ms
|
||||
|
||||
if len(batch) == 0:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
return
|
||||
|
||||
# Process batch (outside lock to allow concurrent submissions)
|
||||
def _process_batch(self, batch: List[BatchFrame]):
|
||||
"""Process a batch through inference"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
results = self._run_batch_inference(batch)
|
||||
|
|
@ -257,8 +172,7 @@ class BaseModelController(ABC):
|
|||
self.total_batches_processed += 1
|
||||
|
||||
logger.debug(
|
||||
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
|
||||
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
|
||||
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
# Emit results to callbacks
|
||||
|
|
@ -276,49 +190,58 @@ class BaseModelController(ABC):
|
|||
except Exception as e:
|
||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
@abstractmethod
|
||||
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run inference on a batch of frames (backend-specific).
|
||||
|
||||
Args:
|
||||
batch: List of BatchFrame objects
|
||||
|
||||
Returns:
|
||||
List of detection results (one per frame)
|
||||
"""
|
||||
"""Run inference on a batch of frames (backend-specific)"""
|
||||
pass
|
||||
|
||||
def _process_remaining_buffers(self):
|
||||
"""Process any remaining frames in buffers during shutdown"""
|
||||
if len(self.buffer_a) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
||||
self._process_buffer("A")
|
||||
if len(self.buffer_b) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
||||
self._process_buffer("B")
|
||||
def _process_remaining_frames(self):
|
||||
"""Process any remaining frames in queue during shutdown"""
|
||||
with self.queue_lock:
|
||||
remaining = len(self.frame_queue)
|
||||
|
||||
if remaining > 0:
|
||||
logger.info(f"Processing remaining {remaining} frames")
|
||||
while True:
|
||||
with self.queue_lock:
|
||||
if not self.frame_queue:
|
||||
break
|
||||
batch = []
|
||||
for _ in range(min(self.batch_size, len(self.frame_queue))):
|
||||
if self.frame_queue:
|
||||
batch.append(self.frame_queue.popleft())
|
||||
|
||||
if batch:
|
||||
self._process_batch(batch)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current buffer statistics"""
|
||||
"""Get current statistics"""
|
||||
with self.queue_lock:
|
||||
queue_size = len(self.frame_queue)
|
||||
|
||||
return {
|
||||
"active_buffer": self.active_buffer,
|
||||
"buffer_a_size": len(self.buffer_a),
|
||||
"buffer_b_size": len(self.buffer_b),
|
||||
"buffer_a_state": self.buffer_a_state.value,
|
||||
"buffer_b_state": self.buffer_b_state.value,
|
||||
"queue_size": queue_size,
|
||||
"max_queue_size": self.max_queue_size,
|
||||
"batch_size": self.batch_size,
|
||||
"registered_streams": len(self.result_callbacks),
|
||||
"total_frames_processed": self.total_frames_processed,
|
||||
"total_batches_processed": self.total_batches_processed,
|
||||
"total_frames_dropped": self.total_frames_dropped,
|
||||
"avg_batch_size": (
|
||||
self.total_frames_processed / self.total_batches_processed
|
||||
if self.total_batches_processed > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Keep old BufferState enum for backwards compatibility
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BufferState(Enum):
|
||||
"""Deprecated - kept for backwards compatibility"""
|
||||
|
||||
IDLE = "idle"
|
||||
FILLING = "filling"
|
||||
PROCESSING = "processing"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue