event driven system
This commit is contained in:
parent
0c5f56c8a6
commit
3a47920186
10 changed files with 782 additions and 253 deletions
|
|
@ -1,17 +1,18 @@
|
|||
"""
|
||||
ModelController - Async batching layer with ping-pong buffers for inference.
|
||||
ModelController - Event-driven batching layer with ping-pong buffers for inference.
|
||||
|
||||
This module provides batched inference coordination using ping-pong circular buffers
|
||||
with force-switch timeout mechanism.
|
||||
with force-switch timeout mechanism using threading and callbacks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import time
|
||||
import logging
|
||||
import queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -43,7 +44,7 @@ class ModelController:
|
|||
Features:
|
||||
- Ping-pong circular buffers (BufferA/BufferB)
|
||||
- Force-switch timeout to prevent batch starvation
|
||||
- Async event-driven processing
|
||||
- Event-driven processing with callbacks
|
||||
- Thread-safe frame submission
|
||||
|
||||
Args:
|
||||
|
|
@ -90,14 +91,15 @@ class ModelController:
|
|||
self.buffer_a_state = BufferState.IDLE
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
# Async coordination
|
||||
self.buffer_lock = asyncio.Lock()
|
||||
# Threading coordination
|
||||
self.buffer_lock = threading.RLock()
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Tasks
|
||||
self.timeout_task: Optional[asyncio.Task] = None
|
||||
self.processor_task: Optional[asyncio.Task] = None
|
||||
# Threads
|
||||
self.timeout_thread: Optional[threading.Thread] = None
|
||||
self.processor_threads: Dict[str, threading.Thread] = {}
|
||||
self.running = False
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
# Result callbacks (stream_id -> callback)
|
||||
self.result_callbacks: Dict[str, Callable] = {}
|
||||
|
|
@ -130,42 +132,46 @@ class ModelController:
|
|||
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
|
||||
return 1
|
||||
|
||||
async def start(self):
|
||||
"""Start the controller background tasks"""
|
||||
def start(self):
|
||||
"""Start the controller background threads"""
|
||||
if self.running:
|
||||
logger.warning("ModelController already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.timeout_task = asyncio.create_task(self._timeout_monitor())
|
||||
self.processor_task = asyncio.create_task(self._batch_processor())
|
||||
self.stop_event.clear()
|
||||
|
||||
# Start timeout monitor thread
|
||||
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True)
|
||||
self.timeout_thread.start()
|
||||
|
||||
# Start processor threads for each buffer
|
||||
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True)
|
||||
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True)
|
||||
self.processor_threads['A'].start()
|
||||
self.processor_threads['B'].start()
|
||||
|
||||
logger.info("ModelController started")
|
||||
|
||||
async def stop(self):
|
||||
def stop(self):
|
||||
"""Stop the controller and cleanup"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info("Stopping ModelController...")
|
||||
self.running = False
|
||||
self.stop_event.set()
|
||||
|
||||
# Cancel tasks
|
||||
if self.timeout_task:
|
||||
self.timeout_task.cancel()
|
||||
try:
|
||||
await self.timeout_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Wait for threads to finish
|
||||
if self.timeout_thread and self.timeout_thread.is_alive():
|
||||
self.timeout_thread.join(timeout=2.0)
|
||||
|
||||
if self.processor_task:
|
||||
self.processor_task.cancel()
|
||||
try:
|
||||
await self.processor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
for thread in self.processor_threads.values():
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
# Process any remaining frames
|
||||
await self._process_remaining_buffers()
|
||||
self._process_remaining_buffers()
|
||||
logger.info("ModelController stopped")
|
||||
|
||||
def register_callback(self, stream_id: str, callback: Callable):
|
||||
|
|
@ -189,7 +195,7 @@ class ModelController:
|
|||
self.result_callbacks.pop(stream_id, None)
|
||||
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
||||
|
||||
async def submit_frame(
|
||||
def submit_frame(
|
||||
self,
|
||||
stream_id: str,
|
||||
frame: torch.Tensor,
|
||||
|
|
@ -203,7 +209,7 @@ class ModelController:
|
|||
frame: GPU tensor (3, H, W) or (C, H, W)
|
||||
metadata: Optional metadata to attach to the frame
|
||||
"""
|
||||
async with self.buffer_lock:
|
||||
with self.buffer_lock:
|
||||
batch_frame = BatchFrame(
|
||||
stream_id=stream_id,
|
||||
frame=frame,
|
||||
|
|
@ -225,23 +231,21 @@ class ModelController:
|
|||
|
||||
# Check if we should immediately swap (batch full)
|
||||
if buffer_size >= self.batch_size:
|
||||
await self._try_swap_buffers()
|
||||
self._try_swap_buffers()
|
||||
|
||||
async def _timeout_monitor(self):
|
||||
def _timeout_monitor(self):
|
||||
"""Monitor force-switch timeout"""
|
||||
while self.running:
|
||||
await asyncio.sleep(0.01) # Check every 10ms
|
||||
|
||||
async with self.buffer_lock:
|
||||
while self.running and not self.stop_event.wait(0.01): # Check every 10ms
|
||||
with self.buffer_lock:
|
||||
time_since_submit = time.time() - self.last_submit_time
|
||||
|
||||
# Check if timeout expired and we have frames waiting
|
||||
if time_since_submit >= self.force_timeout:
|
||||
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||
if len(active_buffer) > 0:
|
||||
await self._try_swap_buffers()
|
||||
self._try_swap_buffers()
|
||||
|
||||
async def _try_swap_buffers(self):
|
||||
def _try_swap_buffers(self):
|
||||
"""
|
||||
Attempt to swap ping-pong buffers.
|
||||
Only swaps if the inactive buffer is not currently processing.
|
||||
|
|
@ -266,20 +270,22 @@ class ModelController:
|
|||
|
||||
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
|
||||
|
||||
async def _batch_processor(self):
|
||||
"""Background task that processes batches when available"""
|
||||
while self.running:
|
||||
await asyncio.sleep(0.001) # Check every 1ms
|
||||
def _batch_processor(self, buffer_name: str):
|
||||
"""Background thread that processes a specific buffer when available"""
|
||||
while self.running and not self.stop_event.is_set():
|
||||
time.sleep(0.001) # Check every 1ms
|
||||
|
||||
# Check if buffer A needs processing
|
||||
if self.buffer_a_state == BufferState.PROCESSING:
|
||||
await self._process_buffer("A")
|
||||
# Check if this buffer needs processing
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
should_process = self.buffer_a_state == BufferState.PROCESSING
|
||||
else:
|
||||
should_process = self.buffer_b_state == BufferState.PROCESSING
|
||||
|
||||
# Check if buffer B needs processing
|
||||
if self.buffer_b_state == BufferState.PROCESSING:
|
||||
await self._process_buffer("B")
|
||||
if should_process:
|
||||
self._process_buffer(buffer_name)
|
||||
|
||||
async def _process_buffer(self, buffer_name: str):
|
||||
def _process_buffer(self, buffer_name: str):
|
||||
"""
|
||||
Process a buffer through inference.
|
||||
|
||||
|
|
@ -287,7 +293,7 @@ class ModelController:
|
|||
buffer_name: "A" or "B"
|
||||
"""
|
||||
# Extract buffer to process
|
||||
async with self.buffer_lock:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
batch = self.buffer_a.copy()
|
||||
self.buffer_a.clear()
|
||||
|
|
@ -297,7 +303,7 @@ class ModelController:
|
|||
|
||||
if len(batch) == 0:
|
||||
# Mark as idle and return
|
||||
async with self.buffer_lock:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
|
|
@ -307,7 +313,7 @@ class ModelController:
|
|||
# Process batch (outside lock to allow concurrent submissions)
|
||||
try:
|
||||
start_time = time.time()
|
||||
results = await self._run_batch_inference(batch)
|
||||
results = self._run_batch_inference(batch)
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# Update statistics
|
||||
|
|
@ -323,27 +329,24 @@ class ModelController:
|
|||
for batch_frame, result in zip(batch, results):
|
||||
callback = self.result_callbacks.get(batch_frame.stream_id)
|
||||
if callback:
|
||||
# Schedule callback asynchronously
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
asyncio.create_task(callback(result))
|
||||
else:
|
||||
# Run sync callback in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.call_soon(lambda cb=callback, r=result: cb(r))
|
||||
# Call callback directly (synchronous)
|
||||
try:
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||
# TODO: Emit error events to streams
|
||||
|
||||
finally:
|
||||
# Mark buffer as idle
|
||||
async with self.buffer_lock:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
async def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run inference on a batch of frames.
|
||||
|
||||
|
|
@ -353,17 +356,15 @@ class ModelController:
|
|||
Returns:
|
||||
List of detection results (one per frame)
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Check if model supports batching
|
||||
if self.model_batch_size == 1:
|
||||
# Process frames one at a time for batch_size=1 models
|
||||
return await self._run_sequential_inference(batch, loop)
|
||||
return self._run_sequential_inference(batch)
|
||||
else:
|
||||
# Use true batching for models that support it
|
||||
return await self._run_batched_inference(batch, loop)
|
||||
return self._run_batched_inference(batch)
|
||||
|
||||
async def _run_sequential_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
||||
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""Run inference sequentially for batch_size=1 models"""
|
||||
results = []
|
||||
|
||||
|
|
@ -376,13 +377,10 @@ class ModelController:
|
|||
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
|
||||
|
||||
# Run inference for this frame
|
||||
outputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda p=processed: self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": p},
|
||||
synchronize=True
|
||||
)
|
||||
outputs = self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": processed},
|
||||
synchronize=True
|
||||
)
|
||||
|
||||
# Postprocess
|
||||
|
|
@ -406,7 +404,7 @@ class ModelController:
|
|||
|
||||
return results
|
||||
|
||||
async def _run_batched_inference(self, batch: List[BatchFrame], loop) -> List[Dict[str, Any]]:
|
||||
def _run_batched_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""Run true batched inference for models that support it"""
|
||||
# Preprocess frames (on GPU)
|
||||
preprocessed = []
|
||||
|
|
@ -434,13 +432,10 @@ class ModelController:
|
|||
batch = batch[:self.model_batch_size]
|
||||
|
||||
# Run inference
|
||||
outputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": batch_tensor},
|
||||
synchronize=True
|
||||
)
|
||||
outputs = self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": batch_tensor},
|
||||
synchronize=True
|
||||
)
|
||||
|
||||
# Postprocess results (split batch back to individual results)
|
||||
|
|
@ -472,14 +467,14 @@ class ModelController:
|
|||
|
||||
return results
|
||||
|
||||
async def _process_remaining_buffers(self):
|
||||
def _process_remaining_buffers(self):
|
||||
"""Process any remaining frames in buffers during shutdown"""
|
||||
if len(self.buffer_a) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
||||
await self._process_buffer("A")
|
||||
self._process_buffer("A")
|
||||
if len(self.buffer_b) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
||||
await self._process_buffer("B")
|
||||
self._process_buffer("B")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current buffer statistics"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue