batch processing/event driven
This commit is contained in:
parent
e71316ef3d
commit
dd57b5a246
7 changed files with 2673 additions and 2 deletions
499
services/model_controller.py
Normal file
499
services/model_controller.py
Normal file
|
|
@ -0,0 +1,499 @@
|
|||
"""
|
||||
ModelController - Async batching layer with ping-pong buffers for inference.
|
||||
|
||||
This module provides batched inference coordination using ping-pong circular buffers
|
||||
with force-switch timeout mechanism.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchFrame:
|
||||
"""Represents a frame in the batch buffer"""
|
||||
stream_id: str
|
||||
frame: torch.Tensor # GPU tensor (3, H, W)
|
||||
timestamp: float
|
||||
metadata: Dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class BufferState(Enum):
|
||||
"""State of a ping-pong buffer"""
|
||||
IDLE = "idle"
|
||||
FILLING = "filling"
|
||||
PROCESSING = "processing"
|
||||
|
||||
|
||||
class ModelController:
|
||||
"""
|
||||
Manages batched inference with ping-pong buffers and force-switch timeout.
|
||||
|
||||
This controller accumulates frames from multiple streams into batches,
|
||||
processes them through a model repository, and routes results back to
|
||||
stream-specific callbacks.
|
||||
|
||||
Features:
|
||||
- Ping-pong circular buffers (BufferA/BufferB)
|
||||
- Force-switch timeout to prevent batch starvation
|
||||
- Async event-driven processing
|
||||
- Thread-safe frame submission
|
||||
|
||||
Args:
|
||||
model_repository: TensorRT model repository for inference
|
||||
model_id: Model identifier in the repository
|
||||
batch_size: Maximum frames per batch (default: 16)
|
||||
force_timeout: Max wait time before forcing buffer switch in seconds (default: 0.05)
|
||||
preprocess_fn: Optional preprocessing function for frames
|
||||
postprocess_fn: Optional postprocessing function for model outputs
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_repository,
|
||||
model_id: str,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
):
|
||||
self.model_repository = model_repository
|
||||
self.model_id = model_id
|
||||
self.batch_size = batch_size
|
||||
self.force_timeout = force_timeout
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self.postprocess_fn = postprocess_fn
|
||||
|
||||
# Detect model's actual batch size from input shape
|
||||
self.model_batch_size = self._detect_model_batch_size()
|
||||
if self.model_batch_size == 1:
|
||||
logger.warning(
|
||||
f"Model '{model_id}' has fixed batch_size=1. "
|
||||
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
|
||||
)
|
||||
else:
|
||||
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}")
|
||||
|
||||
# Ping-pong buffers
|
||||
self.buffer_a: List[BatchFrame] = []
|
||||
self.buffer_b: List[BatchFrame] = []
|
||||
|
||||
# Buffer states
|
||||
self.active_buffer = "A" # Which buffer is currently active (filling)
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
# Async coordination
|
||||
self.buffer_lock = asyncio.Lock()
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Tasks
|
||||
self.timeout_task: Optional[asyncio.Task] = None
|
||||
self.processor_task: Optional[asyncio.Task] = None
|
||||
self.running = False
|
||||
|
||||
# Result callbacks (stream_id -> callback)
|
||||
self.result_callbacks: Dict[str, Callable] = {}
|
||||
|
||||
# Statistics
|
||||
self.total_frames_processed = 0
|
||||
self.total_batches_processed = 0
|
||||
|
||||
def _detect_model_batch_size(self) -> int:
|
||||
"""
|
||||
Detect the model's batch size from its input shape.
|
||||
|
||||
Returns:
|
||||
Maximum batch size supported by the model (1 for fixed batch size models)
|
||||
"""
|
||||
try:
|
||||
metadata = self.model_repository.get_metadata(self.model_id)
|
||||
# Get first input tensor shape
|
||||
first_input = list(metadata.inputs.values())[0]
|
||||
batch_dim = first_input["shape"][0]
|
||||
|
||||
# batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size)
|
||||
if batch_dim == -1:
|
||||
# Dynamic batch size - use user-specified batch_size
|
||||
return self.batch_size
|
||||
else:
|
||||
# Fixed batch size
|
||||
return batch_dim
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
|
||||
return 1
|
||||
|
||||
async def start(self):
|
||||
"""Start the controller background tasks"""
|
||||
if self.running:
|
||||
logger.warning("ModelController already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.timeout_task = asyncio.create_task(self._timeout_monitor())
|
||||
self.processor_task = asyncio.create_task(self._batch_processor())
|
||||
logger.info("ModelController started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the controller and cleanup"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info("Stopping ModelController...")
|
||||
self.running = False
|
||||
|
||||
# Cancel tasks
|
||||
if self.timeout_task:
|
||||
self.timeout_task.cancel()
|
||||
try:
|
||||
await self.timeout_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.processor_task:
|
||||
self.processor_task.cancel()
|
||||
try:
|
||||
await self.processor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Process any remaining frames
|
||||
await self._process_remaining_buffers()
|
||||
logger.info("ModelController stopped")
|
||||
|
||||
def register_callback(self, stream_id: str, callback: Callable):
|
||||
"""
|
||||
Register a callback for inference results from a stream.
|
||||
|
||||
Args:
|
||||
stream_id: Unique stream identifier
|
||||
callback: Callback function to receive results (can be sync or async)
|
||||
"""
|
||||
self.result_callbacks[stream_id] = callback
|
||||
logger.debug(f"Registered callback for stream: {stream_id}")
|
||||
|
||||
def unregister_callback(self, stream_id: str):
|
||||
"""
|
||||
Unregister a stream callback.
|
||||
|
||||
Args:
|
||||
stream_id: Stream identifier to unregister
|
||||
"""
|
||||
self.result_callbacks.pop(stream_id, None)
|
||||
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
||||
|
||||
async def submit_frame(
|
||||
self,
|
||||
stream_id: str,
|
||||
frame: torch.Tensor,
|
||||
metadata: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Submit a frame for batched inference.
|
||||
|
||||
Args:
|
||||
stream_id: Unique stream identifier
|
||||
frame: GPU tensor (3, H, W) or (C, H, W)
|
||||
metadata: Optional metadata to attach to the frame
|
||||
"""
|
||||
async with self.buffer_lock:
|
||||
batch_frame = BatchFrame(
|
||||
stream_id=stream_id,
|
||||
frame=frame,
|
||||
timestamp=time.time(),
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# Add to active buffer
|
||||
if self.active_buffer == "A":
|
||||
self.buffer_a.append(batch_frame)
|
||||
self.buffer_a_state = BufferState.FILLING
|
||||
buffer_size = len(self.buffer_a)
|
||||
else:
|
||||
self.buffer_b.append(batch_frame)
|
||||
self.buffer_b_state = BufferState.FILLING
|
||||
buffer_size = len(self.buffer_b)
|
||||
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Check if we should immediately swap (batch full)
|
||||
if buffer_size >= self.batch_size:
|
||||
await self._try_swap_buffers()
|
||||
|
||||
async def _timeout_monitor(self):
|
||||
"""Monitor force-switch timeout"""
|
||||
while self.running:
|
||||
await asyncio.sleep(0.01) # Check every 10ms
|
||||
|
||||
async with self.buffer_lock:
|
||||
time_since_submit = time.time() - self.last_submit_time
|
||||
|
||||
# Check if timeout expired and we have frames waiting
|
||||
if time_since_submit >= self.force_timeout:
|
||||
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||
if len(active_buffer) > 0:
|
||||
await self._try_swap_buffers()
|
||||
|
||||
async def _try_swap_buffers(self):
|
||||
"""
|
||||
Attempt to swap ping-pong buffers.
|
||||
Only swaps if the inactive buffer is not currently processing.
|
||||
|
||||
This method should be called with buffer_lock held.
|
||||
"""
|
||||
# Check if inactive buffer is available
|
||||
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
||||
|
||||
if inactive_state != BufferState.PROCESSING:
|
||||
# Swap active buffer
|
||||
old_active = self.active_buffer
|
||||
self.active_buffer = "B" if old_active == "A" else "A"
|
||||
|
||||
# Mark old active buffer as ready for processing
|
||||
if old_active == "A":
|
||||
self.buffer_a_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_a)
|
||||
else:
|
||||
self.buffer_b_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_b)
|
||||
|
||||
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
|
||||
|
||||
async def _batch_processor(self):
|
||||
"""Background task that processes batches when available"""
|
||||
while self.running:
|
||||
await asyncio.sleep(0.001) # Check every 1ms
|
||||
|
||||
# Check if buffer A needs processing
|
||||
if self.buffer_a_state == BufferState.PROCESSING:
|
||||
await self._process_buffer("A")
|
||||
|
||||
# Check if buffer B needs processing
|
||||
if self.buffer_b_state == BufferState.PROCESSING:
|
||||
await self._process_buffer("B")
|
||||
|
||||
async def _process_buffer(self, buffer_name: str):
|
||||
"""
|
||||
Process a buffer through inference.
|
||||
|
||||
Args:
|
||||
buffer_name: "A" or "B"
|
||||
"""
|
||||
# Extract buffer to process
|
||||
async with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
batch = self.buffer_a.copy()
|
||||
self.buffer_a.clear()
|
||||
else:
|
||||
batch = self.buffer_b.copy()
|
||||
self.buffer_b.clear()
|
||||
|
||||
if len(batch) == 0:
|
||||
# Mark as idle and return
|
||||
async with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
return
|
||||
|
||||
# Process batch (outside lock to allow concurrent submissions)
|
||||
try:
|
||||
start_time = time.time()
|
||||
results = await self._run_batch_inference(batch)
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# Update statistics
|
||||
self.total_frames_processed += len(batch)
|
||||
self.total_batches_processed += 1
|
||||
|
||||
logger.debug(
|
||||
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms "
|
||||
f"({inference_time*1000/len(batch):.2f}ms per frame)"
|
||||
)
|
||||
|
||||
# Emit results to callbacks
|
||||
for batch_frame, result in zip(batch, results):
|
||||
callback = self.result_callbacks.get(batch_frame.stream_id)
|
||||
if callback:
|
||||
# Schedule callback asynchronously
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
asyncio.create_task(callback(result))
|
||||
else:
|
||||
# Run sync callback in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.call_soon(lambda cb=callback, r=result: cb(r))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||
# TODO: Emit error events to streams
|
||||
|
||||
finally:
|
||||
# Mark buffer as idle
|
||||
async with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run inference on a batch of frames.
|
||||
|
||||
Args:
|
||||
batch: List of BatchFrame objects
|
||||
|
||||
Returns:
|
||||
List of detection results (one per frame)
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Check if model supports batching
|
||||
if self.model_batch_size == 1:
|
||||
# Process frames one at a time for batch_size=1 models
|
||||
return await self._run_sequential_inference(batch, loop)
|
||||
else:
|
||||
# Use true batching for models that support it
|
||||
return await self._run_batched_inference(batch, loop)
|
||||
|
||||
async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
||||
"""Run inference sequentially for batch_size=1 models"""
|
||||
results = []
|
||||
|
||||
for batch_frame in batch:
|
||||
# Preprocess frame
|
||||
if self.preprocess_fn:
|
||||
processed = self.preprocess_fn(batch_frame.frame)
|
||||
else:
|
||||
# Ensure we have batch dimension
|
||||
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
|
||||
|
||||
# Run inference for this frame
|
||||
outputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda p=processed: self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": p},
|
||||
synchronize=True
|
||||
)
|
||||
)
|
||||
|
||||
# Postprocess
|
||||
if self.postprocess_fn:
|
||||
try:
|
||||
detections = self.postprocess_fn(outputs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
|
||||
# Return empty detections on error
|
||||
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
|
||||
else:
|
||||
detections = outputs
|
||||
|
||||
result = {
|
||||
"stream_id": batch_frame.stream_id,
|
||||
"timestamp": batch_frame.timestamp,
|
||||
"detections": detections,
|
||||
"metadata": batch_frame.metadata,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
||||
"""Run true batched inference for models that support it"""
|
||||
# Preprocess frames (on GPU)
|
||||
preprocessed = []
|
||||
for batch_frame in batch:
|
||||
if self.preprocess_fn:
|
||||
processed = self.preprocess_fn(batch_frame.frame)
|
||||
# Preprocess may return (1, C, H, W), squeeze to (C, H, W)
|
||||
if processed.dim() == 4 and processed.shape[0] == 1:
|
||||
processed = processed.squeeze(0)
|
||||
else:
|
||||
processed = batch_frame.frame
|
||||
preprocessed.append(processed)
|
||||
|
||||
# Stack into batch tensor: (N, C, H, W)
|
||||
batch_tensor = torch.stack(preprocessed, dim=0)
|
||||
|
||||
# Limit batch size to model's max batch size
|
||||
if batch_tensor.shape[0] > self.model_batch_size:
|
||||
logger.warning(
|
||||
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}, "
|
||||
f"will split into sub-batches"
|
||||
)
|
||||
# TODO: Handle splitting into sub-batches
|
||||
batch_tensor = batch_tensor[:self.model_batch_size]
|
||||
batch = batch[:self.model_batch_size]
|
||||
|
||||
# Run inference
|
||||
outputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": batch_tensor},
|
||||
synchronize=True
|
||||
)
|
||||
)
|
||||
|
||||
# Postprocess results (split batch back to individual results)
|
||||
results = []
|
||||
for i, batch_frame in enumerate(batch):
|
||||
# Extract single frame output from batch
|
||||
frame_output = {}
|
||||
for k, v in outputs.items():
|
||||
# v has shape (N, ...), extract index i and keep batch dimension
|
||||
frame_output[k] = v[i:i+1] # Shape: (1, ...)
|
||||
|
||||
if self.postprocess_fn:
|
||||
try:
|
||||
detections = self.postprocess_fn(frame_output)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
|
||||
# Return empty detections on error
|
||||
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
|
||||
else:
|
||||
detections = frame_output
|
||||
|
||||
result = {
|
||||
"stream_id": batch_frame.stream_id,
|
||||
"timestamp": batch_frame.timestamp,
|
||||
"detections": detections,
|
||||
"metadata": batch_frame.metadata,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
async def _process_remaining_buffers(self):
|
||||
"""Process any remaining frames in buffers during shutdown"""
|
||||
if len(self.buffer_a) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
||||
await self._process_buffer("A")
|
||||
if len(self.buffer_b) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
||||
await self._process_buffer("B")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current buffer statistics"""
|
||||
return {
|
||||
"active_buffer": self.active_buffer,
|
||||
"buffer_a_size": len(self.buffer_a),
|
||||
"buffer_b_size": len(self.buffer_b),
|
||||
"buffer_a_state": self.buffer_a_state.value,
|
||||
"buffer_b_state": self.buffer_b_state.value,
|
||||
"registered_streams": len(self.result_callbacks),
|
||||
"total_frames_processed": self.total_frames_processed,
|
||||
"total_batches_processed": self.total_batches_processed,
|
||||
"avg_batch_size": (
|
||||
self.total_frames_processed / self.total_batches_processed
|
||||
if self.total_batches_processed > 0 else 0
|
||||
),
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue