ultralytic export

This commit is contained in:
Siwat Sirichai 2025-11-11 01:28:19 +07:00
parent bf7b68edb1
commit fdaeb9981c
14 changed files with 2241 additions and 507 deletions

View file

@ -0,0 +1,324 @@
"""
Base Model Controller - Abstract base class for batched inference controllers.
Provides ping-pong buffer architecture with force-switch timeout mechanism.
Implementations handle backend-specific inference (TensorRT, Ultralytics, etc.).
"""
import logging
import threading
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
import torch
logger = logging.getLogger(__name__)
@dataclass
class BatchFrame:
"""Represents a frame in the batch buffer"""
stream_id: str
frame: torch.Tensor # GPU tensor (3, H, W)
timestamp: float
metadata: Dict = field(default_factory=dict)
class BufferState(Enum):
"""State of a ping-pong buffer"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"
class BaseModelController(ABC):
"""
Abstract base class for batched inference with ping-pong buffers.
This controller accumulates frames from multiple streams into batches,
processes them through an inference backend, and routes results back to
stream-specific callbacks.
Features:
- Ping-pong circular buffers (BufferA/BufferB)
- Force-switch timeout to prevent batch starvation
- Event-driven processing with callbacks
- Thread-safe frame submission
Subclasses must implement:
- _run_batch_inference(): Backend-specific inference logic
"""
def __init__(
self,
model_id: str,
batch_size: int = 16,
force_timeout: float = 0.05,
preprocess_fn: Optional[Callable] = None,
postprocess_fn: Optional[Callable] = None,
):
self.model_id = model_id
self.batch_size = batch_size
self.force_timeout = force_timeout
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn
# Ping-pong buffers
self.buffer_a: List[BatchFrame] = []
self.buffer_b: List[BatchFrame] = []
# Buffer states
self.active_buffer = "A"
self.buffer_a_state = BufferState.IDLE
self.buffer_b_state = BufferState.IDLE
# Threading coordination
self.buffer_lock = threading.RLock()
self.last_submit_time = time.time()
# Threads
self.timeout_thread: Optional[threading.Thread] = None
self.processor_threads: Dict[str, threading.Thread] = {}
self.running = False
self.stop_event = threading.Event()
# Result callbacks (stream_id -> callback)
self.result_callbacks: Dict[str, Callable] = {}
# Statistics
self.total_frames_processed = 0
self.total_batches_processed = 0
def start(self):
"""Start the controller background threads"""
if self.running:
logger.warning("ModelController already running")
return
self.running = True
self.stop_event.clear()
# Start timeout monitor thread
self.timeout_thread = threading.Thread(
target=self._timeout_monitor, daemon=True
)
self.timeout_thread.start()
# Start processor threads for each buffer
self.processor_threads["A"] = threading.Thread(
target=self._batch_processor, args=("A",), daemon=True
)
self.processor_threads["B"] = threading.Thread(
target=self._batch_processor, args=("B",), daemon=True
)
self.processor_threads["A"].start()
self.processor_threads["B"].start()
logger.info(f"{self.__class__.__name__} started")
def stop(self):
"""Stop the controller and cleanup"""
if not self.running:
return
logger.info(f"Stopping {self.__class__.__name__}...")
self.running = False
self.stop_event.set()
# Wait for threads to finish
if self.timeout_thread and self.timeout_thread.is_alive():
self.timeout_thread.join(timeout=2.0)
for thread in self.processor_threads.values():
if thread and thread.is_alive():
thread.join(timeout=2.0)
# Process any remaining frames
self._process_remaining_buffers()
logger.info(f"{self.__class__.__name__} stopped")
def register_callback(self, stream_id: str, callback: Callable):
"""Register a callback for inference results from a stream"""
self.result_callbacks[stream_id] = callback
logger.debug(f"Registered callback for stream: {stream_id}")
def unregister_callback(self, stream_id: str):
"""Unregister a stream callback"""
self.result_callbacks.pop(stream_id, None)
logger.debug(f"Unregistered callback for stream: {stream_id}")
def submit_frame(
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
):
"""Submit a frame for batched inference"""
with self.buffer_lock:
batch_frame = BatchFrame(
stream_id=stream_id,
frame=frame,
timestamp=time.time(),
metadata=metadata or {},
)
# Add to active buffer
if self.active_buffer == "A":
self.buffer_a.append(batch_frame)
self.buffer_a_state = BufferState.FILLING
buffer_size = len(self.buffer_a)
else:
self.buffer_b.append(batch_frame)
self.buffer_b_state = BufferState.FILLING
buffer_size = len(self.buffer_b)
self.last_submit_time = time.time()
# Check if we should immediately swap (batch full)
if buffer_size >= self.batch_size:
self._try_swap_buffers()
def _timeout_monitor(self):
"""Monitor force-switch timeout"""
while self.running and not self.stop_event.wait(0.01):
with self.buffer_lock:
time_since_submit = time.time() - self.last_submit_time
if time_since_submit >= self.force_timeout:
active_buffer = (
self.buffer_a if self.active_buffer == "A" else self.buffer_b
)
if len(active_buffer) > 0:
self._try_swap_buffers()
def _try_swap_buffers(self):
"""Attempt to swap ping-pong buffers (called with buffer_lock held)"""
inactive_state = (
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
)
if inactive_state != BufferState.PROCESSING:
old_active = self.active_buffer
self.active_buffer = "B" if old_active == "A" else "A"
if old_active == "A":
self.buffer_a_state = BufferState.PROCESSING
buffer_size = len(self.buffer_a)
else:
self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b)
logger.debug(
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
)
def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available"""
while self.running and not self.stop_event.is_set():
time.sleep(0.001)
with self.buffer_lock:
if buffer_name == "A":
should_process = self.buffer_a_state == BufferState.PROCESSING
else:
should_process = self.buffer_b_state == BufferState.PROCESSING
if should_process:
self._process_buffer(buffer_name)
def _process_buffer(self, buffer_name: str):
"""Process a buffer through inference"""
# Extract buffer to process
with self.buffer_lock:
if buffer_name == "A":
batch = self.buffer_a.copy()
self.buffer_a.clear()
else:
batch = self.buffer_b.copy()
self.buffer_b.clear()
if len(batch) == 0:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
return
# Process batch (outside lock to allow concurrent submissions)
try:
start_time = time.time()
results = self._run_batch_inference(batch)
inference_time = time.time() - start_time
self.total_frames_processed += len(batch)
self.total_batches_processed += 1
logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
)
# Emit results to callbacks
for batch_frame, result in zip(batch, results):
callback = self.result_callbacks.get(batch_frame.stream_id)
if callback:
try:
callback(result)
except Exception as e:
logger.error(
f"Error in callback for {batch_frame.stream_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
finally:
with self.buffer_lock:
if buffer_name == "A":
self.buffer_a_state = BufferState.IDLE
else:
self.buffer_b_state = BufferState.IDLE
@abstractmethod
def _run_batch_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
"""
Run inference on a batch of frames (backend-specific).
Args:
batch: List of BatchFrame objects
Returns:
List of detection results (one per frame)
"""
pass
def _process_remaining_buffers(self):
"""Process any remaining frames in buffers during shutdown"""
if len(self.buffer_a) > 0:
logger.info(f"Processing remaining {len(self.buffer_a)} frames in buffer A")
self._process_buffer("A")
if len(self.buffer_b) > 0:
logger.info(f"Processing remaining {len(self.buffer_b)} frames in buffer B")
self._process_buffer("B")
def get_stats(self) -> Dict[str, Any]:
"""Get current buffer statistics"""
return {
"active_buffer": self.active_buffer,
"buffer_a_size": len(self.buffer_a),
"buffer_b_size": len(self.buffer_b),
"buffer_a_state": self.buffer_a_state.value,
"buffer_b_state": self.buffer_b_state.value,
"registered_streams": len(self.result_callbacks),
"total_frames_processed": self.total_frames_processed,
"total_batches_processed": self.total_batches_processed,
"avg_batch_size": (
self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0
else 0
),
}