ultralytic export
This commit is contained in:
parent
bf7b68edb1
commit
fdaeb9981c
14 changed files with 2241 additions and 507 deletions
324
services/base_model_controller.py
Normal file
324
services/base_model_controller.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
"""
|
||||
Base Model Controller - Abstract base class for batched inference controllers.
|
||||
|
||||
Provides ping-pong buffer architecture with force-switch timeout mechanism.
|
||||
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchFrame:
|
||||
"""Represents a frame in the batch buffer"""
|
||||
|
||||
stream_id: str
|
||||
frame: torch.Tensor # GPU tensor (3, H, W)
|
||||
timestamp: float
|
||||
metadata: Dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class BufferState(Enum):
|
||||
"""State of a ping-pong buffer"""
|
||||
|
||||
IDLE = "idle"
|
||||
FILLING = "filling"
|
||||
PROCESSING = "processing"
|
||||
|
||||
|
||||
class BaseModelController(ABC):
|
||||
"""
|
||||
Abstract base class for batched inference with ping-pong buffers.
|
||||
|
||||
This controller accumulates frames from multiple streams into batches,
|
||||
processes them through an inference backend, and routes results back to
|
||||
stream-specific callbacks.
|
||||
|
||||
Features:
|
||||
- Ping-pong circular buffers (BufferA/BufferB)
|
||||
- Force-switch timeout to prevent batch starvation
|
||||
- Event-driven processing with callbacks
|
||||
- Thread-safe frame submission
|
||||
|
||||
Subclasses must implement:
|
||||
- _run_batch_inference(): Backend-specific inference logic
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
batch_size: int = 16,
|
||||
force_timeout: float = 0.05,
|
||||
preprocess_fn: Optional[Callable] = None,
|
||||
postprocess_fn: Optional[Callable] = None,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.batch_size = batch_size
|
||||
self.force_timeout = force_timeout
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self.postprocess_fn = postprocess_fn
|
||||
|
||||
# Ping-pong buffers
|
||||
self.buffer_a: List[BatchFrame] = []
|
||||
self.buffer_b: List[BatchFrame] = []
|
||||
|
||||
# Buffer states
|
||||
self.active_buffer = "A"
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
# Threading coordination
|
||||
self.buffer_lock = threading.RLock()
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Threads
|
||||
self.timeout_thread: Optional[threading.Thread] = None
|
||||
self.processor_threads: Dict[str, threading.Thread] = {}
|
||||
self.running = False
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
# Result callbacks (stream_id -> callback)
|
||||
self.result_callbacks: Dict[str, Callable] = {}
|
||||
|
||||
# Statistics
|
||||
self.total_frames_processed = 0
|
||||
self.total_batches_processed = 0
|
||||
|
||||
def start(self):
|
||||
"""Start the controller background threads"""
|
||||
if self.running:
|
||||
logger.warning("ModelController already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.stop_event.clear()
|
||||
|
||||
# Start timeout monitor thread
|
||||
self.timeout_thread = threading.Thread(
|
||||
target=self._timeout_monitor, daemon=True
|
||||
)
|
||||
self.timeout_thread.start()
|
||||
|
||||
# Start processor threads for each buffer
|
||||
self.processor_threads["A"] = threading.Thread(
|
||||
target=self._batch_processor, args=("A",), daemon=True
|
||||
)
|
||||
self.processor_threads["B"] = threading.Thread(
|
||||
target=self._batch_processor, args=("B",), daemon=True
|
||||
)
|
||||
self.processor_threads["A"].start()
|
||||
self.processor_threads["B"].start()
|
||||
|
||||
logger.info(f"{self.__class__.__name__} started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and cleanup"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
logger.info(f"Stopping {self.__class__.__name__}...")
|
||||
self.running = False
|
||||
self.stop_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if self.timeout_thread and self.timeout_thread.is_alive():
|
||||
self.timeout_thread.join(timeout=2.0)
|
||||
|
||||
for thread in self.processor_threads.values():
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
# Process any remaining frames
|
||||
self._process_remaining_buffers()
|
||||
logger.info(f"{self.__class__.__name__} stopped")
|
||||
|
||||
def register_callback(self, stream_id: str, callback: Callable):
|
||||
"""Register a callback for inference results from a stream"""
|
||||
self.result_callbacks[stream_id] = callback
|
||||
logger.debug(f"Registered callback for stream: {stream_id}")
|
||||
|
||||
def unregister_callback(self, stream_id: str):
|
||||
"""Unregister a stream callback"""
|
||||
self.result_callbacks.pop(stream_id, None)
|
||||
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
||||
|
||||
def submit_frame(
|
||||
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
||||
):
|
||||
"""Submit a frame for batched inference"""
|
||||
with self.buffer_lock:
|
||||
batch_frame = BatchFrame(
|
||||
stream_id=stream_id,
|
||||
frame=frame,
|
||||
timestamp=time.time(),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
# Add to active buffer
|
||||
if self.active_buffer == "A":
|
||||
self.buffer_a.append(batch_frame)
|
||||
self.buffer_a_state = BufferState.FILLING
|
||||
buffer_size = len(self.buffer_a)
|
||||
else:
|
||||
self.buffer_b.append(batch_frame)
|
||||
self.buffer_b_state = BufferState.FILLING
|
||||
buffer_size = len(self.buffer_b)
|
||||
|
||||
self.last_submit_time = time.time()
|
||||
|
||||
# Check if we should immediately swap (batch full)
|
||||
if buffer_size >= self.batch_size:
|
||||
self._try_swap_buffers()
|
||||
|
||||
def _timeout_monitor(self):
|
||||
"""Monitor force-switch timeout"""
|
||||
while self.running and not self.stop_event.wait(0.01):
|
||||
with self.buffer_lock:
|
||||
time_since_submit = time.time() - self.last_submit_time
|
||||
|
||||
if time_since_submit >= self.force_timeout:
|
||||
active_buffer = (
|
||||
self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||
)
|
||||
if len(active_buffer) > 0:
|
||||
self._try_swap_buffers()
|
||||
|
||||
def _try_swap_buffers(self):
|
||||
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
|
||||
inactive_state = (
|
||||
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
||||
)
|
||||
|
||||
if inactive_state != BufferState.PROCESSING:
|
||||
old_active = self.active_buffer
|
||||
self.active_buffer = "B" if old_active == "A" else "A"
|
||||
|
||||
if old_active == "A":
|
||||
self.buffer_a_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_a)
|
||||
else:
|
||||
self.buffer_b_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_b)
|
||||
|
||||
logger.debug(
|
||||
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
|
||||
)
|
||||
|
||||
def _batch_processor(self, buffer_name: str):
|
||||
"""Background thread that processes a specific buffer when available"""
|
||||
while self.running and not self.stop_event.is_set():
|
||||
time.sleep(0.001)
|
||||
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
should_process = self.buffer_a_state == BufferState.PROCESSING
|
||||
else:
|
||||
should_process = self.buffer_b_state == BufferState.PROCESSING
|
||||
|
||||
if should_process:
|
||||
self._process_buffer(buffer_name)
|
||||
|
||||
def _process_buffer(self, buffer_name: str):
|
||||
"""Process a buffer through inference"""
|
||||
# Extract buffer to process
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
batch = self.buffer_a.copy()
|
||||
self.buffer_a.clear()
|
||||
else:
|
||||
batch = self.buffer_b.copy()
|
||||
self.buffer_b.clear()
|
||||
|
||||
if len(batch) == 0:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
return
|
||||
|
||||
# Process batch (outside lock to allow concurrent submissions)
|
||||
try:
|
||||
start_time = time.time()
|
||||
results = self._run_batch_inference(batch)
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
self.total_frames_processed += len(batch)
|
||||
self.total_batches_processed += 1
|
||||
|
||||
logger.debug(
|
||||
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
|
||||
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
|
||||
)
|
||||
|
||||
# Emit results to callbacks
|
||||
for batch_frame, result in zip(batch, results):
|
||||
callback = self.result_callbacks.get(batch_frame.stream_id)
|
||||
if callback:
|
||||
try:
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in callback for {batch_frame.stream_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
with self.buffer_lock:
|
||||
if buffer_name == "A":
|
||||
self.buffer_a_state = BufferState.IDLE
|
||||
else:
|
||||
self.buffer_b_state = BufferState.IDLE
|
||||
|
||||
@abstractmethod
|
||||
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run inference on a batch of frames (backend-specific).
|
||||
|
||||
Args:
|
||||
batch: List of BatchFrame objects
|
||||
|
||||
Returns:
|
||||
List of detection results (one per frame)
|
||||
"""
|
||||
pass
|
||||
|
||||
def _process_remaining_buffers(self):
|
||||
"""Process any remaining frames in buffers during shutdown"""
|
||||
if len(self.buffer_a) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
|
||||
self._process_buffer("A")
|
||||
if len(self.buffer_b) > 0:
|
||||
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
|
||||
self._process_buffer("B")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get current buffer statistics"""
|
||||
return {
|
||||
"active_buffer": self.active_buffer,
|
||||
"buffer_a_size": len(self.buffer_a),
|
||||
"buffer_b_size": len(self.buffer_b),
|
||||
"buffer_a_state": self.buffer_a_state.value,
|
||||
"buffer_b_state": self.buffer_b_state.value,
|
||||
"registered_streams": len(self.result_callbacks),
|
||||
"total_frames_processed": self.total_frames_processed,
|
||||
"total_batches_processed": self.total_batches_processed,
|
||||
"avg_batch_size": (
|
||||
self.total_frames_processed / self.total_batches_processed
|
||||
if self.total_batches_processed > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue