528 lines
19 KiB
Python
528 lines
19 KiB
Python
"""
|
|
ModelController - Event-driven batching layer with ping-pong buffers for inference.
|
|
|
|
This module provides batched inference coordination using ping-pong circular buffers
|
|
with force-switch timeout mechanism using threading and callbacks.
|
|
"""
|
|
|
|
import logging
|
|
import queue
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BatchFrame:
|
|
|
|
stream_id: str
|
|
frame: torch.Tensor # GPU tensor (3, H, W)
|
|
timestamp: float
|
|
metadata: Dict = field(default_factory=dict)
|
|
|
|
|
|
class BufferState(Enum):
|
|
"""State of a ping-pong buffer"""
|
|
|
|
IDLE = "idle"
|
|
FILLING = "filling"
|
|
PROCESSING = "processing"
|
|
|
|
|
|
class ModelController:
|
|
"""
|
|
Manages batched inference with ping-pong buffers and force-switch timeout.
|
|
|
|
This controller accumulates frames from multiple streams into batches,
|
|
processes them through a model repository, and routes results back to
|
|
stream-specific callbacks.
|
|
|
|
Features:
|
|
- Ping-pong circular buffers (BufferA/BufferB)
|
|
- Force-switch timeout to prevent batch starvation
|
|
- Event-driven processing with callbacks
|
|
- Thread-safe frame submission
|
|
|
|
Args:
|
|
model_repository: TensorRT model repository for inference
|
|
model_id: Model identifier in the repository
|
|
batch_size: Maximum frames per batch (default: 16)
|
|
force_timeout: Max wait time before forcing buffer switch in seconds (default: 0.05)
|
|
preprocess_fn: Optional preprocessing function for frames
|
|
postprocess_fn: Optional postprocessing function for model outputs
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_repository,
|
|
model_id: str,
|
|
batch_size: int = 16,
|
|
force_timeout: float = 0.05,
|
|
preprocess_fn: Optional[Callable] = None,
|
|
postprocess_fn: Optional[Callable] = None,
|
|
):
|
|
self.model_repository = model_repository
|
|
self.model_id = model_id
|
|
self.batch_size = batch_size
|
|
self.force_timeout = force_timeout
|
|
self.preprocess_fn = preprocess_fn
|
|
self.postprocess_fn = postprocess_fn
|
|
|
|
# Detect model's actual batch size from input shape
|
|
self.model_batch_size = self._detect_model_batch_size()
|
|
if self.model_batch_size == 1:
|
|
logger.warning(
|
|
f"Model '{model_id}' has fixed batch_size=1. "
|
|
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
|
|
)
|
|
|
|
# Ping-pong buffers
|
|
self.buffer_a: List[BatchFrame] = []
|
|
self.buffer_b: List[BatchFrame] = []
|
|
|
|
# Buffer states
|
|
self.active_buffer = "A" # Which buffer is currently active (filling)
|
|
self.buffer_a_state = BufferState.IDLE
|
|
self.buffer_b_state = BufferState.IDLE
|
|
|
|
# Threading coordination
|
|
self.buffer_lock = threading.RLock()
|
|
self.last_submit_time = time.time()
|
|
|
|
# Threads
|
|
self.timeout_thread: Optional[threading.Thread] = None
|
|
self.processor_threads: Dict[str, threading.Thread] = {}
|
|
self.running = False
|
|
self.stop_event = threading.Event()
|
|
|
|
# Result callbacks (stream_id -> callback)
|
|
self.result_callbacks: Dict[str, Callable] = {}
|
|
|
|
# Statistics
|
|
self.total_frames_processed = 0
|
|
self.total_batches_processed = 0
|
|
|
|
def _detect_model_batch_size(self) -> int:
|
|
"""
|
|
Detect the model's batch size from its input shape.
|
|
|
|
Returns:
|
|
Maximum batch size supported by the model (1 for fixed batch size models)
|
|
"""
|
|
try:
|
|
metadata = self.model_repository.get_metadata(self.model_id)
|
|
# Get first input tensor shape (ModelMetadata has input_shapes, not inputs)
|
|
first_input_name = metadata.input_names[0]
|
|
input_shape = metadata.input_shapes[first_input_name]
|
|
batch_dim = input_shape[0]
|
|
|
|
# batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size)
|
|
if batch_dim == -1:
|
|
# Dynamic batch size - use user-specified batch_size
|
|
return self.batch_size
|
|
else:
|
|
# Fixed batch size
|
|
return batch_dim
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Could not detect model batch size: {e}. Assuming batch_size=1"
|
|
)
|
|
return 1
|
|
|
|
def start(self):
|
|
"""Start the controller background threads"""
|
|
if self.running:
|
|
logger.warning("ModelController already running")
|
|
return
|
|
|
|
self.running = True
|
|
self.stop_event.clear()
|
|
|
|
# Start timeout monitor thread
|
|
self.timeout_thread = threading.Thread(
|
|
target=self._timeout_monitor, daemon=True
|
|
)
|
|
self.timeout_thread.start()
|
|
|
|
# Start processor threads for each buffer
|
|
self.processor_threads["A"] = threading.Thread(
|
|
target=self._batch_processor, args=("A",), daemon=True
|
|
)
|
|
self.processor_threads["B"] = threading.Thread(
|
|
target=self._batch_processor, args=("B",), daemon=True
|
|
)
|
|
self.processor_threads["A"].start()
|
|
self.processor_threads["B"].start()
|
|
|
|
logger.info("ModelController started")
|
|
|
|
def stop(self):
|
|
"""Stop the controller and cleanup"""
|
|
if not self.running:
|
|
return
|
|
|
|
logger.info("Stopping ModelController...")
|
|
self.running = False
|
|
self.stop_event.set()
|
|
|
|
# Wait for threads to finish
|
|
if self.timeout_thread and self.timeout_thread.is_alive():
|
|
self.timeout_thread.join(timeout=2.0)
|
|
|
|
for thread in self.processor_threads.values():
|
|
if thread and thread.is_alive():
|
|
thread.join(timeout=2.0)
|
|
|
|
# Process any remaining frames
|
|
self._process_remaining_buffers()
|
|
logger.info("ModelController stopped")
|
|
|
|
def register_callback(self, stream_id: str, callback: Callable):
|
|
"""
|
|
Register a callback for inference results from a stream.
|
|
|
|
Args:
|
|
stream_id: Unique stream identifier
|
|
callback: Callback function to receive results (can be sync or async)
|
|
"""
|
|
self.result_callbacks[stream_id] = callback
|
|
logger.debug(f"Registered callback for stream: {stream_id}")
|
|
|
|
def unregister_callback(self, stream_id: str):
|
|
"""
|
|
Unregister a stream callback.
|
|
|
|
Args:
|
|
stream_id: Stream identifier to unregister
|
|
"""
|
|
self.result_callbacks.pop(stream_id, None)
|
|
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
|
|
|
def submit_frame(
|
|
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
|
):
|
|
"""
|
|
Submit a frame for batched inference.
|
|
|
|
Args:
|
|
stream_id: Unique stream identifier
|
|
frame: GPU tensor (3, H, W) or (C, H, W)
|
|
metadata: Optional metadata to attach to the frame
|
|
"""
|
|
with self.buffer_lock:
|
|
batch_frame = BatchFrame(
|
|
stream_id=stream_id,
|
|
frame=frame,
|
|
timestamp=time.time(),
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
# Add to active buffer
|
|
if self.active_buffer == "A":
|
|
self.buffer_a.append(batch_frame)
|
|
self.buffer_a_state = BufferState.FILLING
|
|
buffer_size = len(self.buffer_a)
|
|
else:
|
|
self.buffer_b.append(batch_frame)
|
|
self.buffer_b_state = BufferState.FILLING
|
|
buffer_size = len(self.buffer_b)
|
|
|
|
self.last_submit_time = time.time()
|
|
|
|
# Check if we should immediately swap (batch full)
|
|
if buffer_size >= self.batch_size:
|
|
self._try_swap_buffers()
|
|
|
|
def _timeout_monitor(self):
|
|
"""Monitor force-switch timeout"""
|
|
while self.running and not self.stop_event.wait(0.01): # Check every 10ms
|
|
with self.buffer_lock:
|
|
time_since_submit = time.time() - self.last_submit_time
|
|
|
|
# Check if timeout expired and we have frames waiting
|
|
if time_since_submit >= self.force_timeout:
|
|
active_buffer = (
|
|
self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
|
)
|
|
if len(active_buffer) > 0:
|
|
self._try_swap_buffers()
|
|
|
|
def _try_swap_buffers(self):
|
|
"""
|
|
Attempt to swap ping-pong buffers.
|
|
Only swaps if the inactive buffer is not currently processing.
|
|
|
|
This method should be called with buffer_lock held.
|
|
"""
|
|
# Check if inactive buffer is available
|
|
inactive_state = (
|
|
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
|
)
|
|
|
|
if inactive_state != BufferState.PROCESSING:
|
|
# Swap active buffer
|
|
old_active = self.active_buffer
|
|
self.active_buffer = "B" if old_active == "A" else "A"
|
|
|
|
# Mark old active buffer as ready for processing
|
|
if old_active == "A":
|
|
self.buffer_a_state = BufferState.PROCESSING
|
|
buffer_size = len(self.buffer_a)
|
|
else:
|
|
self.buffer_b_state = BufferState.PROCESSING
|
|
buffer_size = len(self.buffer_b)
|
|
|
|
logger.debug(
|
|
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
|
|
)
|
|
|
|
def _batch_processor(self, buffer_name: str):
|
|
"""Background thread that processes a specific buffer when available"""
|
|
while self.running and not self.stop_event.is_set():
|
|
time.sleep(0.001) # Check every 1ms
|
|
|
|
# Check if this buffer needs processing
|
|
with self.buffer_lock:
|
|
if buffer_name == "A":
|
|
should_process = self.buffer_a_state == BufferState.PROCESSING
|
|
else:
|
|
should_process = self.buffer_b_state == BufferState.PROCESSING
|
|
|
|
if should_process:
|
|
self._process_buffer(buffer_name)
|
|
|
|
def _process_buffer(self, buffer_name: str):
|
|
"""
|
|
Process a buffer through inference.
|
|
|
|
Args:
|
|
buffer_name: "A" or "B"
|
|
"""
|
|
# Extract buffer to process
|
|
with self.buffer_lock:
|
|
if buffer_name == "A":
|
|
batch = self.buffer_a.copy()
|
|
self.buffer_a.clear()
|
|
else:
|
|
batch = self.buffer_b.copy()
|
|
self.buffer_b.clear()
|
|
|
|
if len(batch) == 0:
|
|
# Mark as idle and return
|
|
with self.buffer_lock:
|
|
if buffer_name == "A":
|
|
self.buffer_a_state = BufferState.IDLE
|
|
else:
|
|
self.buffer_b_state = BufferState.IDLE
|
|
return
|
|
|
|
# Process batch (outside lock to allow concurrent submissions)
|
|
try:
|
|
start_time = time.time()
|
|
results = self._run_batch_inference(batch)
|
|
inference_time = time.time() - start_time
|
|
|
|
# Update statistics
|
|
self.total_frames_processed += len(batch)
|
|
self.total_batches_processed += 1
|
|
|
|
logger.debug(
|
|
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
|
|
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
|
|
)
|
|
|
|
# Emit results to callbacks
|
|
for batch_frame, result in zip(batch, results):
|
|
callback = self.result_callbacks.get(batch_frame.stream_id)
|
|
if callback:
|
|
# Call callback directly (synchronous)
|
|
try:
|
|
callback(result)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error in callback for {batch_frame.stream_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing batch: {e}", exc_info=True)
|
|
|
|
finally:
|
|
# Mark buffer as idle
|
|
with self.buffer_lock:
|
|
if buffer_name == "A":
|
|
self.buffer_a_state = BufferState.IDLE
|
|
else:
|
|
self.buffer_b_state = BufferState.IDLE
|
|
|
|
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Run inference on a batch of frames.
|
|
|
|
Args:
|
|
batch: List of BatchFrame objects
|
|
|
|
Returns:
|
|
List of detection results (one per frame)
|
|
"""
|
|
# Check if model supports batching
|
|
if self.model_batch_size == 1:
|
|
# Process frames one at a time for batch_size=1 models
|
|
return self._run_sequential_inference(batch)
|
|
else:
|
|
# Use true batching for models that support it
|
|
return self._run_batched_inference(batch)
|
|
|
|
def _run_sequential_inference(
|
|
self, batch: List[BatchFrame]
|
|
) -> List[Dict[str, Any]]:
|
|
"""Run inference sequentially for batch_size=1 models"""
|
|
results = []
|
|
|
|
for batch_frame in batch:
|
|
# Preprocess frame
|
|
if self.preprocess_fn:
|
|
processed = self.preprocess_fn(batch_frame.frame)
|
|
else:
|
|
# Ensure we have batch dimension
|
|
processed = (
|
|
batch_frame.frame.unsqueeze(0)
|
|
if batch_frame.frame.dim() == 3
|
|
else batch_frame.frame
|
|
)
|
|
|
|
# Run inference for this frame
|
|
outputs = self.model_repository.infer(
|
|
self.model_id, {"images": processed}, synchronize=True
|
|
)
|
|
|
|
# Postprocess
|
|
if self.postprocess_fn:
|
|
try:
|
|
detections = self.postprocess_fn(outputs)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
|
)
|
|
# Return empty detections on error
|
|
detections = torch.zeros(
|
|
(0, 6), device=list(outputs.values())[0].device
|
|
)
|
|
else:
|
|
detections = outputs
|
|
|
|
result = {
|
|
"stream_id": batch_frame.stream_id,
|
|
"timestamp": batch_frame.timestamp,
|
|
"detections": detections,
|
|
"metadata": batch_frame.metadata,
|
|
}
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
|
"""Run true batched inference for models that support it"""
|
|
# Preprocess frames (on GPU)
|
|
preprocessed = []
|
|
for batch_frame in batch:
|
|
if self.preprocess_fn:
|
|
processed = self.preprocess_fn(batch_frame.frame)
|
|
# Preprocess may return (1, C, H, W), squeeze to (C, H, W)
|
|
if processed.dim() == 4 and processed.shape[0] == 1:
|
|
processed = processed.squeeze(0)
|
|
else:
|
|
processed = batch_frame.frame
|
|
preprocessed.append(processed)
|
|
|
|
# Stack into batch tensor: (N, C, H, W)
|
|
batch_tensor = torch.stack(preprocessed, dim=0)
|
|
|
|
# Limit batch size to model's max batch size
|
|
if batch_tensor.shape[0] > self.model_batch_size:
|
|
logger.warning(
|
|
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}, "
|
|
f"will split into sub-batches"
|
|
)
|
|
# TODO: Handle splitting into sub-batches
|
|
batch_tensor = batch_tensor[: self.model_batch_size]
|
|
batch = batch[: self.model_batch_size]
|
|
|
|
# Run inference
|
|
outputs = self.model_repository.infer(
|
|
self.model_id, {"images": batch_tensor}, synchronize=True
|
|
)
|
|
|
|
# Postprocess results (split batch back to individual results)
|
|
results = []
|
|
for i, batch_frame in enumerate(batch):
|
|
# Extract single frame output from batch and clone to ensure memory safety
|
|
# This prevents potential race conditions if the output tensors are still
|
|
# in use when the next inference batch is processed
|
|
frame_output = {}
|
|
for k, v in outputs.items():
|
|
# v has shape (N, ...), extract index i and keep batch dimension
|
|
# Clone to decouple from shared batch output tensor
|
|
frame_output[k] = v[i : i + 1].clone() # Shape: (1, ...)
|
|
|
|
if self.postprocess_fn:
|
|
try:
|
|
detections = self.postprocess_fn(frame_output)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
|
)
|
|
# Return empty detections on error
|
|
detections = torch.zeros(
|
|
(0, 6), device=list(outputs.values())[0].device
|
|
)
|
|
else:
|
|
detections = frame_output
|
|
|
|
result = {
|
|
"stream_id": batch_frame.stream_id,
|
|
"timestamp": batch_frame.timestamp,
|
|
"detections": detections,
|
|
"metadata": batch_frame.metadata,
|
|
}
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
def _process_remaining_buffers(self):
|
|
"""Process any remaining frames in buffers during shutdown"""
|
|
if len(self.buffer_a) > 0:
|
|
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
|
self._process_buffer("A")
|
|
if len(self.buffer_b) > 0:
|
|
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
|
self._process_buffer("B")
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get current buffer statistics"""
|
|
return {
|
|
"active_buffer": self.active_buffer,
|
|
"buffer_a_size": len(self.buffer_a),
|
|
"buffer_b_size": len(self.buffer_b),
|
|
"buffer_a_state": self.buffer_a_state.value,
|
|
"buffer_b_state": self.buffer_b_state.value,
|
|
"registered_streams": len(self.result_callbacks),
|
|
"total_frames_processed": self.total_frames_processed,
|
|
"total_batches_processed": self.total_batches_processed,
|
|
"avg_batch_size": (
|
|
self.total_frames_processed / self.total_batches_processed
|
|
if self.total_batches_processed > 0
|
|
else 0
|
|
if self.total_batches_processed > 0 else 0
|
|
),
|
|
}
|