python-rtsp-worker/services/base_model_controller.py

247 lines
8.3 KiB
Python

"""
Base Model Controller - Simple circular buffer with continuous batch processing.
Replaces the complex ping-pong buffer architecture with a simple queue:
- Frames arrive and go into a single deque (circular buffer)
- When batch_size frames are ready, process them
- Continue consuming batches until queue is empty
- Drop oldest frames if queue is full
"""
import logging
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
import torch
logger = logging.getLogger(__name__)
@dataclass
class BatchFrame:
"""Represents a frame in the batch buffer"""
stream_id: str
frame: torch.Tensor # GPU tensor (3, H, W)
timestamp: float
metadata: Dict = field(default_factory=dict)
class BaseModelController(ABC):
"""
Simple batched inference with circular buffer.
Architecture:
- Single deque (circular buffer) for incoming frames
- Batch processor thread continuously consumes batches
- Frames come in fast, batches go out as fast as inference allows
- Automatic frame dropping when queue is full
"""
def __init__(
self,
model_id: str,
batch_size: int = 16,
max_queue_size: int = 100,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
self.model_id = model_id
self.batch_size = batch_size
self.max_queue_size = max_queue_size
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn
# Single circular buffer
self.frame_queue = deque(maxlen=max_queue_size)
self.queue_lock = threading.Lock()
# Processing thread
self.processor_thread: Optional[threading.Thread] = None
self.running = False
self.stop_event = threading.Event()
# Result callbacks (stream_id -> callback)
self.result_callbacks: Dict[str, Callable] = {}
# Statistics
self.total_frames_processed = 0
self.total_batches_processed = 0
self.total_frames_dropped = 0
def start(self):
"""Start the controller background thread"""
if self.running:
logger.warning(f"{self.__class__.__name__} already running")
return
self.running = True
self.stop_event.clear()
# Start single processor thread
self.processor_thread = threading.Thread(
target=self._batch_processor, daemon=True
)
self.processor_thread.start()
logger.info(f"{self.__class__.__name__} started (batch_size={self.batch_size})")
def stop(self):
"""Stop the controller and cleanup"""
if not self.running:
return
logger.info(f"Stopping {self.__class__.__name__}...")
self.running = False
self.stop_event.set()
# Wait for thread to finish
if self.processor_thread and self.processor_thread.is_alive():
self.processor_thread.join(timeout=2.0)
# Process remaining frames
self._process_remaining_frames()
logger.info(f"{self.__class__.__name__} stopped")
def register_callback(self, stream_id: str, callback: Callable):
"""Register a callback for inference results from a stream"""
self.result_callbacks[stream_id] = callback
logger.debug(f"Registered callback for stream: {stream_id}")
def unregister_callback(self, stream_id: str):
"""Unregister a stream callback"""
self.result_callbacks.pop(stream_id, None)
logger.debug(f"Unregistered callback for stream: {stream_id}")
def submit_frame(
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
):
"""Submit a frame for batched inference"""
with self.queue_lock:
# If queue is full, oldest frame is automatically dropped (deque with maxlen)
if len(self.frame_queue) >= self.max_queue_size:
self.total_frames_dropped += 1
batch_frame = BatchFrame(
stream_id=stream_id,
frame=frame,
timestamp=time.time(),
metadata=metadata or {},
)
self.frame_queue.append(batch_frame)
def _batch_processor(self):
"""Background thread that continuously processes batches"""
logger.info(f"{self.__class__.__name__} batch processor started")
while self.running and not self.stop_event.is_set():
# Check if we have enough frames for a batch
with self.queue_lock:
queue_size = len(self.frame_queue)
if queue_size > 0 and queue_size % 10 == 0:
logger.info(f"Queue size: {queue_size}/{self.batch_size}")
if queue_size >= self.batch_size:
# Extract batch
with self.queue_lock:
batch = []
for _ in range(min(self.batch_size, len(self.frame_queue))):
if self.frame_queue:
batch.append(self.frame_queue.popleft())
if batch:
logger.info(f"Processing batch of {len(batch)} frames")
self._process_batch(batch)
else:
# Not enough frames, sleep briefly
time.sleep(0.001) # 1ms
def _process_batch(self, batch: List[BatchFrame]):
"""Process a batch through inference"""
try:
start_time = time.time()
results = self._run_batch_inference(batch)
inference_time = time.time() - start_time
self.total_frames_processed += len(batch)
self.total_batches_processed += 1
logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms"
)
# Emit results to callbacks
for batch_frame, result in zip(batch, results):
callback = self.result_callbacks.get(batch_frame.stream_id)
if callback:
try:
callback(result)
except Exception as e:
logger.error(
f"Error in callback for {batch_frame.stream_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
@abstractmethod
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""Run inference on a batch of frames (backend-specific)"""
pass
def _process_remaining_frames(self):
"""Process any remaining frames in queue during shutdown"""
with self.queue_lock:
remaining = len(self.frame_queue)
if remaining > 0:
logger.info(f"Processing remaining {remaining} frames")
while True:
with self.queue_lock:
if not self.frame_queue:
break
batch = []
for _ in range(min(self.batch_size, len(self.frame_queue))):
if self.frame_queue:
batch.append(self.frame_queue.popleft())
if batch:
self._process_batch(batch)
def get_stats(self) -> Dict[str, Any]:
"""Get current statistics"""
with self.queue_lock:
queue_size = len(self.frame_queue)
return {
"queue_size": queue_size,
"max_queue_size": self.max_queue_size,
"batch_size": self.batch_size,
"registered_streams": len(self.result_callbacks),
"total_frames_processed": self.total_frames_processed,
"total_batches_processed": self.total_batches_processed,
"total_frames_dropped": self.total_frames_dropped,
"avg_batch_size": (
self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0
else 0
),
}
# Keep old BufferState enum for backwards compatibility
from enum import Enum
class BufferState(Enum):
"""Deprecated - kept for backwards compatibility"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"