ultralytic export

This commit is contained in:
Siwat Sirichai 2025-11-11 01:28:19 +07:00
parent bf7b68edb1
commit fdaeb9981c
14 changed files with 2241 additions and 507 deletions

View file

@ -5,21 +5,22 @@ This module provides batched inference coordination using ping-pong circular buf
with force-switch timeout mechanism using threading and callbacks.
"""
import threading
import torch
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import time
import logging
import queue
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
import torch
logger = logging.getLogger(__name__)
@dataclass
class BatchFrame:
"""Represents a frame in the batch buffer"""
stream_id: str
frame: torch.Tensor # GPU tensor (3, H, W)
timestamp: float
@ -28,6 +29,7 @@ class BatchFrame:
class BufferState(Enum):
"""State of a ping-pong buffer"""
IDLE = "idle"
FILLING = "filling"
PROCESSING = "processing"
@ -80,7 +82,9 @@ class ModelController:
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
)
else:
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}")
logger.info(
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
)
# Ping-pong buffers
self.buffer_a: List[BatchFrame] = []
@ -130,7 +134,9 @@ class ModelController:
# Fixed batch size
return batch_dim
except Exception as e:
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
logger.warning(
f"Could not detect model batch size: {e}. Assuming batch_size=1"
)
return 1
def start(self):
@ -143,14 +149,20 @@ class ModelController:
self.stop_event.clear()
# Start timeout monitor thread
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True)
self.timeout_thread = threading.Thread(
target=self._timeout_monitor, daemon=True
)
self.timeout_thread.start()
# Start processor threads for each buffer
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True)
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True)
self.processor_threads['A'].start()
self.processor_threads['B'].start()
self.processor_threads["A"] = threading.Thread(
target=self._batch_processor, args=("A",), daemon=True
)
self.processor_threads["B"] = threading.Thread(
target=self._batch_processor, args=("B",), daemon=True
)
self.processor_threads["A"].start()
self.processor_threads["B"].start()
logger.info("ModelController started")
@ -197,10 +209,7 @@ class ModelController:
logger.debug(f"Unregistered callback for stream: {stream_id}")
def submit_frame(
self,
stream_id: str,
frame: torch.Tensor,
metadata: Optional[Dict] = None
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
):
"""
Submit a frame for batched inference.
@ -215,7 +224,7 @@ class ModelController:
stream_id=stream_id,
frame=frame,
timestamp=time.time(),
metadata=metadata or {}
metadata=metadata or {},
)
# Add to active buffer
@ -242,7 +251,9 @@ class ModelController:
# Check if timeout expired and we have frames waiting
if time_since_submit >= self.force_timeout:
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
active_buffer = (
self.buffer_a if self.active_buffer == "A" else self.buffer_b
)
if len(active_buffer) > 0:
self._try_swap_buffers()
@ -254,7 +265,9 @@ class ModelController:
This method should be called with buffer_lock held.
"""
# Check if inactive buffer is available
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
inactive_state = (
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
)
if inactive_state != BufferState.PROCESSING:
# Swap active buffer
@ -269,7 +282,9 @@ class ModelController:
self.buffer_b_state = BufferState.PROCESSING
buffer_size = len(self.buffer_b)
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
logger.debug(
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
)
def _batch_processor(self, buffer_name: str):
"""Background thread that processes a specific buffer when available"""
@ -322,8 +337,8 @@ class ModelController:
self.total_batches_processed += 1
logger.debug(
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms "
f"({inference_time*1000/len(batch):.2f}ms per frame)"
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
)
# Emit results to callbacks
@ -334,7 +349,10 @@ class ModelController:
try:
callback(result)
except Exception as e:
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True)
logger.error(
f"Error in callback for {batch_frame.stream_id}: {e}",
exc_info=True,
)
except Exception as e:
logger.error(f"Error processing batch: {e}", exc_info=True)
@ -365,7 +383,9 @@ class ModelController:
# Use true batching for models that support it
return self._run_batched_inference(batch)
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
def _run_sequential_inference(
self, batch: List[BatchFrame]
) -> List[Dict[str, Any]]:
"""Run inference sequentially for batch_size=1 models"""
results = []
@ -375,13 +395,15 @@ class ModelController:
processed = self.preprocess_fn(batch_frame.frame)
else:
# Ensure we have batch dimension
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
processed = (
batch_frame.frame.unsqueeze(0)
if batch_frame.frame.dim() == 3
else batch_frame.frame
)
# Run inference for this frame
outputs = self.model_repository.infer(
self.model_id,
{"images": processed},
synchronize=True
self.model_id, {"images": processed}, synchronize=True
)
# Postprocess
@ -389,9 +411,13 @@ class ModelController:
try:
detections = self.postprocess_fn(outputs)
except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
# Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else:
detections = outputs
@ -429,32 +455,37 @@ class ModelController:
f"will split into sub-batches"
)
# TODO: Handle splitting into sub-batches
batch_tensor = batch_tensor[:self.model_batch_size]
batch = batch[:self.model_batch_size]
batch_tensor = batch_tensor[: self.model_batch_size]
batch = batch[: self.model_batch_size]
# Run inference
outputs = self.model_repository.infer(
self.model_id,
{"images": batch_tensor},
synchronize=True
self.model_id, {"images": batch_tensor}, synchronize=True
)
# Postprocess results (split batch back to individual results)
results = []
for i, batch_frame in enumerate(batch):
# Extract single frame output from batch
# Extract single frame output from batch and clone to ensure memory safety
# This prevents potential race conditions if the output tensors are still
# in use when the next inference batch is processed
frame_output = {}
for k, v in outputs.items():
# v has shape (N, ...), extract index i and keep batch dimension
frame_output[k] = v[i:i+1] # Shape: (1, ...)
# Clone to decouple from shared batch output tensor
frame_output[k] = v[i : i + 1].clone() # Shape: (1, ...)
if self.postprocess_fn:
try:
detections = self.postprocess_fn(frame_output)
except Exception as e:
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
logger.error(
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
)
# Return empty detections on error
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
detections = torch.zeros(
(0, 6), device=list(outputs.values())[0].device
)
else:
detections = frame_output
@ -490,6 +521,8 @@ class ModelController:
"total_batches_processed": self.total_batches_processed,
"avg_batch_size": (
self.total_frames_processed / self.total_batches_processed
if self.total_batches_processed > 0
else 0
if self.total_batches_processed > 0 else 0
),
}