event driven system

This commit is contained in:
Siwat Sirichai 2025-11-10 11:51:06 +07:00
parent 0c5f56c8a6
commit 3a47920186
10 changed files with 782 additions and 253 deletions

View file

@ -1,17 +1,18 @@
"""
ModelController - Async batching layer with ping-pong buffers for inference.
ModelController - Event-driven batching layer with ping-pong buffers for inference.
This module provides batched inference coordination using ping-pong circular buffers
with force-switch timeout mechanism.
with force-switch timeout mechanism using threading and callbacks.
"""
import asyncio
import threading
import torch
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import time
import logging
import queue
logger = logging.getLogger(__name__)
@ -43,7 +44,7 @@ class ModelController:
Features:
- Ping-pong circular buffers (BufferA/BufferB)
- Force-switch timeout to prevent batch starvation
- Async event-driven processing
- Event-driven processing with callbacks
- Thread-safe frame submission
Args:
@ -90,14 +91,15 @@ class ModelController:
self.buffer_a_state = BufferState.IDLE
self.buffer_b_state = BufferState.IDLE
# Async coordination
self.buffer_lock = asyncio.Lock()
# Threading coordination
self.buffer_lock = threading.RLock()
self.last_submit_time = time.time()
# Tasks
self.timeout_task: Optional[asyncio.Task] = None
self.processor_task: Optional[asyncio.Task] = None
# Threads
self.timeout_thread: Optional[threading.Thread] = None
self.processor_threads: Dict[str, threading.Thread] = {}
self.running = False
self.stop_event = threading.Event()
# Result callbacks (stream_id -> callback)
self.result_callbacks: Dict[str, Callable] = {}
@ -130,42 +132,46 @@ class ModelController:
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
return 1
async def start(self):
"""Start the controller background tasks"""
def start(self):
"""Start the controller background threads"""
if self.running:
logger.warning("ModelController already running")
return
self.running = True
self.timeout_task = asyncio.create_task(self._timeout_monitor())
self.processor_task = asyncio.create_task(self._batch_processor())
self.stop_event.clear()
# Start timeout monitor thread
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True)
self.timeout_thread.start()
# Start processor threads for each buffer
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True)
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True)
self.processor_threads['A'].start()
self.processor_threads['B'].start()
logger.info("ModelController started")
async def stop(self):
def stop(self):
"""Stop the controller and cleanup"""
if not self.running:
return
logger.info("Stopping ModelController...")
self.running = False
self.stop_event.set()
# Cancel tasks
if self.timeout_task:
self.timeout_task.cancel()
try:
await self.timeout_task
except asyncio.CancelledError:
pass
# Wait for threads to finish
if self.timeout_thread and self.timeout_thread.is_alive():
self.timeout_thread.join(timeout=2.0)
if self.processor_task:
self.processor_task.cancel()
try:
await self.processor_task
except asyncio.CancelledError:
pass
for thread in self.processor_threads.values():
if thread and thread.is_alive():
thread.join(timeout=2.0)
# Process any remaining frames
await self._process_remaining_buffers()
self._process_remaining_buffers()
logger.info("ModelController stopped")
def register_callback(self, stream_id: str, callback: Callable):
@ -189,7 +195,7 @@ class ModelController:
self.result_callbacks.pop(stream_id, None)
logger.debug(f"Unregistered callback for stream: {stream_id}")
async def submit_frame(
def submit_frame(
self,
stream_id: str,
frame: torch.Tensor,
@ -203,7 +209,7 @@ class ModelController:
frame: GPU tensor (3, H, W) or (C, H, W)
metadata: Optional metadata to attach to the frame
"""
async with self.buffer_lock:
with self.buffer_lock:
batch_frame = BatchFrame(
stream_id=stream_id,
frame=frame,
@ -225,23 +231,21 @@ class ModelController:
# Check if we should immediately swap (batch full)
if buffer_size >= self.batch_size:
await self._try_swap_buffers()
self._try_swap_buffers()
async def _timeout_monitor(self):
def _timeout_monitor(self):
"""Monitor force-switch timeout"""
while self.running:
await asyncio.sleep(0.01) # Check every 10ms
async with self.buffer_lock:
while self.running and not self.stop_event.wait(0.01): # Check every 10ms
with self.buffer_lock:
time_since_submit = time.time() - self.last_submit_time
# Check if timeout expired and we have frames waiting
if time_since_submit >= self.force_timeout:
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
if len(active_buffer) > 0:
await self._try_swap_buffers()
self._try_swap_buffers()
async def _try_swap_buffers(self):
def _try_swap_buffers(self):
"""
Attempt to swap ping-pong buffers.
Only swaps if the inactive buffer is not currently processing.
@ -266,20 +270,22 @@ class ModelController:
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
async def _batch_processor(self):
"""Background task that processes batches when available"""
while self.running:
await asyncio.sleep(0.001) # Check every 1ms
def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available"""
while self.running and not self.stop_event.is_set():
time.sleep(0.001) # Check every 1ms
# Check if buffer A needs processing
if self.buffer_a_state == BufferState.PROCESSING:
await self._process_buffer("A")
# Check if this buffer needs processing
with self.buffer_lock:
if buffer_name == "A":
should_process = self.buffer_a_state == BufferState.PROCESSING
else:
should_process = self.buffer_b_state == BufferState.PROCESSING
# Check if buffer B needs processing
if self.buffer_b_state == BufferState.PROCESSING:
await self._process_buffer("B")
if should_process:
self._process_buffer(buffer_name)
async def _process_buffer(self, buffer_name: str):
def _process_buffer(self, buffer_name: str):
"""
Process a buffer through inference.
@ -287,7 +293,7 @@ class ModelController:
buffer_name: "A" or "B"
"""
# Extract buffer to process
async with self.buffer_lock:
with self.buffer_lock:
if buffer_name == "A":
batch = self.buffer_a.copy()
self.buffer_a.clear()
@ -297,7 +303,7 @@ class ModelController:
if len(batch) == 0:
# Mark as idle and return
async with self.buffer_lock:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
@ -307,7 +313,7 @@ class ModelController:
# Process batch (outside lock to allow concurrent submissions)
try:
start_time = time.time()
results = await self._run_batch_inference(batch)
results = self._run_batch_inference(batch)
inference_time = time.time() - start_time
# Update statistics
@ -323,27 +329,24 @@ class ModelController:
for batch_frame, result in zip(batch, results):
callback = self.result_callbacks.get(batch_frame.stream_id)
if callback:
# Schedule callback asynchronously
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(result))
else:
# Run sync callback in executor to avoid blocking
loop = asyncio.get_event_loop()
loop.call_soon(lambda cb=callback, r=result: cb(r))
# Call callback directly (synchronous)
try:
callback(result)
except Exception as e:
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True)
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
# TODO: Emit error events to streams
finally:
# Mark buffer as idle
async with self.buffer_lock:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run inference on a batch of frames.
@ -353,17 +356,15 @@ class ModelController:
Returns:
List of detection results (one per frame)
"""
loop = asyncio.get_event_loop()
# Check if model supports batching
if self.model_batch_size == 1:
# Process frames one at a time for batch_size=1 models
return await self._run_sequential_inference(batch, loop)
return self._run_sequential_inference(batch)
else:
# Use true batching for models that support it
return await self._run_batched_inference(batch, loop)
return self._run_batched_inference(batch)
async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models"""
results = []
@ -376,13 +377,10 @@ class ModelController:
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
# Run inference for this frame
outputs = await loop.run_in_executor(
None,
lambda p=processed: self.model_repository.infer(
self.model_id,
{"images": p},
synchronize=True
)
outputs = self.model_repository.infer(
self.model_id,
{"images": processed},
synchronize=True
)
# Postprocess
@ -406,7 +404,7 @@ class ModelController:
return results
async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run true batched inference for models that support it"""
# Preprocess frames (on GPU)
preprocessed = []
@ -434,13 +432,10 @@ class ModelController:
batch = batch[:self.model_batch_size]
# Run inference
outputs = await loop.run_in_executor(
None,
lambda: self.model_repository.infer(
self.model_id,
{"images": batch_tensor},
synchronize=True
)
outputs = self.model_repository.infer(
self.model_id,
{"images": batch_tensor},
synchronize=True
)
# Postprocess results (split batch back to individual results)
@ -472,14 +467,14 @@ class ModelController:
return results
async def _process_remaining_buffers(self):
def _process_remaining_buffers(self):
"""Process any remaining frames in buffers during shutdown"""
if len(self.buffer_a) > 0:
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
await self._process_buffer("A")
self._process_buffer("A")
if len(self.buffer_b) > 0:
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
await self._process_buffer("B")
self._process_buffer("B")
def get_stats(self) -> Dict[str, Any]:
"""Get current buffer statistics"""