247 lines
8.3 KiB
Python
247 lines
8.3 KiB
Python
"""
|
|
Base Model Controller - Simple circular buffer with continuous batch processing.
|
|
|
|
Replaces the complex ping-pong buffer architecture with a simple queue:
|
|
- Frames arrive and go into a single deque (circular buffer)
|
|
- When batch_size frames are ready, process them
|
|
- Continue consuming batches until queue is empty
|
|
- Drop oldest frames if queue is full
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from collections import deque
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BatchFrame:
|
|
"""Represents a frame in the batch buffer"""
|
|
|
|
stream_id: str
|
|
frame: torch.Tensor # GPU tensor (3, H, W)
|
|
timestamp: float
|
|
metadata: Dict = field(default_factory=dict)
|
|
|
|
|
|
class BaseModelController(ABC):
|
|
"""
|
|
Simple batched inference with circular buffer.
|
|
|
|
Architecture:
|
|
- Single deque (circular buffer) for incoming frames
|
|
- Batch processor thread continuously consumes batches
|
|
- Frames come in fast, batches go out as fast as inference allows
|
|
- Automatic frame dropping when queue is full
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
batch_size: int = 16,
|
|
max_queue_size: int = 100,
|
|
preprocess_fn: Optional[Callable] = None,
|
|
postprocess_fn: Optional[Callable] = None,
|
|
):
|
|
self.model_id = model_id
|
|
self.batch_size = batch_size
|
|
self.max_queue_size = max_queue_size
|
|
self.preprocess_fn = preprocess_fn
|
|
self.postprocess_fn = postprocess_fn
|
|
|
|
# Single circular buffer
|
|
self.frame_queue = deque(maxlen=max_queue_size)
|
|
self.queue_lock = threading.Lock()
|
|
|
|
# Processing thread
|
|
self.processor_thread: Optional[threading.Thread] = None
|
|
self.running = False
|
|
self.stop_event = threading.Event()
|
|
|
|
# Result callbacks (stream_id -> callback)
|
|
self.result_callbacks: Dict[str, Callable] = {}
|
|
|
|
# Statistics
|
|
self.total_frames_processed = 0
|
|
self.total_batches_processed = 0
|
|
self.total_frames_dropped = 0
|
|
|
|
def start(self):
|
|
"""Start the controller background thread"""
|
|
if self.running:
|
|
logger.warning(f"{self.__class__.__name__} already running")
|
|
return
|
|
|
|
self.running = True
|
|
self.stop_event.clear()
|
|
|
|
# Start single processor thread
|
|
self.processor_thread = threading.Thread(
|
|
target=self._batch_processor, daemon=True
|
|
)
|
|
self.processor_thread.start()
|
|
|
|
logger.info(f"{self.__class__.__name__} started (batch_size={self.batch_size})")
|
|
|
|
def stop(self):
|
|
"""Stop the controller and cleanup"""
|
|
if not self.running:
|
|
return
|
|
|
|
logger.info(f"Stopping {self.__class__.__name__}...")
|
|
self.running = False
|
|
self.stop_event.set()
|
|
|
|
# Wait for thread to finish
|
|
if self.processor_thread and self.processor_thread.is_alive():
|
|
self.processor_thread.join(timeout=2.0)
|
|
|
|
# Process remaining frames
|
|
self._process_remaining_frames()
|
|
logger.info(f"{self.__class__.__name__} stopped")
|
|
|
|
def register_callback(self, stream_id: str, callback: Callable):
|
|
"""Register a callback for inference results from a stream"""
|
|
self.result_callbacks[stream_id] = callback
|
|
logger.debug(f"Registered callback for stream: {stream_id}")
|
|
|
|
def unregister_callback(self, stream_id: str):
|
|
"""Unregister a stream callback"""
|
|
self.result_callbacks.pop(stream_id, None)
|
|
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
|
|
|
def submit_frame(
|
|
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
|
):
|
|
"""Submit a frame for batched inference"""
|
|
with self.queue_lock:
|
|
# If queue is full, oldest frame is automatically dropped (deque with maxlen)
|
|
if len(self.frame_queue) >= self.max_queue_size:
|
|
self.total_frames_dropped += 1
|
|
|
|
batch_frame = BatchFrame(
|
|
stream_id=stream_id,
|
|
frame=frame,
|
|
timestamp=time.time(),
|
|
metadata=metadata or {},
|
|
)
|
|
self.frame_queue.append(batch_frame)
|
|
|
|
def _batch_processor(self):
|
|
"""Background thread that continuously processes batches"""
|
|
logger.info(f"{self.__class__.__name__} batch processor started")
|
|
|
|
while self.running and not self.stop_event.is_set():
|
|
# Check if we have enough frames for a batch
|
|
with self.queue_lock:
|
|
queue_size = len(self.frame_queue)
|
|
|
|
if queue_size > 0 and queue_size % 10 == 0:
|
|
logger.info(f"Queue size: {queue_size}/{self.batch_size}")
|
|
|
|
if queue_size >= self.batch_size:
|
|
# Extract batch
|
|
with self.queue_lock:
|
|
batch = []
|
|
for _ in range(min(self.batch_size, len(self.frame_queue))):
|
|
if self.frame_queue:
|
|
batch.append(self.frame_queue.popleft())
|
|
|
|
if batch:
|
|
logger.info(f"Processing batch of {len(batch)} frames")
|
|
self._process_batch(batch)
|
|
else:
|
|
# Not enough frames, sleep briefly
|
|
time.sleep(0.001) # 1ms
|
|
|
|
def _process_batch(self, batch: List[BatchFrame]):
|
|
"""Process a batch through inference"""
|
|
try:
|
|
start_time = time.time()
|
|
results = self._run_batch_inference(batch)
|
|
inference_time = time.time() - start_time
|
|
|
|
self.total_frames_processed += len(batch)
|
|
self.total_batches_processed += 1
|
|
|
|
logger.debug(
|
|
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms"
|
|
)
|
|
|
|
# Emit results to callbacks
|
|
for batch_frame, result in zip(batch, results):
|
|
callback = self.result_callbacks.get(batch_frame.stream_id)
|
|
if callback:
|
|
try:
|
|
callback(result)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error in callback for {batch_frame.stream_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing batch: {e}", exc_info=True)
|
|
|
|
@abstractmethod
|
|
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
|
"""Run inference on a batch of frames (backend-specific)"""
|
|
pass
|
|
|
|
def _process_remaining_frames(self):
|
|
"""Process any remaining frames in queue during shutdown"""
|
|
with self.queue_lock:
|
|
remaining = len(self.frame_queue)
|
|
|
|
if remaining > 0:
|
|
logger.info(f"Processing remaining {remaining} frames")
|
|
while True:
|
|
with self.queue_lock:
|
|
if not self.frame_queue:
|
|
break
|
|
batch = []
|
|
for _ in range(min(self.batch_size, len(self.frame_queue))):
|
|
if self.frame_queue:
|
|
batch.append(self.frame_queue.popleft())
|
|
|
|
if batch:
|
|
self._process_batch(batch)
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get current statistics"""
|
|
with self.queue_lock:
|
|
queue_size = len(self.frame_queue)
|
|
|
|
return {
|
|
"queue_size": queue_size,
|
|
"max_queue_size": self.max_queue_size,
|
|
"batch_size": self.batch_size,
|
|
"registered_streams": len(self.result_callbacks),
|
|
"total_frames_processed": self.total_frames_processed,
|
|
"total_batches_processed": self.total_batches_processed,
|
|
"total_frames_dropped": self.total_frames_dropped,
|
|
"avg_batch_size": (
|
|
self.total_frames_processed / self.total_batches_processed
|
|
if self.total_batches_processed > 0
|
|
else 0
|
|
),
|
|
}
|
|
|
|
|
|
# Keep old BufferState enum for backwards compatibility
|
|
from enum import Enum
|
|
|
|
|
|
class BufferState(Enum):
|
|
"""Deprecated - kept for backwards compatibility"""
|
|
|
|
IDLE = "idle"
|
|
FILLING = "filling"
|
|
PROCESSING = "processing"
|