ultralytic export
This commit is contained in:
parent
bf7b68edb1
commit
fdaeb9981c
14 changed files with 2241 additions and 507 deletions
|
|
@ -5,21 +5,22 @@ This module provides batched inference coordination using ping-pong circular buf
|
|||
with force-switch timeout mechanism using threading and callbacks.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import time
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchFrame:
|
||||
"""Represents a frame in the batch buffer"""
|
||||
|
||||
stream_id: str
|
||||
frame: torch.Tensor # GPU tensor (3, H, W)
|
||||
timestamp: float
|
||||
|
|
@ -28,6 +29,7 @@ class BatchFrame:
|
|||
|
||||
class BufferState(Enum):
|
||||
"""State of a ping-pong buffer"""
|
||||
|
||||
IDLE = "idle"
|
||||
FILLING = "filling"
|
||||
PROCESSING = "processing"
|
||||
|
|
@ -80,7 +82,9 @@ class ModelController:
|
|||
f"Will process frames sequentially. Consider rebuilding model with dynamic batching."
|
||||
)
|
||||
else:
|
||||
logger.info(f"Model '{model_id}' supports batch_size={self.model_batch_size}")
|
||||
logger.info(
|
||||
f"Model '{model_id}' supports batch_size={self.model_batch_size}"
|
||||
)
|
||||
|
||||
# Ping-pong buffers
|
||||
self.buffer_a: List[BatchFrame] = []
|
||||
|
|
@ -130,7 +134,9 @@ class ModelController:
|
|||
# Fixed batch size
|
||||
return batch_dim
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not detect model batch size: {e}. Assuming batch_size=1")
|
||||
logger.warning(
|
||||
f"Could not detect model batch size: {e}. Assuming batch_size=1"
|
||||
)
|
||||
return 1
|
||||
|
||||
def start(self):
|
||||
|
|
@ -143,14 +149,20 @@ class ModelController:
|
|||
self.stop_event.clear()
|
||||
|
||||
# Start timeout monitor thread
|
||||
self.timeout_thread = threading.Thread(target=self._timeout_monitor, daemon=True)
|
||||
self.timeout_thread = threading.Thread(
|
||||
target=self._timeout_monitor, daemon=True
|
||||
)
|
||||
self.timeout_thread.start()
|
||||
|
||||
# Start processor threads for each buffer
|
||||
self.processor_threads['A'] = threading.Thread(target=self._batch_processor, args=('A',), daemon=True)
|
||||
self.processor_threads['B'] = threading.Thread(target=self._batch_processor, args=('B',), daemon=True)
|
||||
self.processor_threads['A'].start()
|
||||
self.processor_threads['B'].start()
|
||||
self.processor_threads["A"] = threading.Thread(
|
||||
target=self._batch_processor, args=("A",), daemon=True
|
||||
)
|
||||
self.processor_threads["B"] = threading.Thread(
|
||||
target=self._batch_processor, args=("B",), daemon=True
|
||||
)
|
||||
self.processor_threads["A"].start()
|
||||
self.processor_threads["B"].start()
|
||||
|
||||
logger.info("ModelController started")
|
||||
|
||||
|
|
@ -197,10 +209,7 @@ class ModelController:
|
|||
logger.debug(f"Unregistered callback for stream: {stream_id}")
|
||||
|
||||
def submit_frame(
|
||||
self,
|
||||
stream_id: str,
|
||||
frame: torch.Tensor,
|
||||
metadata: Optional[Dict] = None
|
||||
self, stream_id: str, frame: torch.Tensor, metadata: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Submit a frame for batched inference.
|
||||
|
|
@ -215,7 +224,7 @@ class ModelController:
|
|||
stream_id=stream_id,
|
||||
frame=frame,
|
||||
timestamp=time.time(),
|
||||
metadata=metadata or {}
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
# Add to active buffer
|
||||
|
|
@ -242,7 +251,9 @@ class ModelController:
|
|||
|
||||
# Check if timeout expired and we have frames waiting
|
||||
if time_since_submit >= self.force_timeout:
|
||||
active_buffer = self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||
active_buffer = (
|
||||
self.buffer_a if self.active_buffer == "A" else self.buffer_b
|
||||
)
|
||||
if len(active_buffer) > 0:
|
||||
self._try_swap_buffers()
|
||||
|
||||
|
|
@ -254,7 +265,9 @@ class ModelController:
|
|||
This method should be called with buffer_lock held.
|
||||
"""
|
||||
# Check if inactive buffer is available
|
||||
inactive_state = self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
||||
inactive_state = (
|
||||
self.buffer_b_state if self.active_buffer == "A" else self.buffer_a_state
|
||||
)
|
||||
|
||||
if inactive_state != BufferState.PROCESSING:
|
||||
# Swap active buffer
|
||||
|
|
@ -269,7 +282,9 @@ class ModelController:
|
|||
self.buffer_b_state = BufferState.PROCESSING
|
||||
buffer_size = len(self.buffer_b)
|
||||
|
||||
logger.debug(f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})")
|
||||
logger.debug(
|
||||
f"Swapped buffers: {old_active} -> {self.active_buffer} (size: {buffer_size})"
|
||||
)
|
||||
|
||||
def _batch_processor(self, buffer_name: str):
|
||||
"""Background thread that processes a specific buffer when available"""
|
||||
|
|
@ -322,8 +337,8 @@ class ModelController:
|
|||
self.total_batches_processed += 1
|
||||
|
||||
logger.debug(
|
||||
f"Processed batch of {len(batch)} frames in {inference_time*1000:.2f}ms "
|
||||
f"({inference_time*1000/len(batch):.2f}ms per frame)"
|
||||
f"Processed batch of {len(batch)} frames in {inference_time * 1000:.2f}ms "
|
||||
f"({inference_time * 1000 / len(batch):.2f}ms per frame)"
|
||||
)
|
||||
|
||||
# Emit results to callbacks
|
||||
|
|
@ -334,7 +349,10 @@ class ModelController:
|
|||
try:
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in callback for {batch_frame.stream_id}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error in callback for {batch_frame.stream_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch: {e}", exc_info=True)
|
||||
|
|
@ -365,7 +383,9 @@ class ModelController:
|
|||
# Use true batching for models that support it
|
||||
return self._run_batched_inference(batch)
|
||||
|
||||
def _run_sequential_inference(self, batch: List[BatchFrame]) -> List[Dict[str, Any]]:
|
||||
def _run_sequential_inference(
|
||||
self, batch: List[BatchFrame]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run inference sequentially for batch_size=1 models"""
|
||||
results = []
|
||||
|
||||
|
|
@ -375,13 +395,15 @@ class ModelController:
|
|||
processed = self.preprocess_fn(batch_frame.frame)
|
||||
else:
|
||||
# Ensure we have batch dimension
|
||||
processed = batch_frame.frame.unsqueeze(0) if batch_frame.frame.dim() == 3 else batch_frame.frame
|
||||
processed = (
|
||||
batch_frame.frame.unsqueeze(0)
|
||||
if batch_frame.frame.dim() == 3
|
||||
else batch_frame.frame
|
||||
)
|
||||
|
||||
# Run inference for this frame
|
||||
outputs = self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": processed},
|
||||
synchronize=True
|
||||
self.model_id, {"images": processed}, synchronize=True
|
||||
)
|
||||
|
||||
# Postprocess
|
||||
|
|
@ -389,9 +411,13 @@ class ModelController:
|
|||
try:
|
||||
detections = self.postprocess_fn(outputs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
|
||||
logger.error(
|
||||
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
||||
)
|
||||
# Return empty detections on error
|
||||
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
|
||||
detections = torch.zeros(
|
||||
(0, 6), device=list(outputs.values())[0].device
|
||||
)
|
||||
else:
|
||||
detections = outputs
|
||||
|
||||
|
|
@ -429,32 +455,37 @@ class ModelController:
|
|||
f"will split into sub-batches"
|
||||
)
|
||||
# TODO: Handle splitting into sub-batches
|
||||
batch_tensor = batch_tensor[:self.model_batch_size]
|
||||
batch = batch[:self.model_batch_size]
|
||||
batch_tensor = batch_tensor[: self.model_batch_size]
|
||||
batch = batch[: self.model_batch_size]
|
||||
|
||||
# Run inference
|
||||
outputs = self.model_repository.infer(
|
||||
self.model_id,
|
||||
{"images": batch_tensor},
|
||||
synchronize=True
|
||||
self.model_id, {"images": batch_tensor}, synchronize=True
|
||||
)
|
||||
|
||||
# Postprocess results (split batch back to individual results)
|
||||
results = []
|
||||
for i, batch_frame in enumerate(batch):
|
||||
# Extract single frame output from batch
|
||||
# Extract single frame output from batch and clone to ensure memory safety
|
||||
# This prevents potential race conditions if the output tensors are still
|
||||
# in use when the next inference batch is processed
|
||||
frame_output = {}
|
||||
for k, v in outputs.items():
|
||||
# v has shape (N, ...), extract index i and keep batch dimension
|
||||
frame_output[k] = v[i:i+1] # Shape: (1, ...)
|
||||
# Clone to decouple from shared batch output tensor
|
||||
frame_output[k] = v[i : i + 1].clone() # Shape: (1, ...)
|
||||
|
||||
if self.postprocess_fn:
|
||||
try:
|
||||
detections = self.postprocess_fn(frame_output)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in postprocess for stream {batch_frame.stream_id}: {e}")
|
||||
logger.error(
|
||||
f"Error in postprocess for stream {batch_frame.stream_id}: {e}"
|
||||
)
|
||||
# Return empty detections on error
|
||||
detections = torch.zeros((0, 6), device=list(outputs.values())[0].device)
|
||||
detections = torch.zeros(
|
||||
(0, 6), device=list(outputs.values())[0].device
|
||||
)
|
||||
else:
|
||||
detections = frame_output
|
||||
|
||||
|
|
@ -490,6 +521,8 @@ class ModelController:
|
|||
"total_batches_processed": self.total_batches_processed,
|
||||
"avg_batch_size": (
|
||||
self.total_frames_processed / self.total_batches_processed
|
||||
if self.total_batches_processed > 0
|
||||
else 0
|
||||
if self.total_batches_processed > 0 else 0
|
||||
),
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue