batch processing/event driven

This commit is contained in:
Siwat Sirichai 2025-11-09 18:43:56 +07:00
parent e71316ef3d
commit dd57b5a246
7 changed files with 2673 additions and 2 deletions

View file

@ -0,0 +1,499 @@
"""
ModelController - Async batching layer with ping-pong buffers for inference.
This module provides batched inference coordination using ping-pong circular buffers
with force-switch timeout mechanism.
"""
import asyncio
import torch
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import time
import logging
logger = logging.getLogger(__name__)
@dataclass
class BatchFrame:
"""Represents a frame in the batch buffer"""
stream_id: str
frame: torch.Tensor # GPU tensor (3, H, W)
timestamp: float
metadata: Dict = field(default_factory=dict)
class BufferState(Enum):
"""State of a ping-pong buffer"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"
class ModelController:
"""
Manages batched inference with ping-pong buffers and force-switch timeout.
This controller accumulates frames from multiple streams into batches,
processes them through a model repository, and routes results back to
stream-specific callbacks.
Features:
- Ping-pong circular buffers (BufferA/BufferB)
- Force-switch timeout to prevent batch starvation
- Async event-driven processing
- Thread-safe frame submission
Args:
model_repository: TensorRT model repository for inference
model_id: Model identifier in the repository
batch_size: Maximum frames per batch (default: 16)
force_timeout: Max wait time before forcing buffer switch in seconds (default: 0.05)
preprocess_fn: Optional preprocessing function for frames
postprocess_fn: Optional postprocessing function for model outputs
"""
def __init__(
self,
model_repository,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
self.model_repository = model_repository
self.model_id = model_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn
# Detect model's actual batch size from input shape
self.model_batch_size = self._detect_model_batch_size()
if self.model_batch_size == 1:
logger.warning(
f"Model '{model_id}' has fixed batch_size=1. "
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
)
else:
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}")
# Ping-pong buffers
self.buffer_a: List[BatchFrame] = []
self.buffer_b: List[BatchFrame] = []
# Buffer states
self.active_buffer = "A" # Which buffer is currently active (filling)
self.buffer_a_state = BufferState.IDLE
self.buffer_b_state = BufferState.IDLE
# Async coordination
self.buffer_lock = asyncio.Lock()
self.last_submit_time = time.time()
# Tasks
self.timeout_task: Optional[asyncio.Task] = None
self.processor_task: Optional[asyncio.Task] = None
self.running = False
# Result callbacks (stream_id -> callback)
self.result_callbacks: Dict[str, Callable] = {}
# Statistics
self.total_frames_processed = 0
self.total_batches_processed = 0
def _detect_model_batch_size(self) -> int:
"""
Detect the model's batch size from its input shape.
Returns:
Maximum batch size supported by the model (1 for fixed batch size models)
"""
try:
metadata = self.model_repository.get_metadata(self.model_id)
# Get first input tensor shape
first_input = list(metadata.inputs.values())[0]
batch_dim = first_input["shape"][0]
# batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size)
if batch_dim == -1:
# Dynamic batch size - use user-specified batch_size
return self.batch_size
else:
# Fixed batch size
return batch_dim
except Exception as e:
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
return 1
async def start(self):
"""Start the controller background tasks"""
if self.running:
logger.warning("ModelController already running")
return
self.running = True
self.timeout_task = asyncio.create_task(self._timeout_monitor())
self.processor_task = asyncio.create_task(self._batch_processor())
logger.info("ModelController started")
async def stop(self):
"""Stop the controller and cleanup"""
if not self.running:
return
logger.info("Stopping ModelController...")
self.running = False
# Cancel tasks
if self.timeout_task:
self.timeout_task.cancel()
try:
await self.timeout_task
except asyncio.CancelledError:
pass
if self.processor_task:
self.processor_task.cancel()
try:
await self.processor_task
except asyncio.CancelledError:
pass
# Process any remaining frames
await self._process_remaining_buffers()
logger.info("ModelController stopped")
def register_callback(self, stream_id: str, callback: Callable):
"""
Register a callback for inference results from a stream.
Args:
stream_id: Unique stream identifier
callback: Callback function to receive results (can be sync or async)
"""
self.result_callbacks[stream_id] = callback
logger.debug(f"Registered callback for stream: {stream_id}")
def unregister_callback(self, stream_id: str):
"""
Unregister a stream callback.
Args:
stream_id: Stream identifier to unregister
"""
self.result_callbacks.pop(stream_id, None)
logger.debug(f"Unregistered callback for stream: {stream_id}")
async def submit_frame(
self,
stream_id: str,
frame: torch.Tensor,
metadata: Optional[Dict] = None
):
"""
Submit a frame for batched inference.
Args:
stream_id: Unique stream identifier
frame: GPU tensor (3, H, W) or (C, H, W)
metadata: Optional metadata to attach to the frame
"""
async with self.buffer_lock:
batch_frame = BatchFrame(
stream_id=stream_id,
frame=frame,
timestamp=time.time(),
metadata=metadata or {}
)
# Add to active buffer
if self.active_buffer == "A":
self.buffer_a.append(batch_frame)
self.buffer_a_state = BufferState.FILLING
buffer_size = len(self.buffer_a)
else:
self.buffer_b.append(batch_frame)
self.buffer_b_state = BufferState.FILLING
buffer_size = len(self.buffer_b)
self.last_submit_time = time.time()
# Check if we should immediately swap (batch full)
if buffer_size >= self.batch_size:
await self._try_swap_buffers()
async def _timeout_monitor(self):
"""Monitor force-switch timeout"""
while self.running:
await asyncio.sleep(0.01) # Check every 10ms
async with self.buffer_lock:
time_since_submit = time.time() - self.last_submit_time
# Check if timeout expired and we have frames waiting
if time_since_submit >= self.force_timeout:
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
if len(active_buffer) > 0:
await self._try_swap_buffers()
async def _try_swap_buffers(self):
"""
Attempt to swap ping-pong buffers.
Only swaps if the inactive buffer is not currently processing.
This method should be called with buffer_lock held.
"""
# Check if inactive buffer is available
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
if inactive_state != BufferState.PROCESSING:
# Swap active buffer
old_active = self.active_buffer
self.active_buffer = "B" if old_active == "A" else "A"
# Mark old active buffer as ready for processing
if old_active == "A":
self.buffer_a_state = BufferState.PROCESSING
buffer_size = len(self.buffer_a)
else:
self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b)
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
async def _batch_processor(self):
"""Background task that processes batches when available"""
while self.running:
await asyncio.sleep(0.001) # Check every 1ms
# Check if buffer A needs processing
if self.buffer_a_state == BufferState.PROCESSING:
await self._process_buffer("A")
# Check if buffer B needs processing
if self.buffer_b_state == BufferState.PROCESSING:
await self._process_buffer("B")
async def _process_buffer(self, buffer_name: str):
"""
Process a buffer through inference.
Args:
buffer_name: "A" or "B"
"""
# Extract buffer to process
async with self.buffer_lock:
if buffer_name == "A":
batch = self.buffer_a.copy()
self.buffer_a.clear()
else:
batch = self.buffer_b.copy()
self.buffer_b.clear()
if len(batch) == 0:
# Mark as idle and return
async with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
return
# Process batch (outside lock to allow concurrent submissions)
try:
start_time = time.time()
results = await self._run_batch_inference(batch)
inference_time = time.time() - start_time
# Update statistics
self.total_frames_processed += len(batch)
self.total_batches_processed += 1
logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms "
f"({inference_time*1000/len(batch):.2f}ms per frame)"
)
# Emit results to callbacks
for batch_frame, result in zip(batch, results):
callback = self.result_callbacks.get(batch_frame.stream_id)
if callback:
# Schedule callback asynchronously
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(result))
else:
# Run sync callback in executor to avoid blocking
loop = asyncio.get_event_loop()
loop.call_soon(lambda cb=callback, r=result: cb(r))
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
# TODO: Emit error events to streams
finally:
# Mark buffer as idle
async with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run inference on a batch of frames.
Args:
batch: List of BatchFrame objects
Returns:
List of detection results (one per frame)
"""
loop = asyncio.get_event_loop()
# Check if model supports batching
if self.model_batch_size == 1:
# Process frames one at a time for batch_size=1 models
return await self._run_sequential_inference(batch, loop)
else:
# Use true batching for models that support it
return await self._run_batched_inference(batch, loop)
async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models"""
results = []
for batch_frame in batch:
# Preprocess frame
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
else:
# Ensure we have batch dimension
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
# Run inference for this frame
outputs = await loop.run_in_executor(
None,
lambda p=processed: self.model_repository.infer(
self.model_id,
{"images": p},
synchronize=True
)
)
# Postprocess
if self.postprocess_fn:
try:
detections = self.postprocess_fn(outputs)
except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
# Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
else:
detections = outputs
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
}
results.append(result)
return results
async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
"""Run true batched inference for models that support it"""
# Preprocess frames (on GPU)
preprocessed = []
for batch_frame in batch:
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
# Preprocess may return (1, C, H, W), squeeze to (C, H, W)
if processed.dim() == 4 and processed.shape[0] == 1:
processed = processed.squeeze(0)
else:
processed = batch_frame.frame
preprocessed.append(processed)
# Stack into batch tensor: (N, C, H, W)
batch_tensor = torch.stack(preprocessed, dim=0)
# Limit batch size to model's max batch size
if batch_tensor.shape[0] > self.model_batch_size:
logger.warning(
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}, "
f"will split into sub-batches"
)
# TODO: Handle splitting into sub-batches
batch_tensor = batch_tensor[:self.model_batch_size]
batch = batch[:self.model_batch_size]
# Run inference
outputs = await loop.run_in_executor(
None,
lambda: self.model_repository.infer(
self.model_id,
{"images": batch_tensor},
synchronize=True
)
)
# Postprocess results (split batch back to individual results)
results = []
for i, batch_frame in enumerate(batch):
# Extract single frame output from batch
frame_output = {}
for k, v in outputs.items():
# v has shape (N, ...), extract index i and keep batch dimension
frame_output[k] = v[i:i+1] # Shape: (1, ...)
if self.postprocess_fn:
try:
detections = self.postprocess_fn(frame_output)
except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
# Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
else:
detections = frame_output
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
}
results.append(result)
return results
async def _process_remaining_buffers(self):
"""Process any remaining frames in buffers during shutdown"""
if len(self.buffer_a) > 0:
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
await self._process_buffer("A")
if len(self.buffer_b) > 0:
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
await self._process_buffer("B")
def get_stats(self) -> Dict[str, Any]:
"""Get current buffer statistics"""
return {
"active_buffer": self.active_buffer,
"buffer_a_size": len(self.buffer_a),
"buffer_b_size": len(self.buffer_b),
"buffer_a_state": self.buffer_a_state.value,
"buffer_b_state": self.buffer_b_state.value,
"registered_streams": len(self.result_callbacks),
"total_frames_processed": self.total_frames_processed,
"total_batches_processed": self.total_batches_processed,
"avg_batch_size": (
self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0 else 0
),
}