324 lines
11 KiB
Python
324 lines
11 KiB
Python
"""
|
|
Base Model Controller - Abstract base class for batched inference controllers.
|
|
|
|
Provides ping-pong buffer architecture with force-switch timeout mechanism.
|
|
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.).
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BatchFrame:
|
|
"""Represents a frame in the batch buffer"""
|
|
|
|
stream_id: str
|
|
frame: torch.Tensor # GPU tensor (3, H, W)
|
|
timestamp: float
|
|
metadata: Dict = field(default_factory=dict)
|
|
|
|
|
|
class BufferState(Enum):
|
|
"""State of a ping-pong buffer"""
|
|
|
|
IDLE = "idle"
|
|
FILLING = "filling"
|
|
PROCESSING = "processing"
|
|
|
|
|
|
class BaseModelController(ABC):
|
|
"""
|
|
Abstract base class for batched inference with ping-pong buffers.
|
|
|
|
This controller accumulates frames from multiple streams into batches,
|
|
processes them through an inference backend, and routes results back to
|
|
stream-specific callbacks.
|
|
|
|
Features:
|
|
- Ping-pong circular buffers (BufferA/BufferB)
|
|
- Force-switch timeout to prevent batch starvation
|
|
- Event-driven processing with callbacks
|
|
- Thread-safe frame submission
|
|
|
|
Subclasses must implement:
|
|
- _run_batch_inference(): Backend-specific inference logic
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
batch_size: int = 16,
|
|
force_timeout: float = 0.05,
|
|
preprocess_fn: Optional[Callable] = None,
|
|
postprocess_fn: Optional[Callable] = None,
|
|
):
|
|
self.model_id = model_id
|
|
self.batch_size = batch_size
|
|
self.force_timeout = force_timeout
|
|
self.preprocess_fn = preprocess_fn
|
|
self.postprocess_fn = postprocess_fn
|
|
|
|
# Ping-pong buffers
|
|
self.buffer_a: List[BatchFrame] = []
|
|
self.buffer_b: List[BatchFrame] = []
|
|
|
|
# Buffer states
|
|
self.active_buffer = "A"
|
|
self.buffer_a_state = BufferState.IDLE
|
|
self.buffer_b_state = BufferState.IDLE
|
|
|
|
# Threading coordination
|
|
self.buffer_lock = threading.RLock()
|
|
self.last_submit_time = time.time()
|
|
|
|
# Threads
|
|
self.timeout_thread: Optional[threading.Thread] = None
|
|
self.processor_threads: Dict[str, threading.Thread] = {}
|
|
self.running = False
|
|
self.stop_event = threading.Event()
|
|
|
|
# Result callbacks (stream_id -> callback)
|
|
self.result_callbacks: Dict[str, Callable] = {}
|
|
|
|
# Statistics
|
|
self.total_frames_processed = 0
|
|
self.total_batches_processed = 0
|
|
|
|
def start(self):
|
|
"""Start the controller background threads"""
|
|
if self.running:
|
|
logger.warning("ModelController already running")
|
|
return
|
|
|
|
self.running = True
|
|
self.stop_event.clear()
|
|
|
|
# Start timeout monitor thread
|
|
self.timeout_thread = threading.Thread(
|
|
target=self._timeout_monitor, daemon=True
|
|
)
|
|
self.timeout_thread.start()
|
|
|
|
# Start processor threads for each buffer
|
|
self.processor_threads["A"] = threading.Thread(
|
|
target=self._batch_processor, args=("A",), daemon=True
|
|
)
|
|
self.processor_threads["B"] = threading.Thread(
|
|
target=self._batch_processor, args=("B",), daemon=True
|
|
)
|
|
self.processor_threads["A"].start()
|
|
self.processor_threads["B"].start()
|
|
|
|
logger.info(f"{self.__class__.__name__} started")
|
|
|
|
def stop(self):
|
|
"""Stop the controller and cleanup"""
|
|
if not self.running:
|
|
return
|
|
|
|
logger.info(f"Stopping {self.__class__.__name__}...")
|
|
self.running = False
|
|
self.stop_event.set()
|
|
|
|
# Wait for threads to finish
|
|
if self.timeout_thread and self.timeout_thread.is_alive():
|
|
self.timeout_thread.join(timeout=2.0)
|
|
|
|
for thread in self.processor_threads.values():
|
|
if thread and thread.is_alive():
|
|
thread.join(timeout=2.0)
|
|
|
|
# Process any remaining frames
|
|
self._process_remaining_buffers()
|
|
logger.info(f"{self.__class__.__name__} stopped")
|
|
|
|
def register_callback(self, stream_id: str, callback: Callable):
|
|
"""Register a callback for inference results from a stream"""
|
|
self.result_callbacks[stream_id] = callback
|
|
logger.debug(f"Registered callback for stream: {stream_id}")
|
|
|
|
def unregister_callback(self, stream_id: str):
|
|
"""Unregister a stream callback"""
|
|
self.result_callbacks.pop(stream_id, None)
|
|
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
|
|
|
def submit_frame(
|
|
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
|
):
|
|
"""Submit a frame for batched inference"""
|
|
with self.buffer_lock:
|
|
batch_frame = BatchFrame(
|
|
stream_id=stream_id,
|
|
frame=frame,
|
|
timestamp=time.time(),
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
# Add to active buffer
|
|
if self.active_buffer == "A":
|
|
self.buffer_a.append(batch_frame)
|
|
self.buffer_a_state = BufferState.FILLING
|
|
buffer_size = len(self.buffer_a)
|
|
else:
|
|
self.buffer_b.append(batch_frame)
|
|
self.buffer_b_state = BufferState.FILLING
|
|
buffer_size = len(self.buffer_b)
|
|
|
|
self.last_submit_time = time.time()
|
|
|
|
# Check if we should immediately swap (batch full)
|
|
if buffer_size >= self.batch_size:
|
|
self._try_swap_buffers()
|
|
|
|
def _timeout_monitor(self):
|
|
"""Monitor force-switch timeout"""
|
|
while self.running and not self.stop_event.wait(0.01):
|
|
with self.buffer_lock:
|
|
time_since_submit = time.time() - self.last_submit_time
|
|
|
|
if time_since_submit >= self.force_timeout:
|
|
active_buffer = (
|
|
self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
|
)
|
|
if len(active_buffer) > 0:
|
|
self._try_swap_buffers()
|
|
|
|
def _try_swap_buffers(self):
|
|
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
|
|
inactive_state = (
|
|
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
|
)
|
|
|
|
if inactive_state != BufferState.PROCESSING:
|
|
old_active = self.active_buffer
|
|
self.active_buffer = "B" if old_active == "A" else "A"
|
|
|
|
if old_active == "A":
|
|
self.buffer_a_state = BufferState.PROCESSING
|
|
buffer_size = len(self.buffer_a)
|
|
else:
|
|
self.buffer_b_state = BufferState.PROCESSING
|
|
buffer_size = len(self.buffer_b)
|
|
|
|
logger.debug(
|
|
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
|
|
)
|
|
|
|
def _batch_processor(self, buffer_name: str):
|
|
"""Background thread that processes a specific buffer when available"""
|
|
while self.running and not self.stop_event.is_set():
|
|
time.sleep(0.001)
|
|
|
|
with self.buffer_lock:
|
|
if buffer_name == "A":
|
|
should_process = self.buffer_a_state == BufferState.PROCESSING
|
|
else:
|
|
should_process = self.buffer_b_state == BufferState.PROCESSING
|
|
|
|
if should_process:
|
|
self._process_buffer(buffer_name)
|
|
|
|
def _process_buffer(self, buffer_name: str):
|
|
"""Process a buffer through inference"""
|
|
# Extract buffer to process
|
|
with self.buffer_lock:
|
|
if buffer_name == "A":
|
|
batch = self.buffer_a.copy()
|
|
self.buffer_a.clear()
|
|
else:
|
|
batch = self.buffer_b.copy()
|
|
self.buffer_b.clear()
|
|
|
|
if len(batch) == 0:
|
|
with self.buffer_lock:
|
|
if buffer_name == "A":
|
|
self.buffer_a_state = BufferState.IDLE
|
|
else:
|
|
self.buffer_b_state = BufferState.IDLE
|
|
return
|
|
|
|
# Process batch (outside lock to allow concurrent submissions)
|
|
try:
|
|
start_time = time.time()
|
|
results = self._run_batch_inference(batch)
|
|
inference_time = time.time() - start_time
|
|
|
|
self.total_frames_processed += len(batch)
|
|
self.total_batches_processed += 1
|
|
|
|
logger.debug(
|
|
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
|
|
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
|
|
)
|
|
|
|
# Emit results to callbacks
|
|
for batch_frame, result in zip(batch, results):
|
|
callback = self.result_callbacks.get(batch_frame.stream_id)
|
|
if callback:
|
|
try:
|
|
callback(result)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error in callback for {batch_frame.stream_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing batch: {e}", exc_info=True)
|
|
|
|
finally:
|
|
with self.buffer_lock:
|
|
if buffer_name == "A":
|
|
self.buffer_a_state = BufferState.IDLE
|
|
else:
|
|
self.buffer_b_state = BufferState.IDLE
|
|
|
|
@abstractmethod
|
|
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Run inference on a batch of frames (backend-specific).
|
|
|
|
Args:
|
|
batch: List of BatchFrame objects
|
|
|
|
Returns:
|
|
List of detection results (one per frame)
|
|
"""
|
|
pass
|
|
|
|
def _process_remaining_buffers(self):
|
|
"""Process any remaining frames in buffers during shutdown"""
|
|
if len(self.buffer_a) > 0:
|
|
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
|
self._process_buffer("A")
|
|
if len(self.buffer_b) > 0:
|
|
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
|
self._process_buffer("B")
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get current buffer statistics"""
|
|
return {
|
|
"active_buffer": self.active_buffer,
|
|
"buffer_a_size": len(self.buffer_a),
|
|
"buffer_b_size": len(self.buffer_b),
|
|
"buffer_a_state": self.buffer_a_state.value,
|
|
"buffer_b_state": self.buffer_b_state.value,
|
|
"registered_streams": len(self.result_callbacks),
|
|
"total_frames_processed": self.total_frames_processed,
|
|
"total_batches_processed": self.total_batches_processed,
|
|
"avg_batch_size": (
|
|
self.total_frames_processed / self.total_batches_processed
|
|
if self.total_batches_processed > 0
|
|
else 0
|
|
),
|
|
}
|