python-rtsp-worker/services/model_controller.py

495 lines
18 KiB
Python

"""
ModelController - Event-driven batching layer with ping-pong buffers for inference.
This module provides batched inference coordination using ping-pong circular buffers
with force-switch timeout mechanism using threading and callbacks.
"""
import threading
import torch
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import time
import logging
import queue
logger = logging.getLogger(__name__)
@dataclass
class BatchFrame:
"""Represents a frame in the batch buffer"""
stream_id: str
frame: torch.Tensor # GPU tensor (3, H, W)
timestamp: float
metadata: Dict = field(default_factory=dict)
class BufferState(Enum):
"""State of a ping-pong buffer"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"
class ModelController:
"""
Manages batched inference with ping-pong buffers and force-switch timeout.
This controller accumulates frames from multiple streams into batches,
processes them through a model repository, and routes results back to
stream-specific callbacks.
Features:
- Ping-pong circular buffers (BufferA/BufferB)
- Force-switch timeout to prevent batch starvation
- Event-driven processing with callbacks
- Thread-safe frame submission
Args:
model_repository: TensorRT model repository for inference
model_id: Model identifier in the repository
batch_size: Maximum frames per batch (default: 16)
force_timeout: Max wait time before forcing buffer switch in seconds (default: 0.05)
preprocess_fn: Optional preprocessing function for frames
postprocess_fn: Optional postprocessing function for model outputs
"""
def __init__(
self,
model_repository,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
self.model_repository = model_repository
self.model_id = model_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn
# Detect model's actual batch size from input shape
self.model_batch_size = self._detect_model_batch_size()
if self.model_batch_size == 1:
logger.warning(
f"Model '{model_id}' has fixed batch_size=1. "
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
)
else:
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}")
# Ping-pong buffers
self.buffer_a: List[BatchFrame] = []
self.buffer_b: List[BatchFrame] = []
# Buffer states
self.active_buffer = "A" # Which buffer is currently active (filling)
self.buffer_a_state = BufferState.IDLE
self.buffer_b_state = BufferState.IDLE
# Threading coordination
self.buffer_lock = threading.RLock()
self.last_submit_time = time.time()
# Threads
self.timeout_thread: Optional[threading.Thread] = None
self.processor_threads: Dict[str, threading.Thread] = {}
self.running = False
self.stop_event = threading.Event()
# Result callbacks (stream_id -> callback)
self.result_callbacks: Dict[str, Callable] = {}
# Statistics
self.total_frames_processed = 0
self.total_batches_processed = 0
def _detect_model_batch_size(self) -> int:
"""
Detect the model's batch size from its input shape.
Returns:
Maximum batch size supported by the model (1 for fixed batch size models)
"""
try:
metadata = self.model_repository.get_metadata(self.model_id)
# Get first input tensor shape (ModelMetadata has input_shapes, not inputs)
first_input_name = metadata.input_names[0]
input_shape = metadata.input_shapes[first_input_name]
batch_dim = input_shape[0]
# batch_dim can be -1 (dynamic), 1 (fixed), or N (fixed batch size)
if batch_dim == -1:
# Dynamic batch size - use user-specified batch_size
return self.batch_size
else:
# Fixed batch size
return batch_dim
except Exception as e:
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
return 1
def start(self):
"""Start the controller background threads"""
if self.running:
logger.warning("ModelController already running")
return
self.running = True
self.stop_event.clear()
# Start timeout monitor thread
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True)
self.timeout_thread.start()
# Start processor threads for each buffer
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True)
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True)
self.processor_threads['A'].start()
self.processor_threads['B'].start()
logger.info("ModelController started")
def stop(self):
"""Stop the controller and cleanup"""
if not self.running:
return
logger.info("Stopping ModelController...")
self.running = False
self.stop_event.set()
# Wait for threads to finish
if self.timeout_thread and self.timeout_thread.is_alive():
self.timeout_thread.join(timeout=2.0)
for thread in self.processor_threads.values():
if thread and thread.is_alive():
thread.join(timeout=2.0)
# Process any remaining frames
self._process_remaining_buffers()
logger.info("ModelController stopped")
def register_callback(self, stream_id: str, callback: Callable):
"""
Register a callback for inference results from a stream.
Args:
stream_id: Unique stream identifier
callback: Callback function to receive results (can be sync or async)
"""
self.result_callbacks[stream_id] = callback
logger.debug(f"Registered callback for stream: {stream_id}")
def unregister_callback(self, stream_id: str):
"""
Unregister a stream callback.
Args:
stream_id: Stream identifier to unregister
"""
self.result_callbacks.pop(stream_id, None)
logger.debug(f"Unregistered callback for stream: {stream_id}")
def submit_frame(
self,
stream_id: str,
frame: torch.Tensor,
metadata: Optional[Dict] = None
):
"""
Submit a frame for batched inference.
Args:
stream_id: Unique stream identifier
frame: GPU tensor (3, H, W) or (C, H, W)
metadata: Optional metadata to attach to the frame
"""
with self.buffer_lock:
batch_frame = BatchFrame(
stream_id=stream_id,
frame=frame,
timestamp=time.time(),
metadata=metadata or {}
)
# Add to active buffer
if self.active_buffer == "A":
self.buffer_a.append(batch_frame)
self.buffer_a_state = BufferState.FILLING
buffer_size = len(self.buffer_a)
else:
self.buffer_b.append(batch_frame)
self.buffer_b_state = BufferState.FILLING
buffer_size = len(self.buffer_b)
self.last_submit_time = time.time()
# Check if we should immediately swap (batch full)
if buffer_size >= self.batch_size:
self._try_swap_buffers()
def _timeout_monitor(self):
"""Monitor force-switch timeout"""
while self.running and not self.stop_event.wait(0.01): # Check every 10ms
with self.buffer_lock:
time_since_submit = time.time() - self.last_submit_time
# Check if timeout expired and we have frames waiting
if time_since_submit >= self.force_timeout:
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
if len(active_buffer) > 0:
self._try_swap_buffers()
def _try_swap_buffers(self):
"""
Attempt to swap ping-pong buffers.
Only swaps if the inactive buffer is not currently processing.
This method should be called with buffer_lock held.
"""
# Check if inactive buffer is available
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
if inactive_state != BufferState.PROCESSING:
# Swap active buffer
old_active = self.active_buffer
self.active_buffer = "B" if old_active == "A" else "A"
# Mark old active buffer as ready for processing
if old_active == "A":
self.buffer_a_state = BufferState.PROCESSING
buffer_size = len(self.buffer_a)
else:
self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b)
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available"""
while self.running and not self.stop_event.is_set():
time.sleep(0.001) # Check every 1ms
# Check if this buffer needs processing
with self.buffer_lock:
if buffer_name == "A":
should_process = self.buffer_a_state == BufferState.PROCESSING
else:
should_process = self.buffer_b_state == BufferState.PROCESSING
if should_process:
self._process_buffer(buffer_name)
def _process_buffer(self, buffer_name: str):
"""
Process a buffer through inference.
Args:
buffer_name: "A" or "B"
"""
# Extract buffer to process
with self.buffer_lock:
if buffer_name == "A":
batch = self.buffer_a.copy()
self.buffer_a.clear()
else:
batch = self.buffer_b.copy()
self.buffer_b.clear()
if len(batch) == 0:
# Mark as idle and return
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
return
# Process batch (outside lock to allow concurrent submissions)
try:
start_time = time.time()
results = self._run_batch_inference(batch)
inference_time = time.time() - start_time
# Update statistics
self.total_frames_processed += len(batch)
self.total_batches_processed += 1
logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms "
f"({inference_time*1000/len(batch):.2f}ms per frame)"
)
# Emit results to callbacks
for batch_frame, result in zip(batch, results):
callback = self.result_callbacks.get(batch_frame.stream_id)
if callback:
# Call callback directly (synchronous)
try:
callback(result)
except Exception as e:
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True)
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
finally:
# Mark buffer as idle
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run inference on a batch of frames.
Args:
batch: List of BatchFrame objects
Returns:
List of detection results (one per frame)
"""
# Check if model supports batching
if self.model_batch_size == 1:
# Process frames one at a time for batch_size=1 models
return self._run_sequential_inference(batch)
else:
# Use true batching for models that support it
return self._run_batched_inference(batch)
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models"""
results = []
for batch_frame in batch:
# Preprocess frame
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
else:
# Ensure we have batch dimension
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
# Run inference for this frame
outputs = self.model_repository.infer(
self.model_id,
{"images": processed},
synchronize=True
)
# Postprocess
if self.postprocess_fn:
try:
detections = self.postprocess_fn(outputs)
except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
# Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
else:
detections = outputs
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
}
results.append(result)
return results
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run true batched inference for models that support it"""
# Preprocess frames (on GPU)
preprocessed = []
for batch_frame in batch:
if self.preprocess_fn:
processed = self.preprocess_fn(batch_frame.frame)
# Preprocess may return (1, C, H, W), squeeze to (C, H, W)
if processed.dim() == 4 and processed.shape[0] == 1:
processed = processed.squeeze(0)
else:
processed = batch_frame.frame
preprocessed.append(processed)
# Stack into batch tensor: (N, C, H, W)
batch_tensor = torch.stack(preprocessed, dim=0)
# Limit batch size to model's max batch size
if batch_tensor.shape[0] > self.model_batch_size:
logger.warning(
f"Batch size {batch_tensor.shape[0]} exceeds model max {self.model_batch_size}, "
f"will split into sub-batches"
)
# TODO: Handle splitting into sub-batches
batch_tensor = batch_tensor[:self.model_batch_size]
batch = batch[:self.model_batch_size]
# Run inference
outputs = self.model_repository.infer(
self.model_id,
{"images": batch_tensor},
synchronize=True
)
# Postprocess results (split batch back to individual results)
results = []
for i, batch_frame in enumerate(batch):
# Extract single frame output from batch
frame_output = {}
for k, v in outputs.items():
# v has shape (N, ...), extract index i and keep batch dimension
frame_output[k] = v[i:i+1] # Shape: (1, ...)
if self.postprocess_fn:
try:
detections = self.postprocess_fn(frame_output)
except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
# Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
else:
detections = frame_output
result = {
"stream_id": batch_frame.stream_id,
"timestamp": batch_frame.timestamp,
"detections": detections,
"metadata": batch_frame.metadata,
}
results.append(result)
return results
def _process_remaining_buffers(self):
"""Process any remaining frames in buffers during shutdown"""
if len(self.buffer_a) > 0:
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
self._process_buffer("A")
if len(self.buffer_b) > 0:
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
self._process_buffer("B")
def get_stats(self) -> Dict[str, Any]:
"""Get current buffer statistics"""
return {
"active_buffer": self.active_buffer,
"buffer_a_size": len(self.buffer_a),
"buffer_b_size": len(self.buffer_b),
"buffer_a_state": self.buffer_a_state.value,
"buffer_b_state": self.buffer_b_state.value,
"registered_streams": len(self.result_callbacks),
"total_frames_processed": self.total_frames_processed,
"total_batches_processed": self.total_batches_processed,
"avg_batch_size": (
self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0 else 0
),
}